当前位置:网站首页>【长时间序列预测】Aotoformer 代码详解之[3]模型整体架构分析
【长时间序列预测】Aotoformer 代码详解之[3]模型整体架构分析
2022-06-11 02:26:00 【理心炼丹】
1. 模型整体架构图:

显然,一个编码器-解码器架构。先不考虑batch_sizes,编码器将输入的序列(L,d)编码为(L, d_model),d_model 为设置的参数表示将输入嵌入到多少维度。
编码器输入为:
enc_out:[32, 96, 512] # 假设batch_sizes=32, 输入序列长度=96,输入序列嵌入维度=512
编码器的输出:
enc_out:[32, 96, 512] 。编码器只进行了输入的维度进行了一系列变换,并没有改变长度96。
解码器的输入为:
dec_out:[32, 48 + 192, 512] # lable_len = 48,pred_len = 192。
解码器的输出为:
seasonal_part:[32, 48 + 192, c_out];trend_part:[32, 48 + 192, c_out] # c_out 为数据的实际维度。
最终输出:
outputs:[32, 192, c_out]
序列分解模块已在前文中进行了详细描述。
图中所谓的 前馈模块是集成在 EncoderLayer 和 DecoderLayer 中的,不过是一些Conv1d -> activation -> dropout -> Conv1d -> dropout 。
所需注意的只有一点:解码器的第二个自相关模块的输入是来自于编码器的输出作为(k, v),以及前一个解码器的输出作为 q。显然,二者的长度是不一样的。(图中信号的引出表示同一个东西,比如q, k, v来自一个输出就是一个信号复制的3次)
k = v :[32,96,512]
q: [32, 48 + 192, 512]
所以,AutoCorrelation 类的 forward 函数中:
B, L, H, E = queries.shape # 32 48+192 4 512/4 # 假设多头=4
_, S, _, D = values.shape # 32 96 4 512/4
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]因此,采用的方案是把 k, v 后面进行了补0到和q 一样长。
因此,我们知道了编码器和解码器并不会对长度进行修改。
2. 模型的输入
2.1 编码器的输入
编码器输入为:enc_out:[32, 96, 512],是如何得到的呢?
答案通过enc_embedding,
输入包括两部分:
除了时间戳以外的数据 x_enc:[32, 96, enc_in] #enc_in 就是输入数据的维度
时间戳信息 x_mark_enc:[32, 96, 4] # 如果 args.embed == 'timeF' ,freq='h' 请看前文的数据处理
enc_embedding是DataEmbedding_wo_pos类的对象。该类的forward 函数中将 x_enc(Conv1d) 与 x_mark_enc(Linear) 的shape都编码为 [32, 96, 512] ,然后相加,即得到编码器的输入 enc_out:[32, 96, 512] 。
2.1 解码器的输入
对 x_enc:[32, 96, enc_in] 进行前文的序列分解得到两个序列:
seasonal_init:[32, 96, enc_in]
trend_init:[32, 96, enc_in]
对 x_enc 求均值,和造一个 0 值序列。
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) # mean 会把dim=1 那个维度干掉[32, enc_in] -> [32, 1, enc_in] -> [32, 192, enc_in]
zeros = torch.zeros([x_enc.shape[0], self.pred_len, x_enc.shape[2]], device=x_enc.device) # 源代码中的 x_dec 完全没必要,替换为 x_enc 就行了。之所以那么做,作者是为了统一接口让别的模型一样传递输入。 [32, 192, enc_in] # decoder input
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)seasonal_init:[32, 48 + 192, enc_in]
trend_init:[32, 48 + 192, enc_in]
dec_embedding 嵌入 与上文的 enc_embedding 一样。输入:
seasonal_init:[32, 48 + 192, enc_in]
x_mark_dec:[32, 48 + 192, 4]
dec_out: [32, 48 + 192, 512]
3. 解码器的加法怎么做的?
其实论文中的架构图简化了。因此,我重画了解码器部分。

参考:
计划更新:[4]模型部件之自相关层-AC
边栏推荐
- Les produits financiers de l'assurance - rente peuvent - ils être composés? Quel est le taux d'intérêt?
- A数位dp
- Looking at the ups and downs of the mobile phone accessories market from the green Union's sprint for IPO
- What is the relationship between precious metal silver and spot Silver
- [MySQL 45 lecture -12] lecture 12 the reason why MySQL has a wind attack from time to time
- The Google search console webmaster tool cannot read the sitemap?
- Unity animator rewind
- 【新晋开源项目】动态配置化任务编排框架 Gobrs-Async 加入Dromara开源社区
- [resolved] how to fix another update in progress WordPress upgrade error
- ShaderGraphs
猜你喜欢

Manon's advanced road - Daily anecdotes

蓝桥杯_小蓝吃糖果_鸽巢原理 / 抽屉原理

银行选择电子招标采购的必要性

To view the data in redis, in addition to the command line and client, you have a third option

弄懂了采矿业与碳中和的逻辑,就读懂了矿区无人驾驶的千亿市场

同一个用户的两次请求SessionId竟然不一致-----记录问题

重磅直播!ORB-SLAM3系列之特征匹配(MLPnP、词袋模型等)。
![[AI weekly] AI and freeze electron microscopy reveal the structure of](/img/2e/e986a5bc44526f686c407378a9492f.png)
[AI weekly] AI and freeze electron microscopy reveal the structure of "atomic level" NPC; Tsinghua and Shangtang proposed the "SIM" method, which takes into account semantic alignment and spatial reso

App test_ Summary of test points

MySQL备份与恢复
随机推荐
GCC C内联汇编
【面试题 17.04. 消失的数字】
Unity HTC and Pico are the same
Kotlin apply method
[C language classic]: inverted string
Forest v1.5.22 发布!支持Kotlin
Jetpack compose scaffold and bottomappbar (bottom navigation)
To view the data in redis, in addition to the command line and client, you have a third option
Istio安装与使用
Unity animator rewind
CPT 102_LEC 15
Stc8a8k64d4 EEPROM read / write failure
Write my Ini configuration file error
P4338 [ZJOI2018]历史(树剖)(暴力)
App test_ Summary of test points
码农的进阶之路 | 每日趣闻
MySQL备份与恢复
Jetpack Compose Scaffold和TopAppBar(顶部导航)
Unity3d model skin changing technology
Jetpack Compose Scaffold和BottomAppBar(底部导航)
https://github.com/thuml/Autoformer