当前位置:网站首页>CNN-LSTM的flatten
CNN-LSTM的flatten
2022-06-28 20:03:00 【seven_不是赛文】
CNN和LSTM之间该如何连接?
之前有看到过说,展平也行不展平也行
第一种做法,展平
假设你最原始的序列的最大长度为100,词嵌入之后,维度为16,那就是[N, 100,16]。假设你CNN相同卷积并输出64通道,那CNN之后的输出就是[N, 100, 64]。然后你可以把它flatten成[N, 6400]输入LSTM。
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3, activation='relu',
input_shape=(n_timesteps, n_features)))
model.add(Conv1D(filters=64, kernel_size=3, activation='relu'))
model.add(MaxPooling1D(pool_size=2))
model.add(Flatten())
#如果输入的形状为(None,32),
# 经过添加RepeatVector(3)层之后,
# 输出变为(None,3,32),RepeatVector不改变我们的步长,
# 改变我们的每一步的维数(即:属性长度)
model.add(RepeatVector(n_outputs))
model.add(LSTM(200, activation='relu', return_sequences=True))
# TimeDistributed和Dense一起配合使用,
# 主要应用于一对多,多对多的情况。
model.add(TimeDistributed(Dense(100, activation='relu')))
# input_shape = (10, 16),表示步长是10,
# 每一步的维度为16,(即:每一个数据的属性长度为16))
# 首先使用TimeDistributed(Dense(8),input_shape = (10, 16))
# 把每一步的维度为16变成8,不改变步长的大小
# eg:输入形状然后(50, 10, 16),则这一层之后的输出为(50, 10, 8)
model.add(TimeDistributed(Dense(1)))
model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
print(model.summary())

第二种方法,不展平
比如,也可以直接把[N, 100, 64]输入LSTM。
model = Sequential()
model.add(Conv1D(filters=64, kernel_size=3, activation='relu',
input_shape=(n_timesteps, n_features)))
model.add(Conv1D(filters=64, kernel_size=6, activation='relu'))
model.add(MaxPooling1D(pool_size=1))
model.add(RepeatVector(n_outputs))
model.add(LSTM(200, activation='relu', return_sequences=True))
model.add(TimeDistributed(Dense(100, activation='relu')))
model.add(TimeDistributed(Dense(1)))
model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])
print(model.summary())

结果显示
还是有区别的:
边栏推荐
- QSP读取标签配置错误问题
- 员工薪资管理系统
- odoo15 Module operations are not possible at this time, please try again later or contact your syste
- 集合之ArrayList
- Software supply chain security risk guide for enterprise digitalization and it executives
- Win 10 create a gin framework project
- Leetcode week 299
- 【324. 摆动排序 II】
- Design of secsha system
- 2837. The total number of teams
猜你喜欢

ThreadLocal原理

Racher add / delete node

Win 10 create a gin framework project

Bluecmsv1.6 code audit

easypoi

Number theory -- detailed proof of Euler function, sieve method for Euler function, Euler theorem and Fermat theorem

Day88.七牛云: 房源图片、用户头像上传

How strong a mathematical foundation does deep learning need?

Shell reads the value of the JSON file

csdn涨薪技术-Selenium自动化测试全栈总结
随机推荐
On the first anniversary of the data security law, which four major changes are coming?
redisTemplate
Xiaobai's e-commerce business is very important to choose the right mall system!
数据资产为王,如何解析企业数字化转型与数据资产管理的关系?
Markdown mermaid种草(1)_ mermaid简介
Database learning notes (sql04)
Design of secsha system
ThreadLocal原理
How to understand the usability of cloud native databases?
Why does next() in iterator need to be forcibly converted?
2022 P cylinder filling test exercises and online simulation test
CSDN salary increase technology selenium automated test stack summary
酷学院华少:如何在SaaS赛道里做成一家头部公司
Huawei cloud onemeeting tells you that the whole scene meeting is held like this!
字符和整数
核芯物联蓝牙aoa定位系统服务器配置估算
圆球等的相关计算
2. 整合 Filter
UESTC (shenhengtao team) & JD AI (Mei Tao team) proposed a structured dual stream attention network for video Q & A, with performance SOTA! Better than the method based on dual video representation
Troubleshooting of pyinstaller failed to pack pikepdf