当前位置:网站首页>Pytorch builds transformer to realize multivariable and multi step time series forecasting (load forecasting)
Pytorch builds transformer to realize multivariable and multi step time series forecasting (load forecasting)
2022-06-28 22:24:00 【Cyril_ KI】
Catalog
I. Preface
I have written many articles about time series prediction :
- In depth understanding of PyTorch in LSTM Input and output of ( from input Input to Linear Output )
- PyTorch build LSTM Time series prediction is realized ( Load forecasting )
- PyTorch build LSTM Realize multivariable time series prediction ( Load forecasting )
- PyTorch Build a two-way network LSTM Time series prediction is realized ( Load forecasting )
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( One ): Direct multiple output
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( Two ): Single step rolling prediction
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( 3、 ... and ): Multi model single step prediction
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( Four ): Multi model rolling prediction
- PyTorch build LSTM Realize multivariable and multi step time series prediction ( 5、 ... and ):seq2seq
- PyTorch To realize LSTM Several methods of multi step time series prediction are summarized ( Load forecasting )
- PyTorch-LSTM How to predict the real future value in time series prediction
- PyTorch build LSTM Realize multivariable input and multivariable output time series prediction ( Multi task learning )
- PyTorch build ANN Time series prediction is realized ( Wind speed prediction )
- PyTorch build CNN Time series prediction is realized ( Wind speed prediction )
- PyTorch build CNN-LSTM The hybrid model realizes the prediction of multivariable and multi step time series ( Load forecasting )
- PyTorch build Transformer Realize multivariable and multi step time series prediction ( Load forecasting )
- Summary of time series prediction series ( How to use the code )
None of the above articles has touched on the popular in recent years Attention Mechanism , along with Attention The mechanism proposed together is transformer Model , About transformer The principle of the model is explained on the Internet , I won't describe it here , Write again if you have a chance .
II. Transformer
PyTorch Encapsulates the Transformer The concrete realization of , If the import fails, you can refer to :torch.nn.Transformer Import failed .
Transformer The model is built as follows :
class TransformerModel(nn.Module):
def __init__(self, args):
super(TransformerModel, self).__init__()
self.args = args
# embed_dim = head_dim * num_heads?
self.input_fc = nn.Linear(args.input_size, args.d_model)
self.output_fc = nn.Linear(args.input_size, args.d_model)
self.pos_emb = PositionalEncoding(args.d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=args.d_model,
nhead=8,
dim_feedforward=4 * args.input_size,
batch_first=True,
dropout=0.1,
device=device
)
decoder_layer = nn.TransformerDecoderLayer(
d_model=args.d_model,
nhead=8,
dropout=0.1,
dim_feedforward=4 * args.input_size,
batch_first=True,
device=device
)
self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=8)
self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=8)
self.fc = nn.Linear(args.output_size * args.d_model, args.output_size)
def forward(self, x, y):
# print(x.size()) # (256, 24, 7)
x = self.input_fc(x) # (256, 24, 128)
x = self.pos_emb(x) # (256, 24, 128)
x = self.encoder(x)
# print(y.size()) # (256, 4, 7)
y = self.output_fc(y) # (256, 4, 128)
out = self.decoder(y, x) # (256, 4, 128)
out = out.view(out.shape[0], -1) # (256, 4 * 128)
out = self.fc(out) # (256, 4)
return out
The initial data input dimension is 7, That is, the load value at each time and 6 Environment variables . stay Transformer In the original paper , The embedded dimension of the text is 512, and PyTorch Regulations nhead Sum of numbers d_model That is, the embedded dimension must satisfy the integer division relationship , So first, the original data from 7 Dimension maps to d_model dimension :
x = self.input_fc(x)
among input_fc:
self.input_fc = nn.Linear(args.input_size, args.d_model)
Then encode the position of the original input :
x = self.pos_emb(x)
Then it goes through the coding layer :
x = self.encoder(x)
The resulting output is consistent with the input dimension .
Then output the encoder x And labels y At the same time, input the decoder to decode :
y = self.output_fc(y) # (256, 4, 128)
out = self.decoder(y, x)
label y Before entering the decoder, you also need to change its dimension from 7 Mapping to d_model.
It is worth noting that , In the previous article ,y The dimensions of are all (batch_size, output_size), And in the Transformer in ,y The dimensions are (batch_size, output_size, d_model).
III. Code implementation
3.1 Data processing
Before utilization 24 Hourly load value + After prediction of environmental variables 4 Load value at a time , Data processing is the same as before , Just pay attention to ,y It no longer contains only the load value 1 A variable , But and x equally , Both contain 7 A variable .
3.2 model training / test
Same as above .
3.3 experimental result
The relevant parameters are as follows :
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=50, help='input dimension')
parser.add_argument('--seq_len', type=int, default=24, help='seq len')
parser.add_argument('--input_size', type=int, default=7, help='input dimension')
parser.add_argument('--d_model', type=int, default=128, help='input dimension')
parser.add_argument('--output_size', type=int, default=4, help='output dimension')
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--optimizer', type=str, default='adam', help='type of optimizer')
parser.add_argument('--device', default=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
parser.add_argument('--weight_decay', type=float, default=1e-9, help='weight decay')
parser.add_argument('--bidirectional', type=bool, default=False, help='LSTM direction')
parser.add_argument('--step_size', type=int, default=10, help='step size')
parser.add_argument('--gamma', type=float, default=0.5, help='gamma')
args = parser.parse_args()
return args
Training 50 round ,MAPE by 5.04%:
IV. Source code and data
be based on PyTorch Of Transformer Time series prediction code
边栏推荐
- Description détaillée du schéma technique du sous - environnement syntonique auto - test de Zadig pour les développeurs
- Pat 1054 the dominiant color (20 points)
- 10、标准I/O输入输出重定向及管道
- Set when quartz scheduled task trigger starts
- 水哥的代码
- 软件测试的三个沟通技巧
- 科技巨头成立元宇宙标准论坛,走向开放还是建立围城?
- 杆会睡不着觉
- C#/VB. Net to convert PDF to excel
- 关于杠杆的思考
猜你喜欢

Gross vs60 billion. Is food safety the biggest obstacle to Weilong's listing?

Move the mouse out of the selected area style cancel

After crossing, she said that the multiverse really exists

Zadig 构建究竟何强大?一起来实践

What does project management really manage?

Nc1033 palindrome substring of small a (ring, interval DP)

華為雲的AI深潜之旅
![[dynamic programming] p1018 linear DP: maximum product](/img/3f/b2f394f328c214937add7afa0568c3.jpg)
[dynamic programming] p1018 linear DP: maximum product

IDC:阿里云获2021中国数据治理平台市场份额第一

YAYA LIVE CTO 唐鸿斌:真正本地化,要让产品没有「产地」属性
随机推荐
TCP three handshakes and four waves
[dynamic programming] p1018 linear DP: maximum product
Zadig + SonarQube,为开发过程安全保驾
Rust language survey results in 2021
Un voyage profond d'IA dans Huawei Cloud
Competition rules for the "network security" event of the secondary vocational group in the skills competition of Guangxi Vocational Colleges in 2022
How to open a safe and reliable securities account in the financial management class of qiniu school?
docker下载Mysql镜像创建数据库链接时候发生密码错误问题
Move the mouse out of the selected area style cancel
Set when quartz scheduled task trigger starts
[golang] leetcode intermediate subset & Word Search
初识阿里云(云计算)—发展历程和技术架构、地域和可用区!
Common tool classes and Commons class libraries
硬件开发笔记(七): 硬件开发基本流程,制作一个USB转RS232的模块(六):创建0603封装并关联原理图元器件
Use of axurer9 option group
在亿学学堂开通证券账户是安全可靠的吗?
华为云GaussDB(for Redis)揭秘第19期:六大秒级能力盘点
Laravel文档阅读笔记-Adding a Markdown editor to Laravel
Description détaillée du schéma technique du sous - environnement syntonique auto - test de Zadig pour les développeurs
code review