当前位置:网站首页>Pytoch practice -- MNIST dataset handwritten digit recognition
Pytoch practice -- MNIST dataset handwritten digit recognition
2022-07-05 20:59:00 【Sol-itude】
Original video link : Easy to learn Pytorch Handwritten font recognition MNIST
1. Load necessary Libraries
#1 Load necessary Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2. Define super parameters
#2 Define super parameters
BATCH_SIZE= 64 # The number of data for each training
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")# Make a judgment , If there is gpu Just use gpu, If not, use cpu
EPOCHS=20 # Number of training rounds
3. structure pipeline, Image processing
#3 structure pipeline, Image processing
pipeline=transforms.Compose([
transforms.ToTensor(),# Convert picture to tensor type
transforms.Normalize((0.1307,),(0.3081,))# Regularization , When the model is over fitted , Reduce model complexity
])
4. download 、 Load data set
# Download datasets
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)# Folder , Training is needed , Need to download , To convert to tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
# Load data
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)# Load training set , The number of data is 64, Need to disrupt
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
After downloading , Take a look at the pictures in the dataset
with open("MNIST The absolute path of ","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)# Save the picture
Output results 
5. Build a network model
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
x=self.fc2(x)# Output batch 500 Output batch 10
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
6. Define optimizer
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())## choice adam Optimizer
7. Define training methods
#7 Define training methods
def train_model(model,device,train_loader,optimizer,epoch):
# model training
model.train()
for batch_index,(data,target) in enumerate(train_loader):
# Deploy to DEVICE Up
data,target=data.to(device), target.to(device)
# The gradient is initialized to 0
optimizer.zero_grad()
# forecast , Post training results
output=model(data)
# Calculate the loss
loss = F.cross_entropy(output,target)# Cross validation for multiple classifications
# Back propagation
loss.backward()
# Parameter optimization
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8. Define test methods
#8 Define test methods
def test_model(model,device,test_loader):
# Model validation
model.eval()
# Accuracy rate
corrcet=0.0
# Test loss
test_loss=0.0
with torch.no_grad(): # The gradient is not calculated , There will be no back propagation
for data,target in test_loader:
# Deploy to device On
data,target=data.to(device),target.to(device)
# Test data
output=model(data)
# Calculate the test loss
test_loss+=F.cross_entropy(output,target).item()
# Find the subscript with the highest probability
pred=output.argmax(1)
# Accumulate correct values
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9. Calling method
#9 Calling method
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
Output results
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
Summary and improvement
After watching the video , The teacher really speaks well , But I didn't explain why the network structure should be built like this , So I went to have a look again CNN, This network structure can also be realized 
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
use CNN after , The accuracy of the discovery suddenly fell to 50%, puzzled , I guess it may be the problem of the optimizer , Just replace the optimizer with SGD, The result is better
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
In the 17 The accuracy of the round has reached 99%, do not know why , Dig a hole first , I'll fill it out later when I understand
边栏推荐
- Wanglaoji pharmaceutical's public welfare activity of "caring for the most lovely people under the scorching sun" was launched in Nanjing
- Simple getting started example of Web Service
- Influence of oscilloscope probe on measurement bandwidth
- Prosci LAG-3 recombinant protein specification
- Typhoon is coming! How to prevent typhoons on construction sites!
- PHP deserialization +md5 collision
- Clion-MinGW编译后的exe文件添加ico图标
- Duchefa low melting point agarose PPC Chinese and English instructions
- Interpreting the daily application functions of cooperative robots
- 基于vertx-web-sstore-redis的改造实现vertx http应用的分布式session
猜你喜欢

Which is the best online collaboration product? Microsoft loop, notion, flowus

Return to blowing marshland -- travel notes of zhailidong, founder of duanzhitang

XML建模

研學旅遊實踐教育的開展助力文旅產業發展

PHP deserialization +md5 collision

Abnova e (diii) (WNV) recombinant protein Chinese and English instructions

基于vertx-web-sstore-redis的改造实现vertx http应用的分布式session

Enclosed please find. Net Maui's latest learning resources

Abnova total RNA Purification Kit for cultured cells Chinese and English instructions

解析创客教育的知识迁移和分享精神
随机推荐
字典树简单入门题(居然是蓝题?)
Analysis of steam education mode under the integration of five Education
Interpreting the daily application functions of cooperative robots
Duchefa cytokinin dihydrozeatin (DHZ) instructions
Traps in the explode function in PHP
Open source SPL eliminates tens of thousands of database intermediate tables
Duchefa p1001 plant agar Chinese and English instructions
示波器探头对测量带宽的影响
Mathematical analysis_ Notes_ Chapter 9: curve integral and surface integral
学习机器人无从下手?带你体会当下机器人热门研究方向有哪些
Enclosed please find. Net Maui's latest learning resources
Comparison table of foreign lead American abbreviations
Abnova e (diii) (WNV) recombinant protein Chinese and English instructions
研學旅遊實踐教育的開展助力文旅產業發展
Abnova DNA marker high quality control test program
Abnova cyclosporin a monoclonal antibody and its research tools
Norgen AAV extractant box instructions (including features)
Abnova CRISPR spcas9 polyclonal antibody protocol
EN 438-7建筑覆盖物装饰用层压板材产品—CE认证
Implementation of redis unique ID generator