当前位置:网站首页>Pytorch实战——MNIST数据集手写数字识别
Pytorch实战——MNIST数据集手写数字识别
2022-07-05 20:55:00 【Sol-itude】
原视频链接:轻松学Pytorch手写字体识别MNIST
1.加载必要的库
#1 加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2.定义超参数
#2 定义超参数
BATCH_SIZE= 64 #每次训练的数据的个数
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")#做一个判断,如果有gpu就用gpu,如果没有的话就用cpu
EPOCHS=20 #训练轮数
3.构建pipeline,对图像做处理
#3 构建pipeline,对图像做处理
pipeline=transforms.Compose([
transforms.ToTensor(),#将图片转换成tensor类型
transforms.Normalize((0.1307,),(0.3081,))#正则化,模型过拟合时,降低模型复杂度
])
4.下载、加载数据集
#下载数据集
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)#文件夹,需要训练,需要下载,要转换成tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
#加载数据
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)#加载训练集,数据个数为64,需要打乱
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
下载完毕之后,看一眼数据集内的图片
with open("MNIST的绝对路径","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)#保存图片
输出结果
5.构建网络模型
#5 构建网络模型
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)#输入batch_size 1*28*28 ,输出batch 10*24*24(28-5+1=24)
x=F.relu(x) #激活函数,保持shape不变
x=F.max_pool2d(x,2,2)#2是池化步长的大小,输出batch20*12*12
x=self.conv2(x)#输入batch20*12*12 输出batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)#拉平,-1自动计算维度 20*10*10=2000
x=self.fc1(x)#输入batch 2000 输出batch 500
x=F.relu(x)
x=self.fc2(x)#输出batch 500 输出batch 10
ouput=F.log_softmax(x,dim=1)#计算分类后每个数字概率
return ouput
6.定义优化器
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())##选择adam优化器
7.定义训练方法
#7 定义训练方法
def train_model(model,device,train_loader,optimizer,epoch):
#模型训练
model.train()
for batch_index,(data,target) in enumerate(train_loader):
#部署到DEVICE上去
data,target=data.to(device), target.to(device)
#梯度初始化为0
optimizer.zero_grad()
#预测,训练后结果
output=model(data)
#计算损失
loss = F.cross_entropy(output,target)#多分类用交叉验证
#反向传播
loss.backward()
#参数优化
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8.定义测试方法
#8 定义测试方法
def test_model(model,device,test_loader):
#模型验证
model.eval()
#正确率
corrcet=0.0
#测试损失
test_loss=0.0
with torch.no_grad(): #不会计算梯度,也不会进行反向传播
for data,target in test_loader:
#部署到device上
data,target=data.to(device),target.to(device)
#测试数据
output=model(data)
#计算测试损失
test_loss+=F.cross_entropy(output,target).item()
#找到概率最大的下标
pred=output.argmax(1)
#累计正确的值
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9.调用方法
#9 调用方法
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
输出结果
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
总结和改进
看完视频之后,老师确实讲得好,但是却没有讲明白为什么网络结构为什么要这样搭建,于是我又去看了看CNN,这个网络结构也能实现
#5 构建网络模型
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)#输入batch_size 1*28*28 ,输出batch 10*24*24(28-5+1=24)
x=F.relu(x) #激活函数,保持shape不变
x=F.max_pool2d(x,2,2)#2是池化步长的大小,输出batch20*12*12
x=self.conv2(x)#输入batch20*12*12 输出batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)#拉平,-1自动计算维度 20*10*10=2000
x=self.fc1(x)#输入batch 2000 输出batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)#计算分类后每个数字概率
return ouput
用CNN之后,发现准确度一下子下降到了50%,百思不得其解,我猜可能是优化器的问题,就把优化器换成了SGD,结果果然效果更好
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
在第17轮准确率居然到了99%,不知道为什么,先挖个坑,等我以后研究明白再来填
边栏推荐
- 示波器探头对测量带宽的影响
- 基于flask写一个接口
- 获取前一天的js(时间戳转换)
- wpf 获取datagrid 中指定行列的DataGridTemplateColumn中的控件
- Binary search
- Abnova e (diii) (WNV) recombinant protein Chinese and English instructions
- Abnova丨CRISPR SpCas9 多克隆抗体方案
- Monorepo management methodology and dependency security
- Interpreting the daily application functions of cooperative robots
- Clear app data and get Icon
猜你喜欢
Abnova丨血液总核酸纯化试剂盒预装相关说明书
CADD course learning (7) -- Simulation of target and small molecule interaction (semi flexible docking autodock)
示波器探头对信号源阻抗的影响
Abbkine丨TraKine F-actin染色试剂盒(绿色荧光)方案
Abnova丨DNA 标记高质量控制测试方案
AI 从代码中自动生成注释文档
教你自己训练的pytorch模型转caffe(二)
2.<tag-哈希表, 字符串>补充: 剑指 Offer 50. 第一个只出现一次的字符 dbc
Use of form text box (II) input filtering (synthetic event)
Specification of protein quantitative kit for abbkine BCA method
随机推荐
使用WebAssembly在浏览器端操作Excel
NPDP如何续证?操作指南来了!
渗透创客精神文化转化的创客教育
LeetCode: Distinct Subsequences [115]
基于flask写一个接口
获取前一天的js(时间戳转换)
Abnova e (diii) (WNV) recombinant protein Chinese and English instructions
Abnova丨 CD81单克隆抗体相关参数和应用
SQL series (basic) - Chapter 2 limiting and sorting data
Abnova maxpab mouse derived polyclonal antibody solution
Is the securities account given by the school of Finance and business safe? Can I open an account?
最长摆动序列[贪心练习]
解析五育融合之下的steam教育模式
Prior knowledge of machine learning in probability theory (Part 1)
树莓派4B上ncnn转换出来的模型调用时总是崩溃(Segment Fault)的原因
从架构上详解技术(SLB,Redis,Mysql,Kafka,Clickhouse)的各类热点问题
示波器探头对信号源阻抗的影响
浅聊我和一些编程语言的缘分
Make Jar, Not War
Binary search