当前位置:网站首页>Multi-Head-Attention principle and code implementation
Multi-Head-Attention principle and code implementation
2022-08-03 07:04:00 【WGS。】
Attention详细讲解请看
transformer详细讲解请看
Manuscript of the multi-head attention process
Here is a direct example,For a detailed explanation, see the link at the beginning.
我们有3条记录,两个特征,如下所示,其中x1代表“性别”,x2代表“设备品牌”:
x1 x2
男 华为
男 小米
女 苹果
● batch_size = 3
● fields = 2
● emb_dim = 6
● head_num = 2
Then the input dimension is :[3, 2, 6]
以一个batchTake an example to explain the process of multi-head attention,如下图:
Before and after the splittensor如下:
tensor([[[-1.7435, -1.0348, -0.8986, -0.3036, 2.5530, 0.0273],
[ 2.0777, 0.9267, 1.0873, 0.4455, -1.9582, -0.0131]]])
tensor([[[[-1.7435, -1.0348, -0.8986],
[ 2.0777, 0.9267, 1.0873]]],
[[[-0.3036, 2.5530, 0.0273],
[ 0.4455, -1.9582, -0.0131]]]])
torch 实现多头注意力
# coding:utf-8
# @Email: [email protected]
# @Time: 2022/7/25 2:45 下午
# @File: multi_att_demo.py
import pandas as pd, numpy as np
import torch
from torch import nn
import torch.nn.functional as F
class MultiheadAttention(nn.Module):
def __init__(self, emb_dim, head_num, scaling=True):
super(MultiheadAttention, self).__init__()
self.emb_dim = emb_dim
self.head_num = head_num
self.scaling = scaling
self.att_emb_size = emb_dim // head_num
assert emb_dim % head_num == 0, "emb_dim must be divisible head_num"
self.W_Q = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
self.W_K = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
self.W_V = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
# 初始化, 避免计算得到nan
for weight in self.parameters():
nn.init.xavier_uniform_(weight)
def forward(self, inputs):
# inputs_emb: [3, 2, 6]
'''1. 线性变换生成Q、K、V'''
# dim: [batch_size, fields, emb_size]
# [3, 2, 6] * [6, 6] = [3, 2, 6]
querys = torch.tensordot(inputs, self.W_Q, dims=([-1], [0]))
keys = torch.tensordot(inputs, self.W_K, dims=([-1], [0]))
values = torch.tensordot(inputs, self.W_V, dims=([-1], [0]))
# # 等价于 matmul
# querys = torch.matmul(inputs, self.W_Q)
# keys = torch.matmul(inputs, self.W_K)
# values = torch.matmul(inputs, self.W_V)
'''2. 分头'''
# dim: [head_num, batch_size, fields, emb_size // head_num]
# [3, 2, 6] --> [2, 3, 2, 3]
querys = torch.stack(torch.split(querys, self.att_emb_size, dim=2))
keys = torch.stack(torch.split(keys, self.att_emb_size, dim=2))
values = torch.stack(torch.split(values, self.att_emb_size, dim=2))
'''3. 缩放点积注意力'''
# dim: [head_num, batch_size, fields, emb_size // head_num]
# Q * K^T / scale : [2, 3, 2, 3] * [2, 3, 3, 2] = [2, 3, 2, 2]
inner_product = torch.matmul(querys, keys.transpose(-2, -1))
# # 等价于
# inner_product = torch.einsum('bnik,bnjk->bnij', querys, keys)
if self.scaling:
inner_product /= self.att_emb_size ** 0.5
# Softmax归一化权重
attn_w = F.softmax(inner_product, dim=-1)
# 加权求和, attention结果与V相乘,得到多头注意力结果
# [2, 3, 2, 2] * [2, 3, 2, 3] = [2, 3, 2, 3]
results = torch.matmul(attn_w, values)
'''4. 拼接多头空间'''
# dim: [batch_size, fields, emb_size]
# [2, 3, 2, 3] --> [1, 3, 2, 6] --> [3, 2, 6]
results = torch.cat(torch.split(results, 1, ), dim=-1)
results = torch.squeeze(results, dim=0)
results = F.relu(results)
return results
def dt2():
''' x1 x2 男 华为 男 小米 女 苹果 --- encoder x1 x2 0 0 0 1 1 2 + batch_size = 3, + fields = 2, 有2个特征, + emb_dim = 6, + head_num = 2, 分为2个头,每个头的att_emb_size为3 则输入为:[3, 2, 6] '''
# data = pd.DataFrame({'x1': [0, 0, 1], 'x2': [0, 1, 2]})
data = pd.DataFrame({
'x1': [0], 'x2': [0]})
sparse_fields = data.max().values + 1
sparse_fields = sparse_fields.astype(np.int32) # [2, 3]
tensor = torch.Tensor(data.values).long()
print(tensor)
offsets = np.array((0, *np.cumsum(sparse_fields)[:-1]), dtype=np.longlong) # [0, 2]
tensor = tensor + tensor.new_tensor(offsets).unsqueeze(0)
print(tensor)
emb_layer = nn.Embedding(sum(sparse_fields) + 1, embedding_dim=6)
tensor_emb = emb_layer(tensor)
print(tensor_emb.shape)
net = MultiheadAttention(emb_dim=6, head_num=2, scaling=True)
output = net.forward(tensor_emb)
print(output.shape)
print(output)
边栏推荐
猜你喜欢
torch.nn.modules.activation.ReLU is not a Module subclass
Scala 基础 (三):运算符和流程控制
【设计指南】避免PCB板翘,合格的工程师都会这样设计!
2021年PHP-Laravel面试题问卷题 答案记录
【FCOS】FCOS理论知识讲解
el-tree设置利用setCheckedNodessetCheckedKeys默认勾选节点,以及通过setChecked新增勾选指定节点
【OpenStack云平台】搭建openstack云平台
sql中 exists的用法
Docker-compose安装mysql
【云原生 · Kubernetes】Kubernetes基础环境搭建
随机推荐
UniApp 获取当前页面标题(navigationBarTitleText)
pyspark---低频特征处理
一篇文章教你写扫雷(c语言基础版)
【DIoU CIoU】DIoU和CIoU损失函数理解及代码实现
单节点部署 gpmall 商城系统(二)
RADIUS计费认证如何配置?这篇文章一步一步教你完成
MySQL 日期时间类型精确到毫秒
【经验分享】配置用户通过Console口登录设备示例
C语言实现通讯录功能(400行代码实现)
【云原生 · Kubernetes】Kubernetes基础环境搭建
pyspark df 二次排序
单节点部署 gpmall 商城系统(一)
FiBiNet torch复现
docker-compose部署mysql
2021新版idea过滤无用文件.idea .iml
empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType),
Shell脚本--信号发送与捕捉
MySQL中,对结果或条件进行字符串拼接
PHP Composer常用命令积累
【FCOS】FCOS理论知识讲解