当前位置:网站首页>(pytorch进阶之路三)encoder self attention mask
(pytorch进阶之路三)encoder self attention mask
2022-06-29 08:11:00 【likeGhee】
一般mask是放在softmax中的,softmax是单调函数,输入负无穷输出则接近0,所以我们构造的mask矩阵要么为1,要么为负无穷。
mask的shape [batch_size, max_src_len, max_src_len],max_src_len是最大句子长度
我们先构造有效位置pos,padding至max_src_len,用unsqueeze cat bmm reshape至mask的shape,构造出mask布尔矩阵,最后使用masked_fill构造出masked_score
import torch
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
#%%
# 假设有两个句子
batch_size = 2
# 每个句子长度为2~5
src_len = T.randint(2, 5, (batch_size, ))
tgt_len = T.randint(2, 5, (batch_size, ))
print(src_len)
print(tgt_len)
# 方便研究,我们写死
src_len = T.Tensor([2, 4]).to(T.int32)
tgt_len = T.Tensor([4, 3]).to(T.int32)
valid_encoder_pos = [torch.ones(L) for L in src_len]
# padding至max句子长度
valid_encoder_pos = list(map(lambda x: F.pad(x, (0, max(src_len) - len(x))), valid_encoder_pos))
# 扩1维
valid_encoder_pos = list(map(lambda x: T.unsqueeze(x, 0), valid_encoder_pos))
# 拼接
valid_encoder_pos = T.cat(valid_encoder_pos, 0)
# 继续扩维 -> [2,4,1]
valid_encoder_pos = T.unsqueeze(valid_encoder_pos, 2)
print(valid_encoder_pos.shape, "# valid_encoder_pos")
# bmm:带批的矩阵相乘 [2,4,1] * [2,1,4]
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
print(valid_encoder_pos_matrix.shape, "# valid_encoder_pos_matrix")
print(valid_encoder_pos_matrix, "# 4*4,valid_encoder_pos_matrix 第一行表示第一个单词对其他单词的有效性")
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix # 取反
print(invalid_encoder_pos_matrix, "# invalid_encoder_pos_matrix 0表示有效位置,1表示无效的位置")
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention, "# mask_encoder_self_attention True的地方需要mask")
# 用法,随机生成一个score
score = torch.randn(batch_size,max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9) # 传入一个布尔型的张量,mask的地方置为负无穷
# 再对masked的score计算一个softmax, 计算出注意力的权重
prob = F.softmax(masked_score, -1)
print(prob, "# 注意力权重")
边栏推荐
猜你喜欢

Matlab usage

人民链鲍大伟:打破壁垒,建立全域数据治理共享及应用平台
A method to quickly connect notebook computers to mobile phone hotspots

Automatic operation and maintenance management platform - construction and daily use of SPuG

语音合成:概述【不等长序列关系建模的生成任务】

ThreadLocal线程变量
![[untitled]](/img/6e/5dd5dcff89a74f7d367c9186a77268.png)
[untitled]

自注意力机制超级详解(Self-attention)

Dialogue | prospects and challenges of privacy computing in the digital age

Debugging nocturnal simulator with ADB command
随机推荐
Speech signal processing - Fundamentals (I): basic acoustic knowledge
Standard | China payment and clearing Association releases the first privacy computing financial specification
P6776-[NOI2020]超现实树
Paddlenlp general information extraction model: UIE [information extraction {entity relationship extraction, Chinese word segmentation, accurate entity markers, emotion analysis, etc.}, text error cor
AWS Iam inline policy example
壁纸小程序源码双端微信抖音小程序
开发小技巧-图片资源管理
Voice annotation tool: Praat
重磅发布 | 《FISCO BCOS应用落地指南》
笔记本电脑快速连接手机热点的方法
In PHP version 7.1.13, it is found that floating-point data passes through JSON during use_ There will be precision problems after encode
Résumé des différentes séries (harmoniques, géométriques)
标准|中国支付清算协会发布首个隐私计算金融规范
Mutex mutex
802.11--802.11n protocol phy
互斥量互斥锁
hugetlbfs的写时复制
对比HomeKit、米家,智汀家庭云版有哪些场景化的体验
NP5 格式化输出(三)
语音处理工具:sox