当前位置:网站首页>先有网络模型的使用及修改
先有网络模型的使用及修改
2022-07-01 04:35:00 【booze-J】
先有网络模型的使用
使用示例代码:
import torchvision
from torch import nn
# 加载网络
# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16的使用有两个常用参数,分别是pretrained和process。
- pretrained - 为True的话,说明这个网络是已经训练好的在训练数据集上有比较好的效果 若为False则说明这个网络是没训练的
- process - 为True则显示下载神经网络参数的进度条若为False则不显示下载神经网络参数的进度条
通俗来理解pretrained,就相当于什么呢?比如搭建神经网络卷积层时,你给了一个kernel_size但是并没有kernel_size中的参数,pretrained=True时相当于你得到了一个带参数的卷积核,pretrained=False时相当于你只知道这个卷积核的大小。
先有网络模型的修改(如何利用现有的网络去改动它的一个结构)
1.添加网络层
示例代码如下:
import torchvision
from torch import nn
# 加载网络
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 1.添加网络层
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
# 方式1:在整个网络中直接添加
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))
# 方式2:在相应的模块中添加
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print("vgg16_true:\n",vgg16_true)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
2.直接修改网络
示例代码如下:
import torchvision
from torch import nn
# 加载网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 2.直接修改网络
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
# 按顺序对网络进行索引,修改最后的线性层
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
边栏推荐
- Strategic suggestions and future development trend of global and Chinese vibration isolator market investment report 2022 Edition
- JVM栈和堆简介
- Difference between cookie and session
- Dede collection plug-in does not need to write rules
- Custom components in applets
- MySQL function variable stored procedure
- Maixll-Dock 使用方法
- How to view the changes and opportunities in the construction of smart cities?
- I also gave you the MySQL interview questions of Boda factory. If you need to come in and take your own
- 扩展-Fragment
猜你喜欢

Maixll dock quick start

2022 Shanghai safety officer C certificate examination question simulation examination question bank and answers

How to do the performance pressure test of "Health Code"

TASK04|數理統計

OdeInt与GPU

Applications and features of VR online exhibition

Threejs opening

Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling

Task04 | statistiques mathématiques

2022年煤气考试题库及在线模拟考试
随机推荐
[difficult] sqlserver2008r2, can you recover only some files when recovering the database?
[recommended algorithm] C interview question of a small factory
2022危险化学品生产单位安全生产管理人员题库及答案
2022 t elevator repair new version test questions and t elevator repair simulation test question bank
25.k sets of flipped linked lists
All in all, the low code still needs to solve these four problems
Task04 mathematical statistics
Embedded System Development Notes 79: why should I get the IP address of the local network card
数据加载及预处理
为什么香港服务器最适合海外建站使用
Software testing needs more and more talents. Why do you still not want to take this path?
如何看待智慧城市建设中的改变和机遇?
【硬十宝典】——2.【基础知识】开关电源各种拓扑结构的特点
Shell之分析服务器日志命令集锦
尺取法:有效三角形的个数
离线安装wireshark2.6.10
LM小型可编程控制器软件(基于CoDeSys)笔记二十:plc通过驱动器控制步进电机
OSPF notes [multiple access, two multicast addresses with OSPF]
Grey correlation cases and codes
One click shell to automatically deploy any version of redis