当前位置:网站首页>Pytorch框架学校记录11——搭建小实战完整细节
Pytorch框架学校记录11——搭建小实战完整细节
2022-08-01 20:36:00 【柚子Roo】
Pytorch框架学校记录11——搭建小实战完整细节
1. 搭建小实战和Sequential的使用
我们搭建了一个CIFAR10模型,下面的代码是未使用Sequential的情况。
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.flatten = nn.Flatten()
self.hidden = nn.Linear(in_features=1024, out_features=64)
self.fc = nn.Linear(in_features=64, out_features=10)
def forward(self, input):
input = self.conv1(input)
input = self.maxpool1(input)
input = self.conv2(input)
input = self.maxpool2(input)
input = self.conv3(input)
input = self.maxpool3(input)
input = self.flatten(input)
input = self.hidden(input)
output = self.fc(input)
return output
x = torch.tensor([0.1, 0.2, 0.3])
print(x.shape)
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))
print(x)
2. 损失函数和反向传播
torch.nn.CrossEntropyLoss
(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’, label_smoothing=0.0)
参数:
- Input: Shape (C), (N, C)(N,C) or (N, C, d_1, d_2, …, d_K)(N,C,d1,d2,…,d**K) with K \geq 1K≥1 in the case of K-dimensional loss.
- Target: If containing class indices, shape ()(), (N)(N) or (N, d_1, d_2, …, d_K)(N,d1,d2,…,d**K) with K \geq 1K≥1 in the case of K-dimensional loss where each value should be between [0, C)[0,C). If containing class probabilities, same shape as the input and each value should be between [0, 1][0,1].
在这里我们以交叉熵损失函数为例,backward()
方法为反向传播算法。
from torch import nn
import torchvision
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, 64)
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.flatten = nn.Flatten()
self.hidden = nn.Linear(in_features=1024, out_features=64)
self.fc = nn.Linear(in_features=64, out_features=10)
def forward(self, input):
input = self.conv1(input)
input = self.maxpool1(input)
input = self.conv2(input)
input = self.maxpool2(input)
input = self.conv3(input)
input = self.maxpool3(input)
input = self.flatten(input)
input = self.hidden(input)
output = self.fc(input)
return output
test = Test()
loss = nn.CrossEntropyLoss()
step = 0
for data in dataloader:
imgs, target = data
output = test(imgs)
res = loss(output, target)
res.backward()
print(res)
3. 优化器
优化器的作用:将模型的中的参数根据要求进行实时调整更新,使得模型变得更加优良。
在这里我们使用的是随机梯度下降法(SGD)作为优化器的优化依据。
from torch import nn
import torchvision
import torch
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, 64)
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.flatten = nn.Flatten()
self.hidden = nn.Linear(in_features=1024, out_features=64)
self.fc = nn.Linear(in_features=64, out_features=10)
def forward(self, input):
input = self.conv1(input)
input = self.maxpool1(input)
input = self.conv2(input)
input = self.maxpool2(input)
input = self.conv3(input)
input = self.maxpool3(input)
input = self.flatten(input)
input = self.hidden(input)
output = self.fc(input)
return output
test = Test()
loss = nn.CrossEntropyLoss()
step = 0
optimer = torch.optim.SGD(params=test.parameters(), lr=0.01)
for epoch in range(20):
loss_sum = 0.0
for data in dataloader:
imgs, target = data
output = test(imgs)
res = loss(output, target)
optimer.zero_grad()
res.backward()
optimer.step()
loss_sum += res
print(loss_sum)
4. 现有网络模型的使用及修改
我们使用Pytorch框架中的VGG16模型,并将该模型的全连接层的输出特征的个数设置为10,
torchvision.models.vgg16
(***, weights: Optional[torchvision.models.vgg.VGG16_Weights] = None, progress: bool = True, **kwargs: Any)
参数:
pretrained
:设置为True代表加载预训练模型
对现有网络模型的修改可分为两种方式,一种为添加,另一种为修改。
详细的操作方法如下:
import torch
from torch import nn
import torchvision
vgg_true = torchvision.models.vgg16(pretrained=True)
vgg_false = torchvision.models.vgg16(pretrained=False)
vgg_true.add_module("add_model", nn.Linear(in_features=1000, out_features=10))
print(vgg_true)
vgg_false.classifier[6] = nn.Linear(in_features=4096, out_features=10)
print(vgg_false)
5. 模型的保存与读取
模型的保存有两种方式:
import torchvision
import torch
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1 模型结构+模型参数
torch.save(vgg16, "vgg16_methods1.pth")
# 保存方式2 模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_methods2.pth")
模型的加载方式也有两种方式
import torch
import torchvision
# 方式1 加载模型
vgg16_model1 = torch.load("vgg16_methods1.pth")
print(vgg16_model1)
# 方式2 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.state_dict(torch.load("vgg16_methods2.pth"))
print(vgg16)
边栏推荐
- AQS原理和介绍
- vant实现Select效果--单选和多选
- WeChat applet cloud development | personal blog applet
- 线上问题排查常用命令,总结太全了,建议收藏!!
- Determine a binary tree given inorder traversal and another traversal method
- [Energy Conservation Institute] Ankerui Food and Beverage Fume Monitoring Cloud Platform Helps Fight Air Pollution
- tiup mirror grant
- Debug一个ECC的ODP数据源
- "Torch" tensor multiplication: matmul, einsum
- 小数据如何学习?吉大最新《小数据学习》综述,26页pdf涵盖269页文献阐述小数据学习理论、方法与应用
猜你喜欢
【Social Media Marketing】How to know if your WhatsApp is blocked?
【多任务模型】Progressive Layered Extraction: A Novel Multi-Task Learning Model for Personalized(RecSys‘20)
【nn.Parameter()】生成和为什么要初始化
Godaddy domain name resolution is slow and how to use DNSPod resolution to solve it
To promote energy conservation institute 】 【 the opinions of the agricultural water price reform
启明云端分享|盘点ESP8684开发板有哪些功能
用户体验好的Button,在手机上不应该有Hover态
【Untitled】
[Multi-task learning] Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts KDD18
【社媒营销】如何知道自己的WhatsApp是否被屏蔽了?
随机推荐
【Social Media Marketing】How to know if your WhatsApp is blocked?
Multithreaded producers and consumers
我的驾照考试笔记(2)
【nn.Parameter()】生成和为什么要初始化
【Dart】dart构造函数学习记录(含dart单例模式写法)
解除360对默认浏览器的检测与修改
Excel advanced drawing techniques, 100 (22) - how to respectively the irregular data
C语言实现-直接插入排序(带图详解)
使用微信公众号给指定微信用户发送信息
Application of Acrel-5010 online monitoring system for key energy consumption unit energy consumption in Hunan Sanli Group
string
面试突击70:什么是粘包和半包?怎么解决?
自定义指令,获取焦点
数据库单字段存储多个标签(位移操作)
[Energy Conservation Institute] Comparative analysis of smart small busbar and column head cabinet solutions in data room
Go 语言中常见的坑
Redis 做签到统计
[Personal Work] Remember - Serial Logging Tool
大整数相加,相减,相乘,大整数与普通整数的相乘,相除
【Dart】dart之mixin探究