当前位置:网站首页>【长时间序列预测】Aotoformer 代码详解之[3]模型整体架构分析
【长时间序列预测】Aotoformer 代码详解之[3]模型整体架构分析
2022-06-11 02:26:00 【理心炼丹】
1. 模型整体架构图:

显然,一个编码器-解码器架构。先不考虑batch_sizes,编码器将输入的序列(L,d)编码为(L, d_model),d_model 为设置的参数表示将输入嵌入到多少维度。
编码器输入为:
enc_out:[32, 96, 512] # 假设batch_sizes=32, 输入序列长度=96,输入序列嵌入维度=512
编码器的输出:
enc_out:[32, 96, 512] 。编码器只进行了输入的维度进行了一系列变换,并没有改变长度96。
解码器的输入为:
dec_out:[32, 48 + 192, 512] # lable_len = 48,pred_len = 192。
解码器的输出为:
seasonal_part:[32, 48 + 192, c_out];trend_part:[32, 48 + 192, c_out] # c_out 为数据的实际维度。
最终输出:
outputs:[32, 192, c_out]
序列分解模块已在前文中进行了详细描述。
图中所谓的 前馈模块是集成在 EncoderLayer 和 DecoderLayer 中的,不过是一些Conv1d -> activation -> dropout -> Conv1d -> dropout 。
所需注意的只有一点:解码器的第二个自相关模块的输入是来自于编码器的输出作为(k, v),以及前一个解码器的输出作为 q。显然,二者的长度是不一样的。(图中信号的引出表示同一个东西,比如q, k, v来自一个输出就是一个信号复制的3次)
k = v :[32,96,512]
q: [32, 48 + 192, 512]
所以,AutoCorrelation 类的 forward 函数中:
B, L, H, E = queries.shape # 32 48+192 4 512/4 # 假设多头=4
_, S, _, D = values.shape # 32 96 4 512/4
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]因此,采用的方案是把 k, v 后面进行了补0到和q 一样长。
因此,我们知道了编码器和解码器并不会对长度进行修改。
2. 模型的输入
2.1 编码器的输入
编码器输入为:enc_out:[32, 96, 512],是如何得到的呢?
答案通过enc_embedding,
输入包括两部分:
除了时间戳以外的数据 x_enc:[32, 96, enc_in] #enc_in 就是输入数据的维度
时间戳信息 x_mark_enc:[32, 96, 4] # 如果 args.embed == 'timeF' ,freq='h' 请看前文的数据处理
enc_embedding是DataEmbedding_wo_pos类的对象。该类的forward 函数中将 x_enc(Conv1d) 与 x_mark_enc(Linear) 的shape都编码为 [32, 96, 512] ,然后相加,即得到编码器的输入 enc_out:[32, 96, 512] 。
2.1 解码器的输入
对 x_enc:[32, 96, enc_in] 进行前文的序列分解得到两个序列:
seasonal_init:[32, 96, enc_in]
trend_init:[32, 96, enc_in]
对 x_enc 求均值,和造一个 0 值序列。
mean = torch.mean(x_enc, dim=1).unsqueeze(1).repeat(1, self.pred_len, 1) # mean 会把dim=1 那个维度干掉[32, enc_in] -> [32, 1, enc_in] -> [32, 192, enc_in]
zeros = torch.zeros([x_enc.shape[0], self.pred_len, x_enc.shape[2]], device=x_enc.device) # 源代码中的 x_dec 完全没必要,替换为 x_enc 就行了。之所以那么做,作者是为了统一接口让别的模型一样传递输入。 [32, 192, enc_in] # decoder input
trend_init = torch.cat([trend_init[:, -self.label_len:, :], mean], dim=1)
seasonal_init = torch.cat([seasonal_init[:, -self.label_len:, :], zeros], dim=1)seasonal_init:[32, 48 + 192, enc_in]
trend_init:[32, 48 + 192, enc_in]
dec_embedding 嵌入 与上文的 enc_embedding 一样。输入:
seasonal_init:[32, 48 + 192, enc_in]
x_mark_dec:[32, 48 + 192, 4]
dec_out: [32, 48 + 192, 512]
3. 解码器的加法怎么做的?
其实论文中的架构图简化了。因此,我重画了解码器部分。

参考:
计划更新:[4]模型部件之自相关层-AC
边栏推荐
- PHP starts OpenSSL and reports OpenSSL support=> disabled (install ext/openssl)
- AOSP ~ modify WebView default implementation
- A collection of common ADB commands for app testing
- Unity animator rewind
- WordPress article directory plug-in luckywp table of contents setup tutorial
- [189. rotation array]
- Metal organic framework materials (fe-mil-53, mg-mof-74, ti-kumof-1, fe-mil-100, fe-mil-101) supported on isoflurane / methotrexate / doxorubicin (DOX) / paclitaxel / ibuprofen / camptothecin
- app 测试 常用 adb 命令集合
- GCC C inline assembly
- Byte beating | the first batch of written examination for game R & D post (question solution)
猜你喜欢

RS232/RS485转4G DTU 上传基于Modbus协议的温湿度传感器数据到远程TCP服务器

Baidu submits sitemap to prompt the solution of "index type is not handled"

Jetpack compose scaffold and topappbar (top navigation)
![[implementation of bubble sorting]](/img/c9/5e4aa246c89fd03a184dbd00161f97.png)
[implementation of bubble sorting]

ADVANCE. AI CEO Shoudong will share the compliance of cross-border e-commerce using AI technology at the 2022 emerging market brands online Summit

Write my Ini configuration file error

Necessity for banks to choose electronic bidding procurement

The Google search console webmaster tool cannot read the sitemap?

Add SQL formatter to vscode to format SQL

Uni app - one click access to user information
随机推荐
Three special data types, day3 and redis (geographic location, cardinality statistics and bitmap scene usage)
Navicat premium 15 tool is automatically deleted by anti-virus protection software solution
深度学习基础篇【4】从0开始搭建EasyOCR并进行简单文字识别
Google Gmail mailbox marks all unread messages as read at once
軟件測試英語常見詞匯
ADVANCE. AI CEO Shoudong will share the compliance of cross-border e-commerce using AI technology at the 2022 emerging market brands online Summit
Sd3.0 notes
Unity3d model skin changing technology
Manon's advanced road - Daily anecdotes
基于互联网架构演进, 构建秒杀系统
Cyclodextrin metal organic framework( β- Cd-mof) loaded with dimercaptosuccinic acid / emodin / quercetin / sucralose / diflunisal / omeprazole (OME)
AOSP ~ WIFI默认开启 + GPS默认关闭 + 蓝牙默认关闭 + 旋转屏幕关闭
牛客网:数组中只出现一次的两个数字
[implementation of bubble sorting]
新来的同事问我 where 1=1 是什么意思???
AOSP ~ modify WebView default implementation
Write my Ini configuration file error
CPT 102_LEC 15
【斐波那契数列】
Istio installation and use
https://github.com/thuml/Autoformer