当前位置:网站首页>MNIST handwritten data recognition by RNN
MNIST handwritten data recognition by RNN
2022-06-12 06:06:00 【Singing under the hedge】
List of articles
RNN Realization MNIST Handwritten data recognition
One 、 Code
import torch
import torchvision.datasets
from torch import nn
import torch.utils.data as Data
EPOCH=1 # How many times to train
BATCH_SIZE =64 # Batch training quantity
TIME_STEP=28 #nn Time steps / Picture height
INPUT_SIZE=28 #nn Input value per step / Pixels per line of the picture
LR=0.01 # Learning rate
DOWNLOAD_MNIST=False #mnist data
# Batch training
train_data = torchvision.datasets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=DOWNLOAD_MNIST)
train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
# test
test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=torchvision.transforms.ToTensor())
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255 #shape(2000,28,28)
test_y = test_data.test_labels.numpy()[:2000]
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn =nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=64,
num_layers=1,
batch_first=True, #(bach_size,time_step)
)
self.out = nn.Linear(64,10)
def forward(self,x):
r_out,(h_n,h_c) = self.rnn(x,None)
out = self.out(r_out[:,-1,:])
return out
rnn = RNN()
optimizer = torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step,(b_x,b_y) in enumerate(train_loader):
b_x = b_x.view(-1,28,28)
output=rnn(b_x)
loss = loss_func(output,b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
test_output = rnn(test_x) # (samples, time_step, input_size)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
# print 10 predictions from test data
test_output = rnn(test_x[:10].view(-1, 28, 28))
pred_y = torch.max(test_output, 1)[1].data.numpy()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')
Two 、 Realization effect

边栏推荐
- Script for unity3d to recursively search for a node with a specific name from all child nodes of a node
- Why do I object so [1.01 to the power of 365 and 0.99 to the power of 365]
- Redis队列
- BRDF of directx11 advanced tutorial PBR (2)
- How to split a row of data into multiple rows in Informix database
- A month's worth of DDD will help you master it
- Database experiment I: data definition experiment guide
- Data integration framework seatunnel learning notes
- Leetcode-1535. Find the winner of the array game
- 基于LFFD模型目标检测自动标注生成xml文件
猜你喜欢
随机推荐
. Net core and Net framework comparison
sqlite交叉編譯動態庫
User login (medium)
Annotation configuration of filter
Cross compile libev
Getting started with houdininengine HDA and UE4
Leetcode-1260. 2D mesh migration
C # converts the hexadecimal code form of text to text (ASCII)
Open the camera in unity3d and display the contents of the camera in the scene as texture2d
EBook editing and deleting
Sqlite Cross - compile Dynamic Library
Image processing: image[:,:,:: -1], image[:,: -1,:], image[:,: -1,:]
User login 【 I 】
Jackson - how to convert the array string with only one map object to list < map >
JS预解析
项目管理与统筹
Redis队列
Who is more fierce in network acceleration? New king reappeared in CDN field
交叉编译libev
Unity C script implements AES encryption and decryption








