当前位置:网站首页>Tensorflow2.4 implementation of repvgg
Tensorflow2.4 implementation of repvgg
2022-06-30 20:10:00 【Haohao+++】
Preface
RepVGG It's Tsinghua University & Kuangshi technology, etc. proposed a novel CNN Design paradigm , Avoided VGG The problem of low accuracy of class method training , Again VGG The advantages of efficient reasoning .
Address of thesis : https://arxiv.org/abs/2101.03697
The blog comes up with two important ways .
Training multi branch structures
ResNet In structure ResBlock It uses y = x + f ( x ) y=x+f(x) y=x+f(x), Although multi branch structure is not friendly to reasoning , But training friendly , The author will RepVGG Designed as a multi branch for training , Single branch structure in reasoning . The author refers to ResNet Of identity And 1 × 1 1\times1 1×1 Branch , The following formal modules are designed :
y = x + g ( x ) + f ( x ) y=x+g(x)+f(x) y=x+g(x)+f(x)
among :
- g ( x ) g(x) g(x) by 1x1 Convolution .
- f ( x ) f(x) f(x) by 3x3 Convolution .
Simple is fast , Memory economy , flexible
- Fast: comparison VGG, The existing multi - Branch architecture has lower Flops, But reasoning is not faster . such as VGG16 The parameter of is EfficientNetB3 Of 8.4 times , But in 1080Ti On the contrary, the reasoning speed is faster 1.8 times . This means that the computational density of the former is that of the latter 15 times .Flops The contradiction with reasoning speed mainly stems from two key factors :(1) MAC(memory access cose), For example, multi branch structure Add And Cat The calculation of is very small , but MAC Very high ; (2) Parallelism , Studies have shown that : The model with high parallelism has faster reasoning speed than the model with low parallelism .
- Memory-economical: Multi branch architecture is a memory inefficient architecture , This is because the structure of each branch needs to be in Add/Concat Save before , This results in a larger peak memory footprint ; and plain The model has better memory efficiency .
- Flexible: Multi branch structure will limit CNN The flexibility of the , such as ResBlock Will constrain two branches tensor They have the same shape ; meanwhile , Multi branch structure is not friendly to model pruning .
Network structure


Code implementation
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Conv2D, BatchNormalization, GlobalAvgPool2D, Activation, Multiply,
Add, Dense, Input
)
# ----------------- #
# Convolution + Standardization
# ----------------- #
def conv_bn(filters, kernel_size, strides, padding, groups=1):
def _conv_bn(x):
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, groups=groups, use_bias=False)(x)
x = BatchNormalization()(x)
return x
return _conv_bn
# ----------------- #
# SE modular
# ----------------- #
def SE_block(x_0, r = 16):
channels = x_0.shape[-1]
x = GlobalAvgPool2D()(x_0)
# (?, ?) -> (?, 1, 1, ?)
x = x[:, None, None, :]
# use 2 individual 1x1 Convolution instead of full connection
x = Conv2D(filters=channels//r, kernel_size=1, strides=1)(x)
x = Activation('relu')(x)
x = Conv2D(filters=channels, kernel_size=1, strides=1)(x)
x = Activation('sigmoid')(x)
x = Multiply()([x_0, x])
return x
# ----------------- #
# RepVGG modular
# ----------------- #
def RepVGGBlock(filters, kernel_size, strides=1, padding='valid', dilation=1, groups=1, deploy=False, use_se=False):
def _RepVGGBlock(inputs):
if deploy:
if use_se:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = SE_block(x)
x = Activation('relu')(x)
else:
x = Conv2D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, dilation_rate=dilation, groups=groups, use_bias=True)(inputs)
x = Activation('relu')(x)
return x
if inputs.shape[-1] == filters and strides == 1:
if use_se:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
id_out = BatchNormalization()(inputs)
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding=padding, groups=groups)(inputs)
x3 = Add()([id_out, x1, x2])
return Activation('relu')(x3)
else:
if use_se:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
x4 = SE_block(x3)
return Activation('relu')(x4)
else:
x1 = conv_bn(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, groups=groups)(inputs)
x2 = conv_bn(filters=filters, kernel_size=1, strides=strides, padding='valid', groups=groups)(inputs)
x3 = Add()([x1, x2])
return Activation('relu')(x3)
return _RepVGGBlock
# ----------------- #
# RepVGG Stacking of modules
# ----------------- #
def make_stage(planes, num_blocks, stride_1,deploy,use_se, override_groups_map=None):
def _make_stage(x):
cur_layer_id=1
strides = [stride_1] + [1]*(num_blocks-1)
for stride in strides:
cur_groups = override_groups_map.get(cur_layer_id, 1)
x = RepVGGBlock(filters=planes, kernel_size=3, strides=stride, padding='same',
groups=cur_groups, deploy=deploy, use_se=use_se)(x)
cur_layer_id += 1
return x
return _make_stage
# ----------------- #
# RepVGG The Internet
# ----------------- #
def RepVGG(x, num_blocks, classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False):
override_groups_map = override_groups_map or dict()
in_planes = min(64, int(64 * width_multiplier[0]))
out = RepVGGBlock(filters=in_planes, kernel_size=3, strides=2, padding='same', deploy=deploy, use_se=use_se)(x)
out = make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride_1=2, deploy=deploy, use_se=use_se, override_groups_map=override_groups_map)(out)
out = GlobalAvgPool2D()(out)
out = Dense(classes)(out)
return out
optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {
l: 2 for l in optional_groupwise_layers}
g4_map = {
l: 4 for l in optional_groupwise_layers}
def RepVGG_A0(inputs,classes=1000, deploy=False):
return RepVGG(inputs, num_blocks=[2, 4, 14, 1], classes=classes,
width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A1(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_A2(x, deploy=False):
return RepVGG(x, num_blocks=[2, 4, 14, 1], num_classes=1000,
width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy)
def create_RepVGG_B0(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy)
def create_RepVGG_B1g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B1g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B2g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B2g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_B3(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy)
def create_RepVGG_B3g2(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy)
def create_RepVGG_B3g4(x, deploy=False):
return RepVGG(x, num_blocks=[4, 6, 16, 1], num_classes=1000,
width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy)
def create_RepVGG_D2se(x, deploy=False):
return RepVGG(x, num_blocks=[8, 14, 24, 1], num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True)
if __name__ == '__main__':
inputs = Input(shape=(224,224,3))
classes = 1000
model = Model(inputs=inputs, outputs=RepVGG_A0(inputs))
model.summary()
~ Welcome to correct
边栏推荐
- 广州炒股开户选择手机办理安全吗?
- composer
- pycharm从安装到全副武装,学起来才嗖嗖的快,图片超多,因为过度详细!
- 盘点华为云GaussDB(for Redis)六大秒级能力
- RP原型资源分享-购物类App
- 操作系统面试题汇总(不定期更新)
- Why must a digital transformation strategy include continuous testing?
- [ICLR 2021] semi supervised object detection: unbiased teacher for semi supervised object detection
- 静态类使用@Resource注解注入
- 正则系列之字符类
猜你喜欢

【Try to Hack】Windows系统账户安全

《微信小程序-基础篇》带你了解小程序中的生命周期(二)

Data intelligence - dtcc2022! China database technology conference is about to open

4.3-inch touch screen 12 channel control port programmable network central control supports mutual backup of 5 central control hosts

Audio and video architecture construction in the super video era | science and Intel jointly launched the second season of "architect growth plan"

台湾SSS鑫创SSS1700替代Cmedia CM6533 24bit 96KHZ USB音频编解码芯片

新出生的机器狗,打滚1小时后自己掌握走路,吴恩达开山大弟子最新成果

RP原型资源分享-购物类App

CADD course learning (2) -- target crystal structure information

小学期,第三场-下午:WEB_sessionlfi
随机推荐
腾讯会议应用市场正式上线,首批入驻超20款应用
CADD course learning (2) -- target crystal structure information
Network planning | [five transport layers and six application layers] knowledge points and examples
十分之坑,tar命令解压文件的时候竟然不能解析英文括号“()”
MySQL billing Statistics (Part 1): MySQL installation and client dbeaver connection
mysql统计账单信息(上):mysql安装及客户端DBeaver连接使用
qt中toLocal8Bit和toUtf8()有什么区别
4.3-inch touch screen 12 channel control port programmable network central control supports mutual backup of 5 central control hosts
Why must we move from Devops to bizdevops?
Wechat applets - basics takes you to understand the life cycle of applets (2)
无线充U型超声波电动牙刷方案开发
The former king of fruit juice sold for 1.6 billion yuan
Detailed explanation of specific methods and steps for TCP communication between s7-1500 PLCs (picture and text)
MySQL数据库查询优化
Graduates
SM2246EN+闪迪15131
Warmup预热学习率「建议收藏」
TorchDrug--药物属性预测
为什么数字化转型战略必须包括持续测试?
项目经理面试常见问题及回答技巧