当前位置:网站首页>[pytorch modifies the pre training model: there is little difference between the measured loading pre training model and the random initialization of the model]
[pytorch modifies the pre training model: there is little difference between the measured loading pre training model and the random initialization of the model]
2022-07-05 11:48:00 【Network starry sky (LUOC)】
List of articles
1. pytorch Pre training model
The training of convolutional neural network is time-consuming , In many cases, it is impossible to train the network every time from the random initialization parameters .pytorch Several commonly used deep learning network pre training models are included in , Such as VGG、ResNet etc. . Often in order to speed up the progress of learning , At the beginning of training, we directly load pre-train Pre trained parameters in the model ,model The loading of is as follows :
import torchvision.models as models
#resnet
model = models.ResNet(pretrained=True)
model = models.resnet18(pretrained=True)
model = models.resnet34(pretrained=True)
model = models.resnet50(pretrained=True)
#vgg
model = models.VGG(pretrained=True)
model = models.vgg11(pretrained=True)
model = models.vgg16(pretrained=True)
model = models.vgg16_bn(pretrained=True)
2. Modify the number of full connection layer categories
The pre training model is based on resnet50 For example .
model = torchvision.models.resnet50(pretrained=True)
# extract fc Fixed parameters in layer
fc_features = model.fc.in_features
# Change the category to 10, Redefine the last layer
model.fc = nn.Linear(fc_features ,10)
print(model.fc)
Or directly transfer in the number of categories :
self.resnet = torchvision.models.resnet50(pretrained=False,num_classes=10)
3. Modify the convolution of a certain layer
The pre training model is based on resnet50 For example .
model = torchvision.models.resnet50(pretrained=True)
# Redefine the number of input channels of the first layer convolution
model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
4. Modify the convolution of certain layers
4.1 Remove the last two layers (fc Layer and the pooling layer )
The pre training model is based on resnet50 For example .
nn.module Of model It contains a called children() Function of , This function can be used to extract model The network structure of each layer , On this basis, it can be modified , The modification method is as follows ( Remove the last two layers ):
resnet_50_s = torchvision.models.resnet50(pretrained=False)
resnet_layer = nn.Sequential(*list(resnet_50_s.children())[:-2])
self.resnet = resnet_layer
After removing the pre training resnet The last two layers of the model (fc Layer and the pooling layer ) after , Add a new upper sampling layer 、 Pool layer and classification layer , The code to build the network is as follows :
class Net_resnet50_upsample(nn.Module):
def __init__(self):
super(Net_resnet50_upsample, self).__init__()
self.conv = nn.Conv2d(1, 3, kernel_size=1)
resnet_50_s = torchvision.models.resnet50(pretrained=False)
resnet_layer = nn.Sequential(*list(resnet_50_s.children())[:-2])
self.resnet = resnet_layer
# print(self.resnet)
self.up7to14=nn.UpsamplingNearest2d(scale_factor=2)
self.avgpool=nn.AvgPool2d(7,stride=2)
self.fc = nn.Sequential(
nn.Linear(2048 * 4 * 4, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 10))
def forward(self, x):
x = self.conv(x)
x = self.resnet(x)
x=self.up7to14(x)
x=self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
4.2 Add or remove multiple convolution layers
Sometimes it is necessary to modify the hierarchy in the network , At this time, only the method of parameter coverage can be used , That is to define a similar network first , Then extract the parameters in the pre training to their own network . Here we use resnet Examples of pre training models .
# coding=UTF-8
import torchvision.models as models
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
#Bottleneck It's a class It defines the use of 1*1 A residual block for dimension reduction and dimension increase with the convolution kernel of , Can be in github resnet pytorch Check out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# Layers that are not modified cannot be named randomly , Otherwise, the pre training weight parameter cannot be passed in
class CNN(nn.Module):
def __init__(self, block, layers, num_classes=9):
self.inplanes = 64
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
# Add a new anti roll layer
self.convtranspose1 = nn.ConvTranspose2d(2048, 2048, kernel_size=3, stride=1, padding=1, output_padding=0,
groups=1, bias=False, dilation=1)
# Add a maximum pooling layer
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
# Remove the original fc layer , Add a new one fclass layer
self.fclass = nn.Linear(2048, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
# Newly added layer forward
x = x.view(x.size(0), -1)
x = self.convtranspose1(x)
x = self.maxpool2(x)
x = x.view(x.size(0), -1)
x = self.fclass(x)
return x
# load model
resnet50 = models.resnet50(pretrained=False)
print(resnet50)
cnn = CNN(Bottleneck, [3, 4, 6, 3]) #3 4 6 3 respectively layer1 2 3 4 in Bottleneck The number of modules .res101 Then for 3 4 23 3
# Read parameters
pretrained_dict = resnet50.state_dict()
model_dict = cnn.state_dict()
# take pretrained_dict It doesn't belong to model_dict Key out of
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
# Update existing model_dict
model_dict.update(pretrained_dict)
# Load what we really need state_dict
cnn.load_state_dict(model_dict)
# print(resnet50)
print(cnn)
Results contrast :
The knowledge points of the article match the official knowledge files , You can further learn relevant knowledge
边栏推荐
- C # implements WinForm DataGridView control to support overlay data binding
- C operation XML file
- 13. (map data) conversion between Baidu coordinate (bd09), national survey of China coordinate (Mars coordinate, gcj02), and WGS84 coordinate system
- 【TFLite, ONNX, CoreML, TensorRT Export】
- COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
- 投资理财适合女生吗?女生可以买哪些理财产品?
- Manage multiple instagram accounts and share anti Association tips
- PHP中Array的hash函数实现
- 简单解决redis cluster中从节点读取不了数据(error) MOVED
- [configuration method of win11 multi-user simultaneous login remote desktop]
猜你喜欢
《增长黑客》阅读笔记
12.(地图数据篇)cesium城市建筑物贴图
[yolov3 loss function]
1 plug-in to handle advertisements in web pages
Harbor image warehouse construction
13.(地图数据篇)百度坐标(BD09)、国测局坐标(火星坐标,GCJ02)、和WGS84坐标系之间的转换
iTOP-3568开发板NPU使用安装RKNN Toolkit Lite2
Pytorch weight decay and dropout
splunk配置163邮箱告警
[calculation of loss in yolov3]
随机推荐
Pytorch weight decay and dropout
以交互方式安装ESXi 6.0
Codeforces Round #804 (Div. 2)
PHP中Array的hash函数实现
pytorch-线性回归
Pytorch softmax regression
yolov5目標檢測神經網絡——損失函數計算原理
7 themes and 9 technology masters! Dragon Dragon lecture hall hard core live broadcast preview in July, see you tomorrow
[mainstream nivida graphics card deep learning / reinforcement learning /ai computing power summary]
【TFLite, ONNX, CoreML, TensorRT Export】
XML解析
Network five whip
pytorch-softmax回归
Thoughts and suggestions on the construction of intelligent management and control system platform for safe production in petrochemical enterprises
Programmers are involved and maintain industry competitiveness
2048 game logic
871. Minimum Number of Refueling Stops
多表操作-自关联查询
ACID事务理论
pytorch-多层感知机MLP