当前位置:网站首页>Pytoch practice -- MNIST dataset handwritten digit recognition
Pytoch practice -- MNIST dataset handwritten digit recognition
2022-07-05 20:59:00 【Sol-itude】
Original video link : Easy to learn Pytorch Handwritten font recognition MNIST
1. Load necessary Libraries
#1 Load necessary Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2. Define super parameters
#2 Define super parameters
BATCH_SIZE= 64 # The number of data for each training
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")# Make a judgment , If there is gpu Just use gpu, If not, use cpu
EPOCHS=20 # Number of training rounds
3. structure pipeline, Image processing
#3 structure pipeline, Image processing
pipeline=transforms.Compose([
transforms.ToTensor(),# Convert picture to tensor type
transforms.Normalize((0.1307,),(0.3081,))# Regularization , When the model is over fitted , Reduce model complexity
])
4. download 、 Load data set
# Download datasets
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)# Folder , Training is needed , Need to download , To convert to tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
# Load data
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)# Load training set , The number of data is 64, Need to disrupt
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
After downloading , Take a look at the pictures in the dataset
with open("MNIST The absolute path of ","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)# Save the picture
Output results 
5. Build a network model
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
x=self.fc2(x)# Output batch 500 Output batch 10
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
6. Define optimizer
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())## choice adam Optimizer
7. Define training methods
#7 Define training methods
def train_model(model,device,train_loader,optimizer,epoch):
# model training
model.train()
for batch_index,(data,target) in enumerate(train_loader):
# Deploy to DEVICE Up
data,target=data.to(device), target.to(device)
# The gradient is initialized to 0
optimizer.zero_grad()
# forecast , Post training results
output=model(data)
# Calculate the loss
loss = F.cross_entropy(output,target)# Cross validation for multiple classifications
# Back propagation
loss.backward()
# Parameter optimization
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8. Define test methods
#8 Define test methods
def test_model(model,device,test_loader):
# Model validation
model.eval()
# Accuracy rate
corrcet=0.0
# Test loss
test_loss=0.0
with torch.no_grad(): # The gradient is not calculated , There will be no back propagation
for data,target in test_loader:
# Deploy to device On
data,target=data.to(device),target.to(device)
# Test data
output=model(data)
# Calculate the test loss
test_loss+=F.cross_entropy(output,target).item()
# Find the subscript with the highest probability
pred=output.argmax(1)
# Accumulate correct values
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9. Calling method
#9 Calling method
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
Output results
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
Summary and improvement
After watching the video , The teacher really speaks well , But I didn't explain why the network structure should be built like this , So I went to have a look again CNN, This network structure can also be realized 
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
use CNN after , The accuracy of the discovery suddenly fell to 50%, puzzled , I guess it may be the problem of the optimizer , Just replace the optimizer with SGD, The result is better
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
In the 17 The accuracy of the round has reached 99%, do not know why , Dig a hole first , I'll fill it out later when I understand
边栏推荐
- Return to blowing marshland -- travel notes of zhailidong, founder of duanzhitang
- Which is the best online collaboration product? Microsoft loop, notion, flowus
- Maker education infiltrating the transformation of maker spirit and culture
- Pytorch实战——MNIST数据集手写数字识别
- 教你自己训练的pytorch模型转caffe(二)
- SYSTEMd resolved enable debug log
- LeetCode: Distinct Subsequences [115]
- 序列联配Sequence Alignment
- Monorepo管理方法论和依赖安全
- 培养机器人教育创造力的前沿科技
猜你喜欢

研学旅游实践教育的开展助力文旅产业发展

显示屏DIN 4102-1 Class B1防火测试要求

Phpstudy Xiaopi's MySQL Click to start and quickly flash back. It has been solved

ArcGIS栅格重采样方法介绍

Enclosed please find. Net Maui's latest learning resources

研學旅遊實踐教育的開展助力文旅產業發展

XML建模

Analysis of steam education mode under the integration of five Education

产品好不好,谁说了算?Sonar提出分析的性能指标,帮助您轻松判断产品性能及表现

Abnova e (diii) (WNV) recombinant protein Chinese and English instructions
随机推荐
Binary search
Maker education infiltrating the transformation of maker spirit and culture
Talk about my fate with some programming languages
Abnova DNA marker high quality control test program
产品好不好,谁说了算?Sonar提出分析的性能指标,帮助您轻松判断产品性能及表现
wpf 获取datagrid 中指定行列的DataGridTemplateColumn中的控件
研學旅遊實踐教育的開展助力文旅產業發展
ts 之 属性的修饰符public、private、protect
Abnova fluorescent dye 620-m streptavidin scheme
最长摆动序列[贪心练习]
Interpreting the daily application functions of cooperative robots
Sophomore personal development summary
Learning notes of SAS programming and data mining business case 19
学习机器人无从下手?带你体会当下机器人热门研究方向有哪些
渗透创客精神文化转化的创客教育
示波器探头对信号源阻抗的影响
WPF gets the control in the datagridtemplatecolumn of the specified row and column in the DataGrid
ODPS 下一个map / reduce 准备
基於flask寫一個接口
vant 源码解析之 utils/index.ts 工具函数