当前位置:网站首页>Third week weekly report resnet+resnext
Third week weekly report resnet+resnext
2022-07-29 08:19:00 【_ Salt Baked Chicken】
Video learning + Paper reading
ResNet
ResNet Residual neural network , His main contribution is to provide an idea of residual block , The gradient descent problem and degeneration problem of neural network are solved , Make it possible to train a deep network .
There are generally two kinds of residual blocks ,basic block and bottleneck, The picture above is on the left basic block, On the right is bottle neck.
At the same time, in order to make the input and output consistent , There is another kind. block
This dotted line is often referred to as the fast track
ResNet There are three reasons why the training effect is good :
- The gradient of the model is consistent , It fits well SGD Training for
- The inherent model complexity is not high , Over fitting is not so serious
- Used BN, Make it better
ResNext
ResNext The first time I read it, I just put ResNext Those in the middle Resdual block Changed to the form of group , Adopted group convolution Methods , Absorbed VGG and ResNext The advantages of
!
Three kinds of Resdual Block equivalence , therefore ResNext The network architecture of is basically equivalent to ResNet, Just changed Resdual block .
This network structure adopts packet convolution , Reduced the amount of parameters , So we get better results
Code job
LeNet
Download of quotation package and data set
import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
import json
import shutil
from torch.optim import lr_scheduler
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, models, transforms
Next, Mount Google cloud disk , At the same time, check the files of the cloud disk
# Use the bag from the hard disk
from google.colab import drive
drive.mount('/content/drive/')
path = "/content/drive/"
os.listdir(path)
Downloading data sets from the network , And extract the , Here I added rar Unzip package installation .
# Download the package online
! wget https://static.leiphone.com/cat_dog.rar
! apt-get install rar
! unrar x cat_dog.rar
test GPU The usability of
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('gpu : % s' %torch.cuda.is_available())
tagging
Because I used ImageFolder Read out , So folders are labels , Therefore, the pictures in the dataset need to be stored in various folders .
src_dir_path = '/content/sample_data/cat_dog/train' # Source folder
key= 'dog','cat'
for i in key:
if not os.path.exists(src_dir_path+"/"+i):
print("to_dir_path not exist,so create the dir")
os.mkdir(src_dir_path+"/"+i, 1)
if os.path.exists(src_dir_path):
print("src_dir_path exist"+i)
for file in os.listdir(src_dir_path):
# is file
if os.path.isfile(src_dir_path+'/'+file):
if i in file:
print(' Find the containing "'+i+'" Character file , The absolute path is ----->'+src_dir_path+'/'+file)
shutil.move(src_dir_path+'/'+file, src_dir_path+"/"+i+'/'+file)
src_dir_path = '/content/sample_data/cat_dog/val' # Source folder
key= 'dog','cat'
for i in key:
if not os.path.exists(src_dir_path+"/"+i):
print("to_dir_path not exist,so create the dir")
os.mkdir(src_dir_path+"/"+i, 1)
if os.path.exists(src_dir_path):
print("src_dir_path exist"+i)
for file in os.listdir(src_dir_path):
# is file
if os.path.isfile(src_dir_path+'/'+file):
if i in file:
print(' Find the containing "'+i+'" Character file , The absolute path is ----->'+src_dir_path+'/'+file)
shutil.move(src_dir_path+'/'+file, src_dir_path+"/"+i+'/'+file)
src_dir_path = '/content/sample_data/cat_dog/test' # Source folder
to_dir_path = '/content/sample_data/cat_dog/test/catordogs' # The folder where the copied files are stored
key= 'j' # The file in the source folder contains characters key Copy to to_dir_path In the folder
if not os.path.exists(to_dir_path):
print("to_dir_path not exist,so create the dir")
os.mkdir(to_dir_path, 1)
if os.path.exists(src_dir_path):
print("src_dir_path exist")
for file in os.listdir(src_dir_path):
# is file
if os.path.isfile(src_dir_path+'/'+file):
if key in file:
print(' Find the containing "'+key+'" Character file , The absolute path is ----->'+src_dir_path+'/'+file)
print(' Copied to the ----->'+to_dir_path+file)
shutil.move(src_dir_path+'/'+file, to_dir_path+'/'+file)# Mobile move function
Data processing
The image specification I set here is 128, If you want to change, you can change type, After setting, output
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
type = 128
LeNet_format = transforms.Compose([transforms.Resize((type,type)),
transforms.ToTensor(),
normalize
])
data_dir = './'
dsets = {
x: datasets.ImageFolder(os.path.join(data_dir, x), LeNet_format)for x in ['test', 'train', 'val']}
dset_sizes = {
x: len(dsets[x]) for x in ['test', 'train', 'val']}
dset_classes = dsets['train'].classes
print(dsets['train'].classes)
print(dsets['train'].class_to_idx)
print(dsets['train'].imgs[:5])
print('dset_sizes: ', dset_sizes)
model
This leNet I use it relu, It's not strict lenet, It's a cnn Well
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.conv3 = nn.Conv2d(16, 32, 4)
self.conv4 = nn.Conv2d(32, 32, 4)
self.conv5 = nn.Conv2d(32, 32, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32, 16)
self.fc2 = nn.Linear(16, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = F.relu(self.conv5(x))
x = x.view(-1, 32)
x = F.relu(self.fc1(x))
x = F.softmax(self.fc2(x), dim=1)
return x
# Network put GPU On
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
Training
This training is based on code practice 2 Of , There's nothing to say emmmm
for epoch in range(30): # Repeat multiple rounds
for i, (inputs, labels) in enumerate(loader_train):
inputs = inputs.to(device)
labels = labels.to(device)
# Optimizer gradient zeroing
optimizer.zero_grad()
# Positive communication + Back propagation + Optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Output statistics
if i % 100 == 0:
print('Epoch: %d Minibatch: %5d loss: %.3f' %(epoch + 1, i + 1, loss.item()))
print('Finished Training')
Output results
resfile = open('LeNet.csv', 'w')
for i in range(0,2000):
img_PIL = Image.open('./test/catordogs/'+str(i)+'.jpg')
img_tensor = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor()])(img_PIL)
img_tensor = img_tensor.reshape(-1, img_tensor.shape[0], img_tensor.shape[1], img_tensor.shape[2])
img_tensor = img_tensor.to(device)
out = net(img_tensor).cpu().detach().numpy()
if out[0, 0] < out[0, 1]:
resfile.write(str(i)+','+str(1)+'\n')
else:
resfile.write(str(i)+','+str(0)+'\n')
resfile.close()
result
ResNet
ResNet A lot of code in is reused LeNet Of , So I only write some differences here .
Data processing
This one and LeNet There are two differences , One is that the cropped picture becomes 224*224, This is to conform to the input structure of the network ; The other one is batch_size, This is too big to burst the video storage seal , I lost several numbers these days ....
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
resnet_format = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
data_dir = '/content/sample_data/cat_dog'
dsets = {
x: datasets.ImageFolder(os.path.join(data_dir, x), resnet_format)
for x in ['test','train', 'val']}
dset_sizes = {
x: len(dsets[x]) for x in ['test', 'train', 'val']}
dset_classes = dsets['train'].classes
loader_train = torch.utils.data.DataLoader(dsets['train'], batch_size=32, shuffle=True, num_workers=6)
loader_valid = torch.utils.data.DataLoader(dsets['val'], batch_size=5, shuffle=False, num_workers=6)
loader_test = torch.utils.data.DataLoader(dsets['test'], batch_size=5, shuffle=False, num_workers=6)
ResNet50/152
Here we use transfer learning directly , Use pytorch Ready made models to run . use net_0 Because there was a mistake before , Maybe there is no instantiation ?
Parameters here try Adam and SGD,SGD It works better .
net = models.resnet152(pretrained=True)
net_0 = net
print(net_0)# Look at the network architecture
net_0.fc = nn.Linear(2048,2,bias =True)# Change the last full connection layer , Make it meet the two classification problem
net_0 =net_0.to(device)
# construct an optimizer
#params = [p for p in net_0.parameters() if p.requires_grad]
#optimizer = optim.Adam(params, lr=0.0001)
train and val
This is mainly based on a code change , hold val The results are used to train the next epoch, I added another optimization of learning rate .
lr=0.0001
optimizer = torch.optim.SGD(net_0.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
val_num = len(dsets['val'])
loss_function = nn.CrossEntropyLoss()
epochs = 3
best_acc = 0.0
train_steps = len(loader_train)
for epoch in range(epochs):
# train
net.train()
running_loss = 0.0
train_bar = tqdm(loader_train, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
val_bar = tqdm(loader_valid, file=sys.stdout)
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
scheduler.step()
print('Finished Training')
test
resfile = open('resnet1523.csv', 'w')
for i in range(0,2000):
img_PIL = Image.open('/content/sample_data/cat_dog/test/catordogs/'+str(i)+'.jpg')
img_tensor = transforms.Compose([transforms.Resize(224),transforms.CenterCrop(224),transforms.ToTensor()])(img_PIL)
img_tensor = img_tensor.reshape(-1, img_tensor.shape[0], img_tensor.shape[1], img_tensor.shape[2])
img_tensor = img_tensor.to(device)
out = net(img_tensor).cpu().detach().numpy()
if out[0, 0] < out[0, 1]:
resfile.write(str(i)+','+str(1)+'\n')
else:
resfile.write(str(i)+','+str(0)+'\n')
resfile.close()
result
The original author's :
my :
You can see that there is still no small improvement , It is effective to update the learning rate .
ResNext
I also tried resnext
model
This requires the introduction of weights
from torchvision.models import ResNeXt101_32X8D_Weights
net = models.resnext101_32x8d(weights = ResNeXt101_32X8D_Weights.IMAGENET1K_V1)
net_0 = net
net_0.fc = nn.Linear(2048,2,bias =True)
net_0 = net_0.to(device)
#print(net_0)
result
epoch=3,batch_size=32.
notice resnext It's not as good as resnet, It may be used epoch There are only three ,batch_size Only 16( Too much flash memory ), Maybe output csv There is something wrong with the code .
Thinking questions
- Residual learning
Problem solved : Gradient vanishing and degenerating problems , That is, the gradient and accuracy of the model will not weaken with the deepening of the network .
The main composition is : Add some new networks to the shallow network , These online learning
The residual is only added x Come in , There is no change in the complexity of the model . It keeps better on the gradient , And no matter how many layers are added , The useful network in front is always useful , This is also very suitable SGD.
- Batch Normailization Principle
BN It is a standardized operation , take batch Each characteristic in becomes the mean value 0, The variance of 1 The distribution of , Make it meet the needs of the network
- Why can grouping convolution improve accuracy ? That is, grouping convolution can improve the accuracy , At the same time, it can also reduce the amount of calculation , Can't you try to have as many scores as possible ?
Grouping convolution can reduce the amount of parameters , At the same time, it can be regarded as a sparse structure of normal convolution , Get regular effect .
The data information of group convolution only exists in this group , If there are too many groups , The information exchange between channels is too difficult , This will affect the effect .
- Res2Net How to use block convolution to reduce the amount of computation , At the same time, improve the network performance ?
Using this structure , Make the parameters of the network smaller , Receptive field can capture more details and global characteristics
- Vision Transformer Inside attention, Compare multi-head and The difference and connection of grouping convolution
multi-head attention The first is to split this , Then use Self-Attention In the same way , Finally, the final result is obtained through fusion .
Group convolution can also be understood as the same , Split , Use for each group resnet The method in , Re merger .
Problems encountered
colab I feel a little inadequate , I blew up two numbers these days , It's all here gpu limit . And training resnext And so on. gpu Of batch You can't get too much , If there is more, the video memory will explode .
resnet Running epoch A little less , my colab Bombed , The resources of both numbers are used up , Otherwise, the accuracy rate can be improved by running more .
resnet The effect is better than resnext good ? This is because I use batch_size Is it too small ? Or the training rounds are not enough ? Don't know much about it
VIT I don't quite understand , Not enough time , I just took a look , I will understand it carefully in the future
边栏推荐
- 2.4G band wireless transceiver chip si24r1 summary answer
- Collation of ml.net related resources
- STM32 MDK (keil5) contents mismatch error summary
- Cv520 domestic replacement of ci521 13.56MHz contactless reader chip
- Arduino uno error analysis avrdude: stk500_ recv(): programmer is not responding
- Lora opens a new era of Internet of things -asr6500s, asr6501/6502, asr6505, asr6601
- [beauty of software engineering - column notes] 26 | continuous delivery: how to release new versions to the production environment at any time?
- Low power Bluetooth 5.0 chip nrf52832-qfaa
- Security baseline of network security
- 亚马逊测评自养号是什么,卖家应该怎么做?
猜你喜欢
Solve the problem of MSVC2017 compiler with yellow exclamation mark in kits component of QT
[robomaster] control RM motor from scratch (2) -can communication principle and electric regulation communication protocol
[beauty of software engineering - column notes] 23 | Architect: programmers who don't want to be architects are not good programmers
[beauty of software engineering - column notes] 26 | continuous delivery: how to release new versions to the production environment at any time?
数仓分层设计及数据同步问题,,220728,,,,
[beauty of software engineering - column notes] 28 | what is the core competitiveness of software engineers? (next)
Cv520 domestic replacement of ci521 13.56MHz contactless reader chip
BiSeNet v2
SQL 面试碰到的一个问题
Stm32ff030 replaces domestic MCU dp32g030
随机推荐
commonjs导入导出与ES6 Modules导入导出简单介绍及使用
产品推广的渠道和策略,化妆品品牌推广方法及步骤
Node: file write data (readfile, WriteFile), two modes: overwrite and increment
[beauty of software engineering - column notes] "one question and one answer" issue 3 | 18 common software development problem-solving strategies
Qt/PyQt 窗口类型与窗口标志
PostgreSQL手动创建HikariDataSource解决报错Cannot commit when autoCommit is enabled
数字人民币时代隐私更安全
Low power Bluetooth 5.0 chip nrf52832-qfaa
Pnpm install appears: err_ PNPM_ PEER_ DEP_ ISSUES Unmet peer dependencies
BiSeNet v2
Some simple uses of crawler requests Library
Tensorboard use
Beautiful girls
Segment paging and segment page combination
Taiyuan bus route crawling
Inclination sensor is used for long-term monitoring of communication tower and high-voltage tower
AES 双向加密解密工具
Application of explosion-proof inclination sensor in safe operation of LNG
STM32 detection signal frequency
Detailed steps of installing MySQL 5.7 for windows