当前位置:网站首页>PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
2022-06-28 21:56:00 【Cyril_KI】
I. 前言
前面已经写了很多关于时间序列预测的文章:
- 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
- PyTorch搭建LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
- PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
- PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
- PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
- PyTorch-LSTM时间序列预测中如何预测真正的未来值
- PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
- PyTorch搭建ANN实现时间序列预测(风速预测)
- PyTorch搭建CNN实现时间序列预测(风速预测)
- PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
- PyTorch搭建Transformer实现多变量多步长时间序列预测(负荷预测)
- 时间序列预测系列文章总结(代码使用方法)
上述文章中都没有涉及到近些年来比较火的Attention机制,随Attention机制一起提出的是transformer模型,关于transformer模型的原理网上各种讲解很多,这里就不具体描述了,有机会再写。
II. Transformer
PyTorch封装了Transformer的具体实现,如果导入失败可以参考:torch.nn.Transformer导入失败。
Transformer模型搭建如下:
class TransformerModel(nn.Module):
def __init__(self, args):
super(TransformerModel, self).__init__()
self.args = args
# embed_dim = head_dim * num_heads?
self.input_fc = nn.Linear(args.input_size, args.d_model)
self.output_fc = nn.Linear(args.input_size, args.d_model)
self.pos_emb = PositionalEncoding(args.d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=args.d_model,
nhead=8,
dim_feedforward=4 * args.input_size,
batch_first=True,
dropout=0.1,
device=device
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=args.d_model,
nhead=8,
dropout=0.1,
dim_feedforward=4 * args.input_size,
batch_first=True,
device=device
)
self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=8)
self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=8)
self.fc = nn.Linear(args.output_size * args.d_model, args.output_size)
def forward(self, x, y):
# print(x.size()) # (256, 24, 7)
x = self.input_fc(x) # (256, 24, 128)
x = self.pos_emb(x) # (256, 24, 128)
x = self.encoder(x)
# print(y.size()) # (256, 4, 7)
y = self.output_fc(y) # (256, 4, 128)
out = self.decoder(y, x) # (256, 4, 128)
out = out.view(out.shape[0], -1) # (256, 4 * 128)
out = self.fc(out) # (256, 4)
return out
初始时的数据输入维度为7,也就是每个时刻的负荷值以及6个环境变量。在Transformer的原始论文中,文本的嵌入维度为512,而且PyTorch规定nhead数和d_model也就是嵌入维度必须满足整除关系,因此首先将原始数据从7维映射到d_model维度:
x = self.input_fc(x)
其中input_fc:
self.input_fc = nn.Linear(args.input_size, args.d_model)
然后对原始输入进行位置编码:
x = self.pos_emb(x)
然后经过编码层:
x = self.encoder(x)
得到的输出和输入维度一致。
接着将编码器输出x和标签y同时输入解码器进行解码:
y = self.output_fc(y) # (256, 4, 128)
out = self.decoder(y, x)
标签y在进入解码器前同样需要将其维度由7映射到d_model。
值得注意的是,在前面的文章中,y的维度都是(batch_size, output_size),而在Transformer中,y的维度为(batch_size, output_size, d_model)。
III. 代码实现
3.1 数据处理
利用前24小时的负荷值+环境变量预测后4个时刻的负荷值,数据处理和前面一致,只是需要注意的是,y中不再只含有负荷值这1个变量,而是和x一样,都含有7个变量。
3.2 模型训练/测试
和前文一致。
3.3 实验结果
相关参数如下所示:
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=50, help='input dimension')
parser.add_argument('--seq_len', type=int, default=24, help='seq len')
parser.add_argument('--input_size', type=int, default=7, help='input dimension')
parser.add_argument('--d_model', type=int, default=128, help='input dimension')
parser.add_argument('--output_size', type=int, default=4, help='output dimension')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--optimizer', type=str, default='adam', help='type of optimizer')
parser.add_argument('--device', default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
parser.add_argument('--weight_decay', type=float, default=1e-9, help='weight decay')
parser.add_argument('--bidirectional', type=bool, default=False, help='LSTM direction')
parser.add_argument('--step_size', type=int, default=10, help='step size')
parser.add_argument('--gamma', type=float, default=0.5, help='gamma')
args = parser.parse_args()
return args
训练50轮,MAPE为5.04%:
IV. 源码及数据
边栏推荐
- 17 `bs对象.节点名h3.parent` parents 获取父节点 祖先节点
- 代码复查
- 职业问诊 | 面试中被问到意向薪资时,该怎么回答?
- 加刚干的前提
- ROS 2 Humble Hawksbill 之 f1tenth gym
- 河狸生存记:90后女博士与AI开发者们
- 科技巨头成立元宇宙标准论坛,走向开放还是建立围城?
- Sword finger offer:[day 1 stack and queue (simple)] --- > stack containing min function
- Security dilemma of NFT liquidity agreement - Analysis of the hacked event of NFT loan agreement xcarnival
- Hardware development notes (VII): basic process of hardware development, making a USB to RS232 module (VI): creating 0603 package and associating principle graphic devices
猜你喜欢

Laravel文档阅读笔记-Adding a Markdown editor to Laravel

How to advance data analysis from 1 to 10?

CVPR 2022 𞓜 a creative and aesthetic text generation method! Support any input

Hardware development notes (VII): basic process of hardware development, making a USB to RS232 module (VI): creating 0603 package and associating principle graphic devices

Use of dynamic panels

这个简单的小功能,半年为我们产研团队省下213个小时

视觉弱监督学习研究进展

IDC:阿里云获2021中国数据治理平台市场份额第一

爱数SMART 2022峰会开启,分享数据战略与建设数据驱动型组织方法论

数据可视化中柱状图的实例应用,让乘风破浪公演结果一目了然
随机推荐
Rule engine development experience sharing - reddit
[linq] execute SQL like in statements using EF to LINQ
代码复查
6月底了,让我康康有多少准备跳槽的
go-cryptobin 常用加密解密库
Is it safe to open a stock trading account? Is it reliable?
华为云的AI深潜之旅
常用工具类与commons 类库
Analysis of CSRF Cross Site Request Forgery vulnerability
共探数字技术与信息安全,第四届中俄数字论坛成功举办
Native implementation Net 5.0+ custom log
【电子实验2】简单电子门铃
17 `bs对象.节点名h3.parent` parents 获取父节点 祖先节点
Wechat applet realizes left sliding deletion
软件测试的三个沟通技巧
Laravel文档阅读笔记-Adding a Markdown editor to Laravel
彪马携手10KTF Shop启动其迄今为止规模首屈一指的Web3合作项目
Akamai acquires linode
CVPR 2022 𞓜 a creative and aesthetic text generation method! Support any input
河狸生存记:90后女博士与AI开发者们