当前位置:网站首页>《PyTorch深度学习实践》第九课多分类问题(手写数字MNIST)
《PyTorch深度学习实践》第九课多分类问题(手写数字MNIST)
2022-07-28 08:50:00 【falldeep】
b站刘二视频,地址:
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
代码(课中作业)MNIST手写数字识别,采用全连接的方式

运行结果
loss

准确率

代码
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
BATCH_SIZE = 64
TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# tenser类型 归一化
train_set = datasets.MNIST(root='mnist', train=True, transform=TRANSFORM, download=True)
# 保存地址 训练集 若不存在则下载
train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=BATCH_SIZE)#loader
# 设置数据源 乱序
test_set = datasets.MNIST(root='mnist', train=False, transform=TRANSFORM, download=True)
test_loader = DataLoader(dataset=test_set, shuffle=False, batch_size=BATCH_SIZE)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear1 = torch.nn.Linear(784, 512)
self.linear2 = torch.nn.Linear(512, 256)
self.linear3 = torch.nn.Linear(256, 128)
self.linear4 = torch.nn.Linear(128, 64)
self.linear5 = torch.nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)#将其变成向量
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = F.relu(self.linear3(x))
x = F.relu(self.linear4(x))#采用relu做激活函数
return self.linear5(x) #最后一层不需要做激活
model = Net()
criteration = torch.nn.CrossEntropyLoss() #交叉熵损失函数
optimizor = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
#动量 优化
def train(): #训练
sum = 0
for i, data in enumerate(train_loader, 0):
inputs, lables = data
y_pred = model(inputs)
loss = criteration(y_pred, lables)
sum += loss.item()
optimizor.zero_grad()
loss.backward()
optimizor.step()
if i % 300 == 299:#每300次输出一次
sum /= 300
loss_lst.append(sum)
sum = 0
def test():
total = 0
correct = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, lables = data
outputs = model(inputs)
_, predicted = torch.max(outputs.data, dim=1) #找出每行最大的那个数字(就是预测结果)
correct += (predicted == lables).sum().item()#正确预测的数量
total += lables.size(0)
accurate_lst.append(correct / total * 100)
if __name__ == '__main__':
loss_lst = []
accurate_lst = []
for epoch in range(10):
train()
test()
#可视化画图
num_lst = [i for i in range(len(loss_lst))]
plt.plot(num_lst, loss_lst)
plt.xlabel('i')
plt.ylabel('loss')
plt.show()
num_lst = [i for i in range(len(accurate_lst))]
plt.plot(num_lst, accurate_lst)
plt.xlabel('i')
plt.ylabel('accurate')
plt.show()
边栏推荐
- Realize batch data enhancement | use of keras imagedatagenerator
- RGB-T追踪——【多模态融合】Visible-Thermal UAV Tracking: A Large-Scale Benchmark and New Baseline
- FPGA development learning open source website summary
- The cooperation between starfish OS and metabell is just the beginning
- [English postgraduate entrance examination vocabulary training camp] day 15 - analyze, general, avoid, surveillance, compared
- Sentinel
- How to use gbase C API in multithreaded environment?
- MySQL 8.0.30 GA
- An entry artifact tensorflowplayground
- Map of China province > City > level > District > Town > village 5-level linkage download [2019 and 2021]
猜你喜欢

Bluetooth technology | the total scale of charging piles in Beijing will reach 700000 in 2025. Talk about the indissoluble relationship between Bluetooth and charging piles

Detailed introduction of v-bind instruction

Implementation principle of golang synergy

Recommend an artifact to get rid of the entanglement of variable names and a method to modify file names in batches

Sentinel
![[one flower, one world - Professor Zheng Yi - the way of simplicity] interpretable neural network](/img/fd/8ae7c00061491ad78a0fd68b7c21b0.png)
[one flower, one world - Professor Zheng Yi - the way of simplicity] interpretable neural network

Promise实例如何解决地狱回调

An entry artifact tensorflowplayground

mysql 最大建议行数2000w,靠谱吗?

【C语言】详解顺序表(SeqList)
随机推荐
【高数】高数平面立体几何
Which system table is the keyword of SQL Server in?
【多线程】println方法底层原理
Deconstruction assignment of ES6 variables
[附下载]推荐几款暴力破解和字典生成的工具
Machine learning (11) -- time series analysis
【杂谈】程序员的发展最需要两点能力
2022高压电工考试模拟100题及模拟考试
19c SYSAUX表空间SQLOBJ$PLAN表过大,如何清理
VR全景拍摄,助力民宿多元化宣传
对话MySQL之父:代码一次性完成才是优秀程序员
【leetcode周赛总结】LeetCode第 83场双周赛(7.23)
Rgb-t tracking: [multimodal fusion] visible thermal UAV tracking: a large scale benchmark and new baseline
Magic brace- [group theory] [Burnside lemma] [matrix fast power]
Technology sharing | quick intercom integrated dispatching system
Activiti startup error: cannot create poolableconnectionfactory (could not create connection to database server
数据库核心体系
Hou Jie STL standard library and generic programming
LeetCode_ 406_ Rebuild the queue based on height
IP protocol of network layer