当前位置:网站首页>Pytorch框架学校记录11——搭建小实战完整细节
Pytorch框架学校记录11——搭建小实战完整细节
2022-08-01 20:36:00 【柚子Roo】
Pytorch框架学校记录11——搭建小实战完整细节
1. 搭建小实战和Sequential的使用
我们搭建了一个CIFAR10模型,下面的代码是未使用Sequential的情况。
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.flatten = nn.Flatten()
self.hidden = nn.Linear(in_features=1024, out_features=64)
self.fc = nn.Linear(in_features=64, out_features=10)
def forward(self, input):
input = self.conv1(input)
input = self.maxpool1(input)
input = self.conv2(input)
input = self.maxpool2(input)
input = self.conv3(input)
input = self.maxpool3(input)
input = self.flatten(input)
input = self.hidden(input)
output = self.fc(input)
return output
x = torch.tensor([0.1, 0.2, 0.3])
print(x.shape)
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))
print(x)
2. 损失函数和反向传播
torch.nn.CrossEntropyLoss
(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’, label_smoothing=0.0)
参数:
- Input: Shape (C), (N, C)(N,C) or (N, C, d_1, d_2, …, d_K)(N,C,d1,d2,…,d**K) with K \geq 1K≥1 in the case of K-dimensional loss.
- Target: If containing class indices, shape ()(), (N)(N) or (N, d_1, d_2, …, d_K)(N,d1,d2,…,d**K) with K \geq 1K≥1 in the case of K-dimensional loss where each value should be between [0, C)[0,C). If containing class probabilities, same shape as the input and each value should be between [0, 1][0,1].
在这里我们以交叉熵损失函数为例,backward()
方法为反向传播算法。
from torch import nn
import torchvision
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, 64)
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.flatten = nn.Flatten()
self.hidden = nn.Linear(in_features=1024, out_features=64)
self.fc = nn.Linear(in_features=64, out_features=10)
def forward(self, input):
input = self.conv1(input)
input = self.maxpool1(input)
input = self.conv2(input)
input = self.maxpool2(input)
input = self.conv3(input)
input = self.maxpool3(input)
input = self.flatten(input)
input = self.hidden(input)
output = self.fc(input)
return output
test = Test()
loss = nn.CrossEntropyLoss()
step = 0
for data in dataloader:
imgs, target = data
output = test(imgs)
res = loss(output, target)
res.backward()
print(res)
3. 优化器
优化器的作用:将模型的中的参数根据要求进行实时调整更新,使得模型变得更加优良。
在这里我们使用的是随机梯度下降法(SGD)作为优化器的优化依据。
from torch import nn
import torchvision
import torch
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10(root='dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, 64)
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.flatten = nn.Flatten()
self.hidden = nn.Linear(in_features=1024, out_features=64)
self.fc = nn.Linear(in_features=64, out_features=10)
def forward(self, input):
input = self.conv1(input)
input = self.maxpool1(input)
input = self.conv2(input)
input = self.maxpool2(input)
input = self.conv3(input)
input = self.maxpool3(input)
input = self.flatten(input)
input = self.hidden(input)
output = self.fc(input)
return output
test = Test()
loss = nn.CrossEntropyLoss()
step = 0
optimer = torch.optim.SGD(params=test.parameters(), lr=0.01)
for epoch in range(20):
loss_sum = 0.0
for data in dataloader:
imgs, target = data
output = test(imgs)
res = loss(output, target)
optimer.zero_grad()
res.backward()
optimer.step()
loss_sum += res
print(loss_sum)
4. 现有网络模型的使用及修改
我们使用Pytorch框架中的VGG16模型,并将该模型的全连接层的输出特征的个数设置为10,
torchvision.models.vgg16
(***, weights: Optional[torchvision.models.vgg.VGG16_Weights] = None, progress: bool = True, **kwargs: Any)
参数:
pretrained
:设置为True代表加载预训练模型
对现有网络模型的修改可分为两种方式,一种为添加,另一种为修改。
详细的操作方法如下:
import torch
from torch import nn
import torchvision
vgg_true = torchvision.models.vgg16(pretrained=True)
vgg_false = torchvision.models.vgg16(pretrained=False)
vgg_true.add_module("add_model", nn.Linear(in_features=1000, out_features=10))
print(vgg_true)
vgg_false.classifier[6] = nn.Linear(in_features=4096, out_features=10)
print(vgg_false)
5. 模型的保存与读取
模型的保存有两种方式:
import torchvision
import torch
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1 模型结构+模型参数
torch.save(vgg16, "vgg16_methods1.pth")
# 保存方式2 模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_methods2.pth")
模型的加载方式也有两种方式
import torch
import torchvision
# 方式1 加载模型
vgg16_model1 = torch.load("vgg16_methods1.pth")
print(vgg16_model1)
# 方式2 加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.state_dict(torch.load("vgg16_methods2.pth"))
print(vgg16)
边栏推荐
猜你喜欢
Qt设置应用程序开机自启 解决设置失败原因
【节能学院】智能操控装置在高压开关柜的应用
【nn.Parameter()】生成和为什么要初始化
C语言实现-直接插入排序(带图详解)
Imitation cattle forum project
idea插件generateAllSetMethod一键生成set/get方法以及bean对象转换
我的驾照考试笔记(2)
小数据如何学习?吉大最新《小数据学习》综述,26页pdf涵盖269页文献阐述小数据学习理论、方法与应用
【节能学院】推进农业水价综合改革的意见解读
Zheng Xiangling, Chairman of Tide Pharmaceuticals, won the "2022 Outstanding Influential Entrepreneur Award" Tide Pharmaceuticals won the "Corporate Social Responsibility Model Award"
随机推荐
【节能学院】推进农业水价综合改革的意见解读
latex paper artifact -- server deployment overleaf
WhatsApp群发实战分享——WhatsApp Business API账号
任务调度线程池基本介绍
瀚高数据导入
【Untitled】
SIPp 安装及使用
用户身份标识与账号体系实践
Addition, Subtraction, Multiplication of Large Integers, Multiplication and Division of Large Integers and Ordinary Integers
Qt设置应用程序开机自启 解决设置失败原因
tiup mirror genkey
New graduate students, great experience in reading English literature, worthy of your collection
扣减库存方案
The configuration manual for the secondary development of the XE training system of the missing moment document system
Remove 360's detection and modification of the default browser
LinkedList源码分享
Little data on how to learn?Jida latest small learning data review, 26 PDF page covers the 269 - page document small data learning theory, method and application are expounded
【Social Media Marketing】How to know if your WhatsApp is blocked?
LinkedList source code sharing
线程池处理异常的方法