当前位置:网站首页>用全连接层替代掉卷积 -- RepMLP
用全连接层替代掉卷积 -- RepMLP
2022-07-02 06:26:00 【MezereonXP】
用全连接层替代掉卷积 – RepMLP
这次给大家介绍一个工作, “RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition”,是最近MLP热潮中的一篇有代表性的文章。
其github链接为https://github.com/DingXiaoH/RepMLP,有精力的朋友可以去跑一跑,看一看代码。
我们先回顾一下,先前的基于卷积网络的工作。之所以卷积网络能够有效,一定程度上是其对空间上的信息进行捕捉,通过多次的卷积提取到了空间上的特征,并且基本上覆盖了整张图片。假如我们将图片“拍平”然后用MLP进行训练,则失去了空间中的特征信息。
这篇文章的贡献在于:
- 利用了全连接(FC)的全局能力(global capacity) 以及 位置感知 (positional perception),将其应用到了图像识别上
- 提出了一种简单的、无关平台的 (platform-agnostic)、可差分的算法,来将卷积和BN合并成FC
- 充分的实验分析,验证了RepMLP的可行性
整体框架
整个RepMLP分为两个阶段:
- 训练阶段
- 测试阶段
针对这两个阶段,如下图所示:
看上去有些复杂,我们先单独看看训练阶段的部分。
首先是全局感知(global perceptron)
主要分为两条路径:
- 路径1: 平均池化 + BN + FC1 + ReLU + FC2
- 路径2: 分块
我们记输入张量的形状为 ( N , C , H , W ) (N,C,H,W) (N,C,H,W)
路径1
对于路径1,首先平均池化将输入转换成 ( N , C , H h , W w ) (N,C,\frac{H}{h},\frac{W}{w}) (N,C,hH,wW), 相当于缩放,然后绿色的部分表示将张量“拍平”
也就是变成 ( N , C H W h w ) (N,\frac{CHW}{hw}) (N,hwCHW) 形状的张量,经过两层FC层之后,维度仍然保持,因为整个FC就相当于左乘一个方阵。
最终对 ( N , C H W h w ) (N,\frac{CHW}{hw}) (N,hwCHW) 形状的输出进行reshape,得到一个形状是 ( N H W h w , C , 1 , 1 ) (\frac{NHW}{hw}, C, 1, 1) (hwNHW,C,1,1) 的输出
路径2
对于路径2,直接将输入 ( N , C , H , W ) (N,C,H,W) (N,C,H,W) 转换成 N H W h w \frac{NHW}{hw} hwNHW 个 ( h , w ) (h,w) (h,w) 的小块,其形状也就是 ( N H W h w , C , h , w ) (\frac{NHW}{hw},C,h,w) (hwNHW,C,h,w)
最后将路径1和路径2的结果做加法,由于维度对不上,不过在PyTorch中,会进行自动的copy操作,也就是所有的 ( h , w ) (h,w) (h,w) 大小的块的每一个像素,都会加上一个值。
这一个部分的输出形状为 ( N H W h w , C , h , w ) (\frac{NHW}{hw},C,h,w) (hwNHW,C,h,w)
然后进入局部感知和分块感知的部分,如下图所示:
对于分块感知(partition perceptron)
首先,将4维的张量拍成2维,即 ( N H W h w , C , h , w ) (\frac{NHW}{hw},C,h,w) (hwNHW,C,h,w) 变成 ( N H W h w , C h w ) (\frac{NHW}{hw},Chw) (hwNHW,Chw)
然后FC3是一个参照 分组卷积(groupwise conv) 的操作,其中 g g g 是组的数目
原本FC3应该是 ( O h w , C h w ) (Ohw,Chw) (Ohw,Chw) 的一个矩阵,但是为了降低参数量,使用了分组的FC(groupwise FC)
分组卷积本质上就是对通道进行分组,我举个例子:
假设输入是一个 ( C , H , W ) (C,H,W) (C,H,W) 的张量,如果我们希望输出是 ( N , H ′ , W ′ ) (N,H',W') (N,H′,W′)
通常我们的卷积核形状为 ( N , C , K , K ) (N,C,K,K) (N,C,K,K) ,其中 K K K 是卷积核的大小
我们对通道 C C C 进行分组,每 g g g 个通道为一组,那么就有 C g \frac{C}{g} gC 个组
对于单独每一个组,进行卷积操作,我们的卷积核形状就会缩小成 ( N , C g , K , K ) (N,\frac{C}{g},K,K) (N,gC,K,K)
在这里,分组FC也就是对通道数 C h w Chw Chw 进行分组然后每一个组过FC,最终得到 ( N H W h w , O , h , w ) (\frac{NHW}{hw}, O,h,w) (hwNHW,O,h,w) 的张量
再经过BN层,张量形状不变。
而对于局部感知(local perceptron)
类似FPN的思想,进行了不同尺度的分组卷积,得到了4个形状为 ( N H W h w , O , h , w ) (\frac{NHW}{hw}, O,h,w) (hwNHW,O,h,w) 的张量
把局部感知的结果和分块感知的结果相加,就得到了 ( N , O , H , W ) (N,O,H,W) (N,O,H,W) 的输出
到这里你可能会问,这不是还存在着卷积吗?
这只是训练阶段,在推理阶段,便会把卷积都扔掉,如下图所示:
至此,我们用MLP替代掉了一个卷积的操作
实验分析
首先是一系列消融实验(Ablation Study), 在CIFAR-10数据集上进行测试
A条件是在推断的时候保留BN层和conv层,结果没有变化
D,E条件分别是用一个9x9的卷积层替代掉FC3和整个RepMLP
Wide ConvNet是将本来的网络结构的通道数翻倍
结果说明局部感知和全局感知的重要性,同时推断的时候去除卷积部分没有影响,实现了MLP的替换
然后作者替换掉了ResNet50的一些block,进行了测试
只替换掉倒数第二个残差块,参数量多了一些,但是正确率有小幅度的增加
倘若我们完全替换掉更多的卷积部分
参数量会增加,正确率也会有小幅度的增加
边栏推荐
- ModuleNotFoundError: No module named ‘pytest‘
- Implementation of yolov5 single image detection based on pytorch
- 基于pytorch的YOLOv5单张图片检测实现
- [tricks] whiteningbert: an easy unsupervised sentence embedding approach
- Using compose to realize visible scrollbar
- ModuleNotFoundError: No module named ‘pytest‘
- 【TCDCN】《Facial landmark detection by deep multi-task learning》
- Apple added the first iPad with lightning interface to the list of retro products
- [CVPR‘22 Oral2] TAN: Temporal Alignment Networks for Long-term Video
- One book 1078: sum of fractional sequences
猜你喜欢
【BiSeNet】《BiSeNet:Bilateral Segmentation Network for Real-time Semantic Segmentation》
《Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer》论文翻译
Faster-ILOD、maskrcnn_benchmark安装过程及遇到问题
【BiSeNet】《BiSeNet:Bilateral Segmentation Network for Real-time Semantic Segmentation》
win10+vs2017+denseflow编译
Using compose to realize visible scrollbar
【Sparse-to-Dense】《Sparse-to-Dense:Depth Prediction from Sparse Depth Samples and a Single Image》
【Random Erasing】《Random Erasing Data Augmentation》
Point cloud data understanding (step 3 of pointnet Implementation)
[CVPR‘22 Oral2] TAN: Temporal Alignment Networks for Long-term Video
随机推荐
Win10+vs2017+denseflow compilation
【Cutout】《Improved Regularization of Convolutional Neural Networks with Cutout》
TimeCLR: A self-supervised contrastive learning framework for univariate time series representation
点云数据理解(PointNet实现第3步)
Faster-ILOD、maskrcnn_benchmark训练自己的voc数据集及问题汇总
【Mixed Pooling】《Mixed Pooling for Convolutional Neural Networks》
MMDetection模型微调
超时停靠视频生成
【双目视觉】双目矫正
基于onnxruntime的YOLOv5单张图片检测实现
Apple added the first iPad with lightning interface to the list of retro products
【AutoAugment】《AutoAugment:Learning Augmentation Policies from Data》
Win10 solves the problem that Internet Explorer cannot be installed
Use Baidu network disk to upload data to the server
聊天中文语料库对比(附上各资源链接)
Traditional target detection notes 1__ Viola Jones
【MagNet】《Progressive Semantic Segmentation》
【AutoAugment】《AutoAugment:Learning Augmentation Policies from Data》
【Mixed Pooling】《Mixed Pooling for Convolutional Neural Networks》
ModuleNotFoundError: No module named ‘pytest‘