当前位置:网站首页>先有网络模型的使用及修改
先有网络模型的使用及修改
2022-07-01 04:35:00 【booze-J】
先有网络模型的使用
使用示例代码:
import torchvision
from torch import nn
# 加载网络
# 这一句话(当pretrained设置为False时)就相当与把网络架构在这里替换了一下,网络模型的参数都是初始化的,是默认的一些参数
vgg16_false = torchvision.models.vgg16(pretrained=False)
# 这一句话(当pretrained设置为True时)网络模型的参数都是在ImageNet数据集上训练好的,就是在ImageNet数据集上能够达到一个比较好的效果
vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16的使用有两个常用参数,分别是pretrained和process。
- pretrained - 为True的话,说明这个网络是已经训练好的在训练数据集上有比较好的效果 若为False则说明这个网络是没训练的
- process - 为True则显示下载神经网络参数的进度条若为False则不显示下载神经网络参数的进度条
通俗来理解pretrained,就相当于什么呢?比如搭建神经网络卷积层时,你给了一个kernel_size但是并没有kernel_size中的参数,pretrained=True时相当于你得到了一个带参数的卷积核,pretrained=False时相当于你只知道这个卷积核的大小。
先有网络模型的修改(如何利用现有的网络去改动它的一个结构)
1.添加网络层
示例代码如下:
import torchvision
from torch import nn
# 加载网络
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 1.添加网络层
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
# 方式1:在整个网络中直接添加
# vgg16_true.add_module("add_linear",nn.Linear(1000,10))
# 方式2:在相应的模块中添加
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print("vgg16_true:\n",vgg16_true)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么要添加一个in_feature=1000,out_feature=10的线性层呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要添加一个in_feature=1000,out_feature=10的线性层。
2.直接修改网络
示例代码如下:
import torchvision
from torch import nn
# 加载网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
print("vgg16_false:\n",vgg16_false)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print("vgg16_true:\n",vgg16_true)
# 如何利用现有的网络去改动他的一个结构
# 2.直接修改网络
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10("./CIFAR10",train=True,transform=torchvision.transforms.ToTensor())
# 将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
# 按顺序对网络进行索引,修改最后的线性层
vgg16_false.classifier[6] = nn.Linear(4096,10)
print("vgg16_false",vgg16_false)
运行结果:

讲解:将vgg16_true模型应用到CIFAR10数据集上,为什么修改最后的线性层out_feature=10呢?因为vgg16_true网络训练的ImageNet数据集有1000个分类,而CIFAR10只有10分类,所以要将vgg16_true网络应用在CIFAR10上的话,需要修改最后的线性层out_feature=10。
边栏推荐
- 2022年煤气考试题库及在线模拟考试
- 一些小知识点
- Account sharing technology enables the farmers' market and reshapes the efficiency of transaction management services
- Ospfb notes - five messages [ultra detailed] [Hello message, DD message, LSR message, LSU message, lsack message]
- Learn Chapter 20 of vue3 (keep alive cache component)
- Leetcode learning - day 36
- Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling
- One job hopping up 8K, three times in five years
- 为什么香港服务器最适合海外建站使用
- Tcp/ip explanation (version 2) notes / 3 link layer / 3.4 bridge and switch / 3.4.2 multiple registration protocol (MRP)
猜你喜欢

LM小型可编程控制器软件(基于CoDeSys)笔记十九:报错does not match the profile of the target

Introduction to JVM stack and heap

2022年化工自动化控制仪表操作证考试题库及答案

2022 t elevator repair question bank and simulation test

Embedded System Development Notes 79: why should I get the IP address of the local network card

25.k sets of flipped linked lists

Odeint and GPU

Obtain detailed ideas for ABCDEF questions of 2022 American Games

Offline installation of Wireshark 2.6.10

CF1638E colorful operations
随机推荐
[human version] Web3 privacy game in the dark forest
One click shell to automatically deploy any version of redis
Learn Chapter 20 of vue3 (keep alive cache component)
Kodori tree board
Obtain detailed ideas for ABCDEF questions of 2022 American Games
Strategic suggestions and future development trend of global and Chinese vibration isolator market investment report 2022 Edition
Tcp/ip explanation (version 2) notes / 3 link layer / 3.4 bridge and switch / 3.4.2 multiple registration protocol (MRP)
Tip of edge browser: enter+ctrl can automatically convert the address bar into a web address
Pytest automated testing - compare robotframework framework
Advanced application of ES6 modular and asynchronous programming
嵌入式系統開發筆記80:應用Qt Designer進行主界面設計
2. Use of classlist (element class name)
Task04 | statistiques mathématiques
Common UNIX Operation and maintenance commands of shell
selenium打开chrome浏览器时弹出设置页面:Mircrosoft Defender 防病毒要重置您的设置
Seven crimes of counting software R & D Efficiency
Pytorch(四) —— 可视化工具 Visdom
Summary of testing experience - Testing Theory
2022 Shanghai safety officer C certificate examination question simulation examination question bank and answers
How to choose the right server for website data collection?