当前位置:网站首页>Pytorch实战——MNIST数据集手写数字识别
Pytorch实战——MNIST数据集手写数字识别
2022-07-05 20:55:00 【Sol-itude】
原视频链接:轻松学Pytorch手写字体识别MNIST
1.加载必要的库
#1 加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2.定义超参数
#2 定义超参数
BATCH_SIZE= 64 #每次训练的数据的个数
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")#做一个判断,如果有gpu就用gpu,如果没有的话就用cpu
EPOCHS=20 #训练轮数
3.构建pipeline,对图像做处理
#3 构建pipeline,对图像做处理
pipeline=transforms.Compose([
transforms.ToTensor(),#将图片转换成tensor类型
transforms.Normalize((0.1307,),(0.3081,))#正则化,模型过拟合时,降低模型复杂度
])
4.下载、加载数据集
#下载数据集
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)#文件夹,需要训练,需要下载,要转换成tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
#加载数据
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)#加载训练集,数据个数为64,需要打乱
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
下载完毕之后,看一眼数据集内的图片
with open("MNIST的绝对路径","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)#保存图片
输出结果
5.构建网络模型
#5 构建网络模型
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)#输入batch_size 1*28*28 ,输出batch 10*24*24(28-5+1=24)
x=F.relu(x) #激活函数,保持shape不变
x=F.max_pool2d(x,2,2)#2是池化步长的大小,输出batch20*12*12
x=self.conv2(x)#输入batch20*12*12 输出batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)#拉平,-1自动计算维度 20*10*10=2000
x=self.fc1(x)#输入batch 2000 输出batch 500
x=F.relu(x)
x=self.fc2(x)#输出batch 500 输出batch 10
ouput=F.log_softmax(x,dim=1)#计算分类后每个数字概率
return ouput
6.定义优化器
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())##选择adam优化器
7.定义训练方法
#7 定义训练方法
def train_model(model,device,train_loader,optimizer,epoch):
#模型训练
model.train()
for batch_index,(data,target) in enumerate(train_loader):
#部署到DEVICE上去
data,target=data.to(device), target.to(device)
#梯度初始化为0
optimizer.zero_grad()
#预测,训练后结果
output=model(data)
#计算损失
loss = F.cross_entropy(output,target)#多分类用交叉验证
#反向传播
loss.backward()
#参数优化
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8.定义测试方法
#8 定义测试方法
def test_model(model,device,test_loader):
#模型验证
model.eval()
#正确率
corrcet=0.0
#测试损失
test_loss=0.0
with torch.no_grad(): #不会计算梯度,也不会进行反向传播
for data,target in test_loader:
#部署到device上
data,target=data.to(device),target.to(device)
#测试数据
output=model(data)
#计算测试损失
test_loss+=F.cross_entropy(output,target).item()
#找到概率最大的下标
pred=output.argmax(1)
#累计正确的值
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9.调用方法
#9 调用方法
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
输出结果
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
总结和改进
看完视频之后,老师确实讲得好,但是却没有讲明白为什么网络结构为什么要这样搭建,于是我又去看了看CNN,这个网络结构也能实现
#5 构建网络模型
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)#输入batch_size 1*28*28 ,输出batch 10*24*24(28-5+1=24)
x=F.relu(x) #激活函数,保持shape不变
x=F.max_pool2d(x,2,2)#2是池化步长的大小,输出batch20*12*12
x=self.conv2(x)#输入batch20*12*12 输出batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)#拉平,-1自动计算维度 20*10*10=2000
x=self.fc1(x)#输入batch 2000 输出batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)#计算分类后每个数字概率
return ouput
用CNN之后,发现准确度一下子下降到了50%,百思不得其解,我猜可能是优化器的问题,就把优化器换成了SGD,结果果然效果更好
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
在第17轮准确率居然到了99%,不知道为什么,先挖个坑,等我以后研究明白再来填
边栏推荐
- Simple getting started example of Web Service
- Phpstudy Xiaopi's MySQL Click to start and quickly flash back. It has been solved
- 挖财商学院给的证券账户安全吗?可以开户吗?
- LeetCode: Distinct Subsequences [115]
- Abnova total RNA Purification Kit for cultured cells Chinese and English instructions
- Clear app data and get Icon
- The development of research tourism practical education helps the development of cultural tourism industry
- When a user logs in, there is often a real-time drop-down box. For example, entering an email will @qq com,@163. com,@sohu. com
- Typhoon is coming! How to prevent typhoons on construction sites!
- matplotlib绘图润色(如何形成高质量的图,例如设如何置字体等)
猜你喜欢

Wanglaoji pharmaceutical's public welfare activity of "caring for the most lovely people under the scorching sun" was launched in Nanjing

Cutting edge technology for cultivating robot education creativity

Duchefa丨MS培养基含维生素说明书

教你自己训练的pytorch模型转caffe(一)

Abnova丨 MaxPab 小鼠源多克隆抗体解决方案

解析五育融合之下的steam教育模式

使用WebAssembly在浏览器端操作Excel

XML建模

Promouvoir le développement de l'industrie culturelle et touristique par la recherche, l'apprentissage et l'enseignement pratique du tourisme

请查收.NET MAUI 的最新学习资源
随机推荐
教你自己训练的pytorch模型转caffe(一)
挖财商学院给的证券账户安全吗?可以开户吗?
王老吉药业“关爱烈日下最可爱的人”公益活动在南京启动
The development of research tourism practical education helps the development of cultural tourism industry
LeetCode: Distinct Subsequences [115]
AITM 2-0003 水平燃烧试验
PHP反序列化+MD5碰撞
研学旅游实践教育的开展助力文旅产业发展
基于AVFoundation实现视频录制的两种方式
Abnova丨培养细胞总 RNA 纯化试剂盒中英文说明书
Interpreting the daily application functions of cooperative robots
Monorepo management methodology and dependency security
Careercup its 1.8 serial shift includes problems
Where is a good stock account? Is online account manager safe to open an account
解析五育融合之下的steam教育模式
Monorepo管理方法论和依赖安全
Implementation of redis unique ID generator
【UE4】UnrealInsight获取真机性能测试报告
解读协作型机器人的日常应用功能
示波器探头对测量带宽的影响