当前位置:网站首页>Code implementation additive attention
Code implementation additive attention
2022-07-28 17:11:00 【InfoQ】
import math
import torch
from torch import nn
from d2l import torch as d2l
def masked_softmax(X, valid_lens):
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
X_tile[(1 - torch.eye(n_train)).type(torch.bool)]- Function two parameters
Xandvalid_lens,x Yes. softmax Tensor ,valid_lens Store the effective length on each dimension , Whether it's one-dimensional or two-dimensional , We must ensure that the broadcasting mechanism can be carried out .
- It's a function if sentence
if valid_lens is NoneIt means if you don't givevalid_lens, That is, the whole tensor is effective , There is no need for mask After that softmax, therefore if Statement directly returns an ordinary softmax operation , Function run finished .
- When it comes to
valid_lensWhen you enter else
- The first is to use
shapeStore to be mask TensorXOf shape.
- Another if-else sentence , This is used to deal with
valid_lensOf length , takevalid_lensNumber of rows of length transformation matrix .
- When
valid_lensIt's a one-dimensional time to enter if, Convert it to a mask vector . Explain it. , because mini-batch The existence of , So the incomingXIt's usually three dimensional , The first dimension is batch size, The two or three dimensions are the size of the matrix . I used to useshapeStorageXOf shape, Now useshape[1]FetchXThe matrix in is a few lines , Then the valid elements of each line correspond tovalid_lensThe value in .
- Want to know
torch.repeat_interleaveLook here →pytorch Medium repeat Operation comparison
- When
valid_lensNot in one dimension else in . Directly convert it from a matrix to a vector .
- about mask The operation is to use d2l The functions in , I won't pick it up , For the processing of dimensions, remember :
- If the incoming
valid_lensIt's one-dimensional , thatvalid_lensThe length should be the same asXThe second dimension of (shape[1]) equally .
- If the incoming
valid_lensIt's two-dimensional , thatvalid_lensThe first dimension of should be the same as batch size equally , The second dimension should be consistent withXThe number of rows in the matrix is the same .
- Specific examples can be seen inCode implementation Zoom in and out and focus | scaled dot-product attention
class AdditiveAttention(nn.Module):
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
self.w_v = nn.Linear(num_hiddens, 1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
- Three main parameters ,
key_sizekeys The length of ,query_sizequery The length of ,num_hiddensThe size of the hidden layer . Because additive attention is dealing with keys and queries In case of different lengths .
- Three small linear layers .
self.W_kandself.W_qIt's a key and query Convert to hidden layer ,self.W_vFrom hidden layer to single output .
- All settings here do not need bias
- Finally, I did something dropout
- Then the forward propagation function , It's calculationThe process of :
- take queries and keys Throw it into the first two linear layers to get queries and keys, Make dimension adjustment .
queriesThe shape of the :(batch_size, Number of queries , 1,num_hidden)
keyThe shape of the :(batch_size, 1, “ key - value ” The number of right ,num_hiddens)
- Calculate the formula .
scoresThe calculation ofself.w_vThere is only one output , So remove the last dimension from the shape .scoresThe shape of the :(batch_size, Number of queries , “ key - value ” The number of right )
- Last
valuesThe shape of the :(batch_size, “ key - value ” The number of right , Dimension of value )
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# `values` Small batch dataset , The two value matrices are the same
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
.eval()d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
边栏推荐
- Atcoder regular contest 133 d.range XOR (digital dp+ classification discussion)
- In 2020q2, shipments in the global tablet market soared by 26.1%: Huawei ranked third and Lenovo increased the most!
- 【深度学习】:《PyTorch入门到项目实战》:简洁代码实现线性神经网络(附代码)
- 做题笔记5(有序数组的平方)
- College students participated in six Star Education PHP training and found jobs with salaries far higher than those of their peers
- Alibaba cloud MSE supports go language traffic protection
- Re12:读论文 Se3 Semantic Self-segmentation for Abstractive Summarization of Long Legal Documents in Low
- Ugui learning notes (VI) get the information of the clicked UI
- kubenertes 1.16集群部署问题总结
- Technology sharing | MySQL shell customized deployment MySQL instance
猜你喜欢

Technology sharing | how to recover the erroneously deleted table and the data in the table?

Games101 section 13 ray tracing notes

如何在构建阶段保护镜像安全

MySQL安装教程

在AD中添加差分对及连线

Educational codeforces round 126 (rated for Div. 2) f.teleporters (two sets and two points)

Re13:读论文 Gender and Racial Stereotype Detection in Legal Opinion Word Embeddings

浏览器解码过程分析

Re12: read these3 semantic self segmentation for abstract summary of long legal documents in low

Easypoi --- excel file export
随机推荐
mysql 最大建议行数2000w,靠谱吗?
[deep learning]: day 6 of pytorch introduction to project practice: multi-layer perceptron (including code)
Easypoi multi sheet export by template
Unity shader uses rendered texture to achieve glass effect
Ugui learning notes (IV) ugui event system overview and Usage Summary
The 16th program design competition of Dalian University of Technology (Problem Solver)
After paying $1.8 billion in royalties to Qualcomm, Huawei reportedly ordered 120million chips from MediaTek! Official response
Re13: read the paper gender and racial stereotype detection in legal opinion word embeddings
Read excel xlsx format file in unity
Question making note 3 (two point search)
技术分享 | 误删表以及表中数据,该如何恢复?
数据库故障容错之系统时钟故障
Semtech launched Lora edge, a geolocation solution for the Internet of things, and the first chip lr1110 is now on the market
Applet: scroll view slides to the bottom by default
综合设计一个OPPE主页--页面的售后服务
Exercise note 5 (square of ordered array)
How to use fail2ban to protect WordPress login page
Games101-assignment05 ray tracing - rays intersect triangles
Reduce cycle complexity
Function接口之andThen