当前位置:网站首页>PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
2022-06-26 06:43:00 【Cyril_KI】
I. 前言
关于LSTM的具体原理可以参考:人工智能教程。除了LSTM以外,这个网站还囊括了其他大多机器学习以及深度学习模型的具体讲解,配图生动,简单易懂。
前面已经写了很多关于时间序列预测的文章:
- 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
- PyTorch搭建LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
- PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
- PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
- PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
- PyTorch-LSTM时间序列预测中如何预测真正的未来值
- PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
- PyTorch搭建ANN实现时间序列预测(风速预测)
- PyTorch搭建CNN实现时间序列预测(风速预测)
- PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
上面所有文章一共采用了LSTM、ANN以及CNN三种模型来分别进行时间序列预测。众所周知,CNN提取特征的能力非常强,因此现在不少论文将CNN和LSTM结合起来进行时间序列预测。本文将利用PyTorch来搭建一个简单的CNN-LSTM混合模型实现负荷预测。
II. CNN-LSTM
CNN-LSTM模型搭建如下:
class CNN_LSTM(nn.Module):
def __init__(self, args):
super(CNN_LSTM, self).__init__()
self.args = args
self.relu = nn.ReLU(inplace=True)
# (batch_size=30, seq_len=24, input_size=7) ---> permute(0, 2, 1)
# (30, 7, 24)
self.conv = nn.Sequential(
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=1)
)
# (batch_size=30, out_channels=32, seq_len-4=20) ---> permute(0, 2, 1)
# (30, 20, 32)
self.lstm = nn.LSTM(input_size=args.out_channels, hidden_size=args.hidden_size,
num_layers=args.num_layers, batch_first=True)
self.fc = nn.Linear(args.hidden_size, args.output_size)
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.conv(x)
x = x.permute(0, 2, 1)
x, _ = self.lstm(x)
x = self.fc(x)
x = x[:, -1, :]
return x
可以看到,该CNN-LSTM由一层一维卷积+LSTM组成。
通过PyTorch搭建CNN实现时间序列预测(风速预测)我们知道,一维卷积的原始定义如下:
nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
本文模型的一维卷积定义:
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3)
这里in_channels的概念相当于自然语言处理中的embedding,因此输入通道数为7,表示负荷+其他6个环境变量;out_channels的可以随意设置,本文设置为32;kernel_size设置为3。
PyTorch中一维卷积的输入尺寸为:
input(batch_size, input_size, seq_len)=(30, 7, 24)
而经过数据处理后得到的数据维度为:
input(batch_size, seq_len, input_size)=(30, 24, 7)
因此,我们需要进行维度交换:
x = x.permute(0, 2, 1)
交换后的输入数据将符合CNN的输入。
一维卷积中卷积操作是针对seq_len维度进行的,也就是(30, 7, 24)中的最后一个维度。因此,经过:
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3)
后,数据维度将变为:
(30, 32, 24-3+1)=(30, 32, 22)
第一维度的batch_size不变,第二维度的input_size将由in_channels=7变成out_channels=32,第三维度进行卷积变成22。
然后经过一个最大池化变成:
(30, 32, 22-3+1)=(30, 32, 20)
此时的(30, 32, 20)将作为LSTM的输入。由于在LSTM中我们设置了batch_first=True,因此LSTM能够接收的输入维度为:
input(batch_size, seq_len, input_size)
而经卷积池化后得到的数据维度为:
input(batch_size=30, input_size=32, seq_len=20)
因此,同样需要进行维度交换:
x = x.permute(0, 2, 1)
然后就是比较常规的LSTM输入输出的,不再细说。
因此,完整的forward函数如下所示:
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.conv(x)
x = x.permute(0, 2, 1)
x, _ = self.lstm(x)
x = self.fc(x)
x = x[:, -1, :]
return x
III. 代码实现
3.1 数据处理
我们根据前24个时刻的负荷以及该时刻的环境变量来预测接下来4个时刻的负荷,这里采用了直接多输出策略,调整output_size即可调整输出步长。
代码实现:
def nn_seq(args):
seq_len, B, num = args.seq_len, args.batch_size, args.output_size
print('data processing...')
dataset = load_data()
# split
train = dataset[:int(len(dataset) * 0.6)]
val = dataset[int(len(dataset) * 0.6):int(len(dataset) * 0.8)]
test = dataset[int(len(dataset) * 0.8):len(dataset)]
m, n = np.max(train[train.columns[1]]), np.min(train[train.columns[1]])
def process(data, batch_size, step_size):
load = data[data.columns[1]]
data = data.values.tolist()
load = (load - n) / (m - n)
load = load.tolist()
seq = []
for i in range(0, len(data) - seq_len - num, step_size):
train_seq = []
train_label = []
for j in range(i, i + seq_len):
x = [load[j]]
for c in range(2, 8):
x.append(data[j][c])
train_seq.append(x)
for j in range(i + seq_len, i + seq_len + num):
train_label.append(load[j])
train_seq = torch.FloatTensor(train_seq)
train_label = torch.FloatTensor(train_label).view(-1)
seq.append((train_seq, train_label))
# print(seq[-1])
seq = MyDataset(seq)
seq = DataLoader(dataset=seq, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
return seq
Dtr = process(train, B, step_size=1)
Val = process(val, B, step_size=1)
Dte = process(test, B, step_size=num)
return Dtr, Val, Dte, m, n
3.2 模型训练/测试
和前面一致:
def train(args, Dtr, Val, path):
model = CNN_LSTM(args).to(args.device)
loss_function = nn.MSELoss().to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print('training...')
epochs = 50
min_epochs = 10
best_model = None
min_val_loss = 5
for epoch in range(epochs):
train_loss = []
for batch_idx, (seq, target) in enumerate(Dtr, 0):
seq, target = seq.to(args.device), target.to(args.device)
optimizer.zero_grad()
y_pred = model(seq)
loss = loss_function(y_pred, target)
train_loss.append(loss.item())
loss.backward()
optimizer.step()
# validation
val_loss = get_val_loss(args, model, Val)
if epoch + 1 >= min_epochs and val_loss < min_val_loss:
min_val_loss = val_loss
best_model = copy.deepcopy(model)
print('epoch {:03d} train_loss {:.8f} val_loss {:.8f}'.format(epoch, np.mean(train_loss), val_loss))
model.train()
state = {
'model': best_model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, path)
def test(args, Dte, path, m, n):
print('loading model...')
model = CNN_LSTM(args).to(args.device)
model.load_state_dict(torch.load(path)['model'])
model.eval()
pred = []
y = []
for batch_idx, (seq, target) in enumerate(Dte, 0):
seq = seq.to(args.device)
with torch.no_grad():
target = list(chain.from_iterable(target.tolist()))
y.extend(target)
y_pred = model(seq)
y_pred = list(chain.from_iterable(y_pred.data.tolist()))
pred.extend(y_pred)
y, pred = np.array(y), np.array(pred)
y = (m - n) * y + n
pred = (m - n) * pred + n
print('mape:', get_mape(y, pred))
# plot
x = [i for i in range(1, 151)]
x_smooth = np.linspace(np.min(x), np.max(x), 900)
y_smooth = make_interp_spline(x, y[150:300])(x_smooth)
plt.plot(x_smooth, y_smooth, c='green', marker='*', ms=1, alpha=0.75, label='true')
y_smooth = make_interp_spline(x, pred[150:300])(x_smooth)
plt.plot(x_smooth, y_smooth, c='red', marker='o', ms=1, alpha=0.75, label='pred')
plt.grid(axis='y')
plt.legend()
plt.show()
3.3 实验结果
前24个时刻预测未来4个时刻,MAPE为7.41%:
IV. 源码及数据
后续考虑公开~
边栏推荐
- Research Report on market supply and demand and strategy of China's microneedle device industry
- OCA Security Alliance (cybersecurity mesh)
- Unsatisfied dependency expressed through field ‘baseMapper‘; nested exceptio
- Market trend report, technical innovation and market forecast of microencapsulated chemical pesticides in China
- LightGBM--调参笔记
- LabVIEW arduino TCP / IP Remote Intelligent Home System (Project section - 5)
- Screen sharing recommendations
- MySQL基础用法01
- Gof23 - builder mode
- Hudi compilation of data Lake architecture
猜你喜欢

Hudi compilation of data Lake architecture

MYSQL(三)

STM32F1与STM32CubeIDE编程实例-热敏传感器驱动

On a classical problem
![[digital signal processing] basic sequence (basic sequence lists | unit pulse sequence | unit pulse function | discrete unit pulse function | difference between unit pulse function and discrete unit p](/img/bf/16ea6e1283adda928f62c6f416b254.jpg)
[digital signal processing] basic sequence (basic sequence lists | unit pulse sequence | unit pulse function | discrete unit pulse function | difference between unit pulse function and discrete unit p

Past events of Xinhua III

遇到女司机业余开滴滴,日入500!
![[micro service series] protocol buffer dynamic analysis](/img/86/357d55c77cc67d6413af2de59bf395.png)
[micro service series] protocol buffer dynamic analysis
![[golang] time related](/img/10/56c0031e11677a91a50cda7d8a952f.png)
[golang] time related

Load balancer does not have available server for client: userservice problem solving
随机推荐
C# Nuget离线缓存包安装
How to transfer database data to check box
Research Report on market supply and demand and strategy of natural organic beauty industry in China
Research Report on pallet handling equipment industry - market status analysis and development prospect forecast
cocoscreator播放Spine动画
SparseArray
Failed to configure a DataSource: ‘url‘ attribute is not specified and no embedded datasource could
Pytorch uses multi GPU parallel training and its principle and precautions
闭包问题C# Lua
“试用期避免被辞退“ 指南攻略
Pagoda server setup and database remote connection
Distribution operation of D
The four cores of the browser: Trident, gecko, WebKit, blink
vs code 使用 prettier 格式化 js 的时候, 函数定义的名称和括号之间有一个空格, 而 eslit 又不允许这个空格.
Failed to configure a DataSource: ‘url‘ attribute is not specified and no embedded datasource could
Spark3.3.0源码编译补充篇-抓狂的证书问题
MySQL (III)
Laravel 实现 groupBy 查询分组数量
数据挖掘是什么?
DS18B20 details