当前位置:网站首页>【PyTorch预训练模型修改、增删特定层】
【PyTorch预训练模型修改、增删特定层】
2022-07-05 11:34:00 【网络星空(luoc)】
一、绪论
在构建深度学习网络的过程中,经常会遇到需要对预训练模型进行修改和增删特定层的操作。
torchvision.models提供了丰富的模型满足不同任务的选择,因此在构建网络结构时,无需从头开始复现某个网络结构,只需在官方库的基础上进行修改即可。
二、官方模型库
pytorch提供的模型可以通过以下链接查询:https://pytorch.org/vision/stable/models.html,分为分类、分割、目标检测实例分割与关键点检测和视频分类4个分类,可按需寻找需要的模型。
下面以分类任务为例,使用到的是resnet。torchvision.models提供了resnet18,resnet34,resnet50,resnet101,resnet152。右侧两列分别是它们在ImageNet上的top1 Accuracy和top5 Accuracy。
这里以resnet50为例。函数说明如下:
import torchvision.models as models
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
def forward(self,x):
x = self.model(x)
return x
这样,我们就定义了一个Net,这个Net是一个使用了预训练权重的resnet50.
三、修改特定层
使用过程中,我们可能经常会遇到的一个问题是,输入的通道数和网络首层通道数不一致的问题。这里就需要对首层conv进行修改。如果我们初始初始化了一个conv层,又想使用预训练的权重,这时候怎么办呢?我们可以通过以下方式来实现。
resnet50的conv1权重维度为[64,3,7,7],以为着输入图像需为3通道。假设我们要输入的图像为灰度图,那么conv1的输入通道数就应该修改为1。
将原先的 nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False),替换为 nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
self.model.conv1 = conv1 #替换原来的conv1
def forward(self,x):
x = self.model(x)
return x
按照上方的操作,则conv1的预训练权重无法被利用。为了能够利用到conv1的预训练权重,我们沿着dim=1取平局,拓展平均后的权重至与新conv1权重维度一致。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1_weight = torch.mean(self.model.conv1.weight,dim=1,keepdim=True).repeat(1,input_ch,1,1)#取出从conv1权重并进行平均和拓展
conv1 = nn.Conv2d(input_ch, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
model_dict = self.model.state_dict()#获取整个网络的预训练权重
self.model.conv1 = conv1 #替换原来的conv1
model_dict['conv1.weight'] = conv1_weight #将conv1权重替换为新conv1权重
model_dict.update(model_dict)#更新整个网络的预训练权重
self.model.load_state_dict(model_dict)#载入新预训练权重
def forward(self,x):
x = self.model(x)
return x
四、增删特定层
我们还经常遇到需要对网络结构的最后几层进行删改的问题。还是以resnet50为例。假设要完成一个多标签的分类任务,要增加classifier。
import torchvision.models as models
class classifer(nn.Module):
def __init__(self,in_ch,num_classes):
super(classification_head,self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(in_ch,num_classes)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
# import pdb;pdb.set_trace()
return x
class Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
model = models.resnet50(pretrained=pretrained)
self.backbone = nn.Sequential(*list(model.children())[:-3])#把最后的layer4,Avgpool和Fully Connected Layer去除
self.classification_head1 = nn.Sequential(*list(model.children())[-3],
classifier(2048,3))
self.classification_head2 = nn.Sequential(*list(model.children())[-3],
classifier(2048,5))
def forward(self,x):
x = self.backbone(x)
output1 = self.classification_head1(x)
output2 = self.classification_head2(x)
return [output1,putput2]
将layer4也从backbone中分离出来归属到两个classifer是为了避免两个分类任务的相互干扰,仅保留较低层级、共通性高的网络部分进行特征提取,较高层级的网络则对二者分别进行。
边栏推荐
- pytorch-多层感知机MLP
- Summary of thread and thread synchronization under window
- 如何通俗理解超级浏览器?可以用于哪些场景?有哪些品牌?
- pytorch-线性回归
- 网络五连鞭
- 紫光展锐全球首个5G R17 IoT NTN卫星物联网上星实测完成
- Solve readobjectstart: expect {or N, but found n, error found in 1 byte of
- Crawler (9) - scrape framework (1) | scrape asynchronous web crawler framework
- Project summary notes series wstax kt session2 code analysis
- [LeetCode] Wildcard Matching 外卡匹配
猜你喜欢
【爬虫】charles unknown错误
Is it difficult to apply for a job after graduation? "Hundreds of days and tens of millions" online recruitment activities to solve your problems
Advanced technology management - what is the physical, mental and mental strength of managers
分类TAB商品流多目标排序模型的演进
[office] eight usages of if function in Excel
12.(地图数据篇)cesium城市建筑物贴图
Evolution of multi-objective sorting model for classified tab commodity flow
网络五连鞭
redis 集群模式原理
技术管理进阶——什么是管理者之体力、脑力、心力
随机推荐
POJ 3176 cow bowling (DP | memory search)
go语言学习笔记-初识Go语言
ibatis的动态sql
谜语1
Home office things community essay
1个插件搞定网页中的广告
Open3D 欧式聚类
Harbor image warehouse construction
【Win11 多用户同时登录远程桌面配置方法】
Redis集群(主从)脑裂及解决方案
【云原生 | Kubernetes篇】Ingress案例实战(十三)
What about SSL certificate errors? Solutions to common SSL certificate errors in browsers
13. (map data) conversion between Baidu coordinate (bd09), national survey of China coordinate (Mars coordinate, gcj02), and WGS84 coordinate system
解决访问国外公共静态资源速度慢的问题
Guys, I tested three threads to write to three MySQL tables at the same time. Each thread writes 100000 pieces of data respectively, using F
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
Question and answer 45: application of performance probe monitoring principle node JS probe
Empêcher le navigateur de reculer
COMSOL--建立几何模型---二维图形的建立
pytorch-多层感知机MLP