当前位置:网站首页>Pytoch practice -- MNIST dataset handwritten digit recognition
Pytoch practice -- MNIST dataset handwritten digit recognition
2022-07-05 20:59:00 【Sol-itude】
Original video link : Easy to learn Pytorch Handwritten font recognition MNIST
1. Load necessary Libraries
#1 Load necessary Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
2. Define super parameters
#2 Define super parameters
BATCH_SIZE= 64 # The number of data for each training
DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu")# Make a judgment , If there is gpu Just use gpu, If not, use cpu
EPOCHS=20 # Number of training rounds
3. structure pipeline, Image processing
#3 structure pipeline, Image processing
pipeline=transforms.Compose([
transforms.ToTensor(),# Convert picture to tensor type
transforms.Normalize((0.1307,),(0.3081,))# Regularization , When the model is over fitted , Reduce model complexity
])
4. download 、 Load data set
# Download datasets
train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)# Folder , Training is needed , Need to download , To convert to tensor
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
# Load data
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)# Load training set , The number of data is 64, Need to disrupt
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
After downloading , Take a look at the pictures in the dataset
with open("MNIST The absolute path of ","rb") as f:
file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]]
print(image1)
import cv2
import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape)
cv2.imwrite("digit.jpg",image1_np)# Save the picture
Output results
5. Build a network model
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(10,20,kernel_size=3)
self.fc1=nn.Linear(20*10*10,500)
self.fc2=nn.Linear(500,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
x=self.fc2(x)# Output batch 500 Output batch 10
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
6. Define optimizer
model =Digit().to(device)
optimizer=optim.Adam(model.parameters())## choice adam Optimizer
7. Define training methods
#7 Define training methods
def train_model(model,device,train_loader,optimizer,epoch):
# model training
model.train()
for batch_index,(data,target) in enumerate(train_loader):
# Deploy to DEVICE Up
data,target=data.to(device), target.to(device)
# The gradient is initialized to 0
optimizer.zero_grad()
# forecast , Post training results
output=model(data)
# Calculate the loss
loss = F.cross_entropy(output,target)# Cross validation for multiple classifications
# Back propagation
loss.backward()
# Parameter optimization
optimizer.step()
if batch_index%3000==0:
print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
8. Define test methods
#8 Define test methods
def test_model(model,device,test_loader):
# Model validation
model.eval()
# Accuracy rate
corrcet=0.0
# Test loss
test_loss=0.0
with torch.no_grad(): # The gradient is not calculated , There will be no back propagation
for data,target in test_loader:
# Deploy to device On
data,target=data.to(device),target.to(device)
# Test data
output=model(data)
# Calculate the test loss
test_loss+=F.cross_entropy(output,target).item()
# Find the subscript with the highest probability
pred=output.argmax(1)
# Accumulate correct values
corrcet+=pred.eq(target.view_as(pred)).sum().item()
test_loss/=len(test_loader.dataset)
print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
9. Calling method
#9 Calling method
for epoch in range(1,EPOCHS+1):
train_model(model,DEVICE,train_loader,optimizer,epoch)
test_model(model,DEVICE,test_loader)
Output results
Train Epoch :1 Loss :2.296158
Train Epoch :1 Loss :0.023645
Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262
Train Epoch :2 Loss :0.002957
Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884
Train Epoch :3 Loss :0.000642
Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866
Train Epoch :4 Loss :0.003708
Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039
Train Epoch :5 Loss :0.000145
Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124
Train Epoch :6 Loss :0.035326
Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014
Train Epoch :7 Loss :0.000001
Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491
Train Epoch :8 Loss :0.000045
Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428
Train Epoch :9 Loss :0.000000
Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001
Train Epoch :10 Loss :0.000377
Test--Average Loss:0.0042,Accuarcy:98.930
Summary and improvement
After watching the video , The teacher really speaks well , But I didn't explain why the network structure should be built like this , So I went to have a look again CNN, This network structure can also be realized
#5 Build a network model
class Digit(nn.Module):
def __init__(self):
super().__init__()
self.conv1= nn.Conv2d(1,10,5)
self.conv2=nn.Conv2d(10,20,5)
self.fc1=nn.Linear(20*4*4,10)
def forward(self,x):
input_size=x.size(0)
x=self.conv1(x)# Input batch_size 1*28*28 , Output batch 10*24*24(28-5+1=24)
x=F.relu(x) # Activation function , keep shape unchanged
x=F.max_pool2d(x,2,2)#2 Is the size of the pool step , Output batch20*12*12
x=self.conv2(x)# Input batch20*12*12 Output batch 20*10*10 (12-3+1=10)
x=F.relu(x)
x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)# Flatten ,-1 Automatically calculate dimensions 20*10*10=2000
x=self.fc1(x)# Input batch 2000 Output batch 500
x=F.relu(x)
ouput=F.log_softmax(x,dim=1)# Calculate the probability of each number after classification
return ouput
use CNN after , The accuracy of the discovery suddenly fell to 50%, puzzled , I guess it may be the problem of the optimizer , Just replace the optimizer with SGD, The result is better
Train Epoch :17 Loss :0.014693
Train Epoch :17 Loss :0.000051
Test--Average Loss:0.0026,Accuarcy:99.010
In the 17 The accuracy of the round has reached 99%, do not know why , Dig a hole first , I'll fill it out later when I understand
边栏推荐
- 《SAS编程和数据挖掘商业案例》学习笔记# 19
- AITM 2-0003 水平燃烧试验
- Sophomore personal development summary
- 解析创客教育的知识迁移和分享精神
- WPF gets the control in the datagridtemplatecolumn of the specified row and column in the DataGrid
- Hdu2377bus pass (build more complex diagram +spfa)
- Wanglaoji pharmaceutical's public welfare activity of "caring for the most lovely people under the scorching sun" was launched in Nanjing
- 木板ISO 5660-1 热量释放速率摸底测试
- 当Steam教育进入个性化信息技术课程
- 清除app data以及获取图标
猜你喜欢
Return to blowing marshland -- travel notes of zhailidong, founder of duanzhitang
2.<tag-哈希表, 字符串>补充: 剑指 Offer 50. 第一个只出现一次的字符 dbc
Prosci LAG-3 recombinant protein specification
The development of research tourism practical education helps the development of cultural tourism industry
Abnova maxpab mouse derived polyclonal antibody solution
Analysis of steam education mode under the integration of five Education
显示器要申请BS 476-7 怎么送样?跟显示屏一样吗??
CADD course learning (7) -- Simulation of target and small molecule interaction (semi flexible docking autodock)
haas506 2.0开发教程 - 阿里云ota - pac 固件升级(仅支持2.2以上版本)
Abnova CRISPR spcas9 polyclonal antibody protocol
随机推荐
Norgen AAV extractant box instructions (including features)
基于flask写一个接口
Sophomore personal development summary
AITM 2-0003 水平燃烧试验
ts 之 属性的修饰符public、private、protect
Chemical properties and application instructions of prosci Lag3 antibody
Is it necessary for bazel to learn
实现浏览页面时校验用户是否已经完成登录的功能
Monorepo管理方法论和依赖安全
wpf 获取datagrid 中指定行列的DataGridTemplateColumn中的控件
Abnova blood total nucleic acid purification kit pre installed relevant instructions
重上吹麻滩——段芝堂创始人翟立冬游记
Mathematical analysis_ Notes_ Chapter 9: curve integral and surface integral
LeetCode_哈希表_困难_149. 直线上最多的点数
国外LEAD美国简称对照表
判断横竖屏的最佳实现
ArcGIS\QGIS无插件加载(无偏移)MapBox高清影像图
ODPs next map / reduce preparation
2.<tag-哈希表, 字符串>补充: 剑指 Offer 50. 第一个只出现一次的字符 dbc
How to make ERP inventory accounts of chemical enterprises more accurate