当前位置:网站首页>Tutorial on building pytoch model from zero (V) writing training process -- some basic configurations
Tutorial on building pytoch model from zero (V) writing training process -- some basic configurations
2022-06-29 13:08:00 【CV technical guide (official account)】
Preface This paper introduces the configuration method of training log , Why do I need to seed random numbers , How to seed random numbers , Configuration of loading data , Configuration and adjustment methods of learning , Configuration of loss function and writing of user-defined loss function .
Welcome to the official account CV Technical guide , Focus on computer vision technology summary 、 The latest technology tracking 、 Interpretation of classic papers 、CV Recruitment information .
Configuration of training log
The training log is used to save some information during training , It is convenient to check the training of the model afterwards .
The first is to prepare the basic configuration .
import logging
def train_logger(num):
logger = logging.getLogger(__name__)
# Set the printing level , Altogether 6 A level , From low to high are :
#NOTEST、DEBUG、INFO、WARNING、ERROR、CRITICAL.
#setLevel The lowest printing level is set , Those below this level will not print .
logger.setLevel(level=logging.INFO)
# Print to file , And set the name and path of the printed file
file_log = logging.FileHandler('./run/{}/train.log'.format(num))
# Print to terminal
print_log = logging.StreamHandler()
# Set print format
#%(asctime) Indicates the current time ,%(message) Indicates the information to be printed , When I use it, I will introduce .
formatter = logging.Formatter('%(asctime)s %(message)s')
file_log.setFormatter(formatter)
print_log.setFormatter(formatter)
logger.addHandler(file_log)
logger.addHandler(print_log)
return loggerUsage method
logger = train_logger(0)
logger.info("project_name: {}".format(project_name))
logger.info("batchsize: {}".format(opt.batchsize))
logger.warning("warning message")Set random number seed
Random number seed is used to fix the disorder order of data set every time , Make the model reproducible . If random number seeds are not set , Because the data sets are disordered in different order each time , This causes the model to float slightly .
Let's briefly introduce the reason for floating .
Reason for model floating
At present, they basically use mini-batch gradient descent , That is to say, each time it is a prequel batch After the data of , Will update the weight , meanwhile , Models are basically used BN, For each batch Normalization . therefore ,batch Data will have a certain impact on the performance of the model .
If the random order is different every time , There may be several times batch Very well combined , So that the model training effect is good , Other times batch The combination of is not very suitable , So that it can not achieve the effect when the combination is very good .
therefore , The so-called reason is that each time in a different order batch Diversity of samples ,batch The diversity of samples has a certain impact on the results of the model . The smaller the data set , The greater the impact , Because of batch There are great differences between the samples , The larger the data set ,batch The smaller the difference between samples may be affected by random order .
therefore , It is a necessary thing to seed random numbers , After setting the seed of random number , The generation sequence of each training is the same .
The random numbers on the computer are all artificially simulated , therefore , We can arbitrarily set the range of random numbers .
How to seed random numbers
import random
# Basic configuration
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.deterministic = True
# Usage method
# stay train File, call the following command , Parameters can be set at will ,
# As long as the value of this parameter is the same , The order of each generation is the same
setup_seed(2022)Configuration of loading data
The configuration of loading data is relatively easy , Just use the following lines of code .
from torch.utils.data import DataLoader
from dataloader import MyDataset
train_folder = opt.data_dir + '/train'
train_dataset = MyDataset(data_folder=train_folder,opt=opt)
train_loader = DataLoader(train_dataset, batch_size=opt.batchsize, shuffle=True, num_workers=8)It's mainly about setting up DataLoader, among shuffler The basic default is True, If it is multi-level multi card distributed training , be shuffle by false, and sampler Use the following code to get ,num_workers Indicates the number of processes used to load data .
from torch.utils.data.distributed import DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=opt.batchsize, sampler=train_sampler)
Network initialization
There are two main things about network initialization , One is to initialize the network , The other is to load the pre training model , Since we have already introduced how to load the pre training parameters of the specified layer in the second chapter , therefore , I won't go into that . This part is also relatively simple .
from model import Yolo_v1
net = Yolo_v1()
net.load_state_dict(torch.load(trained_path))
net = net.cuda()The setting of learning rate
Learning settings mainly introduce how to set different learning rates in different layers . for example backbone Using a pre training model , The full connection layer uses random initialization , therefore backbone Primary school attendance is required , The full connection layer requires college attendance .
import torch.optim as optim
from torch.optim import lr_scheduler
optim_params = [{'params': net.backbone.parameters(), 'lr': 0.1 * opt.lr},
{'params': net.interaction.parameters(), 'lr': opt.lr},
{'params': net.large_conv1_1.parameters(), 'lr': opt.lr},
{'params': net.large_conv1_2.parameters(), 'lr': opt.lr},
{'params': net.classifier.parameters(), 'lr': opt.lr}
]
optimizer = optim.SGD(optim_params, weight_decay=5e-4, momentum=0.9, nesterov=True)
scheduler = lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.1)The first is to build a list of parameters and corresponding learning rates , And then as a parameter to optim Of these optimizers , The example uses SGD.
In addition, a learning rate scheduler is set up lr_scheduler, Used to train to different epoch Adjust the learning rate . there step_size Indicates at each 80 individual epoch when , Multiply the learning rate by gamma value .
One more lr_scheduler.MultiStepLR(optimizer, milestones=[30,50,60], gamma=0.1, last_epoch=-1). This is related to StepLR Is the difference between the MultiStepLR It's based on milestones To adjust the learning rate , In this example, it is shown in the 30、50、60epoch Adjust when .
Here is the basic configuration of learning rate , The following also involves the adjustment of learning rate in the training process .
How to adjust the learning rate
stay PyTorch 1.1.0 Previous version , The adjustment of learning rate should be put in optimizer Before updating .1.1.0 Version used in optimizer After the update .
for epoch in range(opt.num_epochs):
# Omit part of the code
loss.backward()
optimizer.step()
scheduler.step()Loss function settings
The general loss function is simple , Just use the following lines of code .
import torch.nn as nn
cls_criterion = nn.CrossEntropyLoss()
dist_criterion = nn.MSELoss() # Use L2 loss function
hinge_criterion = nn.HingeEmbeddingLoss()
But some loss functions are designed by ourselves , So you need to do it yourself . Let's say TripletLoss For example .
import torch
import torch.nn as nn
import torch.nn.functional as F
class TripletLoss(nn.Module):
def __init__(self, t1, t2, beta):
super(TripletLoss, self).__init__()
self.t1 = t1
self.t2 = t2
self.beta = beta
return
def forward(self, anchor, positive, negative):
matched = torch.pow(F.pairwise_distance(anchor, positive), 2)
mismatched = torch.pow(F.pairwise_distance(anchor, negative), 2)
part_1 = torch.clamp(matched - mismatched, min=self.t1)
part_2 = torch.clamp(matched, min=self.t2)
dist_hinge = part_1 + self.beta * part_2
loss = torch.mean(dist_hinge)
return lossBriefly introduce the usage , Just like defining a network , Inherit nn.Module, Then finish __init__ and __forward__ function , But unlike the Internet , There are no trainable parameters in the loss function , Therefore, it is usually used directly torch.nn.functional The function in is just .
This is also torch.nn.functional Functions and nn The difference between functions in , The former needs to set its own weight , And will not be updated with the training process , The latter does not need to set their own weights , The weights update .
Due to limited space , Here are so many , In the next section, we will continue to introduce some other configurations in the process of writing training , Such as Tensorboard, Construction of training process, etc .
Welcome to the official account CV Technical guide , Focus on computer vision technology summary 、 The latest technology tracking 、 Interpretation of classic papers 、CV Recruitment information .
CV The technical guide creates a free Knowledge of the planet . Pay attention to the official account, add the edited micro signal, invite to join. .
Other articles
Introduction to computer vision
YOLO Series carding ( One )YOLOv1-YOLOv3
YOLO Series carding ( Two )YOLOv4
YOLO Series carding ( 3、 ... and )YOLOv5
Attention Mechanism in Computer Vision
Build from scratch Pytorch Model tutorial ( Four ) Write the training process -- Argument parsing
Build from scratch Pytorch Model tutorial ( 3、 ... and ) build Transformer The Internet
Build from scratch Pytorch Model tutorial ( Two ) Build network
Build from scratch Pytorch Model tutorial ( One ) data fetch
StyleGAN Grand summary | Comprehensive understanding SOTA Method 、 New progress in architecture
A thermal map visualization code tutorial
Summary of industrial image anomaly detection research (2019-2020)
Some personal thinking habits and thought summary about learning a new technology or field quickly
边栏推荐
- 倍福控制第三方伺服走CSV模式--以汇川伺服为例
- AES-128-CBC-Pkcs7Padding加密PHP实例
- netdata邮件告警配置
- ArcGIS中对面状河流进行等距分段【渐变赋色、污染物扩散】
- Difficult conversation breaks through the bottleneck of conversation and achieves perfect communication
- 三维模型下载与动画控制
- Don't build the wheel again. It is recommended to use Google guava open source tool class library. It is really powerful!
- 1. opencv realizes simple color recognition
- Inferiority complex and transcendence the meaning of life to you
- 倍福控制器连接松下EtherCAT伺服注意事项
猜你喜欢

C#通过中序遍历对二叉树进行线索化

C # indexe l'arbre binaire en traversant l'ordre moyen

Don't build the wheel again. It is recommended to use Google guava open source tool class library. It is really powerful!

Interview shock 61: tell me about MySQL transaction isolation level?

三维模型下载与动画控制

RT-Thread内存管理

C binary tree structure definition and node value addition

How to calculate win/tai/loss in paired t-test

【云原生】2.4 Kubernetes 核心实战(中)

Matlab to find the limit
随机推荐
[Junzheng T31] decompression and packaging of read-only rootfs file system squashfs
CVPR2022 | A ConvNet for the 2020s & 如何设计神经网络总结
STK_ Gltf model
cnpm报错‘cnpm‘不是内部或外部命令,也不是可运行的程序或批处理文件
Detailed explanation on configuration and commissioning of third-party servo of Beifu TwinCAT -- Taking Huichuan is620n as an example
Go Senior Engineer required course | I sincerely suggest you listen to it. Don't miss it~
UI file introduction in QT
Golang image/png 处理图片 旋转 写入
Nacos startup error
leetcode 第 299场周赛
趣谈网络协议(二)传输层
Schiederwerk power supply maintenance smps12/50 pfc3800 analysis
C # implements the operations of sequence table definition, insertion, deletion and search
C#线索二叉树的定义
在印度与软件相关的发明可不可以申请专利?
[cloud native] 2.4 kubernetes core practice (middle)
C#实现二叉树非递归中序遍历程序
File contained log poisoning (user agent)
Application Service Vulnerability scanning and exploitation of network security skills competition in secondary vocational schools (SSH private key disclosure)
AcWing 234 放弃测试
