当前位置:网站首页>TensorFlow raw_rnn - 实现seq2seq模式中将上一时刻的输出作为下一时刻的输入
TensorFlow raw_rnn - 实现seq2seq模式中将上一时刻的输出作为下一时刻的输入
2022-07-25 09:23:00 【Tobi_Obito】
核心问题
在大部分情况下,RNN的输入序列都是预先定义好的,最为常见的就是训练语料中的sentence。但在序列生成任务中,有时我们希望根据 t 时刻预测出的结果(经过一定变形)作为 t+1 时刻的输入,也就是说一开始我们手中并没有一个完整的句子,往往最开始(t = 0时刻)我们只有一个开始标记"<START>",将<START>输入RNN得到初始时刻的输出
,然后将
(或进行一定的变换)作为下一时刻(t = 1时刻)的输入,即
,再将
输入到RNN得到输出
,以此类推,直到预测到指定长度(或者终止标记"<END>")后停止预测。
可见,这个过程是一个动态的过程,其实现关键是在时刻间进行一定的处理(将
时刻的输出处理后作为
时刻的输入),但现在常见的RNN封装都没有提供在计算时序间进行处理的操作(包括dynamic_rnn,它的dynamic只是指的以循环方式进行而不是遍历预先定义好的输入序列,这里不过多介绍,更多可以自行查询)。而tf.nn.raw_rnn则提供了这种更底层细节上的操作支持。
tf.nn.raw_rnn api简介
先给出部分源码,思路很清晰
def raw_rnn(cell, loop_fn,
parallel_iterations=None, swap_memory=False, scope=None):
"""Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`.
**NOTE: This method is still in testing, and the API may change.**
This function is a more primitive version of `dynamic_rnn` that provides
more direct access to the inputs each iteration. It also provides more
control over when to start and finish reading the sequence, and
what to emit for the output.
For example, it can be used to implement the dynamic decoder of a seq2seq
model.
Instead of working with `Tensor` objects, most operations work with
`TensorArray` objects directly.
The operation of `raw_rnn`, in pseudo-code, is basically the following:
```python
time = tf.constant(0, dtype=tf.int32)
(finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
while not all(finished):
(output, cell_state) = cell(next_input, state)
(next_finished, next_input, next_state, emit, loop_state) = loop_fn(
time=time + 1, cell_output=output, cell_state=cell_state,
loop_state=loop_state)
# Emit zeros and copy forward state for minibatch entries that are finished.
state = tf.where(finished, state, next_state)
emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
emit_ta = emit_ta.write(time, emit)
# If any new minibatch entries are marked as finished, mark these.
finished = tf.logical_or(finished, next_finished)
time += 1
return (emit_ta, state, loop_state)
...```
with the additional properties that output and state may be (possibly nested)
tuples, as determined by `cell.output_size` and `cell.state_size`, and
as a result the final `state` and `emit_ta` may themselves be tuples.
A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this:
```python
inputs = tf.placeholder(shape=(max_time, batch_size, input_depth),
dtype=tf.float32)
sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32)
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time)
inputs_ta = inputs_ta.unstack(inputs)
cell = tf.contrib.rnn.LSTMCell(num_units)
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output # == None for time == 0
if cell_output is None: # time == 0
next_cell_state = cell.zero_state(batch_size, tf.float32)
else:
next_cell_state = cell_state
elements_finished = (time >= sequence_length)
finished = tf.reduce_all(elements_finished)
next_input = tf.cond(
finished,
lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
lambda: inputs_ta.read(time))
next_loop_state = None
return (elements_finished, next_input, next_cell_state,
emit_output, next_loop_state)
outputs_ta, final_state, _ = raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()
... ```其中,关键在于loop_fn()函数,正是使用该函数来处理时序间的输出输入转换。
loop_fn是一个函数,这个函数在rnn的相邻时间步之间被调用。
函数的总体调用过程为:
1. 初始时刻,先调用一次loop_fn,获取第一个时间步的cell的输入,loop_fn中进行读取初始时刻的输入。
2. 进行cell自环 (output, cell_state) = cell(next_input, state)
3. 在t时刻RNN计算结束时,cell有一组输出cell_output和状态cell_state,都是tensor;
4. 到t+1时刻开始进行计算之前,loop_fn被调用,调用的形式为loop_fn( t, cell_output, cell_stat, loop_state),而被期待的输出为:(finished, next_input, initial_state, emit_output, loop_state);
5. RNN采用loop_fn返回的next_input作为输入,initial_state作为状态,计算得到新的输出。
在每次执行(output, cell_state) = cell(next_input, state)后,执行loop_fn()进行数据的准备和处理。
emit_structure 即上文的emit_output将会按照时间存入emit_ta中。
loop_state 记录rnn loop的变量的状态。用作记录状态
Tf.where 是用来实现dynamic的。
### loop_fn()
```python
(elements_finished, next_input, next_cell_state, emit_output, next_loop_state) = loop_fn(time, cell_output, cell_state, loop_state)至此,raw_rnn的使用在代码中已经很明确了,主要是按个人需求自定义loop_fn()中的操作。
边栏推荐
- Flutter Rive 多状态例子
- MinkowskiEngine 安装
- Object initialization
- Some skills to reduce the complexity of program space
- [data mining] Chapter 3 basis of data analysis
- 十进制整数转换为其它进制的数
- 打造个人极限写作流程 -转载
- CDA Level1知识点总结之业务分析报告与数据可视化报表
- 鱼眼图像自监督深度估计原理分析和Omnidet核心代码解读
- 从鱼眼到环视到多任务王炸——盘点Valeo视觉深度估计经典文章(从FisheyeDistanceNet到OmniDet)(上)
猜你喜欢

从鱼眼到环视到多任务王炸——盘点Valeo视觉深度估计经典文章(从FisheyeDistanceNet到OmniDet)(上)

Prim minimum spanning tree (diagram)

初识Opencv4.X----为图像添加高斯噪声
![自定义 view 实现兑奖券背景[初级]](/img/97/53e28673dcd52b31ac7eb7b00d42b3.png)
自定义 view 实现兑奖券背景[初级]

【数据挖掘】第四章 分类任务(决策树)

Some operations of main function

目标检测与分割之MaskRCNN代码结构流程全面梳理+总结

UI - infinite rotation chart and column controller

初识Opencv4.X----图像卷积

Gartner 2022年顶尖科技趋势之超级自动化
随机推荐
Evolution based on packnet -- review of depth estimation articles of Toyota Research Institute (TRI) (Part 2)
*6-1 CCF 2015-03-2 numerical sorting
Raspberry sect door ban system based on face recognition
Some operations of main function
初识Opencv4.X----图像直方图匹配
初识Opencv4.X----图像卷积
【数据挖掘】第四章 分类任务(决策树)
UI原型资源
【cf】Round 128 C. Binary String
基于PackNet的演进——丰田研究院(TRI)深度估计文章盘点(下)
OC -- Foundation -- Collection
【数据挖掘】第二章 认识数据
从鱼眼到环视到多任务王炸——盘点Valeo视觉深度估计经典文章(从FisheyeDistanceNet到OmniDet)(上)
目标检测与分割之MaskRCNN代码结构流程全面梳理+总结
[deep learning] convolutional neural network
解决esp8266无法连接手机和电脑热点的问题
@3-2 optimal threshold of CCF 2020-12-2 final forecast
基于人脸识别的树莓派门禁系统
Learning new technology language process
[dimension reduction strike] Hilbert curve