当前位置:网站首页>如何替换模型的骨干网络(backbone)

如何替换模型的骨干网络(backbone)

2022-07-06 20:08:00 马少爷

1、替换骨干网络的原则
为什么一些模型能够把其内部的模块进行替换?理由很简单,就是把里面的模块看作一个黑盒子,黑盒子有着输入和输出,那么作为网络中的一个部分,前面有着输入过来,后面也需要输出到其他模块,要想替换该模块而让模型能够运行起来,关键是输入输出的匹配问题,下面就以我自己实验过模型作为例子。
我要替换的骨干网络是3D目标检测的一个方法叫Group-Free 3d,它使用的骨干网络是PointNet++,就是下面的图中用红色框框出来的部分。
在这里插入图片描述
而我想用一个transformer的骨干网络Pointformer替换掉PointNet++。首先这里我说一下为什么我想用Pointformer来替换PointNet++,因为在Pointformer这个论文中,它说Pointformer这个骨干网络可以替换PointNet++来获得更好的性能,我观察到Group-Free 3d中使用的是PointNet++作为骨干网络,但是没有人将其中的PointNet++换为Pointformer,所以我想把Pointformer替换PointNet++看能否提高Group-Free 3d的性能。**所以,在这里强调一下,我所替换的骨干网络是别人的方法所使用过的,但是在新的模型中没有实践过的。**但是我觉得原理都是一样的,就是把输入和输出匹配对应上就可以了。

2、查看网络参数的设置
首先根据Pointformer论文的介绍,Pointforemer是不用经过修改就可以直接替换PointNet++的,但是不保证有人用PointNet++的时候会进行一些输入输出层数的修改,所以这里就需要查看输入和输出层的网络设置。就比如我这次的实验就中用到的Group-Free 3d中PointNet++的最后一层输出大小为288,而一般的都是输出256,如下图所示:
在这里插入图片描述
因此要更改输出256为288。如果说自己使用的代码中没有像这么规范的形式把输入和输出都集中在一个文件上的,可以直接从骨干网络的代码里面找第一层网络的数据输入要求和最后一层的网络的输出数据格式进行修改。
比如:
输入:
在这里插入图片描述
这里的真正的开始对数据进行处理的是下面的那个红色的框,这个时候就可以根据self.sa1()这个函数所涉及的输入以及其网络的参数进行查看了。

在这里插入图片描述

输出:
在这里插入图片描述
骨干网络的最后的一层网络是self.fp2()这个函数,同样的可以去找到它的网络参数设计:
在这里插入图片描述
修改完之后就可以替换骨干网络了:

3、查看输入输出是否匹配
修改完设置之后就可以进行骨干网络的替换了,替换后先查看替换后的骨干网络(Pointfomer)在新的模型中的输入输出是否与原始的骨干网络(PointNet++)的输入输出对应。首先查看原始方法中的骨干网络的输入输出是多少,这时就要用到debug模式了,如下图,PointNet++的输入为:

在这里插入图片描述
注意是要在forward函数里面查看输入的大小,比如这里PointNet++的输入大小为torch.Size([6, 20000, 4]),格式为(batch_size, 点云数目,点云的向量长度),记住这个输入的大小,替换为Pointformer的时候也要让Pointformer能够接受这个大小的数据。
接下来查看PointNet++的输出,直接拉到forward函数的返回语句下面,然后同样的设置断点打印输出包含的东西以及大小。

在这里插入图片描述
抓住主要输出,主要关注在骨干网络输出后下一模块需要哪些输出,比如这里有三个输出集合在一个字典里面,再看看骨干网络之后的模块的输入也是需要这三个输出,如下图所示:
在这里插入图片描述
这里的骨干网络的输出为:

torch.Size([6, 288, 1024])
torch.Size([6, 1024, 3])
torch.Size([6, 1024])

同样的记住这三个数据的大小,在替换为Pointformer之后查看Pointformer的输出是不是与这些数据的大小相匹配。

总结
首先要记住的原则是,替换的骨干网络和原始的骨干网络有着相同大小的数据输入输出。然后是查看骨干网络第一层的数据输入大小和骨干网络最后一层的输出大小,对参数进行修改,修改完之后参数后进行骨干网络的替换,然后查看网络的输入输出是否与原始骨干网络的输入输出相匹配。中间涉及到很多细节,每个人遇到的问题都不一样,本篇文章旨在原理和一些经验的说明,不可能详细到每个细节,总之一句话就是记录输入输出,然后进行修改,多debug就行了。
至于替换后的效果怎么样,老实说,这个是个玄学,涉及到很多东西,也许是你的学习率不够好,连替换之前的方法的精度都不如。也有可能是你加进去的模块和后面的模块有着冲突导致性能的下降,等等等。

参考文献:https://blog.csdn.net/weixin_44715117/article/details/125322327

原网站

版权声明
本文为[马少爷]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_27353621/article/details/125635355