当前位置:网站首页>The most detailed swing transformer mask of window attachment in history -- Shaoshuai
The most detailed swing transformer mask of window attachment in history -- Shaoshuai
2022-06-13 03:57:00 【cfsongbj】
0、 Preface
I have seen it in recent days Swin-Transformer This paper , I don't understand the mask mechanism when I look at the code , Especially if you don't understand the code , and Swin The mask mechanism is one of the highlights of this paper , After consulting the data of all parties, I finally understood the principle .
1、 What is the masking mechanism ?
1.1 Sliding window mechanism (Shift windows)
To understand what a mask mechanism is , We need to know why we need a masking mechanism , That's why Swin The reason for the sliding window .
chart 1.1.1
as everyone knows ,Swin The self attention mechanism of is a window based self attention mechanism , Pictured 1.1.1 Shown , The window based self attention mechanism means that the connection between windows has disappeared , This actually goes against Transformer The global self attention mechanism of structure , Also lost Transformer Biggest advantage . therefore Swin Here, the author proposes a self - attention mechanism .
chart 1.1.2
Swin Another self attention mechanism in is the self attention mechanism based on sliding window , Pictured 1.1.2 Shown .
chart 1.1.3: The original image of the left position without sliding , Figure on the right shows the result of sliding
chart 1.1.3 The figure on the left refers to no sliding , There are four windows in the picture , Make self attention mechanism in these four windows respectively . chart 1.1.3 The figure on the right of refers to after sliding the window , Or do self attention in four windows , But these four windows are different from the four windows on the left . from 1.1.3 It can be seen that , Sliding refers to moving the left part and the top part of the picture to the right and bottom of the picture respectively , The size of the move is Shift-Size.
chart 1.1.4
chart 1.1.4 Divide the four windows into 9 A little window ( Here we are and graph 1.1.3 distinguish , use ( Numbers ) In the form of ), window (2) From the left B Moved over . Why should it be divided into 9 A small window , Let's explain one by one .
- First, let's talk about 1.1.3 The window of 1(1.1.3 Right picture ), It's in the picture 1.1.4 Is divided into windows (0), This window can do self - attention directly after sliding .
- Then we discuss 1.1.3 The window of 2(1.1.3 Right picture ), It's in the picture 1.1.4 Is divided into windows (1) and (2),(1) and (2) You can't pay attention to yourself , Because in the original picture ,(2) From the left B,(2) On (1) It is not adjacent in the original drawing , So you shouldn't pay attention to yourself .
- 1.1.3 The window of 3 Divided into (3) and (6), It should not do self - attention .
- 1.1.3 The window of 4 Divided into (4),(5),(7),(8), There should be no self attention mechanism between them .
Based on the above questions , The author hopes to propose a mask mechanism , Make it possible to do self - attention after sliding , You shouldn't do what you pay attention to .
2、 The concrete implementation of the mask mechanism
2.1 The author explains
In fact, many people don't understand the mask mechanism , therefore GIt Some people have raised questions to the author on this issue , The author also gives a very wonderful answer , Here we post the address :https://github.com/microsoft/Swin-Transformer/issues/38
2.2 Code reading
The answers posted by the author are actually detailed enough , But in case some students still don't understand , Let's do some code reading , The code is the code posted by the author in the answer , I paste it here :
import torch
import matplotlib.pyplot as plt
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
window_size = 7
shift_size = 3
H, W = 14, 14
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
plt.matshow(img_mask[0, :, :, 0].numpy())
plt.matshow(attn_mask[0].numpy())
plt.matshow(attn_mask[1].numpy())
plt.matshow(attn_mask[2].numpy())
plt.matshow(attn_mask[3].numpy())
plt.show()
At the same time, I will post the results :
chart 2.2.1
Here we look at the code directly , The author first defines several variables , Window size 、 Sliding step 、 Picture size . after , The author generated a (1,h,w,1) Of img_mask, Here the dimension corresponds to (B,h,w,C), namely Batch_Size, Picture size ,C Channels , This is to correspond to the dimensions in the formal project , It can be understood as generating an image of the same size img_mask.
Next, the author defines h_slices and w_slices Two slices , Is in the h and w Divide the picture into three pieces in two directions , That adds up to nine dollars , Here are two for Cycle to achieve , Give different values to different parts , The final results are as follows :
chart 2.2.2
chart 2.2.2 Medium 0-8 Anaphora 9 A little window , It also represents the value assigned to each window .Window0-3 Refers to a large window . We mentioned this earlier .
chart 2.2.3
Here for the convenience of displaying , We remove two dimensions to show img_mask( This is the picture 2.2.3 What the red box does ), Pictured 2.2.3 Shown , You can see img_mask Divided into 9 Each part is assigned a different value at the same time .
next img_mask Through window_partition After the method , Dimension for [4,7,7,1], That is 14*14 The picture of is divided into four lengths and widths 7 The little window of , Then the dimension is further transformed into [4,49], Here are all routine operations .
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
The most critical code is here , here mask_windows The dimensions are [4,49], Represents a total of four windows , Each window has 7*7 It's worth .unsqueeze Represents adding dimension , It actually uses pytorch Our radio ,mask_windows.unsqueeze(1) Turn into [4,1,49],mask_windows.unsqueeze(1) Turn into [4,49,1], Subtracting the two becomes [4,49,49].
Here we define two temporary variables to represent :
temp1=mask_windows.unsqueeze(2)
temp2=mask_windows.unsqueeze(1)
temp1_1=np.repeat(temp1,49,axis=2)
temp2_2=np.repeat(temp2,49,axis=1)
temp1 The dimensions are [4,49,1],temp2 The dimensions are [4,1,49], Subtracting the two is equivalent to temp1 In dimension 2 Copy on 49 The second change is [4,49,49],temp2 Dimensionality 1 Copy on 49 Time , Turn into [4,49,49], So you can subtract .
The dimensions here 4 Represents four windows , Let's take the second window as an example , take temp1[1] The output is the second window and expands :[1,1,1,1,2,2,2,.......,1,1,1,1,2,2,2], Expand corresponding to the red box in the following figure :
chart 2.2.4: chart 2.2.3 The window of 1
take temp1 In dimension 2 Copy on 49 Time , The dimension is temp1_1[4,49,49], We output here temp1_1 The second window of the is temp1_1[1]:
chart 2.2.5:temp1_1[1] The picture is incomplete , Should be [49,49], The rest is self-help
Here we have only partially intercepted , You can see that it is equivalent to a column of copies 49 Time , Got [49,49] The values in each row of are equal , And the value of each row corresponds to the window respectively 1 All 49 It's worth .
Again we unfold temp2_2 My second window temp2_2[1]:
chart 2.2.6:temp2_2[1]
temp1_1 It is equivalent to... In a window 49 Values are copied in the direction of the column 49 Time ,temp2_2 It is equivalent to... In a window 49 Values are copied in the direction of the row 49 Time , So this 49 The rows are equal , The value of each row is equivalent to the window 1 All within 49 It's worth .
After this operation ,temp1_1 Of [49,49] No [1,49] Equivalent to window 1 The value of the first pixel of is copied 49 Time ,[2,49] Equivalent to window 1 The value of the second pixel of is copied 49 Time , And so on .temp2_2 Of [49,49] Of [1,49] Equivalent to window 1 In all of the 49 It's worth ,[2,49] It's the same thing 49 It's worth .( It's hard to understand here , Let's compare the pictures 2.2.5 and 2.2.6 understand )
temp1_1-temp2_2 Each value in the window must be subtracted from all values , We remember that we gave 9 Nine small windows are assigned different values , So if you subtract here , Pixels that are not in a small window are subtracted , Will become non 0, And one subtraction in a small window , It will become 0.
In fact, we have implemented the mask mechanism here , Will not 0 Value to -100( A very small negative value ), That is, I don't think I will pay attention to .
3、 summary
So far, we have realized the right Swin Explanation of the mask mechanism , I hope you can understand some of the bad points .
边栏推荐
- [multithreading] what is multithreading in the end -- the elementary level of multithreading (review for self use)
- 谈谈激光雷达的波长
- Lambda termination operation Max & min
- Principle and control program of single chip microcomputer serial port communication
- 单片机:A/D 和 D/A 的基本概念
- 单片机:D/A 输出
- Difference between OKR and KPI
- 单片机:Modbus 通信协议介绍
- [test development] file compression project practice
- Goframe day 4
猜你喜欢
随机推荐
MCU: NEC protocol infrared remote controller
【测试开发】自动化测试selenium篇(一)
高等数学(第七版)同济大学 习题1-2 个人解答
单片机串口通信原理和控制程序
SCM: introduction to Modbus communication protocol
swap()
单片机:EEPROM介绍与操作
单片机外设介绍:温度传感器 DS18B20
学生管理系统
【测试开发】用例篇
SCM signal generator program
Lambda终结操作查找与匹配anyMatch
【多线程】多线程到底是个甚——多线程初阶(复习自用)
ET框架-22 创建ServerInfo实体及事件
Redis-HyperLogLog-基数统计算法
Binocular vision -- creating an "optimal solution" for outdoor obstacle avoidance
【Web】Cookie 和 Session
单片机:PCF8591硬件接口
[multithreading] what is multithreading in the end -- the elementary level of multithreading (review for self use)
Lambda终结操作查找与匹配findAny