当前位置:网站首页>pytorch模型转libtorch和onnx格式的通用代码
pytorch模型转libtorch和onnx格式的通用代码
2022-08-02 14:09:00 【虹夭】
依赖
- torch
- onnx
- onnx simplifer
需要自己设置的重要参数
- model_path 模型权重路径
- model 网络实例
- inp 样例输入,就是一个shape合法的tensor,batchsize(第一维)设置为1就行
下面以torchvision自带的resnet101模型为例。权重是使用官方的预训练模型,调用resnet101(pretrained=True)时会自动下载到%USERPROFILE%/.cache/torch/hub下面
import onnx
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchvision.models.resnet import resnet101
from utils.func import file_size, colorstr
model_path = './weights/resnet101.pth' # 模型权重路径
model = resnet101() # 模型对象
height, width = 640, 640
inp = torch.zeros([1, 3, height, width]) # 样例输入,用于trace
# common
half = True # fp16量化
# onnx profile
onnx_export = True # 是否输出onnx格式
opset_version = 13 # 算子集版本
dynamic = False # 是否动态输入batchsize,需要设置下面两个选项
input_names = ['inputs']
dynamic_axes = {
'inputs': {
0: 'batch', 1: 'kp28'}, # 动态batchsize设置
'output': {
0: 'batch', 1: 'classes'}}
simplify = True # 是否简化
# libtorch profile
libtorch_export = True # 是否输出libtorch格式
optimize = False # 针对移动端优化,不是移动端别用
strict = False # 严格模式,设置False就行
if __name__ == '__main__':
model.load_state_dict(torch.load(model_path))
model.cpu().eval()
if half:
inp, model = inp.half(), model.half()
if onnx_export:
prefix = colorstr('ONNX:')
f = model_path.replace('.pth', '.onnx') # filename
torch.onnx.export(model, inp, f, verbose=False, opset_version=opset_version, input_names=input_names,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
dynamic_axes=dynamic_axes if dynamic else None)
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph)) # print
# Simplify
if simplify:
try:
import onnxsim
print(f'simplifying with onnx-simplifier {
onnxsim.__version__}...')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=dynamic,
input_shapes={
'images': list(inp.shape)} if dynamic else None)
assert check, 'assert check failed'
onnx.save(model_onnx, f)
except Exception as e:
print(f'{
prefix} simplifier failure: {
e}')
print(f'{
prefix} export success, saved as {
f} ({
file_size(f):.1f} MB)')
if libtorch_export:
prefix = colorstr('TorchScript:')
try:
print(f'\n{
prefix} starting export with torch {
torch.__version__}...')
f = model_path.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, inp, strict=strict)
(optimize_for_mobile(ts) if optimize else ts).save(f)
print(f'{
prefix} export success, saved as {
f} ({
file_size(f):.1f} MB)')
except Exception as e:
print(f'{
prefix} export failure: {
e}')
边栏推荐
猜你喜欢

Word2vec词向量

深度学习之文本分类总结

利用红外-可见光图像数据集OTCBVS打通图像融合、目标检测和目标跟踪

How to set the win10 taskbar does not merge icons

FP5207电池升压 5V9V12V24V36V42V大功率方案

小T成长记-网络篇-1-什么是网络?

Binder ServiceManager解析

Bert系列之 Transformer详解

A clean start Windows 7?How to load only the basic service start Windows 7 system

Win10系统设置application identity自动提示拒绝访问怎么办
随机推荐
PyTorch④---DataLoader的使用
The overlapping effect of the two surfaceviews is similar to the video and handout practice in the live effect
7. How to add the Click to RecyclerView and LongClick events
PyTorch②---transforms结构及用法、常见的Transforms
ASR6601牛羊定位器芯片GPS国内首颗支持LoRa的LPWAN SoC
How to set the win10 taskbar does not merge icons
HAL框架
TypeScript 快速进阶
Pytorch(16)---搭建一个完整的模型
利用红外-可见光图像数据集OTCBVS打通图像融合、目标检测和目标跟踪
图像配置分类及名词解释
FP7195芯片PWM转模拟调光至0.1%低亮度时恒流一致性的控制原理
Win7遇到错误无法正常开机进桌面怎么解决?
DP1332E内置c8051的mcu内核NFC刷卡芯片国产兼容NXP
jest测试,组件测试
PyTorch⑥---卷积神经网络_池化层
语言模型(NNLM)
LLVM系列第三章:函数Function
PyTorch(15)---模型保存和加载
PyTorch②---transforms结构及用法