当前位置:网站首页>Pytorch builds transformer to realize multivariable and multi step time series forecasting (load forecasting)
Pytorch builds transformer to realize multivariable and multi step time series forecasting (load forecasting)
2022-06-28 22:24:00 【Cyril_ KI】
Catalog
I. Preface
I have written many articles about time series prediction :
- In depth understanding of PyTorch in LSTM Input and output of ( from input Input to Linear Output )
- PyTorch build LSTM Time series prediction is realized ( Load forecasting )
- PyTorch build LSTM Realize multivariable time series prediction ( Load forecasting )
- PyTorch Build a two-way network LSTM Time series prediction is realized ( Load forecasting )
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( One ): Direct multiple output
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( Two ): Single step rolling prediction
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( 3、 ... and ): Multi model single step prediction
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( Four ): Multi model rolling prediction
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( 5、 ... and ):seq2seq
- PyTorch To realize LSTM Several methods of multi step time series prediction are summarized ( Load forecasting )
- PyTorch-LSTM How to predict the real future value in time series prediction
- PyTorch build LSTM Realize multivariable input and multivariable output time series prediction ( Multi task learning )
- PyTorch build ANN Time series prediction is realized ( Wind speed prediction )
- PyTorch build CNN Time series prediction is realized ( Wind speed prediction )
- PyTorch build CNN-LSTM The hybrid model realizes the prediction of multivariable and multi step time series ( Load forecasting )
- PyTorch build Transformer Realize multivariable and multi step time series prediction ( Load forecasting )
- Summary of time series prediction series ( How to use the code )
None of the above articles has touched on the popular in recent years Attention Mechanism , along with Attention The mechanism proposed together is transformer Model , About transformer The principle of the model is explained on the Internet , I won't describe it here , Write again if you have a chance .
II. Transformer
PyTorch Encapsulates the Transformer The concrete realization of , If the import fails, you can refer to :torch.nn.Transformer Import failed .
Transformer The model is built as follows :
class TransformerModel(nn.Module):
def __init__(self, args):
super(TransformerModel, self).__init__()
self.args = args
# embed_dim = head_dim * num_heads?
self.input_fc = nn.Linear(args.input_size, args.d_model)
self.output_fc = nn.Linear(args.input_size, args.d_model)
self.pos_emb = PositionalEncoding(args.d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=args.d_model,
nhead=8,
dim_feedforward=4 * args.input_size,
batch_first=True,
dropout=0.1,
device=device
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=args.d_model,
nhead=8,
dropout=0.1,
dim_feedforward=4 * args.input_size,
batch_first=True,
device=device
)
self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=8)
self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=8)
self.fc = nn.Linear(args.output_size * args.d_model, args.output_size)
def forward(self, x, y):
# print(x.size()) # (256, 24, 7)
x = self.input_fc(x) # (256, 24, 128)
x = self.pos_emb(x) # (256, 24, 128)
x = self.encoder(x)
# print(y.size()) # (256, 4, 7)
y = self.output_fc(y) # (256, 4, 128)
out = self.decoder(y, x) # (256, 4, 128)
out = out.view(out.shape[0], -1) # (256, 4 * 128)
out = self.fc(out) # (256, 4)
return out
The initial data input dimension is 7, That is, the load value at each time and 6 Environment variables . stay Transformer In the original paper , The embedded dimension of the text is 512, and PyTorch Regulations nhead Sum of numbers d_model That is, the embedded dimension must satisfy the integer division relationship , So first, the original data from 7 Dimension maps to d_model dimension :
x = self.input_fc(x)
among input_fc:
self.input_fc = nn.Linear(args.input_size, args.d_model)
Then encode the position of the original input :
x = self.pos_emb(x)
Then it goes through the coding layer :
x = self.encoder(x)
The resulting output is consistent with the input dimension .
Then output the encoder x And labels y At the same time, input the decoder to decode :
y = self.output_fc(y) # (256, 4, 128)
out = self.decoder(y, x)
label y Before entering the decoder, you also need to change its dimension from 7 Mapping to d_model.
It is worth noting that , In the previous article ,y The dimensions of are all (batch_size, output_size), And in the Transformer in ,y The dimensions are (batch_size, output_size, d_model).
III. Code implementation
3.1 Data processing
Before utilization 24 Hourly load value + After prediction of environmental variables 4 Load value at a time , Data processing is the same as before , Just pay attention to ,y It no longer contains only the load value 1 A variable , But and x equally , Both contain 7 A variable .
3.2 model training / test
Same as above .
3.3 experimental result
The relevant parameters are as follows :
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=50, help='input dimension')
parser.add_argument('--seq_len', type=int, default=24, help='seq len')
parser.add_argument('--input_size', type=int, default=7, help='input dimension')
parser.add_argument('--d_model', type=int, default=128, help='input dimension')
parser.add_argument('--output_size', type=int, default=4, help='output dimension')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--optimizer', type=str, default='adam', help='type of optimizer')
parser.add_argument('--device', default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
parser.add_argument('--weight_decay', type=float, default=1e-9, help='weight decay')
parser.add_argument('--bidirectional', type=bool, default=False, help='LSTM direction')
parser.add_argument('--step_size', type=int, default=10, help='step size')
parser.add_argument('--gamma', type=float, default=0.5, help='gamma')
args = parser.parse_args()
return args
Training 50 round ,MAPE by 5.04%:
IV. Source code and data
be based on PyTorch Of Transformer Time series prediction code
边栏推荐
- 嵌入式中 动态阿拉伯语字符串 转换 LCD显示字符串【感谢建国雄心】
- 解读 | 数据分析的发展和演变都经过哪几个阶段?
- Zadig + SonarQube,为开发过程安全保驾
- The technology giants set up the meta universe standard forum to open up or build a besieged city?
- VR全景制作的前景如何?
- 【SSH】无密码登录
- The example application of histogram in data visualization makes the public performance results clear at a glance
- code review
- 职业问诊 | 在数据分析面试中,这样做自我介绍才靠谱
- 论文解读(DCN)《Towards K-means-friendly Spaces: Simultaneous Deep Learning and Clustering》
猜你喜欢

After reading the list of global patent and chip buyers, I understand that high innovation can lead to high profits

IC Nansha|AMD高级副总裁、大中华区总裁潘晓明:制程、架构、平台优化突破计算边界

C#/VB.NET 将PDF转为Excel

YAYA LIVE CTO 唐鸿斌:真正本地化,要让产品没有「产地」属性

【电子实验2】简单电子门铃

Rosdep update using fishros to solve ros1/ros2 problems 2022

Adding a markdown editor to lavel

華為雲的AI深潜之旅

基于graph-linked embedding的多组学单细胞数据整合与调控推理

F1tenth gym of ROS 2 humble hawksbill
随机推荐
Water brother's code
关于杠杆的思考
Gross vs60 billion. Is food safety the biggest obstacle to Weilong's listing?
code review
华为云GaussDB(for Redis)揭秘第19期:六大秒级能力盘点
稳!上千微服务如何快速接入 Zadig(Helm Chart 篇)
Use of axurer9 master
Wechat applet realizes left sliding deletion
【SSH】无密码登录
Simple understanding of counting and sorting
PAT 1021. Traversal of the deep root (25 points) graph, DFS, calculating the number of connected components
职场进阶 | 了解岗位优势三板斧之“进场”
小样本利器2.文本对抗+半监督 FGSM & VAT & FGM代码实现
6年心得,从功能测试到测试开发,送给在测试路上一路走到黑的你
MSCI 2022年市场分类评估
How to use the style label of jade template- How to use the style tag with jade templates?
职业问诊 | 在数据分析面试中,这样做自我介绍才靠谱
How to open a safe and reliable securities account in the financial management class of qiniu school?
穿越过后,她说多元宇宙真的存在
给朋友的忠告