当前位置:网站首页>双向RNN与堆叠的双向RNN
双向RNN与堆叠的双向RNN
2022-07-05 10:20:00 【别团等shy哥发育】
双向RNN与堆叠的双向RNN
1、双向RNN
双向RNN(Bidirectional RNN)的结构如下图所示。
h t → = f ( W → x t + V → h t − 1 → + b → ) h t ← = f ( W ← x t + V ← h t − 1 ← + b ← ) y t = g ( U [ h t → ; h t ← ] + c ) \overrightarrow{h_t}=f(\overrightarrow{W}x_t+\overrightarrow{V}\overrightarrow{h_{t-1}}+\overrightarrow{b})\\ \overleftarrow{h_t}=f(\overleftarrow{W}x_t+\overleftarrow{V}\overleftarrow{h_{t-1}}+\overleftarrow{b})\\ y_t=g(U[\overrightarrow{h_t};\overleftarrow{h_t}]+c) ht=f(Wxt+Vht−1+b)ht=f(Wxt+Vht−1+b)yt=g(U[ht;ht]+c)
这里的 RNN 可以使用任意一种 RNN 结构 SimpleRNN,LSTM 或 GRU。这里箭头表示从左到右或从右到左传播,对于每个时刻的预测,都需要来自双向的特征向量,拼接 (Concatenate)后进行结果预测。箭头虽然不同,但参数还是同一套参数。有些模型中也 可以使用两套不同的参数。f,g表示激活函数, [ h t → ; h t ← ] [\overrightarrow{h_t};\overleftarrow{h_t}] [ht;ht]表示数据拼接(Concatenate)。
双向的 RNN 是同时考虑“过去”和“未来”的信息。上图是一个序列长度为 4 的双向RNN 结构。
比如输入 x 1 x_1 x1沿着实线箭头传输到隐层得到 h 1 h_1 h1,然后还需要再利用 x t x_t xt计算得到 h t ′ h_t' ht′,利用 x 3 x_3 x3和 h t ′ h_t' ht′计算得到 h 3 ′ h_3' h3′,利用 x 2 x_2 x2和 h 3 ′ h_3' h3′计算得到 h 2 ′ h_2' h2′,利用 x 1 x_1 x1和’h_2’计算得到 h 1 ′ h_1' h1′,最后再把 h 1 h_1 h1和 h 1 ′ h_1' h1′进行数据拼接(Concatenate),得到输出结果 y 1 y_1 y1。以此类推,同时利用前向传递和反向传递的数据进行结果的预测。
双向RNN就像是我们做阅读理解的时候从头向后读一遍文章,然后又从后往前读一遍文章,然后再做题。有可能从后往前再读一遍文章的时候会有新的不一样的理解,最后模型可能会得到更好的结果。
2、堆叠的双向RNN
堆叠的双向RNN(Stacked Bidirectional RNN)的结构如上图所示。上图是一个堆叠了3个隐藏层的RNN网络。
注意,这里的堆叠的双向RNN并不是只有双向的RNN才可以堆叠,其实任意的RNN都可以堆叠,如SimpleRNN、LSTM和GRU这些循环神经网络也可以进行堆叠。
堆叠指的是在RNN的结构中叠加多层,类似于BP神经网络中可以叠加多层,增加网络的非线性。
3、双向LSTM实现MNIST数据集分类
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM,Dropout,Bidirectional
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
# 载入数据集
mnist = tf.keras.datasets.mnist
# 载入数据,数据载入的时候就已经划分好训练集和测试集
# 训练集数据x_train的数据形状为(60000,28,28)
# 训练集标签y_train的数据形状为(60000)
# 测试集数据x_test的数据形状为(10000,28,28)
# 测试集标签y_test的数据形状为(10000)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对训练集和测试集的数据进行归一化处理,有助于提升模型训练速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把训练集和测试集的标签转为独热编码
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# 数据大小-一行有28个像素
input_size = 28
# 序列长度-一共有28行
time_steps = 28
# 隐藏层memory block个数
cell_size = 50
# 创建模型
# 循环神经网络的数据输入必须是3维数据
# 数据格式为(数据数量,序列长度,数据大小)
# 载入的mnist数据的格式刚好符合要求
# 注意这里的input_shape设置模型数据输入时不需要设置数据的数量
model = Sequential([
Bidirectional(LSTM(units=cell_size,input_shape=(time_steps,input_size),return_sequences=True)),
Dropout(0.2),
Bidirectional(LSTM(cell_size)),
Dropout(0.2),
# 50个memory block输出的50个值跟输出层10个神经元全连接
Dense(10,activation=tf.keras.activations.softmax)
])
# 循环神经网络的数据输入必须是3维数据
# 数据格式为(数据数量,序列长度,数据大小)
# 载入的mnist数据的格式刚好符合要求
# 注意这里的input_shape设置模型数据输入时不需要设置数据的数量
# model.add(LSTM(
# units = cell_size,
# input_shape = (time_steps,input_size),
# ))
# 50个memory block输出的50个值跟输出层10个神经元全连接
# model.add(Dense(10,activation='softmax'))
# 定义优化器
adam = Adam(lr=1e-3)
# 定义优化器,loss function,训练过程中计算准确率 使用交叉熵损失函数
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# 训练模型
history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test,y_test))
#打印模型摘要
model.summary()
loss=history.history['loss']
val_loss=history.history['val_loss']
accuracy=history.history['accuracy']
val_accuracy=history.history['val_accuracy']
# 绘制loss曲线
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# 绘制acc曲线
plt.plot(accuracy, label='Training accuracy')
plt.plot(val_accuracy, label='Validation accuracy')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
这个可能对文本数据比较容易处理,这里用这个模型有点勉强,只是简单测试下。
模型摘要:
acc曲线:
loss曲线:
边栏推荐
猜你喜欢
Secteur non technique, comment participer à devops?
ConstraintLayout的流式布局Flow
The most complete is an I2C summary
What is the most suitable book for programmers to engage in open source?
Timed disappearance pop-up
Learning Note 6 - satellite positioning technology (Part 1)
【观察】跨境电商“独立站”模式崛起,如何抓住下一个红利爆发时代?
Usage differences between isempty and isblank
“军备竞赛”时期的对比学习
[paper reading] ckan: collaborative knowledge aware autonomous network for adviser systems
随机推荐
WorkManager的学习二
Shortcut keys for vscode
PHP solves the problems of cache avalanche, cache penetration and cache breakdown of redis
Redis如何实现多可用区?
The horizontally scrolling recycleview displays five and a half on one screen, lower than the average distribution of five
How did automated specification inspection software develop?
ModuleNotFoundError: No module named ‘scrapy‘ 终极解决方式
非技術部門,如何參與 DevOps?
CSDN always jumps to other positions when editing articles_ CSDN sends articles without moving the mouse
Have you learned to make money in Dingding, enterprise micro and Feishu?
In the year of "mutual entanglement" of mobile phone manufacturers, the "machine sea tactics" failed, and the "slow pace" playing method rose
【DNS】“Can‘t resolve host“ as non-root user, but works fine as root
AtCoder Beginner Contest 254「E bfs」「F st表维护差分数组gcd」
How to judge that the thread pool has completed all tasks?
各位大佬,我测试起了3条线程同时往3个mysql表中写入,每条线程分别写入100000条数据,用了f
Z-blog template installation and use tutorial
【观察】跨境电商“独立站”模式崛起,如何抓住下一个红利爆发时代?
Livedata interview question bank and answers -- 7 consecutive questions in livedata interview~
报错:Module not found: Error: Can‘t resolve ‘XXX‘ in ‘XXXX‘
Glide advanced level