当前位置:网站首页>[semantic segmentation] 2021-pvt2 cvmj
[semantic segmentation] 2021-pvt2 cvmj
2022-07-29 10:16:00 【Talk about it】
List of articles
【 Semantic segmentation 】2021-PVT2 CVMJ
Thesis title :PVT v2: Improved Baselines with Pyramid Vision Transformer
Thesis link : https://arxiv.org/abs/2106.13797
Paper code :https://github.com/whai362/PVT
Thesis translation :PVT,PVTv2 - Simple books (jianshu.com)
1. brief introduction
In computer vision Transformer Encouraging progress has been made recently . In this work , The author added 3 An improved design to improve the original pyramid vision Transformer(PVTv1), These include :
- Locally continuous characteristic with convolution ;
- have zero paddings The location code of ,
- With average collection .
With these simple modifications ,PVTv2 In the classification 、 Detection and segmentation are significantly better than PVTv1. Besides ,PVTv2 stay ImageNet-1K Under the pre training, I have achieved more recent works ( Include Swin Transformer) Better performance .
2. The Internet
PVTv1[33] The main limitations of are the following three aspects :
(1) And ViT similar , When processing high-resolution input ( If the short side is 800 Pixels ),PVTv1 The computational complexity of is relatively large .
(2) PVTv1 Treat an image as a set of non overlapping patch Sequence , To some extent, the local continuity of the image is lost ;
(3) PVTv1 The location code in is fixed , It is not flexible for processing images of any size . These problems limit PVTv1 Performance in visual tasks .
2.1 Overall framework
2.2 Linear Spatial Reduction Attention
use LinearSRA replace SRA. Here is a problem that needs to be explained , The author in PVTv1 I didn't use convolution , But it's compressing K、V When using Conv2D( See github In the code ). stay PVTv2 Use average pooling instead Conv2D.
2.3 Overlapping Patch Embedding( Overlapping patch embedding )
2.4 Convolutional FeedForward
3. Code pvt2-upernet
import torch
from torch import nn, Tensor
from torch.nn import functional as F
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Copied from timm This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """
def __init__(self, p: float = None):
super().__init__()
self.p = p
def forward(self, x: Tensor) -> Tensor:
if self.p == 0. or not self.training:
return x
kp = 1 - self.p
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
return x.div(kp) * random_tensor
class DWConv(nn.Module):
def __init__(self, dim):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
def forward(self, x: Tensor, H: int, W: int) -> Tensor:
B, _, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
return x.flatten(2).transpose(1, 2)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim, out_dim=None) -> None:
super().__init__()
out_dim = out_dim or dim
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.dwconv = DWConv(hidden_dim)
def forward(self, x: Tensor, H: int, W: int) -> Tensor:
return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))
class Attention(nn.Module):
def __init__(self, dim, head, sr_ratio):
super().__init__()
self.head = head
self.sr_ratio = sr_ratio
self.scale = (dim // head) ** -0.5
self.q = nn.Linear(dim, dim, bias=True)
self.kv = nn.Linear(dim, dim * 2, bias=True)
self.proj = nn.Linear(dim, dim)
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)
self.norm = nn.LayerNorm(dim)
def forward(self, x: Tensor, H, W) -> Tensor:
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
x = self.norm(x)
k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim, head, sr_ratio=1, mlp_ratio=4, dpr=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, head, sr_ratio)
self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, int(dim * mlp_ratio))
def forward(self, x: Tensor, H, W) -> Tensor:
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class PatchEmbed(nn.Module):
def __init__(self, c1=3, c2=64, patch_size=7, stride=4):
super().__init__()
self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size // 2)
self.norm = nn.LayerNorm(c2)
def forward(self, x: Tensor) -> Tensor:
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
pvtv2_settings = {
'B1': [2, 2, 2, 2], # depths
'B2': [3, 4, 6, 3],
'B3': [3, 4, 18, 3],
'B4': [3, 8, 27, 3],
'B5': [3, 6, 40, 3]
}
class PVTv2(nn.Module):
def __init__(self, model_name: str = 'B1') -> None:
super().__init__()
assert model_name in pvtv2_settings.keys(), f"PVTv2 model name should be in {
list(pvtv2_settings.keys())}"
depths = pvtv2_settings[model_name]
embed_dims = [64, 128, 320, 512]
drop_path_rate = 0.1
self.embed_dims = embed_dims
# patch_embed
self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)
self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)
self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)
self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# transformer encoder
cur = 0
self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, 8, dpr[cur + i]) for i in range(depths[0])])
self.norm1 = nn.LayerNorm(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, 8, dpr[cur + i]) for i in range(depths[1])])
self.norm2 = nn.LayerNorm(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, 4, dpr[cur + i]) for i in range(depths[2])])
self.norm3 = nn.LayerNorm(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, 4, dpr[cur + i]) for i in range(depths[3])])
self.norm4 = nn.LayerNorm(embed_dims[3])
def forward(self, x: Tensor) -> Tensor:
B = x.shape[0]
# stage 1
x, H, W = self.patch_embed1(x)
for blk in self.block1:
x = blk(x, H, W)
x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# stage 2
x, H, W = self.patch_embed2(x1)
for blk in self.block2:
x = blk(x, H, W)
x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# stage 3
x, H, W = self.patch_embed3(x2)
for blk in self.block3:
x = blk(x, H, W)
x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
# stage 4
x, H, W = self.patch_embed4(x3)
for blk in self.block4:
x = blk(x, H, W)
x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)
return x1, x2, x3, x4
class PPM(nn.ModuleList):
""" Pyramid pooling model Pyramid Pooling Module https://arxiv.org/abs/1612.01105 CVPR 2017 year The job of Use maximum pooling , obtain """
def __init__(self, pool_sizes, in_channels, out_channels):
super(PPM, self).__init__()
self.pool_sizes = pool_sizes
self.in_channels = in_channels
self.out_channels = out_channels
for pool_size in pool_sizes:
self.append(
nn.Sequential(
nn.AdaptiveMaxPool2d(pool_size),
nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),
)
)
def forward(self, x):
out_puts = []
for ppm in self:
ppm_out = nn.functional.interpolate(ppm(x), size=(x.size(2), x.size(3)), mode='bilinear',
align_corners=True)
out_puts.append(ppm_out)
return out_puts
class PPMHEAD(nn.Module):
def __init__(self, in_channels, out_channels, pool_sizes=[1, 2, 3, 6], ):
super(PPMHEAD, self).__init__()
self.pool_sizes = pool_sizes
self.in_channels = in_channels
self.out_channels = out_channels
self.psp_modules = PPM(self.pool_sizes, self.in_channels, self.out_channels)
self.final = nn.Sequential(
nn.Conv2d(self.in_channels + len(self.pool_sizes) * self.out_channels, 4 * self.out_channels,
kernel_size=1),
nn.BatchNorm2d(4 * self.out_channels),
nn.ReLU(),
)
def forward(self, x):
out = self.psp_modules(x)
out.append(x)
out = torch.cat(out, 1)
out = self.final(out)
return out
class FPNHEAD(nn.Module):
def __init__(self, out_channels=512, num_classes=19, channels=[64, 128, 320, 512]):
""" Args: out_channels: The last level of integration The channel number , Number of channels before classification num_classes: Number of final classifications channels: four layers backbone The number of channels """
super(FPNHEAD, self).__init__()
self.num_classes = num_classes
self.PPMHead = PPMHEAD(in_channels=channels[-1], out_channels=channels[-1] // 4)
self.Conv_fuse1 = nn.Sequential(
nn.Conv2d(channels[-2], channels[-2], 1),
nn.BatchNorm2d(channels[-2]),
nn.ReLU()
)
self.Conv_fuse1_ = nn.Sequential(
nn.Conv2d(channels[-2] + channels[-1], channels[-2], 1),
nn.BatchNorm2d(channels[-2]),
nn.ReLU()
)
self.Conv_fuse2 = nn.Sequential(
nn.Conv2d(channels[-3], channels[-3], 1),
nn.BatchNorm2d(channels[-3]),
nn.ReLU()
)
self.Conv_fuse2_ = nn.Sequential(
nn.Conv2d(channels[-3] + channels[-2], channels[-3], 1),
nn.BatchNorm2d(channels[-3]),
nn.ReLU()
)
self.Conv_fuse3 = nn.Sequential(
nn.Conv2d(channels[-4], channels[-4], 1),
nn.BatchNorm2d(channels[-4]),
nn.ReLU()
)
self.Conv_fuse3_ = nn.Sequential(
nn.Conv2d(channels[-4] + channels[-3], channels[-4], 1),
nn.BatchNorm2d(channels[-4]),
nn.ReLU()
)
self.fuse_all = nn.Sequential(
nn.Conv2d(sum(channels), out_channels, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.cls_seg = nn.Sequential(
nn.Conv2d(out_channels, self.num_classes, kernel_size=3, padding=1),
)
def forward(self, input_fpn):
""" Args: input_fpn: Four characteristic diagrams Returns: """
##############################
# x1 = torch.randn(1, 64, 56, 56)
# x2 = torch.randn(1, 128, 28, 28)
# x3 = torch.randn(1, 320, 14, 14)
# x4 = torch.randn(1, 512, 7, 7)
# 1/32 Characteristic graph Use PPMHead torch.Size([1, 2048, 7, 7])
# x1= [1, 512, 7, 7]
x1 = self.PPMHead(input_fpn[-1])
# print(x1.shape)
# [1, 512, 7, 7]-->[1, 512, 14, 14]
x = F.interpolate(x1,
size=(x1.size(2) * 2, x1.size(3) * 2),
mode='bilinear',
align_corners=True)
# The fusion 1/16 Graph torch.Size([1, 3072, 14, 14]). Just splice on the channel
# [1, 512, 14, 14] + [1,320, 14, 14] =[1, 832, 14, 14]
x = torch.cat([x, self.Conv_fuse1(input_fpn[-2])], dim=1)
##############################
# [1, 832, 14, 14] -->[1, 320, 14, 14] , Reduce the number of channels
x2 = self.Conv_fuse1_(x)
# [1, 320, 14, 14]->[1, 320, 28,28]
x = F.interpolate(x2,
size=(x2.size(2) * 2, x2.size(3) * 2),
mode='bilinear',
align_corners=True)
# The fusion 1/8 Graph
# [1, 320, 28,28] +[1, 128, 28,28] = [1, 448, 28,28]
x = torch.cat([x, self.Conv_fuse2(input_fpn[-3])], dim=1)
# print(x.shape)
##############################
# [1, 448, 28,28] -> [1, 128, 28, 28] Perform channel reduction .
x3 = self.Conv_fuse2_(x)
# Yes 1/8---> 1/4
# [1, 128, 28, 28]-> [1, 128, 56, 56]
x = F.interpolate(x3,
size=(x3.size(2) * 2, x3.size(3) * 2),
mode='bilinear',
align_corners=True)
# The fusion 1/4 Graph
# [1, 128, 56, 56]+[1, 64, 56, 56]=[1, 192, 56, 56]
x = torch.cat([x, self.Conv_fuse3(input_fpn[-4])], dim=1)
##############################
# [1, 192, 56, 56]-> [1, 64, 56, 56]
x4 = self.Conv_fuse3_(x)
x1 = F.interpolate(x1, x4.size()[-2:], mode='bilinear', align_corners=True)
x2 = F.interpolate(x2, x4.size()[-2:], mode='bilinear', align_corners=True)
x3 = F.interpolate(x3, x4.size()[-2:], mode='bilinear', align_corners=True)
x = self.fuse_all(torch.cat([x1, x2, x3, x4], 1))
# print(x.shape)
x = F.interpolate(x, size=(x.size(2) * 4, x.size(3) * 4), mode='bilinear', align_corners=True)
# print(x.shape)
x = self.cls_seg(x)
return x
class pvt2_upernet(nn.Module):
def __init__(self, num_classes, channels, size="B1"):
""" Number of categories Args: num_classes: """
super(pvt2_upernet, self).__init__()
self.backbone = PVTv2(size)
self.decoder = FPNHEAD(num_classes=num_classes, channels=channels)
def forward(self, x):
x = self.backbone(x)
x = self.decoder(x)
return x
def pvt2_B1_upernet(num_classes):
model = pvt2_upernet(num_classes=num_classes, size="B1", channels=[64, 128, 320, 512])
return model
def pvt2_B2_upernet(num_classes):
model = pvt2_upernet(num_classes=num_classes, size="B2", channels=[64, 128, 320, 512])
return model
def pvt2_B3_upernet(num_classes):
model = pvt2_upernet(num_classes=num_classes, size="B3", channels=[64, 128, 320, 512])
return model
def pvt2_B4_upernet(num_classes):
model = pvt2_upernet(num_classes=num_classes, size="B3", channels=[64, 128, 320, 512])
return model
if __name__ == '__main__':
x=torch.randn(1,3,224,224)
model=pvt2_B2_upernet(num_classes=19)
y=model(x)
print(y.shape)
Reference material
PVTv2: Improved Baselines with Pyramid Vision Transformer——PVT2 Reading - You know (zhihu.com)
边栏推荐
- remap_ Use of table in impdp
- TMS320C6000_ Tms320f28035 Chinese data manual
- 【AAAI】用于交通流预测的基于注意力的时空图卷积网络
- [paper reading] i-bert: integer only Bert quantification
- 一文读懂Plato Farm的ePLATO,以及其高溢价缘由
- “为机器立心”:朱松纯团队搭建人与机器人的价值双向对齐系统,解决人机协作领域的重大挑战
- Uniswap entered the NFT trading market and opensea took the lead
- Problems and solutions of introducing redis cache
- Read Plato farm's eplato and the reason for its high premium
- Consumer electronics, frozen to death in summer
猜你喜欢

Where are those test / development programmers in their 30s? a man should be independent at the age of thirty......

根据给定字符数和字符,打印输出“沙漏”和剩余数

Dynamics 365Online 如何自定义商机关闭窗体

静态资源映射

MySQL 8 of relational database -- deepening and comprehensive learning from the inside out

Comprehensively design an oppe home page -- the bottom of the page

remap_ Use of table in impdp

Examples of specific usage of diagnostic instructions in s7-1200 and s7-1500 (led+devicestates+modulestates)

How can Plato obtain premium income through elephant swap in a bear market?

Follow teacher Wu to learn advanced numbers - function, limit and continuity (continuous update)
随机推荐
After E-sports enters Asia, will Tencent be the next "NBA game catcher"?
Efficient 7 habit learning notes
Unity3d空包打apk报错汇总
函数——(C游记)
Reasons for the rise of DDD and its relationship with microservices
跟着武老师学高数——函数、极限和连续(持续更新)
Encyclopedia of introduction to machine learning - 2018 "machine learning beginners" official account article summary
JS temporary dead zone_ Temporary
[fortran]vscode configure FORTRAN to run Hello World
PAHO cross compilation
Window系统操作技巧汇总
English grammar_ Indefinite pronouns - Common Phrases
Unity3d empty package APK error summary
Follow teacher Li to learn line generation determinant (continuous update)
[Yugong series] go teaching course 009 in July 2022 - floating point type of data type
This is an incomplete data competition Yearbook!
Attachment of text of chenjie Report
QoS quality of service five traffic shaping of QoS boundary behavior
Method of cocos2d-x sprite moving
Meituan senior technical expert: DDD's practice in the evolution of tourism e-commerce architecture