当前位置:网站首页>Implementation of handwritten numeral code recognition (pytorch)
Implementation of handwritten numeral code recognition (pytorch)
2022-07-23 04:58:00 【Blameless.lsy】
Data preview :
import pandas
df=pandas.read_csv('C:\\Users\\HP\\Desktop\\mnist_train.csv',header=None)
df.head()
MNIST Each row of data contains 785 It's worth . The first value is the number represented by the image , The rest 784 The first value is the image ( Size is 28 Pixels × 28 Pixels ) Pixel value .¶
We can use info() The function view DataFrame Overview of
df.info()<class 'pandas.core.frame.DataFrame'> RangeIndex: 60000 entries, 0 to 59999 Columns: 785 entries, 0 to 784 dtypes: int64(785) memory usage: 359.3 MB
The above results tell us , The DataFrame Yes 60 000 That's ok . This corresponds to 60 000 A training image . meanwhile , We can also confirm that each line has 785 It's worth .¶
Let's convert a row of pixel values into an actual image to visually view .
We use the universal matplotlib Library to display images . In the following code , We import matplotlib Library pyplot package
Complete code :
#### Import library
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
from torch.utils.data import Dataset # yes pytorch How to load and import data
'''------------ Construct neural network class ------------'''
class Classifier(nn.Module):#nn.Module Is the parent of all classes
""" classifier """
def __init__(self):
# initialization pytorch Parent class
super().__init__()
# Define the neural network layer
self.model=nn.Sequential(
nn.Linear(784,200),
nn.Sigmoid(),
nn.Linear(200,10),
nn.Sigmoid()
)
# Create a loss function ( Mean square error )
self.loss_function=nn.MSELoss()
# Create optimizer , Use simple gradient descent
self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)
''' visualization '''
# Counters and lists for recording training progress
self.counter=0
self.progress=[]
def forward(self,inputs):
# Run the model directly
return self.model(inputs)
""" Trainer """
def train(self,inputs,targets):
# Calculate the output value of the network
outputs=self.forward(inputs)
# Calculate the loss value
loss=self.loss_function(outputs,targets)
""" The next step is to use the loss to update the link weight of the network """
# The gradient goes to zero , Back propagation , And update the weight
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
""" visualization """
# every other 10 Increase the value of the counter once for each training sample , And add the loss value to the end of the list
self.counter+=1
if(self.counter%10==0):
self.progress.append(loss.item())# Use here item() The function of is just to facilitate the expansion of a single valued tensor , Get the numbers inside
pass
# Every time 10000 Print the counter value after training , In this way, you can know the speed of training progress
if(self.counter%10000==0):
print("counter=",self.counter)
pass
""" Plot the loss value """
def plot_progress(self):
df=pandas.DataFrame(self.progress,columns=['loss'])
df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))
pass
'''------------- establish MnistDataset class --------------'''
class MnistDataset(Dataset):
def __init__(self,csv_file):
self.data_df=pandas.read_csv(csv_file,header=None)
pass
def __len__(self):
return len(self.data_df)
def __getitem__(self,index):
# Target image ( label )
label=self.data_df.iloc[index,0]
target=torch.zeros((10))
target[label]=1.0
# Image data , The value range is 0~255, Standardize to 0~1
image_values=torch.FloatTensor(self.data_df.iloc[index,1:].values)/255.0
# Return label , Image data tensor and target tensor
return label, image_values, target
pass
""" visualization """
def plot_image(self,index):
arr=self.data_df.iloc[index,1:].values.reshape(28,28)
plt.title("label = " + str(self.data_df.iloc[index,0]))
plt.imshow(arr,interpolation='none',cmap='Blues')
pass
# Check whether everything is normal so far
mnist_dataset=MnistDataset('C:\\Users\\HP\\Desktop\\mnist_train.csv')
#mnist_dataset.plot_image(9)
""" Training classifier """
# Creating neural networks
C=Classifier()
# stay MNIST Data sets train neural networks
for label, image_data_tensor, target_tensor in mnist_dataset:
C.train(image_data_tensor,target_tensor)
pass
# Draw classifier loss value
C.plot_progress()
# load MNIST Test data
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')counter= 10000 counter= 20000 counter= 30000 counter= 40000 counter= 50000 counter= 60000
# Draw classifier loss value
C.plot_progress()
Now we have a network after training , Image classification is available . We will switch to include 10 000 Of an image MNIST Test data set . These are images that our neural network has never seen . Let's use a new Dataset Object load dataset
# load MNIST Test data
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')
# Pick an image
record = 19
# Draw images and labels
mnist_test_dataset.plot_image(record)
Let's see how the trained neural network judges this image . The following code continues to use section 20 Image and extract the pixel value as image_data. We use forward() Function transfers the image and passes it through neural network
image_data = mnist_test_dataset[record][1]
# Call the trained neural network
output = C.forward(image_data)
# Draw the output tensor
pandas.DataFrame(output.detach().numpy()).plot(kind='bar',
legend=False, ylim=(0,1))
边栏推荐
- 【深度學習】損失函數(平均絕對誤差,均方誤差,平滑損失,交叉熵,帶權值的交叉熵,骰子損失,FocalLoss)
- 【微信小程序开发】(三)首页banner组件使用swiper
- 【经典卷积网络】ResNet理论讲解
- AutoJS一文精通AutoJS脚本教程详解
- CososCreator升级gradle版本
- What is the reason for the failure of video playback and RTMP repeated streaming on easygbs platform?
- Jetpack principle analysis episode I (livedata)
- [numpy] create an array
- Robert算子、Sobel算子、拉普拉斯算子
- 光学相干断层扫描中基于 GAN 的视网膜层超分辨率分割
猜你喜欢

优质微信小程序的体验与设计思路
![[excel] convert half angle string to full angle](/img/18/f64728804f7cb37a53f64f0ab974a8.png)
[excel] convert half angle string to full angle

微信小程序宿主环境,小程序构架,简明运行结构

Definition, function and brief description of wechat applet events

Pytorch 搭建神经网络详细注释

【微信小程序开发】(二)微信原生底部tabbar配置

卷积的三种模式:full, same, valid

【深度學習】損失函數(平均絕對誤差,均方誤差,平滑損失,交叉熵,帶權值的交叉熵,骰子損失,FocalLoss)

Kotlin协程分析(二)——suspendCoroutineUninterceptedOrReturn

【微信小程序开发】(一)开发环境和小程序公众号申请
随机推荐
Go: gin reader save data
Robert算子、Sobel算子、拉普拉斯算子
微信小程序订阅消息开发流程
Jetpack篇——Lifecycle
unity hub 免费版实现
北京网上开户安全吗
卷积的三种模式:full, same, valid
AlexNet代码实现
[excel] solve the "problem" of Excel and TXT conversion
深度可分离卷积
51nod 1685 第K大区间2(树状数组,逆序对)
WXSS 样式简明教程
Some feature fusion techniques
SIFT特征点提取
Upgrade the ecological construction ability of finclip applet, and realize the freedom of enterprise personalized UI customization
error LNK2019: 无法解析的外部符号 [email protected]
【数据库】SQL语句大全
Semi supervised medical image segmentation based on dual task consistency
Jetpack篇——LiveData扩展之Transformations
Jetpack篇——总览