当前位置:网站首页>[semantic segmentation] 2021-pvt iccv
[semantic segmentation] 2021-pvt iccv
2022-07-29 10:16:00 【Talk about it】
List of articles
【 Semantic segmentation 】2021-PVT ICCV
Thesis title :Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
Thesis link :https://arxiv.org/abs/2102.12122
Paper code : https://github.com/whai362/PVT
Thesis translation :PVT,PVTv2 - Simple books (jianshu.com)
1. brief introduction
1.1 brief introduction
The previous summary ViT backbone, It is not aimed at vision, such as segmentation 、 Detection and other intensive predictive tasks , Design suitable structure . follow-up SERT Waiting for the paper is simply VIT As Encoder, The single-scale features extracted from it are analyzed by some simple methods Decoder To deal with , Verified transformer Effect on semantic segmentation task . however , We know , On the task of semantic segmentation , Multiscale features are very important , So in PVT A method of extracting multi-scale features is proposed vision transformer backbone.
1.2 ViT Problems in semantic segmentation
We know ,ViT The output characteristic diagram and input size in the design scheme of are basically consistent . Apply it to segmentation 、 Detection and other intensive prediction tasks will face two problems :
1) Computing overhead soared
Segmentation and detection are relative to classification tasks , Large resolution image input is often required .
therefore , One side , We need to divide more than classification tasks patch To get the same granularity characteristics . If you still keep the same patch Number , Then the granularity of the feature will become coarser , This leads to performance degradation
On the other hand , We know ,Transformer The computational overhead is similar to token After melting patch Quantity is positively correlated , patch The larger the number , The more computation overhead . therefore , If we increase patch Number , It may make our already poor computing resources worse .
Above is ViT The first flaw applied to intensive forecasting tasks .
2) Lack of multiscale features
ViT The output characteristic diagram is basically consistent with the input size . This leads to ViT As Encoder when , Only single scale features can be output .
And in the CNN in , Multiscale features have long been proven to be useful for segmentation 、 Detection and other tasks play an important role , Some classic jobs such as deeplab series 、PSPNet And other effective use of multi-scale features to improve performance .
therefore , How to use it vision transformer Acquiring multiscale features is another problem .
improvement
In computer vision CNN backbone After years of development , Precipitated some general design patterns . The most typical is the pyramid structure .
A simple summary is :
1)feature map Of The resolution increases as the network deepens , Gradually decrease ;
2)feature map Of channel As the network deepens , Gradually increase .
Almost all intensive predictions (dense prediction) Algorithms are designed around the feature pyramid
How can this structure be introduced into Transformer Inside? ?
It turns out that : Simply stack multiple independent Transformer encoder effect It's the best .
And then we got PVT, As shown in the figure below . At every Stage in adopt Patch Embedding To gradually reduce the resolution of the input .
2. The Internet
2.1 The overall architecture

The model is generally composed of 4 individual stage form , Every stage contain Patch Embedding And a number of Transformer modular ( Relative to the original transformer There are changes ) form .
2.2 Patch embedding
At every stage Start , First of all, it looks like ViT The input image is processed in the same way token turn , That is to say patch embedding,patch Size divided by 1 individual stage Yes. 4 × 4 4\times 4 4×4 Outside , The rest are adopted 2 × 2 2\times 2 2×2 size . This idea is somewhat similar to pooling or convolution with step size , Reduce the resolution of the image , Make the model can extract more abstract information . This means that every stage( first stage With the exception of ) The dimension of the final feature graph is halved ,tokens Corresponding reduction in quantity 4 times . Every patch Then it will be sent to the first floor Linear in , Adjust the number of channels , And then again reshape In order to patch token turn .
This makes PVT On the whole, it is similar to resnet Looks like ,4 individual stage Compared with the original image, the size of the obtained feature image is 1/4,1/8,1/16 and 1/32. It also means that PVT Can produce features of different scales .
Note: Because of different stage Of tokens The quantity is different , So each stage Use different position embeddings, stay patch embed Then add their respective position embedding, When the input image size changes ,position embeddings It can also be adapted by interpolation .
2.3 Spatial-reduction attention(SRA)
stay Patch embedding after , Need to put token After melting patch Enter into several transformer Module . In order to further reduce the amount of calculation , The author will multi-head attention (MHA) Use the proposed spatial-reduction attention (SRA) To replace . From the name, we can see what this replacement really does . hold Q,K,V The spatial resolution of is reduced to reduce the amount of parameters . Sure enough , The author in MHA Lieutenant general K and V The resolution of is reduced R times . The schematic diagram is as follows .

On the implementation , First, the dimension is ( H W , C ) (HW,C) (HW,C) Of K,V adopt reshape Change to dimension ( H , W , C ) (H,W, C) (H,W,C) Of 3-D Characteristics of figure , Then divide the size equally into R ∗ R R * R R∗R Of patchs, Every patchs Through linear transformation, the dimension will be ( H ∗ W / R ∗ R , C ) (H*W / R*R,C) (H∗W/R∗R,C) Of patch embeddings( The realization here is actually the same as patch emb The operation is similar to , Equivalent to a convolution operation ), Finally, apply a layer norm layer , This can greatly reduce K and V The number of .
Every stage, After a number of SRA After the processing of the module , The resulting features , Again reshape become 3D Input the form of characteristic diagram to the next Stage in .
2.4 General overview

among P by patch Of size,C For the dimension of characteristics ,R It has been explained before ,N For bulls attention Of head Number ,E by FFN Expansion coefficient of .
3. Code
image A thousand categories of
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
__all__ = [
'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large'
]
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {
dim} should be divided by num_heads {
num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
f"img_size {
img_size} should be divided by patch_size {
patch_size}."
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
H, W = H // self.patch_size[0], W // self.patch_size[1]
return x, (H, W)
class PyramidVisionTransformer(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1]):
super().__init__()
self.num_classes = num_classes
self.depths = depths
# patch_embed
self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# pos_embed
self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0]))
self.pos_drop1 = nn.Dropout(p=drop_rate)
self.pos_embed2 = nn.Parameter(torch.zeros(1, self.patch_embed2.num_patches, embed_dims[1]))
self.pos_drop2 = nn.Dropout(p=drop_rate)
self.pos_embed3 = nn.Parameter(torch.zeros(1, self.patch_embed3.num_patches, embed_dims[2]))
self.pos_drop3 = nn.Dropout(p=drop_rate)
self.pos_embed4 = nn.Parameter(torch.zeros(1, self.patch_embed4.num_patches + 1, embed_dims[3]))
self.pos_drop4 = nn.Dropout(p=drop_rate)
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm = norm_layer(embed_dims[3])
# cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
# classification head
self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
# init weights
trunc_normal_(self.pos_embed1, std=.02)
trunc_normal_(self.pos_embed2, std=.02)
trunc_normal_(self.pos_embed3, std=.02)
trunc_normal_(self.pos_embed4, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
# return {'pos_embed', 'cls_token'} # has pos_embed may be better
return {
'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
# stage 1
x, (H, W) = self.patch_embed1(x)
x = x + self.pos_embed1
x = self.pos_drop1(x)
for blk in self.block1:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 2
x, (H, W) = self.patch_embed2(x)
x = x + self.pos_embed2
x = self.pos_drop2(x)
for blk in self.block2:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 3
x, (H, W) = self.patch_embed3(x)
x = x + self.pos_embed3
x = self.pos_drop3(x)
for blk in self.block3:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
# stage 4
x, (H, W) = self.patch_embed4(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed4
x = self.pos_drop4(x)
for blk in self.block4:
x = blk(x, H, W)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _conv_filter(state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {
}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
@register_model
def pvt_tiny(pretrained=False, **kwargs):
model = PyramidVisionTransformer(
patch_size=4,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
**kwargs)
model.default_cfg = _cfg()
return model
@register_model
def pvt_small(pretrained=False, **kwargs):
model = PyramidVisionTransformer(
patch_size=4,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def pvt_medium(pretrained=False, **kwargs):
model = PyramidVisionTransformer(
patch_size=4,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
depths=[3, 4, 18, 3],
sr_ratios=[8, 4, 2, 1],
**kwargs)
model.default_cfg = _cfg()
return model
@register_model
def pvt_large(pretrained=False, **kwargs):
model = PyramidVisionTransformer(
patch_size=4,
embed_dims=[64, 128, 320, 512],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
**kwargs)
model.default_cfg = _cfg()
return model
@register_model
def pvt_huge_v2(pretrained=False, **kwargs):
model = PyramidVisionTransformer(
patch_size=4,
embed_dims=[128, 256, 512, 768],
num_heads=[2, 4, 8, 12],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
depths=[3, 10, 60, 3],
sr_ratios=[8, 4, 2, 1],
# drop_rate=0.0, drop_path_rate=0.02)
**kwargs)
model.default_cfg = _cfg()
return model
if __name__ == '__main__':
x=torch.randn(1,3,224,224)
model=pvt_small()
y=model(x)
print(y.shape)
边栏推荐
- [wechat applet] interface generates customized homepage QR code
- Node (II)
- JS temporary dead zone_ Temporary
- CS research assurance experience in 2021 (VI): system filling + some thoughts
- [paper reading] i-bert: integer only Bert quantification
- MySQL infrastructure: SQL query statement execution process
- 跟着田老师学实用英语语法(持续更新)
- 10 suggestions for 10x improvement of application performance
- 【AAAI】用于交通流预测的基于注意力的时空图卷积网络
- 跟着李老师学线代——矩阵(持续更新)
猜你喜欢

云服务大厂高管大变阵:技术派让位销售派

Only simple function test? One article takes you to advanced interface automatic testing technology in 6 steps

Implementation and verification logic of complex expression input component

Function - (C travel notes)

根据给定字符数和字符,打印输出“沙漏”和剩余数

How can Plato obtain premium income through elephant swap in a bear market?

Leetcode question brushing - sorting

Shell笔记(超级完整)

Summary of window system operation skills
![[ts]Typescript学习记录坑点合集](/img/4c/14991ea612de8d5c94b758174a1c26.png)
[ts]Typescript学习记录坑点合集
随机推荐
“为机器立心”:朱松纯团队搭建人与机器人的价值双向对齐系统,解决人机协作领域的重大挑战
Science fiction style, standard 6 airbags, popular · yachts from 119900
[ts]Typescript学习记录坑点合集
CS research assurance experience in 2021 (VI): system filling + some thoughts
静态资源映射
Reasons for the rise of DDD and its relationship with microservices
A little knowledge ~ miscellaneous notes on topics ~ a polymorphic problem
Easy to understand and explain the gradient descent method!
还是有机会的
The maximum length of VARCHAR2 type in Oracle is_ Oracle modify field length SQL
我的问题解决记录1:类上使用了@Component注解,想要使用这个类中的方法,便不能直接new,而应该使用# @Autowired进行注入,否则会报错(如空指针异常等)
English语法_不定代词 - 常用短语
不堆概念、换个角度聊多线程并发编程
Follow teacher Li to learn line generation determinant (continuous update)
TCP failure model
After the thunderstorm of two encryption companies: Celsius repayment guarantee collateral, three arrow capital closed and disappeared
Network picture to local picture - default value or shortcut key
Summary of JD internship written examination questions
Implementation and verification logic of complex expression input component
Are you familiar with the redis cluster principle of high paid programmers & interview questions series 122? How to ensure the high availability of redis (Part 2): cluster mechanism and principle, clu