当前位置:网站首页>双向RNN与堆叠的双向RNN
双向RNN与堆叠的双向RNN
2022-07-05 10:20:00 【别团等shy哥发育】
双向RNN与堆叠的双向RNN
1、双向RNN
双向RNN(Bidirectional RNN)的结构如下图所示。
h t → = f ( W → x t + V → h t − 1 → + b → ) h t ← = f ( W ← x t + V ← h t − 1 ← + b ← ) y t = g ( U [ h t → ; h t ← ] + c ) \overrightarrow{h_t}=f(\overrightarrow{W}x_t+\overrightarrow{V}\overrightarrow{h_{t-1}}+\overrightarrow{b})\\ \overleftarrow{h_t}=f(\overleftarrow{W}x_t+\overleftarrow{V}\overleftarrow{h_{t-1}}+\overleftarrow{b})\\ y_t=g(U[\overrightarrow{h_t};\overleftarrow{h_t}]+c) ht=f(Wxt+Vht−1+b)ht=f(Wxt+Vht−1+b)yt=g(U[ht;ht]+c)
这里的 RNN 可以使用任意一种 RNN 结构 SimpleRNN,LSTM 或 GRU。这里箭头表示从左到右或从右到左传播,对于每个时刻的预测,都需要来自双向的特征向量,拼接 (Concatenate)后进行结果预测。箭头虽然不同,但参数还是同一套参数。有些模型中也 可以使用两套不同的参数。f,g表示激活函数, [ h t → ; h t ← ] [\overrightarrow{h_t};\overleftarrow{h_t}] [ht;ht]表示数据拼接(Concatenate)。
双向的 RNN 是同时考虑“过去”和“未来”的信息。上图是一个序列长度为 4 的双向RNN 结构。
比如输入 x 1 x_1 x1沿着实线箭头传输到隐层得到 h 1 h_1 h1,然后还需要再利用 x t x_t xt计算得到 h t ′ h_t' ht′,利用 x 3 x_3 x3和 h t ′ h_t' ht′计算得到 h 3 ′ h_3' h3′,利用 x 2 x_2 x2和 h 3 ′ h_3' h3′计算得到 h 2 ′ h_2' h2′,利用 x 1 x_1 x1和’h_2’计算得到 h 1 ′ h_1' h1′,最后再把 h 1 h_1 h1和 h 1 ′ h_1' h1′进行数据拼接(Concatenate),得到输出结果 y 1 y_1 y1。以此类推,同时利用前向传递和反向传递的数据进行结果的预测。
双向RNN就像是我们做阅读理解的时候从头向后读一遍文章,然后又从后往前读一遍文章,然后再做题。有可能从后往前再读一遍文章的时候会有新的不一样的理解,最后模型可能会得到更好的结果。
2、堆叠的双向RNN
堆叠的双向RNN(Stacked Bidirectional RNN)的结构如上图所示。上图是一个堆叠了3个隐藏层的RNN网络。
注意,这里的堆叠的双向RNN并不是只有双向的RNN才可以堆叠,其实任意的RNN都可以堆叠,如SimpleRNN、LSTM和GRU这些循环神经网络也可以进行堆叠。
堆叠指的是在RNN的结构中叠加多层,类似于BP神经网络中可以叠加多层,增加网络的非线性。
3、双向LSTM实现MNIST数据集分类
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM,Dropout,Bidirectional
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
# 载入数据集
mnist = tf.keras.datasets.mnist
# 载入数据,数据载入的时候就已经划分好训练集和测试集
# 训练集数据x_train的数据形状为(60000,28,28)
# 训练集标签y_train的数据形状为(60000)
# 测试集数据x_test的数据形状为(10000,28,28)
# 测试集标签y_test的数据形状为(10000)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对训练集和测试集的数据进行归一化处理,有助于提升模型训练速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把训练集和测试集的标签转为独热编码
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# 数据大小-一行有28个像素
input_size = 28
# 序列长度-一共有28行
time_steps = 28
# 隐藏层memory block个数
cell_size = 50
# 创建模型
# 循环神经网络的数据输入必须是3维数据
# 数据格式为(数据数量,序列长度,数据大小)
# 载入的mnist数据的格式刚好符合要求
# 注意这里的input_shape设置模型数据输入时不需要设置数据的数量
model = Sequential([
Bidirectional(LSTM(units=cell_size,input_shape=(time_steps,input_size),return_sequences=True)),
Dropout(0.2),
Bidirectional(LSTM(cell_size)),
Dropout(0.2),
# 50个memory block输出的50个值跟输出层10个神经元全连接
Dense(10,activation=tf.keras.activations.softmax)
])
# 循环神经网络的数据输入必须是3维数据
# 数据格式为(数据数量,序列长度,数据大小)
# 载入的mnist数据的格式刚好符合要求
# 注意这里的input_shape设置模型数据输入时不需要设置数据的数量
# model.add(LSTM(
# units = cell_size,
# input_shape = (time_steps,input_size),
# ))
# 50个memory block输出的50个值跟输出层10个神经元全连接
# model.add(Dense(10,activation='softmax'))
# 定义优化器
adam = Adam(lr=1e-3)
# 定义优化器,loss function,训练过程中计算准确率 使用交叉熵损失函数
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# 训练模型
history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test,y_test))
#打印模型摘要
model.summary()
loss=history.history['loss']
val_loss=history.history['val_loss']
accuracy=history.history['accuracy']
val_accuracy=history.history['val_accuracy']
# 绘制loss曲线
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# 绘制acc曲线
plt.plot(accuracy, label='Training accuracy')
plt.plot(val_accuracy, label='Validation accuracy')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
这个可能对文本数据比较容易处理,这里用这个模型有点勉强,只是简单测试下。
模型摘要:
acc曲线:
loss曲线:
边栏推荐
- Pseudo class elements -- before and after
- 微信核酸检测预约小程序系统毕业设计毕设(7)中期检查报告
- [论文阅读] CKAN: Collaborative Knowledge-aware Atentive Network for Recommender Systems
- DDOS攻击原理,被ddos攻击的现象
- 【DNS】“Can‘t resolve host“ as non-root user, but works fine as root
- Excerpt from "sword comes" (VII)
- What is the most suitable book for programmers to engage in open source?
- Constraintlayout officially provides rounded imagefilterview
- 风控模型启用前的最后一道工序,80%的童鞋在这都踩坑
- @SerializedName注解使用
猜你喜欢
【js学习笔记五十四】BFC方式
What is the origin of the domain knowledge network that drives the new idea of manufacturing industry upgrading?
AtCoder Beginner Contest 254「E bfs」「F st表维护差分数组gcd」
SAP UI5 ObjectPageLayout 控件使用方法分享
Workmanager learning 1
[observation] with the rise of the "independent station" model of cross-border e-commerce, how to seize the next dividend explosion era?
How can non-technical departments participate in Devops?
重磅:国产IDE发布,由阿里研发,完全开源!
What is the most suitable book for programmers to engage in open source?
[dark horse morning post] Luo Yonghao responded to ridicule Oriental selection; Dong Qing's husband Mi Chunlei was executed for more than 700million; Geely officially acquired Meizu; Huawei releases M
随机推荐
QT implements JSON parsing
橫向滾動的RecycleView一屏顯示五個半,低於五個平均分布
【js学习笔记五十四】BFC方式
Who is the "conscience" domestic brand?
What are the top ten securities companies? Is it safe to open an account online?
WorkManager的学习二
@Jsonadapter annotation usage
SLAM 01.人类识别环境&路径的模型建立
Activity jump encapsulation
How did automated specification inspection software develop?
Today in history: the first e-book came out; The inventor of magnetic stripe card was born; The pioneer of handheld computer was born
leetcode:1200. 最小绝对差
Interview: is bitmap pixel memory allocated in heap memory or native
Flink CDC cannot monitor MySQL logs. Have you ever encountered this problem?
isEmpty 和 isBlank 的用法区别
《天天数学》连载58:二月二十七日
Error: module not found: error: can't resolve 'xxx' in 'XXXX‘
手机厂商“互卷”之年:“机海战术”失灵,“慢节奏”打法崛起
如何判断线程池已经执行完所有任务了?
Z-blog template installation and use tutorial