当前位置:网站首页>Pytorch -- use and modification of existing network model
Pytorch -- use and modification of existing network model
2022-06-29 00:55:00 【Herding cattle】

With VGG Take the Internet for example , stay Docs in , Select image related torchvision, Left side torchvision.models Are some ready-made network models , On the right is a classification of network models , The first is the model related to classification .
vgg The most common is vgg16 and vgg19:

Parameters pretrained If True, Then the parameters in the model use the existing parameters that are already in ImageNet Data set training is completed , if False, Then the parameter has not been trained ,process if True, A download progress bar will be displayed .

need scipy The package can be installed ,root Is the path ,split Whether to choose training set or test set ,transform Yes PIL Form transformation ,target_transform Transform your goals ,loader Use... When loading datasets .
train_data = torchvision.datasets.ImageNet("./data_ImageNet", split="train", download=True, transform=torchvision.transforms.ToTensor())If you use this method to download , Will report a mistake :
RuntimeError: The dataset is no longer publicly accessible. You need to download the archives externally and place them in the root directory.
Data sets are no longer publicly accessible , To download the data set and put it in the directory . Just search the Internet , The size of the training set is 100 Multiple G.
We load the model directly :
import torchvision
vgg16_False = torchvision.models.vgg16(pretrained=False)#progress The default is True
vgg16_True = torchvision.models.vgg16(pretrained=True)View the parameters of the two models :
vgg16_False:

vgg16_True:

Set to False when , The parameter is the initial default value , by True when , The parameters have been trained , You can get good results . Print out the trained network model :
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7):...
Look at the last layer :
Linear(in_features=4096, out_features=1000, bias=True)
The output is 1000 individual , The network is a classification model , We know that in the end 1000 Classes . As mentioned in the previous blog CIFAR10 Data sets divide images into 10 Classes , Can be vgg16 Last linear layer 1000 Change it to 10, Or add a new linear layer , Input 1000, Output 10.
vgg16_True.add_module('add_linear', nn.Linear(1000, 100))see add_module Source code :
def add_module(self, name: str, module: Optional['Module']) -> None:
r"""Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (string): name of the child module. The child module can be
accessed from this module using the given name
module (Module): child module to be added to the module.
"""name It's the name of the new layer ,module Is the newly added layer . New network is :
...
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
(add_linear): Linear(in_features=1000, out_features=100, bias=True)
If you want to add to classifier in , It is amended as follows :
vgg16_True.classifier.add_module('add_linear', nn.Linear(1000, 100))If you want to put the last one Linear4096 Input ,1000 The output is modified to 4096 Input 10 Output :
vgg16_True.classifier[6] = nn.Linear(4096, 10)( Because the last floor is classifier in , And the serial number is 6)
边栏推荐
- Maximum path and problem (cherry picking problem)
- It is safer for individuals to choose a securities company to open an account when buying interbank certificates of deposit
- [eight part essay] MySQL
- [gym 102423]-elven efficiency | thinking
- 【SV 基础】queue 的一些用法
- 滑环的基本结构及工作原理分析
- [image denoising] matlab code for removing salt and pepper noise based on fast and effective multistage selective convolution filter
- Nodejs安装和下载
- 《Reinforcement learning based parameters adaption method for particleswarm optimization》代码复现
- BMFONT制作位图字体并在CocosCreator中使用
猜你喜欢

分析框架——用户体验度量数据体系搭建
![[image detection] line recognition based on Hough transform (fitting angle bisector) with matlab code](/img/29/a3dc68ebc958ff96c3d8cc771a84f1.jpg)
[image detection] line recognition based on Hough transform (fitting angle bisector) with matlab code

Reference materials in the process of using Excel

Daily question 1: the number of numbers in the array

PR 2021 quick start tutorial, how to use audio editing in PR?

Difference between applying for trademark in the name of individual and company

Breadth first search to catch cattle

Daily practice: delete duplicates in the ordered array
![[eight part essay] MySQL](/img/8e/719149fb49f1850baf5bab343955bf.jpg)
[eight part essay] MySQL

How to solve the problem of Caton screen when easycvr plays video?
随机推荐
Blazor University (34) forms - get form status
Jbridge bridging frame technology for AI computing power landing
Redis common command manual
大型网站架构基础之笔记
[staff] accent mark, gradually stronger mark and gradually weaker mark
Structure of the actual combat battalion | module 5
【SV 基础】queue 的一些用法
Redis常用命令手册
How to mount FSS object storage locally
FATAL ERROR: Could not find ./ bin/my_ print_ Solutions to defaults
Reprint: VTK notes - clipping and segmentation - irregular closed loop clipping -vtkselectpolydata class (black mountain old demon)
Introduction to JVM working principle
Depth first search to realize the problem of catching cattle
卷绕工艺与叠片工艺的对比
Misunderstanding of innovation by enterprise and it leaders
EasyCVR集群版本替换成老数据库造成的服务崩溃是什么原因?
《Reinforcement learning based parameters adaption method for particleswarm optimization》代码复现
[eight part essay] MySQL
UVM: message mechanism
What is redis