当前位置:网站首页>12.RNN应用于手写数字识别
12.RNN应用于手写数字识别
2022-07-07 23:11:00 【booze-J】
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。整体代码给出的注释还是挺简单明了的。这里我们以使用
SimpleRNN
为例。 1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.layers.recurrent import SimpleRNN
from tensorflow.keras.optimizers import Adam
2.加载数据及数据预处理
# 载入数据
# 数据长度-一行有28个像素
input_size=28
# 序列长度-一共有28行
time_steps=28
# 隐藏层cell个数
cell_size=50
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
x_train = x_train/255.0
x_test = x_test/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型
model = Sequential()
# 循环神经网络
model.add(SimpleRNN(
units=cell_size,# 输出
input_shape=(time_steps,input_size),# 输入
))
# 输出层
model.add(Dense(10,activation="softmax"))
# 定义优化器 设置学习率为1e-4
adam = Adam(lr=1e-4)
# 定义优化器,loss function,训练过程中计算准确率
model.compile(optimizer=adam,loss="categorical_crossentropy",metrics=["accuracy"])
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=10)
# 评估模型
loss,accuracy=model.evaluate(x_test,y_test)
print("test loss:",loss)
print("test accuracy:",accuracy)
代码运行结果:
代码中需要注意的一些点,在代码注释中也给出了解释和提醒。从运行结果中可以看到RNN训练出来的模型在测试集上的准确率相对于10.CNN应用于手写数字识别中CNN训练出来的模型在测试集上的准确率效果要更差一些。
边栏推荐
- 华为交换机S5735S-L24T4S-QA2无法telnet远程访问
- 取消select的默认样式的向下箭头和设置select默认字样
- Service mesh introduction, istio overview
- [necessary for R & D personnel] how to make your own dataset and display it.
- 5G NR 系统消息
- Introduction to paddle - using lenet to realize image classification method II in MNIST
- 华泰证券官方网站开户安全吗?
- 8道经典C语言指针笔试题解析
- Reentrantlock fair lock source code Chapter 0
- Experience of autumn recruitment in 22 years
猜你喜欢
国外众测之密码找回漏洞
"An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points
Fofa attack and defense challenge record
Deep dive kotlin synergy (XXII): flow treatment
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
新库上线 | CnOpenData中国星级酒店数据
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
Langchao Yunxi distributed database tracing (II) -- source code analysis
Operating system principle --- summary of interview knowledge points
《因果性Causality》教程,哥本哈根大学Jonas Peters讲授
随机推荐
A network composed of three convolution layers completes the image classification task of cifar10 data set
【笔记】常见组合滤波电路
Is it safe to open an account on the official website of Huatai Securities?
Which securities company has a low, safe and reliable account opening commission
韦东山第三期课程内容概要
什么是负载均衡?DNS如何实现负载均衡?
v-for遍历元素样式失效
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
Introduction to paddle - using lenet to realize image classification method II in MNIST
Cancel the down arrow of the default style of select and set the default word of select
Leetcode brush questions
Analysis of 8 classic C language pointer written test questions
Basic mode of service mesh
Codeforces Round #804 (Div. 2)(A~D)
Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
They gathered at the 2022 ecug con just for "China's technological power"
Hotel
【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
Qt添加资源文件,为QAction添加图标,建立信号槽函数并实现