当前位置:网站首页>pytorch优化器设置
pytorch优化器设置
2022-07-28 01:06:00 【Mick..】
深度学习训练过程中学习率的大小十分重要。学习率过低会导致学习太慢,学习率过高会导致难以收敛。通常情况下,初始学习率会比较大,后来逐渐缩小学习率。
通常情况下模型优化器设置
首先定义两层全连接层模型
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layer1 = nn.Linear(10, 2)
self.layer2 = nn.Linear(2, 10)
def forward(self, input):
return self.layer2(self.layer1(input))神经网络的执行步骤。首先神经网络进过前向传播,这是神经网络框架会搭建好计算图(这里会保存操作和对应参与计算的张量,因为在根据计算图计算梯度时需要这些信息)。然后是误差反向传播,loss.backward() ,这时会计算梯度信息。最后根据梯度信息,更新参数。
loss.backward()
optimizer.step()
optimizer.zero_grad()optimizer.zero_grad() 是将这一轮的梯度清零,防止影响下一轮参数的更新。这里曾问过面试的问题:什么时候不使用这一步进行清零。
model = Net()
# 只传入想要训练层的参数。其他未传入的参数不参与更新
optimizer_Adam = torch.optim.Adam(model.parameters(), lr=0.1)model.parameters()会返回模型的所有参数
只训练模型的部分参数
也就是说只传入模型待优化的参数,为传入的参数不参与更新。
model = Net()
# 只传入待优化的参数
optimizer_Adam = torch.optim.Adam(model.layer1.parameters(), lr=0.1) 不同部分设置不同的学习率
params_dict = [{'params': model.layer1.parameters(), 'lr': 0.01},
{'params': model.layer2.parameters(), 'lr': 0.001}]
optimizer = torch.optim.Adam(params_dict)动态修改学习率
优化器的param_group属性
-param_groups
-0(dict) # 第一组参数
params: # 维护要更新的参数
lr: # 该组参数的学习率
betas:
eps: # 该组参数的学习率最小值
weight_decay: # 该组参数的权重衰减系数
amsgrad:
-1(dict) # 第二组参数
-2(dict) # 第三组参数 parm_group是一个列表,其中每个元素都是一个字典
model = Net() # 生成网络
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # 生成优化器
for epoch in range(100): # 假设迭代100个epoch
if epoch % 5 == 0: # 每迭代5次,更新一次学习率
for params in optimizer.param_groups:
# 遍历Optimizer中的每一组参数,将该组参数的学习率 * 0.9
params['lr'] *= 0.9
边栏推荐
- Day6 函数和模块的使用
- Implementation of mongodb/mongotemplate.upsert batch inserting update data
- Promise from introduction to mastery (Chapter 4 async and await)
- [database data recovery] data recovery case of insufficient disk space of SQL Server database
- 「冒死上传」Proe/Creo产品结构设计-止口与扣位
- Product interpretation - Design and distributed expansion of metersphere UI test module
- 微信小程序实现动态横向步骤条的两种方式
- 你所不知道的WMS
- Data output - image annotation and annotation
- zkrollup学习资料汇总
猜你喜欢

The cooperation between starfish OS and metabell is just the beginning

LeetCode 热题 HOT 100 -> 2.两数相加

OBS keyboard plug-in custom DIY

IT这个岗位,人才缺口百万,薪资水涨船高,上不封顶

智能合约安全——selfdestruct攻击

Record a production deadlock

小程序毕设作品之微信校园维修报修小程序毕业设计成品(4)开题报告

Uniapp summary (applet)

Redis design specification

学会这招再也不怕手误让代码崩掉
随机推荐
测试/开发程序员的级别“陷阱“,级别不是衡量单维度的能力......
小程序毕设作品之微信校园浴室预约小程序毕业设计成品(3)后台功能
Class notes (5) (1) - 593. Binary search
一种比读写锁更快的锁,还不赶紧认识一下
埃睿迪再度亮相数字中国峰会 持续深化用科技守护绿水青山
Vxe table/grid cell grouping and merging
一文读懂Plato Farm的ePLATO,以及其高溢价缘由
Two ways for wechat applet to realize dynamic horizontal step bar
Flume (5 demos easy to get started)
Understand the "next big trend" in the encryption industry - ventures Dao
Flex layout learning completed on PC side
synchronized详解
Eredi reappeared at the digital China Summit and continued to deepen the protection of green waters and mountains with science and technology
Traversal and properties of binary trees
正则表达式
Four common post data submission methods
Principle and implementation of cross entropy
feign调用get和post记录
LeetCode 热题 HOT 100 -> 3. 无重复字符的最长子串
Flume(5个demo轻松入门)