当前位置:网站首页>【PyTorch预训练模型修改、增删特定层】
【PyTorch预训练模型修改、增删特定层】
2022-07-05 11:34:00 【网络星空(luoc)】
一、绪论
在构建深度学习网络的过程中,经常会遇到需要对预训练模型进行修改和增删特定层的操作。
torchvision.models提供了丰富的模型满足不同任务的选择,因此在构建网络结构时,无需从头开始复现某个网络结构,只需在官方库的基础上进行修改即可。
二、官方模型库
pytorch提供的模型可以通过以下链接查询:https://pytorch.org/vision/stable/models.html,分为分类、分割、目标检测实例分割与关键点检测和视频分类4个分类,可按需寻找需要的模型。
下面以分类任务为例,使用到的是resnet。torchvision.models提供了resnet18,resnet34,resnet50,resnet101,resnet152。右侧两列分别是它们在ImageNet上的top1 Accuracy和top5 Accuracy。

这里以resnet50为例。函数说明如下:

import torchvision.models as models
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
def forward(self,x):
x = self.model(x)
return x
这样,我们就定义了一个Net,这个Net是一个使用了预训练权重的resnet50.
三、修改特定层
使用过程中,我们可能经常会遇到的一个问题是,输入的通道数和网络首层通道数不一致的问题。这里就需要对首层conv进行修改。如果我们初始初始化了一个conv层,又想使用预训练的权重,这时候怎么办呢?我们可以通过以下方式来实现。
resnet50的conv1权重维度为[64,3,7,7],以为着输入图像需为3通道。假设我们要输入的图像为灰度图,那么conv1的输入通道数就应该修改为1。
将原先的 nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False),替换为 nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
self.model.conv1 = conv1 #替换原来的conv1
def forward(self,x):
x = self.model(x)
return x
按照上方的操作,则conv1的预训练权重无法被利用。为了能够利用到conv1的预训练权重,我们沿着dim=1取平局,拓展平均后的权重至与新conv1权重维度一致。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1_weight = torch.mean(self.model.conv1.weight,dim=1,keepdim=True).repeat(1,input_ch,1,1)#取出从conv1权重并进行平均和拓展
conv1 = nn.Conv2d(input_ch, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
model_dict = self.model.state_dict()#获取整个网络的预训练权重
self.model.conv1 = conv1 #替换原来的conv1
model_dict['conv1.weight'] = conv1_weight #将conv1权重替换为新conv1权重
model_dict.update(model_dict)#更新整个网络的预训练权重
self.model.load_state_dict(model_dict)#载入新预训练权重
def forward(self,x):
x = self.model(x)
return x
四、增删特定层
我们还经常遇到需要对网络结构的最后几层进行删改的问题。还是以resnet50为例。假设要完成一个多标签的分类任务,要增加classifier。
import torchvision.models as models
class classifer(nn.Module):
def __init__(self,in_ch,num_classes):
super(classification_head,self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(in_ch,num_classes)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
# import pdb;pdb.set_trace()
return x
class Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
model = models.resnet50(pretrained=pretrained)
self.backbone = nn.Sequential(*list(model.children())[:-3])#把最后的layer4,Avgpool和Fully Connected Layer去除
self.classification_head1 = nn.Sequential(*list(model.children())[-3],
classifier(2048,3))
self.classification_head2 = nn.Sequential(*list(model.children())[-3],
classifier(2048,5))
def forward(self,x):
x = self.backbone(x)
output1 = self.classification_head1(x)
output2 = self.classification_head2(x)
return [output1,putput2]
将layer4也从backbone中分离出来归属到两个classifer是为了避免两个分类任务的相互干扰,仅保留较低层级、共通性高的网络部分进行特征提取,较高层级的网络则对二者分别进行。
边栏推荐
- [office] eight usages of if function in Excel
- 如何通俗理解超级浏览器?可以用于哪些场景?有哪些品牌?
- Go language learning notes - analyze the first program
- How can China Africa diamond accessory stones be inlaid to be safe and beautiful?
- Solve readobjectstart: expect {or N, but found n, error found in 1 byte of
- Pytorch training process was interrupted
- I used Kaitian platform to build an urban epidemic prevention policy inquiry system [Kaitian apaas battle]
- 高校毕业求职难?“百日千万”网络招聘活动解决你的难题
- How does redis implement multiple zones?
- Web API配置自定义路由
猜你喜欢

Go language learning notes - analyze the first program

OneForAll安装使用

Evolution of multi-objective sorting model for classified tab commodity flow

11.(地图数据篇)OSM数据如何下载使用

简单解决redis cluster中从节点读取不了数据(error) MOVED

Redis集群(主从)脑裂及解决方案

AUTOCAD——遮罩命令、如何使用CAD对图纸进行局部放大

12.(地图数据篇)cesium城市建筑物贴图

MySQL 巨坑:update 更新慎用影响行数做判断!!!

Harbor image warehouse construction
随机推荐
汉诺塔问题思路的证明
Differences between IPv6 and IPv4 three departments including the office of network information technology promote IPv6 scale deployment
CDGA|数据治理不得不坚持的六个原则
Spark Tuning (I): from HQL to code
Prevent browser backward operation
全网最全的新型数据库、多维表格平台盘点 Notion、FlowUs、Airtable、SeaTable、维格表 Vika、飞书多维表格、黑帕云、织信 Informat、语雀
An error is reported in the process of using gbase 8C database: 80000305, host IPS long to different cluster. How to solve it?
管理多个Instagram帐户防关联小技巧大分享
pytorch-softmax回归
爬虫(9) - Scrapy框架(1) | Scrapy 异步网络爬虫框架
The art of communication III: Listening between people
COMSOL -- 3D casual painting -- sweeping
技术分享 | 常见接口协议解析
Solve the problem of slow access to foreign public static resources
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
Ffmpeg calls avformat_ open_ Error -22 returned during input (invalid argument)
How does redis implement multiple zones?
[leetcode] wild card matching
Open3D 欧式聚类
MySQL statistical skills: on duplicate key update usage