当前位置:网站首页>torch.optim.Adam()
torch.optim.Adam()
2022-07-30 05:38:00 【向大厂出发】
torch.optim
torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用,未来更多精炼的优化算法也将整合进来。
为了使用torch.optim,需先构造一个优化器对象Optimizer,用来保存当前的状态,并能够根据计算得到的梯度来更新参数。
要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。 然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))Optimizer还支持指定每个参数选项。 只需传递一个可迭代的dict来替换先前可迭代的Variable(变量)。dict中的每一项都可以定义为一个单独的参数组,参数组用一个params键来包含属于它的参数列表。其他键应该与优化器接受的关键字参数相匹配,才能用作此组的优化选项。
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)如上,model.base.parameters()将使用1e-2的学习率,model.classifier.parameters()将使用1e-3的学习率。0.9的momentum作用于所有的parameters。
优化步骤:
所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新,它有两种调用方法:
方法一:
optimizer.step()这是大多数优化器都支持的简化版本,使用如下的backward()方法来计算梯度的时候会调用它。
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
方法二:
optimizer.step(closure)一些优化算法,如共轭梯度和LBFGS需要重新评估目标函数多次,所以你必须传递一个closure以重新计算模型。 closure必须清除梯度,计算并返回损失。
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
class torch.optim.Adam()
class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)params(iterable):可用于迭代优化的参数或者定义参数组的dicts。
lr (float, optional) :学习率(默认: 1e-3)
betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
eps (float, optional):为了提高数值稳定性而添加到分母的一个项(默认: 1e-8)
weight_decay (float, optional):权重衰减(如L2惩罚)(默认: 0)
参考:torch.optim优化算法理解之optim.Adam()_shuaiqidexiaojiejie的博客-CSDN博客_optim.adam
边栏推荐
- PyCharm usage tutorial (more detailed, picture + text)
- 坠落的蚂蚁(北京大学考研机试题)
- Different lower_case_table_names settings for server ('1') and data dictionary ('0') solution
- What is SOA (Service Oriented Architecture)?
- leetcode刷题
- Programmers make money and practice, teach you how to do paid courses, self-media, paid articles and paid technical courses to make money
- 1475. 商品折扣后的最终价格
- 4461. 范围分区(Google Kickstart2022 Round C Problem B)
- 4、nerf(pytorch)
- 破纪录者(Google Kickstart2020 Round D Problem A)
猜你喜欢

JVM 类加载机制 超详细学习笔记(三)
![[GO语言基础] 一.为什么我要学习Golang以及GO语言入门普及](/img/ac/80ab67505f7df52d92a206bc3dd50e.png)
[GO语言基础] 一.为什么我要学习Golang以及GO语言入门普及

Teach you to completely uninstall MySQL

CISP-PTE Zhenti Demonstration

从驱动表和被驱动表来快速理解MySQL中的内连接和外连接

Teach you how to design a CSDN system
![[Mysql] DATEDIFF function](/img/cd/7d19e668701cdd5542b6e43f4c2ad4.png)
[Mysql] DATEDIFF function

idea 编译protobuf 文件的设置使用

2022 SQL big factory high-frequency practical interview questions (detailed analysis)

Navicat cannot connect to mysql super detailed processing method
随机推荐
net start mysql MySQL 服务正在启动 . MySQL 服务无法启动。 服务没有报告任何错误。
JVM面试总结
839. 模拟堆
破纪录者(Google Kickstart2020 Round D Problem A)
Countdown (Source: Google Kickstart2020 Round C Problem A) (DAY 88)
[Mysql] DATEDIFF函数
分布式事务之 LCN框架的原理和使用(二)
The difference between asyncawait and promise
[Koltin Flow (2)] The end operator of the Flow operator
[Image processing] Image skeleton extraction based on central axis transformation with matlab code
分布式事务之 Atomikos 原理和使用(一)
ClickHouse 数据插入、更新与删除操作 SQL
cmd(命令行)操作或连接mysql数据库,以及创建数据库与表
从驱动表和被驱动表来快速理解MySQL中的内连接和外连接
子查询作为检索表时的不同使用场景以及是否需要添加别名的问题
2022 Pengcheng Cup web
成绩排序(华中科技大学考研机试题)(DAY 87)
坠落的蚂蚁(北京大学考研机试题)
Summary of SQL classic interview questions in 2022 (with analysis)
MySql的初识感悟,以及sql语句中的DDL和DML和DQL的基本语法