当前位置:网站首页>基于GPT的隐变量表征解码结构

基于GPT的隐变量表征解码结构

2022-08-02 14:07:00 3A是个坏同志

有时候我们想要用GPT(的一部分)作为预训练的隐变量解码器,应该怎么办呢?最近看论文,总结了三种隐变量注入(code injection)的方式。

1. Cheng X ,  Xu W ,  Wang T , et al. Variational Semi-Supervised Aspect-Term Sentiment Analysis via Transformer[C]// Proceedings of the 23rd Conference on Computational Natural Language Learning (CoNLL). 2019.

这篇文章用的方法是将隐变量添加到每个输入token中。原文是将每个token表示为隐变量和原嵌入的和,然后使用reconstruction任务微调解码器。

2. Ziegler Z M ,  Melas-Kyriazi L ,  Gehrmann S , et al. Encoder-Agnostic Adaptation for Conditional Language Generation[J].  2019.

这篇文章探索了使自然语言预训练语言模型适应任意条件输入的方法。作者观察到,预训练transformer模型对微调过程中的大参数变化非常敏感。因此提出了一种直接将任意条件(arbitrary conditioning)注入self-attention的适应方法,并称之为pseudo self-attention(伪自注意)。主要思路是在每层基础上注入隐变量。其与其它几种注入结构区别简图如下:

a训练了一个新的transformer解码器;b训练了一个新的context attention layer;伪自注意只修改了self-attention层的一部分(图中绿色表示只使用预训练权重初始化,灰色表示从头训练,红色向量表示每一层的target activation,蓝色向量表示编码器输出的源特征)

具体来讲,作者通过一个线性层将隐变量z\in R^d投影到z_l\in R^{d\times L},这样就可以将其拆分成L个向量[z_1,......,z_l]。其中z_l被送入第l个自注意力block。如下图所示:

浅红色为隐变量的键和值

浅红色为隐变量的键和值

伪自注意可以通过以下方式将隐变量吸收到预训练的GPT2自注意力结构中,对于每个自注意层,有:

PSA(Q,K',V')=softmax(\frac{Q{K'}^T}{\sqrt{d_k}})V'

为了表示方便,这里的z指的是z_l(行向量)。其中Q,K,V\in R^{l\times d}为原始的输入嵌入。K'=\begin{pmatrix} z_k \\ K \end{pmatrix} \in R^{(1+l)\times d}V'=\begin{pmatrix} z_v \\ V \end{pmatrix} \in R^{(1+l)\times d}为具有被投影的隐变量z_kz_v(来自z(即z_l)填充的第一行?)的增广的键和值。\begin{pmatrix} . \\ . \end{pmatrix}表示按行连接。

3. Wang T ,  Wan X . T-CVAE: Transformer-Based Conditioned Variational Autoencoder for Story Completion[C]// Twenty-Eighth International Joint Conference on Artificial Intelligence (IJCAI-19). 2019.

该文章使用的方法在最后一层注入,文章的模型结构如下图:

在原始GPT2中,来自最后一个注意力层的嵌入向量h \in R^d通过线性头投影到pre-softmax logit向量p\in R^V(V是词表大小)。想在这个时候注入隐变量,需要学习一个新的共享线性头,以便将z\in R^d投影到p_z\in R^V。最终把p+p_z送入softmax。原文中的具体过程为:

C_t=tanh([z,D_{out,t}^{L}]W_c)

O_t=C_tW_o+b_o

P_t=softmax(O_t)

其中D为解码器输出,W_c为参数矩阵,C_t为合并z与解码器状态的组合层在时间步t上的输出。并进一步馈送到线性头和softmax获得概率分布。

扩展阅读:基于预训练语言模型的文本生成研究综述_zenRRan的博客-CSDN博客 第五节,使用预训练语言模型进行文本生成的常用微调策略_3A的奇奇怪怪圣地-CSDN博客

原网站

版权声明
本文为[3A是个坏同志]所创,转载请带上原文链接,感谢
https://blog.csdn.net/FYZDMMCpp/article/details/121581969