当前位置:网站首页>24.循环神经网络RNN
24.循环神经网络RNN
2022-08-04 07:03:00 【派大星的最爱海绵宝宝】
目录
RNN原理
对于一个句子,可能会因为句子过长而导致参数太多,或者上下不联系,我们无法像平常说话一样根据前文推理后面。
我们通过权值共享减少参数,每个单词共用同一个w和
b。
使用memory,持续的单元(consistent memory)贯穿整个网络,能保存语境信息。
最开始的memory初始化为h0,与feature送入的xt合在一起考虑,会生成新的memory为ht,类似于语境信息,ht会根据上一个语境或者新的输入不停的更新,这也是RNN与CNN不同的地方。

对于ht的选取,根据自己的需要,可以选取最后一个,也可以选取综合的ht。

Whh是对h的一个特征提取,Wxh对x的特征提取,整体合起来来经过激活函数tanh。
yt可以理解为一个Linear层,ht根据当前的ht做一个变换得到yt。
how to train

因为权值共享,使用同一个w和h,当前时刻的w和h对当前的y进行影响。
推导过程


RNNLayer使用
rnn=nn.RNN(100,10)
print(rnn._parameters.keys())
print(rnn.weight_ih_l0.shape)
print(rnn.bias_ih_l0.shape)
100是word dim即input dim,10是memory即hidden dim

nn.RNN
init
input_size,表示用几个向量来表示我们的单词
hidden_size,即memory size
num_layers,默认是1
out,ht=forward(x,h0)
x:[seq len,b,word vec],单词的数量、句子的数量、input_size。
h0/ht:[num layers,b,h dim],是特定时间的,h是最后一个时间的所有状态
out:[seq len,b,h dim],out是返回所有聚合过的信息。是所有时间点上的最后一个状态。
单层RNN

rnn=nn.RNN(input_size=100,hidden_size=20,num_layers=1)
print(rnn)
x=torch.randn(10,3,100)
out,h=rnn(x,torch.zeros(1,3,20))
print('out:',out.shape)
print('h:',h.shape)

多层RNN

原本的h是[1,b,20],out是[10,b,20,]现在h变成了[2,b,20],out是[10,b,20]。对于多层来说,out是不变的,永远都是所有的。
rnn=nn.RNN(input_size=100,hidden_size=20,num_layers=4)
print(rnn)
x=torch.randn(10,3,100)
out,h=rnn(x)
print('out:',out.shape)
print('h:',h.shape)

nn.RNNCell
nn.RNN都是句子一次性送入,nn.RNNCell不会循环。
两者的初始化一样,forward不一样。
ht=rnncell(xt,ht_1)
xt:[b,word vec]
ht_1/ht:[num layers,b,h dim]
out=torch.stack([h1,h2,…,ht])
单层
我们需要进行人为的循环。
有几个时间戳就for几次。
x=torch.randn(10,3,100)
cell1=nn.RNNCell(100,20)
h1=torch.zeros(3,20)
for xt in x:
h1=cell1(xt,h1)
print(h1.shape)

多层
cell1和cell2的30位置必须匹配起来。
cell1=nn.RNNCell(100,30)
cell2=nn.RNNCell(30,20)
h1=torch.zeros(3,30)
h2=torch.zeros(3,20)
for xt in x:
h1=cell1(xt,h1)
h2=cell2(h1,h2)
print(h2.shape)

边栏推荐
猜你喜欢

The national vocational skills contest competition of network security emergency response
![[Paper Notes] - Low Illumination Image Enhancement - Supervised - RetinexNet - 2018-BMVC](/img/54/685fb2620aa53416437943705d3d38.png)
[Paper Notes] - Low Illumination Image Enhancement - Supervised - RetinexNet - 2018-BMVC

打破千篇一律,DIY属于自己独一无二的商城

data:image/jpg;base64格式数据转化为图片

Distributed Computing Experiment 1 Load Balancing

Verilog“七宗罪”

Error EPERM operation not permitted, mkdir ‘Dsoftwarenodejsnode_cache_cacach两种解决办法

一天搞定JDBC02:开启事务

登录拦截实现过程

RT-Thread Studio学习(十二)W25Q128(SPI)的读写
随机推荐
redis stream 实现消息队列
QT + msvc2017编译器
likeshop单商户高级版企业源码发布了新的版本1.8.1
unity 循环选择器
力扣每日一题-第47天-15. 三数之和
【selenium自动化】第四篇,结合testNg
adb无法桥接夜神模拟器
高等代数_证明_对称矩阵一定能够相似对角化
unity3d-Animation&&Animator接口(基本使用)
ContrstrainLayout的动画之ConstraintSet
海康VisionMaster与西门子Smart 200进行S7通信
Provide 和 Inject 的用法
MYSQL JDBC图书管理系统
SQL去重的三种方法汇总
【学习笔记】AGC036
解决循环依赖import cycle not allowed的最佳解决办法
分布式计算实验1 负载均衡
GIS数据与CAD数据间带属性字段互相转换还原工具,解决ArcGIS等软件进行GIS数据转CAD数据无法保留属性字段问题
使用腾讯云发送短信 ---- 手把手教你搞定所有步骤
一天学会JDBC04:ResultSet的用法