当前位置:网站首页>【SwinTransformer源码阅读二】Window Attention和Shifted Window Attention部分
【SwinTransformer源码阅读二】Window Attention和Shifted Window Attention部分
2022-07-28 08:25:00 【qq_29750461】
先放一下SwinTransformer的整体结构,图片源于原论文,可以发现,在Transformer的Block中 W-MSA(Window based multi-head self attention) 和 SW-MSA是关键组成部分。W-MSA出现在某阶段的奇数层,SW-MSA出现在某阶段的偶数层,W-MSA考虑的是单个窗口的信息,SW-MSA考虑的是不同窗口间的信息。

虽然从网络架构图里看,W-MSA和SW-MSA为两个不同的模块,但是在代码层面,两者是同一个代码片段,只是在计算SW-MSA时候,在计算完W-MSA后,然后通过代码进行滑动窗口,即cyclic shift操作,多计算了一个mask的操作。下面将针对代码进行分析。
W-MSA的代码
【注意】注释第一句话:Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.
代码注释中的中文,是以配置文件中 swin-tiny 相关的量 来进行注释的。
#窗口注意力
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim#96*(2^layer_index 0,1,2,3...)
self.window_size = window_size # Wh, Ww (7,7)
self.num_heads = num_heads#[3, 6, 12, 24]
head_dim = dim // num_heads#(96//3=32,96*2^1 // 6=32,...)
self.scale = qk_scale or head_dim ** -0.5#default:head_dim ** -0.5
# define a parameter table of relative position bias
#定义相对位置偏置表格
#[(2*7-1)*(2*7-1),3]
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
#得到一对在窗口中的相对位置索引
coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
#让相对坐标从0开始
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
#relative_coords[:, :, 0] * (2*7-1)
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
#为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
#索引值为 (49,49) 值在0-168对应位置偏移表的索引
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
#dim*(dim*3)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#attn_drop=0.0
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
#初始化相对位置偏置值表(截断正态分布)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
#模块的前向传播
def forward(self, x, mask=None):
""" Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """
B_, N, C = x.shape#输入特征的尺寸
#(3, B_, num_heads, N, C // num_heads)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# q/k/v: [B_, num_heads, N, C // num_heads]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# q*head_dim ** -0.5
q = q * self.scale
# attn:B_, num_heads,N,N
attn = (q @ k.transpose(-2, -1))
# 在 随机在relative_position_bias_table中的第一维(169)选择position_index对应的值,共49*49个
#由于relative_position_bias_table第二维为 nHeads所以最终表变为了 49*49*nHead 的随机表
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
#attn每一个批次,加上随机的相对位置偏移 说民attn.shape=B_,num_heads,Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
#mask 在某阶段的奇数层为None 偶数层才存在
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
#进行 dropout
attn = self.attn_drop(attn)
#attn @ v:B_, num_heads, N, C/num_heads
#x: B_, N, C 其中
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
#经过一层全连接
x = self.proj(x)
#进行drop out
x = self.proj_drop(x)
return x
关于W-MSA中的注意力机制的运算,其实就是按照下面这个公式来进行的,在这个公式里,其实 QKV 三者均是又输入经过一个全连接层(nn.Linear())得到的,这个在代码里很好看明白。关键是在W-MSA中,增加了一个位置偏移量 B,这里的B相关计算也是W-MSA中的关键一步,下面进行记录下。


位置偏移量 B 的代码详解
这里关键是理解 relative_position_bias_table 和 relative_position_index ,这两个矩阵的对应关系设计的比较巧妙,即relative_position_index 刚好设计为 relative_position_bias_table 所对应的网格数量
# define a parameter table of relative position bias
#定义相对位置偏置表格
#[(2*7-1)*(2*7-1),3]
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
#得到一对在窗口中的相对位置索引
coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
#让相对坐标从0开始
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
#relative_coords[:, :, 0] * (2*7-1)
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
#为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
#索引值为 (49,49) 值在0-168对应位置偏移表的索引
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
#注册为不可学习变量
self.register_buffer("relative_position_index", relative_position_index)
#dim*(dim*3)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#attn_drop=0.0
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
#初始化相对位置偏置值表(截断正态分布)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
relative_position_bias_table :设置的一个可学习的 (2 x window_size[0]-1)x(2 x window_size[1]-1) x nHeads 的随机变量(利用截断正态分布赋值),如果以代码中第一个阶段的参数量为例,则 window_size[0]=window_size[1]=7, 在第一个阶段 nHeads=3 。即该表中存储的时候一系列的随机数,用于位置编码,提升模型的性能。

下面展示一个relative_position_bias_table的例子
relative_position_index :相对位置编码表的索引表,即存储的值,用来取得相对位置偏移量表relative_position_bias_table 中某个位置的值,relative_position_index中存的值所取范围为 [0,168],即relative_position_bias_table 的大小为 169(13 x 13)个单元格。通过下面的图片,可以看到 relative_position_index中0 和 168 位置的编码只取一次,其实符合传统transformer中对于位置编码的运用,即开头和结尾的位置编码只用一次。

关注到最后计算出的 attn 需要加上位置偏移量,则这里需要看一下 relative_position_bias的计算策略,即下面的图示
relative_position_bias的计算策略:
最终的 relative_position_bias, 即经过转置后和 attn 的后三维一致,进而就可以进行直接位置相加了。

SW-MSA (Shifted Window based multi-head self attention (SW-MSA) module )
SW-MSA的代码中关键步骤为 attn_mask 和 shift windows的操作,即通过对特征图移位,并给Attention设置mask来间接实现的,在保持原有的window个数下,节省计算。
首先来看 attn_mask
代码如下:
#奇数层没有shift_size 偶数层有 shift_size
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution#(56/(2^layer_index),56/(2^layer_index))
#zero_init:img_mask (1,H,W,1)
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
#h_slices :(slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))
#>>> c=range(0, 10)
#>>> c[h_slices[0]]
# range(0, 3)
#>>> c[h_slices[1]]
# range(3, 7)
#>>> c[h_slices[2]]
# range(7, 10)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
#将一个H*W的输入按照切片分为9块
#按照H维进行切片
for h in h_slices:
#按照W维进行切片
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
#将img_mask shape 1,H,W,1-> nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
#nW, window_size, window_size
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
#attn_mask:[nW, window_size * window_size, window_size * window_size]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
#矩阵中为0的置0 不为0的置 -100
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
关于程序中的数组切片,slice 部分,代码注释中有说明,注意 这里 window_size=7, shift_size=3,就不详细说明了,这里先针对 img_mask 来说明下,即下图的步骤,具体完成了哪些内容?在例子中,我让img_mask的 H=W=14
首先 img_mask 为和输入大小一致的张量,经过上面的slice代码的切片后,则形成了下面形状(1,14,14,1)的张量
注意到
如果直接将img_mask转为(14,14)的张量,我们可以看到其形状,相当于将张量根据slice切片,分为了9个部分,其中红色部分,不论img|_mask的 H W 为多少,始终为矩形,且大小为 (H-window_size)*(W-window_size),其余黄色的框基本大小是确定的,和 window_size 和 shift_size有关系。

其中attn_mask代码中的 mask_windows - > attn_mask的变换是关键的一步,这一步主要是让以 7*7 为单元的窗口中,块索引值相同的位置,置0,不同的位置 置为 -100 即直接屏蔽掉。

我们可以用以下代码模拟下,比如,a和b为shape为[2,3]的张量,则可以发现,a.unsqueeze(1) shape 为 [2,1,3], b.unsequeeze(2).shape 为 [2,3,1] ,最后经过 c = a.unsqueeze(1) - b.unsequeeze(2),c 变为了 shape [2,3,3],可以根据图中的计算过程,发现其实是 a [2,1,3] 中 b [2,3,1],就是 a 第一个 [1,3] 和 b 中 第一个[3,1] 中的每个元素进行减法操作,形成一个[3,3]的矩阵,然后再让 a 第二个 [1,3] 和 b 中 第二个[3,1] 中的每个元素进行减法操作,形成二个[3,3]的矩阵,最终形成了 [2,3,3] 的矩阵。
即经过上面的 mask_windows - > attn_mask 的运算,可以将不同窗口中,对应位置的索引相同值置0,不同值为两者的差值。

然后根据下面的代码进行同则赋 0,异则赋 -100。
#矩阵中为0的置0 不为0的置 -100
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
最后将得到的 attn_mask 与 得到的特征图 attn 进行相加
#mask 在某阶段的奇数层为None 偶数层才存在
if mask is not None:
#nW=B*H/7*W/7
#mask.shape:[B*H/7*W/7 , 49, 49]
nW = mask.shape[0]
#mask:torch.Size([1, 4, 1, 49, 49])
#attn.view:[B_ // nW, nW(4), self.num_heads, N, N]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
我们用程序模拟下面这个步骤,即假设 attn.view:[B_ // nW, nW(4), self.num_heads, N, N]=[1,2,2,2,2],而mask.unsqueeze(1).unsqueeze(0).shape=[1,2,1,2,2] , 如下面的图

所以根据代码 展开 attn_mask的计算过程,可以用图示表示:
通过图示可以发现,相当于强行将某些模块的样本用来计算对应mask的注意力值,这个属于对网络的一种约束了。且是强行分了 B_/nw 个模块,每个模块中交替进行计算对应那几个(nw)个mask的注意力。
说完了 attn_mask,再来看看 shift windows的操作,具体来讲,应该是一个特征图循环移位的操作,不过只移动了一次,所以直接用 shift 也可以理解。相关代码如下:
进行移动窗口
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
由于代码中,默认 fused_window_process 为 False,所以进行移动窗口主要代码是:
#这里的 x = x.view(B, H, W, C)
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
为与上面的例子对应,这里我们假设 shift_size = 3,由于 X的 shape为 [B,H,W,C] 所以可以看出,这个移位是在 H 和 W的维度分别移动 3
最后的shift恢复过程,就是上面 roll 和 partition 的反过程,代码中:
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
最后再来看一下 SwinTransformerBlock 的前向传播代码,即如果 shift_size>0 ,整体过程是对输入的整个特征图进行 循环移位 - > 然后进行带mask的注意力机制计算(SW-MSA)->再进行一系列后操作,这里并看不到针对某个窗口进行特征图移位和针对某个窗口进行 mask 均是针对整张特征图进行的相关操作。
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
所以看了代码,不禁怀疑网上有些关于swintransformer的教程,其实有些问题的,具体还是要看代码,如果有问题,欢迎留言指正!但为什么swintransformer仅仅进行了特征图循环移位和限制性的mask注意力机制,就有效果,其实还需要深究,个人感觉是多阶段连续后,其实特征图的循环移位,让越靠后的层,越考虑了全局的特征,这个还需要再看看代码了
边栏推荐
- Different HR labels
- Path and attribute labels of picture labels
- What content does the new version of network security level protection evaluation report template contain? Where can I find it?
- Mobaxtermsession synchronization
- Eight ways to solve EMC and EMI conducted interference
- Go interface advanced
- Go interface Foundation
- Quickly build a gateway service, dynamic routing and authentication process, and watch the second meeting (including the flow chart)
- Linux initializes MySQL with fatal error: could not find my-default.cnf
- TXT text file storage
猜你喜欢

Two dimensional array and operation

Kubernetes technology and Architecture (VII)

Learn to draw with nature communications -- complex violin drawing

站在大佬的肩膀上,你可以看的更远

Network interface network crystal head RJ45, Poe interface definition line sequence

Explain cache consistency and memory barrier

Div tags and span Tags

IDC脚本文件运行

C #, introductory tutorial -- debugging skills and logical error probe technology and source code when the program is running

How to obtain the subordinate / annotation information of KEGG channel
随机推荐
修改虚拟机IP地址
从开发转测试:我从零开始,一干就是6年的自动化测试历程
01-TensorFlow计算模型(一)——计算图
Basic syntax of jquey
I am a 27 year old technical manager, whose income is too high, and my heart is in a panic
Eight ways to solve EMC and EMI conducted interference
leetcode 452. Minimum Number of Arrows to Burst Balloons 用最少数量的箭引爆气球(中等)
This flick SQL timestamp_ Can ltz be used in create DDL
Bluetooth technology | it is reported that apple, meta and other manufacturers will promote new wearable devices, and Bluetooth will help the development of intelligent wearable devices
What are the main uses of digital factory management system
台大林轩田《机器学习基石》习题解答和代码实现 | 【你值得拥有】
mysql主从架构 ,主库挂掉重启后,从库怎么自动连接主库
Argocd Web UI loading is slow? A trick to teach you to solve
Why can ThreadLocal achieve thread isolation?
看完这12个面试问题,新媒体运营岗位就是你的了
Huid learning 7: Hudi and Flink integration
XMIND Zen installation tutorial
mysql5.7.38容器里启动keepalived
[cloud computing] several mistakes that enterprises need to avoid after going to the cloud
51单片机存储篇:EEPROM(I2C)