当前位置:网站首页>torch 网络模型转换onnx格式,并可视化
torch 网络模型转换onnx格式,并可视化
2022-06-12 19:30:00 【佐倉】
1. 构建lenet5 网络
import torch.nn as nn
import torch.nn.functional as F
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class LeNet(nn.Module):
def __init__(self,class_num=10,input_shape=(1,32,32)):
super(LeNet, self).__init__()
self.conv1 = nn.Sequential( #input_size=(1*28*28)
nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
nn.ReLU(), #input_size=(6*28*28)
nn.MaxPool2d(kernel_size=2, stride=2), #output_size=(6*14*14)
)
self.conv2 = nn.Sequential(
nn.Conv2d(6, 16, 5), #padding=1输出尺寸变化
nn.ReLU(), #input_size=(16*10*10)
nn.MaxPool2d(2, 2) #output_size=(16*5*5)
)
self.fc1 = nn.Sequential(
nn.Linear(16 * ((input_shape[1]//2-4)//2) * ((input_shape[2]//2-4)//2), 120),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.ReLU()
)
self.fc3 = nn.Linear(84, class_num)
# 定义前向传播过程,输入为x
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
# nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
2. 转为onnx格式
input_shape = (1,100,100) #输入数据
model = LeNet(input_shape=input_shape)
torch.save(model, './model_para.pth')
torch_model = torch.load("./model_para.pth") # pytorch模型加载
batch_size = 1 #批处理大小
# set the model to inference mode
torch_model.eval()
x = torch.randn(batch_size,*input_shape) # 生成张量
print (x.shape)
export_onnx_file = "lenet5.onnx" # 目的ONNX文件名
torch.onnx.export(torch_model,
x,
export_onnx_file,
opset_version=10,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
"output":{0:"batch_size"}})
3. 通过netron查看网络结构
3.1 netron安装
pip install netron
3.2 netron可视化
import netron
onnx_path = "lenet5.onnx"
netron.start(file=onnx_path, log=False, browse=True)

边栏推荐
- Jenkins中pipeline对接CMDB接口获取主机列表的发布实践原创
- Shell programming regular expressions and metacharacters
- [5gc] Introduction to three SSC (session and service continuity) modes
- EASYCODE one click plug-in custom template
- Méthode de sauvegarde programmée basée sur la base de données distribuée elle - même
- 基于微信电子书阅读小程序毕业设计毕设作品(2)小程序功能
- asp. Net using JSON to interact with API data
- RT thread simulator builds lvgl development and debugging environment
- 基于微信电子书阅读小程序毕业设计毕设作品(1)开发概要
- What are meta-inf and WEB-INF respectively?
猜你喜欢

WinCC7.5 SP1调整画面尺寸以适应显示分辨率的方法

The Bean Validation API is on the classpath but no implementation could be found

存储体系概述
![[digital ic/fpga] data accumulation output](/img/58/8d10e41a7bc837feba677f1e0b1ceb.png)
[digital ic/fpga] data accumulation output

Leetcodesql: count the number of students in each major

基于微信电子书阅读小程序毕业设计毕设作品(4)开题报告

5G R17标准冻结,主要讲了些啥?

Implementation of VGA protocol based on FPGA

【图像去噪】基于正则化实现图像去噪附matlab代码

The component style set by uniapp takes effect in H5 and app, but does not take effect in wechat applet. The problem is solved
随机推荐
【生成对抗网络学习 其三】BiGAN论文阅读笔记及其原理理解
QT -- how to get the contents of selected cells in qtableview
typescript的装饰器(Decorotor)基本使用
Lua record
Module 8 fonctionnement
嵌入式开发:固件工程师的6项必备技能
[5gc] Introduction to three SSC (session and service continuity) modes
什么是数据驱动
Mode of most elements (map, sort, random, Boyer Moore voting method)
Leetcodesql: count the number of students in each major
Uniapp uses the Ali Icon
从16页PPT里看懂Jack Dorsey的Web5
Wincc7.5 SP1 method for adjusting picture size to display resolution
mysql的增删改查,mysql常用命令
vc hacon 聯合編程 GenImage3Extern WriteImage
Wangxuegang room+paging3
RT thread simulator builds lvgl development and debugging environment
VC hacon joint programming genimage3extern writeimage
In 2021, the global revenue of chlorinated polyvinyl chloride (CPVC) was about $1809.9 million, and it is expected to reach $3691.5 million in 2028
Shell arrays and functions