当前位置:网站首页>Pytorch idea and implementation of keras code conversion for CNN image classification
Pytorch idea and implementation of keras code conversion for CNN image classification
2022-06-22 07:08:00 【zorchp】
tags: Python DL
Write it at the front
I changed a code a few days ago , It is about convolution neural network in deep learning Python Code , Used to solve classification problems . The code uses TensorFlow Of Keras Interface written , Requirements are translated into pytorch Code , Whereas both api Close , It won't be too difficult to cover it , Just some details need to be paid attention to , Record it here , For your reference .
About library function import
First, let's take a look at the difference between the two popular deep learning frameworks in the import of library functions , This requires a brief understanding of their main structures . For the convenience of narration , As mentioned below TF All refer to TensorFlow2.X with Keras, Torch All refer to PyTorch.
model building
First, let's look at the construction of the model , about TF, The model can be built easily through sequential Method to get , This requires the introduction of this method :
from tensorflow.keras.models import Sequential
stay Torch in , Of course, it can also be sequential Build the model , ( However, the government still recommends an object-oriented approach )
We need to introduce :
from torch.nn import Sequential
Speaking of model building , We have to mention several layers that are commonly used in convolutional neural networks : conv layer , maxpool Layer and full connection layer (softmax), These are readily available in both frameworks , Let's see how to call these methods :
stay TF in :
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense
And in the Torch in :
from torch.nn import Conv2d, MaxPool2d
from torch.nn import Flatten, Linear, CrossEntropyLoss
from torch.optim import SGD
It can be seen that the two are only slightly different , TF Put the calls of some activation functions in the parameters , and Torch Are given in the form of library functions .
Data read in
Finally, let's take a look at the data import section , stay TF You can easily use the following methods to process data ( picture ) Processing and reading of :
from tensorflow.keras import backend
from tensorflow.keras.preprocessing.image import ImageDataGenerator
stay Torch in , Similar import is required :
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
data fetch / Deal with partial api differences
In the data reading section , I still feel like Keras It's more convenient 1, Torch It mainly uses the modular import method , You need to instantiate a class first , Then the object is used to process the image .
Let's take a look at TF Code for reading picture data :
# Import data
if backend.image_data_format() == 'channels_first':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
# Training set image enhancement
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# Test set image enhancement (only rescaling)
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical') # Many classification
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical') # Many classification
Next is Torch Code for :
# Import data
input_shape = (img_width, img_height, 3)
# Training set image enhancement
train_datagen = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Resize((img_width, img_height))
])
# Test set image enhancement (only rescaling)
test_datagen = transforms.Compose([ # Perform the following specified operations on the read picture
transforms.ToTensor(), # This step is equivalent to Keras Of rescale by 1/255
transforms.Resize((img_width, img_height))
])
train_generator = datasets.ImageFolder(train_data_dir,
transform=train_datagen)
validation_generator = datasets.ImageFolder(validation_data_dir,
transform=test_datagen)
train_loader = torch.utils.data.DataLoader(train_generator,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(validation_generator,
batch_size=batch_size,
shuffle=False)
Of model building api differences
Let's talk about the most important , The building part of the model api The difference between calls , stay TF Directly in model.add Call to , You can easily create a CNN Identify the model , Note the correspondence of data flow dimensions , Here is the code . Concise and intuitive .
# Creating models
model = Sequential()
model.add(Conv2D(filters=6,
kernel_size=(5, 5),
padding='valid',
input_shape=input_shape,
activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(filters=16,
kernel_size=(5, 5),
padding='valid',
activation='tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(4, activation='softmax'))
# Compile model
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
stay Torch in , There's a similar way , However, there is no need to compile the model , The code is as follows :
# Creating models
model = Sequential(
Conv2d(in_channels=3,
out_channels=6,
kernel_size=(5, 5),
padding='valid'),
MaxPool2d(kernel_size=(2, 2)),
Conv2d(in_channels=6,
out_channels=16,
kernel_size=(5, 5),
padding='valid'),
MaxPool2d(kernel_size=(2, 2)),
Flatten(),
Linear(400, 120),
Linear(120, 84),
Linear(84, 4)
)
# The loss function is set as the cross entropy function
criterion = CrossEntropyLoss()
# Set optimizer to random gradient descent algorithm
optimizer = SGD(model.parameters(), lr=0.001)
Here it is api There are still some differences , For example, the writing method and parameters of the full connection layer , There are also some differences in convolution . Same as , Still pay great attention to the data dimension .
Model training part api differences
stay TF in , By introducing Keras This powerful and grammatically concise api, The training model is also very simple , The code is as follows :
# Training models
history=model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size)
But in Torch in , You also need to build it step by step , A little fussy
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 5 == 0:
print(f''' Epoch [{
epoch+1}/{
num_epochs}], Step [{
i+1}/{
n_total_steps}], Loss: {
loss.item():.4f} ''')
torch.save(model.state_dict(), './ckpt')
Summary
Make good use of search engine , Both frameworks are detailed in official documents api Usage method .
Main reference
边栏推荐
- Tpflow v6.0.6 official release
- [outside distribution detection] your classifier is secret an energy based model and you head treat it like one ICLR '20
- Sharing the strongest summer vacation plan of ape tutoring: the summer vacation plan is the same as learning and playing
- Programming problem: removing elements from an array (JS Implementation)
- 【实习】跨域问题
- Cesium loading 3D tiles model
- Difference between grail layout and twin wing layout
- Several methods of array de duplication in JS
- vue连接mysql数据库失败
- 流程引擎解决复杂的业务问题
猜你喜欢

RFID仓储管理系统解决方案实施可视化流程

JDBC查询结果集,结果集转化成表

Introduction notes to quantum computing (continuously updated)

咖啡供应链是如何被趟平的?

Introduction to 51 single chip microcomputer - LED light

Error: unable to find vcvarsall Solutions to bat errors

Self supervised learning for general out of distribution detection AAAI '20

【GCN-RS】UltraGCN: Ultra Simplification of Graph Convolutional Networks for Recommendation (CIKM‘21)

Xh_CMS渗透测试文档

Protection of RT thread critical section
随机推荐
Use of sessionstorage and localstorage
Vector of relevant knowledge of STL Standard Template Library
实训渗透靶场02|3星vh-lll靶机|vulnhub靶场Node1
Process engine solves complex business problems
Advanced usage of setting breakpoints during keil debugging
The journey of an operator in the framework of deep learning
Article editing test of CSDN
生成字符串方式
leetcode:面试题 08.12. 八皇后【dfs + backtrack】
[anomaly detection] malware detection: mamadroid (dnss 2017)
[fundamentals of machine learning 01] blending, bagging and AdaBoost
Xh_ CMS penetration test documentation
[distributed external detection] Odin ICLR '18
[fundamentals of machine learning 02] decision tree and random forest
JS中对数组进行去重的几种方法
Introduction to 51 Single Chip Microcomputer -- timer and external interrupt
咖啡供应链是如何被趟平的?
【GAN】W-GAN ICLR‘17, ICML‘17
[outside distribution detection] your classifier is secret an energy based model and you head treat it like one ICLR '20
RFID仓储管理系统解决方案实施可视化流程