当前位置:网站首页>Pytorch quantitative practice (2)
Pytorch quantitative practice (2)
2022-06-30 21:43:00 【Breeze_】
Translation source https://pytorch.org/blog/quantization-in-practice/
Quantification is a cheap and simple method , It can make the deep neural network model run faster , And has lower memory requirements .PyTorch Several different methods of quantifying models are provided . In this blog post , We will ( Fast ) Lay the foundation for quantification in deep learning , Then look at how each technology works in practice . Last , We will conclude with the recommendations in the literature on the use of quantification in workflow .
Quantification method
PyTorch Several different methods are allowed to quantify the model :
- If you prefer flexible but manual , Or a limited automatic process (Eager Patterns and FX Graph Pattern )
- If quantification is active ( Layer output ) Of qparams Pre calculated for all inputs , Or recalculate for each input ( Static and dynamic )
- Calculation qparams Did you retrain after that (quantization-aware training and post-training quantization)
FX Graph The mode automatically fuses the modules that meet the conditions , Insert Quant/DeQuant stubs, Calibrate the model and return to a quantification module , All of this is in two method calls , But only for symbol traceable (symbolic traceable) Network of . The following examples include the use of Eager Mode and FX Graph Mode Call to compare
stay DNNs in , The candidates eligible for quantification are FP32 The weight (layer Parameters ) And activation (layer Output ). Quantifying weights can reduce the size of the model . Quantitative activation usually leads to faster inferences . for example ,50 Layer of ResNet The Internet has ~ 2600 10000 weight parameters , Calculate in forward pass ~ 1600 Ten thousand activations .
Post-Training Dynamic/Weight-only Quantization Dynamic quantification after pre training
Here the weight of the model is pre quantified . In the process of reasoning , Activation is quantified in real time (“ dynamic ”). This is the simplest of all methods , It's in torch. quantized .quantize_dynamic There's only one line in API call . Currently only linear and recursive are supported (LSTM, GRU, RNN) Layer for dynamic quantification .
advantage :
- It can produce higher accuracy , Because the shear range accurately calibrates each input
- about LSTMs and transformer Such a model , Dynamic quantization is preferred , In these models , Write... From memory / The weight of the retrieval model dominates the bandwidth
shortcoming :
- Calibrating and quantifying the activation of each layer at runtime increases the computational overhead .
import torch
from torch import nn
# toy model
m = nn.Sequential(
nn.Conv2d(2, 64, (8,)),
nn.ReLU(),
nn.Linear(16,10),
nn.LSTM(10, 10))
m.eval()
## EAGER MODE
from torch.quantization import quantize_dynamic
model_quantized = quantize_dynamic(
model=m, qconfig_spec={
nn.LSTM, nn.Linear}, dtype=torch.qint8, inplace=False
)
## FX MODE
from torch.quantization import quantize_fx
qconfig_dict = {
"": torch.quantization.default_dynamic_qconfig} # An empty key denotes the default applied to all modules
model_prepared = quantize_fx.prepare_fx(m, qconfig_dict)
model_quantized = quantize_fx.convert_fx(model_prepared)
Post-Training Static Quantization (PTQ) Static quantification after pre training
PTQ It is also a pre quantized model weight , But not real-time calibration activation , Instead, use validation data to pre calibrate and fix (“ static state ”) Shear range of . During reasoning , Activation between operations maintains the accuracy of quantification . about 100 Representative data from a small batch is sufficient to calibrate the observer's [2]. For convenience , The following example uses random data during calibration —— Using it in an application will cause errors qparams.
[ Failed to transfer the external chain picture , The origin station may have anti-theft chain mechanism , It is suggested to save the pictures and upload them directly (img-5YeN8Hxc-1654075075319)(https://pytorch.org/assets/images/quantization-practice/ptq-flowchart.svg)]
Module fusion combines multiple sequential modules ( Such as :[Conv2d, BatchNorm, ReLU]) Merge into one module . The fusion module means that the compiler only needs to run one kernel , Not multiple ; This can speed up and improve accuracy by reducing quantization errors .
advantage :
- Static quantization has faster inference speed than dynamic quantization , Because it eliminates the float<->int Conversion overhead .
shortcoming :
- Static quantitative models may require periodic recalibration , To maintain the robustness to distribution drift .
# Static quantization of a model consists of the following steps:
# Fuse modules
# Insert Quant/DeQuant Stubs
# Prepare the fused module (insert observers before and after layers)
# Calibrate the prepared module (pass it representative data)
# Convert the calibrated module (replace with quantized version)
import torch
from torch import nn
backend = "fbgemm" # running on a x86 CPU. Use "qnnpack" if running on ARM.
m = nn.Sequential(
nn.Conv2d(2, 64, 3),
nn.ReLU(),
nn.Conv2d(64, 128, 3),
nn.ReLU()
)
## EAGER MODE
"""Fuse - Inplace fusion replaces the first module in the sequence with the fused module, and the rest with identity modules """
torch.quantization.fuse_modules(m, ['0', '1'], inplace=True) # fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ['2', '3'], inplace=True) # fuse second Conv-ReLU pair
"""Insert stubs"""
m = nn.Sequential(torch.quantization.QuantStub(),
*m,
torch.quantization.DeQuantStub())
"""Prepare"""
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(m, inplace=True)
"""Calibrate - This example uses random data for convenience. Use representative (validation) data instead. """
with torch.no_grad():
for _ in range(10):
x = torch.rand(1, 2, 28, 28)
m(x)
"""Convert"""
torch.quantization.convert(m, inplace=True)
"""Check"""
print(m[1].weight().element_size()) # 1 byte instead of 4 bytes for FP32
## FX GRAPH
from torch.quantization import quantize_fx
model_to_quantize = nn.Sequential(
nn.Conv2d(2, 64, 3),
nn.ReLU(),
nn.Conv2d(64, 128, 3),
nn.ReLU()
)
model_to_quantize.eval()
qconfig_dict = {
"": torch.quantization.get_default_qconfig(backend)}
# Prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# Calibrate - Use representative (validation) data.
with torch.no_grad():
for _ in range(10):
x = torch.rand(1, 2, 28, 28)
model_prepared(x)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
print(model_quantized)
Quantization-aware Training (QAT) Quantitative post training
[ Failed to transfer the external chain picture , The origin station may have anti-theft chain mechanism , It is suggested to save the pictures and upload them directly (img-vHXsOp2M-1654075075321)(https://pytorch.org/assets/images/quantization-practice/qat-flowchart.svg)]
PTQ The method is applicable to large models , But in smaller models, the accuracy will be affected . Of course , This is due to the fact that FP32 The model of is adjusted to INT8 Field will cause loss of numerical accuracy ( The figure below a).QAT This problem is solved by including quantization error in training loss , So as to train a INT8-first Model .

All weights and deviations are stored in FP32 in , Back propagation happens as usual . However , In forward transmission , Quantification is done through FakeQuantize The module performs internal simulation . They are called fake , Because they quantify the data and immediately de quantify it , The quantization noise similar to that may be encountered in the process of quantization inference is added . therefore , The final loss explains any expected quantization error . Optimization on this basis can make the model identify a larger area in the loss function ( Upper figure b), And identify FP32 Parameters , So as to quantify it to INT8 There will be no significant deviation in .
characteristic
advantage :QAT The accuracy of is higher than PTQ.
advantage :Qparams You can learn during model training , For finer grained accuracy ( See LearnableFakeQuantize)
shortcoming : Model in QAT The computational cost of retraining can reach hundreds epoch
# QAT follows the same steps as PTQ, with the exception of the training loop before you actually convert the model to its quantized version
import torch
from torch import nn
backend = "fbgemm" # running on a x86 CPU. Use "qnnpack" if running on ARM.
m = nn.Sequential(
nn.Conv2d(2,64,8),
nn.ReLU(),
nn.Conv2d(64, 128, 8),
nn.ReLU()
)
"""Fuse"""
torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair
"""Insert stubs"""
m = nn.Sequential(torch.quantization.QuantStub(),
*m,
torch.quantization.DeQuantStub())
"""Prepare"""
m.train()
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare_qat(m, inplace=True)
"""Training Loop"""
n_epochs = 10
opt = torch.optim.SGD(m.parameters(), lr=0.1)
loss_fn = lambda out, tgt: torch.pow(tgt-out, 2).mean()
for epoch in range(n_epochs):
x = torch.rand(10,2,24,24)
out = m(x)
loss = loss_fn(out, torch.rand_like(out))
opt.zero_grad()
loss.backward()
opt.step()
print(loss)
"""Convert"""
m.eval()
torch.quantization.convert(m, inplace=True)
Sensitivity analysis
Not all layers react the same to quantification , Some layers are more sensitive to precise descent than others . Determining the best layer combination that minimizes accuracy is time consuming , therefore [3] It is recommended to conduct sensitivity analysis once , To determine which layers are the most sensitive , And keep on these layers FP32 The accuracy of the . In their experiment , Just skip 2 Transport layer ( stay MobileNet v1 In total 28 Transport layer ), You can get close to fp32 The accuracy of the . Use FX Graphic mode , We can easily create customizations qconfig:
# ONE-AT-A-TIME SENSITIVITY ANALYSIS
for quantized_layer, _ in model.named_modules():
print("Only quantizing layer: ", quantized_layer)
# The module_name key allows module-specific qconfigs.
qconfig_dict = {
"": None,
"module_name":[(quantized_layer, torch.quantization.get_default_qconfig(backend))]}
model_prepared = quantize_fx.prepare_fx(model, qconfig_dict)
# calibrate
model_quantized = quantize_fx.convert_fx(model_prepared)
# evaluate(model)
Another way is to compare FP32 and INT8 Layer statistics ; The commonly used measure is signal-to-noise ratio ( Signal-to-noise ratio ) And mean square error . This comparative analysis is also helpful to guide further optimization .

PyTorch stay Numeric Suite Tools to help with this analysis are provided under . from Full tutorial Learn more about using Numeric Suite Information about .
# extract from https://pytorch.org/tutorials/prototype/numeric_suite_tutorial.html
import torch.quantization._numeric_suite as ns
def SQNR(x, y):
# Higher is better
Ps = torch.norm(x)
Pn = torch.norm(x-y)
return 20*torch.log10(Ps/Pn)
wt_compare_dict = ns.compare_weights(fp32_model.state_dict(), int8_model.state_dict())
for key in wt_compare_dict:
print(key, compute_error(wt_compare_dict[key]['float'], wt_compare_dict[key]['quantized'].dequantize()))
act_compare_dict = ns.compare_model_outputs(fp32_model, int8_model, input_data)
for key in act_compare_dict:
print(key, compute_error(act_compare_dict[key]['float'][0], act_compare_dict[key]['quantized'][0].dequantize()))
Quantitative workflow recommendations
[ Failed to transfer the external chain picture , The origin station may have anti-theft chain mechanism , It is suggested to save the pictures and upload them directly (img-WqvFAyjp-1654075075325)(https://pytorch.org/assets/images/quantization-practice/quantization-flowchart2.png)]
The main points of
Big (10M+ Parameters ) The model is more robust to quantization error .
from FP32 The checkpoint quantification model is better than training from zero INT8 The model provides better accuracy
The analysis model runtime is optional , But it can help identify the layers of bottleneck inference .
Dynamic quantification is a simple first step , Especially if your model has many linear or cyclic layers .
Use symmetric per channel quantization with MinMax The observer quantifies the weights . Use a MovingAverageMinMax The observer's affine quantization of each tensor is activated
Use something like SQNR Such indicators are used to identify which layers are most prone to quantitative errors . Turn off quantification on these layers .
Use QAT Fine tuning approx 10% Your original workout , The annealing learning rate is planned from the initial training learning rate 1% Start .
边栏推荐
- Text recognition svtr paper interpretation
- Clickhouse native monitoring item, system table description
- A group of K inverted linked lists
- Ml & DL: introduction to hyperparametric optimization in machine learning and deep learning, evaluation index, over fitting phenomenon, and detailed introduction to commonly used parameter adjustment
- ceshi deces
- 用yml文件进行conda迁移环境时的报错小结
- Four Misunderstandings of Internet Marketing
- Can flinksql two Kafka streams join?
- 1-19 利用CORS解决接口跨域问题
- 的撒啊苏丹看老司机
猜你喜欢

Radar data processing technology

asp. Net core JWT delivery

1-2 install and configure MySQL related software
Testing media cache

pytorch geometric torch-scatter和torch-sparse安装报错问题解决

本地浏览器打开远程服务器上的Jupyter Notebook/Lab以及常见问题&设置

5G 在智慧医疗中的需求

Iclr'22 spotlight | how to measure the amount of information in neural network weights?

《ClickHouse原理解析与应用实践》读书笔记(2)

Excitatory neurotransmitter glutamate and brain health
随机推荐
[grade evaluator] how to register a grade evaluator? How many passes?
1-12 初步认识Express
1-20 预检请求
Ten security measures against unauthorized access attacks
Who are you and I
Side sleep ha ha ha
Auto-created primary key used when not defining a primary key
CA I ah, several times Oh, ah, a sentence IU home Oh
The 16th Heilongjiang Provincial Collegiate Programming Contest
.netcore redis GEO类型
1-17 express中间件
Inventory the six second level capabilities of Huawei cloud gaussdb (for redis)
Clickhouse Native Monitoring item, System table Description
Analysis and proposal on the "sour Fox" vulnerability attack weapon platform of the US National Security Agency
1-2 install and configure MySQL related software
[untitled]
请问,启牛证券开户,可以开户吗?安全吗?你想要的答案全在这里
Icml2022 | utility theory of sequential decision making
ca i啊几次哦啊句iu家哦
ML&DL:機器學習和深度學習中超參數優化的簡介、評估指標、過擬合現象、常用的調參優化方法之詳細攻略