当前位置:网站首页>Pytorch model parameter adjustment and training related contents
Pytorch model parameter adjustment and training related contents
2022-06-28 13:23:00 【Gu_ NN】
Catalog
Adjust the learning rate
PyTorch Already in torch.optim.lr_scheduler We have encapsulated some methods for dynamically adjusting the learning rate . The method to invoke is as follows :
# Choose an optimizer
optimizer = torch.optim.Adam(...)
# Choose one or more of the methods mentioned above to dynamically adjust the learning rate
scheduler1 = torch.optim.lr_scheduler....
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# Training
for epoch in range(100):
train(...)
validate(...)
optimizer.step()
# You need to dynamically adjust the learning rate after the optimizer parameters are updated
scheduler1.step()
...
schedulern.step()
You can also define the learning rate change through a custom function .
fine-tuning
Modify specified layer , The rest of the parameters remain unchanged
# Freeze the parameters of the original pre training model
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
import torchvision.models as models
# Gradient of freezing parameters
feature_extract = True
# Load the pre trained model
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# The full connection layer of the output part
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)
Third party Library timm
timm yes Ross Wightman establish torchvision Extension library for , Provides a lot of computer vision SOTA Model . The pre trained model list can be obtained through the following command :
import timm
avail_pretrained_models = timm.list_models(pretrained=True)# Fuzzy query is also supported
- Model tuning code
import timm
import torch
# take 1000 Class to 10 Class output
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
# Change the number of input channels
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
- Model storage 、 load
torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))
Semi precision
- Definition :PyTorch The floating-point number is stored from torch.float32 Change it to torch.float16 be called Semi precision .
- Purpose : In practical application , On the premise of ensuring the data accuracy requirements , Reduce memory usage .
- applications : Of the data itself size The larger , Such as 3D Images 、 Video etc. .
- Set up
from torch.cuda.amp import autocast
# use autocast In the decoration model forward function
@autocast()
def forward(self, x):
...
return x
# Training
for x in train_loader:
x = x.cuda()
# Put the data into the model and its subsequent parts into with autocast()
with autocast():
output = model(x)
...
Data to enhance
Picture data can be used imgaug Libraries and Albumentations Library for data enhancement .
Adjustable parameter
- Parameter transfer process
import argparse #python built-in , No installation required
# establish ArgumentParser() object
parser = argparse.ArgumentParser()
# Add parameter
parser.add_argument('-o', '--output', action='store_true',
help="shows output")
# action = `store_true` Will output The parameter record is True
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
# type Specifies the format of parameters
# default Default values are specified
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')
# required=True It means a required parameter
# Use parse_args() analytic function
args = parser.parse_args()
if args.output:
print("This is some output")
print(f"learning rate:{
args.lr} ")
- Operation configuration file of super parameter config.py
import argparse
def get_options(parser=argparse.ArgumentParser()):
parser.add_argument('--workers', type=int, default=0,
help='number of data loading workers, you had better put it '
'4 times of your gpu')
parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--seed', type=int, default=118, help="random seed")
parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')
parser.add_argument('--checkpoint_path',type=str,default='',
help='Path to load a previous trained model if not empty (default empty)')
parser.add_argument('--output',action='store_true',default=True,help="shows output")
opt = parser.parse_args()
if opt.output:
print(f'num_workers: {
opt.workers}')
print(f'batch_size: {
opt.batch_size}')
print(f'epochs (niters) : {
opt.niter}')
print(f'learning rate : {
opt.lr}')
print(f'manual_seed: {
opt.seed}')
print(f'cuda enable: {
opt.cuda}')
print(f'checkpoint_path: {
opt.checkpoint_path}')
return opt
if __name__ == '__main__':
opt = get_options()
- call
import config
opt = config.get_options()
manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path
# Setting of random number , Ensure that the results are reproduced
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
...
if __name__ == '__main__':
set_seed(manual_seed)
for epoch in range(niters):
train(model,lr,batch_size,num_workers,checkpoint_path)
val(model,lr,batch_size,num_workers,checkpoint_path)
Reference resources
datawhale Explain profound theories in simple language pytorch
边栏推荐
- Scratch travel photo album Electronic Society graphical programming scratch grade examination level 1 true questions and answers analysis June 2022
- In the past four years, the number of users exceeded 100 million, and sun Ge led the wave field to a new high
- Centos7——安装mysql5.7
- Go language learning notes - Gorm usage - database configuration, table addition | web framework gin (VII)
- Vscode shortcut key
- 同花顺上怎么进行开户啊, 安全吗
- align-items 与 align-content 的区别
- 电驴怎么显示服务器列表,(转)如何更新电驴服务器列表(eMule Server List)
- 2.01 backpack problem
- Deep understanding of Bayes theorem
猜你喜欢
![[today in history] June 28: musk was born; Microsoft launches office 365; The inventor of Chua's circuit was born](/img/bf/09ccf36caec099098a22f0e8b670bd.png)
[today in history] June 28: musk was born; Microsoft launches office 365; The inventor of Chua's circuit was born

pytorch基础

Oracle 云基础设施扩展分布式云服务,为组织提供更高的灵活性和可控性

PHP crawls web pages for specific information

In the past four years, the number of users exceeded 100 million, and sun Ge led the wave field to a new high

Vs2012 VC creates a new blank window application

如何在熊市中寻找机会?

pytorch模型
![Buuctf:[wustctf2020] plain](/img/0f/a7973d3f7593f2464e48609e27d7bd.png)
Buuctf:[wustctf2020] plain

全志V853芯片 如何在Tina V85x平台切换sensor?
随机推荐
Data analysis - promoter evolution analysis
华泰证券开户怎么开 怎么办理开户最安全
我呕血收集融合了来自各路经典shell书籍的脚本教学,作为小白的你快点来吧
895. 最长上升子序列
PHP crawls web pages for specific information
Stackoverflow 2022 database annual survey
SHAREit实力出众,登陆全球 IAP 实力榜 Top7
G1垃圾收集器中重要的配置参数及其默认值
开源项目维权成功案例: Spug 开源运维平台成功维权
1015. picking flowers
Class structure in C language - dot
投资98万美元的Saas项目失败了
The English translation of heartless sword Zhu Xi's two impressions of reading
RK3399平台开发系列讲解(使用篇)Pinctrl子系统的介绍 - 视频介绍
Centos7: switch MySQL users and log in to MySQL
词云的可视化设计教程
Which company has a low rate for opening a securities account? How to open an account is the safest
G1 important configuration parameters and their default values in the garbage collector
弹性盒子自动换行小Demo
Stm32f1 and stm32cubeide programming example - matrix keyboard driver