当前位置:网站首页>Pytorch实战——MNIST数据集手写数字识别
Pytorch实战——MNIST数据集手写数字识别
2022-07-05 20:55:00 【Sol-itude】
原视频链接:轻松学Pytorch手写字体识别MNIST
1.加载必要的库
#1 加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2.定义超参数
#2 定义超参数
BATCH_SIZE= 64 #每次训练的数据的个数
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")#做一个判断,如果有gpu就用gpu,如果没有的话就用cpu
EPOCHS=20 #训练轮数
3.构建pipeline,对图像做处理
#3 构建pipeline,对图像做处理
pipeline=transforms.Compose([
transforms.ToTensor(),#将图片转换成tensor类型
transforms.Normalize((0.1307,),(0.3081,))#正则化,模型过拟合时,降低模型复杂度
])
4.下载、加载数据集
#下载数据集
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)#文件夹,需要训练,需要下载,要转换成tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
#加载数据
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)#加载训练集,数据个数为64,需要打乱
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
下载完毕之后,看一眼数据集内的图片
with open("MNIST的绝对路径","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)#保存图片
输出结果
5.构建网络模型
#5 构建网络模型
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)#输入batch_size 1*28*28 ,输出batch 10*24*24(28-5+1=24)
x=F.relu(x) #激活函数,保持shape不变
x=F.max_pool2d(x,2,2)#2是池化步长的大小,输出batch20*12*12
x=self.conv2(x)#输入batch20*12*12 输出batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)#拉平,-1自动计算维度 20*10*10=2000
x=self.fc1(x)#输入batch 2000 输出batch 500
x=F.relu(x)
x=self.fc2(x)#输出batch 500 输出batch 10
ouput=F.log_softmax(x,dim=1)#计算分类后每个数字概率
return ouput
6.定义优化器
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())##选择adam优化器
7.定义训练方法
#7 定义训练方法
def train_model(model,device,train_loader,optimizer,epoch):
#模型训练
model.train()
for batch_index,(data,target) in enumerate(train_loader):
#部署到DEVICE上去
data,target=data.to(device), target.to(device)
#梯度初始化为0
optimizer.zero_grad()
#预测,训练后结果
output=model(data)
#计算损失
loss = F.cross_entropy(output,target)#多分类用交叉验证
#反向传播
loss.backward()
#参数优化
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8.定义测试方法
#8 定义测试方法
def test_model(model,device,test_loader):
#模型验证
model.eval()
#正确率
corrcet=0.0
#测试损失
test_loss=0.0
with torch.no_grad(): #不会计算梯度,也不会进行反向传播
for data,target in test_loader:
#部署到device上
data,target=data.to(device),target.to(device)
#测试数据
output=model(data)
#计算测试损失
test_loss+=F.cross_entropy(output,target).item()
#找到概率最大的下标
pred=output.argmax(1)
#累计正确的值
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9.调用方法
#9 调用方法
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
输出结果
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
总结和改进
看完视频之后,老师确实讲得好,但是却没有讲明白为什么网络结构为什么要这样搭建,于是我又去看了看CNN,这个网络结构也能实现
#5 构建网络模型
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)#输入batch_size 1*28*28 ,输出batch 10*24*24(28-5+1=24)
x=F.relu(x) #激活函数,保持shape不变
x=F.max_pool2d(x,2,2)#2是池化步长的大小,输出batch20*12*12
x=self.conv2(x)#输入batch20*12*12 输出batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)#拉平,-1自动计算维度 20*10*10=2000
x=self.fc1(x)#输入batch 2000 输出batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)#计算分类后每个数字概率
return ouput
用CNN之后,发现准确度一下子下降到了50%,百思不得其解,我猜可能是优化器的问题,就把优化器换成了SGD,结果果然效果更好
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
在第17轮准确率居然到了99%,不知道为什么,先挖个坑,等我以后研究明白再来填
边栏推荐
- Abbkine丨TraKine F-actin染色试剂盒(绿色荧光)方案
- 实现浏览页面时校验用户是否已经完成登录的功能
- Analysis of steam education mode under the integration of five Education
- Open source SPL eliminates tens of thousands of database intermediate tables
- ts 之 类的简介、构造函数和它的this、继承、抽象类、接口
- PHP deserialization +md5 collision
- 教你自己训练的pytorch模型转caffe(二)
- Is it necessary for bazel to learn
- Clear app data and get Icon
- [quick start of Digital IC Verification] 2. Through an example of SOC project, understand the architecture of SOC and explore the design process of digital system
猜你喜欢

Abnova maxpab mouse derived polyclonal antibody solution

表单文本框的使用(二) 输入过滤(合成事件)

Duchefa丨S0188盐酸大观霉素五水合物中英文说明书

Promouvoir le développement de l'industrie culturelle et touristique par la recherche, l'apprentissage et l'enseignement pratique du tourisme

When steam education enters personalized information technology courses

Duchefa d5124 md5a medium Chinese and English instructions

Duchefa丨D5124 MD5A 培养基中英文说明书

The development of research tourism practical education helps the development of cultural tourism industry

Abnova丨CRISPR SpCas9 多克隆抗体方案

Research and development efficiency improvement practice of large insurance groups with 10000 + code base and 3000 + R & D personnel
随机推荐
解析创客教育的知识迁移和分享精神
AI automatically generates annotation documents from code
从架构上详解技术(SLB,Redis,Mysql,Kafka,Clickhouse)的各类热点问题
Abnova丨CRISPR SpCas9 多克隆抗体方案
ProSci LAG3抗体的化学性质和应用说明
Wanglaoji pharmaceutical's public welfare activity of "caring for the most lovely people under the scorching sun" was launched in Nanjing
Abbkine丨TraKine F-actin染色试剂盒(绿色荧光)方案
ODPs next map / reduce preparation
Analysis of steam education mode under the integration of five Education
Abnova丨血液总核酸纯化试剂盒预装相关说明书
mysql全面解析json/数组
The Chinese Academy of Management Sciences gathered industry experts, and Fu Qiang won the title of "top ten youth" of think tank experts
SQL series (basic) - Chapter 2 limiting and sorting data
Duchefa细胞分裂素丨二氢玉米素 (DHZ)说明书
ts 之 类的简介、构造函数和它的this、继承、抽象类、接口
Research and development efficiency improvement practice of large insurance groups with 10000 + code base and 3000 + R & D personnel
XML建模
《SAS编程和数据挖掘商业案例》学习笔记# 19
Popular science | does poor English affect the NPDP exam?
[UE4] unrealinsight obtains the real machine performance test report