当前位置:网站首页>【PyTorch预训练模型修改、增删特定层】
【PyTorch预训练模型修改、增删特定层】
2022-07-05 11:34:00 【网络星空(luoc)】
一、绪论
在构建深度学习网络的过程中,经常会遇到需要对预训练模型进行修改和增删特定层的操作。
torchvision.models提供了丰富的模型满足不同任务的选择,因此在构建网络结构时,无需从头开始复现某个网络结构,只需在官方库的基础上进行修改即可。
二、官方模型库
pytorch提供的模型可以通过以下链接查询:https://pytorch.org/vision/stable/models.html,分为分类、分割、目标检测实例分割与关键点检测和视频分类4个分类,可按需寻找需要的模型。
下面以分类任务为例,使用到的是resnet。torchvision.models提供了resnet18,resnet34,resnet50,resnet101,resnet152。右侧两列分别是它们在ImageNet上的top1 Accuracy和top5 Accuracy。

这里以resnet50为例。函数说明如下:

import torchvision.models as models
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
def forward(self,x):
x = self.model(x)
return x
这样,我们就定义了一个Net,这个Net是一个使用了预训练权重的resnet50.
三、修改特定层
使用过程中,我们可能经常会遇到的一个问题是,输入的通道数和网络首层通道数不一致的问题。这里就需要对首层conv进行修改。如果我们初始初始化了一个conv层,又想使用预训练的权重,这时候怎么办呢?我们可以通过以下方式来实现。
resnet50的conv1权重维度为[64,3,7,7],以为着输入图像需为3通道。假设我们要输入的图像为灰度图,那么conv1的输入通道数就应该修改为1。
将原先的 nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False),替换为 nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
self.model.conv1 = conv1 #替换原来的conv1
def forward(self,x):
x = self.model(x)
return x
按照上方的操作,则conv1的预训练权重无法被利用。为了能够利用到conv1的预训练权重,我们沿着dim=1取平局,拓展平均后的权重至与新conv1权重维度一致。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1_weight = torch.mean(self.model.conv1.weight,dim=1,keepdim=True).repeat(1,input_ch,1,1)#取出从conv1权重并进行平均和拓展
conv1 = nn.Conv2d(input_ch, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
model_dict = self.model.state_dict()#获取整个网络的预训练权重
self.model.conv1 = conv1 #替换原来的conv1
model_dict['conv1.weight'] = conv1_weight #将conv1权重替换为新conv1权重
model_dict.update(model_dict)#更新整个网络的预训练权重
self.model.load_state_dict(model_dict)#载入新预训练权重
def forward(self,x):
x = self.model(x)
return x
四、增删特定层
我们还经常遇到需要对网络结构的最后几层进行删改的问题。还是以resnet50为例。假设要完成一个多标签的分类任务,要增加classifier。
import torchvision.models as models
class classifer(nn.Module):
def __init__(self,in_ch,num_classes):
super(classification_head,self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(in_ch,num_classes)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
# import pdb;pdb.set_trace()
return x
class Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
model = models.resnet50(pretrained=pretrained)
self.backbone = nn.Sequential(*list(model.children())[:-3])#把最后的layer4,Avgpool和Fully Connected Layer去除
self.classification_head1 = nn.Sequential(*list(model.children())[-3],
classifier(2048,3))
self.classification_head2 = nn.Sequential(*list(model.children())[-3],
classifier(2048,5))
def forward(self,x):
x = self.backbone(x)
output1 = self.classification_head1(x)
output2 = self.classification_head2(x)
return [output1,putput2]
将layer4也从backbone中分离出来归属到两个classifer是为了避免两个分类任务的相互干扰,仅保留较低层级、共通性高的网络部分进行特征提取,较高层级的网络则对二者分别进行。
边栏推荐
- Open3D 网格(曲面)赋色
- Implementation of array hash function in PHP
- 分类TAB商品流多目标排序模型的演进
- idea设置打开文件窗口个数
- An error is reported in the process of using gbase 8C database: 80000305, host IPS long to different cluster. How to solve it?
- Go language learning notes - analyze the first program
- Project summary notes series wstax kt session2 code analysis
- Shell script file traversal STR to array string splicing
- 【使用TensorRT通过ONNX部署Pytorch项目】
- 7 themes and 9 technology masters! Dragon Dragon lecture hall hard core live broadcast preview in July, see you tomorrow
猜你喜欢

12.(地图数据篇)cesium城市建筑物贴图

Redis集群的重定向

How can China Africa diamond accessory stones be inlaid to be safe and beautiful?

谜语1

如何让你的产品越贵越好卖

11. (map data section) how to download and use OSM data

Oneforall installation and use

7.2 daily study 4

7 themes and 9 technology masters! Dragon Dragon lecture hall hard core live broadcast preview in July, see you tomorrow
![[crawler] bugs encountered by wasm](/img/29/6782bda4c149b7b2b334238936e211.png)
[crawler] bugs encountered by wasm
随机推荐
SLAM 01. Modeling of human recognition Environment & path
查看多台机器所有进程
POJ 3176-Cow Bowling(DP||记忆化搜索)
【上采样方式-OpenCV插值】
Ziguang zhanrui's first 5g R17 IOT NTN satellite in the world has been measured on the Internet of things
Crawler (9) - scrape framework (1) | scrape asynchronous web crawler framework
7 themes and 9 technology masters! Dragon Dragon lecture hall hard core live broadcast preview in July, see you tomorrow
Open3D 欧式聚类
项目总结笔记系列 wsTax KT Session2 代码分析
redis集群中hash tag 使用
Pytorch training process was interrupted
2048游戏逻辑
一次生产环境redis内存占用居高不下问题排查
Ffmpeg calls avformat_ open_ Error -22 returned during input (invalid argument)
shell脚本文件遍历 str转数组 字符串拼接
汉诺塔问题思路的证明
居家办公那些事|社区征文
c#操作xml文件
Manage multiple instagram accounts and share anti Association tips
高校毕业求职难?“百日千万”网络招聘活动解决你的难题