当前位置:网站首页>Swin-Transformer(2021-08)
Swin-Transformer(2021-08)
2022-06-30 19:06:00 【Gy Zhao】
brief introduction
Until I write this note , be based on Swin Our model still dominates the list Object Detection Wait for multiple lists .
Many blogs have introduced it in great detail , Here I only record the puzzles I encountered during my study .
Swin And ViT Comparison of ,ViT take image Divided into fixed size patch, With patch In units attention Calculation , In the process of calculation feature map The resolution is unchanged , also ViT In order to keep up with NLP The consistency of , Added an additional class token, Finally used for classification . therefore ViT Not suitable for downstream tasks such as detection , Unable to extract multi-scale features .
to want to transformer For visual tasks such as detection , One is imitation CNN, take transformer Transformed into a hierarchical structure , One is to use pure transformer Structure to explore .
Swin Obviously belongs to the former , use Local window self attention and shift window In a clever way Hierarchical structure , It can be used as a general purpose in the visual field backbone Use .
Above, Swin-T Structure diagram , The input image first passes through Patch Partitiion and Linear Embedding Turn into token Vector sequence of form , Then input Swin Transformer Block in , Every Block All by one window–Multi-Head self Attention and Shift-Window Multi-head self-attention form , So it's always even .
window partition Window partition
take (B, H, W, C) Divided into (num_windows*B, window_size, window_size, C) Of windows
def window_partition(x, window_size):
""" Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
#view() Must be directed at contiguous Data storage format
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
# Restore the original input x
def window_reverse(windows, window_size, H, W):
""" Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
The test runs as follows
x = torch.randn(1,224,224,3)
p =window_partition(x,7)
print(p.size())
o = window_reverse(p,7,224,224)
print(o.size())
***output
torch.Size([1024, 7, 7, 3])
torch.Size([1, 224, 224, 3])***
partition Is to enter image Convert to specify window size Of patch vector , Here is the will (1 ,224,224,3) Of batch Convert to 1024 Window size is (7,7) Of patch.
reverse The function is partition The inverse function of
PatchEmbedding
[B C H W]->[B ,Ph*Pw,96]
import torch
from torch import nn
from timm.models.layers import to_2tuple
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size) #(224,224)
patch_size = to_2tuple(patch_size) #patch size (4,4)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] #(56,56)
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution # Resolution refers to patch Count
self.num_patches = patches_resolution[0] * patches_resolution[1] #56*56=3136
self.in_chans = in_chans
self.embed_dim = embed_dim
# Use 2d Convolution is carried out patch Divide , Input channe The default is 3
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) #output:(batch,96,56,56)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# The limit image size must be 224*224
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({
H}*{
W}) doesn't match model ({
self.img_size[0]}*{
self.img_size[1]})."
#B C h W -> b embed h*w ->b h*W embed obtain Patch Embeding shape
x = self.proj(x) #Patch Divide [1,96,56,56]
x=x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
x = torch.randn(1,3,224,224)
PM = PatchEmbed()
out = PM(x)
print(out.shape)
**output
torch.Size([1, 3136, 96])**
PatchMerging
[B,H*W,C]->[B,H/2,W/2,2C]
PatchMerging Equivalent to down sampling in convolution , Reduce feature map The resolution of the , At the same time increase channel Dimensions , Here the resolution H,W All reduced to half of the original ( Overall decrease 4 times ),channel The number has doubled .
class PatchMerging(nn.Module):
r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
""" x: B, H*W, C """
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size" # Limit input x The second dimension of is related to H,W matching
assert H % 2 == 0 and W % 2 == 0, f"x size ({
H}*{
W}) are not even."
x = x.view(B, H, W, C)
# according to H W The interval is 2 , A combination of two common 4 Group
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
print("\n",x0,"\n",x1,"\n",x2,"\n",x3)
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C # The resolution of the feature map is reduced by half
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C At this point, the channel dimension becomes the original 4 times
x = self.norm(x)
x = self.reduction(x) #4C -> 2C adopt linear The channel dimension consists of 4 Times become the original 2 times
return x
def extra_repr(self) -> str:
return f"input_resolution={
self.input_resolution}, dim={
self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
PM = PatchMerging(input_resolution=(4,4),dim=3)
x = torch.arange(48,dtype=torch.float).view(1,16,3)
o=PM(x)
print(o.shape)
output:
***tensor([[[[ 0., 1., 2.],
[ 6., 7., 8.]],
[[24., 25., 26.],
[30., 31., 32.]]]])***
***tensor([[[[12., 13., 14.],
[18., 19., 20.]],
[[36., 37., 38.],
[42., 43., 44.]]]])***
***tensor([[[[ 3., 4., 5.],
[ 9., 10., 11.]],
[[27., 28., 29.],
[33., 34., 35.]]]])***
***tensor([[[[15., 16., 17.],
[21., 22., 23.]],
[[39., 40., 41.],
[45., 46., 47.]]]])***
**torch.Size([1, 4, 6])**
Window Attention
Window self attention calculation , Limited to one window in , The formula is compared with the previous attention Added an additional B- Relative position offset (relative postional bias)
From the results given in the paper , The result of relative position offset is better than other methods .
mask shift-window visualization
In a more intuitive way , Colored results need to be preserved by self - attention calculation , Other white parts need to be ignored (mask) The place of , You can refer to the results generated from the following code, which are consistent with this figure .
#attn mask Part of the test code
# Assume that the input image by 4X4 size , Divided into 4 individual window, Every window by 2X2 size
input_resolution =(4,4)
window_size =2
shift_size=1 #shift_size That is to say [M/2] Round down
def window_partition(x, window_size):#[B H W C]->[BHW/(window size)^2 , window size,winsow size,C ]
""" Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
if shift_size > 0:
# calculate attention mask for SW-MSA
H, W = input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
# Pre planned partition location index -slice(start,stop,step)
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
print(attn_mask)
output
tensor([[[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]],
[[ 0., -100., 0., -100.],
[-100., 0., -100., 0.],
[ 0., -100., 0., -100.],
[-100., 0., -100., 0.]],
[[ 0., 0., -100., -100.],
[ 0., 0., -100., -100.],
[-100., -100., 0., 0.],
[-100., -100., 0., 0.]],
[[ 0., -100., -100., -100.],
[-100., 0., -100., -100.],
[-100., -100., 0., -100.],
[-100., -100., -100., 0.]]])
#mask shift window Visual code
import torch
import matplotlib.pyplot as plt
def window_partition(x, window_size):
""" Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
window_size = 7
shift_size = 3
H, W = 14, 14
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
plt.matshow(img_mask[0, :, :, 0].numpy())
plt.matshow(attn_mask[0].numpy())
plt.matshow(attn_mask[1].numpy())
plt.matshow(attn_mask[2].numpy())
plt.matshow(attn_mask[3].numpy())
plt.show()
Swin-X Parameter configuration list
Inclusion relation :Swin_transformer(Basic_layer(Swin_Block))
reference( Recommended reading ):
1.https://zhuanlan.zhihu.com/p/367111046
2. https://hub.fastgit.xyz/microsoft/Swin-Transformer/issues/38
3. https://zhuanlan.zhihu.com/p/430047908 - Very clear
4. https://blog.csdn.net/qq_37541097/article/details/121119988
5. https://blog.csdn.net/qq_39478403/article/details/120042232
边栏推荐
- Electronic components bidding and purchasing Mall: optimize traditional purchasing business and speed up enterprise digital upgrading
- Entry node of link in linked list - linked list topic
- MySQL transaction concurrency and mvcc mechanism
- [零基础学IoT Pwn] 环境搭建
- PC端微信多开
- [community star selection] the 23rd issue of the July revision plan | bit by bit creation, converging into a tower! Huawei freebuses 4E and other cool gifts
- mysql下载和安装详细教程
- Pytorch learning (III)
- Glacier teacher's book
- What if the apple watch fails to power on? Apple watch can not boot solution!
猜你喜欢
DTD modeling
教你30分钟快速搭建直播间
不同制造工艺对PCB上的焊盘的影响和要求
Cloud Native Landing Practice Using rainbond for extension dimension information
Multipass Chinese document - setting graphical interface
Classic problem of leetcode dynamic programming (I)
Geoffrey Hinton:我的五十年深度学习生涯与研究心法
AI首席架构师10-AICA-蓝翔 《飞桨框架设计与核心技术》
云上“视界” 创新无限 | 2022阿里云直播峰会正式上线
Digital intelligent supplier management system solution for coal industry: data driven, supplier intelligent platform helps enterprises reduce costs and increase efficiency
随机推荐
dtd建模
Distributed transaction
DTD modeling
Pytorch learning (III)
ForkJoinPool
基于 actix、async-graphql、rbatis、pgsql/mysql 构建 GraphQL 服务(4)-变更服务
《客从何处来》
Cloud Native Landing Practice Using rainbond for extension dimension information
Infineon - GTM architecture -generic timer module
System integration project management engineer certification high frequency examination site: prepare project scope management plan
Opengauss database source code analysis series articles -- detailed explanation of dense equivalent query technology (Part 1)
3.10 haas506 2.0开发教程-example-TFT
基于STM32F1的环境光与微距离检测系统
《被讨厌的勇气:“自我启发之父”阿德勒的哲学课》
At present, the big guys are joining the two streams of flinksql, cdcmysql and Kafka, and the results are put into MySQL or KA
深度学习编译器的理解
Tensorflow2 ten must know for deep learning
医疗行业企业供应链系统解决方案:实现医疗数智化供应链协同可视
Geoffrey Hinton:我的五十年深度学习生涯与研究心法
Can go struct in go question bank · 15 be compared?