当前位置:网站首页>Tutorial on building pytoch model from zero (V) writing training process -- some basic configurations
Tutorial on building pytoch model from zero (V) writing training process -- some basic configurations
2022-06-29 13:08:00 【CV technical guide (official account)】
Preface This paper introduces the configuration method of training log , Why do I need to seed random numbers , How to seed random numbers , Configuration of loading data , Configuration and adjustment methods of learning , Configuration of loss function and writing of user-defined loss function .
Welcome to the official account CV Technical guide , Focus on computer vision technology summary 、 The latest technology tracking 、 Interpretation of classic papers 、CV Recruitment information .
Configuration of training log
The training log is used to save some information during training , It is convenient to check the training of the model afterwards .
The first is to prepare the basic configuration .
import logging
def train_logger(num):
logger = logging.getLogger(__name__)
# Set the printing level , Altogether 6 A level , From low to high are :
#NOTEST、DEBUG、INFO、WARNING、ERROR、CRITICAL.
#setLevel The lowest printing level is set , Those below this level will not print .
logger.setLevel(level=logging.INFO)
# Print to file , And set the name and path of the printed file
file_log = logging.FileHandler('./run/{}/train.log'.format(num))
# Print to terminal
print_log = logging.StreamHandler()
# Set print format
#%(asctime) Indicates the current time ,%(message) Indicates the information to be printed , When I use it, I will introduce .
formatter = logging.Formatter('%(asctime)s %(message)s')
file_log.setFormatter(formatter)
print_log.setFormatter(formatter)
logger.addHandler(file_log)
logger.addHandler(print_log)
return loggerUsage method
logger = train_logger(0)
logger.info("project_name: {}".format(project_name))
logger.info("batchsize: {}".format(opt.batchsize))
logger.warning("warning message")Set random number seed
Random number seed is used to fix the disorder order of data set every time , Make the model reproducible . If random number seeds are not set , Because the data sets are disordered in different order each time , This causes the model to float slightly .
Let's briefly introduce the reason for floating .
Reason for model floating
At present, they basically use mini-batch gradient descent , That is to say, each time it is a prequel batch After the data of , Will update the weight , meanwhile , Models are basically used BN, For each batch Normalization . therefore ,batch Data will have a certain impact on the performance of the model .
If the random order is different every time , There may be several times batch Very well combined , So that the model training effect is good , Other times batch The combination of is not very suitable , So that it can not achieve the effect when the combination is very good .
therefore , The so-called reason is that each time in a different order batch Diversity of samples ,batch The diversity of samples has a certain impact on the results of the model . The smaller the data set , The greater the impact , Because of batch There are great differences between the samples , The larger the data set ,batch The smaller the difference between samples may be affected by random order .
therefore , It is a necessary thing to seed random numbers , After setting the seed of random number , The generation sequence of each training is the same .
The random numbers on the computer are all artificially simulated , therefore , We can arbitrarily set the range of random numbers .
How to seed random numbers
import random
# Basic configuration
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.deterministic = True
# Usage method
# stay train File, call the following command , Parameters can be set at will ,
# As long as the value of this parameter is the same , The order of each generation is the same
setup_seed(2022)Configuration of loading data
The configuration of loading data is relatively easy , Just use the following lines of code .
from torch.utils.data import DataLoader
from dataloader import MyDataset
train_folder = opt.data_dir + '/train'
train_dataset = MyDataset(data_folder=train_folder,opt=opt)
train_loader = DataLoader(train_dataset, batch_size=opt.batchsize, shuffle=True, num_workers=8)It's mainly about setting up DataLoader, among shuffler The basic default is True, If it is multi-level multi card distributed training , be shuffle by false, and sampler Use the following code to get ,num_workers Indicates the number of processes used to load data .
from torch.utils.data.distributed import DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=opt.batchsize, sampler=train_sampler)
Network initialization
There are two main things about network initialization , One is to initialize the network , The other is to load the pre training model , Since we have already introduced how to load the pre training parameters of the specified layer in the second chapter , therefore , I won't go into that . This part is also relatively simple .
from model import Yolo_v1
net = Yolo_v1()
net.load_state_dict(torch.load(trained_path))
net = net.cuda()The setting of learning rate
Learning settings mainly introduce how to set different learning rates in different layers . for example backbone Using a pre training model , The full connection layer uses random initialization , therefore backbone Primary school attendance is required , The full connection layer requires college attendance .
import torch.optim as optim
from torch.optim import lr_scheduler
optim_params = [{'params': net.backbone.parameters(), 'lr': 0.1 * opt.lr},
{'params': net.interaction.parameters(), 'lr': opt.lr},
{'params': net.large_conv1_1.parameters(), 'lr': opt.lr},
{'params': net.large_conv1_2.parameters(), 'lr': opt.lr},
{'params': net.classifier.parameters(), 'lr': opt.lr}
]
optimizer = optim.SGD(optim_params, weight_decay=5e-4, momentum=0.9, nesterov=True)
scheduler = lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.1)The first is to build a list of parameters and corresponding learning rates , And then as a parameter to optim Of these optimizers , The example uses SGD.
In addition, a learning rate scheduler is set up lr_scheduler, Used to train to different epoch Adjust the learning rate . there step_size Indicates at each 80 individual epoch when , Multiply the learning rate by gamma value .
One more lr_scheduler.MultiStepLR(optimizer, milestones=[30,50,60], gamma=0.1, last_epoch=-1). This is related to StepLR Is the difference between the MultiStepLR It's based on milestones To adjust the learning rate , In this example, it is shown in the 30、50、60epoch Adjust when .
Here is the basic configuration of learning rate , The following also involves the adjustment of learning rate in the training process .
How to adjust the learning rate
stay PyTorch 1.1.0 Previous version , The adjustment of learning rate should be put in optimizer Before updating .1.1.0 Version used in optimizer After the update .
for epoch in range(opt.num_epochs):
# Omit part of the code
loss.backward()
optimizer.step()
scheduler.step()Loss function settings
The general loss function is simple , Just use the following lines of code .
import torch.nn as nn
cls_criterion = nn.CrossEntropyLoss()
dist_criterion = nn.MSELoss() # Use L2 loss function
hinge_criterion = nn.HingeEmbeddingLoss()
But some loss functions are designed by ourselves , So you need to do it yourself . Let's say TripletLoss For example .
import torch
import torch.nn as nn
import torch.nn.functional as F
class TripletLoss(nn.Module):
def __init__(self, t1, t2, beta):
super(TripletLoss, self).__init__()
self.t1 = t1
self.t2 = t2
self.beta = beta
return
def forward(self, anchor, positive, negative):
matched = torch.pow(F.pairwise_distance(anchor, positive), 2)
mismatched = torch.pow(F.pairwise_distance(anchor, negative), 2)
part_1 = torch.clamp(matched - mismatched, min=self.t1)
part_2 = torch.clamp(matched, min=self.t2)
dist_hinge = part_1 + self.beta * part_2
loss = torch.mean(dist_hinge)
return lossBriefly introduce the usage , Just like defining a network , Inherit nn.Module, Then finish __init__ and __forward__ function , But unlike the Internet , There are no trainable parameters in the loss function , Therefore, it is usually used directly torch.nn.functional The function in is just .
This is also torch.nn.functional Functions and nn The difference between functions in , The former needs to set its own weight , And will not be updated with the training process , The latter does not need to set their own weights , The weights update .
Due to limited space , Here are so many , In the next section, we will continue to introduce some other configurations in the process of writing training , Such as Tensorboard, Construction of training process, etc .
Welcome to the official account CV Technical guide , Focus on computer vision technology summary 、 The latest technology tracking 、 Interpretation of classic papers 、CV Recruitment information .
CV The technical guide creates a free Knowledge of the planet . Pay attention to the official account, add the edited micro signal, invite to join. .
Other articles
Introduction to computer vision
YOLO Series carding ( One )YOLOv1-YOLOv3
YOLO Series carding ( Two )YOLOv4
YOLO Series carding ( 3、 ... and )YOLOv5
Attention Mechanism in Computer Vision
Build from scratch Pytorch Model tutorial ( Four ) Write the training process -- Argument parsing
Build from scratch Pytorch Model tutorial ( 3、 ... and ) build Transformer The Internet
Build from scratch Pytorch Model tutorial ( Two ) Build network
Build from scratch Pytorch Model tutorial ( One ) data fetch
StyleGAN Grand summary | Comprehensive understanding SOTA Method 、 New progress in architecture
A thermal map visualization code tutorial
Summary of industrial image anomaly detection research (2019-2020)
Some personal thinking habits and thought summary about learning a new technology or field quickly
边栏推荐
猜你喜欢

Simple introduction to matlab

1. opencv realizes simple color recognition

Beifu PLC controls servo through CANopen communication

Qt的信号与槽

Install the typescript environment and enable vscode to automatically monitor the compiled TS file as a JS file

Recommended model recurrence (I): familiar with torch rechub framework and use

Go Senior Engineer required course | I sincerely suggest you listen to it. Don't miss it~

Precautions for Beifu controller connecting Panasonic EtherCAT servo

CVPR2022 | A ConvNet for the 2020s & 如何设计神经网络总结

Cnpm reports an error 'cnpm' is not an internal or external command, nor is it a runnable program or batch file
随机推荐
深入理解 volatile 关键字
C # realizes the first order traversal, middle order traversal and second order traversal of binary tree
C # output the middle order traversal through the clue binary tree
如何計算win/tai/loss in paired t-test
leetcode 903. DI 序列的有效排列
推荐模型复现(四):多任务模型ESMM、MMOE
nvtmpp
23、 1-bit data storage (delay line / core /dram/sram/ tape / disk / optical disc /flash SSD)
Blurred pictures become clear, one button two-color pictures, quickly organize local pictures These 8 online picture tools apply to join your favorites!
树状数组应用(AcWing 242,243,244)
Bison uses error loop records
强大、优秀的文件管理软件评测:图片管理、书籍管理、文献管理
墨菲安全入选中关村科学城24个重点项目签约
Hystrix断路器
从Mpx资源构建优化看splitChunks代码分割
Cocos star meetings at Hangzhou station in 2022
Async principle implementation
三维模型下载与动画控制
Problem solving: modulenotfounderror: no module named 'pip‘
leetcode 第 299场周赛
