当前位置:网站首页>ShuffleNet v2 network structure reproduction (Pytorch version)
ShuffleNet v2 network structure reproduction (Pytorch version)
2022-08-04 08:02:00 【Diffie Herman】
ShuffleNet v2网络结构复现
from torch import nn
from torch.nn import functional
import torch
from torchsummary import summary
# ---------------------------- ShuffleBlock start -------------------------------
# 通道重排,跨group信息交流
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class CBRM(nn.Module):
def __init__(self, c1, c2): # ch_in, ch_out
super(CBRM, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(c2),
nn.ReLU(inplace=True),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
def forward(self, x):
return self.maxpool(self.conv(x))
class Shuffle_Block(nn.Module):
def __init__(self, inp, oup, stride):
super(Shuffle_Block, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1) # 按照维度1进行split
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(nn.Module):
def __init__(self):
super(ShuffleNetV2, self).__init__()
self.MobileNet_01 = nn.Sequential(
CBRM(3, 32), # 160x160
Shuffle_Block(32, 128, 2), # 80x80
Shuffle_Block(128, 128, 1), # 80x80
Shuffle_Block(128, 256, 2), # 40x40
Shuffle_Block(256, 256, 1), # 40x40
Shuffle_Block(256, 512, 2), # 20x20
Shuffle_Block(512, 512, 1), # 20x20
)
def forward(self, x):
x = self.MobileNet_01(x)
return x
if __name__ == '__main__':
shufflenetv2 = ShuffleNetV2()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
inputs = shufflenetv2.to(device)
summary(inputs, (3, 640, 640), batch_size=1, device="cuda") # 分别是输入数据的三个维度
#print(shufflenetv2)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [1, 32, 320, 320] 864
BatchNorm2d-2 [1, 32, 320, 320] 64
ReLU-3 [1, 32, 320, 320] 0
MaxPool2d-4 [1, 32, 160, 160] 0
CBRM-5 [1, 32, 160, 160] 0
Conv2d-6 [1, 32, 80, 80] 288
BatchNorm2d-7 [1, 32, 80, 80] 64
Conv2d-8 [1, 64, 80, 80] 2,048
BatchNorm2d-9 [1, 64, 80, 80] 128
ReLU-10 [1, 64, 80, 80] 0
Conv2d-11 [1, 64, 160, 160] 2,048
BatchNorm2d-12 [1, 64, 160, 160] 128
ReLU-13 [1, 64, 160, 160] 0
Conv2d-14 [1, 64, 80, 80] 576
BatchNorm2d-15 [1, 64, 80, 80] 128
Conv2d-16 [1, 64, 80, 80] 4,096
BatchNorm2d-17 [1, 64, 80, 80] 128
ReLU-18 [1, 64, 80, 80] 0
Shuffle_Block-19 [1, 128, 80, 80] 0
Conv2d-20 [1, 64, 80, 80] 4,096
BatchNorm2d-21 [1, 64, 80, 80] 128
ReLU-22 [1, 64, 80, 80] 0
Conv2d-23 [1, 64, 80, 80] 576
BatchNorm2d-24 [1, 64, 80, 80] 128
Conv2d-25 [1, 64, 80, 80] 4,096
BatchNorm2d-26 [1, 64, 80, 80] 128
ReLU-27 [1, 64, 80, 80] 0
Shuffle_Block-28 [1, 128, 80, 80] 0
Conv2d-29 [1, 128, 40, 40] 1,152
BatchNorm2d-30 [1, 128, 40, 40] 256
Conv2d-31 [1, 128, 40, 40] 16,384
BatchNorm2d-32 [1, 128, 40, 40] 256
ReLU-33 [1, 128, 40, 40] 0
Conv2d-34 [1, 128, 80, 80] 16,384
BatchNorm2d-35 [1, 128, 80, 80] 256
ReLU-36 [1, 128, 80, 80] 0
Conv2d-37 [1, 128, 40, 40] 1,152
BatchNorm2d-38 [1, 128, 40, 40] 256
Conv2d-39 [1, 128, 40, 40] 16,384
BatchNorm2d-40 [1, 128, 40, 40] 256
ReLU-41 [1, 128, 40, 40] 0
Shuffle_Block-42 [1, 256, 40, 40] 0
Conv2d-43 [1, 128, 40, 40] 16,384
BatchNorm2d-44 [1, 128, 40, 40] 256
ReLU-45 [1, 128, 40, 40] 0
Conv2d-46 [1, 128, 40, 40] 1,152
BatchNorm2d-47 [1, 128, 40, 40] 256
Conv2d-48 [1, 128, 40, 40] 16,384
BatchNorm2d-49 [1, 128, 40, 40] 256
ReLU-50 [1, 128, 40, 40] 0
Shuffle_Block-51 [1, 256, 40, 40] 0
Conv2d-52 [1, 256, 20, 20] 2,304
BatchNorm2d-53 [1, 256, 20, 20] 512
Conv2d-54 [1, 256, 20, 20] 65,536
BatchNorm2d-55 [1, 256, 20, 20] 512
ReLU-56 [1, 256, 20, 20] 0
Conv2d-57 [1, 256, 40, 40] 65,536
BatchNorm2d-58 [1, 256, 40, 40] 512
ReLU-59 [1, 256, 40, 40] 0
Conv2d-60 [1, 256, 20, 20] 2,304
BatchNorm2d-61 [1, 256, 20, 20] 512
Conv2d-62 [1, 256, 20, 20] 65,536
BatchNorm2d-63 [1, 256, 20, 20] 512
ReLU-64 [1, 256, 20, 20] 0
Shuffle_Block-65 [1, 512, 20, 20] 0
Conv2d-66 [1, 256, 20, 20] 65,536
BatchNorm2d-67 [1, 256, 20, 20] 512
ReLU-68 [1, 256, 20, 20] 0
Conv2d-69 [1, 256, 20, 20] 2,304
BatchNorm2d-70 [1, 256, 20, 20] 512
Conv2d-71 [1, 256, 20, 20] 65,536
BatchNorm2d-72 [1, 256, 20, 20] 512
ReLU-73 [1, 256, 20, 20] 0
Shuffle_Block-74 [1, 512, 20, 20] 0
================================================================
Total params: 445,824
Trainable params: 445,824
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 4.69
Forward/backward pass size (MB): 270.31
Params size (MB): 1.70
Estimated Total Size (MB): 276.70
----------------------------------------------------------------
ShuffleNetV2(
(MobileNet_01): Sequential(
(0): CBRM(
(conv): Sequential(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
)
(1): Shuffle_Block(
(branch1): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(2): Shuffle_Block(
(branch2): Sequential(
(0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(3): Shuffle_Block(
(branch1): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(4): Shuffle_Block(
(branch2): Sequential(
(0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(5): Shuffle_Block(
(branch1): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=256, bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
(6): Shuffle_Block(
(branch2): Sequential(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): ReLU(inplace=True)
)
)
)
)
边栏推荐
猜你喜欢
随机推荐
【CNN基础】转置卷积学习笔记
安装GBase 8c数据库的时候,报错显示“Resource:gbase8c already in use”,这怎么处理呢?
金仓数据库KingbaseES客户端编程接口指南-JDBC(5. JDBC 查询结果集处理)
ExoPlayer添加Ffmpeg扩展实现软解功能
The difference between character stream and byte stream
给Unity Behavior Designer(Unity行为树) 的Can See Object 画圆锥辅助图
MYSQL JDBC图书管理系统
占位,稍后补上
redis分布式锁的实现
【剑指Offer】二分法例题
MySQL BIGINT 数据类型
Redis分布式锁的应用
金仓数据库的单节点如何转集群?
ContrstrainLayout的动画之ConstraintSet
经典递归回溯问题之——解数独(LeetCode 37)
轻量化Backbone VGNetG成就“不做选择,全都要”轻量化主干网络
「PHP基础知识」转换数据类型
form表单提交到数据库储存
JMeter 常用的几种断言方法,你会几种呢?
Lightweight Backbone VGNetG Achieves "No Choice, All" Lightweight Backbone Network