当前位置:网站首页>长短期记忆网络 LSTM
长短期记忆网络 LSTM
2022-08-03 08:15:00 【OPTree412】
这里写目录标题
1. LSTM介绍
1.1 什么是LSTM
长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题
。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
1.2 LSTM相较于RNN的优势
举个表现更好的例子:
我们想预测’full’之前系动词的单复数情况,显然full是取决于第二个单词’cat‘的单复数情况,而非其前面的单词food。
- 根据RNN的结构,随着数据时间片的增加,RNN丧失了学习连接如此远的信息的能力。如下图所示,越到后面,前面单词的信息就越来越少了。
- 根据LSTM的结构,能够解决RNN的长期依赖问题,是因为
LSTM引入了门(gate)机制用于控制特征的流通和损失
。对于上面的例子,LSTM可以做到在t9时刻将t2时刻的特征传过来,这样就可以非常有效的判断t9时刻使用单数还是复数了。
1.3 LSTM的结构图
从下图我们可以看到相较于RNN模型,LSTM好想复杂了好多。这正式因为引入了门机制来控制特征的流通和损失,因此LSTM能够在更长的序列中,相较于RNN有更好的表现。
其中,Ct-1 表示过去的记忆, Ct 表示当前时刻神经元的状态,ht-1 表示输入,ht表示输出,Xt表示当前输入。C表示长期记忆,h可以看做是短期记忆,x代表事件信息,也就是输入。这个些很重要!!!!!
现在来看看里面详细的样子
1.3.1 LSTM的核心思想
LSTM主线就是这条顶部水平贯穿的线,也就是长期记忆C线
(细胞状态),达到了序列学习的目的。而h可以看做是短期记忆
,x代表事件信息
,也就是输入。
LSTM具有其中三个门,以保护和控制单元状态。LSTM通过“门”(gate)来控制丢弃或者增加信息,从而实现遗忘或记忆的功能。
“门”是一种使信息选择性通过的结构,由一个sigmoid函数和一个点乘操作组成。sigmoid函数的输出值在[0,1]区间,0代表完全丢弃,1代表完全通过。
1.3.2 LSTM的遗忘门
顾名思义,这个遗忘门就是决定要不要以往的信息,遗忘门控制之前记忆状态的输入幅度。
遗忘门的任务就是接受一个长期记忆 Ct-1(上一个单元模块传过来的输出)并决定要保留和遗忘Ct-1的哪个部分。
例如,在语言模型中,我们想要通过一个词来预测下一个词,单元模块cell 中可能包含某个属性在下一个模块单元不需要,就可以将这个属性在单元模块cell 中遗忘(丢弃)。
1.3.3 LSTM的输入门
输入门控制输入(新记忆)的输入幅度,决定要让多少新的信息加入到cell 状态中。
输入门包括两个部分:
1、sigmoid层,决定那些信息需要更新。(图中第一个公式)
2、tanh层,创建一个新的候选值(candidate)向量,生成候选记忆。(图中第二个公式)
下面,我们把这两个部分联合起来对 cell 状态进行更新。
我们将旧状态Ct-1乘以ft,忘把一些不想保留的信息忘掉,然后加上输入的Ct∗it来形成新的状态Ct。根据我们决定更新每个状态值的程度进行缩放。
1.3.4 LSTM的输出门
输出门控制最终记忆的输出幅度,这个输出主要是依赖于 cell 状态 Ct。
输出门也包括两个部分:
1、经过 sigmoid 层,它决定Ct中的哪些部分将会被输出。
2、tanh层,把Ct的数值控制到(-1,1)之间,然后输出并与 simoid 层计算出来的权重相乘,这样就得到了最后的输出结果。
1.4 LSTM的优缺点
优点:
- 面对对时间序列敏感的问题和任务,RNN(如LSTM)通常会比较合适。RNN用于序列数据,并且有了一定的记忆效应;
- 采用了特殊隐式单元的LSTM是为了长期的保存输入。一种称作记忆细胞的特殊单元类似累加器和门控神经元:它在下一个时间步长将拥有一个权值并联接到自身,拷贝自身状态的真实值和累积的外部信号,但这种自联接是由另一个单元学习并决定何时清除记忆内容的乘法门控制的;
- LSTM是RNN的一个优秀的变种模型,继承了大部分RNN模型的特性,同时解决了梯度反传过程由于逐步缩减而产生的梯度消失问题。
缺点:
- 并行处理上存在劣势。与一些最新的网络相对效果一般;
- RNN的梯度问题在LSTM及其变种里面得到了一定程度的解决,但还是不够。它可以处理100个量级的序列,而对于1000个量级,或者更长的序列则依然会显得很棘手;
- 计算费时。每一个LSTM的cell里面都意味着有4个全连接层(MLP),如果LSTM的时间跨度很大,并且网络又很深,这个计算量会很大,很耗时。
1.5 如何计算 LSTM 的参数量?
步一步来看,实际上这里面有 4 个非线性变换(3 个 sigmoid + 1 个 tanh),每一个非线性变换说白了就是一个全连接网络。
怕您看不太懂,我们先来看一下全连接层的参数个数和输出维度
如图所示,每一个输入都要与所有的神经元进行连接,因此是n_input * n_neur = 2 * 4,然后再加上bias,参数就是n_input * n_neur + n_b= 2 * 4 + 4
那么重点来了,输入是 xt和 ht-1的结合,维度就是 embedding_size + hidden_size。输出层的维度为 hidden_size,然后再加上bias的维度hidden_size。所以该网络的参数量就是:(embedding_size + hidden_size)* hidden_size + hidden_size
一个 cell 有 4 个这样结构相同的网络,那么一个 cell 的总参数量就是直接 × 4:[(embedding_size + hidden_size)* hidden_size + hidden_size ] * 4
1.6 LSTM与GRU相比怎么样呢?
与GRU 相比,LSTM的隐层节点的门的数量和工作方式貌似是非常灵活的,那么是否存在一个最好的结构模型或者比LSTM和GRU性能更好的模型呢?
Rafal 等人采集了能采集到的100个最好模型,然后在这100个模型的基础上通过变异的形式产生了10000个新的模型。然后通过在字符串,结构化文档,语言模型,音频4个场景的实验比较了这10000多个模型,得出的重要结论总结如下:
- GRU,LSTM是表现最好的模型。
- GRU的在除了语言模型的场景中表现均超过LSTM。
- LSTM的输出门的偏置的均值初始化为1时,LSTM的性能接近GRU。
- 在LSTM中,门的重要性排序是遗忘门 > 输入门 > 输出门。
2. nn.LSTM()
2.1 nn.LSTM()的参数解释
nn.LSTM()中的参数和nn.RNN()的参数基本一样,可以参考我的循环神经网络(RNN)当中对参数的解释。
2.2 nn.LSTM()的输入输出
输入数据包括input,(h_0,c_0):
input
就是shape=(seq_length,batch_size,input_size)的张量h_0
是shape=(num_layers*num_directions,batch_size,hidden_size)的张量,它包含了在当前这个batch_size中每个句子的初始隐藏状态。其中num_layers就是LSTM的层数。如果bidirectional=True,num_directions=2,否则就是1,表示只有一个方向。c_0
和h_0的形状相同,它包含的是在当前这个batch_size中的每个句子的初始细胞状态。h_0,c_0如果不提供,那么默认是0。
输出数据包括output,(h_n,c_n):
output
的shape=(seq_length,batch_size,num_directions*hidden_size),
它包含的是LSTM的最后一时间步的输出特征(h_t),t是batch_size中每个句子的长度。h_n
包含的是句子的最后一个单词(也就是最后一个时间步)的隐藏状态,c_n
包含的是句子的最后一个单词的细胞状态,所以它们都与句子的长度seq_length无关。h_n
.shape==(num_directions * num_layers,batch,hidden_size)c_n
.shape==h_n.shape- output[-1]与h_n是相等的,因为
output[-1]包含的正是batch_size个句子中每一个句子的最后一个单词的隐藏状态
。注意LSTM中的隐藏状态其实就是输出,cell state细胞状态才是LSTM中一直隐藏的,记录着信息。
边栏推荐
- 进程信息
- IDEA2021.2安装与配置(持续更新)
- Eject stubborn hard drives with diskpart's offline command
- ArcEngine(八)用IWorkspaceFactory加载矢量数据
- Add Modulo 10 (规律循环节,代码实现细节)
- Docker starts mysql
- “唯一索引允许为空“ 的说法是不严谨的
- vs 2022无法安装 vc_runtimeMinmum_x86错误
- PowerShell:执行 Install-Module 时,不能从 URI 下载
- 使用pipreqs导出项目所需的requirements.txt(而非整个环境)
猜你喜欢
随机推荐
PowerShell:执行 Install-Module 时,不能从 URI 下载
【论文笔记】基于动作空间划分的MAXQ自动分层方法
dflow入门2——Slices
frp: open source intranet penetration tool
The use of the database table structure document generation tool screw
WPS EXCEL 筛选指定长度的文本 内容 字符串
BOM系列之localStorage
uniapp swiper 卡片轮播 修改指示点样式效果demo(整理)
Evaluate:huggingface评价指标模块入门详细介绍
【愚公系列】2022年07月 Go教学课程 026-结构体
IDEA2021.2安装与配置(持续更新)
Laya中关于摄像机跟随人物移动或者点击人物碰撞器触发事件的Demo
mysql服务器上的mysql这个实例中表的介绍
ArcEngine (5) use the ICommand interface to achieve zoom in and zoom out
LAN技术-2免费ARP
【TPC-DS】25张表的详细介绍,SQL的查询特征
流行和声基础大笔记
行业洞察 | 如何更好的实现与虚拟人的互动体验?
ArcEngine(二)加载地图文档
dflow入门5——Big step & Big parameter