当前位置:网站首页>pytorch优化器设置
pytorch优化器设置
2022-07-28 01:06:00 【Mick..】
深度学习训练过程中学习率的大小十分重要。学习率过低会导致学习太慢,学习率过高会导致难以收敛。通常情况下,初始学习率会比较大,后来逐渐缩小学习率。
通常情况下模型优化器设置
首先定义两层全连接层模型
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Linear(10, 2)
self.layer2 = nn.Linear(2, 10)
def forward(self, input):
return self.layer2(self.layer1(input))神经网络的执行步骤。首先神经网络进过前向传播,这是神经网络框架会搭建好计算图(这里会保存操作和对应参与计算的张量,因为在根据计算图计算梯度时需要这些信息)。然后是误差反向传播,loss.backward() ,这时会计算梯度信息。最后根据梯度信息,更新参数。
loss.backward()
optimizer.step()
optimizer.zero_grad()optimizer.zero_grad() 是将这一轮的梯度清零,防止影响下一轮参数的更新。这里曾问过面试的问题:什么时候不使用这一步进行清零。
model = Net()
# 只传入想要训练层的参数。其他未传入的参数不参与更新
optimizer_Adam = torch.optim.Adam(model.parameters(), lr=0.1)model.parameters()会返回模型的所有参数
只训练模型的部分参数
也就是说只传入模型待优化的参数,为传入的参数不参与更新。
model = Net()
# 只传入待优化的参数
optimizer_Adam = torch.optim.Adam(model.layer1.parameters(), lr=0.1) 不同部分设置不同的学习率
params_dict = [{'params': model.layer1.parameters(), 'lr': 0.01},
{'params': model.layer2.parameters(), 'lr': 0.001}]
optimizer = torch.optim.Adam(params_dict)动态修改学习率
优化器的param_group属性
-param_groups
-0(dict) # 第一组参数
params: # 维护要更新的参数
lr: # 该组参数的学习率
betas:
eps: # 该组参数的学习率最小值
weight_decay: # 该组参数的权重衰减系数
amsgrad:
-1(dict) # 第二组参数
-2(dict) # 第三组参数 parm_group是一个列表,其中每个元素都是一个字典
model = Net() # 生成网络
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # 生成优化器
for epoch in range(100): # 假设迭代100个epoch
if epoch % 5 == 0: # 每迭代5次,更新一次学习率
for params in optimizer.param_groups:
# 遍历Optimizer中的每一组参数,将该组参数的学习率 * 0.9
params['lr'] *= 0.9
边栏推荐
- 支付宝小程序授权/获取用户信息
- MySQL的pymysql操作
- Promise从入门到精通(第4章 async 和 await)
- feign调用get和post记录
- Codeforces Round #810 (Div. 2)A~C题解
- [Yugong series] July 2022 go teaching course 019 - for circular structure
- 对话Atlassian认证专家叶燕秀:Atlassian产品进入后Server时代,中国用户应当何去何从?
- Leetcode hot topic Hot 100 - > 2. Add two numbers
- Appium click operation sorting
- Redis design specification
猜你喜欢

产品解读丨MeterSphere UI测试模块的设计与分布式扩展

Codeworks round 807 (Div. 2) a-c problem solution

清除浮动的原因和六种方法(解决浮动飞起影响父元素和全局的问题)

小程序毕设作品之微信校园浴室预约小程序毕业设计成品(3)后台功能

How to evaluate the effectiveness of R & D personnel? Software Engineer reports help you see everyone's contribution

Flex布局学习完成PC端

都在说DevOps,你真正了解它吗?
![[database data recovery] data recovery case of insufficient disk space of SQL Server database](/img/0e/908db40e1e8b7dd62e12558c1c6dc4.png)
[database data recovery] data recovery case of insufficient disk space of SQL Server database

OBS键盘插件自定义diy

LeetCode 热题 HOT 100 -> 1.两数之和
随机推荐
Product interpretation - Design and distributed expansion of metersphere UI test module
网络必知题目
Plato Farm在Elephant Swap上铸造的ePLATO是什么?
MySQL high availability and master-slave synchronization
【网站搭建】使用acme.sh更新ssl证书:将zerossl改为letsencrypt
获取两个集合相差数据
Common video resolution
【愚公系列】2022年07月 Tabby集成终端的使用
Aike AI frontier promotion (7.14)
synchronized详解
Vxe table/grid cell grouping and merging
Promise from introduction to mastery (Chapter 4 async and await)
Starfish Os X MetaBell战略合作,元宇宙商业生态更进一步
Appium 点击操作梳理
Understand the "next big trend" in the encryption industry - ventures Dao
小程序毕设作品之微信校园浴室预约小程序毕业设计成品(2)小程序功能
产品解读丨MeterSphere UI测试模块的设计与分布式扩展
Likeshop takeout ordering system [100% open source, no encryption]
Talk to ye Yanxiu, an atlassian certification expert: where should Chinese users go when atlassian products enter the post server era?
这个操作可能不值钱,但却值得学习 | 【图片批量裁剪】