当前位置:网站首页>先有网络模型的使用及修改
先有网络模型的使用及修改
2022-07-01 04:35:00 【booze-J】
先有网络模型的使用
使用示例代码:
import torchvision
from torch import nn
# 加载网络
# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16的使用有两个常用参数,分别是pretrained和process。
- pretrained - 为True的话,说明这个网络是已经训练好的在训练数据集上有比较好的效果 若为False则说明这个网络是没训练的
- process - 为True则显示下载神经网络参数的进度条若为False则不显示下载神经网络参数的进度条
通俗来理解pretrained,就相当于什么呢?比如搭建神经网络卷积层时,你给了一个kernel_size但是并没有kernel_size中的参数,pretrained=True时相当于你得到了一个带参数的卷积核,pretrained=False时相当于你只知道这个卷积核的大小。
先有网络模型的修改(如何利用现有的网络去改动它的一个结构)
1.添加网络层
示例代码如下:
import torchvision
from torch import nn
# 加载网络
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 1.添加网络层
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
# 方式1:在整个网络中直接添加
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))
# 方式2:在相应的模块中添加
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print("vgg16_true:\n",vgg16_true)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
2.直接修改网络
示例代码如下:
import torchvision
from torch import nn
# 加载网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 2.直接修改网络
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
# 按顺序对网络进行索引,修改最后的线性层
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
边栏推荐
- 为什么香港服务器最适合海外建站使用
- Question bank and answers for chemical automation control instrument operation certificate examination in 2022
- slf4j 简单实现
- Basic exercise of test questions hexadecimal to decimal
- Maixll dock quick start
- Simple implementation of slf4j
- Threejs opening
- OSPF notes [dr and bdr]
- Collect the annual summary of laws, regulations, policies and plans related to trusted computing of large market points (national, ministerial, provincial and municipal)
- How to choose the right server for website data collection?
猜你喜欢

测量三相永磁同步电机的交轴直轴电感

TCP server communication flow

Possible problems and solutions of using scroll view to implement slider view

TASK04|数理统计

Knowledge supplement: basic usage of redis based on docker

Offline installation of Wireshark 2.6.10

Maixll-Dock 使用方法

LM small programmable controller software (based on CoDeSys) note 19: errors do not match the profile of the target

Applications and features of VR online exhibition

This sideline workload is small, 10-15k, free unlimited massage
随机推荐
LM small programmable controller software (based on CoDeSys) note 20: PLC controls stepping motor through driver
软件研发的十大浪费:研发效能的另一面
网站服务器:好用的网站服务器怎么选这五方面要关注
How do I sort a list of strings in dart- How can I sort a list of strings in Dart?
MySQL winter vacation self-study 2022 12 (5)
Shell之Unix运维常用命令
Kodori tree board
206. reverse linked list
2022 hoisting machinery command registration examination and hoisting machinery command examination registration
VIM easy to use tutorial
Day 52 - tree problem
Odeint and GPU
2022 a special equipment related management (elevator) simulation test and a special equipment related management (elevator) certificate examination
2022危险化学品生产单位安全生产管理人员题库及答案
尺取法:有效三角形的个数
CF1638E. Colorful operations Kodori tree + differential tree array
【深度学习】(4) Transformer 中的 Decoder 机制,附Pytorch完整代码
LM small programmable controller software (based on CoDeSys) note 19: errors do not match the profile of the target
[send email with error] 535 error:authentication failed
What is uid? What is auth? What is a verifier?