当前位置:网站首页>12.RNN应用于手写数字识别
12.RNN应用于手写数字识别
2022-07-07 23:11:00 【booze-J】
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。整体代码给出的注释还是挺简单明了的。这里我们以使用
SimpleRNN为例。 1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.layers.recurrent import SimpleRNN
from tensorflow.keras.optimizers import Adam
2.加载数据及数据预处理
# 载入数据
# 数据长度-一行有28个像素
input_size=28
# 序列长度-一共有28行
time_steps=28
# 隐藏层cell个数
cell_size=50
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
x_train = x_train/255.0
x_test = x_test/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型
model = Sequential()
# 循环神经网络
model.add(SimpleRNN(
units=cell_size,# 输出
input_shape=(time_steps,input_size),# 输入
))
# 输出层
model.add(Dense(10,activation="softmax"))
# 定义优化器 设置学习率为1e-4
adam = Adam(lr=1e-4)
# 定义优化器,loss function,训练过程中计算准确率
model.compile(optimizer=adam,loss="categorical_crossentropy",metrics=["accuracy"])
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=10)
# 评估模型
loss,accuracy=model.evaluate(x_test,y_test)
print("test loss:",loss)
print("test accuracy:",accuracy)
代码运行结果:
代码中需要注意的一些点,在代码注释中也给出了解释和提醒。从运行结果中可以看到RNN训练出来的模型在测试集上的准确率相对于10.CNN应用于手写数字识别中CNN训练出来的模型在测试集上的准确率效果要更差一些。
边栏推荐
- 【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
- 玩转Sonar
- 爬虫实战(八):爬表情包
- Solution to prompt configure: error: curses library not found when configuring and installing crosstool ng tool
- Service Mesh介绍,Istio概述
- RPA cloud computer, let RPA out of the box with unlimited computing power?
- ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
- Solution to the problem of unserialize3 in the advanced web area of the attack and defense world
- 丸子官网小程序配置教程来了(附详细步骤)
- 基于卷积神经网络的恶意软件检测方法
猜你喜欢

5g NR system messages

Reentrantlock fair lock source code Chapter 0

"An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points

《因果性Causality》教程,哥本哈根大学Jonas Peters讲授

玩轉Sonar

新库上线 | CnOpenData中华老字号企业名录

RPA cloud computer, let RPA out of the box with unlimited computing power?

Class head up rate detection based on face recognition

爬虫实战(八):爬表情包

Langchao Yunxi distributed database tracing (II) -- source code analysis
随机推荐
Basic principle and usage of dynamic library, -fpic option context
Class head up rate detection based on face recognition
51与蓝牙模块通讯,51驱动蓝牙APP点灯
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
攻防演练中沙盘推演的4个阶段
Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
AI遮天传 ML-初识决策树
他们齐聚 2022 ECUG Con,只为「中国技术力量」
Password recovery vulnerability of foreign public testing
DNS series (I): why does the updated DNS record not take effect?
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
[necessary for R & D personnel] how to make your own dataset and display it.
CVE-2022-28346:Django SQL注入漏洞
"An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points
大数据开源项目,一站式全自动化全生命周期运维管家ChengYing(承影)走向何方?
Solution to the problem of unserialize3 in the advanced web area of the attack and defense world
Four stages of sand table deduction in attack and defense drill
服务器防御DDOS的方法,杭州高防IP段103.219.39.x
詹姆斯·格雷克《信息简史》读后感记录