当前位置:网站首页>Training course of mixed accuracy from simple to deep
Training course of mixed accuracy from simple to deep
2022-06-24 04:13:00 【PaperWeekly】

author | serendipity
Company | Tongji University
Research direction | Pedestrian search
2022 The present of , Mixing accuracy (Automatically Mixed Precision, AMP) Training has become a standard tool for alchemists , Just a few lines of code , You can halve the memory usage , Double training speed .
AMP Technology is developed by Baidu and NIVDIA Team in 2017 Put forward in (Mixed Precision Training [1]), The results are published in ICLR On .PyTorch 1.6 Before , We all use NVIDIA Of apex [2] Library to implement AMP Training .1.6 After the version ,PyTorch It comes from the factory AMP.
This article explains : How to be in PyTorch Use in AMP、AMP Principle 、AMP Code implementation of .

How to be in PyTorch Use in AMP
If you're new , I just want to give it a simple try AMP, Just add the relevant training code
output = net(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()Modify as follows .
with torch.cuda.amp.autocast():
output = net(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()If GPU Support Tensor Core (Volta、Turing、Ampere framework ),AMP Will greatly reduce the consumption of video memory , Speed up your training . For other types of GPU, The video memory can still be reduced , But the training speed may slow down .

AMP Principle
2.1 What is? FP16
Semi precision floating point number (FP16) Is a binary floating-point data type used by computers , Use 2 byte (16 position ) Storage , The range is . and PyTorch By default Single-precision floating-point (FP32) To calculate the network model and store the weight .FP32 Use... In memory 4 byte (32 position ) Storage , The range is
402 Payment Required
. You can see FP32 The range that can be represented is greater than FP16 Much larger .In addition, floating point numbers have another magical feature : When two numbers differ too much , Addition is invalid , Also known as rounding error [3].
Use a piece of code to show :
>>> # FP32 There is no problem adding .
>>> torch.tensor(2**-3) + torch.tensor(2**-14)
tensor(0.1251)
>>> # FP16 Add up , Smaller numbers are ignored . Because in [2**-3, 2**-2] Inside ,FP16 The fixed interval indicated is 2**-13.
>>> # That is to say, bi 2**-3 The next big number is 2**-3 + 2**-13, therefore 2**-14 It's the same as it's not .
>>> # half() The role of the FP32 Turn into FP16.
>>> torch.tensor(2**-3).half() + torch.tensor(2**-14).half()
tensor(0.1250, dtype=torch.float16)
>>> # take 2**-14 Switch to 2**-13 That's all right. .
>>> torch.tensor(2**-3).half() + torch.tensor(2**-13).half()
tensor(0.1251, dtype=torch.float16)2.2 Why use FP16
If we will FP32 Substitute for FP16, There are two advantages :
1. Reduce memory usage :FP16 Of video memory is only FP32 Half of , This allows us to use larger batch size;
2. Speed up training : Use FP16, The training speed of the model can almost be improved 1 times .
2.3 Why only FP16 There will be problems.
If we simply take the model weights and inputs from FP32 Turn it into FP16, Although the speed can be doubled , But the accuracy of the model will be seriously affected . Here's why :
On / underflow :FP16 The expression range of is not large , exceed The number of will overflow into inf, Less than The number of will overflow and become 0. Underflow is more common , Because in the late stage of network training , The gradient of the model is often very small , Even less than FP16 The lower limit of , At this point, the gradient value becomes 0, Model parameters cannot be updated . The following figure for SSD The gradient statistics of the network in the training process , Yes 67% The overflow under the value of becomes 0.

Rounding error : Even if the gradient doesn't go up / underflow , If the gradient value is too far from the parameter value of the model , Rounding errors also occur . Suppose the model parameters weight , Learning rate , gradient gradient ,weight weight gradient .
2.4 Solution
Loss scaling (Loss Scaling)
In order to solve the problem of lower overflow , In the paper, the calculated loss Value to zoom (scale), Because of the chain rule , Yes loss The scaling of is applied to each gradient . Scaled gradient , Will pan to FP16 Within the effective range of . In this way, you can use FP16 Store gradients without overflowing . Besides , Before updating , need First convert the scaled gradient to FP32, Then scale the gradient back (unscale) Go back .
Note that this must be converted to FP32, Otherwise unscale It still overflows when it comes to .
Zoom factor (loss_scale) Generally, the framework determines automatically , As long as it doesn't happen inf perhaps nan,loss_scale The bigger the better . Because as the training goes on , The gradient of the network will be smaller and smaller , Bigger loss_scale Can make full use of FP16 Representation range of .
FP32 Weight backup
In order to achieve FP16 Training for , We need to convert the model weights and input data into FP16, Back propagation will get FP16 Gradient of . If you update directly at this time , because gradient * Learning rate The value of is often small , There will be a big gap with the model weight , Rounding errors may occur .
So the solution is : take Model weight 、 Activation value 、 gradient For other data FP16 To store , At the same time, maintain a FP32 Of Copy of model weights Used to update the . Get in the back propagation FP16 After the gradient of , Turn it into FP32 and unscale, The last update FP32 Weight of the model . Because the whole update process is in FP32 It's going on in the environment of , So there is no rounding error .
FP32 Weight backup solves the rounding error problem of back propagation .

The blacklist
For those in FP16 Unstable modules running in the environment , We will add it to the blacklist , Force it to FP32 Run with the accuracy of . For example, calculation is needed batch Mean BN The layer should be FP32 Run under , Otherwise, rounding error will occur . There are also some functions that require high algorithm accuracy , such as torch.acos(), It should be in FP32 Run under . The blacklist in the paper only contains BN layer .
How to ensure that the blacklist module is in FP32 Running in the environment : With BN Layer as an example , Change its weight to FP32, And input from FP16 Turn into FP32, This will ensure that the entire module is in FP32 Running down .
Blacklist resolves some functions in FP16 The problem of arithmetic instability in the environment .
Tensor Core

Tensor Core It can make FP16 Do matrix multiplication , Then add the results to FP32 In the matrix of . In this way, you can enjoy FP16 High speed matrix multiplication , It can be used again FP32 To eliminate rounding errors .
Don't understand Tensor Core How is it applied to AMP Medium . Some people say Tensor Core Can help us take advantage of FP16 To update FP32 Weight of the model . But I read apex After the source code of , I find FP16 The gradient of will first be transformed into FP32, Update again , So weight update and Tensor Core No relationship . I'll come back to you after I understand .
2.5 Some thoughts
In fact, will FP16 and FP32 Mixed use is the inevitable result , There are several reasons :
1. In the late stage of network training , The value of the gradient is very small , May let FP16 underflow . If not used FP32, Even if we temporarily avoid this problem by scaling , Weight update unscale The operation will still overflow the gradient ;
2. Take on the second 1 strip , Even if the gradient can be FP16 Express , however gradient * Learning rate May overflow . So the operation of weight update still needs to be done in FP32 Run under ;
3. Take on the second 2 strip , Even if the gradient * Learning rate No spillage , Its value is also very small relative to the weight itself . The weight + gradient * Learning rate Rounding error may occur in this operation ;
4. Take on the second 3 strip , Even if the The weight + gradient * Learning rate No rounding error occurs , Some operators are FP16 It is also unstable , such as BN、torch.acos etc. .

NVIDIA apex Library code interpretation
First of all apex Several kinds of opt-level: o1, o2, o3, o4. Notice that this is Letter "o" Not numbers "0".

The picture is from : The most complete network - Principle of mixed accuracy training
https://zhuanlan.zhihu.com/p/441591808
o0 Is pure FP32, Used as a benchmark for accuracy .o3 Is pure FP16, Used as a benchmark for speed .
Key points o1 and o2 . What we said before AMP The strategy is actually o2: except BN Layer weights and inputs are used FP32, The rest of the weights and inputs of the model are converted to FP16. In addition, a FP32 To perform the update operation .
and o2 Different , o1 No longer need FP32 Weight backup , because o1 Our model has always been FP32. Maybe some readers will be curious , Since the model parameters are FP32, How to use it in training FP16 Well ? The answer is o1 Set up a PyTorch Black and white list of functions , For functions on the whitelist , Its use is mandatory FP16, That is, the parameters of the function will be converted to FP16, And execute the function itself . Blacklists are mandatory FP32.
With nn.Linear For example , This module has two weight parameters weight and bias, Input is input, Forward propagation is to call torch.nn.functional.linear(input, weight, bias).o1 The pattern will input、weight、bias First convert to FP16 Format input_fp16、weight_fp16、bias_fp16, Call function again torch.nn.functional.linear(input_fp16, weight_fp16, bias_fp16). In this way, the model parameters are FP32, But you can still use FP16 To speed up training .
o1 And one more detail : Although the white list PyTorch The function is based on FP16 Running , But the resulting gradient is FP32, So there is no need to manually turn it into FP32 Again unscale, direct unscale that will do .
I guess PyTorch Will let each Tensor The data type of itself is consistent with that of the gradient , Although it produced FP16 Gradient of , But because the weight itself is FP32, So the framework will also convert the gradient into FP32.
if o1 yes FP16 + FP32, More radical o2 Namely almost FP16 ( Almost all FP16). Generally speaking o1 Than o2 More stable , Usually choose first o1, Try again o2 See if there is any drop , Use it if you don't lose any points o2.
3.1 apex Of o1 Realization
1. According to the black and white list PyTorch Built in functions to wrap [4]. The whitelist function forces FP16, Blacklist function enforces FP32. Other functions automatically judge according to the parameter type , If the parameters are all FP16, with FP16 function , If you have a parameter that is FP32, with FP32 function .
2. take loss_scale Initialize to a large value [5].
3. For each iteration
(a). Forward propagation : The model weight is FP32, Automatically select operator precision according to the black and white list .
(b). take loss multiply loss_scale [6]
(c). Back propagation , Because the model weight is FP32, So even if the function is written in FP16 function , Will also get FP32 Gradient of .
(d). Will gradient unscale [7], Divided by loss_scale
(e). If detected inf or nan [8]
i. loss_scale /= 2 [9]
ii. Skip this update [10]
(f). optimizer.step(), Perform this update
(g). If it's continuous 2000 There are no iterations inf or nan, be loss_scale *= 2 [11]
3.2 apex Of o2 Realization
1. Will be in addition to BN Model weights outside the layer are converted to FP16 [12], And packed forward function [13], Its parameters are also converted to FP16;
2. Maintain a FP32 A copy of the model weights for updating [14];
3. take loss_scale Initialize to a large value [15];
4. For each iteration
(a). Forward propagation : except BN Layer is FP32, The rest of the model is FP16.
(b). take loss multiply loss_scale [16]
(c). Back propagation , obtain FP16 Gradient of
(d). take FP16 The gradient is transformed into FP32, and unscale [17]
(e). If detected inf or nan [18]
i. loss_scale /= 2 [19]
ii. Skip this update [20]
(f). optimizer.step(), Perform this update
(g). If it's continuous 2000 There are no iterations inf or nan, be loss_scale *= 2 [21]
Besides , It is also recommended to read MMCV about AMP Of o2 Realization [22], Code ratio apex More clarity . But because I want to say o1 and o2, There is no choice to interpret MMCV Code for , Interested readers can further study .

Reference link

[1] https://arxiv.org/abs/1710.03740
[2] https://github.com/NVIDIA/apex
[3] https://en.wikipedia.org/wiki/Round-off_error#Addition
[4] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/amp.py#L68
[5] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L40
[6] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L113
[7] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_process_optimizer.py#L123
[8] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L202
[9] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L207
[10] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L128
[11] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L213
[12] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_initialize.py#L179
[13] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_initialize.py#L194
[14] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_process_optimizer.py#L44
[15] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L40
[16] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L113
[17] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L94
[18] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L202
[19] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L207
[20] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L128
[21] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L213
[22] https://github.com/open-mmlab/mmcv/blob/f5425ab7611ab2376ddb478b57cb2f46f6054e13/mmcv/runner/hooks/optimizer.py#L344
[23] https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html
[24] https://pytorch.org/docs/stable/amp.html#autocast-op-reference
[26] https://zhuanlan.zhihu.com/p/103685761
[27] https://zhuanlan.zhihu.com/p/441591808
Read more

# cast draft through Avenue #
Let your words be seen by more people
How to make more high-quality content reach the reader group in a shorter path , How about reducing the cost of finding quality content for readers ? The answer is : People you don't know .
There are always people you don't know , Know what you want to know .PaperWeekly Maybe it could be a bridge , Push different backgrounds 、 Scholars and academic inspiration in different directions collide with each other , There are more possibilities .
PaperWeekly Encourage university laboratories or individuals to , Share all kinds of quality content on our platform , It can be Interpretation of the latest paper , It can also be Analysis of academic hot spots 、 Scientific research experience or Competition experience explanation etc. . We have only one purpose , Let knowledge really flow .
The basic requirements of the manuscript :
• The article is really personal Original works , Not published in public channels , For example, articles published or to be published on other platforms , Please clearly mark
• It is suggested that markdown Format writing , The pictures are sent as attachments , The picture should be clear , No copyright issues
• PaperWeekly Respect the right of authorship , And will be adopted for each original first manuscript , Provide Competitive remuneration in the industry , Specifically, according to the amount of reading and the quality of the article, the ladder system is used for settlement
Contribution channel :
• Send email :[email protected]
• Please note your immediate contact information ( WeChat ), So that we can contact the author as soon as we choose the manuscript
• You can also directly add Xiaobian wechat (pwbot02) Quick contribution , remarks : full name - contribute

△ Long press add PaperWeekly Small make up
Now? , stay 「 You know 」 We can also be found
Go to Zhihu home page and search 「PaperWeekly」
Click on 「 Focus on 」 Subscribe to our column
·
·
·

边栏推荐
- Clickhouse synchronous asynchronous executor
- Clang代码覆盖率检测(插桩技术)
- Clang code coverage detection (pile insertion technology)
- How to spell the iframe address of the video channel in easycvr?
- Black hat actual combat SEO: never be found hijacking
- Methods of creating and modifying shell script files in batch
- Demonstration of C language structure function research
- How to adjust the alarm information that remains unchanged after paging is selected on the easygbs alarm page?
- C string input considerations
- Maintain the visibility of data automation: logging, auditing and error handling of the bridge of knowledge and action
猜你喜欢

Flutter series: offstage in flutter

Black hat SEO actual combat search engine snapshot hijacking

openEuler社区理事长江大勇:共推欧拉开源新模式 共建开源新体系
![Web technology sharing | [map] to realize customized track playback](/img/b2/25677ca08d1fb83290dd825a242f06.png)
Web technology sharing | [map] to realize customized track playback

开源之夏2022中选结果公示,449名高校生将投入开源项目贡献

Black hat SEO practice: General 301 weight PR hijacking

讲讲我的不丰富的远程办公经验和推荐一些办公利器 | 社区征文

应用实践 | Apache Doris 整合 Iceberg + Flink CDC 构建实时湖仓一体的联邦查询分析架构

Brief ideas and simple cases of JVM tuning - how to tune

JVM调优简要思想及简单案例-怎么调优
随机推荐
[hot promotion] Tencent cloud enterprise cloud disk solution
mysql - sql执行过程
Configuration process of easygbs access to law enforcement recorder
Black hat SEO actual combat search engine snapshot hijacking
2. in depth tidb: entry code analysis and debugging tidb
Changjiang Dayong, director of openeuler community: jointly promote the new open source model of Euler and jointly build a new open source system
Black hat SEO actual combat directory wheel chain generates millions of pages in batch
An open source monitoring data collector that can monitor everything
Wide & deep model and optimizer understand code practice
Garbage collection mechanism
Indicator statistics: real time uvpv statistics based on flow computing Oceanus (Flink)
2021 graphic design trend: aesthetic response to chaos
Cadence OrCAD Capture 批量修改网络名称的两种最实用的方法图文教程及视频演示
Exploration of web application component automatic discovery
讲讲我的不丰富的远程办公经验和推荐一些办公利器 | 社区征文
Clickhouse synchronous asynchronous executor
TCP three handshakes and four waves
uni-app进阶之认证【day12】
How to remote server is the price of the server expensive
Easyplayer consumes traffic but does not play video and reports an error libdecoder Wasm404 troubleshooting


