Using keras in tensorflow to build convolutional neural network

2022-07-07 10:03:00 guluC

Six steps

1、 Import tensorflow library

import tensorflow.keras as keras

2、 Prepare training data


3、 Build a network structure

  • Generate a container to store the network structure
model = keras.models.Sequential() # Describe each layer of network 
  • Convolution layer
	filters,  # Number of convolution kernels 
    kernel_size,  # Convolution kernel size   It's usually (3,3)
    strides=(1, 1),  # Sliding step   Default (1,1)
    padding='valid',  # Zero filling strategy  'valid' perhaps 'same'
    activation=None,  # Activation function   Commonly used relu,softmax,selu
    input_shape  # Define the input data style  (64,64,3)64*64 Three dimensional view of 
  • Pooling layer
    pool_size=(2, 2),  # Pool layer size 
    strides=None,  # step 
    padding='valid',  # Zero filling strategy  'valid' perhaps 'same'
  • Flatten ( Change the data into one dimension , It is often used in the transition from convolution layer to full connection layer )
  • Fully connected layer
    units,  #  Dimension of output space 
    activation=None, #  Activation function   Commonly used relu,softmax,selu
    use_bias=True, #  Boolean value , Whether to use offset vector 
  • Dropout layer ( Prevent too fitting , Improve the generalization ability of the model )
	rate, #0-1 Decimal between   Percentage discarded 
    seed=None # Random seeds 

4、 Print network structure and parameter statistics

summary Function is used to print network structure and parameter statistics

5、 Configure the optimizer for training 、 Loss function and accuracy evaluation criteria

    optimizer, # Optimizer 
    loss, #  Loss function 
    metrics # Network evaluation index 
  • optimizer Parameters can be character string The optimizer name given by the form , It can also be in the form of a function , The learning rate can be set in the form of a function 、 Momentum and hyperparameters
    • “sgd” perhaps keras.optimizers.SGD(lr = Learning rate ,decay = Learning decay rate ,momentum = Momentum parameter )
    • "adagrad’" perhaps keras.optimizers.Adagrad(lr = Learning rate , decay = Learning rate decay rate )
    • "adadelta" perhaps keras.optimizers.Adadelta(lr = Learning rate , decay = Learning rate decay rate )
    • "adam" perhaps keras.optimizers.Adam(lr = Learning rate ,decay = Learning rate decay rate )
  • lose Parameters can be given in string form Loss function Name , It can also be in the form of a function 、
    • "mse" perhaps keras.losses.MeanSquaredError()
    • “sparse_categorical_crossentropy” perhaps keras.losses.SparseCatagoricalCrossentropy(from_logits = False)
  • Metrics Label network evaluation index
    • “accuracy” : y_ and y It's all numbers , Such as y_ = [1] y = [1] #y_ For real value ,y For the predicted value
    • "sparse_accuracy"y_ and y It's all based on a single hot code And probability distribution , Such as y_ = [0, 1, 0], y = [0.256, 0.695, 0.048]
    • "sparse_categorical_accuracy"y_ It is given in numerical form ,y In order to The unique heat code gives , Such as y_ = [1], y = [0.256 0.695, 0.048]


    x, y, 
  • x: input data . If the model has only one input , that x The type is numpy
    array, If the model has multiple inputs , that x The type should be list,list Is corresponding to each input numpy array
  • y: label ,numpy array
  • batch_size: Integers , Specifies the gradient descent for each batch Number of samples included . One for training batch The sample will be calculated as a gradient descent , Optimize the target function one step .
  • epochs: Integers , At the end of training epoch value , The training will be at the end of the day epoch Value , When there is no setting initial_epoch when , It is the total number of rounds of training , Otherwise, the total number of rounds of training is epochs - inital_epoch
  • verbose: The log shows ,0 Output log information for non-standard output stream ,1 Record for the output progress bar ,2 For each epoch Output line record
  • callbacks:list, The elements are keras.callbacks.Callback The object of . This list The callback function will be called at the appropriate time during the training , Refer to the callback function
  • validation_split:0~1 The floating point number between , A percentage of the data used to specify the training set is used as the validation set . Validation sets will not be trained , And in each epoch End - of - test model metrics , Like the loss function 、 Precision etc. . Be careful ,validation_split The division of the shuffle Before , So if your data itself is ordered , You need to manually scramble it before you specify it validation_split, Otherwise, an uneven sample of the validation set may occur .
  • validation_data: In the form of (X,y) Of tuple, Is the specified validation set . This parameter overrides validation_spilt.
  • shuffle: Boolean or string , Is generally a Boolean value , Indicates whether the sequence of input samples is randomly scrambled during training . If string “batch”, It's used to deal with HDF5 The special case of data , It will be batch Internally scrambles the data .
  • class_weight: Dictionaries , Mapping different categories to different weights , This parameter is used to adjust the loss function during training ( Only for training )
  • sample_weight: Weights of numpy
    array, Used to adjust the loss function during training ( For training purposes only ). You can pass a 1D The vector with the same length as the sample is used to carry on the sample 1 Yes 1 A weighted , Or in the case of temporal data , The form of passing one is (samples,sequence_length) To assign different weights to the samples on each time step . In this case be sure to add when compiling the model sample_weight_mode=’temporal’.
  • initial_epoch: Specified from this parameter epoch Start training , It's useful to continue the previous training .

fit The function returns a History The object of , Its History.history Attribute records the value of loss function and other indicators epoch Changing circumstances , If there is a verification set , Also contains the change of these indicators of the verification set

