当前位置:网站首页>Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整
Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整
2022-08-01 19:16:00 【zstar-_】
前言
最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读。
于是在gayhub上找到了这样一份教程《Pytorch模型训练实用教程》,写得不错,特此根据它来再学习一下Pytorch。
仓库地址:https://github.com/TingsongYu/PyTorch_Tutorial
优化器概念
Pytorch提供了优化器Optimizer,作为基类,它有下面几种方法
param_groups
调用param_groups,可以查看一个优化器的参数组,其包含了每一层的权值,偏置,学习率等参数。
调用实例:
# coding: utf-8
import torch
import torch.optim as optim
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
w3 = torch.randn(2, 2)
w3.requires_grad = True
# 一个参数组
optimizer_1 = optim.SGD([w1, w3], lr=0.1)
print('len(optimizer.param_groups): ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
# 两个参数组
optimizer_2 = optim.SGD([{
'params': w1, 'lr': 0.1},
{
'params': w2, 'lr': 0.001}])
print('len(optimizer.param_groups): ', len(optimizer_2.param_groups))
print(optimizer_2.param_groups)
zero_grad()
功能:将梯度清零。
调用示例:
# coding: utf-8
import torch
import torch.optim as optim
# ----------------------------------- zero_grad
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
optimizer = optim.SGD([w1, w2], lr=0.001, momentum=0.9)
optimizer.param_groups[0]['params'][0].grad = torch.randn(2, 2)
print('参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad, '\n') # 参数组,第一个参数(w1)的梯度
optimizer.zero_grad()
print('执行zero_grad()之后,参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad) # 参数组,第一个参数(w1)的梯度
state_dict()
功能:获取模型当前的参数,以一个有序字典形式返回。
调用示例:
# coding: utf-8
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
net = Net()
# 获取网络当前参数
net_state_dict = net.state_dict()
print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
for key, value in net_state_dict.items():
print('参数名: ', key, '\t大小: ', value.shape)
load_state_dict(state_dict)
功能:将 state_dict 中的参数加载到当前网络。
调用示例:
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- load_state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
def zero_param(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.constant_(m.weight.data, 0)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
torch.nn.init.constant_(m.weight.data, 0)
m.bias.data.zero_()
net = Net()
# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl') # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')
# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')
# 通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])
add_param_group()
功能:给 optimizer 管理的参数组中增加一组参数,可为该组参数定制 lr, momentum, weight_decay等。
调用示例:
# coding: utf-8
import torch
import torch.optim as optim
# ----------------------------------- add_param_group
w1 = torch.randn(2, 2)
w1.requires_grad = True
w2 = torch.randn(2, 2)
w2.requires_grad = True
w3 = torch.randn(2, 2)
w3.requires_grad = True
# 一个参数组
optimizer_1 = optim.SGD([w1, w2], lr=0.1)
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
# 增加一个参数组
print('增加一组参数 w3\n')
optimizer_1.add_param_group({
'params': w3, 'lr': 0.001, 'momentum': 0.8})
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')
step()
功能:执行一步权值更新。
优化器汇总
torch.optim.SGD
torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)
功能:
可实现 SGD 优化算法,带动量 SGD 优化算法,带 NAG(Nesterov accelerated gradient)动量 SGD 优化算法。
参数:
params(iterable)- 参数组(参数组的概念请查看 3.2 优化器基类:Optimizer),优化器
要管理的那部分参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
momentum(float)- 动量,通常设置为 0.9,0.8
dampening(float)- dampening for momentum ,暂时不了其功能,在源码中是这样用的:buf.mul_(momentum).add_(1 - dampening, d_p),值得注意的是,若采用nesterov,dampening 必须为 0.
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数
nesterov(bool)- bool 选项,是否使用 NAG(Nesterov accelerated gradient)
torch.optim.ASGD
torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)
功能:
ASGD 也成为 SAG,均表示随机平均梯度下降
参数:
params(iterable)- 参数组(参数组的概念请查看 3.1 优化器基类:Optimizer),优化器要优化的那些参数。
lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。
lambd(float)- 衰减项,默认值 1e-4。
alpha(float)- power for eta update ,默认值 0.75。
t0(float)- point at which to start averaging,默认值 1e6。
weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数。
torch.optim.Rprop
torch.optim.Rprop(params, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))
功能:
实现 Rprop 优化方法(弹性反向传播),该优化方法适用于 full-batch,不适用于 mini-batch。
torch.optim.Adagrad
torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
功能:
实现 Adagrad 优化方法(Adaptive Gradient),Adagrad 是一种自适应优化方法,是自适应的为各个参数分配不同的学习率。这个学习率的变化,会受到梯度的大小和迭代次数的影响。梯度越大,学习率越小;梯度越小,学习率越大。缺点是训练后期,学习率过小,因为 Adagrad 累加之前所有的梯度平方作为分母。
torch.optim.Adadelta
torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
功能:
实现 Adadelta 优化方法。Adadelta 是 Adagrad 的改进。Adadelta 分母中采用距离当前时间点比较近的累计项,这可以避免在训练后期,学习率过小。
torch.optim.RMSprop
torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
功能:
实现 RMSprop 优化方法,RMS 是均方根(root meam square)的意思。RMSprop 采用均方根作为分
母,可缓解 Adagrad 学习率下降较快的问题。
torch.optim.Adam(AMSGrad)
torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
功能:
实现 Adam(Adaptive Moment Estimation)优化方法。Adam 是一种自适应学习率的优化方法,Adam 利用梯度的一阶矩估计和二阶矩估计动态的调整学习率。
torch.optim.Adamax
torch.optim.Adamax(params, lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
功能:
实现 Adamax 优化方法。Adamax 是对 Adam 增加了一个学习率上限的概念,所以也称之为 Adamax。
torch.optim.SparseAdam
torch.optim.SparseAdam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08)
功能:
针对稀疏张量的一种Adam优化方法。
torch.optim.LBFGS
torch.optim.LBFGS(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-05, tolerance_change=1e-09, history_size=100, line_search_fn=None)
功能:
实现 L-BFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno)优化方法。L-BFGS 属于拟牛顿算法。L-BFGS 是对 BFGS 的改进,特点就是节省内存。
学习率调整
为了让学习率能够随着模型的训练进行动态调整,Pytorch提供了下列一些学习率调整方法。
lr_scheduler.StepLR
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
功能:
等间隔调整学习率,调整倍数为 gamma 倍,调整间隔为 step_size。间隔单位是step。
参数:
step_size(int)- 学习率下降间隔数,若为 30,则会在 30、60、90…个 step 时,将学习率调整为 lr*gamma。
gamma(float)- 学习率调整倍数,默认为0.1倍,即下降10倍。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
lr_scheduler.MultiStepLR
torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)
功能:
按设定的间隔调整学习率。这个方法适合后期调试使用,观察 loss 曲线,为每个实验定制学习率调整时机。
参数:
milestones(list)- 一个 list,每一个元素代表何时调整学习率,list 元素必须是递增的。如 milestones=[30,80,120]
gamma(float)- 学习率调整倍数,默认为 0.1 倍,即下降 10 倍。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
lr_scheduler.ExponentialLR
torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1)
功能:
按指数衰减调整学习率,调整公式: lr = lr * gammaepoch
参数:
gamma- 学习率调整倍数的底,指数为 epoch,即 gammaepoch
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
lr_scheduler.CosineAnnealingLR
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
功能:
以余弦函数为周期,并在每个周期最大值时重新设置学习率。
参数:
T_max(int)- 一次学习率周期的迭代次数,即 T_max 个 epoch 之后重新设置学习率。
eta_min(float)- 最小学习率,即在一个周期中,学习率最小会下降到 eta_min,默认值为 0。
lr_scheduler.ReduceLROnPlateau
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=‘min’, factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode=‘rel’, cooldown=0, min_lr=0, eps=1e-08)
功能:
当某指标不再变化(下降或升高),调整学习率,这是非常实用的学习率调整策略。
参数:
mode(str)- 模式选择,有 min 和 max 两种模式,min 表示当指标不再降低(如监测loss),max 表示当指标不再升高(如监测 accuracy)。
factor(float)- 学习率调整倍数(等同于其它方法的 gamma),即学习率更新为 lr = lr * factor
patience(int)- 直译——“耐心”,即忍受该指标多少个 step 不变化,当忍无可忍时,调整学习率。
verbose(bool)- 是否打印学习率信息
threshold_mode(str)- 选择判断指标是否达最优的模式,有两种模式,rel 和 abs
cooldown(int)- “冷却时间“,当调整学习率之后,让学习率调整策略冷静一下,让模型再训练一段时间,再重启监测模式。
min_lr(float or list)- 学习率下限,可为 float,或者 list,当有多个参数组时,可用 list 进行设置。
eps(float)- 学习率衰减的最小值,当学习率变化小于 eps 时,则不调整学习率。
lr_scheduler.LambdaLR
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
功能:
为不同参数组设定不同学习率调整策略。
参数:
lr_lambda(function or list)- 一个计算学习率调整倍数的函数,输入通常为 step,当有多个参数组时,设为 list。
last_epoch(int)- 上一个 epoch 数,这个变量用来指示学习率是否需要调整。当last_epoch 符合设定的间隔时,就会对学习率进行调整。当为-1 时,学习率设置为初始值。
边栏推荐
- Win11如何开启剪贴板自动复制?Win11开启剪贴板自动复制的方法
- Redis启动时提示Creating Server TCP listening socket *:6379: bind: No error
- The XML configuration
- 首篇 NLP 领域图神经网络综述:127 页,从图构建到实际应用面面观
- 突破边界,华为存储的破壁之旅
- Heavy cover special | build the first line of defense, cloud firewall offensive and defensive drills best practices
- Screenshot of Selenium in Remote
- 【神经网络】一文带你轻松解析神经网络(附实例恶搞女友)
- Ha ha!A print function, quite good at playing!
- Heavy cover special | intercept 99% malicious traffic, reveal WAF offensive and defensive drills best practices
猜你喜欢

cf:D. Magical Array【数学直觉 + 前缀和的和】

kubernetes - deploy nfs storage class

Try compiling QT test on Allwinner V853 development board

Combining two ordered arrays

首篇 NLP 领域图神经网络综述:127 页,从图构建到实际应用面面观

Win11校园网无法连接怎么办?Win11连接不到校园网的解决方法

The life cycle and scope

Hardware Bear Original Collection (Updated 2022/07)

使用常见问题解答软件的好处有哪些?

LeetCode 0152. Product Maximum Subarray: dp + Roll in Place
随机推荐
[Server data recovery] Data recovery case of offline multiple disks in mdisk group of server Raid5 array
app直播源码,点击搜索栏自动弹出下拉框
金鱼哥RHCA回忆录:CL210管理OPENSTACK网络--网络配置选项
面试必问的HashCode技术内幕
Win11怎么安装语音包?Win11语音包安装教程
在GBase 8c数据库后台,使用什么样的命令来对gtm、dn节点进行主备切换的操作?
The XML configuration
在Map传值与对象传值中模糊查询
Screenshot of Selenium in Remote
C#/VB.NET 从PDF中提取表格
#yyds干货盘点# 面试必刷TOP101: 链表中倒数最后k个结点
mysql函数的作用有哪些
In the background of the GBase 8c database, what command is used to perform the master-slave switchover operation for the gtm and dn nodes?
[Neural Network] This article will take you to easily analyze the neural network (with an example of spoofing your girlfriend)
SQL的 ISNULL 函数
Library website construction source code sharing
mysql解压版简洁式本地配置方式
即时通讯开发移动端弱网络优化方法总结
Tencent Cloud Hosting Security x Lightweight Application Server | Powerful Joint Hosting Security Pratt & Whitney Version Released
Redis启动时提示Creating Server TCP listening socket *:6379: bind: No error