当前位置:网站首页>Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整
Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整
2022-08-01 19:16:00 【zstar-_】
前言
最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读。
于是在gayhub上找到了这样一份教程《Pytorch模型训练实用教程》,写得不错,特此根据它来再学习一下Pytorch。
仓库地址:https://github.com/TingsongYu/PyTorch_Tutorial
优化器概念
Pytorch提供了优化器Optimizer,作为基类,它有下面几种方法
param_groups
调用param_groups,可以查看一个优化器的参数组,其包含了每一层的权值,偏置,学习率等参数。
调用实例:
# coding: utf-8
import torch
import torch.optim as optim
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
w3 = torch.randn(2, 2)
w3.requires_grad = True
# 一个参数组
optimizer_1 = optim.SGD([w1, w3], lr=0.1)
print('len(optimizer.param_groups): ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
# 两个参数组
optimizer_2 = optim.SGD([{
'params': w1, 'lr': 0.1},
{
'params': w2, 'lr': 0.001}])
print('len(optimizer.param_groups): ', len(optimizer_2.param_groups))
print(optimizer_2.param_groups)
zero_grad()
功能:将梯度清零。
调用示例:
# coding: utf-8
import torch
import torch.optim as optim
# ----------------------------------- zero_grad
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
optimizer = optim.SGD([w1, w2], lr=0.001, momentum=0.9)
optimizer.param_groups[0]['params'][0].grad = torch.randn(2, 2)
print('参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad, '\n') # 参数组,第一个参数(w1)的梯度
optimizer.zero_grad()
print('执行zero_grad()之后,参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad) # 参数组,第一个参数(w1)的梯度
state_dict()
功能:获取模型当前的参数,以一个有序字典形式返回。
调用示例:
# coding: utf-8
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
net = Net()
# 获取网络当前参数
net_state_dict = net.state_dict()
print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
for key, value in net_state_dict.items():
print('参数名: ', key, '\t大小: ', value.shape)
load_state_dict(state_dict)
功能:将 state_dict 中的参数加载到当前网络。
调用示例:
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- load_state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
def zero_param(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.constant_(m.weight.data, 0)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.constant_(m.weight.data, 0)
m.bias.data.zero_()
net = Net()
# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl') # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')
# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')
# 通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])
add_param_group()
功能:给 optimizer 管理的参数组中增加一组参数,可为该组参数定制 lr, momentum, weight_decay等。
调用示例:
# coding: utf-8
import torch
import torch.optim as optim
# ----------------------------------- add_param_group
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
w3 = torch.randn(2, 2)
w3.requires_grad = True
# 一个参数组
optimizer_1 = optim.SGD([w1, w2], lr=0.1)
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
# 增加一个参数组
print('增加一组参数 w3\n')
optimizer_1.add_param_group({
'params': w3, 'lr': 0.001, 'momentum': 0.8})
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
step()
功能:执行一步权值更新。
优化器汇总
torch.optim.SGD
torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)
功能:
可实现 SGD 优化算法,带动量 SGD 优化算法,带 NAG(Nesterov accelerated gradient)动量 SGD 优化算法。
参数:
params(iterable)- 参数组(参数组的概念请查看 3.2 优化器基类:Optimizer),优化器
要管理的那部分参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
momentum(float)- 动量,通常设置为 0.9,0.8
dampening(float)- dampening for momentum ,暂时不了其功能,在源码中是这样用的:buf.mul_(momentum).add_(1 - dampening, d_p),值得注意的是,若采用nesterov,dampening 必须为 0.
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数
nesterov(bool)- bool 选项,是否使用 NAG(Nesterov accelerated gradient)
torch.optim.ASGD
torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)
功能:
ASGD 也成为 SAG,均表示随机平均梯度下降
参数:
params(iterable)- 参数组(参数组的概念请查看 3.1 优化器基类:Optimizer),优化器要优化的那些参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
lambd(float)- 衰减项,默认值 1e-4。
alpha(float)- power for eta update ,默认值 0.75。
t0(float)- point at which to start averaging,默认值 1e6。
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数。
torch.optim.Rprop
torch.optim.Rprop(params, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))
功能:
实现 Rprop 优化方法(弹性反向传播),该优化方法适用于 full-batch,不适用于 mini-batch。
torch.optim.Adagrad
torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
功能:
实现 Adagrad 优化方法(Adaptive Gradient),Adagrad 是一种自适应优化方法,是自适应的为各个参数分配不同的学习率。这个学习率的变化,会受到梯度的大小和迭代次数的影响。梯度越大,学习率越小;梯度越小,学习率越大。缺点是训练后期,学习率过小,因为 Adagrad 累加之前所有的梯度平方作为分母。
torch.optim.Adadelta
torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
功能:
实现 Adadelta 优化方法。Adadelta 是 Adagrad 的改进。Adadelta 分母中采用距离当前时间点比较近的累计项,这可以避免在训练后期,学习率过小。
torch.optim.RMSprop
torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
功能:
实现 RMSprop 优化方法,RMS 是均方根(root meam square)的意思。RMSprop 采用均方根作为分
母,可缓解 Adagrad 学习率下降较快的问题。
torch.optim.Adam(AMSGrad)
torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
功能:
实现 Adam(Adaptive Moment Estimation)优化方法。Adam 是一种自适应学习率的优化方法,Adam 利用梯度的一阶矩估计和二阶矩估计动态的调整学习率。
torch.optim.Adamax
torch.optim.Adamax(params, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
功能:
实现 Adamax 优化方法。Adamax 是对 Adam 增加了一个学习率上限的概念,所以也称之为 Adamax。
torch.optim.SparseAdam
torch.optim.SparseAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
功能:
针对稀疏张量的一种Adam优化方法。
torch.optim.LBFGS
torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-05, tolerance_change=1e-09, history_size=100, line_search_fn=None)
功能:
实现 L-BFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno)优化方法。L-BFGS 属于拟牛顿算法。L-BFGS 是对 BFGS 的改进,特点就是节省内存。
学习率调整
为了让学习率能够随着模型的训练进行动态调整,Pytorch提供了下列一些学习率调整方法。
lr_scheduler.StepLR
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
功能:
等间隔调整学习率,调整倍数为 gamma 倍,调整间隔为 step_size。间隔单位是step。
参数:
step_size(int)- 学习率下降间隔数,若为 30,则会在 30、60、90…个 step 时,将学习率调整为 lr*gamma。
gamma(float)- 学习率调整倍数,默认为0.1倍,即下降10倍。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
lr_scheduler.MultiStepLR
torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
功能:
按设定的间隔调整学习率。这个方法适合后期调试使用,观察 loss 曲线,为每个实验定制学习率调整时机。
参数:
milestones(list)- 一个 list,每一个元素代表何时调整学习率,list 元素必须是递增的。如 milestones=[30,80,120]
gamma(float)- 学习率调整倍数,默认为 0.1 倍,即下降 10 倍。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
lr_scheduler.ExponentialLR
torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)
功能:
按指数衰减调整学习率,调整公式: lr = lr * gammaepoch
参数:
gamma- 学习率调整倍数的底,指数为 epoch,即 gammaepoch
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
lr_scheduler.CosineAnnealingLR
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
功能:
以余弦函数为周期,并在每个周期最大值时重新设置学习率。
参数:
T_max(int)- 一次学习率周期的迭代次数,即 T_max 个 epoch 之后重新设置学习率。
eta_min(float)- 最小学习率,即在一个周期中,学习率最小会下降到 eta_min,默认值为 0。
lr_scheduler.ReduceLROnPlateau
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=‘min’, factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode=‘rel’, cooldown=0, min_lr=0, eps=1e-08)
功能:
当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。
参数:
mode(str)- 模式选择,有 min 和 max 两种模式,min 表示当指标不再降低(如监测loss),max 表示当指标不再升高(如监测 accuracy)。
factor(float)- 学习率调整倍数(等同于其它方法的 gamma),即学习率更新为 lr = lr * factor
patience(int)- 直译——“耐心”,即忍受该指标多少个 step 不变化,当忍无可忍时,调整学习率。
verbose(bool)- 是否打印学习率信息
threshold_mode(str)- 选择判断指标是否达最优的模式,有两种模式,rel 和 abs
cooldown(int)- “冷却时间“,当调整学习率之后,让学习率调整策略冷静一下,让模型再训练一段时间,再重启监测模式。
min_lr(float or list)- 学习率下限,可为 float,或者 list,当有多个参数组时,可用 list 进行设置。
eps(float)- 学习率衰减的最小值,当学习率变化小于 eps 时,则不调整学习率。
lr_scheduler.LambdaLR
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
功能:
为不同参数组设定不同学习率调整策略。
参数:
lr_lambda(function or list)- 一个计算学习率调整倍数的函数,输入通常为 step,当有多个参数组时,设为 list。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
边栏推荐
- When installing the GBase 8c database, the error message "Resource: gbase8c already in use" is displayed. How to deal with this?
- 力扣刷题之合并两个有序数组
- 升哲科技携全域数字化方案亮相2022全球数字经济大会
- Win11如何开启剪贴板自动复制?Win11开启剪贴板自动复制的方法
- 在GBase 8c数据库后台,使用什么样的命令来对gtm、dn节点进行主备切换的操作?
- 基于flowable的upp(统一流程平台)运行性能优化
- 如何看待腾讯云数据库负责人林晓斌借了一个亿炒股?
- 网站建设流程
- DAO开发教程【WEB3.0】
- 如何记录分析你的炼丹流程—可视化神器Wandb使用笔记【1】
猜你喜欢

Shell script topic (07): file from cfs to bos

力扣刷题之合并两个有序数组

开源视界 | StreamNative 盛宇帆:和浪漫的人一起做最浪漫的事

Hardware Bear Original Collection (Updated 2022/07)

Source code analysis of GZIPOutputStream class

kubernetes-部署nfs存储类

BN BatchNorm + BatchNorm的替代新方法KNConvNets

SENSORO成长伙伴计划 x 怀柔黑马科技加速实验室丨以品牌力打造To B企业影响力

#yyds干货盘点# 面试必刷TOP101: 链表中倒数最后k个结点

shell脚本专题(07):文件由cfs到bos
随机推荐
From ordinary advanced to excellent test/development programmer, all the way through
使用常见问题解答软件的好处有哪些?
在表格数据上,为什么基于树的模型仍然优于深度学习?
ThreadLocal讲义
How to query database configuration parameters in GBase 8c, such as datestyle.What function or syntax to use?
Ha ha!A print function, quite good at playing!
mysql解压版简洁式本地配置方式
Break the performance ceiling!AsiaInfo database supports more than 1 billion users, with a peak of one million transactions per second
Mobile Zero of Likou Brush Questions
选择合适的 DevOps 工具,从理解 DevOps 开始
MySQL database - stored procedures and functions
Find the sum of two numbers
A simple Flask PIN
GZIPOutputStream 类源码分析
【木棉花】#夏日挑战赛# 鸿蒙小游戏项目——数独Sudoku(3)
SENSORO成长伙伴计划 x 怀柔黑马科技加速实验室丨以品牌力打造To B企业影响力
工作5年,测试用例都设计不好?来看看大神的用例设计总结
短视频软件开发,Android开发,使用Kotlin实现WebView
硬件大熊原创合集(2022/07更新)
COS User Practice Call for Papers