当前位置:网站首页>PyTorch混合精度原理及如何开启该方法
PyTorch混合精度原理及如何开启该方法
2022-06-26 06:01:00 【Le0v1n】
1. 前置知识
在PyTorch中默认的精度的Float32,即32位浮点数。
使用自动混合精度(Automatic Mixed Precision)的目的是让模型在训练时,Tensor的精度设置为16而不是32。因为32的精度对于模型学习来说,没什么必要。
在PyTorch1.6中已经内置的混合精度的包,如下:
from torch.cuda.amp import Scaler, autocast
自动混合精度,也就是torch.FloatTensor和torch.HalfTensor的混合。
1.1 关于数据类型的思考
思考一个问题:为什么不使用纯torch.FloatTensor或者纯torch.HalfTensor?
要想回答这个问题,我们首先需要知道这两种数据类型有什么特点,确切来说它俩各自有什么优势和劣势。
torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备。因此训练的时候可以减少显存的占用。由于计算简单,训练速度更快。根据NVIDIA官方的介绍,在某些设备上,使用此精度模型的训练速度可以加速一倍。torch.HalfTensor的劣势就是:- 数值范围小(更容易Overflow / Underflow)-> 有时会导致
loss变为nan,从而无法训练 - 存在舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)
- 数值范围小(更容易Overflow / Underflow)-> 有时会导致
torch.FloatTensor的优点是没有torch.HalfTensor那样的缺点torch.FloatTensor的缺点就是占用显存大,训练速度慢
可见,当有优势的场景尽量使用torch.HalfTensor可以加速训练。
1.2 消除torch.HalfTensor缺点的两种方案
为了消除torch.HalfTensor的劣势,一般会有两种方案。
1.2.1 方案1:torch.cuda.amp.GradScaler
梯度scale(缩放),使用的工具为torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的underflow(这仅仅使用在梯度反向传播时,在optimizer进行更新权重时还是要把放大的梯度再unscale回去)
- 梯度回传 -> scale -> 放大梯度
- 更新参数 -> 不使用scaled的梯度
1.2.2 方案2:autocast()的上下文管理器或装饰器
- 上下文管理器 ->
with autocast(): - 装饰器 ->
@autocast()
回落到torch.FloatTensor,这就是混合一词的由来。那怎么知道什么时候用torch.FloatTensor,什么时候用半精度浮点型呢?这是PyTorch框架决定的,在PyTorch 1.6的AMP上下文(或装饰器)中,如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:
| 操作 | 说明 |
|---|---|
__matmul__ | ⊙ \odot ⊙ |
addbmm | 批次的矩阵 ⊗ \otimes ⊗ |
addmm | torch.addmm(input, mat1, mat2),mat1和mat2执行矩阵乘法,结果与input相加 |
addmv | torch.addmv(input, mat, vec) -> mat和vec执行 ⊙ \odot ⊙,再将结果与input相加 |
addr | torch.addr(input, vec1, vec2) -> vec1 ⊗ \otimes ⊗ vec2 + input |
baddbmm | torch.baddbmm(input, batch1, batch2) -> batch1 ⊙ \odot ⊙batch2 + input |
bmm | torch.bmm(input, mat2) -> input ⊙ \odot ⊙ mat2 |
chain_matmul | torch.chain_matmul(*matrices) -> 返回NN二维张量的矩阵乘积。该乘积使用矩阵链序算法有效计算,该算法选择在算术运算方面产生最低成本的顺序 |
conv1d | 一维卷积 |
conv2d | 二维卷积 |
conv3d | 三维卷积 |
conv_transpose1d | 一维转置卷积 |
conv_transpose2d | 二维转置卷积 |
conv_transpose3d | 三维转置卷积 |
linear | 线性层 |
matmul | torch.matmul(input, other) -> 两个tensor执行 ⊙ \odot ⊙ |
mm | torch.mm(input, mat2) -> input ⊗ \otimes ⊗ mat2 |
mv | torch.mv(input, vec) -> input ⊙ \odot ⊙ vec |
prelu | 自学习的ReLU激活函数(ReLU的变体) |
- mm -> matrix & matrix
- mv -> matrix & vector
- 不理解 ⊙ \odot ⊙符号的含义可以看这篇博文:Computer Vision的论文中“圈加”、“圈乘”和“点乘”的解释
- 转置卷积推荐博文:转置卷积(Transposed Convolution)的介绍以及理论讲解
- PReLU不明白可以看博文:深度学习中常用激活函数分析
2. PyTorch如何使用AMP(自动混合精度)
说白了就是:
- autocast
- GradScaler
掌握这两部分的使用就可以了。
2.1 autocast
正如前文所说,AMP需要使用torch.cuda.amp模块中的autocast类。
下面是一个标准的分类网络训练过程(不包含预测阶段)
# 创建model,默认是torch.FloatTensor
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), ...)
for epoch in range(1, args.epochs+1): # epoch -> [1, epochs]
if rank == 0: # 在主进程中使用tqdm对dataloader进行包装
data_loader = tqdm(data_loader, file=sys.stdout)
for step, inputs, labels in enumerate(data_loader): # 迭代data_loader
optimizer.zero_grad() # 首先清空优化器中的梯度残留
pred = model(input.to(device)) # 网络正向传播获取预测结果
loss = loss_fn(pred, labels.to(device)) # 使用loss函数计算预测值和GT直接的差距,从而计算出loss
loss.backward() # 计算完loss后进行反向传播
if rank == 0: # 在主进程中打印训练信息
data_loader.desc = f"[train]epoch {
epoch}/{
opt.n_epochs} | lr: {
optimizer.param_groups[0]['lr']:.4f} | mloss: {
round(mean_loss.item(), 4):.4f}"
# 每张卡判断自己求出来的loss是否为有限数据
if not torch.isfinite(loss):
print(f"WARNING: non-finite loss, ending training! loss -> {
loss}")
sys.exit(1) # 如果loss为无穷 -> 退出训练
# 每张卡判断自己求出来的判断loss是否为nan
if torch.isnan(loss):
print(f"WARNING: nan loss, ending training! loss -> {
loss}")
sys.exit(1) # 如果nan为无穷 -> 退出训练
optimizer.step() # 最后优化器更新参数
如果要使用AMP,则也是非常简单的,如下所示:
from torch.cuda.amp import autocast
# 创建model,默认是torch.FloatTensor
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), ...)
for epoch in range(1, args.epochs+1): # epoch -> [1, epochs]
if rank == 0: # 在主进程中使用tqdm对dataloader进行包装
data_loader = tqdm(data_loader, file=sys.stdout)
for step, inputs, labels in enumerate(data_loader): # 迭代data_loader
optimizer.zero_grad() # 首先清空优化器中的梯度残留
""" 仅仅在前向推理和求loss的时候开启autocast即可! """
with autocast(): # 建立autocast的上下文语句
pred = model(input.to(device)) # 网络正向传播获取预测结果
loss = loss_fn(pred, labels.to(device)) # 使用loss函数计算预测值和GT直接的差距,从而计算出loss
loss.backward() # 计算完loss后进行反向传播
if rank == 0: # 在主进程中打印训练信息
data_loader.desc = f"[train]epoch {
epoch}/{
opt.n_epochs} | lr: {
optimizer.param_groups[0]['lr']:.4f} | mloss: {
round(mean_loss.item(), 4):.4f}"
# 每张卡判断自己求出来的loss是否为有限数据
if not torch.isfinite(loss):
print(f"WARNING: non-finite loss, ending training! loss -> {
loss}")
sys.exit(1) # 如果loss为无穷 -> 退出训练
# 每张卡判断自己求出来的判断loss是否为nan
if torch.isnan(loss):
print(f"WARNING: nan loss, ending training! loss -> {
loss}")
sys.exit(1) # 如果nan为无穷 -> 退出训练
optimizer.step() # 最后优化器更新参数
当进入autocast的上下文后,上面列出来的那些CUDA操作(1.2.2 方案2中的表格)会把tensor的dtype转换为torch.HalfTensor,从而在不损失训练精度的情况下加快运算。
刚进入autocast的上下文时,tensor可以是任何类型,不需要在model或者inputs上手工调用.half(),PyTorch框架会自动帮你完成,这也是自动混合精度中 自动 一词的由来。
另外一点就是,autocast上下文应该只包含网络的前向过程(包括loss的计算),而不要包含反向传播,因为反向传播的操作(operations)会使用和前向操作相同的类型。
2.2 autocast报错
有的时候,代码在autocast上下文中会报如下的错误:
Traceback (most recent call last):
......
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
......
RuntimeError: expected scalar type float but found c10::Half
在模型的forward函数上面添加autocast()装饰器,以MobileNet v3 Small为例:
from torch.cuda.amp import autocast
class MobileNetV3(nn.Module):
def __init__(self, num_classes=27, sample_size=112, dropout=0.2, width_mult=1.0):
super(MobileNetV3, self).__init__()
input_channel = 16
last_channel = 1024
# 各种网络定义...
# 各种网络定义...
@autocast()
def forward(self, x): # 这时MobileNet v3 Small总的forward,在这个函数上面加上autocast()装饰器即可
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
需要注意的是:
- 仅在最终的
forward函数上加上即可,不需要在内部其他的forward函数上加autocast()装饰器了
2.3 GradScaler
别忘了前面提到的梯度scaler模块呀,需要在训练最开始之前实例化一个GradScaler对象。因此PyTorch中经典的AMP使用方式如下:
from torch.cuda.amp import autocast, GradScaler
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()
for epoch in range(1, args.epochs+1): # epoch -> [1, epochs]
if rank == 0: # 在主进程中使用tqdm对dataloader进行包装
data_loader = tqdm(data_loader, file=sys.stdout)
for step, inputs, labels in enumerate(data_loader): # 迭代data_loader
optimizer.zero_grad() # 首先清空优化器中的梯度残留(不变)
""" 仅仅在前向推理和求loss的时候开启autocast即可! """
with autocast(): # 建立autocast的上下文语句
pred = model(input.to(device)) # 网络正向传播获取预测结果
loss = loss_fn(pred, labels.to(device)) # 使用loss函数计算预测值和GT直接的差距,从而计算出loss
# 使用scaler先对loss进行放大,再反向传播放大后的梯度
scaler.scale(loss).backward()
if rank == 0: # 在主进程中打印训练信息
data_loader.desc = f"[train]epoch {
epoch}/{
opt.n_epochs} | lr: {
optimizer.param_groups[0]['lr']:.4f} | mloss: {
round(mean_loss.item(), 4):.4f}"
# 每张卡判断自己求出来的loss是否为有限数据
if not torch.isfinite(loss):
print(f"WARNING: non-finite loss, ending training! loss -> {
loss}")
sys.exit(1) # 如果loss为无穷 -> 退出训练
# 每张卡判断自己求出来的判断loss是否为nan
if torch.isnan(loss):
print(f"WARNING: nan loss, ending training! loss -> {
loss}")
sys.exit(1) # 如果nan为无穷 -> 退出训练
# scaler.step() 首先把梯度的值unscale回来.
# 如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,
# 否则,忽略step调用,从而保证权重不更新(不被破坏)
""" 这句话是这样理解的: 1. 首先把梯度的值缩放回原来的样子 2. 如果梯度值不是infs或nan,那么就会自动调用optimizer.step()来更新权重 3. 如果梯度值是infs或nan,则不进行optimizer.step() -> 保证权重不被破环(这样明显错误的梯度会让权重直接损坏!) 而scaler.update()这句话的含义是:根据loss的情况让scaler的放大系数动态调整 """
scaler.step(optimizer)
scaler.update()
scaler的大小在每次迭代中动态的估计,为了尽可能的减少梯度underflow,scaler应该更大;但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者nan)。所以动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值——在每次scaler.step(optimizer)中,都会检查是否有inf或NaN的梯度出现:
- 如果出现了
inf或者nan,scaler.step(optimizer)会忽略此次的权重更新(不调用optimizer.step()),并且将scaler的大小缩小(乘上backoff_factor) - 如果没有出现
inf或者nan,那么权重正常更新(调用optimizer.step()),并且当连续多次(growth_interval指定)没有出现inf或者nan,则scaler.update()会将scaler的大小增加(乘上growth_factor)
可以使用PyTorch项目规范来简化开发:https://github.com/deepVAC/deepvac/。
3. 注意事项
3.1 loss出现inf或nan
- 可以不使用
GradScaler-> 直接使用autocast()上下文管理器,然后loss.backward()->optimizer.step() loss scale时梯度偶尔overflow可以忽略,因为amp会检测溢出情况并跳过该次更新(如果自定义了optimizer.step的返回值,会发现溢出时step返回值永远是None),scaler下次会自动缩减倍率,如果长时间稳定更新,scaler又会尝试放大倍数- 一直显示overflow而且
loss很不稳定的话就需要适当调小学习率(建议10倍往下调),如果loss还是一直在波动,那可能是网络深层问题了。
3.2 使用AMP后速度变慢
可能的原因如下:
- 单精度和半精度之间的转换开销,不过这部分开销比较小,相比之下半精度减少的后续计算量可以cover住
- 梯度回传时的数值放大和缩小,即加了scaler会变慢,这部分开销应该是蛮大的,本身需要回传的参数梯度就很多,再加上乘法和除法操作,但是如果不加scaler,梯度回传的时候就容易出现underflow(
16 bit能表示的精度有限,梯度值太小丢失信息会很大),所以不加scaler最后的结果可能会变差。整体来讲这是一个balance问题,属于时间换空间。
3.3 推荐文章
参考
- https://zhuanlan.zhihu.com/p/165152789
- https://blog.csdn.net/weixin_44878336/article/details/124501040
- https://blog.csdn.net/weixin_44878336/article/details/124754484
- https://blog.csdn.net/weixin_44878336/article/details/125119242
- https://zhuanlan.zhihu.com/p/516996892
边栏推荐
- tf.nn.top_k()
- REUSE_ ALV_ GRID_ Display event implementation (data_changed)
- 电商借助小程序技术发力寻找增长突破口
- Selective search for object recognition paper notes [image object segmentation]
- 从新东方直播来探究下小程序音视频通话及互动直播
- MySQL database-01 database overview
- Getting to know concurrency problems
- MySQL-06
- 302. minimum rectangular BFS with all black pixels
- Younger sister Juan takes you to learn JDBC -- two days' Sprint Day2
猜你喜欢
![[C language] deep analysis of data storage in memory](/img/2e/ff0b5326d796b9436f4a10c10cfe22.png)
[C language] deep analysis of data storage in memory

The purpose of writing programs is to solve problems

Unicloud cloud development obtains applet user openid

MySQL database-01 database overview

小程序如何关联微信小程序二维码,实现二码聚合

原型模式,咩咩乱叫

canal部署、原理和使用介绍

Tencent WXG internship experience (has offered), I hope it will help you!

Logstash -- send an alert message to the nail using the throttle filter

Tortoise and rabbit race example
随机推荐
numpy.log
Application of cow read / write replication mechanism in Linux, redis and file systems
Pytorch (environment, tensorboard, transforms, torchvision, dataloader)
Func < T, tresult > Commission - learning record
Data visualization practice: Data Visualization
去哪儿网BI平台建设演进史
numpy.tile()
实时数仓方案如何选型和构建
Class and object learning
小程序第三方微信授权登录的实现
How to associate wechat applet QR code to realize two code aggregation
在web页面播放rtsp流视频(webrtc)
The difference between abstract and interface interface
421- binary tree (226. reversed binary tree, 101. symmetric binary tree, 104. maximum depth of binary tree, 222. number of nodes of complete binary tree)
Logstash——Logstash向Email发送告警邮件
Unicloud cloud development obtains applet user openid
Volatile application scenarios
冒泡排序(Bubble Sort)
String class learning
机器学习 07:PCA 及其 sklearn 源码解读