当前位置:网站首页>Pytorch框架学习记录12——完整的模型训练套路
Pytorch框架学习记录12——完整的模型训练套路
2022-08-01 20:36:00 【柚子Roo】
Pytorch框架学习记录12——完整的模型训练套路
本次模型的主要使用CIFIAR10数据集,搭建了CIFIAR 10模型。
首先,需要对数据集进行下载读取,并进行分组。
# 读取数据集
trainset = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)
testset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 数据分组
train_loader = DataLoader(trainset, 64)
test_loader = DataLoader(testset, 64)
搭建模型,并在train.py文件中导入已构建好的模型,定义优化器、损失函数和相关的一些变量。
model.py
import torch
from torch import nn
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
nn.MaxPool2d(kernel_size=2),
nn.Flatten(),
nn.Linear(in_features=1024, out_features=64),
nn.Linear(in_features=64, out_features=10)
)
def forward(self, input):
output = self.model(input)
return output
if __name__ == '__main__':
input = torch.ones((64, 3, 32, 32))
test = Test()
output = test(input)
print(output.shape)
train.py文件
# 创建网络模型
test = Test()
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.SGD(params=test.parameters(), lr=learning_rate)
# 记录训练次数
train_step = 0
# 记录测试次数
test_step = 0
writer = SummaryWriter("logs")
接下来,对模型进行训练,
# 训练
epoch = 20
for i in range(epoch):
# 训练步骤
print("-------第 {} 轮训练-------".format(i+1))
test.train()
for train_data in train_loader:
imgs, target = train_data
output = test(imgs)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_step += 1
if train_step % 100 == 0:
print("第 {} 次训练完成 训练损失:{}".format(train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), train_step)
在测试集上进行测试,
# 测试步骤
test.eval()
test_loss_sum = 0.0
total_accuracy = 0
with torch.no_grad():
for test_data in test_loader:
imgs, target = test_data
output = test(imgs)
loss = loss_fn(output, target)
accuracy = (output.argmax(1) == target).sum()
test_loss_sum += loss.item()
total_accuracy += accuracy
writer.add_scalar("test_loss", test_loss_sum, test_step)
print("在测试集上的Loss:{}, 正确率:{}".format(test_loss_sum, total_accuracy/len(testset)))
test_step += 1
【完整源代码】
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from Pytorch_Learning.model import *
# 读取数据集
trainset = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor(), download=True)
testset = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 数据分组
train_loader = DataLoader(trainset, 64)
test_loader = DataLoader(testset, 64)
# 创建网络模型
test = Test()
# 定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.SGD(params=test.parameters(), lr=learning_rate)
# 记录训练次数
train_step = 0
# 记录测试次数
test_step = 0
writer = SummaryWriter("logs")
# 训练
epoch = 20
for i in range(epoch):
# 训练步骤
print("-------第 {} 轮训练-------".format(i+1))
test.train()
for train_data in train_loader:
imgs, target = train_data
output = test(imgs)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_step += 1
if train_step % 100 == 0:
print("第 {} 次训练完成 训练损失:{}".format(train_step, loss.item()))
writer.add_scalar("train_loss", loss.item(), train_step)
# 测试步骤
test.eval()
test_loss_sum = 0.0
total_accuracy = 0
with torch.no_grad():
for test_data in test_loader:
imgs, target = test_data
output = test(imgs)
loss = loss_fn(output, target)
accuracy = (output.argmax(1) == target).sum()
test_loss_sum += loss.item()
total_accuracy += accuracy
writer.add_scalar("test_loss", test_loss_sum, test_step)
print("在测试集上的Loss:{}, 正确率:{}".format(test_loss_sum, total_accuracy/len(testset)))
test_step += 1
边栏推荐
猜你喜欢
【Dart】dart之mixin探究
[Energy Conservation Institute] Ankerui Food and Beverage Fume Monitoring Cloud Platform Helps Fight Air Pollution
C语言实现-直接插入排序(带图详解)
【多任务学习】Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts KDD18
】 【 nn. The Parameter () to generate and why do you want to initialize
AQS原理和介绍
【torch】张量乘法:matmul,einsum
启明云端分享|盘点ESP8684开发板有哪些功能
vant实现Select效果--单选和多选
Get started quickly with MongoDB
随机推荐
string
用户身份标识与账号体系实践
通俗解释:什么是临床预测模型
数据库单字段存储多个标签(位移操作)
徒步,治好了我的精神内耗
SIPp 安装及使用
专利检索常用的网站有哪些?
【无标题】
Interview Blitz 70: What are sticky packs and half packs?How to deal with it?
我的驾照考试笔记(3)
tiup mirror genkey
Postman 批量测试接口详细教程
Redis 做网页UV统计
宝塔搭建PESCMS-Ticket开源客服工单系统源码实测
What is the difference between a utility model patent and an invention patent?Understand in seconds!
仿牛客论坛项目
Zheng Xiangling, Chairman of Tide Pharmaceuticals, won the "2022 Outstanding Influential Entrepreneur Award" Tide Pharmaceuticals won the "Corporate Social Responsibility Model Award"
实用新型专利和发明专利的区别?秒懂!
Excel advanced drawing techniques, 100 (22) - how to respectively the irregular data
idea插件generateAllSetMethod一键生成set/get方法以及bean对象转换