当前位置:网站首页>Pytoch practice -- MNIST dataset handwritten digit recognition
Pytoch practice -- MNIST dataset handwritten digit recognition
2022-07-05 20:59:00 【Sol-itude】
Original video link : Easy to learn Pytorch Handwritten font recognition MNIST
1. Load necessary Libraries
#1 Load necessary Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2. Define super parameters
#2 Define super parameters
BATCH_SIZE= 64 # The number of data for each training
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")# Make a judgment , If there is gpu Just use gpu, If not, use cpu
EPOCHS=20 # Number of training rounds
3. structure pipeline, Image processing
#3 structure pipeline, Image processing
pipeline=transforms.Compose([
transforms.ToTensor(),# Convert picture to tensor type
transforms.Normalize((0.1307,),(0.3081,))# Regularization , When the model is over fitted , Reduce model complexity
])
4. download 、 Load data set
# Download datasets
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)# Folder , Training is needed , Need to download , To convert to tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
# Load data
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)# Load training set , The number of data is 64, Need to disrupt
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
After downloading , Take a look at the pictures in the dataset
with open("MNIST The absolute path of ","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)# Save the picture
Output results 
5. Build a network model
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
x=self.fc2(x)# Output batch 500 Output batch 10
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
6. Define optimizer
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())## choice adam Optimizer
7. Define training methods
#7 Define training methods
def train_model(model,device,train_loader,optimizer,epoch):
# model training
model.train()
for batch_index,(data,target) in enumerate(train_loader):
# Deploy to DEVICE Up
data,target=data.to(device), target.to(device)
# The gradient is initialized to 0
optimizer.zero_grad()
# forecast , Post training results
output=model(data)
# Calculate the loss
loss = F.cross_entropy(output,target)# Cross validation for multiple classifications
# Back propagation
loss.backward()
# Parameter optimization
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8. Define test methods
#8 Define test methods
def test_model(model,device,test_loader):
# Model validation
model.eval()
# Accuracy rate
corrcet=0.0
# Test loss
test_loss=0.0
with torch.no_grad(): # The gradient is not calculated , There will be no back propagation
for data,target in test_loader:
# Deploy to device On
data,target=data.to(device),target.to(device)
# Test data
output=model(data)
# Calculate the test loss
test_loss+=F.cross_entropy(output,target).item()
# Find the subscript with the highest probability
pred=output.argmax(1)
# Accumulate correct values
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9. Calling method
#9 Calling method
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
Output results
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
Summary and improvement
After watching the video , The teacher really speaks well , But I didn't explain why the network structure should be built like this , So I went to have a look again CNN, This network structure can also be realized 
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
use CNN after , The accuracy of the discovery suddenly fell to 50%, puzzled , I guess it may be the problem of the optimizer , Just replace the optimizer with SGD, The result is better
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
In the 17 The accuracy of the round has reached 99%, do not know why , Dig a hole first , I'll fill it out later when I understand
边栏推荐
- 解析五育融合之下的steam教育模式
- Abnova total RNA Purification Kit for cultured cells Chinese and English instructions
- 大二下个人发展小结
- haas506 2.0开发教程 - 阿里云ota - pac 固件升级(仅支持2.2以上版本)
- Prior knowledge of machine learning in probability theory (Part 1)
- 清除app data以及获取图标
- vant 源码解析 之深层 合并对象 深拷贝
- vant 源码解析 event.ts 事件处理 全局函数 addEventListener详解
- php中explode函数存在的陷阱
- Duchefa cytokinin dihydrozeatin (DHZ) instructions
猜你喜欢

Write an interface based on flask

基於flask寫一個接口

示波器探头对测量带宽的影响

使用WebAssembly在浏览器端操作Excel

EN 438-7建筑覆盖物装饰用层压板材产品—CE认证

How to make ERP inventory accounts of chemical enterprises more accurate

Duchefa MS medium contains vitamin instructions

Who the final say whether the product is good or not? Sonar puts forward performance indicators for analysis to help you easily judge product performance and performance

Influence of oscilloscope probe on signal source impedance

When steam education enters personalized information technology courses
随机推荐
sql系列(基础)-第二章 限制和排序数据
产品好不好,谁说了算?Sonar提出分析的性能指标,帮助您轻松判断产品性能及表现
10000+ 代码库、3000+ 研发人员大型保险集团的研发效能提升实践
Is it safe to open a stock account by mobile phone? My home is relatively remote. Is there a better way to open an account?
Wanglaoji pharmaceutical's public welfare activity of "caring for the most lovely people under the scorching sun" was launched in Nanjing
100 cases of shell programming
树莓派4B上ncnn转换出来的模型调用时总是崩溃(Segment Fault)的原因
XML建模
Influence of oscilloscope probe on signal source impedance
Cutting edge technology for cultivating robot education creativity
Which is the best online collaboration product? Microsoft loop, notion, flowus
示波器探头对信号源阻抗的影响
ViewRootImpl和WindowManagerService笔记
Duchefa low melting point agarose PPC Chinese and English instructions
【案例】定位的运用-淘宝轮播图
大二下个人发展小结
How to renew NPDP? Here comes the operation guide!
CLion配置visual studio(msvc)和JOM多核编译
学习机器人无从下手?带你体会当下机器人热门研究方向有哪些
Generics of TS