当前位置:网站首页>(pytorch进阶之路三)encoder self attention mask
(pytorch进阶之路三)encoder self attention mask
2022-06-29 08:11:00 【likeGhee】
一般mask是放在softmax中的,softmax是单调函数,输入负无穷输出则接近0,所以我们构造的mask矩阵要么为1,要么为负无穷。
mask的shape [batch_size, max_src_len, max_src_len],max_src_len是最大句子长度
我们先构造有效位置pos,padding至max_src_len,用unsqueeze cat bmm reshape至mask的shape,构造出mask布尔矩阵,最后使用masked_fill构造出masked_score
import torch
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
#%%
# 假设有两个句子
batch_size = 2
# 每个句子长度为2~5
src_len = T.randint(2, 5, (batch_size, ))
tgt_len = T.randint(2, 5, (batch_size, ))
print(src_len)
print(tgt_len)
# 方便研究,我们写死
src_len = T.Tensor([2, 4]).to(T.int32)
tgt_len = T.Tensor([4, 3]).to(T.int32)
valid_encoder_pos = [torch.ones(L) for L in src_len]
# padding至max句子长度
valid_encoder_pos = list(map(lambda x: F.pad(x, (0, max(src_len) - len(x))), valid_encoder_pos))
# 扩1维
valid_encoder_pos = list(map(lambda x: T.unsqueeze(x, 0), valid_encoder_pos))
# 拼接
valid_encoder_pos = T.cat(valid_encoder_pos, 0)
# 继续扩维 -> [2,4,1]
valid_encoder_pos = T.unsqueeze(valid_encoder_pos, 2)
print(valid_encoder_pos.shape, "# valid_encoder_pos")
# bmm:带批的矩阵相乘 [2,4,1] * [2,1,4]
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
print(valid_encoder_pos_matrix.shape, "# valid_encoder_pos_matrix")
print(valid_encoder_pos_matrix, "# 4*4,valid_encoder_pos_matrix 第一行表示第一个单词对其他单词的有效性")
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix # 取反
print(invalid_encoder_pos_matrix, "# invalid_encoder_pos_matrix 0表示有效位置,1表示无效的位置")
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention, "# mask_encoder_self_attention True的地方需要mask")
# 用法,随机生成一个score
score = torch.randn(batch_size,max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9) # 传入一个布尔型的张量,mask的地方置为负无穷
# 再对masked的score计算一个softmax, 计算出注意力的权重
prob = F.softmax(masked_score, -1)
print(prob, "# 注意力权重")
边栏推荐
- [untitled]
- 名企实习一年要学会的15件事,这样你就省的走弯路了。
- Déclaration de la variable Typescript - - assertion de type
- 802.11--802.11n protocol phy
- Wallpaper applet source code double ended wechat Tiktok applet
- 人民链鲍大伟:打破壁垒,建立全域数据治理共享及应用平台
- Wechat applet development, how to add multiple spaces
- 启牛学堂让开的证券账户是真的安全靠谱吗?
- AWS Iam inline policy example
- [hcie TAC] question 5-2
猜你喜欢

2022春夏系列 KOREANO ESSENTIAL重塑时装生命力

Swift中@dynamicMemberLookup和callAsFunction特性实现对象透明代理功能

Dialogue | prospects and challenges of privacy computing in the digital age

【LoRaWAN节点应用】安信可Ra-08/Ra-08H模组入网LoRaWAN网络的应用及功耗情况

New spark in intelligent era: wireless irrigation with Lora wireless transmission technology
笔记本电脑快速连接手机热点的方法

The @dynamicmemberlookup and callasfunction features in swift implement the object transparent proxy function

航芯开发板&调试器
![Target tracking [single target tracking (vot/sot), target detection, pedestrian re identification (re ID)]](/img/f2/d42032f05214a4ad9339ea18966cc2.jpg)
Target tracking [single target tracking (vot/sot), target detection, pedestrian re identification (re ID)]

【最全】PS各个版本下载安装及小试牛刀教程(PhotoShop CS3 ~~ PhotoShop 2022)
随机推荐
Déclaration de la variable Typescript - - assertion de type
TypeScript 變量聲明 —— 類型斷言
《乔布斯传》英文原著重点词汇笔记(七)【 chapter five】
《乔布斯传》英文原著重点词汇笔记(八)【 chapter six 】
Is it really safe to open a stock account online? Find the answer
互斥量互斥锁
ThreadLocal线程变量
分布式数字身份的几个“非技术”思考
Actual combat memoir starts from webshell to break through the border
51单片机中断与定时器计数器,基于普中科技HC6800-ESV2.0
Open an account to buy funds. Is it safe to open an account through online funds-
消息中间件:pulsar
【最全】PS各个版本下载安装及小试牛刀教程(PhotoShop CS3 ~~ PhotoShop 2022)
Wallpaper applet source code double ended wechat Tiktok applet
苹果开发者容易招致调查的若干行为
A high-frequency problem, three kinds of model thinking to solve this risk control problem
Sed replace value with variable
Measure the level of various chess playing activities through ELO mechanism
开发小技巧-图片资源管理
微信小程序开发,如何添加多个空格