当前位置:网站首页>Tensorflow2.4实现RepVGG
Tensorflow2.4实现RepVGG
2022-06-30 19:33:00 【Haohao+++】
前言
RepVGG是清华大学&旷视科技等提出的一种新颖的CNN设计范式,避免了VGG类方法训练所得精度低的问题,又保持了VGG方案的高效推理优点。
论文地址: https://arxiv.org/abs/2101.03697
博客拿出我觉得重要的两点方法。
训练多分支结构
ResNet结构中ResBlock使用的是 y = x + f ( x ) y=x+f(x) y=x+f(x),尽管多分支结构对于推理不友好,但对于训练友好,作者将RepVGG设计为训练时的多分支,推理时单分支结构。作者参考ResNet的identity与 1 × 1 1\times1 1×1分支,设计了如下形式模块:
y = x + g ( x ) + f ( x ) y=x+g(x)+f(x) y=x+g(x)+f(x)
其中:
- g ( x ) g(x) g(x)为1x1卷积。
- f ( x ) f(x) f(x)为3x3卷积。
简单是快,内存经济型,灵活
- Fast:相比VGG,现有的多分支架构理论上具有更低的Flops,但推理速度并未更快。比如VGG16的参数量为EfficientNetB3的8.4倍,但在1080Ti上推理速度反而快1.8倍。这就意味着前者的计算密度是后者的15倍。Flops与推理速度的矛盾主要源自两个关键因素:(1) MAC(memory access cose),比如多分支结构的Add与Cat的计算很小,但MAC很高; (2)并行度,已有研究表明:并行度高的模型要比并行度低的模型推理速度更快。
- Memory-economical:多分支结构是一种内存低效的架构,这是因为每个分支的结构都需要在Add/Concat之前保存,这会导致更大的峰值内存占用;而plain模型则具有更好的内存高效特征。
- Flexible:多分支结构会限制CNN的灵活性,比如ResBlock会约束两个分支的tensor具有相同的形状;与此同时,多分支结构对于模型剪枝不够友好。
网络结构


代码实现
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Conv2D, BatchNormalization, GlobalAvgPool2D, Activation, Multiply,
Add, Dense, Input
)
# ----------------- #
# 卷积+标准化
# ----------------- #
def conv_bn(filters, kernel_size, strides, padding, groups=1):
def _conv_bn(x):
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, groups=groups, use_bias=False)(x)
x = BatchNormalization()(x)
return x
return _conv_bn
# ----------------- #
# SE模块
# ----------------- #
def SE_block(x_0, r = 16):
channels = x_0.shape[-1]
x = GlobalAvgPool2D()(x_0)
# (?, ?) -> (?, 1, 1, ?)
x = x[:, None, None, :]
# 用2个1x1卷积代替全连接
x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
x = Activation('relu')(x)
x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
x = Activation('sigmoid')(x)
x = Multiply()([x_0, x])
return x
# ----------------- #
# RepVGG模块
# ----------------- #
def RepVGGBlock(filters, kernel_size, strides=1, padding='valid', dilation=1, groups=1, deploy=False, use_se=False):
def _RepVGGBlock(inputs):
if deploy:
if use_se:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = SE_block(x)
x = Activation('relu')(x)
else:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = Activation('relu')(x)
return x
if inputs.shape[-1] == filters and strides == 1:
if use_se:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
return Activation('relu')(x3)
else:
if use_se:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
return Activation('relu')(x3)
return _RepVGGBlock
# ----------------- #
# RepVGG模块的堆叠
# ----------------- #
def make_stage(planes, num_blocks, stride_1,deploy,use_se, override_groups_map=None):
def _make_stage(x):
cur_layer_id=1
strides = [stride_1] + [1]*(num_blocks-1)
for stride in strides:
cur_groups = override_groups_map.get(cur_layer_id, 1)
x = RepVGGBlock(filters=planes, kernel_size=3, strides=stride, padding='same',
groups=cur_groups, deploy=deploy, use_se=use_se)(x)
cur_layer_id += 1
return x
return _make_stage
# ----------------- #
# RepVGG网络
# ----------------- #
def RepVGG(x, num_blocks, classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False):
override_groups_map = override_groups_map or dict()
in_planes = min(64, int(64 * width_multiplier[0]))
out = RepVGGBlock(filters=in_planes, kernel_size=3, strides=2, padding='same', deploy=deploy, use_se=use_se)(x)
out = make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = GlobalAvgPool2D()(out)
out = Dense(classes)(out)
return out
optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {
l: 2 for l in optional_groupwise_layers}
g4_map = {
l: 4 for l in optional_groupwise_layers}
def RepVGG_A0(inputs,classes=1000, deploy=False):
return RepVGG(inputs, num_blocks=[2, 4, 14, 1], classes=classes,
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A1(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A2(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], num_classes=1000,
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy)
def create_RepVGG_B0(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B1g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B2g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B2g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B3(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B3g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B3g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_D2se(x, deploy=False):
return RepVGG(x, num_blocks=[8, 14, 24, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True)
if __name__ == '__main__':
inputs = Input(shape=(224,224,3))
classes = 1000
model = Model(inputs=inputs, outputs=RepVGG_A0(inputs))
model.summary()
~欢迎更正
边栏推荐
- Playwright - 滚动条操作
- 《微信小程序-基础篇》带你了解小程序中的生命周期(二)
- VR全景中特效是如何编辑的?细节功能如何展示?
- dataloader 源码_DataLoader
- Unity 如何拖拉多个组件中的一个
- Is it safe to open an account for mobile phone stock trading!?
- DNS服务器搭建、转发、主从配置
- What securities dealers recommend? In addition, is it safe to open a mobile account?
- 线下门店为什么要做新零售?
- Smarter! Airiot accelerates the upgrading of energy conservation and emission reduction in the coal industry
猜你喜欢
Redis ziplist 压缩列表的源码解析

派尔特医疗在港交所招股书二次“失效”,上市计划实质性延迟

Cartoon | has Oracle been abandoned by the new era?

将 EMQX Cloud 数据通过公网桥接到 AWS IoT

TorchDrug--药物属性预测

VR全景中特效是如何编辑的?细节功能如何展示?

科大讯飞活跃竞赛汇总!(12个)

2022年高考都结束了,还有人真觉得程序员下班后不需要学习吗?

FH6908A负极关断同步整流模拟低压降二极管控制IC芯片TSOT23-6超低功耗整流器 1w功耗 <100uA静态 替代MP6908

MySQL billing Statistics (Part 1): MySQL installation and client dbeaver connection
随机推荐
传输层 使用滑动窗口实现流量控制
英语没学好到底能不能做coder,别再纠结了先学起来
Kubernetes为什么会赢,容器圈的风云变幻!
Tupu software has passed CMMI5 certification| High authority and high-level certification in the international software field
Ten percent of the time, the tar command can't parse the English bracket "()" when decompressing the file
如何使用robots.txt及其详解
c语言数组截取,C# 字符串按数组截取方法(C/S)
8 - function
《微信小程序-基础篇》带你了解小程序中的生命周期(二)
DNS服务器搭建、转发、主从配置
CADD课程学习(1)-- 药物设计基础知识
暑期实训21组第一周个人工作总结
为什么一定要从DevOps走向BizDevOps?
S7-1500 PLC之间进行TCP通信的具体方法和步骤详解(图文)
闲鱼难“翻身”
neo4j load csv 配置和使用
VR云展厅如何给线下实体带来活力?有哪些功能?
昔日果汁大王,16个亿卖了
VR全景中特效是如何编辑的?细节功能如何展示?
This morning, investors began to travel collectively