当前位置:网站首页>《PyTorch深度学习实践》第九课多分类问题(手写数字MNIST)
《PyTorch深度学习实践》第九课多分类问题(手写数字MNIST)
2022-07-28 08:50:00 【falldeep】
b站刘二视频,地址:
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
代码(课中作业)MNIST手写数字识别,采用全连接的方式

运行结果
loss

准确率

代码
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
BATCH_SIZE = 64
TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# tenser类型 归一化
train_set = datasets.MNIST(root='mnist', train=True, transform=TRANSFORM, download=True)
# 保存地址 训练集 若不存在则下载
train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=BATCH_SIZE)#loader
# 设置数据源 乱序
test_set = datasets.MNIST(root='mnist', train=False, transform=TRANSFORM, download=True)
test_loader = DataLoader(dataset=test_set, shuffle=False, batch_size=BATCH_SIZE)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear1 = torch.nn.Linear(784, 512)
self.linear2 = torch.nn.Linear(512, 256)
self.linear3 = torch.nn.Linear(256, 128)
self.linear4 = torch.nn.Linear(128, 64)
self.linear5 = torch.nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)#将其变成向量
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = F.relu(self.linear3(x))
x = F.relu(self.linear4(x))#采用relu做激活函数
return self.linear5(x) #最后一层不需要做激活
model = Net()
criteration = torch.nn.CrossEntropyLoss() #交叉熵损失函数
optimizor = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
#动量 优化
def train(): #训练
sum = 0
for i, data in enumerate(train_loader, 0):
inputs, lables = data
y_pred = model(inputs)
loss = criteration(y_pred, lables)
sum += loss.item()
optimizor.zero_grad()
loss.backward()
optimizor.step()
if i % 300 == 299:#每300次输出一次
sum /= 300
loss_lst.append(sum)
sum = 0
def test():
total = 0
correct = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, lables = data
outputs = model(inputs)
_, predicted = torch.max(outputs.data, dim=1) #找出每行最大的那个数字(就是预测结果)
correct += (predicted == lables).sum().item()#正确预测的数量
total += lables.size(0)
accurate_lst.append(correct / total * 100)
if __name__ == '__main__':
loss_lst = []
accurate_lst = []
for epoch in range(10):
train()
test()
#可视化画图
num_lst = [i for i in range(len(loss_lst))]
plt.plot(num_lst, loss_lst)
plt.xlabel('i')
plt.ylabel('loss')
plt.show()
num_lst = [i for i in range(len(accurate_lst))]
plt.plot(num_lst, accurate_lst)
plt.xlabel('i')
plt.ylabel('accurate')
plt.show()
边栏推荐
- 从开发转测试:我从零开始,一干就是6年的自动化测试历程
- Face warp - hand tear code
- The new mode of 3D panoramic display has become the key to breaking the game
- A perfect cross compilation environment records the shell scripts generated by PETA
- v-bind指令的详细介绍
- Informatics Olympiad all in one 1617: circle game | 1875: [13noip improvement group] circle game | Luogu p1965 [noip2013 improvement group] circle game
- 2022安全员-C证特种作业证考试题库及答案
- 12 common design ideas of design for failure
- DAPP safety summary and typical safety incident analysis
- Dn-detr paper accuracy, and analyze its model structure & 2022 CVPR paper
猜你喜欢
![[附下载]推荐几款暴力破解和字典生成的工具](/img/c6/f4a9c566ff21a8e133a8a991108201.png)
[附下载]推荐几款暴力破解和字典生成的工具

js数组去重,id相同对某值相加合并

训练一个自己的分类 | 【包教包会,数据都准备好了】

IDC脚本文件运行

golang升级到1.18.4版本 遇到的问题

2022 safety officer-b certificate examination simulated 100 questions and answers

MQTT. JS introductory tutorial: learning notes

Detailed introduction of v-bind instruction

Alibaba cloud server setup and pagoda panel connection

From development to testing: I started from scratch and worked for six years of automated testing
随机推荐
关闭页面时向后台发送消息
[附下载]推荐几款暴力破解和字典生成的工具
力扣题(1)—— 两数之和
376. Swing sequence [greedy, dynamic planning -----]
golang升级到1.18.4版本 遇到的问题
剑指offer
ES6 let and Const
数据泄漏、删除事件频发,企业应如何构建安全防线?
【592. 分数加减运算】
【多线程】println方法底层原理
一款入门神器TensorFlowPlayground
Design for failure常见的12种设计思想
【高数】高数平面立体几何
Title and answer of work permit for safety management personnel of hazardous chemical business units in 2022
快速上手Flask(一) 认识框架Flask、项目结构、开发环境
10. Learn MySQL like clause
DN-DETR 论文精度,并解析其模型结构 & 2022年CVPR论文
[solution] error in [eslint] eslint is not a constructor
Activiti启报错: Cannot create PoolableConnectionFactory (Could not create connection to database server
Oracle-11gR2默认的系统JOB