当前位置:网站首页>双向RNN与堆叠的双向RNN
双向RNN与堆叠的双向RNN
2022-07-05 10:20:00 【别团等shy哥发育】
双向RNN与堆叠的双向RNN
1、双向RNN
双向RNN(Bidirectional RNN)的结构如下图所示。

h t → = f ( W → x t + V → h t − 1 → + b → ) h t ← = f ( W ← x t + V ← h t − 1 ← + b ← ) y t = g ( U [ h t → ; h t ← ] + c ) \overrightarrow{h_t}=f(\overrightarrow{W}x_t+\overrightarrow{V}\overrightarrow{h_{t-1}}+\overrightarrow{b})\\ \overleftarrow{h_t}=f(\overleftarrow{W}x_t+\overleftarrow{V}\overleftarrow{h_{t-1}}+\overleftarrow{b})\\ y_t=g(U[\overrightarrow{h_t};\overleftarrow{h_t}]+c) ht=f(Wxt+Vht−1+b)ht=f(Wxt+Vht−1+b)yt=g(U[ht;ht]+c)
这里的 RNN 可以使用任意一种 RNN 结构 SimpleRNN,LSTM 或 GRU。这里箭头表示从左到右或从右到左传播,对于每个时刻的预测,都需要来自双向的特征向量,拼接 (Concatenate)后进行结果预测。箭头虽然不同,但参数还是同一套参数。有些模型中也 可以使用两套不同的参数。f,g表示激活函数, [ h t → ; h t ← ] [\overrightarrow{h_t};\overleftarrow{h_t}] [ht;ht]表示数据拼接(Concatenate)。
双向的 RNN 是同时考虑“过去”和“未来”的信息。上图是一个序列长度为 4 的双向RNN 结构。
比如输入 x 1 x_1 x1沿着实线箭头传输到隐层得到 h 1 h_1 h1,然后还需要再利用 x t x_t xt计算得到 h t ′ h_t' ht′,利用 x 3 x_3 x3和 h t ′ h_t' ht′计算得到 h 3 ′ h_3' h3′,利用 x 2 x_2 x2和 h 3 ′ h_3' h3′计算得到 h 2 ′ h_2' h2′,利用 x 1 x_1 x1和’h_2’计算得到 h 1 ′ h_1' h1′,最后再把 h 1 h_1 h1和 h 1 ′ h_1' h1′进行数据拼接(Concatenate),得到输出结果 y 1 y_1 y1。以此类推,同时利用前向传递和反向传递的数据进行结果的预测。
双向RNN就像是我们做阅读理解的时候从头向后读一遍文章,然后又从后往前读一遍文章,然后再做题。有可能从后往前再读一遍文章的时候会有新的不一样的理解,最后模型可能会得到更好的结果。
2、堆叠的双向RNN

堆叠的双向RNN(Stacked Bidirectional RNN)的结构如上图所示。上图是一个堆叠了3个隐藏层的RNN网络。

注意,这里的堆叠的双向RNN并不是只有双向的RNN才可以堆叠,其实任意的RNN都可以堆叠,如SimpleRNN、LSTM和GRU这些循环神经网络也可以进行堆叠。
堆叠指的是在RNN的结构中叠加多层,类似于BP神经网络中可以叠加多层,增加网络的非线性。
3、双向LSTM实现MNIST数据集分类
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM,Dropout,Bidirectional
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
# 载入数据集
mnist = tf.keras.datasets.mnist
# 载入数据,数据载入的时候就已经划分好训练集和测试集
# 训练集数据x_train的数据形状为(60000,28,28)
# 训练集标签y_train的数据形状为(60000)
# 测试集数据x_test的数据形状为(10000,28,28)
# 测试集标签y_test的数据形状为(10000)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对训练集和测试集的数据进行归一化处理,有助于提升模型训练速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把训练集和测试集的标签转为独热编码
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# 数据大小-一行有28个像素
input_size = 28
# 序列长度-一共有28行
time_steps = 28
# 隐藏层memory block个数
cell_size = 50
# 创建模型
# 循环神经网络的数据输入必须是3维数据
# 数据格式为(数据数量,序列长度,数据大小)
# 载入的mnist数据的格式刚好符合要求
# 注意这里的input_shape设置模型数据输入时不需要设置数据的数量
model = Sequential([
Bidirectional(LSTM(units=cell_size,input_shape=(time_steps,input_size),return_sequences=True)),
Dropout(0.2),
Bidirectional(LSTM(cell_size)),
Dropout(0.2),
# 50个memory block输出的50个值跟输出层10个神经元全连接
Dense(10,activation=tf.keras.activations.softmax)
])
# 循环神经网络的数据输入必须是3维数据
# 数据格式为(数据数量,序列长度,数据大小)
# 载入的mnist数据的格式刚好符合要求
# 注意这里的input_shape设置模型数据输入时不需要设置数据的数量
# model.add(LSTM(
# units = cell_size,
# input_shape = (time_steps,input_size),
# ))
# 50个memory block输出的50个值跟输出层10个神经元全连接
# model.add(Dense(10,activation='softmax'))
# 定义优化器
adam = Adam(lr=1e-3)
# 定义优化器,loss function,训练过程中计算准确率 使用交叉熵损失函数
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# 训练模型
history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test,y_test))
#打印模型摘要
model.summary()
loss=history.history['loss']
val_loss=history.history['val_loss']
accuracy=history.history['accuracy']
val_accuracy=history.history['val_accuracy']
# 绘制loss曲线
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# 绘制acc曲线
plt.plot(accuracy, label='Training accuracy')
plt.plot(val_accuracy, label='Validation accuracy')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
这个可能对文本数据比较容易处理,这里用这个模型有点勉强,只是简单测试下。
模型摘要:

acc曲线:

loss曲线:


边栏推荐
- 【js学习笔记五十四】BFC方式
- 钉钉、企微、飞书学会赚钱了吗?
- How can non-technical departments participate in Devops?
- IDEA新建sprintboot项目
- 微信核酸检测预约小程序系统毕业设计毕设(6)开题答辩PPT
- [vite] 1371 - develop vite plug-ins by hand
- 微信小程序中,从一个页面跳转到另一个页面后,在返回后发现页面同步滚动了
- 5G NR系统架构
- Constrained layout flow
- A large number of virtual anchors in station B were collectively forced to refund: revenue evaporated, but they still owe station B; Jobs was posthumously awarded the U.S. presidential medal of freedo
猜你喜欢

Events and bubbles in the applet of "wechat applet - Basics"

IDEA新建sprintboot项目

如何判断线程池已经执行完所有任务了?

ByteDance Interviewer: how to calculate the memory size occupied by a picture

Apple 5g chip research and development failure? It's too early to get rid of Qualcomm

微信核酸检测预约小程序系统毕业设计毕设(8)毕业设计论文模板

SAP UI5 ObjectPageLayout 控件使用方法分享

ConstraintLayout官方提供圆角ImageFilterView

SAP ui5 objectpagelayout control usage sharing

The most complete is an I2C summary
随机推荐
A large number of virtual anchors in station B were collectively forced to refund: revenue evaporated, but they still owe station B; Jobs was posthumously awarded the U.S. presidential medal of freedo
C language QQ chat room small project [complete source code]
IDEA新建sprintboot项目
WorkManager學習一
微信小程序触底加载与下拉刷新的实现
Timed disappearance pop-up
Z-blog template installation and use tutorial
Activity jump encapsulation
SQL Server 监控统计阻塞脚本信息
请问大佬们 有遇到过flink cdc mongdb 执行flinksql 遇到这样的问题的么?
【SWT组件】内容滚动组件 ScrolledComposite
如何判断线程池已经执行完所有任务了?
Apple 5g chip research and development failure? It's too early to get rid of Qualcomm
Usage differences between isempty and isblank
Detailed explanation of the use of staticlayout
SLAM 01.人类识别环境&路径的模型建立
Pseudo class elements -- before and after
【js学习笔记五十四】BFC方式
Events and bubbles in the applet of "wechat applet - Basics"
Excerpt from "sword comes" (VII)