当前位置:网站首页>【长时间序列预测】Aotoformer 代码详解之[4]自相关机制
【长时间序列预测】Aotoformer 代码详解之[4]自相关机制
2022-06-12 05:30:00 【理心炼丹】
先看论文中的图:

显然,AutoCorrelation 完全可以替换 self-attention 。输入都是 q,k, v ,输出是一个 V。
直接看代码:
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
# period-based dependencies
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
res = q_fft * torch.conj(k_fft) # 32, 4, 128, 49
corr = torch.fft.irfft(res, dim=-1) # 32, 4, 128, 96一般情况下,queries, keys, values 的shape都是一样的,比如:[32, 96, 4, 128]。
当然,解码器的第二个AC,keys和 values 的shape 一样,比如:[32, 96, 4, 128]。queries 的为[32, 48+192, 4, 128]。通过上面 if 语句会把 keys和 values 的长度补0 到 queries 一样长。
本文以 queries, keys, values 的shape:[32, 96, 4, 128] 说明。
显然, fft 并没有可说的。主要关注的在 time_delay_agg_training 函数。
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) def time_delay_agg_training(self, values, corr): # 32 4 128 96
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the training phase.
"""
# 把批次看成1 进行分析,一目了然
head = values.shape[1] # 4
channel = values.shape[2] # 128
length = values.shape[3] # 96
# find top k
top_k = int(self.factor * math.log(length)) # 1* ln(96) = 4
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) # 32, 96 每次求mean 那个维度就会消失
print(torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)) # 4个值
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] # 4 假设得到:[3, 4, 5, 2] # 这里[1]后得到的是index ## 批次维度求均值得到 shape: 96 再取 top_k
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) # 32 4 把96个里面最大的4个值取出来
# update corr
tmp_corr = torch.softmax(weights, dim=-1) # 32 4 这4个值归一化
# aggregation
tmp_values = values # 32 4 128 96
delays_agg = torch.zeros_like(values).float() # 32 4 128 96 都是0
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1) # 32 4 128 96 从序列那个维度移动 # index 对应移动的步长
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) # pattern 乘的是 那个 R_q,k 后面部分表示对应元素相乘 后面的就是那个权重 unsqueeze repeat 是为了形状相同 可以相乘 # tmp_corr[:, i] 这里不考虑batch_size 时是一个数 把这个数repeat 96次
return delays_agg # 32 4 128 96
上图,假设top_k = 2, 序列长度 = 10。
该模型更新完毕。
更多参考:请参考前文论文阅读部分的参考文献。作者回答了为什么用 torch.roll 。使用 top_k的目的还是降低计算复杂度。
边栏推荐
- 57 - II. Continuous positive sequence with sum s
- 38. arrangement of strings
- Index fund summary
- Applet pull-down load refresh onreachbottom
- 12.26 exercise summary
- 16. sum of the nearest three numbers
- Necessary for Test Engineer -- package capturing tool fiddler
- Role and understanding of proc/cmdline
- Codec of ASoC framework driven by alsa
- Test work summary - performance test related issues
猜你喜欢

Selenium crawler automatically captures TOEFL test position of NEEA website

Vivado HLS introductory notes

SQL transaction

FPGA语法的细节

What is the project advance payment

利用jieba库进行词频统计

Detailed usage of vim editor

Stm32f4 ll library multi-channel ADC

Multi thread learning v. volatile visibility and cache inconsistency, instruction reordering

个体工商户是不是法人企业
随机推荐
Stm32f4 ll library multi-channel ADC
Deploying a single node kubernetes cluster using rancher-2.5.5
Sv806 QT UI development
Introduction to audio alsa architecture
Matlab: image rotation and interpolation and comparison of MSE before and after
43. Number of occurrences of 1 in 1 ~ n integers
New knowledge today
Pupanvr hardware and software board side development environment configuration (4)
个体工商户是不是法人企业
16. Somme des trois plus proches
Role and understanding of proc/cmdline
Abstract methods and interfaces
57 - II. Continuous positive sequence with sum s
What is the project advance payment
Google reinforcement learning framework seed RL environment deployment
29. print matrix clockwise
Static keyword and inheritance, polymorphic and special classes
Qs100 at command mqtt access thingsboard
Set common methods
31. stack push in and pop-up sequence