当前位置:网站首页>Use and modification of prior network model
Use and modification of prior network model
2022-07-01 04:46:00 【booze-J】
article
First there is the use of the network model
Use sample code :
import torchvision
from torch import nn
# Load network
# This sentence ( When pretrained Set to False when ) It is equivalent to replacing the network architecture here , The parameters of the network model are initialized , Are the default parameters
vgg16_false = torchvision.models.vgg16(pretrained=False)
# This sentence ( When pretrained Set to True when ) The parameters of the network model are all in ImageNet Trained on the dataset , Is in the ImageNet A good result can be achieved on the data set
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16 There are two common parameters for the use of , Namely pretrained and process.
- pretrained - by True Words , It shows that the network has been trained and has a good effect on the training data set if False It means that this network is not trained
- process - by True The progress bar for downloading neural network parameters will be displayed. If it is False The progress bar for downloading neural network parameters will not be displayed
To understand pretrained, It's equivalent to something ? For example, when building the convolution layer of neural network , You gave one kernel_size But it didn't kernel_size Parameters in ,pretrained=True When you get a convolution kernel with parameters ,pretrained=False When you only know the size of this convolution kernel .
First there is the modification of the network model ( How to use the existing network to change its structure )
1. Add a network layer
The sample code is as follows :
import torchvision
from torch import nn
# Load network
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# How to use the existing network to change one of its structures
# 1. Add a network layer
# load CIFAR10 Data sets
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# take vgg16_true Model applied to CIFAR10 On dataset , Why add a in_feature=1000,out_feature=10 The linear layer of ? because vgg16_true Network training ImageNet The dataset has 1000 A classification , and CIFAR10 Only 10 classification , So will vgg16_true Internet applications are in CIFAR10 What I said , Need to add a in_feature=1000,out_feature=10 The linear layer of .
# The way 1: Add... Directly across the network
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))
# The way 2: Add
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print("vgg16_true:\n",vgg16_true)
Running results :

Explain : take vgg16_true Model applied to CIFAR10 On dataset , Why add a in_feature=1000,out_feature=10 The linear layer of ? because vgg16_true Network training ImageNet The dataset has 1000 A classification , and CIFAR10 Only 10 classification , So will vgg16_true Internet applications are in CIFAR10 What I said , Need to add a in_feature=1000,out_feature=10 The linear layer of .
2. Modify network directly
The sample code is as follows :
import torchvision
from torch import nn
# Load network model
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# How to use the existing network to change one of its structures
# 2. Modify network directly
# load CIFAR10 Data sets
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# take vgg16_true Model applied to CIFAR10 On dataset , Why modify the last linear layer out_feature=10 Well ? because vgg16_true Network training ImageNet The dataset has 1000 A classification , and CIFAR10 Only 10 classification , So will vgg16_true Internet applications are in CIFAR10 What I said , Need to modify the last linear layer out_feature=10.
# Index the network sequentially , Modify the last linear layer
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)
Running results :

Explain : take vgg16_true Model applied to CIFAR10 On dataset , Why modify the last linear layer out_feature=10 Well ? because vgg16_true Network training ImageNet The dataset has 1000 A classification , and CIFAR10 Only 10 classification , So will vgg16_true Internet applications are in CIFAR10 What I said , Need to modify the last linear layer out_feature=10.
边栏推荐
- Simple implementation of slf4j
- Construction of Meizhou nursing laboratory: equipment configuration
- I also gave you the MySQL interview questions of Boda factory. If you need to come in and take your own
- Leecode record 1351 negative numbers in statistical ordered matrix
- All in all, the low code still needs to solve these four problems
- Dual contractual learning: text classification via label aware data augmentation reading notes
- How to view the changes and opportunities in the construction of smart cities?
- Grey correlation cases and codes
- Quelques outils dont les chiens scientifiques pourraient avoir besoin
- LeetCode_28(实现 strStr())
猜你喜欢

2022 a special equipment related management (elevator) simulation test and a special equipment related management (elevator) certificate examination

One click shell to automatically deploy any version of redis

Ten wastes of software research and development: the other side of research and development efficiency

Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling

Use of dataloader

2022年化工自动化控制仪表操作证考试题库及答案

Kodori tree board

Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation 阅读笔记

Pytorch(二) —— 激活函数、损失函数及其梯度

手动实现一个简单的栈
随机推荐
常用的Transforms中的方法
测量三相永磁同步电机的交轴直轴电感
Odeint and GPU
Cmake selecting compilers and setting compiler options
STM32 光敏电阻传感器&两路AD采集
pytorch中常用数据集的使用方法
【硬十宝典目录】——转载自“硬件十万个为什么”(持续更新中~~)
All in all, the low code still needs to solve these four problems
2022 a special equipment related management (elevator) simulation test and a special equipment related management (elevator) certificate examination
Sorting out 49 reports of knowledge map industry conference | AI sees the future with wisdom
数据加载及预处理
OdeInt與GPU
Shell之Unix运维常用命令
Announcement on the list of Guangdong famous high-tech products to be selected in 2021
pytorch神经网络搭建 模板
Use of dataloader
2022-02-15 (399. Division evaluation)
Quelques outils dont les chiens scientifiques pourraient avoir besoin
[FTP] common FTP commands, updating continuously
STM32扩展板 温度传感器和温湿度传感器的使用