当前位置:网站首页>Pytorch model parameter adjustment and training related contents
Pytorch model parameter adjustment and training related contents
2022-06-28 13:23:00 【Gu_ NN】
Catalog
Adjust the learning rate
PyTorch Already in torch.optim.lr_scheduler We have encapsulated some methods for dynamically adjusting the learning rate . The method to invoke is as follows :
# Choose an optimizer
optimizer = torch.optim.Adam(...)
# Choose one or more of the methods mentioned above to dynamically adjust the learning rate
scheduler1 = torch.optim.lr_scheduler....
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# Training
for epoch in range(100):
train(...)
validate(...)
optimizer.step()
# You need to dynamically adjust the learning rate after the optimizer parameters are updated
scheduler1.step()
...
schedulern.step()
You can also define the learning rate change through a custom function .
fine-tuning
Modify specified layer , The rest of the parameters remain unchanged
# Freeze the parameters of the original pre training model
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
import torchvision.models as models
# Gradient of freezing parameters
feature_extract = True
# Load the pre trained model
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# The full connection layer of the output part
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)
Third party Library timm
timm yes Ross Wightman establish torchvision Extension library for , Provides a lot of computer vision SOTA Model . The pre trained model list can be obtained through the following command :
import timm
avail_pretrained_models = timm.list_models(pretrained=True)# Fuzzy query is also supported
- Model tuning code
import timm
import torch
# take 1000 Class to 10 Class output
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
# Change the number of input channels
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
- Model storage 、 load
torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))
Semi precision
- Definition :PyTorch The floating-point number is stored from torch.float32 Change it to torch.float16 be called Semi precision .
- Purpose : In practical application , On the premise of ensuring the data accuracy requirements , Reduce memory usage .
- applications : Of the data itself size The larger , Such as 3D Images 、 Video etc. .
- Set up
from torch.cuda.amp import autocast
# use autocast In the decoration model forward function
@autocast()
def forward(self, x):
...
return x
# Training
for x in train_loader:
x = x.cuda()
# Put the data into the model and its subsequent parts into with autocast()
with autocast():
output = model(x)
...
Data to enhance
Picture data can be used imgaug Libraries and Albumentations Library for data enhancement .
Adjustable parameter
- Parameter transfer process
import argparse #python built-in , No installation required
# establish ArgumentParser() object
parser = argparse.ArgumentParser()
# Add parameter
parser.add_argument('-o', '--output', action='store_true',
help="shows output")
# action = `store_true` Will output The parameter record is True
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
# type Specifies the format of parameters
# default Default values are specified
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')
# required=True It means a required parameter
# Use parse_args() analytic function
args = parser.parse_args()
if args.output:
print("This is some output")
print(f"learning rate:{
args.lr} ")
- Operation configuration file of super parameter config.py
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument('--workers', type=int, default=0,
help='number of data loading workers, you had better put it '
'4 times of your gpu')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--seed', type=int, default=118, help="random seed")
parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
parser.add_argument('--checkpoint_path',type=str,default='',
help='Path to load a previous trained model if not empty (default empty)')
parser.add_argument('--output',action='store_true',default=True,help="shows output")
opt = parser.parse_args()
if opt.output:
print(f'num_workers: {
opt.workers}')
print(f'batch_size: {
opt.batch_size}')
print(f'epochs (niters) : {
opt.niter}')
print(f'learning rate : {
opt.lr}')
print(f'manual_seed: {
opt.seed}')
print(f'cuda enable: {
opt.cuda}')
print(f'checkpoint_path: {
opt.checkpoint_path}')
return opt
if __name__ == '__main__':
opt = get_options()
- call
import config
opt = config.get_options()
manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path
# Setting of random number , Ensure that the results are reproduced
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
...
if __name__ == '__main__':
set_seed(manual_seed)
for epoch in range(niters):
train(model,lr,batch_size,num_workers,checkpoint_path)
val(model,lr,batch_size,num_workers,checkpoint_path)
Reference resources
datawhale Explain profound theories in simple language pytorch
边栏推荐
- Buuctf:[wustctf2020] plain
- 简历模板百度网盘自取
- 移动Web实训DAY-1
- 新品体验:阿里云新一代本地SSD实例i4开放公测
- php获取数字的个位数并替换为指定的尾数
- Shareit a une force exceptionnelle et se connecte au top 7 de la liste mondiale des forces IAP
- 数据分析-启动子进化分析
- 1015. picking flowers
- The press conference of Tencent cloud Database & CSDN engineer's ability lightweight certification is coming
- PHP根据年月获取月初月末时间
猜你喜欢

Embedded development: seven techniques for estimating battery life

真香啊!最全的 Pycharm 常用快捷键大全!

Fs7022 scheme series fs4059a dual two lithium battery series charging IC and protection IC

China Database Technology Conference (DTCC) specially invited experts from Kelan sundb database to share

STM32F1与STM32CubeIDE编程实例-矩阵键盘驱动

4年用户数破亿,孙哥带领波场再创新高

Oracle 云基础设施扩展分布式云服务,为组织提供更高的灵活性和可控性

MySQL multi table joint query

How vscade sets auto save code

Vscode如何设置自动保存代码
随机推荐
How to set auto format after saving code in vscade
G1垃圾收集器中重要的配置参数及其默认值
The difference between align items and align content
How to handle the safest account opening when downloading the mobile app of Huatai Securities
求职简历的书写技巧
Data analysis - promoter evolution analysis
PHP根据年月获取月初月末时间
Mobile web training day-2
Professional English calendar questions
How to open an account on the flush? Is it safe
Google Earth engine (GEE) - Global organic soil area of FAO (1992-2018)
基于SSM实现水果蔬菜商城管理系统
1015.摘花生
mysql数据库扫盲,你真的知道什么是数据库嘛
Vscode shortcut key
Go language learning notes - Gorm usage - database configuration, table addition | web framework gin (VII)
##测试bug常用“Redmine”
ShareIt has outstanding strength and landed in the top 7 of the global IAP strength list
海思35xx实现GT911触摸屏功能「建议收藏」
Flutter series part: detailed explanation of GridView layout commonly used in flutter