当前位置:网站首页>PyTorch量化感知训练(QAT)步骤
PyTorch量化感知训练(QAT)步骤
2022-06-30 21:30:00 【小风_】
# QAT follows the same steps as PTQ, with the exception of the training loop before you actually convert the model to its quantized version
# QAT遵循与PTQ相同的步骤,除了在实际将模型转换为量化版本之前进行训练循环
''''''
'''量化感知训练步骤: step1.搭建模型 step2.融合(可选步骤) step3.插入stubs(1和3可合在一起) step4.准备(主要是选择架构) step5.训练 step6.模型转换 '''
import torch
from torch import nn
backend = "fbgemm" # running on a x86 CPU. Use "qnnpack" if running on ARM.
'''step1.搭建模型build model'''
m = nn.Sequential(
nn.Conv2d(2,64,8),
nn.ReLU(),
nn.Conv2d(64, 128, 8),
nn.ReLU(),
)
"""step2.融合Fuse(可选步骤)"""
torch.quantization.fuse_modules(m, ['0','1'], inplace=True) # fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ['2','3'], inplace=True) # fuse second Conv-ReLU pair
"""step3.插入stubs于模型,Insert stubs"""
m = nn.Sequential(torch.quantization.QuantStub(),
*m,
torch.quantization.DeQuantStub())
"""step4.准备Prepare"""
m.train()
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare_qat(m, inplace=True)
"""step5.训练Training Loop"""
n_epochs = 10
opt = torch.optim.SGD(m.parameters(), lr=0.1)
loss_fn = lambda out, tgt: torch.pow(tgt-out, 2).mean()
for epoch in range(n_epochs):
x = torch.rand(10,2,24,24)
out = m(x)
loss = loss_fn(out, torch.rand_like(out))
opt.zero_grad()
loss.backward()
opt.step()
print(loss)
"""step6.模型转换Convert"""
m.eval()
torch.quantization.convert(m, inplace=True)
边栏推荐
- 1-18 创建最基本的express服务器&创建路由的API模块
- Side sleep ha ha ha
- Adobe Photoshop (PS) - script development - remove file bloated script
- Analysis and proposal on the "sour Fox" vulnerability attack weapon platform of the US National Security Agency
- How to move forward when facing confusion in scientific research? How to give full play to women's advantages in scientific research?
- Export El table as is to excel table
- 1-10 respond to client content according to different URLs
- ArcGIS construction and release of simple road network data service and rest call test
- Sqlserver gets the data of numbers, Chinese and characters in the string
- SQL server extracts pure numbers from strings
猜你喜欢

布隆过滤器

Why have the intelligent investment advisory products collectively taken off the shelves of banks become "chicken ribs"?
Understand what MySQL index push down (ICP) is in one article

jenkins下载插件下载不了,解决办法

银行集体下架的智能投顾产品,为何成了“鸡肋”?

开源实习经验分享:openEuler软件包加固测试

Text recognition svtr paper interpretation

科研中遇到迷茫困惑如何向前一步?如何在科研中发挥女性优势?

文本生成模型退化怎麼辦?SimCTG 告訴你答案

Five years after graduation, I wondered if I would still be so anxious if I hadn't taken the test
随机推荐
双立体柱状图/双y轴
银行集体下架的智能投顾产品,为何成了“鸡肋”?
How to run jenkins build, in multiple servers with ssh-key
激发新动能 多地发力数字经济
Who are you and I
Clickhouse distributed table engine
Double solid histogram / double y-axis
Adobe Photoshop (PS) - script development - remove file bloated script
Phoenix architecture: an architect's perspective
1-10 respond to client content according to different URLs
To the Sultanate of Anderson
Two skylines
Four Misunderstandings of Internet Marketing
《ClickHouse原理解析与应用实践》读书笔记(2)
The 16th Heilongjiang Provincial Collegiate Programming Contest
1-2 安装并配置MySQL相关的软件
Test medal 1234
Understand what MySQL index push down (ICP) is in one article
.netcore redis GEO类型
文本生成模型退化怎麼辦?SimCTG 告訴你答案