当前位置:网站首页>pytorch with Automatic Mixed Precision(AMP)
pytorch with Automatic Mixed Precision(AMP)
2022-06-09 05:18:00 【kaims】
PyTorch 源码解读之 torch.cuda.amp: 自动混合精度详解 - 知乎
Automatic Mixed Precision examples — PyTorch 1.9.1 documentation
torch.cuda.amp 提供了较为方便的混合精度训练机制:
用户不需要手动对模型参数 dtype 转换,amp 会自动为算子选择合适的数值精度
对于反向传播的时候,FP16 的梯度数值溢出的问题,amp 提供了梯度 scaling 操作,而且在优化器更新参数前,会自动对梯度 unscaling,所以,对用于模型优化的超参数不会有任何影响
以上两点,分别是通过使用amp.autocast和amp.GradScaler来实现的。
basic
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()gradient clipping
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()gradient accumulation
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss = loss / accumulate_steps
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if i % accumulate_steps == 0:
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
# unscale 梯度,可以不影响clip的threshold
scaler.unscale_(optimizer)
# clip梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()AMP in DDP
autocast 设计为 “thread local” 的,所以只在 main thread 上设 autocast 区域是不 work 的,所以,还需要对model的forward进行修饰:
MyModel(nn.Module):
...
@autocast()
def forward(self, input):
...或者在forward中设置autocast区域:
MyModel(nn.Module):
...
def forward(self, input):
with autocast():
...第一种在使用DDP时出错了(显示forward的某些参数没有正常获取到,未解决……)
边栏推荐
- Test question bank and online simulation test for operation certificate of main principals of hazardous chemical business units in 2022
- [it] Foxit PDF retention tool selection
- Program implementation of inserting, updating and deleting in Oracle Internet cafe design
- Lighting - 光的亮度衰减
- validate-npm-package-name
- array
- Marathon环境下fastdfs和vsftpd和miniIo文件服务器搭建的方式
- P1743 Audiophobia
- [005] [esp32 Development Notes] ADF basic framework
- Clcnet: Rethink integrated modeling with classified confidence network (with source code download)
猜你喜欢

Product weekly report issue 28 | CSDN editor upgrade, adding the function of inserting existing videos

Transaction code qc51 of SAP QM preliminary level creates quality certificate for purchase order

2022 "Cyberspace Security" event module B of Jiangxi secondary vocational group - SQL injection test

2022 tea artist (intermediate) examination question simulation examination question bank and simulation examination

Apache devlake code base guide
![[C language] a quick pass operator](/img/b3/df88d5d3945b553c1bb67247eda788.jpg)
[C language] a quick pass operator

PS how to border an image

2022年危险化学品经营单位主要负责人操作证考试题库及在线模拟考试

【IT】福昕pdf保持工具選擇

The 27th issue of product weekly report | members' new interests of black users; CSDN app v5.1.0 release
随机推荐
How WPS ppt pictures come out one by one
Transformer里面的缓存机制
Typescript learning [8] enumeration type
由id获取name调用示例(腾讯IM)
AQS 之 CountdownLatch 源码分析
Ribbon vs feign - with simple examples
Differences between tinyint and int
Simple process and problem handling of cmdbuilding
wps ppt图片如何一张一张出来
Kube dns yaml
Good hazelnut comes from Liaoyang!
[django learning notes - 12]: database operation
Typescript learning [9] generic
Clcnet: Rethink integrated modeling with classified confidence network (with source code download)
内网渗透 - 哈希传递攻击
FPGA based TDC Research Report
Typescript learning [6] interface type
[series of troubles caused by inaccurate positioning] Part 2: what's wrong with the weak satellite signal
pytest_ Introduction to allure priority and fixture scope parameters
How to change the color of WPS ppt background picture