当前位置:网站首页>Tensorflow2.4实现RepVGG
Tensorflow2.4实现RepVGG
2022-06-30 19:33:00 【Haohao+++】
前言
RepVGG是清华大学&旷视科技等提出的一种新颖的CNN设计范式,避免了VGG类方法训练所得精度低的问题,又保持了VGG方案的高效推理优点。
论文地址: https://arxiv.org/abs/2101.03697
博客拿出我觉得重要的两点方法。
训练多分支结构
ResNet结构中ResBlock使用的是 y = x + f ( x ) y=x+f(x) y=x+f(x),尽管多分支结构对于推理不友好,但对于训练友好,作者将RepVGG设计为训练时的多分支,推理时单分支结构。作者参考ResNet的identity与 1 × 1 1\times1 1×1分支,设计了如下形式模块:
y = x + g ( x ) + f ( x ) y=x+g(x)+f(x) y=x+g(x)+f(x)
其中:
- g ( x ) g(x) g(x)为1x1卷积。
- f ( x ) f(x) f(x)为3x3卷积。
简单是快,内存经济型,灵活
- Fast:相比VGG,现有的多分支架构理论上具有更低的Flops,但推理速度并未更快。比如VGG16的参数量为EfficientNetB3的8.4倍,但在1080Ti上推理速度反而快1.8倍。这就意味着前者的计算密度是后者的15倍。Flops与推理速度的矛盾主要源自两个关键因素:(1) MAC(memory access cose),比如多分支结构的Add与Cat的计算很小,但MAC很高; (2)并行度,已有研究表明:并行度高的模型要比并行度低的模型推理速度更快。
- Memory-economical:多分支结构是一种内存低效的架构,这是因为每个分支的结构都需要在Add/Concat之前保存,这会导致更大的峰值内存占用;而plain模型则具有更好的内存高效特征。
- Flexible:多分支结构会限制CNN的灵活性,比如ResBlock会约束两个分支的tensor具有相同的形状;与此同时,多分支结构对于模型剪枝不够友好。
网络结构
代码实现
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Conv2D, BatchNormalization, GlobalAvgPool2D, Activation, Multiply,
Add, Dense, Input
)
# ----------------- #
# 卷积+标准化
# ----------------- #
def conv_bn(filters, kernel_size, strides, padding, groups=1):
def _conv_bn(x):
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, groups=groups, use_bias=False)(x)
x = BatchNormalization()(x)
return x
return _conv_bn
# ----------------- #
# SE模块
# ----------------- #
def SE_block(x_0, r = 16):
channels = x_0.shape[-1]
x = GlobalAvgPool2D()(x_0)
# (?, ?) -> (?, 1, 1, ?)
x = x[:, None, None, :]
# 用2个1x1卷积代替全连接
x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
x = Activation('relu')(x)
x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
x = Activation('sigmoid')(x)
x = Multiply()([x_0, x])
return x
# ----------------- #
# RepVGG模块
# ----------------- #
def RepVGGBlock(filters, kernel_size, strides=1, padding='valid', dilation=1, groups=1, deploy=False, use_se=False):
def _RepVGGBlock(inputs):
if deploy:
if use_se:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = SE_block(x)
x = Activation('relu')(x)
else:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = Activation('relu')(x)
return x
if inputs.shape[-1] == filters and strides == 1:
if use_se:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
return Activation('relu')(x3)
else:
if use_se:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
return Activation('relu')(x3)
return _RepVGGBlock
# ----------------- #
# RepVGG模块的堆叠
# ----------------- #
def make_stage(planes, num_blocks, stride_1,deploy,use_se, override_groups_map=None):
def _make_stage(x):
cur_layer_id=1
strides = [stride_1] + [1]*(num_blocks-1)
for stride in strides:
cur_groups = override_groups_map.get(cur_layer_id, 1)
x = RepVGGBlock(filters=planes, kernel_size=3, strides=stride, padding='same',
groups=cur_groups, deploy=deploy, use_se=use_se)(x)
cur_layer_id += 1
return x
return _make_stage
# ----------------- #
# RepVGG网络
# ----------------- #
def RepVGG(x, num_blocks, classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False):
override_groups_map = override_groups_map or dict()
in_planes = min(64, int(64 * width_multiplier[0]))
out = RepVGGBlock(filters=in_planes, kernel_size=3, strides=2, padding='same', deploy=deploy, use_se=use_se)(x)
out = make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = GlobalAvgPool2D()(out)
out = Dense(classes)(out)
return out
optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {
l: 2 for l in optional_groupwise_layers}
g4_map = {
l: 4 for l in optional_groupwise_layers}
def RepVGG_A0(inputs,classes=1000, deploy=False):
return RepVGG(inputs, num_blocks=[2, 4, 14, 1], classes=classes,
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A1(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A2(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], num_classes=1000,
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy)
def create_RepVGG_B0(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B1g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B2g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B2g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B3(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B3g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B3g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_D2se(x, deploy=False):
return RepVGG(x, num_blocks=[8, 14, 24, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True)
if __name__ == '__main__':
inputs = Input(shape=(224,224,3))
classes = 1000
model = Model(inputs=inputs, outputs=RepVGG_A0(inputs))
model.summary()
~欢迎更正
边栏推荐
猜你喜欢
WordPress 博客使用火山引擎 veImageX 进行静态资源 CDN 加速(免费)
科大讯飞活跃竞赛汇总!(12个)
Primary school, session 3 - afternoon: Web_ sessionlfi
正则系列之字符类
S7-1500 PLC之间进行TCP通信的具体方法和步骤详解(图文)
4.3-inch touch screen 12 channel control port programmable network central control supports mutual backup of 5 central control hosts
屏幕显示技术进化史
A necessary tool for testing -- postman practical tutorial
Source code analysis of redis ziplist compressed list
CADD课程学习(1)-- 药物设计基础知识
随机推荐
Detailed steps for Django to upload excel tables and write data to the database
MySQL billing Statistics (Part 1): MySQL installation and client dbeaver connection
小学期,第三场-下午:WEB_sessionlfi
SM2246EN+闪迪15131
Audio and video architecture construction in the super video era | science and Intel jointly launched the second season of "architect growth plan"
composer
Why should offline stores do new retail?
新出生的机器狗,打滚1小时后自己掌握走路,吴恩达开山大弟子最新成果
This morning, investors began to travel collectively
Enterprise middle office planning and it architecture microservice transformation
Ten percent of the time, the tar command can't parse the English bracket "()" when decompressing the file
基于slate构建文档编辑器
Primary school, session 3 - afternoon: Web_ sessionlfi
Lombok
盘点华为云GaussDB(for Redis)六大秒级能力
将 EMQX Cloud 数据通过公网桥接到 AWS IoT
Detailed explanation of specific methods and steps for TCP communication between s7-1500 PLCs (picture and text)
Transport layer uses sliding window to realize flow control
超视频时代的音视频架构建设|Science和英特尔联袂推出“架构师成长计划”第二季
Django上传excel表格并将数据写入数据库的详细步骤