当前位置:网站首页>Pytorch idea and implementation of keras code conversion for CNN image classification

Pytorch idea and implementation of keras code conversion for CNN image classification

2022-06-22 07:08:00 zorchp


tags: Python DL

Write it at the front

I changed a code a few days ago , It is about convolution neural network in deep learning Python Code , Used to solve classification problems . The code uses TensorFlow Of Keras Interface written , Requirements are translated into pytorch Code , Whereas both api Close , It won't be too difficult to cover it , Just some details need to be paid attention to , Record it here , For your reference .

About library function import

First, let's take a look at the difference between the two popular deep learning frameworks in the import of library functions , This requires a brief understanding of their main structures . For the convenience of narration , As mentioned below TF All refer to TensorFlow2.X with Keras, Torch All refer to PyTorch.

model building

First, let's look at the construction of the model , about TF, The model can be built easily through sequential Method to get , This requires the introduction of this method :

from tensorflow.keras.models import Sequential

stay Torch in , Of course, it can also be sequential Build the model , ( However, the government still recommends an object-oriented approach )

We need to introduce :

from torch.nn import Sequential

Speaking of model building , We have to mention several layers that are commonly used in convolutional neural networks : conv layer , maxpool Layer and full connection layer (softmax), These are readily available in both frameworks , Let's see how to call these methods :

stay TF in :

from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense

And in the Torch in :

from torch.nn import Conv2d, MaxPool2d
from torch.nn import Flatten, Linear, CrossEntropyLoss
from torch.optim import SGD

It can be seen that the two are only slightly different , TF Put the calls of some activation functions in the parameters , and Torch Are given in the form of library functions .

Data read in

Finally, let's take a look at the data import section , stay TF You can easily use the following methods to process data ( picture ) Processing and reading of :

from tensorflow.keras import backend
from tensorflow.keras.preprocessing.image import ImageDataGenerator

stay Torch in , Similar import is required :

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

data fetch / Deal with partial api differences

In the data reading section , I still feel like Keras It's more convenient 1, Torch It mainly uses the modular import method , You need to instantiate a class first , Then the object is used to process the image .

Let's take a look at TF Code for reading picture data :

#  Import data 
if backend.image_data_format() == 'channels_first':
    input_shape = (3, img_width, img_height)
else:
    input_shape = (img_width, img_height, 3)

#  Training set image enhancement 
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

#  Test set image enhancement (only rescaling)
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')  #  Many classification 

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical')  #  Many classification 

Next is Torch Code for :

#  Import data 
input_shape = (img_width, img_height, 3)

#  Training set image enhancement 
train_datagen = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((img_width, img_height))
])


#  Test set image enhancement (only rescaling)
test_datagen = transforms.Compose([  #  Perform the following specified operations on the read picture 
    transforms.ToTensor(), #  This step is equivalent to Keras Of rescale by 1/255
    transforms.Resize((img_width, img_height))
])

train_generator = datasets.ImageFolder(train_data_dir, 
                                       transform=train_datagen)

validation_generator = datasets.ImageFolder(validation_data_dir,
                                            transform=test_datagen)

train_loader = torch.utils.data.DataLoader(train_generator, 
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(validation_generator,
                                          batch_size=batch_size,
                                          shuffle=False)

Of model building api differences

Let's talk about the most important , The building part of the model api The difference between calls , stay TF Directly in model.add Call to , You can easily create a CNN Identify the model , Note the correspondence of data flow dimensions , Here is the code . Concise and intuitive .

#  Creating models 
model = Sequential()
model.add(Conv2D(filters=6, 
                 kernel_size=(5, 5), 
                 padding='valid', 
                 input_shape=input_shape, 
                 activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=16, 
                 kernel_size=(5, 5), 
                 padding='valid', 
                 activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(4, activation='softmax'))

# Compile model 
model.compile(loss='categorical_crossentropy',
              optimizer='sgd',
              metrics=['accuracy'])

stay Torch in , There's a similar way , However, there is no need to compile the model , The code is as follows :

#  Creating models 
model = Sequential(
    Conv2d(in_channels=3, 
           out_channels=6, 
           kernel_size=(5, 5), 
           padding='valid'),
    MaxPool2d(kernel_size=(2, 2)),
    Conv2d(in_channels=6, 
           out_channels=16, 
           kernel_size=(5, 5), 
           padding='valid'),
    MaxPool2d(kernel_size=(2, 2)),
    Flatten(),
    Linear(400, 120),
    Linear(120, 84), 
    Linear(84, 4)
)

#  The loss function is set as the cross entropy function 
criterion = CrossEntropyLoss()
#  Set optimizer to random gradient descent algorithm 
optimizer = SGD(model.parameters(), lr=0.001)

Here it is api There are still some differences , For example, the writing method and parameters of the full connection layer , There are also some differences in convolution . Same as , Still pay great attention to the data dimension .

Model training part api differences

stay TF in , By introducing Keras This powerful and grammatically concise api, The training model is also very simple , The code is as follows :

# Training models 
history=model.fit_generator(
    train_generator,
    steps_per_epoch=nb_train_samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=nb_validation_samples // batch_size)

But in Torch in , You also need to build it step by step , A little fussy

n_total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 5 == 0:
            print(f''' Epoch [{
      epoch+1}/{
      num_epochs}], Step [{
      i+1}/{
      n_total_steps}], Loss: {
      loss.item():.4f} ''')

torch.save(model.state_dict(), './ckpt')

Summary

Make good use of search engine , Both frameworks are detailed in official documents api Usage method .

Main reference


  1. Image preprocessing - Keras Chinese document ;

原网站

版权声明
本文为[zorchp]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202220540410302.html