当前位置:网站首页>(pytorch进阶之路三)encoder self attention mask
(pytorch进阶之路三)encoder self attention mask
2022-06-29 08:11:00 【likeGhee】
一般mask是放在softmax中的,softmax是单调函数,输入负无穷输出则接近0,所以我们构造的mask矩阵要么为1,要么为负无穷。
mask的shape [batch_size, max_src_len, max_src_len],max_src_len是最大句子长度
我们先构造有效位置pos,padding至max_src_len,用unsqueeze cat bmm reshape至mask的shape,构造出mask布尔矩阵,最后使用masked_fill构造出masked_score
import torch
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
#%%
# 假设有两个句子
batch_size = 2
# 每个句子长度为2~5
src_len = T.randint(2, 5, (batch_size, ))
tgt_len = T.randint(2, 5, (batch_size, ))
print(src_len)
print(tgt_len)
# 方便研究,我们写死
src_len = T.Tensor([2, 4]).to(T.int32)
tgt_len = T.Tensor([4, 3]).to(T.int32)
valid_encoder_pos = [torch.ones(L) for L in src_len]
# padding至max句子长度
valid_encoder_pos = list(map(lambda x: F.pad(x, (0, max(src_len) - len(x))), valid_encoder_pos))
# 扩1维
valid_encoder_pos = list(map(lambda x: T.unsqueeze(x, 0), valid_encoder_pos))
# 拼接
valid_encoder_pos = T.cat(valid_encoder_pos, 0)
# 继续扩维 -> [2,4,1]
valid_encoder_pos = T.unsqueeze(valid_encoder_pos, 2)
print(valid_encoder_pos.shape, "# valid_encoder_pos")
# bmm:带批的矩阵相乘 [2,4,1] * [2,1,4]
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
print(valid_encoder_pos_matrix.shape, "# valid_encoder_pos_matrix")
print(valid_encoder_pos_matrix, "# 4*4,valid_encoder_pos_matrix 第一行表示第一个单词对其他单词的有效性")
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix # 取反
print(invalid_encoder_pos_matrix, "# invalid_encoder_pos_matrix 0表示有效位置,1表示无效的位置")
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention, "# mask_encoder_self_attention True的地方需要mask")
# 用法,随机生成一个score
score = torch.randn(batch_size,max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9) # 传入一个布尔型的张量,mask的地方置为负无穷
# 再对masked的score计算一个softmax, 计算出注意力的权重
prob = F.softmax(masked_score, -1)
print(prob, "# 注意力权重")
边栏推荐
- Message Oriented Middleware: pulsar
- Debugging nocturnal simulator with ADB command
- A review of visual SLAM methods for autonomous driving vehicles
- 开发小技巧-图片资源管理
- Huawei equipment is configured with small network WLAN basic services
- Hands on deep learning (I) -- linear neural network
- Is it really safe to open a stock account online? Find the answer
- 802.11--802.11n协议 PHY
- Transformer details
- Speech signal processing - Fundamentals (I): basic acoustic knowledge
猜你喜欢

802.11--802.11n protocol phy

VMware vcenter/esxi series vulnerability summary

hostname -f与uname -n的返回值可能不同
![[most complete] download and installation of various versions of PS and tutorial of small test ox knife (Photoshop CS3 ~ ~ Photoshop 2022)](/img/6d/4d8d90dd221de697f4c2ab5dcc7f96.png)
[most complete] download and installation of various versions of PS and tutorial of small test ox knife (Photoshop CS3 ~ ~ Photoshop 2022)

x86和x64的区别

Tutorial on building open source Internet of things platform

Oracle subquery

Baodawei of the people's Chain: break down barriers and establish a global data governance sharing and application platform

机器人代码生成器之Robcogen使用教程
![Target tracking [single target tracking (vot/sot), target detection, pedestrian re identification (re ID)]](/img/f2/d42032f05214a4ad9339ea18966cc2.jpg)
Target tracking [single target tracking (vot/sot), target detection, pedestrian re identification (re ID)]
随机推荐
Notes mosaïque
互斥量互斥锁
Dialogue | prospects and challenges of privacy computing in the digital age
搭建开源物联网平台教程
P6776-[NOI2020]超现实树
51单片机中断与定时器计数器,基于普中科技HC6800-ESV2.0
Hands on deep learning (I) -- linear neural network
[microservices openfeign] timeout of openfeign
Swift中@dynamicMemberLookup和callAsFunction特性实现对象透明代理功能
Intelligent hardware EVT DVT PVT mp
积分商城运营要如何做才能获取到利润
A review of visual SLAM methods for autonomous driving vehicles
Target tracking [single target tracking (vot/sot), target detection, pedestrian re identification (re ID)]
微信小程序开发,如何添加多个空格
Huawei equipment is configured with medium-sized network WLAN basic services
十大券商账号开户安全吗?是靠谱的吗?
Memoirs of actual combat: breaking the border from webshell
A review of visual SLAM methods for autonomous driving vehicles
Excel中VLOOKUP函数简易使用——精确匹配或近似匹配数据
表格背单词的方法