当前位置:网站首页>pytorch转onnx相关问题
pytorch转onnx相关问题
2022-07-27 05:13:00 【Mr_health】
这些问题是在转谱归一化spectral_norm中遇到的。
首先遇到的就是torch.mv算子和torch.dot算子不支持的问题。
目前pytorch已经官方实现了谱归一化:spectral_norm,其中包含了torch.mv、 torch.dot算子,转onnx会出现错误
解决办法:将torch.mv和torch.dot用torch.matmul代替,不过可能需要自己改变一下tensor的维度。(通过unsqueeze之类的)
我再解决了上述两个算子后,能够跑torch.onnx.export函数,但是转换推断的时候会出现:
RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 2 and 1貌似是维度出现了问题,但是我找了几个小时都没有找到问题所在。
后来解决的办法是,在转onnx之前,除去spectral_norm。
具体参考了:https://github.com/pytorch/pytorch/issues/27723
官方已经实现了如何移除spectral_norm的函数:
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
break
else:
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))
for k, hook in module._state_dict_hooks.items():
if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
del module._state_dict_hooks[k]
break
for k, hook in module._load_state_dict_pre_hooks.items():
if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
del module._load_state_dict_pre_hooks[k]
break
return module具体的步骤是:
1.按照训练的时候构建模型model(此时还是含有spectral_norm),并且装载pretrained model,这个pretrained model中含有spectral_norm的相关参数:weight_orig、weight_u以及weight_v。
2.之后利用以下函数,这个函数的输入是构建的model,完成的是递归模型的结构,当遇见spectral_norm时,会调用上面的remove_spectral_norm移除spectral_norm。
def remove_all_spectral_norm(item):
if isinstance(item, nn.Module):
try:
remove_spectral_norm(item)
except Exception:
pass
for child in item.children():
remove_all_spectral_norm(child)
if isinstance(item, nn.ModuleList):
for module in item:
remove_all_spectral_norm(module)
if isinstance(item, nn.Sequential):
modules = item.children()
for module in modules:
remove_all_spectral_norm(module)3.最后跑torch.onnx.export
这里稍微解释一下,普通的卷积操作后面增加spectral_norm后,训练的参数会从卷积的weight会变为weight_orig、weight_u、weight_v这三类,也就是在保存模型的时候保存的都是这些参数。
通过上述的移除操作,会从weight_orig、weight_u、weight_v恢复出weight。
边栏推荐
猜你喜欢

leetcode系列(一):买卖股票

Day 7. Towards Preemptive Detection of Depression and Anxiety in Twitter

Day 7. Towards Preemptive Detection of Depression and Anxiety in Twitter

【并发编程系列9】阻塞队列之PriorityBlockingQueue,DelayQueue原理分析

数字图像处理——第三章 灰度变换与空间滤波

Day 6. Analysis of the energy transmission process of network public opinion in major medical injury events * -- Taking the "Wei Zexi incident" as an example

数字图像处理——第六章 彩色图像处理

7.合并与分割

Jenkins build image automatic deployment

DSGAN退化网络
随机推荐
Rk3288 board HDMI displays logo images of uboot and kernel
MySQL索引优化相关原理
Day 17.The role of news sentiment in oil futures returns and volatility forecasting
golang控制goroutine数量以及获取处理结果
vscode打造golang开发环境以及golang的debug单元测试
2.简单回归问题
Gbase 8C - SQL reference 6 SQL syntax (6)
【好文种草】根域名的知识 - 阮一峰的网络日志
视觉横向课题bug1:FileNotFoundError: Could not find module ‘MvCameraControl.dll‘ (or one of it
【mysql学习】8
【高并发】面试官
GBASE 8C——SQL参考6 sql语法(5)
Uboot中支持lcd和hdmi显示不同的logo图片
rk3399 gpio口 如何查找是哪个gpio口
Gbase 8C - SQL reference 6 SQL syntax (3)
GBASE 8C——SQL参考6 sql语法(3)
Global evidence of expressed sentimental alterations during the covid-19 pandemics
8.数学运算与属性统计
【MVC架构】MVC模型
数字图像处理第四章——频率域滤波