当前位置:网站首页>【PyTorch预训练模型修改、增删特定层】
【PyTorch预训练模型修改、增删特定层】
2022-07-05 11:34:00 【网络星空(luoc)】
一、绪论
在构建深度学习网络的过程中,经常会遇到需要对预训练模型进行修改和增删特定层的操作。
torchvision.models提供了丰富的模型满足不同任务的选择,因此在构建网络结构时,无需从头开始复现某个网络结构,只需在官方库的基础上进行修改即可。
二、官方模型库
pytorch提供的模型可以通过以下链接查询:https://pytorch.org/vision/stable/models.html,分为分类、分割、目标检测实例分割与关键点检测和视频分类4个分类,可按需寻找需要的模型。
下面以分类任务为例,使用到的是resnet。torchvision.models提供了resnet18,resnet34,resnet50,resnet101,resnet152。右侧两列分别是它们在ImageNet上的top1 Accuracy和top5 Accuracy。
这里以resnet50为例。函数说明如下:
import torchvision.models as models
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
def forward(self,x):
x = self.model(x)
return x
这样,我们就定义了一个Net,这个Net是一个使用了预训练权重的resnet50.
三、修改特定层
使用过程中,我们可能经常会遇到的一个问题是,输入的通道数和网络首层通道数不一致的问题。这里就需要对首层conv进行修改。如果我们初始初始化了一个conv层,又想使用预训练的权重,这时候怎么办呢?我们可以通过以下方式来实现。
resnet50的conv1权重维度为[64,3,7,7],以为着输入图像需为3通道。假设我们要输入的图像为灰度图,那么conv1的输入通道数就应该修改为1。
将原先的 nn.Conv2d(3, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False),替换为 nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
self.model.conv1 = conv1 #替换原来的conv1
def forward(self,x):
x = self.model(x)
return x
按照上方的操作,则conv1的预训练权重无法被利用。为了能够利用到conv1的预训练权重,我们沿着dim=1取平局,拓展平均后的权重至与新conv1权重维度一致。
def Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
self.model = models.resnet50(pretrained=pretrained)
conv1_weight = torch.mean(self.model.conv1.weight,dim=1,keepdim=True).repeat(1,input_ch,1,1)#取出从conv1权重并进行平均和拓展
conv1 = nn.Conv2d(input_ch, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) #新的conv1层
model_dict = self.model.state_dict()#获取整个网络的预训练权重
self.model.conv1 = conv1 #替换原来的conv1
model_dict['conv1.weight'] = conv1_weight #将conv1权重替换为新conv1权重
model_dict.update(model_dict)#更新整个网络的预训练权重
self.model.load_state_dict(model_dict)#载入新预训练权重
def forward(self,x):
x = self.model(x)
return x
四、增删特定层
我们还经常遇到需要对网络结构的最后几层进行删改的问题。还是以resnet50为例。假设要完成一个多标签的分类任务,要增加classifier。
import torchvision.models as models
class classifer(nn.Module):
def __init__(self,in_ch,num_classes):
super(classification_head,self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = nn.Linear(in_ch,num_classes)
def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
# import pdb;pdb.set_trace()
return x
class Net(nn.Module):
def __init__(self, input_ch, num_class,pretrained=True):
super(Net,self).__init__()
model = models.resnet50(pretrained=pretrained)
self.backbone = nn.Sequential(*list(model.children())[:-3])#把最后的layer4,Avgpool和Fully Connected Layer去除
self.classification_head1 = nn.Sequential(*list(model.children())[-3],
classifier(2048,3))
self.classification_head2 = nn.Sequential(*list(model.children())[-3],
classifier(2048,5))
def forward(self,x):
x = self.backbone(x)
output1 = self.classification_head1(x)
output2 = self.classification_head2(x)
return [output1,putput2]
将layer4也从backbone中分离出来归属到两个classifer是为了避免两个分类任务的相互干扰,仅保留较低层级、共通性高的网络部分进行特征提取,较高层级的网络则对二者分别进行。
边栏推荐
猜你喜欢
Idea set the number of open file windows
无密码身份验证如何保障用户隐私安全?
一次生产环境redis内存占用居高不下问题排查
How did the situation that NFT trading market mainly uses eth standard for trading come into being?
【Win11 多用户同时登录远程桌面配置方法】
NFT 交易市场主要使用 ETH 本位进行交易的局面是如何形成的?
MySQL giant pit: update updates should be judged with caution by affecting the number of rows!!!
Ziguang zhanrui's first 5g R17 IOT NTN satellite in the world has been measured on the Internet of things
liunx禁ping 详解traceroute的不同用法
Pytorch training process was interrupted
随机推荐
Web API configuration custom route
CDGA|数据治理不得不坚持的六个原则
【使用TensorRT通过ONNX部署Pytorch项目】
POJ 3176-Cow Bowling(DP||记忆化搜索)
What about SSL certificate errors? Solutions to common SSL certificate errors in browsers
How does redis implement multiple zones?
comsol--三维图形随便画----回转
c#操作xml文件
[office] eight usages of if function in Excel
liunx禁ping 详解traceroute的不同用法
ACID事务理论
Prevent browser backward operation
TSQL – identity column, guid, sequence
How to get a token from tokenstream based on Lucene 3.5.0
紫光展锐全球首个5G R17 IoT NTN卫星物联网上星实测完成
基于Lucene3.5.0怎样从TokenStream获得Token
FFmpeg调用avformat_open_input时返回错误 -22(Invalid argument)
C # implements WinForm DataGridView control to support overlay data binding
【爬虫】wasm遇到的bug
Manage multiple instagram accounts and share anti Association tips