当前位置:网站首页>pytorch转onnx相关问题
pytorch转onnx相关问题
2022-07-27 05:13:00 【Mr_health】
这些问题是在转谱归一化spectral_norm中遇到的。
首先遇到的就是torch.mv算子和torch.dot算子不支持的问题。
目前pytorch已经官方实现了谱归一化:spectral_norm,其中包含了torch.mv、 torch.dot算子,转onnx会出现错误
解决办法:将torch.mv和torch.dot用torch.matmul代替,不过可能需要自己改变一下tensor的维度。(通过unsqueeze之类的)
我再解决了上述两个算子后,能够跑torch.onnx.export函数,但是转换推断的时候会出现:
RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 2 and 1貌似是维度出现了问题,但是我找了几个小时都没有找到问题所在。
后来解决的办法是,在转onnx之前,除去spectral_norm。
具体参考了:https://github.com/pytorch/pytorch/issues/27723
官方已经实现了如何移除spectral_norm的函数:
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
break
else:
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))
for k, hook in module._state_dict_hooks.items():
if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
del module._state_dict_hooks[k]
break
for k, hook in module._load_state_dict_pre_hooks.items():
if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
del module._load_state_dict_pre_hooks[k]
break
return module具体的步骤是:
1.按照训练的时候构建模型model(此时还是含有spectral_norm),并且装载pretrained model,这个pretrained model中含有spectral_norm的相关参数:weight_orig、weight_u以及weight_v。
2.之后利用以下函数,这个函数的输入是构建的model,完成的是递归模型的结构,当遇见spectral_norm时,会调用上面的remove_spectral_norm移除spectral_norm。
def remove_all_spectral_norm(item):
if isinstance(item, nn.Module):
try:
remove_spectral_norm(item)
except Exception:
pass
for child in item.children():
remove_all_spectral_norm(child)
if isinstance(item, nn.ModuleList):
for module in item:
remove_all_spectral_norm(module)
if isinstance(item, nn.Sequential):
modules = item.children()
for module in modules:
remove_all_spectral_norm(module)3.最后跑torch.onnx.export
这里稍微解释一下,普通的卷积操作后面增加spectral_norm后,训练的参数会从卷积的weight会变为weight_orig、weight_u、weight_v这三类,也就是在保存模型的时候保存的都是这些参数。
通过上述的移除操作,会从weight_orig、weight_u、weight_v恢复出weight。
边栏推荐
- Numpy basic learning
- Day 11. Evidence for a mental health crisis in graduate education
- 视觉横向课题bug1:FileNotFoundError: Could not find module ‘MvCameraControl.dll‘ (or one of it
- Gbase 8C - SQL reference 6 SQL syntax (9)
- rk3399 gpio口 如何查找是哪个gpio口
- GBASE 8C——SQL参考6 sql语法(3)
- Rk3288 board HDMI displays logo images of uboot and kernel
- GBASE 8C——SQL参考6 sql语法(13)
- 什么是okr,和kpi的区别在哪里
- php 定义数组使用逗号,
猜你喜欢

14.实例-多分类问题

【好文种草】根域名的知识 - 阮一峰的网络日志

Emoji表情符号用于文本情感分析-Improving sentiment analysis accuracy with emoji embedding

Day 9. Graduate survey: A love–hurt relationship

数字图像处理——第九章 形态学图像处理

Day 3. Suicidal ideation and behavior in institutions of higher learning: A latent class analysis

Minio8.x version setting policy bucket policy

19.上下采样与BatchNorm

9. High order operation

根据文本自动生成UML时序图(draw.io格式)
随机推荐
Gbase 8C - SQL reference 4 character set support
15.GPU加速、minist测试实战和visdom可视化
Gbase 8C - SQL reference 6 SQL syntax (11)
舆情&传染病时空分析文献阅读笔记
Emoji Emoji for text emotion analysis -improving sentimental analysis accuracy with Emoji embedding
Day 17.The role of news sentiment in oil futures returns and volatility forecasting
Day 7. Towards Preemptive Detection of Depression and Anxiety in Twitter
Only one looper may be created per thread
Do you really know session and cookies?
【mysql学习】8
golang控制goroutine数量以及获取处理结果
php 定义数组使用逗号,
Docker deploys the stand-alone version of redis - modify the redis password and persistence method
Day14. Using interpretable machine learning method to distinguish intestinal tuberculosis and Crohn's disease
Brief analysis of application process creation process of activity
3. Classification problems - initial experience of handwritten digit recognition
常用adb命令汇总 性能优化
3.分类问题---手写数字识别初体验
Gbase 8C - SQL reference 6 SQL syntax (5)
【高并发】面试官