当前位置:网站首页>先有网络模型的使用及修改
先有网络模型的使用及修改
2022-07-01 04:35:00 【booze-J】
先有网络模型的使用
使用示例代码:
import torchvision
from torch import nn
# 加载网络
# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16的使用有两个常用参数,分别是pretrained和process。
- pretrained - 为True的话,说明这个网络是已经训练好的在训练数据集上有比较好的效果 若为False则说明这个网络是没训练的
- process - 为True则显示下载神经网络参数的进度条若为False则不显示下载神经网络参数的进度条
通俗来理解pretrained,就相当于什么呢?比如搭建神经网络卷积层时,你给了一个kernel_size但是并没有kernel_size中的参数,pretrained=True时相当于你得到了一个带参数的卷积核,pretrained=False时相当于你只知道这个卷积核的大小。
先有网络模型的修改(如何利用现有的网络去改动它的一个结构)
1.添加网络层
示例代码如下:
import torchvision
from torch import nn
# 加载网络
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 1.添加网络层
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
# 方式1:在整个网络中直接添加
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))
# 方式2:在相应的模块中添加
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print("vgg16_true:\n",vgg16_true)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
2.直接修改网络
示例代码如下:
import torchvision
from torch import nn
# 加载网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 2.直接修改网络
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
# 按顺序对网络进行索引,修改最后的线性层
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
边栏推荐
- (12) Somersault cloud case (navigation bar highlights follow)
- Seven crimes of counting software R & D Efficiency
- 2022 G2 power station boiler stoker examination question bank and G2 power station boiler stoker simulation examination question bank
- 使用WinMTR软件简单分析跟踪检测网络路由情况
- 2022 tea master (intermediate) examination question bank and tea master (intermediate) examination questions and analysis
- Common interview questions ①
- Day 52 - tree problem
- Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling
- Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation 阅读笔记
- Basic usage, principle and details of session
猜你喜欢

VR线上展览所具备应用及特色

Section 27 remote access virtual private network workflow and experimental demonstration

Cmake selecting compilers and setting compiler options

Pytorch(四) —— 可视化工具 Visdom

使用WinMTR软件简单分析跟踪检测网络路由情况

2022危险化学品生产单位安全生产管理人员题库及答案

Grey correlation cases and codes

LM小型可编程控制器软件(基于CoDeSys)笔记十九:报错does not match the profile of the target

25.k sets of flipped linked lists

Obtain detailed ideas for ABCDEF questions of 2022 American Games
随机推荐
Web server: how to choose a good web server these five aspects should be paid attention to
Software testing needs more and more talents. Why do you still not want to take this path?
Maixll-Dock 快速上手
Leetcode learning - day 36
LM small programmable controller software (based on CoDeSys) note 20: PLC controls stepping motor through driver
JS image path conversion Base64 format
Threejs opening
[deep learning] (4) decoder mechanism in transformer, complete pytoch code attached
Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling
How to do the performance pressure test of "Health Code"
Learn Chapter 20 of vue3 (keep alive cache component)
Common UNIX Operation and maintenance commands of shell
[send email with error] 535 error:authentication failed
Some small knowledge points
Account sharing technology enables the farmers' market and reshapes the efficiency of transaction management services
Openresty rewrites the location of 302
做网站数据采集,怎么选择合适的服务器呢?
One job hopping up 8K, three times in five years
嵌入式系統開發筆記80:應用Qt Designer進行主界面設計
2022 gas examination question bank and online simulation examination