当前位置:网站首页>[long time series prediction] the [4] autocorrelation mechanism of aotoformer code explanation

[long time series prediction] the [4] autocorrelation mechanism of aotoformer code explanation

2022-06-12 05:35:00 Heart regulating and pill refining

First look at the pictures in the paper :

obviously ,AutoCorrelation Can completely replace self-attention . The input is q,k, v , The output is a V.

Look directly at the code :

    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        if L > S:
            zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
            values = torch.cat([values, zeros], dim=1)
            keys = torch.cat([keys, zeros], dim=1)
        else:
            values = values[:, :L, :, :]
            keys = keys[:, :L, :, :]

        # period-based dependencies
        q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
        k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) # 32, 4, 128, 49
        res = q_fft * torch.conj(k_fft) # 32, 4, 128, 49
        corr = torch.fft.irfft(res, dim=-1) # 32, 4, 128, 96

In general ,queries, keys, values Of shape It's all the same , such as :[32, 96, 4, 128].

Of course , The second of the decoder AC,keys and values Of shape equally , such as :[32, 96, 4, 128].queries For the [32, 48+192, 4, 128]. Pass above if The sentence will put  keys and values Length complement of 0 To queries As long as .

This article takes  queries, keys, values Of shape:[32, 96, 4, 128] explain .

obviously , fft There is nothing to say . The main concern is  time_delay_agg_training function .

            V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
    def time_delay_agg_training(self, values, corr): # 32 4 128 96
        """
        SpeedUp version of Autocorrelation (a batch-normalization style design)
        This is for the training phase.
        """
        #  Think of the batch as 1  Analyze , Be clear at a glance 
        head = values.shape[1] # 4
        channel = values.shape[2] # 128 
        length = values.shape[3] # 96
        # find top k
        top_k = int(self.factor * math.log(length)) # 1* ln(96) = 4
        mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) # 32, 96  Every time mean  That dimension will disappear 
        print(torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)) # 4 It's worth 
        index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] # 4  Suppose we get :[3, 4, 5, 2] #  here [1] After that, I got index  ##  The batch dimension is averaged to obtain  shape: 96   Retake  top_k
        weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) # 32 4   hold 96 The largest of the three 4 Take out a value 
        # update corr
        tmp_corr = torch.softmax(weights, dim=-1) # 32 4   this 4 The values are normalized 
        # aggregation
        tmp_values = values # 32 4 128 96
        delays_agg = torch.zeros_like(values).float() # 32 4 128 96  All are 0
        for i in range(top_k):
            pattern = torch.roll(tmp_values, -int(index[i]), -1) # 32 4 128  96   Move from the sequence dimension   # index  The step length corresponding to the movement 
            delays_agg = delays_agg + pattern * \
                         (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) # pattern  By   that  R_q,k  The latter part represents the multiplication of corresponding elements   The latter is the weight   unsqueeze repeat  For the same shape   You can multiply  # tmp_corr[:, i]  I don't think about it here batch_size  Time is a number   Count this repeat 96 Time 
        return delays_agg # 32 4 128 96

Upper figure , hypothesis top_k = 2, Sequence length = 10.


The model has been updated .

More reference : Please refer to the references in the reading section of the previous paper . The author answered why  torch.roll . Use top_k The goal is to reduce the computational complexity .

原网站

版权声明
本文为[Heart regulating and pill refining]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/163/202206120530233893.html