当前位置:网站首页>Swin-transformer --relative positional Bias
Swin-transformer --relative positional Bias
2022-06-30 17:35:00 【GY-赵】

论文中对于这一块的描述不是很清楚,特意记录一下学习过程。
这篇博客讲解的很清楚,请参考阅读
https://blog.csdn.net/qq_37541097/article/details/121119988
以下通过代码形式进行一个demo。
1.假设window的H,W均为2,首先构造一个二维坐标
x= torch.arange(2)
y= torch.arange(2)
#输入为一维序列,输出两个二维网格,常用来生成坐标
ox,oy = torch.meshgrid([x,y])
#按照某个维度拼接,输入序列shape必须一致,默认按照dim0
o2 = torch.stack((ox,oy))
print(ox,oy)
print(o2,o2.shape)
coords = torch.flatten(o2,1)
print(coords,coords.shape)
输出
tensor([[0, 0],
[1, 1]]) tensor([[0, 1],
[0, 1]])
tensor([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]]) torch.Size([2, 2, 2])
#得到2行序列,对应x,y轴的坐标
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
torch.Size([2, 4])
计算相对坐标索引时,采用了一种我之前没见过的扩张维度的方法,简介高效
print(coords[:,:,None].shape) #相当于增加一个维度
print(coords[:,None,:],coords[:,None,:].shape)
print(coords[:,None,:,None].shape)
#作用与unsqueeze()相同
coords.unsqueeze(1)==coords[:,None,:]
输出
torch.Size([2, 4, 1])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
torch.Size([2, 1, 4])
torch.Size([2, 1, 4, 1])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
print(coords[:,:,None]) #相当于增加一个维度
print(coords[:,None,:])
输出
tensor([[[0],
[0],
[1],
[1]],
[[0],
[1],
[0],
[1]]])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
2.计算相对索引
relative_coords=coords[:,:,None]-coords[:,None,:] #(2,16,1)-(2,1,16) #广播机制相减
print(f"relative_coords:{
relative_coords.shape}={
coords[:,:,None].shape}-{
coords[:,None,:].shape }","\n",{
relative_coords})
输出
#这里相减,应该是使用了广播机制,先扩展到相同shape后,再进行元素相减运算
relative_coords:torch.Size([2, 4, 4])=torch.Size([2, 4, 1])-torch.Size([2, 1, 4])
{
tensor([[[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]],
[[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]]])}
转换为[4,4,2],相当于得到4个4*2的坐标对,一行横坐标,一行纵坐标
relative_coords=relative_coords.permute(1,2,0).contiguous()
print(relative_coords)
输出
torch.Size([4, 4, 2])
tensor([[[ 0, 0],
[ 0, -1],
[-1, 0],
[-1, -1]],
[[ 0, 1],
[ 0, 0],
[-1, 1],
[-1, 0]],
[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]]])
print(relative_coords[:,:,0]) #输出第一列元素对应输入中第一列的第1个元素集合 ,第二列对应输入第一列的第2个元素集合
print(relative_coords[:,:,1])
输出
tensor([[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]])
tensor([[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]])
window_size=(2,2)
#行、列元素都加上M-1 ,这里M=2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
print(relative_coords)
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords)
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print(relative_coords)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
print(relative_position_index)
输出
#第一列(行)加M-1
tensor([[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]],
[[ 2, 0],
[ 2, -1],
[ 1, 0],
[ 1, -1]],
[[ 2, 1],
[ 2, 0],
[ 1, 1],
[ 1, 0]]])
# 继续第2列 (列) 加M-1
tensor([[[1, 1],
[1, 0],
[0, 1],
[0, 0]],
[[1, 2],
[1, 1],
[0, 2],
[0, 1]],
[[2, 1],
[2, 0],
[1, 1],
[1, 0]],
[[2, 2],
[2, 1],
[1, 2],
[1, 1]]])
#第一列 (行) 乘 2M-1(3)
tensor([[[3, 1],
[3, 0],
[0, 1],
[0, 0]],
[[3, 2],
[3, 1],
[0, 2],
[0, 1]],
[[6, 1],
[6, 0],
[3, 1],
[3, 0]],
[[6, 2],
[6, 1],
[3, 2],
[3, 1]]])
#行列元素相加
tensor([[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]])
这里就得到相对位置索引,这里对应的值需要到relative positional bias Table 中获取,一开始程序中就定一个了一个可学习的table,长度为[2M-1]*[2M-1], 这里M=2,也就是长度为9,正对应上边索引0-8
# define a parameter table of relative position bias
#构造可学习的相对位置偏置table,长度为 (2H-1)*(2W-1)*(num_head)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
这里假设有两个attention头
from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 2)) # 2*Wh-1 * 2*Ww-1, nH 假设有两个attn头
print(relative_position_bias_table.shape,"\n",relative_position_bias_table)
trunc_normal_(relative_position_bias_table, std=.02) #初始化bias_table
输出
torch.Size([9, 2]) #两个attn头,每个头(2M-1)*(2M-1)个数
Parameter containing:
tensor([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], requires_grad=True)
Parameter containing: #初始化后的数据
tensor([[-0.0340, 0.0181],
[-0.0033, -0.0055],
[ 0.0045, 0.0193],
[ 0.0412, -0.0031],
[ 0.0004, -0.0032],
[ 0.0201, -0.0161],
[ 0.0067, 0.0079],
[ 0.0241, -0.0279],
[-0.0125, -0.0291]], requires_grad=True)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)]
print("index :\n",relative_position_index.view(-1).shape,"\n",relative_position_index.view(-1))
print("bias table 根据索引取值后的数据:\n",relative_position_bias.shape,"\n",relative_position_bias)
relative_position_bias=relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1) # Wh*Ww,Wh*Ww,nH
print("维度变换:\n",relative_position_bias.shape,"\n",relative_position_bias)
#转换为与attention shape一致
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
index :
torch.Size([16])
tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4]) #索引展开成一维
bias table 根据索引取值后的数据:
torch.Size([16, 2])
tensor([[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055],
[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]], grad_fn=<IndexBackward>)
维度变换:
torch.Size([4, 4, 2])
tensor([[[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181]],
[[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055]],
[[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031]],
[[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]]], grad_fn=<ViewBackward>)

以上代码就是有关相对位置偏置的全部内容了。
边栏推荐
- Merged binary tree of leetcode
- C# Winform程序界面优化实例
- Helping the ultimate experience, best practice of volcano engine edge computing
- How to use AI technology to optimize the independent station customer service system? Listen to the experts!
- 云上“视界” 创新无限 | 2022阿里云直播峰会正式上线
- 「经验」我对用户增长的理解『新用户篇』
- PHP uses queues to solve maze problems
- 漏洞复现----37、Apache Unomi 远程代码执行漏洞 (CVE-2020-13942)
- MySQL advanced - Architecture
- [零基础学IoT Pwn] 环境搭建
猜你喜欢

Vulnerability recurrence ----- 38. Thinkphp5 5.0.23 Remote Code Execution Vulnerability

It's not easy to say I love you | use the minimum web API to upload files

Dlib库实现人脸关键点检测(Opencv实现)

Multipass Chinese document - setting graphical interface

iCloud照片无法上传或同步怎么办?

ONEFLOW source code parsing: automatic inference of operator signature

Sword finger offer 17 Print from 1 to maximum n digits

Volcano engine was selected into the first "panorama of edge computing industry" in China

When selecting smart speakers, do you prefer "smart" or "sound quality"? This article gives you the answer

ForkJoinPool
随机推荐
AI chief architect 10-aica-lanxiang, propeller frame design and core technology
Apple Watch无法开机怎么办?苹果手表不能开机解决方法!
《所谓情商高,就是会说话》读书笔记
Geoffrey Hinton: my 50 years of in-depth study and Research on mental skills
[cloud resident co creation] Huawei iconnect enables IOT terminals to connect at one touch
ForkJoinPool
Do you write API documents or code first?
Electronic components bidding and purchasing Mall: optimize traditional purchasing business and speed up enterprise digital upgrading
Type ~ storage ~ variable in C #
又一篇CVPR 2022论文被指抄袭,平安保险研究者控诉IBM苏黎世团队
In distributed scenarios, do you know how to generate unique IDs?
Merged binary tree of leetcode
剑指 Offer 17. 打印从1到最大的n位数
MySQL n'a pas pu trouver MySQL. Solution temporaire pour le fichier Sock
云安全日报220630:IBM数据保护平台发现执行任意代码漏洞,需要尽快升级
视频内容生产与消费创新
Tensorflow2 深度学习十必知
php利用队列解决迷宫问题
Infineon - GTM architecture -generic timer module
How to do a good job in software system demand research? Seven weapons make it easy for you to do it