当前位置:网站首页>torch. nn. Simple understanding of parameter / / to be continued. Let me understand this paragraph

torch. nn. Simple understanding of parameter / / to be continued. Let me understand this paragraph

2022-06-10 20:57:00 Thinking and Practice


# ########## fourier layer #############
class FourierBlock(nn.Module):
    def __init__(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'):
        super(FourierBlock, self).__init__()
        print('fourier enhanced block used!')
        """
        1D Fourier block. It performs representation learning on frequency domain, 
        it does FFT, linear transform, and Inverse FFT.    
        """
        # get modes on frequency domain
        self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)# Get the randomly selected basis , Follow up DFT operation 
        print('modes={}, index={}'.format(modes, self.index))

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(
            self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.cfloat))

    # Complex multiplication  Complex multiplication 
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bhi,hio->bho", input, weights)# Calculation of high dimensional tensors 
    # Understand this torch.einsum operation !!!

    def forward(self, q, k, v, mask):
        # size = [B, L, H, E]
        B, L, H, E = q.shape
        x = q.permute(0, 2, 3, 1)
        # Compute Fourier coefficients
        x_ft = torch.fft.rfft(x, dim=-1)
        # Perform Fourier neural operations
        out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)
        for wi, i in enumerate(self.index):
            out_ft[:, :, :, wi] = self.compl_mul1d(x_ft[:, :, :, i], self.weights1[:, :, :, wi])
        # Return to time domain
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return (x, None) 

 

torch.nn.Parameter understand _Stoneplay26 The blog of -CSDN Blog _torch.nn.parameter

PyTorch Inside torch.nn.Parameter()_ J-choice . The blog of -CSDN Blog _torch.nn.parameter

Reference material

torch.nn.Parameter()_chenzy_hust The blog of -CSDN Blog _nn.parameter()

PyTorch Inside torch.nn.Parameter() - Simple books  

原网站

版权声明
本文为[Thinking and Practice]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/161/202206101906369931.html