当前位置:网站首页>Pytoch practice -- MNIST dataset handwritten digit recognition
Pytoch practice -- MNIST dataset handwritten digit recognition
2022-07-05 20:59:00 【Sol-itude】
Original video link : Easy to learn Pytorch Handwritten font recognition MNIST
1. Load necessary Libraries
#1 Load necessary Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2. Define super parameters
#2 Define super parameters
BATCH_SIZE= 64 # The number of data for each training
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")# Make a judgment , If there is gpu Just use gpu, If not, use cpu
EPOCHS=20 # Number of training rounds
3. structure pipeline, Image processing
#3 structure pipeline, Image processing
pipeline=transforms.Compose([
transforms.ToTensor(),# Convert picture to tensor type
transforms.Normalize((0.1307,),(0.3081,))# Regularization , When the model is over fitted , Reduce model complexity
])
4. download 、 Load data set
# Download datasets
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)# Folder , Training is needed , Need to download , To convert to tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
# Load data
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)# Load training set , The number of data is 64, Need to disrupt
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
After downloading , Take a look at the pictures in the dataset
with open("MNIST The absolute path of ","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)# Save the picture
Output results
5. Build a network model
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
x=self.fc2(x)# Output batch 500 Output batch 10
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
6. Define optimizer
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())## choice adam Optimizer
7. Define training methods
#7 Define training methods
def train_model(model,device,train_loader,optimizer,epoch):
# model training
model.train()
for batch_index,(data,target) in enumerate(train_loader):
# Deploy to DEVICE Up
data,target=data.to(device), target.to(device)
# The gradient is initialized to 0
optimizer.zero_grad()
# forecast , Post training results
output=model(data)
# Calculate the loss
loss = F.cross_entropy(output,target)# Cross validation for multiple classifications
# Back propagation
loss.backward()
# Parameter optimization
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8. Define test methods
#8 Define test methods
def test_model(model,device,test_loader):
# Model validation
model.eval()
# Accuracy rate
corrcet=0.0
# Test loss
test_loss=0.0
with torch.no_grad(): # The gradient is not calculated , There will be no back propagation
for data,target in test_loader:
# Deploy to device On
data,target=data.to(device),target.to(device)
# Test data
output=model(data)
# Calculate the test loss
test_loss+=F.cross_entropy(output,target).item()
# Find the subscript with the highest probability
pred=output.argmax(1)
# Accumulate correct values
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9. Calling method
#9 Calling method
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
Output results
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
Summary and improvement
After watching the video , The teacher really speaks well , But I didn't explain why the network structure should be built like this , So I went to have a look again CNN, This network structure can also be realized
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
use CNN after , The accuracy of the discovery suddenly fell to 50%, puzzled , I guess it may be the problem of the optimizer , Just replace the optimizer with SGD, The result is better
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
In the 17 The accuracy of the round has reached 99%, do not know why , Dig a hole first , I'll fill it out later when I understand
边栏推荐
- 台风来袭!建筑工地该如何防范台风!
- Interpreting the daily application functions of cooperative robots
- Who the final say whether the product is good or not? Sonar puts forward performance indicators for analysis to help you easily judge product performance and performance
- 珍爱网微服务底层框架演进从开源组件封装到自研
- Modifiers of attributes of TS public, private, protect
- Duchefa cytokinin dihydrozeatin (DHZ) instructions
- 木板ISO 5660-1 热量释放速率摸底测试
- 显示器要申请BS 476-7 怎么送样?跟显示屏一样吗??
- Binary search
- 研学旅游实践教育的开展助力文旅产业发展
猜你喜欢
Use of thread pool
haas506 2.0开发教程 - 阿里云ota - pac 固件升级(仅支持2.2以上版本)
当Steam教育进入个性化信息技术课程
Abnova DNA marker high quality control test program
解析创客教育的知识迁移和分享精神
ArcGIS\QGIS无插件加载(无偏移)MapBox高清影像图
研学旅游实践教育的开展助力文旅产业发展
木板ISO 5660-1 热量释放速率摸底测试
研學旅遊實踐教育的開展助力文旅產業發展
Return to blowing marshland -- travel notes of zhailidong, founder of duanzhitang
随机推荐
Chemical properties and application instructions of prosci Lag3 antibody
POJ 3414 pots (bfs+ clues)
Mode - "Richter replacement principle"
解读协作型机器人的日常应用功能
CLion配置visual studio(msvc)和JOM多核编译
AITM 2-0003 水平燃烧试验
【案例】元素的显示与隐藏的运用--元素遮罩
Wanglaoji pharmaceutical's public welfare activity of "caring for the most lovely people under the scorching sun" was launched in Nanjing
Is Kai Niu 2980 useful? Is it safe to open an account
poj 3414 Pots (bfs+线索)
Prior knowledge of machine learning in probability theory (Part 1)
systemd-resolved 开启 debug 日志
解析创客教育的知识迁移和分享精神
解析五育融合之下的steam教育模式
Who the final say whether the product is good or not? Sonar puts forward performance indicators for analysis to help you easily judge product performance and performance
Learning notes of SAS programming and data mining business case 19
Abnova DNA marker high quality control test program
Clion-MinGW编译后的exe文件添加ico图标
Comparison table of foreign lead American abbreviations
Abnova e (diii) (WNV) recombinant protein Chinese and English instructions