当前位置:网站首页>Swin-transformer --relative positional Bias
Swin-transformer --relative positional Bias
2022-06-30 17:35:00 【GY-赵】

论文中对于这一块的描述不是很清楚,特意记录一下学习过程。
这篇博客讲解的很清楚,请参考阅读
https://blog.csdn.net/qq_37541097/article/details/121119988
以下通过代码形式进行一个demo。
1.假设window的H,W均为2,首先构造一个二维坐标
x= torch.arange(2)
y= torch.arange(2)
#输入为一维序列,输出两个二维网格,常用来生成坐标
ox,oy = torch.meshgrid([x,y])
#按照某个维度拼接,输入序列shape必须一致,默认按照dim0
o2 = torch.stack((ox,oy))
print(ox,oy)
print(o2,o2.shape)
coords = torch.flatten(o2,1)
print(coords,coords.shape)
输出
tensor([[0, 0],
[1, 1]]) tensor([[0, 1],
[0, 1]])
tensor([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]]) torch.Size([2, 2, 2])
#得到2行序列,对应x,y轴的坐标
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
torch.Size([2, 4])
计算相对坐标索引时,采用了一种我之前没见过的扩张维度的方法,简介高效
print(coords[:,:,None].shape) #相当于增加一个维度
print(coords[:,None,:],coords[:,None,:].shape)
print(coords[:,None,:,None].shape)
#作用与unsqueeze()相同
coords.unsqueeze(1)==coords[:,None,:]
输出
torch.Size([2, 4, 1])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
torch.Size([2, 1, 4])
torch.Size([2, 1, 4, 1])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
print(coords[:,:,None]) #相当于增加一个维度
print(coords[:,None,:])
输出
tensor([[[0],
[0],
[1],
[1]],
[[0],
[1],
[0],
[1]]])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
2.计算相对索引
relative_coords=coords[:,:,None]-coords[:,None,:] #(2,16,1)-(2,1,16) #广播机制相减
print(f"relative_coords:{
relative_coords.shape}={
coords[:,:,None].shape}-{
coords[:,None,:].shape }","\n",{
relative_coords})
输出
#这里相减,应该是使用了广播机制,先扩展到相同shape后,再进行元素相减运算
relative_coords:torch.Size([2, 4, 4])=torch.Size([2, 4, 1])-torch.Size([2, 1, 4])
{
tensor([[[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]],
[[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]]])}
转换为[4,4,2],相当于得到4个4*2的坐标对,一行横坐标,一行纵坐标
relative_coords=relative_coords.permute(1,2,0).contiguous()
print(relative_coords)
输出
torch.Size([4, 4, 2])
tensor([[[ 0, 0],
[ 0, -1],
[-1, 0],
[-1, -1]],
[[ 0, 1],
[ 0, 0],
[-1, 1],
[-1, 0]],
[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]]])
print(relative_coords[:,:,0]) #输出第一列元素对应输入中第一列的第1个元素集合 ,第二列对应输入第一列的第2个元素集合
print(relative_coords[:,:,1])
输出
tensor([[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]])
tensor([[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]])
window_size=(2,2)
#行、列元素都加上M-1 ,这里M=2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
print(relative_coords)
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords)
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print(relative_coords)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
print(relative_position_index)
输出
#第一列(行)加M-1
tensor([[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]],
[[ 2, 0],
[ 2, -1],
[ 1, 0],
[ 1, -1]],
[[ 2, 1],
[ 2, 0],
[ 1, 1],
[ 1, 0]]])
# 继续第2列 (列) 加M-1
tensor([[[1, 1],
[1, 0],
[0, 1],
[0, 0]],
[[1, 2],
[1, 1],
[0, 2],
[0, 1]],
[[2, 1],
[2, 0],
[1, 1],
[1, 0]],
[[2, 2],
[2, 1],
[1, 2],
[1, 1]]])
#第一列 (行) 乘 2M-1(3)
tensor([[[3, 1],
[3, 0],
[0, 1],
[0, 0]],
[[3, 2],
[3, 1],
[0, 2],
[0, 1]],
[[6, 1],
[6, 0],
[3, 1],
[3, 0]],
[[6, 2],
[6, 1],
[3, 2],
[3, 1]]])
#行列元素相加
tensor([[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]])
这里就得到相对位置索引,这里对应的值需要到relative positional bias Table 中获取,一开始程序中就定一个了一个可学习的table,长度为[2M-1]*[2M-1], 这里M=2,也就是长度为9,正对应上边索引0-8
# define a parameter table of relative position bias
#构造可学习的相对位置偏置table,长度为 (2H-1)*(2W-1)*(num_head)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
这里假设有两个attention头
from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 2)) # 2*Wh-1 * 2*Ww-1, nH 假设有两个attn头
print(relative_position_bias_table.shape,"\n",relative_position_bias_table)
trunc_normal_(relative_position_bias_table, std=.02) #初始化bias_table
输出
torch.Size([9, 2]) #两个attn头,每个头(2M-1)*(2M-1)个数
Parameter containing:
tensor([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], requires_grad=True)
Parameter containing: #初始化后的数据
tensor([[-0.0340, 0.0181],
[-0.0033, -0.0055],
[ 0.0045, 0.0193],
[ 0.0412, -0.0031],
[ 0.0004, -0.0032],
[ 0.0201, -0.0161],
[ 0.0067, 0.0079],
[ 0.0241, -0.0279],
[-0.0125, -0.0291]], requires_grad=True)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)]
print("index :\n",relative_position_index.view(-1).shape,"\n",relative_position_index.view(-1))
print("bias table 根据索引取值后的数据:\n",relative_position_bias.shape,"\n",relative_position_bias)
relative_position_bias=relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1) # Wh*Ww,Wh*Ww,nH
print("维度变换:\n",relative_position_bias.shape,"\n",relative_position_bias)
#转换为与attention shape一致
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
index :
torch.Size([16])
tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4]) #索引展开成一维
bias table 根据索引取值后的数据:
torch.Size([16, 2])
tensor([[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055],
[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]], grad_fn=<IndexBackward>)
维度变换:
torch.Size([4, 4, 2])
tensor([[[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181]],
[[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055]],
[[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031]],
[[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]]], grad_fn=<ViewBackward>)

以上代码就是有关相对位置偏置的全部内容了。
边栏推荐
- Summary of methods for offline installation of chrome extensions in China
- 「经验」我对用户增长的理解『新用户篇』
- php利用队列解决迷宫问题
- 冰河老师的书
- AI chief architect 10-aica-lanxiang, propeller frame design and core technology
- 麻烦问下 Flink支持同步数据到 sqlserver么
- 分布式事务
- Force deduction solution summary 1175- prime number arrangement
- The online procurement system of the electronic components industry accurately matches the procurement demand and leverages the digital development of the electronic industry
- autocad中文语言锁定只读警报怎么解决?
猜你喜欢

医疗行业企业供应链系统解决方案:实现医疗数智化供应链协同可视

ForkJoinPool

Multipass Chinese document - setting graphical interface

Helping the ultimate experience, best practice of volcano engine edge computing

英飞凌--GTM架构-Generic Timer Module

LeetCode动态规划经典题(一)

What if icloud photos cannot be uploaded or synchronized?

Solution of enterprise supply chain system in medical industry: realize collaborative visualization of medical digital intelligent supply chain

Leader: who can use redis expired monitoring to close orders and get out of here!

PHP uses queues to solve maze problems
随机推荐
Multipass Chinese document - setting graphical interface
这里数据过滤支持啥样的sql语句
挖财账号开户安全吗?是靠谱的吗?
MySQL n'a pas pu trouver MySQL. Solution temporaire pour le fichier Sock
autocad中文语言锁定只读警报怎么解决?
Rust 文件系统处理之文件读写 - Rust 实践指南
Coding officially entered Tencent conference application market!
剑指 Offer 17. 打印从1到最大的n位数
【TiDB】TiCDC canal_ Practical application of JSON
音频 librosa 库 与 torchaudio 库中 的 Mel- spectrogram 进行对比
深度学习编译器的理解
Rhai - Rust 的嵌入式脚本引擎
Type ~ storage ~ variable in C #
PHP uses queues to solve maze problems
PyTorch学习(三)
When selecting smart speakers, do you prefer "smart" or "sound quality"? This article gives you the answer
Hospital online consultation applet source code Internet hospital source code smart hospital source code
「经验」我对用户增长的理解『新用户篇』
屏幕显示技术进化史
秉持'家在中国'理念 2022 BMW儿童交通安全训练营启动