当前位置:网站首页>【PyTorch预训练模型修改、增删特定层】
【PyTorch预训练模型修改、增删特定层】
2022-07-05 11:34:00 【网络星空(luoc)】
一、绪论
在构建深度学习网络的过程中,经常会遇到需要对预训练模型进行修改和增删特定层的操作。
torchvision.models提供了丰富的模型满足不同任务的选择,因此在构建网络结构时,无需从头开始复现某个网络结构,只需在官方库的基础上进行修改即可。
二、官方模型库
pytorch提供的模型可以通过以下链接查询:https://pytorch.org/vision/stable/models.html,分为分类、分割、目标检测实例分割与关键点检测和视频分类4个分类,可按需寻找需要的模型。
下面以分类任务为例,使用到的是resnet。torchvision.models提供了resnet18,resnet34,resnet50,resnet101,resnet152。右侧两列分别是它们在ImageNet上的top1 Accuracy和top5 Accuracy。

这里以resnet50为例。函数说明如下:

import torchvision.models as models
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
def forward(self,x):
x = self.model(x)
return x
这样,我们就定义了一个Net,这个Net是一个使用了预训练权重的resnet50.
三、修改特定层
使用过程中,我们可能经常会遇到的一个问题是,输入的通道数和网络首层通道数不一致的问题。这里就需要对首层conv进行修改。如果我们初始初始化了一个conv层,又想使用预训练的权重,这时候怎么办呢?我们可以通过以下方式来实现。
resnet50的conv1权重维度为[64,3,7,7],以为着输入图像需为3通道。假设我们要输入的图像为灰度图,那么conv1的输入通道数就应该修改为1。
将原先的 nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False),替换为 nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
self.model.conv1 = conv1 #替换原来的conv1
def forward(self,x):
x = self.model(x)
return x
按照上方的操作,则conv1的预训练权重无法被利用。为了能够利用到conv1的预训练权重,我们沿着dim=1取平局,拓展平均后的权重至与新conv1权重维度一致。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1_weight = torch.mean(self.model.conv1.weight,dim=1,keepdim=True).repeat(1,input_ch,1,1)#取出从conv1权重并进行平均和拓展
conv1 = nn.Conv2d(input_ch, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
model_dict = self.model.state_dict()#获取整个网络的预训练权重
self.model.conv1 = conv1 #替换原来的conv1
model_dict['conv1.weight'] = conv1_weight #将conv1权重替换为新conv1权重
model_dict.update(model_dict)#更新整个网络的预训练权重
self.model.load_state_dict(model_dict)#载入新预训练权重
def forward(self,x):
x = self.model(x)
return x
四、增删特定层
我们还经常遇到需要对网络结构的最后几层进行删改的问题。还是以resnet50为例。假设要完成一个多标签的分类任务,要增加classifier。
import torchvision.models as models
class classifer(nn.Module):
def __init__(self,in_ch,num_classes):
super(classification_head,self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(in_ch,num_classes)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
# import pdb;pdb.set_trace()
return x
class Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
model = models.resnet50(pretrained=pretrained)
self.backbone = nn.Sequential(*list(model.children())[:-3])#把最后的layer4,Avgpool和Fully Connected Layer去除
self.classification_head1 = nn.Sequential(*list(model.children())[-3],
classifier(2048,3))
self.classification_head2 = nn.Sequential(*list(model.children())[-3],
classifier(2048,5))
def forward(self,x):
x = self.backbone(x)
output1 = self.classification_head1(x)
output2 = self.classification_head2(x)
return [output1,putput2]
将layer4也从backbone中分离出来归属到两个classifer是为了避免两个分类任务的相互干扰,仅保留较低层级、共通性高的网络部分进行特征提取,较高层级的网络则对二者分别进行。
边栏推荐
- [crawler] Charles unknown error
- [leetcode] wild card matching
- 一次生产环境redis内存占用居高不下问题排查
- FFmpeg调用avformat_open_input时返回错误 -22(Invalid argument)
- redis 集群模式原理
- How to understand super browser? What scenarios can it be used in? What brands are there?
- What about SSL certificate errors? Solutions to common SSL certificate errors in browsers
- [SWT component] content scrolledcomposite
- Solve the problem of slow access to foreign public static resources
- Technology sharing | common interface protocol analysis
猜你喜欢

【无标题】

pytorch-权重衰退(weight decay)和丢弃法(dropout)

网络五连鞭

How to make your products as expensive as possible

龙蜥社区第九次运营委员会会议顺利召开

comsol--三维图形随便画----回转

Redis集群(主从)脑裂及解决方案

Harbor image warehouse construction

Differences between IPv6 and IPv4 three departments including the office of network information technology promote IPv6 scale deployment

redis主从中的Master自动选举之Sentinel哨兵机制
随机推荐
查看多台机器所有进程
CDGA|数据治理不得不坚持的六个原则
How to understand super browser? What scenarios can it be used in? What brands are there?
Web API配置自定义路由
一次生产环境redis内存占用居高不下问题排查
【使用TensorRT通过ONNX部署Pytorch项目】
Implementation of array hash function in PHP
Redis集群的重定向
13. (map data) conversion between Baidu coordinate (bd09), national survey of China coordinate (Mars coordinate, gcj02), and WGS84 coordinate system
石油化工企业安全生产智能化管控系统平台建设思考和建议
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
Guys, I tested three threads to write to three MySQL tables at the same time. Each thread writes 100000 pieces of data respectively, using F
阻止浏览器后退操作
Mysql统计技巧:ON DUPLICATE KEY UPDATE用法
项目总结笔记系列 wsTax KT Session2 代码分析
PHP中Array的hash函数实现
Acid transaction theory
Project summary notes series wstax kt session2 code analysis
基于Lucene3.5.0怎样从TokenStream获得Token
Evolution of multi-objective sorting model for classified tab commodity flow