当前位置:网站首页>【过一下 17】pytorch 改写 keras
【过一下 17】pytorch 改写 keras
2022-08-05 05:12:00 【墨苏玩电脑】
任务
pytorch自建模型转keras
想法
- 有pytorch的自建模型代码
有自建模型导出的onnx,作为中间件可以onnx转keras(自动转- 搞不出来
查到的博客
首先,我们必须有清楚的认识,网上以及github上一些所谓的pytorch转换Keras或者Keras转换成Pytorch的工具代码几乎不能运行或者有使用的局限性(比如仅仅能转换某一些模型),但是我们是可以用这些转换代码中看出一些端倪来,比如二者的参数的尺寸(shape)的形式、channel的排序(first or last)是否一样,掌握到差异性,就能根据这些差异自己编写转换代码,没错,自己编写转换代码,是最稳妥的办法。整个过程也就分为两个部分。笔者将会以Nvidia开源的FlowNet为例,将开源的Pytorch代码转化为Keras模型。
按照Pytorch中模型的结构,编写对应的Keras代码,用keras的函数式API,构建起来会非常方便。
把Pytorch的模型参数,按照层的名称依次赋值给Keras的模型
以上两步虽然看上去简单,但实际我也走了不少弯路。这里一个关键的地方,就是参数的shape在两个框架中是否统一,那当然是不统一的。查到的博客
pytorch 到 tensorflow 可以用onnx作为中间工具转换,将pytorch转为onnx,再从onnx转为tensorflow,但是中间可能出现一些乱七八糟的问题。其实手动读参数再填充的对应的模型中也很方便,本文就总结一下手动模型转换。
过程
我肯定是先尝试 自动转
pytorch模型转keras模型_AI算法-图哥的博客-CSDN博客_pytorch2keras
ONNX系列二 — 使用ONNX使Keras模型可移植_寒冰屋的博客-CSDN博客_keras onnx
安装一下
接触到了(load和load_state_dict)的问题
Pytorch保存和加载模型(load和load_state_dict)_木盏的博客-CSDN博客_pytorch 加载模型
插曲
安装whl
可以通过查看python版本和安装对应python版本的离线安装包(附.whl安装网页链接)_Az_z的博客-CSDN博客_离线安装指定版本python
真是服了,onnx1.8.1之后移除了了optimizer那个包,然后pytorch2keras这个包要用的就是版本就是1.8.1及之前的,onnx1.8.1支持py3.8,然后我Py3.9。所以只得在另一个虚拟环境(py3.6)里面重新安装了相应的包(400M的tensorflow还有torch等等)。安好了,以为现在可以用了
然后报 modulelist的错
然后我导出成onnx,再load,就报“TypeError: ‘ModelProto’ object is not callable”的错
我真是服了。
看来只能手工写一下了
python关于onnx模型的一些基本操作_一杯盐水的博客-CSDN博客_onnx 静态量化
学习keras,主要是 两边参数不一样
改倒是好改。(不问对错,就是搭积木)
结果
input_data = keras.layers.Input(shape=(4,1), dtype='float64')
cnn_out = keras.layers.Conv1D(1, kernel_size=2, strides=1)(input_data)
cnn_out = keras.layers.MaxPool1D(2)(cnn_out)
lstm_out =keras.layers.LSTM(4)(cnn_out)
dense_out = keras.layers.Dense(3,activation='softmax')(lstm_out)
model = keras.Model(input_data, dense_out)
model.compile(optimizer='adam',loss='mean_absolute_error',metrics=['accuracy'])
model.summary()
改写成功了,但是准确率什么的,很垃
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OSjRcl88-1659412492791)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20220802115449998.png)]](/img/a2/7f0c7eebd119373bf20c44de9f7947.png)
边栏推荐
- [cesium] 3D Tileset model is loaded and associated with the model tree
- server disk array
- Flex layout frog game clearance strategy
- 逆向理论知识4
- 【学生毕业设计】基于web学生信息管理系统网站的设计与实现(13个页面)
- Flutter TapGestureRecognizer 如何工作
- [Decoding tools] Some online tools for Bitcoin
- 软件管理rpm
- Error creating bean with name ‘configDataContextRefresher‘ defined in class path resource
- Geek卸载工具
猜你喜欢
随机推荐
[Surveying] Quick Summary - Excerpt from Gaoshu Gang
The underlying mechanism of the class
mutillidae download and installation
使用二维码解决固定资产管理的难题
coppercam primer [6]
Redis - 13、开发规范
2022 Hangzhou Electric Multi-School 1st Session 01
Flutter TapGestureRecognizer 如何工作
Qt produces 18 frames of Cupid to express his love, is it your Cupid!!!
Dephi逆向工具Dede导出函数名MAP导入到IDA中
【cesium】Load and locate 3D Tileset
C#关于set()和get()方法的理解及使用
8.04 Day35-----MVC三层架构
雷克萨斯lm的安全性到底体现在哪里?一起来看看吧
UVA10827
phone call function
2022杭电多校第一场01
span标签和p标签的区别
jvm three heap and stack
淘宝账号如何快速提升到更高等级








![coppercam primer [6]](/img/d3/a7d44aa19acfb18c5a8cacdc8176e9.png)
