当前位置:网站首页>使用Keras实现CNN+BiLSTM+Attention的多维(多变量)时间序列预测
使用Keras实现CNN+BiLSTM+Attention的多维(多变量)时间序列预测
2022-07-24 05:22:00 【一只小EZ】
数据集
首先介绍一下我们的数据集,可以在我的github下载
该数据集是一个污染数据集,我们需要用该多维时间序列去预测pollution这个维度
构建训练数据
首先我们删去数据中date,wnd_dir维(注:为了演示方便故不使用wnd_dir,其实可以通过代码将其转换为数字序列)
data = pd.read_csv("./pollution.csv")
data = data.drop(['date','wnd_dir'], axis = 1)
再对于数据进行归一化处理,这里因为工程需要,笔者自写了最大最小归一化,可以使用
sklearn的归一化函数代替
#多维归一化 返回数据和最大最小值
def NormalizeMult(data):
data = np.array(data)
normalize = np.arange(2*data.shape[1],dtype='float64')
normalize = normalize.reshape(data.shape[1],2)
print(normalize.shape)
for i in range(0,data.shape[1]):
#第i列
list = data[:,i]
listlow,listhigh = np.percentile(list, [0, 100])
# print(i)
normalize[i,0] = listlow
normalize[i,1] = listhigh
delta = listhigh - listlow
if delta != 0:
#第j行
for j in range(0,data.shape[0]):
data[j,i] = (data[j,i] - listlow)/delta
#np.save("./normalize.npy",normalize)
return data,normalize
data,normalize = NormalizeMult(data)
对于时间序列预测而言,需要使用滑动时间窗方法,对于数据集进行构造,才可以输入到所构造的神经网络之中,该方法具体可以参见 使用LSTM进行多维多步的时间序列预测,文章里总结了数据集构造的情况
def create_dataset(dataset, look_back):
''' 对数据进行处理 '''
dataX, dataY = [], []
for i in range(len(dataset)-look_back-1):
a = dataset[i:(i+look_back),:]
dataX.append(a)
dataY.append(dataset[i + look_back,:])
TrainX = np.array(dataX)
Train_Y = np.array(dataY)
return TrainX, Train_Y
pollution_data = data[:,0].reshape(len(data),1)
INPUT_DIMS = 7
TIME_STEPS = 20
lstm_units = 64
train_X, _ = create_dataset(data,TIME_STEPS)
_ , train_Y = create_dataset(pollution_data,TIME_STEPS)
所构建的网络输入为 [samples, timesteps, input_dims]
由于这里是使用多维数据去预测pollution维,指定 TIME_STEPS = 20,即对于单个数据,是使用7维(输入数据)的前20步去预测pollution维的后一步。
| 数据 | shape |
|---|---|
| train_X | (43699, 20, 7) |
| train_Y | (43699, 1) |
构造CNN+BiLSTM+Attention的预测网络
在这里整体的网络结构参考了:
CoupletAI:基于CNN+Bi-LSTM+Attention 的自动对对联系统
Keras框架 深度学习模型CNN+LSTM+Attention机制 预测黄金主力收盘价
注意力机制的实现见我的博客使用Keras实现 基于注意力机制(Attention)的 LSTM 时间序列预测
在这里在输入维度方向上添加了注意力机制,即不同重要性的维度权值不同
TensorFlow版本为:1.9.0
Keras版本为:2.0.2
SINGLE_ATTENTION_VECTOR = False
def attention_3d_block(inputs):
# inputs.shape = (batch_size, time_steps, input_dim)
input_dim = int(inputs.shape[2])
a = inputs
#a = Permute((2, 1))(inputs)
#a = Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what.
a = Dense(input_dim, activation='softmax')(a)
if SINGLE_ATTENTION_VECTOR:
a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a)
a = RepeatVector(input_dim)(a)
a_probs = Permute((1, 2), name='attention_vec')(a)
output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul')
return output_attention_mul
def attention_model():
inputs = Input(shape=(TIME_STEPS, INPUT_DIMS))
x = Conv1D(filters = 64, kernel_size = 1, activation = 'relu')(inputs) #, padding = 'same'
x = Dropout(0.3)(x)
#lstm_out = Bidirectional(LSTM(lstm_units, activation='relu'), name='bilstm')(x)
lstm_out = Bidirectional(LSTM(lstm_units, return_sequences=True))(x)
attention_mul = attention_3d_block(lstm_out)
attention_mul = Flatten()(attention_mul)
output = Dense(1, activation='sigmoid')(attention_mul)
model = Model(inputs=[inputs], outputs=output)
return model
进行训练
m = attention_model()
m.summary()
m.compile(optimizer='adam', loss='mse')
m.fit([train_X], train_Y, epochs=10, batch_size=64, validation_split=0.1)
网络结构如图
注:如上网络参数可适当缩小
代码已经上传到我的github
参考:
CoupletAI:基于CNN+Bi-LSTM+Attention 的自动对对联系统
Keras框架 深度学习模型CNN+LSTM+Attention机制 预测黄金主力收盘价
使用Keras实现 基于注意力机制(Attention)的 LSTM 时间序列预测
使用LSTM进行多维多步的时间序列预测
边栏推荐
- Commands for quickly opening management tools
- MySQL download and installation environment settings
- 顺序栈 C语言 进栈 出栈 遍历
- Foundation of JUC concurrent programming (1) -- related basic concepts
- JVM系统学习
- 通道注意力与空间注意力模块
- Statistical analysis of catering data --- Teddy cloud course homework
- day1-jvm+leetcode
- Numpy array broadcast rule memory method array broadcast broadcast principle broadcast mechanism
- Sequential stack C language stack entry and exit traversal
猜你喜欢

顺序栈 C语言 进栈 出栈 遍历
![[MYCAT] related concepts of MYCAT](/img/44/99d413d47828252267b5242c64960b.png)
[MYCAT] related concepts of MYCAT

MySql与Qt连接、将数据输出到QT的窗口tableWidget详细过程。
![[deep learning] teach you to write](/img/c6/333b16758d79ebd77185be6e3cb38f.png)
[deep learning] teach you to write "handwritten digit recognition neural network" hand in hand, without using any framework, pure numpy

String methods and instances

JDBC进阶—— 师承尚硅谷(DAO)

MySQL基础---约束

Foundation of JUC concurrent programming (7) -- multithread lock

【USB Host】STM32H7 CubeMX移植带FreeRTOS的USB Host读取U盘,USBH_Process_OS卡死问题,有个值为0xA5A5A5A5

如何解决训练集和测试集的分布差距过大问题
随机推荐
JUC并发编程基础(7)--多线程锁
Raspberry pie is of great use. Use the campus network to build a campus local website
【数据库系统原理】第五章 代数和逻辑查询语言:包、扩展操作符、关系逻辑、关系代数与Datalog
C language linked list (create, traverse, release, find, delete, insert a node, sort, reverse order)
通道注意力与空间注意力模块
数组常用方法
AD1256
MySQL download and installation environment settings
谷歌/火狐浏览器管理后台新增账号时用户名密码自动填入的问题
Jupyter notebook选择conda环境
CRC-16 MODBUS code
使用Qt连接MySql并创建表号、写入数据、删除数据
Foundation of JUC concurrent programming (8) -- read write lock
《剑指Offer》 二维数组的查找 C语言版本
Sequential stack C language stack entry and exit traversal
day3-jvm+排序总结
"Statistical learning methods (2nd Edition)" Li Hang Chapter 13 introduction to unsupervised learning mind map notes
JVM system learning
Openwrt quick configuration Samba
day4-jvm