当前位置:网站首页>ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
2022-08-04 19:34:00 【Live up to your youth ღ】
模型概述
ELECTRAA new pre-training task is proposedReplaced Token Detection(RTD),Its goal is to learn to distinguish input words.This method does not usemask,Instead, it destroys the input by using a generative network to generate some reasonable replacement characters..然后,train a discriminator model,The model can predict whether the current character has been replaced by the language model.One benefit of discriminative tasks is that the model learns from all the words in the input,而不是MLMuse only masked words as in,因此计算更加有效.
ELECTRAThe model structure is similar toGAN,including two neural networks:一个生成器G和一个判别器D,Both use the formTransformer的编码网络To get the input sequencex的向量表示h(x).
The goal of the generator is to train as a masked language model,i.e. given the input sequencex,First, according to a certain proportion(通常15%)replace the words in the input with[MASK]得到,然后通过网络得到向量表示hG(x),接着采用softmaxlayer to predict a word for a masked position in the input sequence,The objective function of training is to maximize the likelihood of masked words.Goal is to determine the input sequence of discriminant words each position is replaced by a generator,If the corresponding location of the word is not the same as the original input sequences as replacement.
ELECTRA与GAN的区别,The author lists the following points:
模型 | ELECTRA | GAN |
---|---|---|
输入 | 真实文本 | 随机噪声 |
目标 | Generators learn language models,Discriminant, learn to distinguish true and false text | The generator tricks the discriminator as much as possible,The discriminator tries to distinguish between real and fake images |
反向传播 | Gradient cannot be derived fromD传到G | 梯度可以从D传到G |
特殊情况 | generated real text,mark as positive | Generated are all negative examples(假图片) |
模型结构
ELECTRABoth the generator and discriminator are based onBERT的encoder模型,encoder模型结构如下:
因为ELECTRAThe words of the sentence input by the model are discrete,So the gradient cannot be back-propagated,The gradient of the discriminator cannot be passed to the generator,So the goal of the generator isMLM,The goal of the discriminator is sequence labeling(Determine whether each character is true or false),两者同时训练,But the gradient of the discriminator is not passed to the generator.因此,The objective function of the entire model for the pre-training task is: m i n θ G , θ D ∑ x ∈ X L M L M ( x , θ G ) + λ L D i s c ( x , θ D ) minθ_G,θ_D\sum_{\begin{subarray}{l}x∈X\end{subarray}}L_{MLM}(x,θ_G)+λL_{Disc}(x,θ_D) minθG,θD∑x∈XLMLM(x,θG)+λLDisc(x,θD).
Because of discrimination task is relatively simple,RTDThe loss is relativelyMLM损失会很小,So add a coefficient,论文中使用了50.经过预训练,in use in downstream tasks,give the generator directly,fine-tuning the discriminator.另外,All computed when optimizing the discriminatortoken上的损失,while the previous calculationBERT的MLM losswill be ignored when notmask的token.
In terms of generator and discriminator weight sharing,Set up the same size of the generator and discrimination.The effect without sharing weights is83.6,只共享token embeddingThe effect of the layer is84.3,The effect of weight is Shared ownership84.4.For the generatorembedding 有更好的学习能力,This is because the discriminator only updates samples generated by the generatortoken,而softmax是建立在所有vocab上的,Then the generator will update all theembedding,In the end the author only usedembedding sharing.
在模型大小方面,Because it is seen from experiments with weight sharing,The generator and discriminator only need to shareembedding weight is enough.In this case, can the size of the generator be reduced to improve the training efficiency??The author maintains the originalhidden sizeThe number of layers is reduced under the setting of,The relationship diagram shown in the following figure is obtained:It can be seen from the figure that the size of the generator is in the discriminator1/4到1/2The effect is the best.The reason is that an overly strong generator will increase the difficulty of the discriminator.
in training strategies,The authors also tried two other training strategies:
1、Adversarial Contrastive Estimation:ELECTRAUnable to use due to some of the above problemsGAN,But it can also be trained with an adversarial learning mindset.The author minimizes the objective function of the generator byMLMThe loss is replaced by the maximizing discriminator being replacedtoken上RTD损失.但还有一个问题,Is the new generator can't use gradient rise update generator,So the author uses reinforcement learningPolicy Gradient思想,The final optimization down to the generator isMLM The task can be achieved54%的准确率,而之前MLEoptimized to achieve65%.
2、Two-stage training:i.e. train the generator first,然后freeze掉,Initialize the discriminator with the weights of the generator,Then train the discriminator of the same sync number.
Comparing three training strategies,得到下图:
可见“隔离式”The training strategy is still the best,The two-stage training is weaker,The author guesses that the generator is too strong, which increases the difficulty of the discrimination task.However, the final effect of the two-stageBERTItself is better,Further proof of the effect of discriminative pre-training.
ELECTRA模型代码,可以和BERT代码对比一下,基本上没有什么差别,The difference is that some parameters are different in the pre-training task:
import math
from dataclasses import dataclass
from typing import Optional, List, Tuple
import tensorflow as tf
from tensorflow.keras import layers
from transformers import shape_list
from transformers.activations_tf import get_tf_activation
from transformers.modeling_tf_utils import get_initializer
from transformers.tf_utils import stable_softmax
from transformers.utils import ModelOutput
class TFElectraModel(tf.keras.Model):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.electra = TFElectraMainLayer(config, name="electra")
def call(
self,
inputs_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False
):
outputs = self.electra(
inputs_ids=inputs_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training
)
return outputs
class TFElectraMainLayer(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.is_decoder = config.is_decoder
self.embeddings = TFElectraEmbeddings(config, name="embeddings")
if config.embedding_size != config.hidden_size:
self.embeddings_project = layers.Dense(config.hidden_size, name="embeddings_project")
self.encoder = TFElectraEncoder(config, name="encoder")
def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False
):
# input_idsrepresents the input sequence of words,shape: (batch_size, seq_length)
# inputs_embedsRepresents the word vector corresponding to the word sequence,shape(batch_size, seq_length, embed_size)
# Only one of these two parameters can be specified,如果指定input_ids,Then it needs to be generated by the embedding layerinputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[: 2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
# past_key_values用于decoderand computations used to speed up the inference process
# cached during trainingkey、value两个向量,Since the inference process or prediction process will repeatedly calculate these two vectors
# 因此,by using cachedkey、value向量,no need to recalculate,thereby speeding up the process of reasoning
# shape: (n_layers, 4, batch_size, num_heads, seq_length - 1, head_dim)
# 另外,如果指定了past_key_values,那么decoder_input_ids的形状应为(batch_size, 1)
# 而不是(batch_size, seq_length),that is, use eachbatchthe last word to replace all words
if past_key_values is None:
past_key_value_length = 0
past_key_values = [None] * len(self.encoder.layer)
else:
past_key_value_length = shape_list(past_key_values[0][0])[2]
if attention_mask is None:
attention_mask = tf.fill((batch_size, seq_length + past_key_value_length), value=1)
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, value=0)
# hidden_states = word_embeds + token_type_embeds + position_embeds
# shape: (batch_size, seq_length, embed_size)
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_value_length=past_key_value_length,
training=training
)
# attention_mask -> extended_attention_mask
# shape: (batch_size, seq_length) -> (batch_size, 1, 1, mask_seq_length)
# Consider several situations
# 1、如果是decoder的self-attention的attention_mask,需要变成causal attention mask.
# Then expand the dimension to becomeextended_attention_mask,形状为一个4维张量(batch_size, 1, 1, seq_length),
# 通过pythonBroadcast mechanism into(batch_size, num_heads, mask_seq_length, mask_seq_length)To adapt to the long attention calculation.
# 2、如果是decoder并且使用past_key_values,那么seq_length = 1
# 计算decoder的self-attention时,need to convert the shape to(batch_size, num_heads, 1, mask_seq_length)
# 3、如果是encoder,直接转换成(batch_size, 1, 1, mask_seq_length)
attention_mask_shape = shape_list(attention_mask)
mask_seq_length = seq_length + past_key_value_length
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None]
)
causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
# 这里的causal_mask * attention_mask[:, None, :]表示考虑到attention_mask的pad信息
extended_attention_mask = causal_mask * attention_mask[:, None, :]
attention_mask_shape = shape_list(extended_attention_mask)
extended_attention_mask = tf.reshape(
extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
)
if past_key_value_length > 0:
extended_attention_mask = extended_attention_mask[:, :, -seq_length, :]
else:
extended_attention_mask = tf.reshape(
attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
)
extended_attention_mask = tf.cast(extended_attention_mask, dtype=attention_mask.dtype)
# 与下面代码等价
# extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
one_cst = tf.constant(1.0, dtype=hidden_states.dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=hidden_states.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
# encoder_attention_mask
# for participating in computingdecoder的cross_attention
if self.is_decoder and encoder_attention_mask is not None:
encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
if num_dims_encoder_attention_mask == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if num_dims_encoder_attention_mask == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
else:
encoder_extended_attention_mask = None
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.config.num_hidden_layers
if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.encoder(
hidden_states=hidden_states,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training
)
return hidden_states
class TFElectraEmbeddings(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.vocab_size = config.vocab_size
self.type_vocab_size = config.type_vocab_size
self.embedding_size = config.embedding_size
self.max_position_embeddings = config.max_position_embeddings
self.initializer_range = config.initializer_range
self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = layers.Dropout(rate=config.hidden_dropout)
def build(self, input_shape):
with tf.name_scope("word_embeddings"):
self.word_mebeddings = self.add_weight(
name="embeddings",
shape=[self.vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range)
)
with tf.name_scope("token_type_embeddings"):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.embedding_size],
initializer=get_initializer(self.initializer_range)
)
with tf.name_scope("position_embeddings"):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.embedding_size],
initializer=get_initializer(self.initializer_range)
)
super().build(input_shape)
def call(
self,
input_ids=None,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
past_key_value_length=0,
training=False
):
if input_ids is None and inputs_embeds is None:
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
if input_ids is not None:
inputs_embeds = tf.gather(self.word_mebeddings, input_ids)
input_shape = shape_list(inputs_embeds)[: -1]
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, value=0)
if position_ids is None:
position_ids = tf.expand_dims(
tf.range(past_key_value_length, input_shape[1] + past_key_value_length),
axis=0
)
position_embeds = tf.gather(self.position_embeddings, position_ids)
token_type_embeds = tf.gather(self.token_type_embeddings, token_type_ids)
final_embeddings = inputs_embeds + token_type_embeds + position_embeds
final_embeddings = self.layer_norm(final_embeddings)
final_embeddings = self.dropout(final_embeddings)
return final_embeddings
class TFElectraEncoder(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False
):
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
past_key_value = past_key_values[i] if past_key_values is not None else None
# 如果是decoder,且
# 如果output_attentions=True,则layer_outputs是一个四元组
# (hidden_states, self-attention, cross_attention, past_key_value)
# 如果output_attentions=False,则layer_outputs是一个二元组
# (hidden_states, past_key_value)
# 其中past_key_valueanother quadruple
# (self_attn_key_layer, self_attn_value_layer, cross_attn_key_layer, cross_attn_value_layer)
# 如果是encoder,且
# 如果output_attentions=True,则layer_outputs是一个二元组
# (hidden_states, attention)
# 如果output_attentions=False,则layer_outputs是一个一元组
# (hidden_states)
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1], )
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1], )
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2], )
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class TFElectraLayer(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.attention = TFElectraAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.cross_attention = TFElectraAttention(config, name="cross_attention")
self.intermediate = TFElectraIntermediate(config, name="intermediate")
self.bert_output = TFElectraOutput(config, name="output")
def call(
self,
hidden_states=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=None,
training=False
):
self_attn_past_key_value = past_key_value[: 2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions,
training=training
)
attention_output = self_attention_outputs[0]
if self.is_decoder:
outputs = self_attention_outputs[1: -1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[-1]
if self.is_decoder and encoder_hidden_states is not None:
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.cross_attention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1: -1]
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + outputs
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
class TFElectraAttention(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFElectraSelfAttention(config, name="self_attention")
self.dense_output = TFElectraSelfOutput(config, name="dense_output")
def call(
self,
input_tensor=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=None,
training=False
):
self_outputs = self.self_attention(
hidden_states=input_tensor,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:]
return outputs
class TFElectraIntermediate(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFElectraOutput(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
class TFElectraSelfAttention(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=None,
training=False,
):
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFElectraModel call() function)
attention_scores = tf.add(attention_scores, attention_mask)
# Normalize the attention scores to probabilities.
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(inputs=attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
class TFElectraSelfOutput(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
return hidden_states
@dataclass
class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
last_hidden_state: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
边栏推荐
猜你喜欢
随机推荐
openharmony初体验(1)
蚂蚁集团时序数据库CeresDB正式开源
Storage resource activation system to help new infrastructure
JS: 数组和树的相互转换
MMDetection 使用示例:从入门到出门
性能测试流程
SAP 电商云 Accelerator 和 Spartacus UI 的工作机制差异
nr部分计算
如何手动下载并安装 Visual Studio Code 的 SAP Fiori tools - Extension Pack
SIGIR 2022 | 邻域建模Graph-Masked Transformer,显著提高CTR预测性能
存储资源盘活系统助力新基建
入门:人脸专集1 | 级联卷积神经网络用于人脸检测(文末福利)
基于YOLOV5行人跌倒检测实验
T+Cloud: A "Smart Company" for Building New Business Social Networks and Marketing Relationships
按需视觉识别:愿景和初步方案
win10 uwp DataContext
SAP UI5 确保控件 id 全局唯一的实现方法
NLP技术为何在工业界这么卷?前沿案例解析来了
Yuanguo chain game system development
T+Cloud:构建新型生意社交网络和营销关系的“智公司”