当前位置:网站首页>A deep learning code base for Xiaobai, one line of code implements 30+ attention mechanisms.
A deep learning code base for Xiaobai, one line of code implements 30+ attention mechanisms.
2022-08-05 05:38:00 【FightingCV】

Hello,大家好,我是小马,最近创建了一个深度学习代码库,欢迎大家来玩呀!代码库地址是https://github.com/xmu-xiaoma666/External-Attention-pytorch,目前实现了将近40个深度学习的常见算法!
For 小白(Like Me):最近在读论文的时候会发现一个问题,有时候论文核心思想非常简单,核心代码可能也就十几行.但是打开作者release的源码时,却发现提出的模块嵌入到分类、检测、分割等任务框架中,导致代码比较冗余,对于特定任务框架不熟悉的我,很难找到核心代码,导致在论文和网络思想的理解上会有一定困难.
For 进阶者(Like You):如果把Conv、FC、RNN这些基本单元看做小的Lego积木,把Transformer、ResNet这些结构看成已经搭好的Lego城堡.那么本项目提供的模块就是一个个具有完整语义信息的Lego组件.让科研工作者们避免反复造轮子,只需思考如何利用这些“Lego组件”,搭建出更多绚烂多彩的作品.
For 大神(May Be Like You):能力有限,不喜轻喷!!!
For All:本项目就是要实现一个既能让深度学习小白也能搞懂,又能服务科研和工业社区的代码库.本项目的宗旨是从代码角度,实现让世界上没有难读的论文.
(同时也非常欢迎各位科研工作者将自己的工作的核心代码整理到本项目中,推动科研社区的发展,会在readme中注明代码的作者~)
Attention Series
1. External Attention Usage
1.1. Paper
"Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks"
1.2. Overview

1.3. Usage Code
from model.attention.ExternalAttention import ExternalAttentionimport torch
input=torch.randn( 50, 49, 512)
ea = ExternalAttention(d_model= 512,S= 8)
output=ea(input)
print(output.shape)
2. Self Attention Usage
2.1. Paper
"Attention Is All You Need"
1.2. Overview

1.3. Usage Code
from model.attention.SelfAttention import ScaledDotProductAttentionimport torch
input=torch.randn( 50, 49, 512)
sa = ScaledDotProductAttention(d_model= 512, d_k= 512, d_v= 512, h= 8)
output=sa(input,input,input)
print(output.shape)
3. Simplified Self Attention Usage
3.1. Paper
None
3.2. Overview

3.3. Usage Code
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttentionimport torch
input=torch.randn( 50, 49, 512)
ssa = SimplifiedScaledDotProductAttention(d_model= 512, h= 8)
output=ssa(input,input,input)
print(output.shape)
4. Squeeze-and-Excitation Attention Usage
4.1. Paper
"Squeeze-and-Excitation Networks"
4.2. Overview

4.3. Usage Code
from model.attention.SEAttention import SEAttentionimport torch
input=torch.randn( 50, 512, 7, 7)
se = SEAttention(channel= 512,reduction= 8)
output=se(input)
print(output.shape)
5. SK Attention Usage
5.1. Paper
"Selective Kernel Networks"
5.2. Overview

5.3. Usage Code
from model.attention.SKAttention import SKAttentionimport torch
input=torch.randn( 50, 512, 7, 7)
se = SKAttention(channel= 512,reduction= 8)
output=se(input)
print(output.shape)
6. CBAM Attention Usage
6.1. Paper
"CBAM: Convolutional Block Attention Module"
6.2. Overview


6.3. Usage Code
from model.attention.CBAM import CBAMBlockimport torch
input=torch.randn( 50, 512, 7, 7)
kernel_size=input.shape[ 2]
cbam = CBAMBlock(channel= 512,reduction= 16,kernel_size=kernel_size)
output=cbam(input)
print(output.shape)
7. BAM Attention Usage
7.1. Paper
"BAM: Bottleneck Attention Module"
7.2. Overview

7.3. Usage Code
from model.attention.BAM import BAMBlockimport torch
input=torch.randn( 50, 512, 7, 7)
bam = BAMBlock(channel= 512,reduction= 16,dia_val= 2)
output=bam(input)
print(output.shape)
8. ECA Attention Usage
8.1. Paper
"ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks"
8.2. Overview

8.3. Usage Code
from model.attention.ECAAttention import ECAAttentionimport torch
input=torch.randn( 50, 512, 7, 7)
eca = ECAAttention(kernel_size= 3)
output=eca(input)
print(output.shape)
9. DANet Attention Usage
9.1. Paper
"Dual Attention Network for Scene Segmentation"
9.2. Overview

9.3. Usage Code
from model.attention.DANet import DAModuleimport torch
input=torch.randn( 50, 512, 7, 7)
danet=DAModule(d_model= 512,kernel_size= 3,H= 7,W= 7)
print(danet(input).shape)
10. Pyramid Split Attention Usage
10.1. Paper
"EPSANet: An Efficient Pyramid Split Attention Block on Convolutional Neural Network"
10.2. Overview

10.3. Usage Code
from model.attention.PSA import PSAimport torch
input=torch.randn( 50, 512, 7, 7)
psa = PSA(channel= 512,reduction= 8)
output=psa(input)
print(output.shape)
11. Efficient Multi-Head Self-Attention Usage
11.1. Paper
"ResT: An Efficient Transformer for Visual Recognition"
11.2. Overview

11.3. Usage Code
from model.attention.EMSA import EMSA
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 64, 512)
emsa = EMSA(d_model= 512, d_k= 512, d_v= 512, h= 8,H= 8,W= 8,ratio= 2,apply_transform= True)
output=emsa(input,input,input)
print(output.shape)
12. Shuffle Attention Usage
12.1. Paper
"SA-NET: SHUFFLE ATTENTION FOR DEEP CONVOLUTIONAL NEURAL NETWORKS"
12.2. Overview

12.3. Usage Code
from model.attention.ShuffleAttention import ShuffleAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 512, 7, 7)
se = ShuffleAttention(channel= 512,G= 8)
output=se(input)
print(output.shape)
13. MUSE Attention Usage
13.1. Paper
"MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning"
13.2. Overview

13.3. Usage Code
from model.attention.MUSEAttention import MUSEAttentionimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 49, 512)
sa = MUSEAttention(d_model= 512, d_k= 512, d_v= 512, h= 8)
output=sa(input,input,input)
print(output.shape)
14. SGE Attention Usage
14.1. Paper
Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks
14.2. Overview

14.3. Usage Code
from model.attention.SGE import SpatialGroupEnhanceimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 512, 7, 7)
sge = SpatialGroupEnhance(groups= 8)
output=sge(input)
print(output.shape)
15. A2 Attention Usage
15.1. Paper
A2-Nets: Double Attention Networks
15.2. Overview

15.3. Usage Code
from model.attention.A2Atttention import DoubleAttentionimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 512, 7, 7)
a2 = DoubleAttention( 512, 128, 128, True)
output=a2(input)
print(output.shape)
16. AFT Attention Usage
16.1. Paper
An Attention Free Transformer
16.2. Overview

16.3. Usage Code
from model.attention.AFT import AFT_FULLimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 49, 512)
aft_full = AFT_FULL(d_model= 512, n= 49)
output=aft_full(input)
print(output.shape)
17. Outlook Attention Usage
17.1. Paper
VOLO: Vision Outlooker for Visual Recognition"
17.2. Overview

17.3. Usage Code
from model.attention.OutlookAttention import OutlookAttentionimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 28, 28, 512)
outlook = OutlookAttention(dim= 512)
output=outlook(input)
print(output.shape)
18. ViP Attention Usage
18.1. Paper
Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition"
18.2. Overview

18.3. Usage Code
from model.attention.ViP import WeightedPermuteMLP
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 64, 8, 8, 512)
seg_dim= 8
vip=WeightedPermuteMLP( 512,seg_dim)
out=vip(input)
print(out.shape)
19. CoAtNet Attention Usage
19.1. Paper
CoAtNet: Marrying Convolution and Attention for All Data Sizes"
19.2. Overview
None
19.3. Usage Code
from model.attention.CoAtNet import CoAtNet
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 3, 224, 224)
mbconv=CoAtNet(in_ch= 3,image_size= 224)
out=mbconv(input)
print(out.shape)
20. HaloNet Attention Usage
20.1. Paper
Scaling Local Self-Attention for Parameter Efficient Visual Backbones"
20.2. Overview

20.3. Usage Code
from model.attention.HaloAttention import HaloAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 512, 8, 8)
halo = HaloAttention(dim= 512,
block_size= 2,
halo_size= 1,)
output=halo(input)
print(output.shape)
21. Polarized Self-Attention Usage
21.1. Paper
Polarized Self-Attention: Towards High-quality Pixel-wise Regression"
21.2. Overview

21.3. Usage Code
from model.attention.PolarizedSelfAttention import ParallelPolarizedSelfAttention,SequentialPolarizedSelfAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 512, 7, 7)
psa = SequentialPolarizedSelfAttention(channel= 512)
output=psa(input)
print(output.shape)
22. CoTAttention Usage
22.1. Paper
Contextual Transformer Networks for Visual Recognition---arXiv 2021.07.26
22.2. Overview

22.3. Usage Code
from model.attention.CoTAttention import CoTAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 512, 7, 7)
cot = CoTAttention(dim= 512,kernel_size= 3)
output=cot(input)
print(output.shape)
23. Residual Attention Usage
23.1. Paper
Residual Attention: A Simple but Effective Method for Multi-Label Recognition---ICCV2021
23.2. Overview

23.3. Usage Code
from model.attention.ResidualAttention import ResidualAttention
import torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 512, 7, 7)
resatt = ResidualAttention(channel= 512,num_class= 1000,la= 0.2)
output=resatt(input)
print(output.shape)
24. S2 Attention Usage
24.1. Paper
S²-MLPv2: Improved Spatial-Shift MLP Architecture for Vision---arXiv 2021.08.02
24.2. Overview

24.3. Usage Code
from model.attention.S2Attention import S2Attentionimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 512, 7, 7)
s2att = S2Attention(channels= 512)
output=s2att(input)
print(output.shape)
25. GFNet Attention Usage
25.1. Paper
Global Filter Networks for Image Classification---arXiv 2021.07.01
25.2. Overview

25.3. Usage Code - Implemented by Wenliang Zhao (Author)
from model.attention.gfnet import GFNetimport torch
from torch import nn
from torch.nn import functional as F
x = torch.randn( 1, 3, 224, 224)
gfnet = GFNet(embed_dim= 384, img_size= 224, patch_size= 16, num_classes= 1000)
out = gfnet(x)
print(out.shape)
26. TripletAttention Usage
26.1. Paper
Rotate to Attend: Convolutional Triplet Attention Module---CVPR 2021
26.2. Overview

26.3. Usage Code - Implemented by digantamisra98
from model.attention.TripletAttention import TripletAttentionimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 50, 512, 7, 7)
triplet = TripletAttention()
output=triplet(input)
print(output.shape)
27. Coordinate Attention Usage
27.1. Paper
Coordinate Attention for Efficient Mobile Network Design---CVPR 2021
27.2. Overview

27.3. Usage Code - Implemented by Andrew-Qibin
from model.attention.CoordAttention import CoordAttimport torch
from torch import nn
from torch.nn import functional as F
inp=torch.rand([ 2, 96, 56, 56])
inp_dim, oup_dim = 96, 96
reduction= 32
coord_attention = CoordAtt(inp_dim, oup_dim, reduction=reduction)
output=coord_attention(inp)
print(output.shape)
28. MobileViT Attention Usage
28.1. Paper
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2021.10.05
28.2. Overview

28.3. Usage Code
from model.attention.MobileViTAttention import MobileViTAttentionimport torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
m=MobileViTAttention()
input=torch.randn( 1, 3, 49, 49)
output=m(input)
print(output.shape) #output:(1,3,49,49)
29. ParNet Attention Usage
29.1. Paper
Non-deep Networks---ArXiv 2021.10.20
29.2. Overview

29.3. Usage Code
from model.attention.ParNetAttention import *import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn( 50, 512, 7, 7)
pna = ParNetAttention(channel= 512)
output=pna(input)
print(output.shape) #50,512,7,7
30. UFO Attention Usage
30.1. Paper
UFO-ViT: High Performance Linear Vision Transformer without Softmax---ArXiv 2021.09.29
30.2. Overview

30.3. Usage Code
from model.attention.UFOAttention import *import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn( 50, 49, 512)
ufo = UFOAttention(d_model= 512, d_k= 512, d_v= 512, h= 8)
output=ufo(input,input,input)
print(output.shape) #[50, 49, 512]
31. MobileViTv2 Attention Usage
31.1. Paper
Separable Self-attention for Mobile Vision Transformers---ArXiv 2022.06.06
31.2. Overview

31.3. Usage Code
from model.attention.UFOAttention import *import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn( 50, 49, 512)
ufo = UFOAttention(d_model= 512, d_k= 512, d_v= 512, h= 8)
output=ufo(input,input,input)
print(output.shape) #[50, 49, 512]
Backbone Series
1. ResNet Usage
1.1. Paper
"Deep Residual Learning for Image Recognition---CVPR2016 Best Paper"
1.2. Overview


1.3. Usage Code
from model.backbone.resnet import ResNet50,ResNet101,ResNet152
import torch
if __name__ == '__main__':
input=torch.randn( 50, 3, 224, 224)
resnet50=ResNet50( 1000)
# resnet101=ResNet101(1000)
# resnet152=ResNet152(1000)
out=resnet50(input)
print(out.shape)
2. ResNeXt Usage
2.1. Paper
"Aggregated Residual Transformations for Deep Neural Networks---CVPR2017"
2.2. Overview

2.3. Usage Code
from model.backbone.resnext import ResNeXt50,ResNeXt101,ResNeXt152
import torch
if __name__ == '__main__':
input=torch.randn( 50, 3, 224, 224)
resnext50=ResNeXt50( 1000)
# resnext101=ResNeXt101(1000)
# resnext152=ResNeXt152(1000)
out=resnext50(input)
print(out.shape)
3. MobileViT Usage
3.1. Paper
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer---ArXiv 2020.10.05
3.2. Overview

3.3. Usage Code
from model.backbone.MobileViT import *
import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn( 1, 3, 224, 224)
### mobilevit_xxs
mvit_xxs=mobilevit_xxs()
out=mvit_xxs(input)
print(out.shape)
### mobilevit_xs
mvit_xs=mobilevit_xs()
out=mvit_xs(input)
print(out.shape)
### mobilevit_s
mvit_s=mobilevit_s()
out=mvit_s(input)
print(out.shape)
4. ConvMixer Usage
4.1. Paper
Patches Are All You Need?---ICLR2022 (Under Review)
4.2. Overview

4.3. Usage Code
from model.backbone.ConvMixer import *
import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
x=torch.randn( 1, 3, 224, 224)
convmixer=ConvMixer(dim= 512,depth= 12)
out=convmixer(x)
print(out.shape) #[1, 1000]
MLP Series
1. RepMLP Usage
1.1. Paper
"RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition"
1.2. Overview

1.3. Usage Code
from model.mlp.repmlp import RepMLPimport torch
from torch import nn
N= 4 #batch size
C= 512 #input dim
O= 1024 #output dim
H= 14 #image height
W= 14 #image width
h= 7 #patch height
w= 7 #patch width
fc1_fc2_reduction= 1 #reduction ratio
fc3_groups= 8 # groups
repconv_kernels=[ 1, 3, 5, 7] #kernel list
repmlp=RepMLP(C,O,H,W,h,w,fc1_fc2_reduction,fc3_groups,repconv_kernels=repconv_kernels)
x=torch.randn(N,C,H,W)
repmlp.eval()
for module in repmlp.modules():
if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
nn.init.uniform_(module.running_mean, 0, 0.1)
nn.init.uniform_(module.running_var, 0, 0.1)
nn.init.uniform_(module.weight, 0, 0.1)
nn.init.uniform_(module.bias, 0, 0.1)
#training result
out=repmlp(x)
#inference result
repmlp.switch_to_deploy()
deployout = repmlp(x)
print(((deployout-out)** 2).sum())
2. MLP-Mixer Usage
2.1. Paper
"MLP-Mixer: An all-MLP Architecture for Vision"
2.2. Overview

2.3. Usage Code
from model.mlp.mlp_mixer import MlpMixerimport torch
mlp_mixer=MlpMixer(num_classes= 1000,num_blocks= 10,patch_size= 10,tokens_hidden_dim= 32,channels_hidden_dim= 1024,tokens_mlp_dim= 16,channels_mlp_dim= 1024)
input=torch.randn( 50, 3, 40, 40)
output=mlp_mixer(input)
print(output.shape)
3. ResMLP Usage
3.1. Paper
"ResMLP: Feedforward networks for image classification with data-efficient training"
3.2. Overview

3.3. Usage Code
from model.mlp.resmlp import ResMLPimport torch
input=torch.randn( 50, 3, 14, 14)
resmlp=ResMLP(dim= 128,image_size= 14,patch_size= 7,class_num= 1000)
out=resmlp(input)
print(out.shape) #the last dimention is class_num
4. gMLP Usage
4.1. Paper
"Pay Attention to MLPs"
4.2. Overview

4.3. Usage Code
from model.mlp.g_mlp import gMLPimport torch
num_tokens= 10000
bs= 50
len_sen= 49
num_layers= 6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLP(num_tokens=num_tokens,len_sen=len_sen,dim= 512,d_ff= 1024)
output=gmlp(input)
print(output.shape)
5. sMLP Usage
5.1. Paper
"Sparse MLP for Image Recognition: Is Self-Attention Really Necessary?"
5.2. Overview

5.3. Usage Code
from model.mlp.sMLP_block import sMLPBlockimport torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn( 50, 3, 224, 224)
smlp=sMLPBlock(h= 224,w= 224)
out=smlp(input)
print(out.shape)
Re-Parameter Series
1. RepVGG Usage
1.1. Paper
"RepVGG: Making VGG-style ConvNets Great Again"
1.2. Overview

1.3. Usage Code
from model.rep.repvgg import RepBlock
import torch
input=torch.randn( 50, 512, 49, 49)
repblock=RepBlock( 512, 512)
repblock.eval()
out=repblock(input)
repblock._switch_to_deploy()
out2=repblock(input)
print( 'difference between vgg and repvgg')
print(((out2-out)** 2).sum())
2. ACNet Usage
2.1. Paper
"ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks"
2.2. Overview

2.3. Usage Code
from model.rep.acnet import ACNetimport torch
from torch import nn
input=torch.randn( 50, 512, 49, 49)
acnet=ACNet( 512, 512)
acnet.eval()
out=acnet(input)
acnet._switch_to_deploy()
out2=acnet(input)
print( 'difference:')
print(((out2-out)** 2).sum())
2. Diverse Branch Block Usage
2.1. Paper
"Diverse Branch Block: Building a Convolution as an Inception-like Unit"
2.2. Overview

2.3. Usage Code
2.3.1 Transform I
from model.rep.ddb import transI_conv_bnimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 64, 7, 7)
#conv+bn
conv1=nn.Conv2d( 64, 64, 3,padding= 1)
bn1=nn.BatchNorm2d( 64)
bn1.eval()
out1=bn1(conv1(input))
#conv_fuse
conv_fuse=nn.Conv2d( 64, 64, 3,padding= 1)
conv_fuse.weight.data,conv_fuse.bias.data=transI_conv_bn(conv1,bn1)
out2=conv_fuse(input)
print( "difference:",((out2-out1)** 2).sum().item())
2.3.2 Transform II
from model.rep.ddb import transII_conv_branchimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 64, 7, 7)
#conv+conv
conv1=nn.Conv2d( 64, 64, 3,padding= 1)
conv2=nn.Conv2d( 64, 64, 3,padding= 1)
out1=conv1(input)+conv2(input)
#conv_fuse
conv_fuse=nn.Conv2d( 64, 64, 3,padding= 1)
conv_fuse.weight.data,conv_fuse.bias.data=transII_conv_branch(conv1,conv2)
out2=conv_fuse(input)
print( "difference:",((out2-out1)** 2).sum().item())
2.3.3 Transform III
from model.rep.ddb import transIII_conv_sequentialimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 64, 7, 7)
#conv+conv
conv1=nn.Conv2d( 64, 64, 1,padding= 0,bias= False)
conv2=nn.Conv2d( 64, 64, 3,padding= 1,bias= False)
out1=conv2(conv1(input))
#conv_fuse
conv_fuse=nn.Conv2d( 64, 64, 3,padding= 1,bias= False)
conv_fuse.weight.data=transIII_conv_sequential(conv1,conv2)
out2=conv_fuse(input)
print( "difference:",((out2-out1)** 2).sum().item())
2.3.4 Transform IV
from model.rep.ddb import transIV_conv_concatimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 64, 7, 7)
#conv+conv
conv1=nn.Conv2d( 64, 32, 3,padding= 1)
conv2=nn.Conv2d( 64, 32, 3,padding= 1)
out1=torch.cat([conv1(input),conv2(input)],dim= 1)
#conv_fuse
conv_fuse=nn.Conv2d( 64, 64, 3,padding= 1)
conv_fuse.weight.data,conv_fuse.bias.data=transIV_conv_concat(conv1,conv2)
out2=conv_fuse(input)
print( "difference:",((out2-out1)** 2).sum().item())
2.3.5 Transform V
from model.rep.ddb import transV_avgimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 64, 7, 7)
avg=nn.AvgPool2d(kernel_size= 3,stride= 1)
out1=avg(input)
conv=transV_avg( 64, 3)
out2=conv(input)
print( "difference:",((out2-out1)** 2).sum().item())
2.3.6 Transform VI
from model.rep.ddb import transVI_conv_scaleimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 64, 7, 7)
#conv+conv
conv1x1=nn.Conv2d( 64, 64, 1)
conv1x3=nn.Conv2d( 64, 64,( 1, 3),padding=( 0, 1))
conv3x1=nn.Conv2d( 64, 64,( 3, 1),padding=( 1, 0))
out1=conv1x1(input)+conv1x3(input)+conv3x1(input)
#conv_fuse
conv_fuse=nn.Conv2d( 64, 64, 3,padding= 1)
conv_fuse.weight.data,conv_fuse.bias.data=transVI_conv_scale(conv1x1,conv1x3,conv3x1)
out2=conv_fuse(input)
print( "difference:",((out2-out1)** 2).sum().item())
Convolution Series
1. Depthwise Separable Convolution Usage
1.1. Paper
"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
1.2. Overview

1.3. Usage Code
from model.conv.DepthwiseSeparableConvolution import DepthwiseSeparableConvolutionimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 3, 224, 224)
dsconv=DepthwiseSeparableConvolution( 3, 64)
out=dsconv(input)
print(out.shape)
2. MBConv Usage
2.1. Paper
"Efficientnet: Rethinking model scaling for convolutional neural networks"
2.2. Overview

2.3. Usage Code
from model.conv.MBConv import MBConvBlockimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 3, 224, 224)
mbconv=MBConvBlock(ksize= 3,input_filters= 3,output_filters= 512,image_size= 224)
out=mbconv(input)
print(out.shape)
3. Involution Usage
3.1. Paper
"Involution: Inverting the Inherence of Convolution for Visual Recognition"
3.2. Overview

3.3. Usage Code
from model.conv.Involution import Involutionimport torch
from torch import nn
from torch.nn import functional as F
input=torch.randn( 1, 4, 64, 64)
involution=Involution(kernel_size= 3,in_channel= 4,stride= 2)
out=involution(input)
print(out.shape)
4. DynamicConv Usage
4.1. Paper
"Dynamic Convolution: Attention over Convolution Kernels"
4.2. Overview

4.3. Usage Code
from model.conv.DynamicConv import *import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn( 2, 32, 64, 64)
m=DynamicConv(in_planes= 32,out_planes= 64,kernel_size= 3,stride= 1,padding= 1,bias= False)
out=m(input)
print(out.shape) # 2,32,64,64
5. CondConv Usage
5.1. Paper
"CondConv: Conditionally Parameterized Convolut ions for Efficient Inference"
5.2. Overview

5.3. Usage Code
from model.conv.CondConv import *import torch
from torch import nn
from torch.nn import functional as F
if __name__ == '__main__':
input=torch.randn( 2, 32, 64, 64)
m=CondConv(in_planes= 32,out_planes= 64,kernel_size= 3,stride= 1,padding= 1,bias= False)
out=m(input)
print(out.shape)
已建立深度学习公众号——FightingCV,欢迎大家关注!!!
ICCV、CVPR、NeurIPS、ICML论文解析汇总:https://github.com/xmu-xiaoma666/FightingCV-Paper-Reading
面向小白的Attention、重参数、MLP、卷积核心代码学习:https://github.com/xmu-xiaoma666/External-Attention-pytorch
加入交流群,请添加小助手wx:FightngCV666
边栏推荐
猜你喜欢
flink项目开发-配置jar依赖,连接器,类库
【After a while 6】Machine vision video 【After a while 2 was squeezed out】
flink实例开发-详细使用指南
flink部署操作-flink standalone集群安装部署
Flutter 3.0升级内容,该如何与小程序结合
11%的参数就能优于Swin,微软提出快速预训练蒸馏方法TinyViT
哥廷根大学提出CLIPSeg,能同时作三个分割任务的模型
[Practice 1] Diabetes Genetic Risk Detection Challenge [IFLYTEK Open Platform]
【Pytorch学习笔记】11.取Dataset的子集、给Dataset打乱顺序的方法(使用Subset、random_split)
【数据库和SQL学习笔记】4.SELECT查询2:排序(ORDER BY)、聚合函数、分组查询(GROUP BY)
随机推荐
【After a while 6】Machine vision video 【After a while 2 was squeezed out】
day6-列表作业
如何编写一个优雅的Shell脚本(二)
MSRA提出学习实例和分布式视觉表示的极端掩蔽模型ExtreMA
BFC详解(Block Formmating Context)
Distributed and Clustered
学习总结week3_2函数进阶
神经网络也能像人类利用外围视觉一样观察图像
[Go through 11] Random Forest and Feature Engineering
The software design experiment four bridge model experiment
Calling Matlab configuration in pycharm: No module named 'matlab.engine'; 'matlab' is not a package
ES6 Set、WeakSet
机器学习(一) —— 机器学习基础
The fourth back propagation back propagation
Oracle压缩表修改字段的处理方法
[After a 12] No record for a whole week
Flutter 3.0升级内容,该如何与小程序结合
flink项目开发-flink的scala shell命令行交互模式开发
BFC(Block Formatting Context)
[Go through 9] Convolution