当前位置:网站首页>Swin-transformer --relative positional Bias
Swin-transformer --relative positional Bias
2022-06-30 19:06:00 【Gy Zhao】

The description of this piece in the paper is not very clear , Specially record the learning process .
This blog explains very clearly , Please refer to reading
https://blog.csdn.net/qq_37541097/article/details/121119988
The following is an example in the form of code demo.
1. hypothesis window Of H,W Are all 2, First, construct a two-dimensional coordinate
x= torch.arange(2)
y= torch.arange(2)
# Input is a one-dimensional sequence , Output two 2D meshes , Commonly used to generate coordinates
ox,oy = torch.meshgrid([x,y])
# Splice according to a certain dimension , Input sequence shape It has to be consistent , The default in accordance with the dim0
o2 = torch.stack((ox,oy))
print(ox,oy)
print(o2,o2.shape)
coords = torch.flatten(o2,1)
print(coords,coords.shape)
Output
tensor([[0, 0],
[1, 1]]) tensor([[0, 1],
[0, 1]])
tensor([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]]) torch.Size([2, 2, 2])
# obtain 2 Row sequence , Corresponding x,y Axis coordinates
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
torch.Size([2, 4])
When calculating the relative coordinate index , It uses a method of expanding dimensions that I haven't seen before , Brief and efficient
print(coords[:,:,None].shape) # It is equivalent to adding a dimension
print(coords[:,None,:],coords[:,None,:].shape)
print(coords[:,None,:,None].shape)
# The functions and unsqueeze() identical
coords.unsqueeze(1)==coords[:,None,:]
Output
torch.Size([2, 4, 1])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
torch.Size([2, 1, 4])
torch.Size([2, 1, 4, 1])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
print(coords[:,:,None]) # It is equivalent to adding a dimension
print(coords[:,None,:])
Output
tensor([[[0],
[0],
[1],
[1]],
[[0],
[1],
[0],
[1]]])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
2. Calculate relative index
relative_coords=coords[:,:,None]-coords[:,None,:] #(2,16,1)-(2,1,16) # The broadcast mechanism subtracts
print(f"relative_coords:{
relative_coords.shape}={
coords[:,:,None].shape}-{
coords[:,None,:].shape }","\n",{
relative_coords})
Output
# Subtract here , The broadcast mechanism should be used , Expand to the same shape after , Then perform the element subtraction operation
relative_coords:torch.Size([2, 4, 4])=torch.Size([2, 4, 1])-torch.Size([2, 1, 4])
{
tensor([[[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]],
[[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]]])}
Convert to [4,4,2], It's equivalent to getting 4 individual 4*2 Coordinate pairs of , A row of abscissa , A row of vertical coordinates
relative_coords=relative_coords.permute(1,2,0).contiguous()
print(relative_coords)
Output
torch.Size([4, 4, 2])
tensor([[[ 0, 0],
[ 0, -1],
[-1, 0],
[-1, -1]],
[[ 0, 1],
[ 0, 0],
[-1, 1],
[-1, 0]],
[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]]])
print(relative_coords[:,:,0]) # The output first column element corresponds to the... Of the first column in the input 1 A collection of elements , The second column corresponds to the... Of the first column 2 A collection of elements
print(relative_coords[:,:,1])
Output
tensor([[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]])
tensor([[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]])
window_size=(2,2)
# That's ok 、 Column elements are added with M-1 , here M=2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
print(relative_coords)
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords)
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print(relative_coords)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
print(relative_position_index)
Output
# First column ( That's ok ) Add M-1
tensor([[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]],
[[ 2, 0],
[ 2, -1],
[ 1, 0],
[ 1, -1]],
[[ 2, 1],
[ 2, 0],
[ 1, 1],
[ 1, 0]]])
# Continue to the first 2 Column ( Column ) Add M-1
tensor([[[1, 1],
[1, 0],
[0, 1],
[0, 0]],
[[1, 2],
[1, 1],
[0, 2],
[0, 1]],
[[2, 1],
[2, 0],
[1, 1],
[1, 0]],
[[2, 2],
[2, 1],
[1, 2],
[1, 1]]])
# First column ( That's ok ) ride 2M-1(3)
tensor([[[3, 1],
[3, 0],
[0, 1],
[0, 0]],
[[3, 2],
[3, 1],
[0, 2],
[0, 1]],
[[6, 1],
[6, 0],
[3, 1],
[3, 0]],
[[6, 2],
[6, 1],
[3, 2],
[3, 1]]])
# Add row and column elements
tensor([[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]])
Here we get the relative position index , The corresponding value here needs to be up to relative positional bias Table In order to get , At the beginning of the program, a learnable table, The length is [2M-1]*[2M-1], here M=2, That is, the length is 9, It corresponds to the upper index 0-8
# define a parameter table of relative position bias
# Construct a learnable relative position offset table, The length is (2H-1)*(2W-1)*(num_head)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
Let's say there are two attention head
from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 2)) # 2*Wh-1 * 2*Ww-1, nH Let's say I have two attn head
print(relative_position_bias_table.shape,"\n",relative_position_bias_table)
trunc_normal_(relative_position_bias_table, std=.02) # initialization bias_table
Output
torch.Size([9, 2]) # Two attn head , Every head (2M-1)*(2M-1) Number
Parameter containing:
tensor([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], requires_grad=True)
Parameter containing: # Data after initialization
tensor([[-0.0340, 0.0181],
[-0.0033, -0.0055],
[ 0.0045, 0.0193],
[ 0.0412, -0.0031],
[ 0.0004, -0.0032],
[ 0.0201, -0.0161],
[ 0.0067, 0.0079],
[ 0.0241, -0.0279],
[-0.0125, -0.0291]], requires_grad=True)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)]
print("index :\n",relative_position_index.view(-1).shape,"\n",relative_position_index.view(-1))
print("bias table According to the data obtained from the index :\n",relative_position_bias.shape,"\n",relative_position_bias)
relative_position_bias=relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1) # Wh*Ww,Wh*Ww,nH
print(" Dimensional transformation :\n",relative_position_bias.shape,"\n",relative_position_bias)
# Convert to and from attention shape Agreement
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
index :
torch.Size([16])
tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4]) # Expand the index into one dimension
bias table According to the data obtained from the index :
torch.Size([16, 2])
tensor([[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055],
[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]], grad_fn=<IndexBackward>)
Dimensional transformation :
torch.Size([4, 4, 2])
tensor([[[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181]],
[[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055]],
[[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031]],
[[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]]], grad_fn=<ViewBackward>)

The above code is all about relative position offset .
边栏推荐
- Detailed single case mode
- Courage to be hated: Adler's philosophy class: the father of self inspiration
- 期货怎么开户安全些?现在哪些期货公司靠谱些?
- 3.10 haas506 2.0 development tutorial example TFT
- Merged binary tree of leetcode
- PHP uses queues to solve maze problems
- mysql函数获取全路径
- Full recharge, im+rtc+x full communication service "feedback season" starts
- Redis入门到精通01
- PO模式简介「建议收藏」
猜你喜欢

TCP粘包问题

法国A+ 法国VOC标签最高环保级别

When selecting smart speakers, do you prefer "smart" or "sound quality"? This article gives you the answer

What if the apple watch fails to power on? Apple watch can not boot solution!

Swin-transformer --relative positional Bias

EasyNVR平台设备通道均在线,操作出现“网络请求失败”是什么原因?

sqlserver SQL Server Management Studio和Transact-SQL创建账户、创建访问指定数据库的只读用户

电子元器件招标采购商城:优化传统采购业务,提速企业数字化升级

云上“视界” 创新无限 | 2022阿里云直播峰会正式上线

一套十万级TPS的IM综合消息系统的架构实践与思考
随机推荐
Leader: who can use redis expired monitoring to close orders and get out of here!
医疗行业企业供应链系统解决方案:实现医疗数智化供应链协同可视
《客从何处来》
Infineon - GTM architecture -generic timer module
新版EasyGBS如何配置WebRTC视频流格式播放?
Troubleshooting MySQL for update deadlock
PHP uses queues to solve maze problems
C WinForm program interface optimization example
20220607跌破建议零售价,GPU市场正全面走向供过于求...
SaaS project management system solution for the financial service industry helps enterprises tap a broader growth service space
How to seamlessly transition from traditional microservice framework to service grid ASM
Electronic components bidding and purchasing Mall: optimize traditional purchasing business and speed up enterprise digital upgrading
秉持'家在中国'理念 2022 BMW儿童交通安全训练营启动
com.alibaba.fastjson.JSONObject # toJSONString 消除循环引用
Solution of enterprise supply chain system in medical industry: realize collaborative visualization of medical digital intelligent supply chain
Detailed single case mode
Entry node of link in linked list - linked list topic
How to do a good job in software system demand research? Seven weapons make it easy for you to do it
TCP packet sticking problem
浏览器窗口切换激活事件 visibilitychange