当前位置:网站首页>从零搭建Pytorch模型教程(五)编写训练过程--一些基本的配置
从零搭建Pytorch模型教程(五)编写训练过程--一些基本的配置
2022-06-29 12:21:00 【CV技术指南(公众号)】
前言 本文介绍了训练日志的配置方法,为什么需要设置随机数种子,设置随机数种子的方法,加载数据的配置,学习的配置和调整方法,损失函数的配置和自定义损失函数的写法。
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
训练日志的配置
训练日志是用于保存训练过程中的一些信息,方便事后查看模型的训练情况。
首先是准备好基本的配置。
import logging
def train_logger(num):
logger = logging.getLogger(__name__)
#设置打印的级别,一共有6个级别,从低到高分别为:
#NOTEST、DEBUG、INFO、WARNING、ERROR、CRITICAL。
#setLevel设置的是最低打印的级别,低于该级别的将不会打印。
logger.setLevel(level=logging.INFO)
#打印到文件,并设置打印的文件名称和路径
file_log = logging.FileHandler('./run/{}/train.log'.format(num))
#打印到终端
print_log = logging.StreamHandler()
#设置打印格式
#%(asctime)表示当前时间,%(message)表示要打印的信息,用的时候会介绍。
formatter = logging.Formatter('%(asctime)s %(message)s')
file_log.setFormatter(formatter)
print_log.setFormatter(formatter)
logger.addHandler(file_log)
logger.addHandler(print_log)
return logger使用方法
logger = train_logger(0)
logger.info("project_name: {}".format(project_name))
logger.info("batchsize: {}".format(opt.batchsize))
logger.warning("warning message")设置随机数种子
随机数种子是为了固定数据集每次的打乱顺序,让模型变得可复现。不设置随机数种子的话,由于数据集每次打乱的顺序都不一样,导致模型会略有浮动。
咱们这里稍微介绍一下浮动的原因。
模型浮动的原因
目前基本都是使用mini-batch梯度下降,也就是说每次都是前传一个batch的数据后,才会更新权重,与此同时,模型基本都是有使用BN,即对每个batch做归一化。因此,batch数据对模型的性能会有一定的影响。
如果每次随机顺序都不一样,可能会存在某几次的batch组合得非常好,以至于模型训练效果不错,而其它时候的batch的组合不是很合适,以至于达不到组合得很好的时候的效果。
因此,所谓的原因就是每次的不同顺序产生了batch样本的多样性,batch样本多样性对模型的结果有一定的影响。数据集越小,这个影响可能越大,因为造成的batch样本之间的差异性很大,而数据集越大时,batch样本之间的差异性可能受到随机顺序的影响越小。
因此,设置随机数种子时一个必要的事情,设置随机数种子后,每次训练生成顺序都是一样的。
计算机上的随机数都是人工模拟出来的,因此,我们可以任意地设置随机数的范围等。
设置随机数种子的方法
import random
#基本配置
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.deterministic = True
#使用方法
#在train文件下调用下面这行命令,参数随意设置,
#只要这个参数数值一样,每次生成的顺序也就一样
setup_seed(2022)加载数据的配置
加载数据的配置比较容易,仅通过以下几行代码即可。
from torch.utils.data import DataLoader
from dataloader import MyDataset
train_folder = opt.data_dir + '/train'
train_dataset = MyDataset(data_folder=train_folder,opt=opt)
train_loader = DataLoader(train_dataset, batch_size=opt.batchsize, shuffle=True, num_workers=8)主要是设置DataLoader,其中shuffler基本默认为True,如果是多级多卡分布式训练,则shuffle为false,而sampler通过下面的代码来获取,num_workers表示加载数据使用的进程数量。
from torch.utils.data.distributed import DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset, batch_size=opt.batchsize, sampler=train_sampler)
网络的初始化
网络初始化主要是两件事,一个是初始化网络,另一个是加载预训练模型,由于前面第二篇中我们已经介绍过了如何加载指定层的预训练参数,因此,这里就不多介绍了。这部分也比较简单。
from model import Yolo_v1
net = Yolo_v1()
net.load_state_dict(torch.load(trained_path))
net = net.cuda()学习率的设置
学习的设置主要是介绍一下如何在不同的层设定不同的学习率。例如backbone使用的是预训练模型,而全连接层是使用随机初始化的,因此backbone需要小学习率,全连接层需要大学习率。
import torch.optim as optim
from torch.optim import lr_scheduler
optim_params = [{'params': net.backbone.parameters(), 'lr': 0.1 * opt.lr},
{'params': net.interaction.parameters(), 'lr': opt.lr},
{'params': net.large_conv1_1.parameters(), 'lr': opt.lr},
{'params': net.large_conv1_2.parameters(), 'lr': opt.lr},
{'params': net.classifier.parameters(), 'lr': opt.lr}
]
optimizer = optim.SGD(optim_params, weight_decay=5e-4, momentum=0.9, nesterov=True)
scheduler = lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.1)首先是构建一个参数和对应学习率的列表,然后作为参数传给optim的这些优化器,例子中用的是SGD。
此外还设置了一个学习率的调度器 lr_scheduler,用于训练到不同的epoch时调整学习率。这里的step_size表示在每80个epoch时,学习率乘以gamma值。
还有一个lr_scheduler.MultiStepLR(optimizer, milestones=[30,50,60], gamma=0.1, last_epoch=-1)。这与StepLR的区别是MultiStepLR是根据milestones中的时间来调整学习率的,这个例子中表示的是在第30、50、60epoch时调整。
上面这里是学习率的基本配置,下面还涉及到学习率在训练过程中的调整。
学习率调整的方式
在PyTorch 1.1.0之前的版本,学习率的调整应该被放在optimizer更新之前。1.1.0版本后用在optimizer更新后。
for epoch in range(opt.num_epochs):
#省略一部分代码
loss.backward()
optimizer.step()
scheduler.step()损失函数的设置
通用的损失函数是简单的,直接通过下面几行代码即可。
import torch.nn as nn
cls_criterion = nn.CrossEntropyLoss()
dist_criterion = nn.MSELoss() # Use L2 loss function
hinge_criterion = nn.HingeEmbeddingLoss()
但有些损失函数是自己设计的,因此需要自己来实现。下面以TripletLoss为例。
import torch
import torch.nn as nn
import torch.nn.functional as F
class TripletLoss(nn.Module):
def __init__(self, t1, t2, beta):
super(TripletLoss, self).__init__()
self.t1 = t1
self.t2 = t2
self.beta = beta
return
def forward(self, anchor, positive, negative):
matched = torch.pow(F.pairwise_distance(anchor, positive), 2)
mismatched = torch.pow(F.pairwise_distance(anchor, negative), 2)
part_1 = torch.clamp(matched - mismatched, min=self.t1)
part_2 = torch.clamp(matched, min=self.t2)
dist_hinge = part_1 + self.beta * part_2
loss = torch.mean(dist_hinge)
return loss简单介绍一下用法,跟定义网络一样,继承nn.Module,然后完成__init__和__forward__函数,但与网络不同的是,损失函数中没有可训练的参数,因此通常直接使用torch.nn.functional中的函数即可。
这也是torch.nn.functional中函数与nn中的函数的区别,前者需要自己设置权重,且不会随训练过程更新,而后者不需要自己设置权重,权重会更新。
由于篇幅有限,这里先介绍这么多,下篇我们继续介绍编写训练过程中的一些其它配置,如Tensorboard,训练过程的搭建等。
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
CV技术指南创建了一个免费的知识星球。关注公众号添加编辑的微信号可邀请加入。
其它文章
招聘 | 迁移科技招聘深度学习、视觉、3D视觉、机器人算法工程师等多个职位
Attention Mechanism in Computer Vision
从零搭建Pytorch模型教程(四)编写训练过程--参数解析
从零搭建Pytorch模型教程(三)搭建Transformer网络
边栏推荐
- 服务器监控netdata面板配置邮件服务
- 3D model downloading and animation control
- huffman编码
- Beifu PLC controls servo through CANopen communication
- 2022.6.28-----leetcode. three hundred and twenty-four
- Yunlong fire version aircraft battle (full version)
- 倍福TwinCAT配置、调试第三方伺服详细讲解--以汇川IS620N为例子
- OPC of Beifu twincat3_ UA communication test case
- C#实现顺序表定义、插入、删除、查找操作
- Golang image/png processing image rotation writing
猜你喜欢

倍福TwinCAT3 的OPC_UA通信测试案例

Matlab简单入门

LR、CR纽扣电池对照表

Comparison table of LR and Cr button batteries

QT custom control: value range

推荐模型复现(四):多任务模型ESMM、MMOE

oracle 19c : change the user sys/system username pasword under Linux

Paper reproduction - ac-fpn:attention-guided context feature pyramid network for object detection

Go Senior Engineer required course | I sincerely suggest you listen to it. Don't miss it~

Definition of C # clue binary tree
随机推荐
360数科新能源专项产品规模突破60亿
YOLO系列梳理(九)初尝新鲜出炉的YOLOv6
SCHIEDERWERK电源维修SMPS12/50 PFC3800解析
Testing -- automated testing: about the unittest framework
推荐模型复现(二):精排模型DeepFM、DIN
C#通過中序遍曆對二叉樹進行線索化
C # realize the definition, stack entry and stack exit of stack structure
C#实现二叉树非递归中序遍历程序
趣谈网络协议(二)传输层
Golang image/png processing image rotation writing
huffman编码
Proteus软件初学笔记
Murphy safety was selected for signing 24 key projects of Zhongguancun Science City
ZALSM_ EXCEL_ TO_ INTERNAL_ Solving the big problem of importing data from table
Matlab简单入门
Earth observation satellite data
C # output the middle order traversal through the clue binary tree
安装typescript环境并开启VSCode自动监视编译ts文件为js文件
C # realizes the first order traversal, middle order traversal and second order traversal of binary tree
1. Opencv实现简单颜色识别
