当前位置:网站首页>Mechanism and principle of multihead attention and masked attention
Mechanism and principle of multihead attention and masked attention
2022-07-25 05:42:00 【iioSnail】
List of articles
One 、 This article suggests that
Before reading this article , You need to understand it thoroughly first Self-Attention. I recommend reading another blog post Layers of analysis , Let you understand completely Self-Attention、MultiHead-Attention and Masked-Attention The mechanism and principle of . The content of this article is also in the above article , You can watch it together .
Two . MultiHead Attention
2.1 MultiHead Attention Theoretical explanation
stay Transformer Is used in MultiHead Attention, Actually, it's not the same thing Self Attention It's not a big difference . First clarify the following points , Then start to explain :
- MultiHead Of head No matter how many , Parameter quantities are It's the same . Not at all head many , There are many parameters .
- When MultiHead Of head by 1 when , and No Equivalent to Self Attetnion,MultiHead Attention and Self Attention It's something different
- MultiHead Attention Also used Self Attention Formula
- MultiHead except W q , W k , W v W^q, W^k, W^v Wq,Wk,Wv Outside three matrices , We need to define one more W o W^o Wo.
Okay , Know the above points , We can start to explain MultiHeadAttention 了 .
MultiHead Attention Most of the logic and Self Attention It's consistent , It is from finding Q,K,V And then began to change , So let's start from here .
Now we have solved Q, K, V matrix , about Self-Attention, We can already bring in the formula , Represented by images, it is :

For the sake of simplicity , The figure ignores Softmax and d k d_k dk The calculation of
and MultiHead Attention I did one thing before entering the formula , Namely Demolition , It follows “ Word vector dimension ” In this direction , take Q,K,V Split into multiple heads , As shown in the figure :

Here mine head The number of 4. Since it has been disassembled into several head, Then the following calculation , It's also their own head Calculate , As shown in the figure :

But it can be calculated in this way Attention Use Concat The effect of merging is not very good , So finally, we need to use an additional W o W^o Wo matrix , Yes Attention Do another linear transformation , As shown in the figure :

You can also see it here ,head The number is not the more the better . And why use MultiHead Attention,Transformer The explanation given is :Multi-head attention Allow the model to focus on the information of different representation subspaces from different locations . Anyway, it's better to use it than not .
2.2. Pytorch Realization MultiHead Attention
This code refers to the project annotated-transformer.
First, define a general Attention function :
def attention(query, key, value):
""" Calculation Attention Result . What is actually introduced here is Q,K,V, and Q,K,V The calculation of is put in the model , Please refer to the following MultiHeadedAttention class . there Q,K,V There are two kinds of Shape, If it is Self-Attention,Shape by (batch, The number of words , d_model), for example (1, 7, 128), namely batch_size by 1, A word of 7 Word , Every word 128 dimension But if it is Multi-Head Attention, be Shape by (batch, head Count , The number of words ,d_model/head Count ), for example (1, 8, 7, 16), namely Batch_size by 1,8 individual head, A word of 7 Word ,128/8=16. In this way, you can actually see , So-called MultiHead It's really just the 128 It's taken apart . stay Transformer in , Because of the use of MultiHead Attention, therefore Q,K,V Of Shape It will only be the second kind . """
# obtain d_model Value . The reason why we can get , Because query And the input of shape identical ,
# if Self-Attention, Then the last dimension is the dimension of word vector , That is to say d_model Value .
# if MultiHead Attention, Then the last dimension is d_model / h,h by head Count
d_k = query.size(-1)
# perform QK^T / √d_k
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Execute Softmax
# there p_attn It's a square array
# if Self Attention, be shape by (batch, The number of words , frequency ), for example (1, 7, 7)
# if MultiHead Attention, be shape by (batch, head Count , The number of words , The number of words )
p_attn = scores.softmax(dim=-1)
# Finally, multiply by V.
# about Self Attention Come on , result Shape by (batch, The number of words , d_model), This is the final result .
# But for the MultiHead Attention Come on , result Shape by (batch, head Count , The number of words ,d_model/head Count )
# And this is not the end result , Later, we will head Merge , Turn into (batch, The number of words , d_model). But this is MultiHeadAttention
# What to do .
return torch.matmul(p_attn, value)
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model):
""" h: head The number of """
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
# Definition W^q, W^k, W^v and W^o matrix .
# If you don't know why you use nn.Linear Define the matrix , You can refer to this article :
# https://blog.csdn.net/zhaohongfei_358/article/details/122797190
self.linears = [
nn.Linear(d_model, d_model),
nn.Linear(d_model, d_model),
nn.Linear(d_model, d_model),
nn.Linear(d_model, d_model),
]
def forward(self, x):
# obtain Batch Size
nbatches = x.size(0)
""" 1. Find out Q, K, V, This is for MultiHead Of Q,K,V, therefore Shape by (batch, head Count , The number of words ,d_model/head Count ) 1.1 First , By definition W^q,W^k,W^v Find out SelfAttention Of Q,K,V, here Q,K,V Of Shape by (batch, The number of words , d_model) The corresponding code is `linear(x)` 1.2 Split into bulls , the Shape from (batch, The number of words , d_model) Turn into (batch, The number of words , head Count ,d_model/head Count ). The corresponding code is `view(nbatches, -1, self.h, self.d_k)` 1.3 The final exchange “ The number of words ” and “head Count ” These two dimensions , take head Put the number in front , Final shape Turn into (batch, head Count , The number of words ,d_model/head Count ). The corresponding code is `transpose(1, 2)` """
query, key, value = [
linear(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for linear, x in zip(self.linears, (x, x, x))
]
""" 2. Find out Q,K,V after , adopt attention Function calculated Attention result , here x Of shape by (batch, head Count , The number of words ,d_model/head Count ) self.attn Of shape by (batch, head Count , The number of words , The number of words ) """
x = attention(
query, key, value
)
""" 3. Will be multiple head Merge again , the x Of shape from (batch, head Count , The number of words ,d_model/head Count ) And then it becomes (batch, The number of words ,d_model) 3.1 First , In exchange for “head Count ” and “ The number of words ”, These two dimensions , The result is (batch, The number of words , head Count , d_model/head Count ) The corresponding code is :`x.transpose(1, 2).contiguous()` 3.2 And then “head Count ” and “d_model/head Count ” These two dimensions merge , The result is (batch, The number of words ,d_model) """
x = (
x.transpose(1, 2)
.contiguous()
.view(nbatches, -1, self.h * self.d_k)
)
# Finally through W^o The matrix performs another linear transformation , Get the final result .
return self.linears[-1](x)
Next try to use :
# Definition 8 individual head, The dimension of the word vector is 512
model = MultiHeadedAttention(8, 512)
# Pass in a batch_size by 2, 7 Word , Each word is 512 dimension
x = torch.rand(2, 7, 512)
# Output Attention After the results of the
print(model(x).size())
Output is :
torch.Size([2, 7, 512])
3、 ... and . Masked Attention
3.1 Why use Mask Mask
stay Transformer Medium Decoder There is one of them. Masked MultiHead Attention. This section will explain it in detail .
First of all, let's review Attention Formula :
O n × d v = Attention ( Q n × d k , K n × d k , V n × d v ) = softmax ( Q n × d k K d k × n T d k ) V n × d v = A n × n ′ V n × d v \begin{aligned} O_{n\times d_v} = \text { Attention }(Q_{n\times d_k}, K_{n\times d_k}, V_{n\times d_v})&=\operatorname{softmax}\left(\frac{Q_{n\times d_k} K^{T}_{d_k\times n}}{\sqrt{d_k}}\right) V_{n\times d_v} \\\\ & = A'_{n\times n} V_{n\times d_v} \end{aligned} On×dv= Attention (Qn×dk,Kn×dk,Vn×dv)=softmax(dkQn×dkKdk×nT)Vn×dv=An×n′Vn×dv
among :
O n × d v = [ o 1 o 2 ⋮ o n ] , A n × n ′ = [ α 1 , 1 ′ α 2 , 1 ′ ⋯ α n , 1 ′ α 1 , 2 ′ α 2 , 2 ′ ⋯ α n , 2 ′ ⋮ ⋮ ⋮ α 1 , n ′ α 2 , n ′ ⋯ α n , n ′ ] , V n × d v = [ v 1 v 2 ⋮ v n ] O_{n\times d_v}= \begin{bmatrix} o_1\\ o_2\\ \vdots \\ o_n\\ \end{bmatrix},~~~~A'_{n\times n} = \begin{bmatrix} \alpha'_{1,1} & \alpha'_{2,1} & \cdots &\alpha'_{n,1} \\ \alpha'_{1,2} & \alpha'_{2,2} & \cdots &\alpha'_{n,2} \\ \vdots & \vdots & &\vdots \\ \alpha'_{1,n} & \alpha'_{2,n} & \cdots &\alpha'_{n,n} \\ \end{bmatrix}, ~~~~V_{n\times d_v}= \begin{bmatrix} v_1\\ v_2\\ \vdots \\ v_n\\ \end{bmatrix} On×dv=⎣⎡o1o2⋮on⎦⎤, An×n′=⎣⎡α1,1′α1,2′⋮α1,n′α2,1′α2,2′⋮α2,n′⋯⋯⋯αn,1′αn,2′⋮αn,n′⎦⎤, Vn×dv=⎣⎡v1v2⋮vn⎦⎤
hypothesis ( v 1 , v 2 , . . . v n ) (v_1, v_2, ... v_n) (v1,v2,...vn) Corresponding ( machine , device , learn , xi , really , good , play ) ( machine , device , learn , xi , really , good , play ) ( machine , device , learn , xi , really , good , play ). that ( o 1 , o 2 , . . . , o n ) (o_1, o_2, ..., o_n) (o1,o2,...,on) It corresponds to ( machine ′ , device ′ , learn ′ , xi ′ , really ′ , good ′ , play ′ ) ( machine ', device ', learn ', xi ', really ', good ', play ') ( machine ′, device ′, learn ′, xi ′, really ′, good ′, play ′). among machine ′ machine ' machine ′ contains v 1 v_1 v1 To v n v_n vn All attention information . And calculation machine ′ machine ' machine ′ At the time of the ( machine , device , . . . ) ( machine , device , ...) ( machine , device ,...) The weight of these words is A ′ A' A′ The first line of ( α 1 , 1 ′ , α 2 , 1 ′ , . . . ) (\alpha'_{1,1}, \alpha'_{2,1}, ...) (α1,1′,α2,1′,...).
If you recall the above , So let's take a look at Transformer Usage of , Suppose we want to use Transformer translate “Machine learning is fun” this sentence .
First , We will “Machine learning is fun” Send Encoder, Output a name Memory Of Tensor, As shown in the figure :

After that, we will Memory As Decoder An input to , Use Decoder forecast .Decoder Not all at once “ Machine learning is fun ” Say it , But one word, one word ( Or word by word , It depends on your way of word segmentation ), As shown in the figure :
Then , We will call again Decoder, This time it's incoming “<bos> machine ”:
By analogy , Until the last output <eos> end :

When Transformer Output <eos> when , The prediction is over .
Here we will find , about Decoder It is predicted word by word , So suppose we Decoder The input is “ machine learning ” when ,“ xi ” Words can only be seen in front “ Machine science ” Three words , So now for “ xi ” There are only “ machine learning ” Four word attention information .
however , For example, the last step is “<bos> Machine learning is fun ”, Still can't let “ xi ” See the words behind “ It's fun ” Three words , So use mask Cover it , Why is that ? as a result of : If you allow “ xi ” See the words behind , that “ xi ” The encoding of words will change .
Let's analyze :
At first we only introduced “ machine ”( Ignore bos), At this time to use attention Mechanism , take “ machine ” The word code is [ 0.13 , 0.73 , . . . ] [0.13, 0.73, ...] [0.13,0.73,...]
The second time , We introduced “ machine ”, At this time to use attention Mechanism , If we don't “ device ” Words covered , that “ machine ” The encoding of words will change , It is no longer [ 0.13 , 0.73 , . . . ] [0.13, 0.73, ...] [0.13,0.73,...] 了 , Maybe it becomes [ 0.95 , 0.81 , . . . ] [0.95, 0.81, ...] [0.95,0.81,...].
This will lead to the first “ machine ” The code of the word is [ 0.13 , 0.73 , . . . ] [0.13, 0.73, ...] [0.13,0.73,...], The second time it became [ 0.95 , 0.81 , . . . ] [0.95, 0.81, ...] [0.95,0.81,...], This may cause network problems . So in order not to let “ machine ” The encoding of words changes , So we use mask, Cover up “ machine ” Words after words , That is, even if he can attention The words after , Don't let him attention.
Many articles explain Mask To prevent Transformer Disclose the following information that it should not see during training , I think this explanation is wrong :①Transformer Of Decoder There is no distinction between training and testing , So if it is to prevent the training from divulging the following information , Then why do we have to mask when reasoning ? ② Pass to Decoder It's all about Decoder I reasoned it out by myself , It reasoned out by itself. Don't let it see , It is said to prevent information leakage , This is not bullshit .
Of course , This is also my personal view , Maybe I misunderstood it
3.2 How to do mask Mask
To mask , Only need to scores Just do it , That is to say A n × n ′ A'_{n\times n} An×n′ . Direct example :
for the first time , We only have v 1 v_1 v1 Variable , So it is :
[ o 1 ] = [ α 1 , 1 ′ ] ⋅ [ v 1 ] \begin{bmatrix} o_1\\ \end{bmatrix}=\begin{bmatrix} \alpha'_{1,1} \end{bmatrix} \cdot \begin{bmatrix} v_1\\ \end{bmatrix} [o1]=[α1,1′]⋅[v1]
The second time , We have v 1 , v 2 v_1, v_2 v1,v2 Two variables :
[ o 1 o 2 ] = [ α 1 , 1 ′ α 2 , 1 ′ α 1 , 2 ′ α 2 , 2 ′ ] [ v 1 v 2 ] \begin{bmatrix} o_1\\ o_2 \end{bmatrix} = \begin{bmatrix} \alpha'_{1,1} & \alpha'_{2,1} \\ \alpha'_{1,2} & \alpha'_{2,2} \end{bmatrix} \begin{bmatrix} v_1\\ v_2\\ \end{bmatrix} [o1o2]=[α1,1′α1,2′α2,1′α2,2′][v1v2]
At this time, if we are wrong A 2 × 2 ′ A'_{2\times 2} A2×2′ Mask , o 1 o_1 o1 The value of will change ( The first is α 1 , 1 ′ v 1 \alpha'_{1,1}v_1 α1,1′v1, The second time it became α 1 , 1 ′ v 1 + α 2 , 1 ′ v 2 \alpha'_{1,1}v_1+\alpha'_{2,1}v_2 α1,1′v1+α2,1′v2). Look at it this way , We just need to put α 2 , 1 ′ \alpha'_{2,1} α2,1′ Cover it , This will ensure twice o 1 o_1 o1 It's the same .
So the second time is actually :
[ o 1 o 2 ] = [ α 1 , 1 ′ 0 α 1 , 2 ′ α 2 , 2 ′ ] [ v 1 v 2 ] \begin{bmatrix} o_1\\ o_2 \end{bmatrix} = \begin{bmatrix} \alpha'_{1,1} & 0 \\ \alpha'_{1,2} & \alpha'_{2,2} \end{bmatrix} \begin{bmatrix} v_1\\ v_2\\ \end{bmatrix} [o1o2]=[α1,1′α1,2′0α2,2′][v1v2]
By analogy , If we go to the n n n When the time , It should become :
[ o 1 o 2 ⋮ o n ] = [ α 1 , 1 ′ 0 ⋯ 0 α 1 , 2 ′ α 2 , 2 ′ ⋯ 0 ⋮ ⋮ ⋮ α 1 , n ′ α 2 , n ′ ⋯ α n , n ′ ] [ v 1 v 2 ⋮ v n ] \begin{bmatrix} o_1\\ o_2\\ \vdots \\ o_n\\ \end{bmatrix} = \begin{bmatrix} \alpha'_{1,1} & 0 & \cdots & 0 \\ \alpha'_{1,2} & \alpha'_{2,2} & \cdots & 0 \\ \vdots & \vdots & &\vdots \\ \alpha'_{1,n} & \alpha'_{2,n} & \cdots &\alpha'_{n,n} \\ \end{bmatrix} \begin{bmatrix} v_1\\ v_2\\ \vdots \\ v_n\\ \end{bmatrix} ⎣⎡o1o2⋮on⎦⎤=⎣⎡α1,1′α1,2′⋮α1,n′0α2,2′⋮α2,n′⋯⋯⋯00⋮αn,n′⎦⎤⎣⎡v1v2⋮vn⎦⎤
3.3 Why is it negative infinity instead of 0
According to the above ,mask The mask is 0, But why is the mask in the source code − 1 e 9 -1e9 −1e9 ( Negative infinity ).Attention Part of the source code is as follows :
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = scores.softmax(dim=-1)
Look at it carefully , What we said above A n × n ′ A'_{n\times n} An×n′ What is it? , yes softmax After that . And in the source code , The source code is softmax Mask before , So it's negative infinity , Because it will be negative infinite softmax And then it becomes 0 了 .
边栏推荐
- Samsung folding screen has sent samples to apple and Google, and the annual production capacity will be expanded from 2.4 million to 10million!
- 2021年ICPC陕西省赛热身赛 B.CODE(位运算)
- Programming hodgepodge (I)
- Working principle and precautions of bubble water level gauge
- VIM configuring golang development environment
- LeetCode第302场周赛
- The computer accesses the Internet normally with the same network cable, and the mobile phone connects to WiFi successfully, but it cannot access the Internet
- LCP plug-in creates peer VLAN interface
- 微服务 - 远程调用(Feign组件)
- 50 places are limited to open | with the news of oceanbase's annual press conference coming!
猜你喜欢
随机推荐
出于数据安全考虑,荷兰教育部要求学校暂停使用 Chrome 浏览器
C Programming -- the solution of dynamic programming of "the sum of the largest subarray"
Three schemes for finclip to realize wechat authorized login
How to start if you want to be a product manager?
ERA5数据集说明
sqlilabs less-29
Microservice configuration center Nacos
What are the ways to realize web digital visualization?
Zhou Chen, vice president of zhanrui market, responded to everything about 5g chip chunteng 510!
Continuous maximum sum and judgement palindrome
Talk about how redis handles requests
VPP cannot load up status interface
剖析kubernetes集群内部DNS解析原理
ECS is exclusive to old users, and the new purchase of the remaining 10 instances is as low as 3.6% off
ThreadLocal
Easyrecovery free data recovery tool is easy to operate and restore data with one click
LeetCode 15:三数之和
剑指 Offer 05. 替换空格
Dynamic planning learning notes
CSDN编程挑战赛之数组编程问题









