当前位置:网站首页>Pytorch框架学习记录12——完整的模型训练套路
Pytorch框架学习记录12——完整的模型训练套路
2022-08-01 20:36:00 【柚子Roo】
Pytorch框架学习记录12——完整的模型训练套路
本次模型的主要使用CIFIAR10数据集,搭建了CIFIAR 10模型。
首先,需要对数据集进行下载读取,并进行分组。
# 读取数据集
trainset = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)
testset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 数据分组
train_loader = DataLoader(trainset, 64)
test_loader = DataLoader(testset, 64)
搭建模型,并在train.py文件中导入已构建好的模型,定义优化器、损失函数和相关的一些变量。
model.py
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=1024, out_features=64),
nn.Linear(in_features=64, out_features=10)
)
def forward(self, input):
output = self.model(input)
return output
if __name__ == '__main__':
input = torch.ones((64, 3, 32, 32))
test = Test()
output = test(input)
print(output.shape)
train.py文件
# 创建网络模型
test = Test()
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.SGD(params=test.parameters(), lr=learning_rate)
# 记录训练次数
train_step = 0
# 记录测试次数
test_step = 0
writer = SummaryWriter("logs")
接下来,对模型进行训练,
# 训练
epoch = 20
for i in range(epoch):
# 训练步骤
print("-------第 {} 轮训练-------".format(i+1))
test.train()
for train_data in train_loader:
imgs, target = train_data
output = test(imgs)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_step += 1
if train_step % 100 == 0:
print("第 {} 次训练完成 训练损失:{}".format(train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), train_step)
在测试集上进行测试,
# 测试步骤
test.eval()
test_loss_sum = 0.0
total_accuracy = 0
with torch.no_grad():
for test_data in test_loader:
imgs, target = test_data
output = test(imgs)
loss = loss_fn(output, target)
accuracy = (output.argmax(1) == target).sum()
test_loss_sum += loss.item()
total_accuracy += accuracy
writer.add_scalar("test_loss", test_loss_sum, test_step)
print("在测试集上的Loss:{}, 正确率:{}".format(test_loss_sum, total_accuracy/len(testset)))
test_step += 1
【完整源代码】
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from Pytorch_Learning.model import *
# 读取数据集
trainset = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)
testset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 数据分组
train_loader = DataLoader(trainset, 64)
test_loader = DataLoader(testset, 64)
# 创建网络模型
test = Test()
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.SGD(params=test.parameters(), lr=learning_rate)
# 记录训练次数
train_step = 0
# 记录测试次数
test_step = 0
writer = SummaryWriter("logs")
# 训练
epoch = 20
for i in range(epoch):
# 训练步骤
print("-------第 {} 轮训练-------".format(i+1))
test.train()
for train_data in train_loader:
imgs, target = train_data
output = test(imgs)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_step += 1
if train_step % 100 == 0:
print("第 {} 次训练完成 训练损失:{}".format(train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), train_step)
# 测试步骤
test.eval()
test_loss_sum = 0.0
total_accuracy = 0
with torch.no_grad():
for test_data in test_loader:
imgs, target = test_data
output = test(imgs)
loss = loss_fn(output, target)
accuracy = (output.argmax(1) == target).sum()
test_loss_sum += loss.item()
total_accuracy += accuracy
writer.add_scalar("test_loss", test_loss_sum, test_step)
print("在测试集上的Loss:{}, 正确率:{}".format(test_loss_sum, total_accuracy/len(testset)))
test_step += 1
边栏推荐
- [Energy Conservation Institute] Application of Intelligent Control Device in High Voltage Switchgear
- Software you should know as a programmer
- 任务调度线程池基本介绍
- 有点奇怪!访问目的网址,主机能容器却不行
- LTE time domain and frequency domain resources
- Batch get protein .pdb files based on Uniprot ID/PDB ID
- 数据库内核面试中我不会的问题(1)
- 1374. 生成每种字符都是奇数个的字符串 : 简单构造模拟题
- Fork/Join线程池
- Redis does web page UV statistics
猜你喜欢

有点奇怪!访问目的网址,主机能容器却不行

SIPp installation and use

数据库单字段存储多个标签(位移操作)

Use WeChat official account to send information to designated WeChat users

启明云端分享|盘点ESP8684开发板有哪些功能
![[Energy Conservation Institute] Ankerui Food and Beverage Fume Monitoring Cloud Platform Helps Fight Air Pollution](/img/ca/e67c8e2196adb5a078acc44ba5ad6f.jpg)
[Energy Conservation Institute] Ankerui Food and Beverage Fume Monitoring Cloud Platform Helps Fight Air Pollution

Debug一个ECC的ODP数据源

【节能学院】智能操控装置在高压开关柜的应用
![[Energy Conservation Institute] Application of Intelligent Control Device in High Voltage Switchgear](/img/6d/05233ce5c91a612b6247ea07d7982e.jpg)
[Energy Conservation Institute] Application of Intelligent Control Device in High Voltage Switchgear

【多任务学习】Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts KDD18
随机推荐
tiup mirror genkey
The Internet giant development process
面试突击70:什么是粘包和半包?怎么解决?
仿牛客论坛项目
【Untitled】
Little data on how to learn?Jida latest small learning data review, 26 PDF page covers the 269 - page document small data learning theory, method and application are expounded
大神经验:软件测试的自我发展规划
The configuration manual for the secondary development of the XE training system of the missing moment document system
有点奇怪!访问目的网址,主机能容器却不行
Imitation cattle forum project
"No title"
线程池处理异常的方法
数据库单字段存储多个标签(位移操作)
Protocol Buffer 使用
Goroutine Leaks - The Forgotten Sender
微信小程序云开发|个人博客小程序
【ES】ES2021 我学不动了,这次只学 3 个。
使用常见问题解答软件的好处有哪些?
数据库内核面试中我不会的问题(1)
字符串