当前位置:网站首页>PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
2022-06-26 06:43:00 【Cyril_KI】
I. 前言
关于LSTM的具体原理可以参考:人工智能教程。除了LSTM以外,这个网站还囊括了其他大多机器学习以及深度学习模型的具体讲解,配图生动,简单易懂。
前面已经写了很多关于时间序列预测的文章:
- 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
- PyTorch搭建LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
- PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
- PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
- PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
- PyTorch-LSTM时间序列预测中如何预测真正的未来值
- PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
- PyTorch搭建ANN实现时间序列预测(风速预测)
- PyTorch搭建CNN实现时间序列预测(风速预测)
- PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
上面所有文章一共采用了LSTM、ANN以及CNN三种模型来分别进行时间序列预测。众所周知,CNN提取特征的能力非常强,因此现在不少论文将CNN和LSTM结合起来进行时间序列预测。本文将利用PyTorch来搭建一个简单的CNN-LSTM混合模型实现负荷预测。
II. CNN-LSTM
CNN-LSTM模型搭建如下:
class CNN_LSTM(nn.Module):
def __init__(self, args):
super(CNN_LSTM, self).__init__()
self.args = args
self.relu = nn.ReLU(inplace=True)
# (batch_size=30, seq_len=24, input_size=7) ---> permute(0, 2, 1)
# (30, 7, 24)
self.conv = nn.Sequential(
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=1)
)
# (batch_size=30, out_channels=32, seq_len-4=20) ---> permute(0, 2, 1)
# (30, 20, 32)
self.lstm = nn.LSTM(input_size=args.out_channels, hidden_size=args.hidden_size,
num_layers=args.num_layers, batch_first=True)
self.fc = nn.Linear(args.hidden_size, args.output_size)
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.conv(x)
x = x.permute(0, 2, 1)
x, _ = self.lstm(x)
x = self.fc(x)
x = x[:, -1, :]
return x
可以看到,该CNN-LSTM由一层一维卷积+LSTM组成。
通过PyTorch搭建CNN实现时间序列预测(风速预测)我们知道,一维卷积的原始定义如下:
nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
本文模型的一维卷积定义:
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3)
这里in_channels的概念相当于自然语言处理中的embedding,因此输入通道数为7,表示负荷+其他6个环境变量;out_channels的可以随意设置,本文设置为32;kernel_size设置为3。
PyTorch中一维卷积的输入尺寸为:
input(batch_size, input_size, seq_len)=(30, 7, 24)
而经过数据处理后得到的数据维度为:
input(batch_size, seq_len, input_size)=(30, 24, 7)
因此,我们需要进行维度交换:
x = x.permute(0, 2, 1)
交换后的输入数据将符合CNN的输入。
一维卷积中卷积操作是针对seq_len维度进行的,也就是(30, 7, 24)中的最后一个维度。因此,经过:
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3)
后,数据维度将变为:
(30, 32, 24-3+1)=(30, 32, 22)
第一维度的batch_size不变,第二维度的input_size将由in_channels=7变成out_channels=32,第三维度进行卷积变成22。
然后经过一个最大池化变成:
(30, 32, 22-3+1)=(30, 32, 20)
此时的(30, 32, 20)将作为LSTM的输入。由于在LSTM中我们设置了batch_first=True,因此LSTM能够接收的输入维度为:
input(batch_size, seq_len, input_size)
而经卷积池化后得到的数据维度为:
input(batch_size=30, input_size=32, seq_len=20)
因此,同样需要进行维度交换:
x = x.permute(0, 2, 1)
然后就是比较常规的LSTM输入输出的,不再细说。
因此,完整的forward函数如下所示:
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.conv(x)
x = x.permute(0, 2, 1)
x, _ = self.lstm(x)
x = self.fc(x)
x = x[:, -1, :]
return x
III. 代码实现
3.1 数据处理
我们根据前24个时刻的负荷以及该时刻的环境变量来预测接下来4个时刻的负荷,这里采用了直接多输出策略,调整output_size即可调整输出步长。
代码实现:
def nn_seq(args):
seq_len, B, num = args.seq_len, args.batch_size, args.output_size
print('data processing...')
dataset = load_data()
# split
train = dataset[:int(len(dataset) * 0.6)]
val = dataset[int(len(dataset) * 0.6):int(len(dataset) * 0.8)]
test = dataset[int(len(dataset) * 0.8):len(dataset)]
m, n = np.max(train[train.columns[1]]), np.min(train[train.columns[1]])
def process(data, batch_size, step_size):
load = data[data.columns[1]]
data = data.values.tolist()
load = (load - n) / (m - n)
load = load.tolist()
seq = []
for i in range(0, len(data) - seq_len - num, step_size):
train_seq = []
train_label = []
for j in range(i, i + seq_len):
x = [load[j]]
for c in range(2, 8):
x.append(data[j][c])
train_seq.append(x)
for j in range(i + seq_len, i + seq_len + num):
train_label.append(load[j])
train_seq = torch.FloatTensor(train_seq)
train_label = torch.FloatTensor(train_label).view(-1)
seq.append((train_seq, train_label))
# print(seq[-1])
seq = MyDataset(seq)
seq = DataLoader(dataset=seq, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
return seq
Dtr = process(train, B, step_size=1)
Val = process(val, B, step_size=1)
Dte = process(test, B, step_size=num)
return Dtr, Val, Dte, m, n
3.2 模型训练/测试
和前面一致:
def train(args, Dtr, Val, path):
model = CNN_LSTM(args).to(args.device)
loss_function = nn.MSELoss().to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print('training...')
epochs = 50
min_epochs = 10
best_model = None
min_val_loss = 5
for epoch in range(epochs):
train_loss = []
for batch_idx, (seq, target) in enumerate(Dtr, 0):
seq, target = seq.to(args.device), target.to(args.device)
optimizer.zero_grad()
y_pred = model(seq)
loss = loss_function(y_pred, target)
train_loss.append(loss.item())
loss.backward()
optimizer.step()
# validation
val_loss = get_val_loss(args, model, Val)
if epoch + 1 >= min_epochs and val_loss < min_val_loss:
min_val_loss = val_loss
best_model = copy.deepcopy(model)
print('epoch {:03d} train_loss {:.8f} val_loss {:.8f}'.format(epoch, np.mean(train_loss), val_loss))
model.train()
state = {
'model': best_model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, path)
def test(args, Dte, path, m, n):
print('loading model...')
model = CNN_LSTM(args).to(args.device)
model.load_state_dict(torch.load(path)['model'])
model.eval()
pred = []
y = []
for batch_idx, (seq, target) in enumerate(Dte, 0):
seq = seq.to(args.device)
with torch.no_grad():
target = list(chain.from_iterable(target.tolist()))
y.extend(target)
y_pred = model(seq)
y_pred = list(chain.from_iterable(y_pred.data.tolist()))
pred.extend(y_pred)
y, pred = np.array(y), np.array(pred)
y = (m - n) * y + n
pred = (m - n) * pred + n
print('mape:', get_mape(y, pred))
# plot
x = [i for i in range(1, 151)]
x_smooth = np.linspace(np.min(x), np.max(x), 900)
y_smooth = make_interp_spline(x, y[150:300])(x_smooth)
plt.plot(x_smooth, y_smooth, c='green', marker='*', ms=1, alpha=0.75, label='true')
y_smooth = make_interp_spline(x, pred[150:300])(x_smooth)
plt.plot(x_smooth, y_smooth, c='red', marker='o', ms=1, alpha=0.75, label='pred')
plt.grid(axis='y')
plt.legend()
plt.show()
3.3 实验结果
前24个时刻预测未来4个时刻,MAPE为7.41%:
IV. 源码及数据
后续考虑公开~
边栏推荐
- js-下载图片
- Alarm operation and Maintenance Center | build an efficient and accurate alarm collaborative processing system
- 数据挖掘是什么?
- Mysql delete in 不走索引的
- I use flask to write the website "II"
- TCP連接與斷開,狀態遷移圖詳解
- SecureCRT运行SparkShell 删除键出现乱码的解法
- Zotero使用之自定义参考文献格式
- Introduction to the use of TS generics in functions, interfaces and classes
- Go语言学习笔记 1.2-变量篇
猜你喜欢

If you meet a female driver who drives didi as an amateur, you can earn 500 yuan a day!

Vulnerability discovery - API interface service vulnerability probe type utilization and repair

What is data mining?

Connexion et déconnexion TCP, détails du diagramme de migration de l'état

Go语言学习笔记 1.1

Live broadcast Preview - fire safety instructor training "cloud class" is about to start!

Gof23 - prototype mode

在公司逮到一个阿里10年的测试开发,聊过之后大彻大悟...

How to set MySQL triggers is a simple tutorial for novices

Zotero文献管理工具之Jasminum(茉莉花)插件
随机推荐
[alluxio & Dachang] the original boss direct employment was applied in this way
Laravel implements groupby to query the number of packets
[micro service series] protocol buffer dynamic analysis
Vulnerability discovery - API interface service vulnerability probe type utilization and repair
Research Report on market supply and demand and strategy of China's pallet scale industry
How to set MySQL triggers is a simple tutorial for novices
OCA安全联盟(CyberSecurity Mesh)
STM 32 使用cube 生成TIM触发ADC并通过DMA传输的问题
Simple use of enum type in TS
MySQL 数据库的小白安装与登录
New generation engineers teach you how to play with alluxio + ml (Part 1)
China micronutrient market trend report, technical innovation and market forecast
Marketing skills: compared with the advantages of the product, it is more effective to show the use effect to customers
Hudi compilation of data Lake architecture
Mysql delete in 不走索引的
面试官:测试计划和测试方案有什么区别?
STM32F1与STM32CubeIDE编程实例-热敏传感器驱动
typescript的class结合接口(interface)的简单使用
Zotero文献管理工具之Jasminum(茉莉花)插件
连接数服务器数据库报:错误号码2003Can‘t connect to MySQL server on ‘服务器地址‘(10061)