当前位置:网站首页>[pytorch] CNN practice - flower species identification
[pytorch] CNN practice - flower species identification
2022-07-26 06:16:00 【Li Junfeng】
Data sets
utilize Kaggle A public dataset on , Download connection is as follows :
https://www.kaggle.com/datasets/alxmamaev/flowers-recognition
It's a picture of some flowers , share 5 class , More than 4000 photos .
Data processing
The entire data set is not large , Therefore, it can be read into memory first ( In memory ), Instead of reading from the hard disk every time you need it , It can effectively improve the running speed .
And the number of pictures is not much , So we need to use Picture enlargement technology .
Reading data sets
Kaggle The data on has classified the pictures according to folders , So when reading pictures , It needs to be classified according to folders .
class Flower_Dataset(Dataset):
def __init__(self, path , is_train, augs):
data_root = pathlib.Path(path)
all_image_paths = list(data_root.glob('*/*'))
self.all_image_paths = [str(path) for path in all_image_paths]
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
label_to_index = dict((label, index) for index, label in enumerate(label_names))
self.all_image = [cv.imread(path) for path in self.all_image_paths]
self.all_image_labels = [label_to_index[path.parent.name] for path in all_image_paths]
Picture enlargement
Consider pictures of flowers , After the horizontal transformation, it is still a flower , Therefore, this augmentation method can be used .
This is a , brightness 、 Contrast and other adjustments can be used .
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug])
iterator
Extract a batch size from the data set each time .
Generally, we use the way of disordering the order .
train_iter = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers= 4)
test_iter = DataLoader(test_set, batch_size=batch_size, num_workers= 4)
CNN Model
Use the classic resnet Model , Due to the limited size of the data set , Too complex networks should not be used , Therefore, we choose resnet18, The total of 68 layer , Not too deep , The specific structure is as follows :
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 36,864
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
BasicBlock-11 [-1, 64, 56, 56] 0
Conv2d-12 [-1, 64, 56, 56] 36,864
BatchNorm2d-13 [-1, 64, 56, 56] 128
ReLU-14 [-1, 64, 56, 56] 0
Conv2d-15 [-1, 64, 56, 56] 36,864
BatchNorm2d-16 [-1, 64, 56, 56] 128
ReLU-17 [-1, 64, 56, 56] 0
BasicBlock-18 [-1, 64, 56, 56] 0
Conv2d-19 [-1, 128, 28, 28] 73,728
BatchNorm2d-20 [-1, 128, 28, 28] 256
ReLU-21 [-1, 128, 28, 28] 0
Conv2d-22 [-1, 128, 28, 28] 147,456
BatchNorm2d-23 [-1, 128, 28, 28] 256
Conv2d-24 [-1, 128, 28, 28] 8,192
BatchNorm2d-25 [-1, 128, 28, 28] 256
ReLU-26 [-1, 128, 28, 28] 0
BasicBlock-27 [-1, 128, 28, 28] 0
Conv2d-28 [-1, 128, 28, 28] 147,456
BatchNorm2d-29 [-1, 128, 28, 28] 256
ReLU-30 [-1, 128, 28, 28] 0
Conv2d-31 [-1, 128, 28, 28] 147,456
BatchNorm2d-32 [-1, 128, 28, 28] 256
ReLU-33 [-1, 128, 28, 28] 0
BasicBlock-34 [-1, 128, 28, 28] 0
Conv2d-35 [-1, 256, 14, 14] 294,912
BatchNorm2d-36 [-1, 256, 14, 14] 512
ReLU-37 [-1, 256, 14, 14] 0
Conv2d-38 [-1, 256, 14, 14] 589,824
BatchNorm2d-39 [-1, 256, 14, 14] 512
Conv2d-40 [-1, 256, 14, 14] 32,768
BatchNorm2d-41 [-1, 256, 14, 14] 512
ReLU-42 [-1, 256, 14, 14] 0
BasicBlock-43 [-1, 256, 14, 14] 0
Conv2d-44 [-1, 256, 14, 14] 589,824
BatchNorm2d-45 [-1, 256, 14, 14] 512
ReLU-46 [-1, 256, 14, 14] 0
Conv2d-47 [-1, 256, 14, 14] 589,824
BatchNorm2d-48 [-1, 256, 14, 14] 512
ReLU-49 [-1, 256, 14, 14] 0
BasicBlock-50 [-1, 256, 14, 14] 0
Conv2d-51 [-1, 512, 7, 7] 1,179,648
BatchNorm2d-52 [-1, 512, 7, 7] 1,024
ReLU-53 [-1, 512, 7, 7] 0
Conv2d-54 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-55 [-1, 512, 7, 7] 1,024
Conv2d-56 [-1, 512, 7, 7] 131,072
BatchNorm2d-57 [-1, 512, 7, 7] 1,024
ReLU-58 [-1, 512, 7, 7] 0
BasicBlock-59 [-1, 512, 7, 7] 0
Conv2d-60 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-61 [-1, 512, 7, 7] 1,024
ReLU-62 [-1, 512, 7, 7] 0
Conv2d-63 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-64 [-1, 512, 7, 7] 1,024
ReLU-65 [-1, 512, 7, 7] 0
BasicBlock-66 [-1, 512, 7, 7] 0
AdaptiveAvgPool2d-67 [-1, 512, 1, 1] 0
Linear-68 [-1, 5] 2,565
================================================================
Total params: 11,179,077
Trainable params: 11,179,077
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 42.64
Estimated Total Size (MB): 106.00
----------------------------------------------------------------
Fine tuning technology
Consider the pictures in this dataset and ImageNet similar , Therefore, this technology can be used .
The only thing that needs to be changed , It's the last floor , Set the original output to 5.
In addition, the learning rate of each level also needs to be modified .
net = torchvision.models.resnet18(pretrained=True)
net.fc = nn.Linear(net.fc.in_features, 5)
nn.init.xavier_uniform_(net.fc.weight)
summary(net , input_size=(3,224,224) , device="cpu")
lr = 0.0005
loss = nn.CrossEntropyLoss(reduction="mean")
params_1x = [param for name, param in net.named_parameters()
if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.SGD([{
'params': params_1x},{
'params': net.fc.parameters(),'lr': lr * 80}],lr=lr, weight_decay=0.001)
Training
This part is similar to other neural networks , I'm not going to repeat .
from tqdm import tqdm
import numpy as np
#Training
Accuracies = []
Losses = []
T_Accuracies = []
T_Losses = []
for epoch in range(epochs):
net.train()
loop = tqdm(enumerate(train_iter), total = len(train_iter)) # Define progress bar
loop.set_description(f'Epoch [{
epoch+ 1}/{
epochs}]')# Set the beginning
T_Accuracies.append(0)
T_Losses.append(0)
for index, (X, Y) in loop:
scores = net(X)
l = loss(scores, Y)
trainer.zero_grad()
l.backward()
_ , predictions = scores.max(1)
num_correct = (predictions == Y).sum()
running_train_acc = float(num_correct) / float(X.shape[0])
if index == 0:
T_Accuracies[-1] = running_train_acc
T_Losses[-1] = l.item()
else:
T_Accuracies[-1] = T_Accuracies[-1] * 0.9 + 0.1 * running_train_acc
T_Losses[-1] = T_Losses[-1] * 0.9 + 0.1 * l.item()
loop.set_postfix(loss='{:.3f}'.format(T_Losses[-1]), accuracy='{:.3f}'.format(T_Accuracies[-1] )) # Define end
trainer.step()
pass
a , b = testing()
Accuracies.append(a)
Losses.append(b)
result
According to the data of training set and test set , Draw the following image :
It can be seen that the accuracy of both the training set and the test set is relatively high , It shows that fine-tuning technology is useful .
And the test accuracy has exceeded in the fifth round 90%, It can be said that it reached a higher level in a short time .
Besides , The accuracy of training set is lower than that of test set , This is due to the use of image augmentation in the training set but not in the test set .
Complete code
边栏推荐
- Convolutional neural network (II) - deep convolutional network: case study
- Read five articles in the evening | Economic Daily: don't treat digital collections as wealth making products
- How can machinery manufacturing enterprises do well in production management with the help of ERP system?
- 【BM2 链表内指定区间反转】
- If introduced according to the open source framework
- Do it yourself smart home: intelligent air conditioning control
- Using dynamic libraries in VS
- 白盒测试的概念、目的是什么?及主要方法有哪些?
- Sequential action localization | fine grained temporal contrast learning for weak supervised temporal action localization (CVPR 2022)
- 字节面试题——判断一棵树是否为平衡二叉树
猜你喜欢

招标信息获取

Kingbasees SQL language reference manual of Jincang database (10. Query and sub query)

时序动作定位 | 用于弱监督时态动作定位的细粒度时态对比学习(CVPR 2022)

Implementation of PHP multitask second timer

CV (1)- Introduction

Oc/swift Technology Download File (breakpoint continuation AFN download file alamofire Download File native download) (source code)

【pytorch】微调技术

Leetcode:940. How many subsequences have different literal values

Registration conditions for system integration project management engineer (intermediate level of soft exam) in the second half of 2022

Youwei low code: Brick life cycle component life cycle
随机推荐
Leetcode:336. palindrome pair
【Day_03 0420】字符串中找出连续最长的数字串
Introduction of four redis cluster schemes + comparison of advantages and disadvantages
WebAPI整理
Docking wechat payment (II) unified order API
[day_070425] Fibonacci series
VRRP protocol and experimental configuration
【Day_03 0420】数组中出现次数超过一半的数字
Calling mode and execution sequence of JS
[day_020419] sort subsequence
Using dynamic libraries in VS
Database SQL language practice
Widget is everything, widget introduction
Read five articles in the evening | Economic Daily: don't treat digital collections as wealth making products
[day_010418] delete public string
Easycvr video square channel display and video access full screen display style problem repair
Do it yourself smart home: intelligent air conditioning control
flex布局
Should we test the Dao layer?
二叉树的前中后序遍历——本质(每个节点都是“根”节点)