当前位置:网站首页>The principle of attention mechanism and its application in seq2seq (bahadanau attention)
The principle of attention mechanism and its application in seq2seq (bahadanau attention)
2022-07-05 04:32:00 【SaltyFish_ Go】
Overview of attention mechanism
1、 Additive attention Additive Attention
2、 Zoom in and out and focus Scaled Dot-Product Attention
stay seq2seq Application in Attention(Bahdanau attention )
stay seq2seq Use in attention The motive of :
added attention The difference after
Specific implementation process :
Bahdanau Code implementation of attention
Overview of attention mechanism
Psychological definition : Pay attention to noteworthy points in complex situations
“ Whether to include autonomy tips ” Distinguish the attention mechanism from the full connection layer or aggregation layer :
Convolution 、 Full connection 、pooling They only consider non will clues ;
The attention mechanism considers will clues : Will clue ( Autonomy tips ) be called Inquire about (query). Given any query , Attention mechanism through Focus (attention pooling) Guide selection to Sensory input (sensory inputs, For example, intermediate features represent ). In the attention mechanism , These sensory inputs are called value (value). A more popular explanation , Each value is associated with a key (key) pairing , This can be imagined as an involuntary cue for sensory input . Such as chart 10.1.3 Shown , We can design focus , So that the given query ( Autonomy tips ) Can be used with the key ( Involuntary cues ) Match , This will lead to the best matching value ( Sensory input ).
It's attention weight
Will clue query Namely f(x), In the formula x It's a non will clue key. Simply speaking ,query That's right key Did a weighted arithmetic , The weight is query Yes key Biased choice .
Attention rating
The attention score is query and key The similarity ( Take a chestnut , When I go to the company to apply for a job , Guess your salary , It should be similar to my major , The salary of people with similar work content and time , These people will have high attention scores ), No, Normalized Of alpha Is the attention score a.
Inquire about q Sum key ki The weight of attention ( Scalar ) Through the attention scoring function a Map two vectors into scalars , after softmax Calculated .. Choose different attention scoring functions a It will lead to different attention focusing operations .
1、 Additive attention Additive Attention
Given a query q∈Rq and key k∈Rk, Additive attention (additive attention) The scoring function is
take query and key Connect them and input them into a multi-layer perceptron (MLP) in , The sensor contains a hidden layer , The number of hidden cells is a super parameter h. By using tanh As an activation function , And disable the offset item .query and key It can be any length , By matrix multiplication , final a It's a value , because (1*h)*(h*1).
#@save
class AdditiveAttention(nn.Module):
""" Additive attention """
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion ,
# queries The shape of the :(batch_size, Number of queries ,1,num_hidden)
# key The shape of the :(batch_size,1,“ key - value ” The number of right ,num_hiddens)
# Use broadcast for summation
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
# self.w_v There is only one output , So remove the last dimension from the shape .
# scores The shape of the :(batch_size, Number of queries ,“ key - value ” The number of right )
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# values The shape of the :(batch_size,“ key - value ” The number of right , Dimension of value )
return torch.bmm(self.dropout(self.attention_weights), values)
2、 Zoom in and out and focus Scaled Dot-Product Attention
When n individual query and m individual Key It's all the length of d when
#@save
class DotProductAttention(nn.Module):
""" Zoom in and out and focus """
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# queries The shape of the :(batch_size, Number of queries ,d)
# keys The shape of the :(batch_size,“ key - value ” The number of right ,d)
# values The shape of the :(batch_size,“ key - value ” The number of right , Dimension of value )
# valid_lens The shape of the :(batch_size,) perhaps (batch_size, Number of queries )
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set up transpose_b=True In exchange for keys The last two dimensions of
scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
stay seq2seq Application in Attention(Bahdanau attention )
Notice what is in this model value\key\query
key—— Each word in the encoder corresponds to RNN Output .
query—— The predictive output of the decoder for the previous word ( Used to match related key).
value—— Semantics of context ?
stay seq2seq Use in attention The motive of :
In machine translation , Each generated word may be related to a different word in the source sentence . however seq2seq This cannot be modeled directly :decoder The input of is the last H, Although it contains the previous information , But we have to Restore the location information .
added attention The difference after
The original input forever All are encoder Of RNN The last hidden state is the context context and embedding Passed into the decoder together RNN.
Add... To the original basis attention( Using the state of the last moment as input is not good , We should look at the hidden state of the previous moments according to the translated words ). It is equivalent to making a weighted average of the previous hidden states , Take the most relevant state.
Specific implementation process :
compiler encoder Take the output of each word as Key and value( The two are the same ) In the attention Inside , Then when decoder Predict hello When , Take the output prediction as query Put in attention seek hello Words around ( matching match Of Key).
“ Create index with compiler , Use the prediction of the decoder to locate the focus ”
The last input is attention Output and embedding
Bahdanau Code implementation of attention
added attention_weight function
class Seq2SeqAttentionDecoder(AttentionDecoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
self.attention = d2l.AdditiveAttention(
num_hiddens, num_hiddens, num_hiddens, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(
embed_size + num_hiddens, num_hiddens, num_layers,
dropout=dropout)
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
# outputs The shape of is (batch_size,num_steps,num_hiddens).
# hidden_state The shape of is (num_layers,batch_size,num_hiddens)
outputs, hidden_state = enc_outputs
return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
def forward(self, X, state):
# enc_outputs The shape of is (batch_size,num_steps,num_hiddens).
# hidden_state The shape of is (num_layers,batch_size,
# num_hiddens)
enc_outputs, hidden_state, enc_valid_lens = state
# Output X The shape of is (num_steps,batch_size,embed_size)
X = self.embedding(X).permute(1, 0, 2)
outputs, self._attention_weights = [], []
for x in X:
# query The shape of is (batch_size,1,num_hiddens)
query = torch.unsqueeze(hidden_state[-1], dim=1)
# context The shape of is (batch_size,1,num_hiddens)
context = self.attention(
query, enc_outputs, enc_outputs, enc_valid_lens)
# Connect... In the feature dimension
x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
# take x Deformed to (1,batch_size,embed_size+num_hiddens)
out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention.attention_weights)
# After full connection layer transformation ,outputs The shape of is
# (num_steps,batch_size,vocab_size)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.permute(1, 0, 2), [enc_outputs, hidden_state,
enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weights
边栏推荐
- mxnet导入报各种libcudart*.so、 libcuda*.so找不到
- The scale of computing power in China ranks second in the world: computing is leaping forward in Intelligent Computing
- 这是一个不确定的时代
- 如何进行「小步重构」?
- Hexadecimal to decimal
- 防护电路中的元器件
- [phantom engine UE] only six steps are needed to realize the deployment of ue5 pixel stream and avoid detours! (the principles of 4.26 and 4.27 are similar)
- Sequence diagram of single sign on Certification Center
- Ffmepg usage guide
- Aperçu en direct | Services de conteneurs ACK flexible Prediction Best Practices
猜你喜欢
Fuel consumption calculator
3 minutes learn to create Google account and email detailed tutorial!
网络安全-记录web漏洞修复
可观测|时序数据降采样在Prometheus实践复盘
【虚幻引擎UE】实现背景模糊下近景旋转操作物体的方法及踩坑记录
[phantom engine UE] the difference between running and starting, and the analysis of common problems
2022-2028 global and Chinese equipment as a Service Market Research Report
Pointer function (basic)
Threejs Internet of things, 3D visualization of farm (III) model display, track controller setting, model moving along the route, model adding frame, custom style display label, click the model to obt
Uncover the seven quirky brain circuits necessary for technology leaders
随机推荐
Observable time series data downsampling practice in Prometheus
A application wakes up B should be a fast method
Seven join join queries of MySQL
2022-2028 global and Chinese virtual data storage Market Research Report
Function (basic: parameter, return value)
可观测|时序数据降采样在Prometheus实践复盘
首席信息官如何利用业务分析构建业务价值?
【thingsboard】替换首页logo的方法
Advanced length of redis -- deletion strategy, master-slave replication, sentinel mode
Neural networks and deep learning Chapter 4: feedforward neural networks reading questions
[illusory engine UE] method to realize close-range rotation of operating objects under fuzzy background and pit recording
mxnet导入报各种libcudart*.so、 libcuda*.so找不到
About the prompt loading after appscan is opened: guilogic, it keeps loading and gets stuck. My personal solution. (it may be the first solution available in the whole network at present)
Decryption function calculates "task state and lifecycle management" of asynchronous task capability
Hexadecimal to octal
windows下Redis-cluster集群搭建
Uncover the seven quirky brain circuits necessary for technology leaders
Realize the attention function of the article in the applet
【UNIAPP】系统热更新实现思路
User behavior collection platform