当前位置:网站首页>torch.optim.Adam() 函数用法
torch.optim.Adam() 函数用法
2022-07-30 16:41:00 【Mick..】
Adam: A method for stochastic optimization
Adam是通过梯度的一阶矩和二阶矩自适应的控制每个参数的学习率的大小。
adam的初始化
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if group['amsgrad']:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
F.adam(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=group['amsgrad'],
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'])
return loss
边栏推荐
- SMI 与 Gateway API 的 GAMMA 倡议意味着什么?
- What does a good resume look like in the eyes of a big factory interviewer?
- Security business revenue growth rate exceeds 70% 360 builds digital security leader
- MySQL 8.0.29 解压版安装教程(亲测有效)
- 【SOC FPGA】Peripheral KEY LED
- DTSE Tech Talk丨Phase 2: 1 hour in-depth interpretation of SaaS application system design
- Invalid or corrupt jarfile xxx.jar
- 大厂面试官眼中的好简历到底长啥样
- 字符串加千分位符与递归数组求和
- CMake库搜索函数居然不搜索LD_LIBRARY_PATH
猜你喜欢
基于STM32F407使用ADC采集电压实验
.NET 6.0中使用Identity框架实现JWT身份认证与授权
rhce笔记2
大厂面试官眼中的好简历到底长啥样
Rounding out the most practical way of several DLL injection
武汉星起航跨境电商有前景吗?亚马逊的未来趋势如何发展?
How does the new retail saas applet explore the way to break the digital store?
The first time I used debug query and found that this was empty, does it mean that the database has not been obtained yet?please help.
php how to query string occurrence position
第一次用debug查询,发现这个为空,是不是代表还没获得数据库的意思?求帮助。
随机推荐
深度学习区分不同种类的图片
全职做自媒体靠谱吗?
Leetcode 119. Yang Hui's Triangle II
23. Please talk about the difference between IO synchronization, asynchronous, blocking and non-blocking
在 Chrome 浏览器中安装 JSON 显示插件
[NCTF2019]Fake XML cookbook-1|XXE漏洞|XXE信息介绍
第六章:决胜秋招
What does a good resume look like in the eyes of a big factory interviewer?
【SOC FPGA】Peripheral KEY LED
CMake库搜索函数居然不搜索LD_LIBRARY_PATH
node.js中怎么连接redis?
如何在 UE4 中用代码去控制角色移动
新技术要去做新价值
Nervegrowold d2l (7) kaggle housing forecast model, numerical stability and the initialization and activation function
vivo宣布延长产品保修期限 系统上线多种功能服务
(1) Cloud computing technology learning - virtualized vSphere learning
加密生活,Web3 项目合伙人的一天
Chapter 5 Advanced SQL Processing
报错500,“message“: “nested exception is org.apache.ibatis.binding.BindingException: 解决记录
Jetpack Compose 到底优秀在哪里?| 开发者说·DTalk