当前位置:网站首页>用全连接层替代掉卷积 -- RepMLP
用全连接层替代掉卷积 -- RepMLP
2022-07-02 06:26:00 【MezereonXP】
用全连接层替代掉卷积 – RepMLP
这次给大家介绍一个工作, “RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition”,是最近MLP热潮中的一篇有代表性的文章。
其github链接为https://github.com/DingXiaoH/RepMLP,有精力的朋友可以去跑一跑,看一看代码。
我们先回顾一下,先前的基于卷积网络的工作。之所以卷积网络能够有效,一定程度上是其对空间上的信息进行捕捉,通过多次的卷积提取到了空间上的特征,并且基本上覆盖了整张图片。假如我们将图片“拍平”然后用MLP进行训练,则失去了空间中的特征信息。
这篇文章的贡献在于:
- 利用了全连接(FC)的全局能力(global capacity) 以及 位置感知 (positional perception),将其应用到了图像识别上
- 提出了一种简单的、无关平台的 (platform-agnostic)、可差分的算法,来将卷积和BN合并成FC
- 充分的实验分析,验证了RepMLP的可行性
整体框架
整个RepMLP分为两个阶段:
- 训练阶段
- 测试阶段
针对这两个阶段,如下图所示:

看上去有些复杂,我们先单独看看训练阶段的部分。
首先是全局感知(global perceptron)

主要分为两条路径:
- 路径1: 平均池化 + BN + FC1 + ReLU + FC2
- 路径2: 分块
我们记输入张量的形状为 ( N , C , H , W ) (N,C,H,W) (N,C,H,W)
路径1
对于路径1,首先平均池化将输入转换成 ( N , C , H h , W w ) (N,C,\frac{H}{h},\frac{W}{w}) (N,C,hH,wW), 相当于缩放,然后绿色的部分表示将张量“拍平”
也就是变成 ( N , C H W h w ) (N,\frac{CHW}{hw}) (N,hwCHW) 形状的张量,经过两层FC层之后,维度仍然保持,因为整个FC就相当于左乘一个方阵。
最终对 ( N , C H W h w ) (N,\frac{CHW}{hw}) (N,hwCHW) 形状的输出进行reshape,得到一个形状是 ( N H W h w , C , 1 , 1 ) (\frac{NHW}{hw}, C, 1, 1) (hwNHW,C,1,1) 的输出
路径2
对于路径2,直接将输入 ( N , C , H , W ) (N,C,H,W) (N,C,H,W) 转换成 N H W h w \frac{NHW}{hw} hwNHW 个 ( h , w ) (h,w) (h,w) 的小块,其形状也就是 ( N H W h w , C , h , w ) (\frac{NHW}{hw},C,h,w) (hwNHW,C,h,w)
最后将路径1和路径2的结果做加法,由于维度对不上,不过在PyTorch中,会进行自动的copy操作,也就是所有的 ( h , w ) (h,w) (h,w) 大小的块的每一个像素,都会加上一个值。
这一个部分的输出形状为 ( N H W h w , C , h , w ) (\frac{NHW}{hw},C,h,w) (hwNHW,C,h,w)
然后进入局部感知和分块感知的部分,如下图所示:

对于分块感知(partition perceptron)
首先,将4维的张量拍成2维,即 ( N H W h w , C , h , w ) (\frac{NHW}{hw},C,h,w) (hwNHW,C,h,w) 变成 ( N H W h w , C h w ) (\frac{NHW}{hw},Chw) (hwNHW,Chw)
然后FC3是一个参照 分组卷积(groupwise conv) 的操作,其中 g g g 是组的数目
原本FC3应该是 ( O h w , C h w ) (Ohw,Chw) (Ohw,Chw) 的一个矩阵,但是为了降低参数量,使用了分组的FC(groupwise FC)
分组卷积本质上就是对通道进行分组,我举个例子:
假设输入是一个 ( C , H , W ) (C,H,W) (C,H,W) 的张量,如果我们希望输出是 ( N , H ′ , W ′ ) (N,H',W') (N,H′,W′)
通常我们的卷积核形状为 ( N , C , K , K ) (N,C,K,K) (N,C,K,K) ,其中 K K K 是卷积核的大小
我们对通道 C C C 进行分组,每 g g g 个通道为一组,那么就有 C g \frac{C}{g} gC 个组
对于单独每一个组,进行卷积操作,我们的卷积核形状就会缩小成 ( N , C g , K , K ) (N,\frac{C}{g},K,K) (N,gC,K,K)
在这里,分组FC也就是对通道数 C h w Chw Chw 进行分组然后每一个组过FC,最终得到 ( N H W h w , O , h , w ) (\frac{NHW}{hw}, O,h,w) (hwNHW,O,h,w) 的张量
再经过BN层,张量形状不变。
而对于局部感知(local perceptron)

类似FPN的思想,进行了不同尺度的分组卷积,得到了4个形状为 ( N H W h w , O , h , w ) (\frac{NHW}{hw}, O,h,w) (hwNHW,O,h,w) 的张量
把局部感知的结果和分块感知的结果相加,就得到了 ( N , O , H , W ) (N,O,H,W) (N,O,H,W) 的输出
到这里你可能会问,这不是还存在着卷积吗?
这只是训练阶段,在推理阶段,便会把卷积都扔掉,如下图所示:

至此,我们用MLP替代掉了一个卷积的操作
实验分析
首先是一系列消融实验(Ablation Study), 在CIFAR-10数据集上进行测试

A条件是在推断的时候保留BN层和conv层,结果没有变化
D,E条件分别是用一个9x9的卷积层替代掉FC3和整个RepMLP
Wide ConvNet是将本来的网络结构的通道数翻倍
结果说明局部感知和全局感知的重要性,同时推断的时候去除卷积部分没有影响,实现了MLP的替换
然后作者替换掉了ResNet50的一些block,进行了测试

只替换掉倒数第二个残差块,参数量多了一些,但是正确率有小幅度的增加
倘若我们完全替换掉更多的卷积部分

参数量会增加,正确率也会有小幅度的增加
边栏推荐
- How to clean up logs on notebook computers to improve the response speed of web pages
- MoCO ——Momentum Contrast for Unsupervised Visual Representation Learning
- 基于pytorch的YOLOv5单张图片检测实现
- 解决latex图片浮动的问题
- PPT的技巧
- 【Batch】learning notes
- conda常用命令
- Thesis writing tip2
- 基于onnxruntime的YOLOv5单张图片检测实现
- The difference and understanding between generative model and discriminant model
猜你喜欢

常见的机器学习相关评价指标

TimeCLR: A self-supervised contrastive learning framework for univariate time series representation

Regular expressions in MySQL

Execution of procedures
![[model distillation] tinybert: distilling Bert for natural language understanding](/img/c1/e1c1a3cf039c4df1b59ef4b4afbcb2.png)
[model distillation] tinybert: distilling Bert for natural language understanding

【Programming】

MMDetection安装问题

点云数据理解(PointNet实现第3步)
![[mixup] mixup: Beyond Imperial Risk Minimization](/img/14/8d6a76b79a2317fa619e6b7bf87f88.png)
[mixup] mixup: Beyond Imperial Risk Minimization

How do vision transformer work?【论文解读】
随机推荐
【Sparse-to-Dense】《Sparse-to-Dense:Depth Prediction from Sparse Depth Samples and a Single Image》
ABM论文翻译
程序的内存模型
【Wing Loss】《Wing Loss for Robust Facial Landmark Localisation with Convolutional Neural Networks》
【BiSeNet】《BiSeNet:Bilateral Segmentation Network for Real-time Semantic Segmentation》
ModuleNotFoundError: No module named ‘pytest‘
Conversion of numerical amount into capital figures in PHP
Two dimensional array de duplication in PHP
PHP returns the corresponding key value according to the value in the two-dimensional array
【双目视觉】双目立体匹配
CONDA common commands
基于pytorch的YOLOv5单张图片检测实现
《Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer》论文翻译
Common machine learning related evaluation indicators
Regular expressions in MySQL
Point cloud data understanding (step 3 of pointnet Implementation)
Calculate the total in the tree structure data in PHP
点云数据理解(PointNet实现第3步)
open3d学习笔记三【采样与体素化】
[mixup] mixup: Beyond Imperial Risk Minimization