当前位置:网站首页>Transformer中position encoding实践

Transformer中position encoding实践

2022-07-04 14:54:00 初学者chris

近年来,transformer由于其可以实现并行计算且可以解决长序列的依赖问题在nlp领域和cv领域大放异彩。
原理图如下所示:
在这里插入图片描述
这里我们主要关注一个小部分,即position encoding部分,因为transformer取消了循环依赖,为了体现位置属性,所以给每个元素进行位置编码。
代码如下所示,至于为什么会这么写,可以参考作者原文,或者参考一下文章。https://zhuanlan.zhihu.com/p/338592312
代码如下:

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)#(max-len,1,d_model)
        
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(1), :].squeeze(1)
        #x = x + self.pe[:x.size(1), :]
        return x

为了测试,我们定义两个输入矩阵,分别为全0、全1tensor。

d_model = 4   
a=torch.zeros(2,3,4)
pos=PositionalEncoding(d_model)
b=pos(a)
c=torch.ones(2,3,4)
b1=pos(c)

很明显,输入矩阵为

在这里插入图片描述
输出为b,b1,如下所示:;

在这里插入图片描述
在这里插入图片描述
可以看出,都是在输入的基础之上,加上了固定值,而那些固定值就是编码得到的,与输入无关,与d_model有关,d_model可以理解为单词的embedding大小。

原网站

版权声明
本文为[初学者chris]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_42282231/article/details/125492890