当前位置:网站首页>The most detailed swing transformer mask of window attachment in history -- Shaoshuai

The most detailed swing transformer mask of window attachment in history -- Shaoshuai

2022-06-13 03:57:00 cfsongbj

0、 Preface


I have seen it in recent days Swin-Transformer This paper , I don't understand the mask mechanism when I look at the code , Especially if you don't understand the code , and Swin The mask mechanism is one of the highlights of this paper , After consulting the data of all parties, I finally understood the principle .

1、 What is the masking mechanism ?

1.1 Sliding window mechanism (Shift windows)

To understand what a mask mechanism is , We need to know why we need a masking mechanism , That's why Swin The reason for the sliding window .

chart 1.1.1

as everyone knows ,Swin The self attention mechanism of is a window based self attention mechanism , Pictured 1.1.1 Shown , The window based self attention mechanism means that the connection between windows has disappeared , This actually goes against Transformer The global self attention mechanism of structure , Also lost Transformer Biggest advantage . therefore Swin Here, the author proposes a self - attention mechanism .

chart 1.1.2

Swin Another self attention mechanism in is the self attention mechanism based on sliding window , Pictured 1.1.2 Shown .

chart 1.1.3: The original image of the left position without sliding , Figure on the right shows the result of sliding

chart 1.1.3 The figure on the left refers to no sliding , There are four windows in the picture , Make self attention mechanism in these four windows respectively . chart 1.1.3 The figure on the right of refers to after sliding the window , Or do self attention in four windows , But these four windows are different from the four windows on the left . from 1.1.3 It can be seen that , Sliding refers to moving the left part and the top part of the picture to the right and bottom of the picture respectively , The size of the move is Shift-Size.

 

chart 1.1.4

chart 1.1.4 Divide the four windows into 9 A little window ( Here we are and graph 1.1.3 distinguish , use ( Numbers ) In the form of ), window (2) From the left B Moved over . Why should it be divided into 9 A small window , Let's explain one by one .

  • First, let's talk about 1.1.3 The window of 1(1.1.3 Right picture ), It's in the picture 1.1.4 Is divided into windows (0), This window can do self - attention directly after sliding .
  • Then we discuss 1.1.3 The window of 2(1.1.3 Right picture ), It's in the picture 1.1.4 Is divided into windows (1) and (2),(1) and (2) You can't pay attention to yourself , Because in the original picture ,(2) From the left B,(2) On (1) It is not adjacent in the original drawing , So you shouldn't pay attention to yourself .
  • 1.1.3 The window of 3 Divided into (3) and (6), It should not do self - attention .
  • 1.1.3 The window of 4 Divided into (4),(5),(7),(8), There should be no self attention mechanism between them .

Based on the above questions , The author hopes to propose a mask mechanism , Make it possible to do self - attention after sliding , You shouldn't do what you pay attention to .

2、 The concrete implementation of the mask mechanism

2.1 The author explains

In fact, many people don't understand the mask mechanism , therefore GIt Some people have raised questions to the author on this issue , The author also gives a very wonderful answer , Here we post the address :https://github.com/microsoft/Swin-Transformer/issues/38

2.2 Code reading

The answers posted by the author are actually detailed enough , But in case some students still don't understand , Let's do some code reading , The code is the code posted by the author in the answer , I paste it here :

import torch

import matplotlib.pyplot as plt


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


window_size = 7
shift_size = 3
H, W = 14, 14
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

plt.matshow(img_mask[0, :, :, 0].numpy())
plt.matshow(attn_mask[0].numpy())
plt.matshow(attn_mask[1].numpy())
plt.matshow(attn_mask[2].numpy())
plt.matshow(attn_mask[3].numpy())

plt.show()

  At the same time, I will post the results :

  chart 2.2.1

Here we look at the code directly , The author first defines several variables , Window size 、 Sliding step 、 Picture size . after , The author generated a (1,h,w,1) Of img_mask, Here the dimension corresponds to (B,h,w,C), namely Batch_Size, Picture size ,C Channels , This is to correspond to the dimensions in the formal project , It can be understood as generating an image of the same size img_mask.

Next, the author defines h_slices and w_slices Two slices , Is in the h and w Divide the picture into three pieces in two directions , That adds up to nine dollars , Here are two for Cycle to achieve , Give different values to different parts , The final results are as follows :

chart 2.2.2 

chart 2.2.2 Medium 0-8 Anaphora 9 A little window , It also represents the value assigned to each window .Window0-3 Refers to a large window . We mentioned this earlier .

chart 2.2.3

Here for the convenience of displaying , We remove two dimensions to show img_mask( This is the picture 2.2.3 What the red box does ), Pictured 2.2.3 Shown , You can see  img_mask Divided into 9 Each part is assigned a different value at the same time .

next img_mask Through window_partition After the method , Dimension for [4,7,7,1], That is 14*14 The picture of is divided into four lengths and widths 7 The little window of , Then the dimension is further transformed into [4,49], Here are all routine operations .

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

The most critical code is here , here mask_windows The dimensions are [4,49], Represents a total of four windows , Each window has 7*7 It's worth .unsqueeze Represents adding dimension , It actually uses pytorch Our radio ,mask_windows.unsqueeze(1) Turn into [4,1,49],mask_windows.unsqueeze(1) Turn into [4,49,1], Subtracting the two becomes [4,49,49].

Here we define two temporary variables to represent :

temp1=mask_windows.unsqueeze(2)
temp2=mask_windows.unsqueeze(1)
temp1_1=np.repeat(temp1,49,axis=2)
temp2_2=np.repeat(temp2,49,axis=1)

temp1 The dimensions are [4,49,1],temp2 The dimensions are [4,1,49], Subtracting the two is equivalent to temp1 In dimension 2 Copy on 49 The second change is [4,49,49],temp2 Dimensionality 1 Copy on 49 Time , Turn into [4,49,49], So you can subtract .

The dimensions here 4 Represents four windows , Let's take the second window as an example , take temp1[1] The output is the second window and expands :[1,1,1,1,2,2,2,.......,1,1,1,1,2,2,2], Expand corresponding to the red box in the following figure :

  chart 2.2.4: chart 2.2.3 The window of 1

take temp1 In dimension 2 Copy on 49 Time , The dimension is temp1_1[4,49,49], We output here temp1_1 The second window of the is temp1_1[1]:

chart 2.2.5:temp1_1[1] The picture is incomplete , Should be [49,49], The rest is self-help

  Here we have only partially intercepted , You can see that it is equivalent to a column of copies 49 Time , Got [49,49] The values in each row of are equal , And the value of each row corresponds to the window respectively 1 All 49 It's worth .

Again we unfold temp2_2 My second window temp2_2[1]:

  chart 2.2.6:temp2_2[1]

temp1_1 It is equivalent to... In a window 49  Values are copied in the direction of the column 49 Time ,temp2_2 It is equivalent to... In a window 49 Values are copied in the direction of the row 49 Time , So this 49 The rows are equal , The value of each row is equivalent to the window 1 All within 49 It's worth .

After this operation ,temp1_1 Of [49,49] No [1,49] Equivalent to window 1 The value of the first pixel of is copied 49 Time ,[2,49] Equivalent to window 1 The value of the second pixel of is copied 49 Time , And so on .temp2_2 Of [49,49] Of [1,49] Equivalent to window 1 In all of the 49 It's worth ,[2,49] It's the same thing 49 It's worth .( It's hard to understand here , Let's compare the pictures 2.2.5 and 2.2.6 understand )

temp1_1-temp2_2 Each value in the window must be subtracted from all values , We remember that we gave 9 Nine small windows are assigned different values , So if you subtract here , Pixels that are not in a small window are subtracted , Will become non 0, And one subtraction in a small window , It will become 0.

In fact, we have implemented the mask mechanism here , Will not 0 Value to -100( A very small negative value ), That is, I don't think I will pay attention to .

3、 summary

So far, we have realized the right Swin Explanation of the mask mechanism , I hope you can understand some of the bad points .

原网站

版权声明
本文为[cfsongbj]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/164/202206130346323767.html