当前位置:网站首页>lstm pipeline 过程理解(输入输出)
lstm pipeline 过程理解(输入输出)
2022-08-04 05:29:00 【TigerZ*】




step1, raw text:接触LSTM模型不久,简单看了一些相关的论文,还没有动手实现过。然而至今仍然想不通LSTM神经网络究竟是怎么工作的。……step2, tokenize (中文得分词):sentence1: 接触 LSTM 模型 不久 ,简单 看了 一些 相关的 论文 , 还 没有 动手 实现过 。sentence2: 然而 至今 仍然 想不通 LSTM 神经网络 究竟是 怎么 工作的。……step3, dictionarize:sentence1: 1 34 21 98 10 23 9 23sentence2: 17 12 21 12 8 10 13 79 31 44 9 23……step4, padding every sentence to fixed length:sentence1: 1 34 21 98 10 23 9 23 0 0 0 0 0sentence2: 17 12 21 12 8 10 13 79 31 44 9 23 0……step5, mapping token to an embeddings:sentence1:
,每一列代表一个词向量,词向量维度自行确定;矩阵列数固定为time_step length。sentence2:……step6, feed into RNNs as input:假设 一个RNN的time_step 确定为l, 则padded sentence length(step5中矩阵列数)固定为l。一次RNNs的run只处理一条sentence。每个sentence的每个token的embedding对应了每个时序t的输入。一次RNNs的run,连续地将整个sentence处理完。
step7, get output:看图,每个time_step都是可以输出当前时序t的隐状态;但整体RNN的输出
是在最后一个time_step t=l 时获取,才是完整的最终结果。
step8, further processing with the output:我们可以将output根据分类任务或回归拟合任务的不同,分别进一步处理。比如,传给cross_entropy&softmax进行分类……或者获取每个time_step对应的隐状态,做seq2seq 网络……或者搞创新……
import tensorflow as tf
import numpy as np
def dynamic_rnn(rnn_type='lstm'):
X=np.random.rand(3,6,4)
X[1,4:]=0
X_length=[6,4,6]
rnn_hidden_size=5
if(rnn_type=='lstm'):
cell=tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size,state_is_tuple=True)
else:
cell=tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
num=cell.output_size
outputs,last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_length,
inputs=X
)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1,s1 = session.run([outputs,last_states])
print(X)
print(np.shape(o1))
print(o1)
print(np.shape(s1))
print(s1)
print(num)
if __name__ == '__main__':
dynamic_rnn(rnn_type='lstm')
边栏推荐
- 【论文阅读】Anchor-Free Person Search
- SQL练习 2022/7/3
- pgsql函数中的return类型
- Android foundation [Super detailed android storage method analysis (SharedPreferences, SQLite database storage)]
- BatchNorm&&LayerNorm
- PP-LiteSeg
- TensorFlow2 study notes: 6. Overfitting and underfitting, and their mitigation solutions
- TensorFlow2学习笔记:4、第一个神经网模型,鸢尾花分类
- 动手学深度学习_多层感知机
- android基础 [超级详细android存储方式解析(SharedPreferences,SQLite数据库存储)]
猜你喜欢
(十一)树--堆排序
【论文阅读】Exploring Spatial Significance via Hybrid Pyramidal Graph Network for Vehicle Re-identificatio
[CV-Learning] Semantic Segmentation
多项式回归(PolynomialFeatures)
Matplotlib中的fill_between;np.argsort()函数
sklearn中的pipeline机制
with recursive用法
【CV-Learning】图像分类
0, deep learning 21 days learning challenge 】 【 set up learning environment
字典特征提取,文本特征提取。
随机推荐
[Deep Learning 21 Days Learning Challenge] 1. My handwriting was successfully recognized by the model - CNN implements mnist handwritten digit recognition model study notes
Various commands such as creating a new user in postgresql
【论文阅读】TransReID: Transformer-based Object Re-Identification
SQL练习 2022/7/5
oracle的number与postgresql的numeric对比
TensorFlow:tf.ConfigProto()与Session
[CV-Learning] Semantic Segmentation
(TensorFlow) - detailed explanation of tf.variable_scope and tf.name_scope
【论文阅读】Multi-View Spectral Clustering with Optimal Neighborhood Laplacian Matrix
[Go language entry notes] 13. Structure (struct)
(十)树的基础部分(二)
简单明了,数据库设计三大范式
【CV-Learning】Object Detection & Instance Segmentation
TensorFlow2学习笔记:6、过拟合和欠拟合,及其缓解方案
[Deep Learning 21 Days Learning Challenge] Memo: What does our neural network model look like? - detailed explanation of model.summary()
PyTorch
Install dlib step pit record, error: WARNING: pip is configured with locations that require TLS/SSL
图像合并水平拼接
【深度学习21天学习挑战赛】2、复杂样本分类识别——卷积神经网络(CNN)服装图像分类
空洞卷积