当前位置:网站首页>Swin-transformer --relative positional Bias
Swin-transformer --relative positional Bias
2022-06-30 19:06:00 【Gy Zhao】

The description of this piece in the paper is not very clear , Specially record the learning process .
This blog explains very clearly , Please refer to reading
https://blog.csdn.net/qq_37541097/article/details/121119988
The following is an example in the form of code demo.
1. hypothesis window Of H,W Are all 2, First, construct a two-dimensional coordinate
x= torch.arange(2)
y= torch.arange(2)
# Input is a one-dimensional sequence , Output two 2D meshes , Commonly used to generate coordinates
ox,oy = torch.meshgrid([x,y])
# Splice according to a certain dimension , Input sequence shape It has to be consistent , The default in accordance with the dim0
o2 = torch.stack((ox,oy))
print(ox,oy)
print(o2,o2.shape)
coords = torch.flatten(o2,1)
print(coords,coords.shape)
Output
tensor([[0, 0],
[1, 1]]) tensor([[0, 1],
[0, 1]])
tensor([[[0, 0],
[1, 1]],
[[0, 1],
[0, 1]]]) torch.Size([2, 2, 2])
# obtain 2 Row sequence , Corresponding x,y Axis coordinates
tensor([[0, 0, 1, 1],
[0, 1, 0, 1]])
torch.Size([2, 4])
When calculating the relative coordinate index , It uses a method of expanding dimensions that I haven't seen before , Brief and efficient
print(coords[:,:,None].shape) # It is equivalent to adding a dimension
print(coords[:,None,:],coords[:,None,:].shape)
print(coords[:,None,:,None].shape)
# The functions and unsqueeze() identical
coords.unsqueeze(1)==coords[:,None,:]
Output
torch.Size([2, 4, 1])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
torch.Size([2, 1, 4])
torch.Size([2, 1, 4, 1])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
print(coords[:,:,None]) # It is equivalent to adding a dimension
print(coords[:,None,:])
Output
tensor([[[0],
[0],
[1],
[1]],
[[0],
[1],
[0],
[1]]])
tensor([[[0, 0, 1, 1]],
[[0, 1, 0, 1]]])
tensor([[[True, True, True, True]],
[[True, True, True, True]]])
2. Calculate relative index
relative_coords=coords[:,:,None]-coords[:,None,:] #(2,16,1)-(2,1,16) # The broadcast mechanism subtracts
print(f"relative_coords:{
relative_coords.shape}={
coords[:,:,None].shape}-{
coords[:,None,:].shape }","\n",{
relative_coords})
Output
# Subtract here , The broadcast mechanism should be used , Expand to the same shape after , Then perform the element subtraction operation
relative_coords:torch.Size([2, 4, 4])=torch.Size([2, 4, 1])-torch.Size([2, 1, 4])
{
tensor([[[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]],
[[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]]])}
Convert to [4,4,2], It's equivalent to getting 4 individual 4*2 Coordinate pairs of , A row of abscissa , A row of vertical coordinates
relative_coords=relative_coords.permute(1,2,0).contiguous()
print(relative_coords)
Output
torch.Size([4, 4, 2])
tensor([[[ 0, 0],
[ 0, -1],
[-1, 0],
[-1, -1]],
[[ 0, 1],
[ 0, 0],
[-1, 1],
[-1, 0]],
[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]]])
print(relative_coords[:,:,0]) # The output first column element corresponds to the... Of the first column in the input 1 A collection of elements , The second column corresponds to the... Of the first column 2 A collection of elements
print(relative_coords[:,:,1])
Output
tensor([[ 0, 0, -1, -1],
[ 0, 0, -1, -1],
[ 1, 1, 0, 0],
[ 1, 1, 0, 0]])
tensor([[ 0, -1, 0, -1],
[ 1, 0, 1, 0],
[ 0, -1, 0, -1],
[ 1, 0, 1, 0]])
window_size=(2,2)
# That's ok 、 Column elements are added with M-1 , here M=2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
print(relative_coords)
relative_coords[:, :, 1] += window_size[1] - 1
print(relative_coords)
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
print(relative_coords)
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
print(relative_position_index)
Output
# First column ( That's ok ) Add M-1
tensor([[[ 1, 0],
[ 1, -1],
[ 0, 0],
[ 0, -1]],
[[ 1, 1],
[ 1, 0],
[ 0, 1],
[ 0, 0]],
[[ 2, 0],
[ 2, -1],
[ 1, 0],
[ 1, -1]],
[[ 2, 1],
[ 2, 0],
[ 1, 1],
[ 1, 0]]])
# Continue to the first 2 Column ( Column ) Add M-1
tensor([[[1, 1],
[1, 0],
[0, 1],
[0, 0]],
[[1, 2],
[1, 1],
[0, 2],
[0, 1]],
[[2, 1],
[2, 0],
[1, 1],
[1, 0]],
[[2, 2],
[2, 1],
[1, 2],
[1, 1]]])
# First column ( That's ok ) ride 2M-1(3)
tensor([[[3, 1],
[3, 0],
[0, 1],
[0, 0]],
[[3, 2],
[3, 1],
[0, 2],
[0, 1]],
[[6, 1],
[6, 0],
[3, 1],
[3, 0]],
[[6, 2],
[6, 1],
[3, 2],
[3, 1]]])
# Add row and column elements
tensor([[4, 3, 1, 0],
[5, 4, 2, 1],
[7, 6, 4, 3],
[8, 7, 5, 4]])
Here we get the relative position index , The corresponding value here needs to be up to relative positional bias Table In order to get , At the beginning of the program, a learnable table, The length is [2M-1]*[2M-1], here M=2, That is, the length is 9, It corresponds to the upper index 0-8
# define a parameter table of relative position bias
# Construct a learnable relative position offset table, The length is (2H-1)*(2W-1)*(num_head)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
Let's say there are two attention head
from torch import nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), 2)) # 2*Wh-1 * 2*Ww-1, nH Let's say I have two attn head
print(relative_position_bias_table.shape,"\n",relative_position_bias_table)
trunc_normal_(relative_position_bias_table, std=.02) # initialization bias_table
Output
torch.Size([9, 2]) # Two attn head , Every head (2M-1)*(2M-1) Number
Parameter containing:
tensor([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], requires_grad=True)
Parameter containing: # Data after initialization
tensor([[-0.0340, 0.0181],
[-0.0033, -0.0055],
[ 0.0045, 0.0193],
[ 0.0412, -0.0031],
[ 0.0004, -0.0032],
[ 0.0201, -0.0161],
[ 0.0067, 0.0079],
[ 0.0241, -0.0279],
[-0.0125, -0.0291]], requires_grad=True)
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)]
print("index :\n",relative_position_index.view(-1).shape,"\n",relative_position_index.view(-1))
print("bias table According to the data obtained from the index :\n",relative_position_bias.shape,"\n",relative_position_bias)
relative_position_bias=relative_position_bias.view(window_size[0] * window_size[1], window_size[0] * window_size[1], -1) # Wh*Ww,Wh*Ww,nH
print(" Dimensional transformation :\n",relative_position_bias.shape,"\n",relative_position_bias)
# Convert to and from attention shape Agreement
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
index :
torch.Size([16])
tensor([4, 3, 1, 0, 5, 4, 2, 1, 7, 6, 4, 3, 8, 7, 5, 4]) # Expand the index into one dimension
bias table According to the data obtained from the index :
torch.Size([16, 2])
tensor([[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055],
[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]], grad_fn=<IndexBackward>)
Dimensional transformation :
torch.Size([4, 4, 2])
tensor([[[ 0.0004, -0.0032],
[ 0.0412, -0.0031],
[-0.0033, -0.0055],
[-0.0340, 0.0181]],
[[ 0.0201, -0.0161],
[ 0.0004, -0.0032],
[ 0.0045, 0.0193],
[-0.0033, -0.0055]],
[[ 0.0241, -0.0279],
[ 0.0067, 0.0079],
[ 0.0004, -0.0032],
[ 0.0412, -0.0031]],
[[-0.0125, -0.0291],
[ 0.0241, -0.0279],
[ 0.0201, -0.0161],
[ 0.0004, -0.0032]]], grad_fn=<ViewBackward>)

The above code is all about relative position offset .
边栏推荐
- mysql下载和安装详细教程
- dtd建模
- Leader: who can use redis expired monitoring to close orders and get out of here!
- 云上“视界” 创新无限 | 2022阿里云直播峰会正式上线
- 金融服务行业SaaS项目管理系统解决方案,助力企业挖掘更广阔的增长服务空间
- The online procurement system of the electronic components industry accurately matches the procurement demand and leverages the digital development of the electronic industry
- Construction and practice of full stack code test coverage and use case discovery system
- 链表中环的入口结点-链表专题
- 基于STM32F1的环境光与微距离检测系统
- 医疗行业企业供应链系统解决方案:实现医疗数智化供应链协同可视
猜你喜欢

拓維信息使用 Rainbond 的雲原生落地實踐

sqlserver SQL Server Management Studio和Transact-SQL创建账户、创建访问指定数据库的只读用户

Troubleshooting MySQL for update deadlock

Redis入门到精通01

Full recharge, im+rtc+x full communication service "feedback season" starts
![[Collection - industry solutions] how to build a high-performance data acceleration and data editing platform](/img/56/9f3370eac60df182971607aa642dc2.jpg)
[Collection - industry solutions] how to build a high-performance data acceleration and data editing platform

dtd建模

Multipass Chinese document - setting graphical interface

3.10 haas506 2.0开发教程-example-TFT

MRO工业品采购管理系统:赋能MRO企业采购各节点,构建数字化采购新体系
随机推荐
EasyNVR平台设备通道均在线,操作出现“网络请求失败”是什么原因?
教你30分钟快速搭建直播间
rust配置国内源
Detailed single case mode
Solution of enterprise supply chain system in medical industry: realize collaborative visualization of medical digital intelligent supply chain
TCP packet sticking problem
挑选智能音箱时,首选“智能”还是“音质”?这篇文章给你答案
CTF流量分析常见题型(二)-USB流量
Tensorflow2 深度学习十必知
基于 actix、async-graphql、rbatis、pgsql/mysql 构建 GraphQL 服务(4)-变更服务
系统集成项目管理工程师认证高频考点:编制项目范围管理计划
领导:谁再用 Redis 过期监听实现关闭订单,立马滚蛋!
【TiDB】TiCDC canal_json的实际应用
联想YOGA 27 2022,超强配置全面升级
PyTorch学习(三)
Memory Limit Exceeded
Infineon - GTM architecture -generic timer module
英飞凌--GTM架构-Generic Timer Module
MySQL事务并发问题和MVCC机制
NEON优化2:ARM优化高频指令总结