当前位置:网站首页>CNN-LSTM的flatten
CNN-LSTM的flatten
2022-06-28 20:03:00 【seven_不是赛文】
CNN和LSTM之间该如何连接?
之前有看到过说,展平也行不展平也行
第一种做法,展平
假设你最原始的序列的最大长度为100,词嵌入之后,维度为16,那就是[N, 100,16]。假设你CNN相同卷积并输出64通道,那CNN之后的输出就是[N, 100, 64]。然后你可以把它flatten成[N, 6400]输入LSTM。
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3, activation='relu',
input_shape=(n_timesteps, n_features)))
model.add(Conv1D(filters=64, kernel_size=3, activation='relu'))
model.add(MaxPooling1D(pool_size=2))
model.add(Flatten())
#如果输入的形状为(None,32),
# 经过添加RepeatVector(3)层之后,
# 输出变为(None,3,32),RepeatVector不改变我们的步长,
# 改变我们的每一步的维数(即:属性长度)
model.add(RepeatVector(n_outputs))
model.add(LSTM(200, activation='relu', return_sequences=True))
# TimeDistributed和Dense一起配合使用,
# 主要应用于一对多,多对多的情况。
model.add(TimeDistributed(Dense(100, activation='relu')))
# input_shape = (10, 16),表示步长是10,
# 每一步的维度为16,(即:每一个数据的属性长度为16))
# 首先使用TimeDistributed(Dense(8),input_shape = (10, 16))
# 把每一步的维度为16变成8,不改变步长的大小
# eg:输入形状然后(50, 10, 16),则这一层之后的输出为(50, 10, 8)
model.add(TimeDistributed(Dense(1)))
model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
print(model.summary())

第二种方法,不展平
比如,也可以直接把[N, 100, 64]输入LSTM。
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3, activation='relu',
input_shape=(n_timesteps, n_features)))
model.add(Conv1D(filters=64, kernel_size=6, activation='relu'))
model.add(MaxPooling1D(pool_size=1))
model.add(RepeatVector(n_outputs))
model.add(LSTM(200, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(100, activation='relu')))
model.add(TimeDistributed(Dense(1)))
model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
print(model.summary())

结果显示
还是有区别的:
边栏推荐
- Markdown mermaid种草(1)_ mermaid简介
- 修复一次flutter 无法选中模拟器
- Markdown Mermaid Grass (1) Introduction à Mermaid
- Relevant calculation of sphere, etc
- Echart: category text position adjustment of horizontal histogram
- 市值1200亿美金,老牌财税巨头Intuit是如何做到的?
- redisTemplate
- 2022年T电梯修理考试题库模拟考试平台操作
- Racher add / delete node
- 【毕业季·进击的技术er】努力只能及格,拼命才能优秀!
猜你喜欢

Windows 64 bit download install my SQL

CSDN salary increase technology selenium automated test stack summary

5G NR MBS架构介绍

Racher add / delete node

rsync远程同步

SQL server2019 create a new SQL server authentication user name and log in

ThreadLocal原理

Markdown Mermaid planting grass (1)_ Introduction to Mermaid

ArrayList of collection

2022年P气瓶充装考试练习题及在线模拟考试
随机推荐
算力时代怎么「算」?「算网融合」先发优势很重要!
核芯物联蓝牙aoa定位系统服务器配置估算
Kaggle gastrointestinal image segmentation competition baseline
【算法篇】刷了两道大厂面试题,含泪 ”重学数组“
2022 P cylinder filling test exercises and online simulation test
Day88.七牛云: 房源图片、用户头像上传
C # application interface development foundation - form control
2022 welder (elementary) special operation certificate examination question bank and answers
bluecmsv1.6代码审计
Use of WC command
Jenkins pipeline's handling of job parameters
Grep text search tool
Kettle (VI): full database backup based on kettle
3. integrate listener
管道 | 与重定向 >
Rsync remote synchronization
522. 最长特殊序列 II(贪心&双指针)
grep文本搜索工具
Real number operation
rsync远程同步