当前位置:网站首页>[pytorch modifies the pre training model: there is little difference between the measured loading pre training model and the random initialization of the model]
[pytorch modifies the pre training model: there is little difference between the measured loading pre training model and the random initialization of the model]
2022-07-05 11:48:00 【Network starry sky (LUOC)】
List of articles
1. pytorch Pre training model
The training of convolutional neural network is time-consuming , In many cases, it is impossible to train the network every time from the random initialization parameters .pytorch Several commonly used deep learning network pre training models are included in , Such as VGG、ResNet etc. . Often in order to speed up the progress of learning , At the beginning of training, we directly load pre-train Pre trained parameters in the model ,model The loading of is as follows :
import torchvision.models as models
#resnet
model = models.ResNet(pretrained=True)
model = models.resnet18(pretrained=True)
model = models.resnet34(pretrained=True)
model = models.resnet50(pretrained=True)
#vgg
model = models.VGG(pretrained=True)
model = models.vgg11(pretrained=True)
model = models.vgg16(pretrained=True)
model = models.vgg16_bn(pretrained=True)
2. Modify the number of full connection layer categories
The pre training model is based on resnet50 For example .
model = torchvision.models.resnet50(pretrained=True)
# extract fc Fixed parameters in layer
fc_features = model.fc.in_features
# Change the category to 10, Redefine the last layer
model.fc = nn.Linear(fc_features ,10)
print(model.fc)
Or directly transfer in the number of categories :
self.resnet = torchvision.models.resnet50(pretrained=False,num_classes=10)
3. Modify the convolution of a certain layer
The pre training model is based on resnet50 For example .
model = torchvision.models.resnet50(pretrained=True)
# Redefine the number of input channels of the first layer convolution
model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
4. Modify the convolution of certain layers
4.1 Remove the last two layers (fc Layer and the pooling layer )
The pre training model is based on resnet50 For example .
nn.module Of model It contains a called children() Function of , This function can be used to extract model The network structure of each layer , On this basis, it can be modified , The modification method is as follows ( Remove the last two layers ):
resnet_50_s = torchvision.models.resnet50(pretrained=False)
resnet_layer = nn.Sequential(*list(resnet_50_s.children())[:-2])
self.resnet = resnet_layer
After removing the pre training resnet The last two layers of the model (fc Layer and the pooling layer ) after , Add a new upper sampling layer 、 Pool layer and classification layer , The code to build the network is as follows :
class Net_resnet50_upsample(nn.Module):
def __init__(self):
super(Net_resnet50_upsample, self).__init__()
self.conv = nn.Conv2d(1, 3, kernel_size=1)
resnet_50_s = torchvision.models.resnet50(pretrained=False)
resnet_layer = nn.Sequential(*list(resnet_50_s.children())[:-2])
self.resnet = resnet_layer
# print(self.resnet)
self.up7to14=nn.UpsamplingNearest2d(scale_factor=2)
self.avgpool=nn.AvgPool2d(7,stride=2)
self.fc = nn.Sequential(
nn.Linear(2048 * 4 * 4, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 10))
def forward(self, x):
x = self.conv(x)
x = self.resnet(x)
x=self.up7to14(x)
x=self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
4.2 Add or remove multiple convolution layers
Sometimes it is necessary to modify the hierarchy in the network , At this time, only the method of parameter coverage can be used , That is to define a similar network first , Then extract the parameters in the pre training to their own network . Here we use resnet Examples of pre training models .
# coding=UTF-8
import torchvision.models as models
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
#Bottleneck It's a class It defines the use of 1*1 A residual block for dimension reduction and dimension increase with the convolution kernel of , Can be in github resnet pytorch Check out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# Layers that are not modified cannot be named randomly , Otherwise, the pre training weight parameter cannot be passed in
class CNN(nn.Module):
def __init__(self, block, layers, num_classes=9):
self.inplanes = 64
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
# Add a new anti roll layer
self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0,
groups=1, bias=False, dilation=1)
# Add a maximum pooling layer
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
# Remove the original fc layer , Add a new one fclass layer
self.fclass = nn.Linear(2048, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
# Newly added layer forward
x = x.view(x.size(0), -1)
x = self.convtranspose1(x)
x = self.maxpool2(x)
x = x.view(x.size(0), -1)
x = self.fclass(x)
return x
# load model
resnet50 = models.resnet50(pretrained=False)
print(resnet50)
cnn = CNN(Bottleneck, [3, 4, 6, 3]) #3 4 6 3 respectively layer1 2 3 4 in Bottleneck The number of modules .res101 Then for 3 4 23 3
# Read parameters
pretrained_dict = resnet50.state_dict()
model_dict = cnn.state_dict()
# take pretrained_dict It doesn't belong to model_dict Key out of
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
# Update existing model_dict
model_dict.update(pretrained_dict)
# Load what we really need state_dict
cnn.load_state_dict(model_dict)
# print(resnet50)
print(cnn)
Results contrast :
The knowledge points of the article match the official knowledge files , You can further learn relevant knowledge
边栏推荐
- 石油化工企业安全生产智能化管控系统平台建设思考和建议
- [crawler] bugs encountered by wasm
- Crawler (9) - scrape framework (1) | scrape asynchronous web crawler framework
- Splunk configuration 163 mailbox alarm
- Ncp1342 chip substitute pn8213 65W gallium nitride charger scheme
- POJ 3176 cow bowling (DP | memory search)
- Solve the grpc connection problem. Dial succeeds with transientfailure
- 【TFLite, ONNX, CoreML, TensorRT Export】
- 【PyTorch预训练模型修改、增删特定层】
- [LeetCode] Wildcard Matching 外卡匹配
猜你喜欢
How did the situation that NFT trading market mainly uses eth standard for trading come into being?
Sentinel sentinel mechanism of master automatic election in redis master-slave
COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
Troubleshooting of high memory usage of redis in a production environment
Redis cluster (master-slave) brain crack and solution
Use and install RkNN toolkit Lite2 on itop-3568 development board NPU
Redis master-slave mode
1个插件搞定网页中的广告
Redirection of redis cluster
pytorch-权重衰退(weight decay)和丢弃法(dropout)
随机推荐
Yolov5 target detection neural network -- calculation principle of loss function
15 methods in "understand series after reading" teach you to play with strings
Open3d European clustering
COMSOL -- 3D casual painting -- sweeping
基于Lucene3.5.0怎样从TokenStream获得Token
COMSOL -- three-dimensional graphics random drawing -- rotation
Redis集群(主从)脑裂及解决方案
Open3D 欧式聚类
[LeetCode] Wildcard Matching 外卡匹配
View all processes of multiple machines
中非 钻石副石怎么镶嵌,才能既安全又好看?
汉诺塔问题思路的证明
c#操作xml文件
C operation XML file
Acid transaction theory
Ncp1342 chip substitute pn8213 65W gallium nitride charger scheme
pytorch-softmax回归
The most comprehensive new database in the whole network, multidimensional table platform inventory note, flowus, airtable, seatable, Vig table Vika, flying Book Multidimensional table, heipayun, Zhix
Hash tag usage in redis cluster
Pytorch linear regression