当前位置:网站首页>pytorch优化器设置
pytorch优化器设置
2022-07-28 01:06:00 【Mick..】
深度学习训练过程中学习率的大小十分重要。学习率过低会导致学习太慢,学习率过高会导致难以收敛。通常情况下,初始学习率会比较大,后来逐渐缩小学习率。
通常情况下模型优化器设置
首先定义两层全连接层模型
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Linear(10, 2)
self.layer2 = nn.Linear(2, 10)
def forward(self, input):
return self.layer2(self.layer1(input))神经网络的执行步骤。首先神经网络进过前向传播,这是神经网络框架会搭建好计算图(这里会保存操作和对应参与计算的张量,因为在根据计算图计算梯度时需要这些信息)。然后是误差反向传播,loss.backward() ,这时会计算梯度信息。最后根据梯度信息,更新参数。
loss.backward()
optimizer.step()
optimizer.zero_grad()optimizer.zero_grad() 是将这一轮的梯度清零,防止影响下一轮参数的更新。这里曾问过面试的问题:什么时候不使用这一步进行清零。
model = Net()
# 只传入想要训练层的参数。其他未传入的参数不参与更新
optimizer_Adam = torch.optim.Adam(model.parameters(), lr=0.1)model.parameters()会返回模型的所有参数
只训练模型的部分参数
也就是说只传入模型待优化的参数,为传入的参数不参与更新。
model = Net()
# 只传入待优化的参数
optimizer_Adam = torch.optim.Adam(model.layer1.parameters(), lr=0.1) 不同部分设置不同的学习率
params_dict = [{'params': model.layer1.parameters(), 'lr': 0.01},
{'params': model.layer2.parameters(), 'lr': 0.001}]
optimizer = torch.optim.Adam(params_dict)动态修改学习率
优化器的param_group属性
-param_groups
-0(dict) # 第一组参数
params: # 维护要更新的参数
lr: # 该组参数的学习率
betas:
eps: # 该组参数的学习率最小值
weight_decay: # 该组参数的权重衰减系数
amsgrad:
-1(dict) # 第二组参数
-2(dict) # 第三组参数 parm_group是一个列表,其中每个元素都是一个字典
model = Net() # 生成网络
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # 生成优化器
for epoch in range(100): # 假设迭代100个epoch
if epoch % 5 == 0: # 每迭代5次,更新一次学习率
for params in optimizer.param_groups:
# 遍历Optimizer中的每一组参数,将该组参数的学习率 * 0.9
params['lr'] *= 0.9
边栏推荐
- This operation may not be worth money, but it is worth learning | [batch cutting of pictures]
- 新零售业态下,零售电商RPA助力重塑增长
- 轻量版项目管理系统
- How to evaluate the effectiveness of R & D personnel? Software Engineer reports help you see everyone's contribution
- Promise from introduction to mastery (Chapter 1 Introduction and basic use of promise)
- Principle and implementation of focal loss
- Sample imbalance - entry 0
- 软件测试面试题:常见的 POST 提交数据方式
- [Star Project] small hat aircraft War (VI)
- Use of classes in typescript
猜你喜欢
![[Yugong series] use of tabby integrated terminal in July 2022](/img/df/bf01fc77ae019200d1bf57be783cb9.png)
[Yugong series] use of tabby integrated terminal in July 2022

Flex layout - fixed positioning + flow layout - main axis alignment - side axis alignment - expansion ratio

Under the new retail format, retail e-commerce RPA helps reshape growth

Go learning 01

结构伪类选择器—查找单个—查找多个—nth-of-type和伪元素

Talk to ye Yanxiu, an atlassian certification expert: where should Chinese users go when atlassian products enter the post server era?

【愚公系列】2022年07月 Tabby集成终端的使用

Appium click operation sorting

Starfish Os X MetaBell战略合作,元宇宙商业生态更进一步

Uniapp summary (applet)
随机推荐
Product interpretation - Design and distributed expansion of metersphere UI test module
小程序毕设作品之微信校园浴室预约小程序毕业设计成品(1)开发概要
Unittest单元测试框架全栈知识
支付宝小程序授权/获取用户信息
Unity 保存图片到相册以及权限管理
如何评估研发人员效能?软件工程师报告帮你看见每个人的贡献
考研数学一元微分学证明题常见题型方法
[advanced ROS chapter] Lecture 10 gadf integrated simulation process and examples based on gazebo
Ceresdao: the world's first decentralized digital asset management protocol based on Dao enabled Web3.0
Software test interview questions: common post data submission methods
【ROS进阶篇】第十讲 基于Gazebo的URDF集成仿真流程及实例
上课笔记(5)(1)——#593. 二分查找(binary)
CeresDAO:全球首个基于DAO赋能Web3.0的去中心化数字资产管理协议
Four common post data submission methods
MySQL pymysql operation
C# 使用Abp仓储访问数据库时报错记录集
视频常用分辨率
Flex layout learning completed on PC side
cn+dt
Skywalking distributed system application performance monitoring tool - medium