当前位置:网站首页>MNIST handwritten data recognition by RNN
MNIST handwritten data recognition by RNN
2022-06-12 06:06:00 【Singing under the hedge】
List of articles
RNN Realization MNIST Handwritten data recognition
One 、 Code
import torch
import torchvision.datasets
from torch import nn
import torch.utils.data as Data
EPOCH=1 # How many times to train
BATCH_SIZE =64 # Batch training quantity
TIME_STEP=28 #nn Time steps / Picture height
INPUT_SIZE=28 #nn Input value per step / Pixels per line of the picture
LR=0.01 # Learning rate
DOWNLOAD_MNIST=False #mnist data
# Batch training
train_data = torchvision.datasets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MNIST)
train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
# test
test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=torchvision.transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255 #shape(2000,28,28)
test_y = test_data.test_labels.numpy()[:2000]
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn =nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=64,
num_layers=1,
batch_first=True, #(bach_size,time_step)
)
self.out = nn.Linear(64,10)
def forward(self,x):
r_out,(h_n,h_c) = self.rnn(x,None)
out = self.out(r_out[:,-1,:])
return out
rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step,(b_x,b_y) in enumerate(train_loader):
b_x = b_x.view(-1,28,28)
output=rnn(b_x)
loss = loss_func(output,b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
test_output = rnn(test_x) # (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
# print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')
Two 、 Realization effect

边栏推荐
猜你喜欢
![How to increase heap size of JVM [duplicate] - how to increase heap size of JVM [duplicate]](/img/65/a214d137e230b1a1190feb03660f2c.jpg)
How to increase heap size of JVM [duplicate] - how to increase heap size of JVM [duplicate]

线程有哪些状态?

dlib 人脸检测

Leetcode-1552. Magnetic force between two balls

Quickly master makefile file writing

About why GPU early-z reduces overdraw

Review notes of naturallanguageprocessing based on deep learning

UE4 4.27 modify the mobile forward pipeline to support cluster multi light source culling

MySQL 主从,6 分钟带你掌握

Findasync and include LINQ statements - findasync and include LINQ statements
随机推荐
Unity implements smooth interpolation
Guns framework multi data source configuration without modifying the configuration file
Leetcode sword finger offer II 119 Longest continuous sequence
Jpg format and XML format files are separated into different folders
Leetcode dynamic programming
Makefile文件编写快速掌握
Brief summary of software project architecture
How to increase heap size of JVM [duplicate] - how to increase heap size of JVM [duplicate]
English语法_副词_有无ly,意义不同
nus_ data_ Handler source code interprets data types such as structure
Leetcode 第 80 场双周赛题解
Unity custom translucent surface material shader
交叉编译libev
SQLite cross compile dynamic library
IO system - code example
线程有哪些状态?
Leetcode-93. Restore IP address
为什么数据库不使用二叉树、红黑树、B树、Hash表? 而是使用了B+树
EBook upload
Database Experiment 3: data query