当前位置:网站首页>Pytorch学习笔记13——Basic_RNN
Pytorch学习笔记13——Basic_RNN
2022-07-31 05:16:00 【qq_50749521】
Pytorch学习笔记13——Basic_RNN
今天讨论一下基本的RNN,RNN叫做循环神经网络,是神经网络的一种。本质是对线性层的复用。
现在考虑一种场景:
已知前四天的温度、气压及是否下雨,我们需要预测第五天的气象情况。显然易见,如果我们只在第五天拿到当天的温度气压值去判断是否下雨,这是没有意义的,因为我们需要提前判断。于是这个问题就变成了:
第一天:温度 气压 是否下雨
第二天:温度 气压 是否下雨
第三天:温度 气压 是否下雨
第四天:温度 气压 是否下雨
一共四个样本,每个样本有三个特征,作为input,去预测第五天的情况。我们可以把x={x1,x2,x3,x4}拼接成一个序列,x2的数值会部分依赖于x1,x3的数值部分依赖于x2,x4的数值部分依赖于x3。
RNN主要用于处理这样具有序列性质的输入数据。

xt表示第t时刻的输入数据,在这里就是第t天的所有数据,特征为三维,因此他是3D的。经过RNN cell后,得到另一个维度的向量,假设变成5D。RNN Cell的本质就是一个线性层,把一个维度映射到另一个维度的空间内。
和我们平常用的线性层不同的是,这里的线性层是共享的。
展开来看,x1为第一天的3D数据,和初始值h0融合之后经过线性层得到h1(hidden),h1又和x2融合再经过同一个线性层得到h2,h2又和x3融合再经过同一个线性层得到h3…以此类推,这也是RNN叫做循环神经网络的原因。
这里面的初始值h0,如果没有先验,那就可以先设成一个全0向量。
用代码来表示这个过程就是:
h = 0
for x in X:
h = Linear(x, h)
那x(t) 和h(t-1)是怎么融合的呢?具体计算过程如下:
输入x(t)是(1, input_size)大小的,设h(t-1)是(1, hidden_size)大小。
x(t)经过一个self.Linear(input_size, hidden_size)进行线性变换后变成 (1, hidden_size)大小,
h(t-1)经过一个self.Linear(hidden_size, hidden_size)线性变换后变成 (1, hidden_size)大小。
目的就是使两输出矩阵大小相同,从而进行相加融合,最后通过一个tanh激活函数即可。在RNN里更多用tanh而不是sigmoid是因为tanh输出(-1, 1),往往效果更好。
而RNN Cell实现的就是中间部分的运算。
我们可以把这俩线性运算拼一块,进行向量化操作:
实际上,我们就是做了这样一件拼接的工作,把h和x拼成一个(hidden_size
+input_size)大小的矩阵,Whh和Wih拼成一个hidden_size*(hidden_size + input_size)大小的矩阵。从而变成一个线性层的工作,这就是RNN本质实现方式。
在pytorch中共有两种实现方式:创建RNNCell或直接创建RNN。
1. 创建RNNCell
cell = torch.nn.RNNCell(input_size = input_size, hidden_size = hidden_size)
hidden = cell(input, hidden)

这里把每一时刻的input和前一时刻的hidden送进去,这也注定了用这种创建方法需要我们自己写循环,遍历序列长度。
对于批量训练来说,这里的输入维度是:
input(batch_size, input_size)
hidden(batch_size, hidden_size)
输出维度:
hidden(batch_size, hidden_size)
举个例子,假设现在batch_size就等于1,每次处理1个样本。seqlen = 3,每个样本的序列长为3。输入特征维度为4,即input_size = 4;输出特征维度为2, 即hidden_size = 2
从而,RNNCell的输入矩阵是
input = (batch_size, input_size) = (1, 4)
输出矩阵是
output = (batch_size, hidden_size) = (1, 2)
整个数据集就可以写成:
dataset.shape = (seqlen, batch_size, input_size) = (3, 1, 4)
代表每次处理batch_size个样本的数据,每个样本都是由seqlen长度的序列构成的,在当前时刻有input_size个特征。
import torch
batch_size = 1
seq_len = 3
input_size = 4
hidden_size = 2
cell = torch.nn.RNNCell(input_size = input_size, hidden_size = hidden_size)
dataset = torch.randn(seq_len, batch_size, input_size)
hidden = torch.zeros(batch_size, hidden_size)
for idx, input in enumerate(dataset):
print('='*20, idx, '='*20)#按序列长度seq_len来循环,每次取出input(batch_size, input_size)
print('Input_size = ', input.shape)
hidden = cell(input, hidden)
print('Output_size = ', hidden.shape)
print(hidden) #打印每次输出
==================== 0 ====================
Input_size = torch.Size([1, 4])
Output_size = torch.Size([1, 2])
tensor([[-0.2002, -0.1809]], grad_fn=<TanhBackward0>)
==================== 1 ====================
Input_size = torch.Size([1, 4])
Output_size = torch.Size([1, 2])
tensor([[ 0.6019, -0.6742]], grad_fn=<TanhBackward0>)
==================== 2 ====================
Input_size = torch.Size([1, 4])
Output_size = torch.Size([1, 2])
tensor([[ 0.4382, -0.5559]], grad_fn=<TanhBackward0>)
2. 创建 RNN
cell = torch.nn.RNN(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers)
out, hidden = cell(inputs, hidden)

从左到右:
out表示整个输出{h1, h2 ,…, hN}
hidden表示最终的输出hN
inputs表示整个输入{x1, x2, x3, x4},
hidden表示初始h0
我们把整个序列送进去,RNN内部自己完成循环,返回每一次的输出以及最终输出。
INPUT:
input.shape = (seqSize, batch_size, input_size)
hidden.shape = (numLayers, batch_size, hidden_size)
OUTPUT:
output.shape = (seqSize,batch_size, hidden_size)
hidden.shape = (numLayers, batch_size, hidden_size)
这里的numlayers指的是层数。
import torch
batch_size = 1
seq_len = 3
input_size = 4
hidden_size = 2
num_layers = 1
cell = torch.nn.RNN(input_size = input_size, hidden_size = hidden_size,
num_layers = num_layers)
inputs = torch.randn(seq_len, batch_size, input_size)
hidden = torch.zeros(num_layers, batch_size, hidden_size)
out, hidden = cell(inputs, hidden)
print('Output size = ', out.shape)
print(out)
print('Hidden size = ', hidden.shape)
print(hidden)
Output size = torch.Size([3, 1, 2])
tensor([[[ 0.7941, 0.2922]],
[[ 0.0260, 0.0201]],
[[ 0.6268, -0.0600]]], grad_fn=<StackBackward0>)
Hidden size = torch.Size([1, 1, 2])
tensor([[[ 0.6268, -0.0600]]], grad_fn=<StackBackward0>)
RNN里面还有一些其他的参数,比如batch_first,如果设置了为True,那么input时需要变成inputs = torch.randn(batch_size, seq_len, input_size).
RNN应用:
进行一个学习任务,将序列"hello"序列学习变换为“ohlol”
①根据字符构造词典,给每个字符分配索引,对词进行one-hot编码
input_size就是所有字符种类,这里等于4。
②送到RNN里,很明显是一个多分类问题,输出得到4向量,接softmax得到P(y = e), P(y = h), P(y = l), P(y = o)。与标签的one-hot编码向量进行损失计算。
import torch
batch_size = 1
input_size = 4
hidden_size = 4
idx2Char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]#hello
y_data = [3, 1, 2, 3 ,2]#ohlol
one_hot_lookup = [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]
y_one_hot = [one_hot_lookup[y] for y in y_data]
print(x_one_hot)
print(y_one_hot)
[[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
[[0, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0]]
inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
labels = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)
print(inputs)
print(labels)
tensor([[[0., 1., 0., 0.]],
[[1., 0., 0., 0.]],
[[0., 0., 1., 0.]],
[[0., 0., 1., 0.]],
[[0., 0., 0., 1.]]])
tensor([[[0., 1., 0., 0.]],
[[1., 0., 0., 0.]],
[[0., 0., 1., 0.]],
[[0., 0., 1., 0.]],
[[0., 0., 0., 1.]]])
class Model(torch.nn.Module):
def __init__(self, input_size, hidden_size, batch_size):
super(Model, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.hidden_size = hidden_size
self.rnncell = torch.nn.RNNCell(input_size = self.input_size,
hidden_size = self.hidden_size)
def forward(self, input, hidden):
hidden = self.rnncell(input, hidden)
return hidden
def init_hidden(self):
return torch.zeros(self.batch_size, self.hidden_size)
net = Model(input_size, hidden_size, batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.1)
for epoch in range(25):
loss = 0
optimizer.zero_grad()
hidden = net.init_hidden()
for input, label in zip(inputs, labels):
hidden = net(input, hidden)
loss += criterion(hidden, label)
_, idx = hidden.max(dim=1)
print(idx.item(), end='')
loss.backward()
optimizer.step()
print(", Epoch: [%d/25] loss = %.4f" % (epoch+1, loss.item()))
最终输出:
12131, Epoch: [1/25] loss = 7.0161
12333, Epoch: [2/25] loss = 6.0797
12223, Epoch: [3/25] loss = 5.3174
12223, Epoch: [4/25] loss = 4.7188
10223, Epoch: [5/25] loss = 4.2018
10223, Epoch: [6/25] loss = 3.7728
10223, Epoch: [7/25] loss = 3.4347
10223, Epoch: [8/25] loss = 3.1307
10223, Epoch: [9/25] loss = 2.8780
10223, Epoch: [10/25] loss = 2.6863
10223, Epoch: [11/25] loss = 2.5337
10223, Epoch: [12/25] loss = 2.4025
10223, Epoch: [13/25] loss = 2.2894
10223, Epoch: [14/25] loss = 2.1993
10223, Epoch: [15/25] loss = 2.1284
10223, Epoch: [16/25] loss = 2.0683
10223, Epoch: [17/25] loss = 2.0166
10223, Epoch: [18/25] loss = 1.9729
10223, Epoch: [19/25] loss = 1.9360
10223, Epoch: [20/25] loss = 1.9051
10223, Epoch: [21/25] loss = 1.8795
10223, Epoch: [22/25] loss = 1.8586
10223, Epoch: [23/25] loss = 1.8415
10223, Epoch: [24/25] loss = 1.8273
10223, Epoch: [25/25] loss = 1.8152
10223代表ohlol,可见损失逐步减小,模型学习到了从hello到ohlol的序列变换。
边栏推荐
猜你喜欢

MySql to create data tables

动态规划(一)| 斐波那契数列和归递

对js的数组的理解

RuntimeError: CUDA error: no kernel image is available for execution on the device问题记录

Hyper-V新建虚拟机注意事项

flutter arr 依赖

The browser looks for events bound or listened to by js

为数学而歌之伯努利家族

2021美赛C题M奖思路

The feign call fails, JSON parse error Illegal character ((CTRL-CHAR, code 31)) only regular white space (r
随机推荐
js中的this指向与原型对象
UiBot has an open Microsoft Edge browser and cannot perform the installation
对js的数组的理解
MySQL错误-this is incompatible with sql_mode=only_full_group_by完美解决方案
After unicloud is released, the applet prompts that the connection to the local debugging service failed. Please check whether the client and the host are under the same local area network.
数据库 | SQL查询进阶语法
ERROR Error: No module factory availabl at Object.PROJECT_CONFIG_JSON_NOT_VALID_OR_NOT_EXIST ‘Error
quick-3.6源码修改纪录
kotlin 插件更新到1.3.21
Attribute Changer的几种形态
Navicat从本地文件中导入sql文件
Access database query
人脸识别AdaFace学习笔记
Podspec verification dependency error problem pod lib lint , need to specify the source
cocos2d-x-3.2 不能混合颜色修改
Talking about the understanding of CAP in distributed mode
DC-CDN学习笔记
Android software security and reverse analysis reading notes
Pytorch实现ResNet
一文速学-玩转MySQL获取时间、格式转换各类操作方法详解