当前位置:网站首页>lstm pipeline 过程理解(输入输出)
lstm pipeline 过程理解(输入输出)
2022-08-04 05:29:00 【TigerZ*】
step1, raw text:接触LSTM模型不久,简单看了一些相关的论文,还没有动手实现过。然而至今仍然想不通LSTM神经网络究竟是怎么工作的。……step2, tokenize (中文得分词):sentence1: 接触 LSTM 模型 不久 ,简单 看了 一些 相关的 论文 , 还 没有 动手 实现过 。sentence2: 然而 至今 仍然 想不通 LSTM 神经网络 究竟是 怎么 工作的。……step3, dictionarize:sentence1: 1 34 21 98 10 23 9 23sentence2: 17 12 21 12 8 10 13 79 31 44 9 23……step4, padding every sentence to fixed length:sentence1: 1 34 21 98 10 23 9 23 0 0 0 0 0sentence2: 17 12 21 12 8 10 13 79 31 44 9 23 0……step5, mapping token to an embeddings:sentence1:,每一列代表一个词向量,词向量维度自行确定;矩阵列数固定为time_step length。sentence2:……step6, feed into RNNs as input:假设 一个RNN的time_step 确定为l, 则padded sentence length(step5中矩阵列数)固定为l。一次RNNs的run只处理一条sentence。每个sentence的每个token的embedding对应了每个时序t的输入 。一次RNNs的run,连续地将整个sentence处理完。step7, get output:看图,每个time_step都是可以输出当前时序t的隐状态 ;但整体RNN的输出 是在最后一个time_step t=l 时获取,才是完整的最终结果。step8, further processing with the output:我们可以将output根据分类任务或回归拟合任务的不同,分别进一步处理。比如,传给cross_entropy&softmax进行分类……或者获取每个time_step对应的隐状态 ,做seq2seq 网络……或者搞创新……
import tensorflow as tf
import numpy as np
def dynamic_rnn(rnn_type='lstm'):
X=np.random.rand(3,6,4)
X[1,4:]=0
X_length=[6,4,6]
rnn_hidden_size=5
if(rnn_type=='lstm'):
cell=tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size,state_is_tuple=True)
else:
cell=tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
num=cell.output_size
outputs,last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_length,
inputs=X
)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1,s1 = session.run([outputs,last_states])
print(X)
print(np.shape(o1))
print(o1)
print(np.shape(s1))
print(s1)
print(num)
if __name__ == '__main__':
dynamic_rnn(rnn_type='lstm')
边栏推荐
- (十二)树--哈夫曼树
- Deep Adversarial Decomposition: A Unified Framework for Separating Superimposed Images
- [Deep Learning 21 Days Learning Challenge] 1. My handwriting was successfully recognized by the model - CNN implements mnist handwritten digit recognition model study notes
- 【CV-Learning】图像分类
- PyTorch
- latex-写论文时一些常用设置
- (五)栈及其应用
- TensorFlow2 study notes: 6. Overfitting and underfitting, and their mitigation solutions
- MySql--存储引擎以及索引
- 剑指 Offer 2022/7/2
猜你喜欢
Data reading in yolov3 (1)
(十)树的基础部分(一)
DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better 图像去模糊
【CV-Learning】卷积神经网络预备知识
MFC读取点云,只能正常显示第一个,显示后面时报错
MAE 论文《Masked Autoencoders Are Scalable Vision Learners》
【CV-Learning】Convolutional Neural Network
TensorFlow2学习笔记:4、第一个神经网模型,鸢尾花分类
基于PyTorch的FCN-8s语义分割模型搭建
RecyclerView的用法
随机推荐
postgres recursive query
逻辑回归---简介、API简介、案例:癌症分类预测、分类评估法以及ROC曲线和AUC指标
(十五)B-Tree树(B-树)与B+树
ValueError: Expected 96 from C header, got 88 from PyObject
The use of the attribute of the use of the animation and ButterKnife
简单明了,数据库设计三大范式
oracle的number与postgresql的numeric对比
DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better 图像去模糊
sklearn中的学习曲线learning_curve函数
Deep Adversarial Decomposition: A Unified Framework for Separating Superimposed Images
软著撰写注意事项
AIDL communication between two APPs
剑指 Offer 2022/7/5
read and study
SQL的性能分析、优化
(TensorFlow) - detailed explanation of tf.variable_scope and tf.name_scope
0, deep learning 21 days learning challenge 】 【 set up learning environment
【论文阅读】Exploring Spatial Significance via Hybrid Pyramidal Graph Network for Vehicle Re-identificatio
TensorFlow2学习笔记:7、优化器
ConnectionRefusedError: [Errno 111] Connection refused问题解决