当前位置:网站首页>12. RNN is applied to handwritten digit recognition

12. RNN is applied to handwritten digit recognition

2022-07-08 00:55:00 booze-J


The code running platform is jupyter-notebook, Code blocks in the article , According to jupyter-notebook Written in the order of division in , Run article code , Glue directly into jupyter-notebook that will do . The comments given by the overall code are quite simple . Here we use SimpleRNN For example .

1. Import third-party library

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.layers.recurrent import SimpleRNN
from tensorflow.keras.optimizers import Adam

2. Loading data and data preprocessing

#  Load data 
#  Data length - The line has 28 Pixel 
input_size=28
#  Sequence length - Altogether 28 That's ok 
time_steps=28
#  Hidden layer cell Number 
cell_size=50

#  Load data 
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
x_train = x_train/255.0
x_test = x_test/255.0
#  in one hot Format 
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)

3. Training models

#  Creating models 
model = Sequential()

#  Cyclic neural network 
model.add(SimpleRNN(
    units=cell_size,#  Output 
    input_shape=(time_steps,input_size),#  Input 
))
#  Output layer 
model.add(Dense(10,activation="softmax"))

#  Define optimizer   Set the learning rate to 1e-4
adam = Adam(lr=1e-4)

#  Define optimizer ,loss function, The accuracy of calculation during training 
model.compile(optimizer=adam,loss="categorical_crossentropy",metrics=["accuracy"])

#  Training models 
model.fit(x_train,y_train,batch_size=64,epochs=10)

#  Evaluation model 
loss,accuracy=model.evaluate(x_test,y_test)

print("test loss:",loss)
print("test accuracy:",accuracy)

Code run results :
 Insert picture description here

Some points needing attention in the code , Explanations and reminders are also given in the code comments . You can see from the run results RNN The accuracy of the trained model on the test set is relative to 10.CNN Applied to handwritten numeral recognition in CNN The accuracy effect of the trained model on the test set is worse .

原网站

版权声明
本文为[booze-J]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/189/202207072310361333.html