当前位置:网站首页>torch.optim.Adam()
torch.optim.Adam()
2022-07-30 05:38:00 【向大厂出发】
torch.optim
torch.optim是一个实现了多种优化算法的包,大多数通用的方法都已支持,提供了丰富的接口调用,未来更多精炼的优化算法也将整合进来。
为了使用torch.optim,需先构造一个优化器对象Optimizer,用来保存当前的状态,并能够根据计算得到的梯度来更新参数。
要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。 然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))Optimizer还支持指定每个参数选项。 只需传递一个可迭代的dict来替换先前可迭代的Variable(变量)。dict中的每一项都可以定义为一个单独的参数组,参数组用一个params键来包含属于它的参数列表。其他键应该与优化器接受的关键字参数相匹配,才能用作此组的优化选项。
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)如上,model.base.parameters()将使用1e-2的学习率,model.classifier.parameters()将使用1e-3的学习率。0.9的momentum作用于所有的parameters。
优化步骤:
所有的优化器Optimizer都实现了step()方法来对所有的参数进行更新,它有两种调用方法:
方法一:
optimizer.step()这是大多数优化器都支持的简化版本,使用如下的backward()方法来计算梯度的时候会调用它。
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
方法二:
optimizer.step(closure)一些优化算法,如共轭梯度和LBFGS需要重新评估目标函数多次,所以你必须传递一个closure以重新计算模型。 closure必须清除梯度,计算并返回损失。
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
class torch.optim.Adam()
class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)params(iterable):可用于迭代优化的参数或者定义参数组的dicts。
lr (float, optional) :学习率(默认: 1e-3)
betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.999))
eps (float, optional):为了提高数值稳定性而添加到分母的一个项(默认: 1e-8)
weight_decay (float, optional):权重衰减(如L2惩罚)(默认: 0)
参考:torch.optim优化算法理解之optim.Adam()_shuaiqidexiaojiejie的博客-CSDN博客_optim.adam
边栏推荐
- More fragrant open source projects than Ruoyi in 2022
- [Redis Master Cultivation Road] Jedis - the basic use of Jedis
- [Koltin Flow (1)] Five ways to create flow
- 【飞控开发基础教程9】疯壳·开源编队无人机-PWM(电机控制)
- MySQL 灵魂 16 问,你能撑到第几问?
- 2022 SQL big factory high-frequency practical interview questions (detailed analysis)
- MySQL 有这一篇就够(呕心狂敲37k字,只为博君一点赞!!!)
- MySQL-Explain详解
- Teach you to completely uninstall MySQL
- 【Koltin Flow(一)】五种创建flow的方式
猜你喜欢

Navicat connection MySQL error: 1045 - Access denied for user 'root'@'localhost' (using password YES)

MySQL Soul 16 Questions, how many questions can you last?

MYSQL-InnoDB的线程模型

Navicat new database

每日练习------输出一个整数的二进制数、八进制数、十六进制数。

从字节码角度带你彻底理解异常中catch,return和finally,再也不用死记硬背了

idea 编译protobuf 文件的设置使用
![[Mysql] DATEDIFF function](/img/cd/7d19e668701cdd5542b6e43f4c2ad4.png)
[Mysql] DATEDIFF function

JVM 垃圾回收 超详细学习笔记(二)

mysql 中 in 的用法
随机推荐
cmd (command line) to operate or connect to the mysql database, and to create databases and tables
最新版MySQL 8.0 的下载与安装(详细教程)
图形镜像对称(示意图)
MySQL kills 10 questions, how many questions can you stick to?
Navicat connection MySQL error: 1045 - Access denied for user 'root'@'localhost' (using password YES)
Mysql8.+学习笔记
瑞吉外卖项目:新增菜品与菜品分页查询
2022年比若依更香的开源项目
mysql 中 in 的用法
Basic syntax of MySQL DDL and DML and DQL
解决phpstudy无法启动MySQL服务
分布式事务之 Seata框架的原理和实战使用(三)
MySql模糊查询大全
net start mysql MySQL 服务正在启动 . MySQL 服务无法启动。 服务没有报告任何错误。
postman 请求 post 调用 传 复合 json数据
MySQL(4)
What is SOA (Service Oriented Architecture)?
net start mysql MySQL service is starting. MySQL service failed to start.The service did not report any errors.
asyncawait和promise的区别
JVM 内存结构 超详细学习笔记(一)