Convolutional MLP
ConvMLP: Hierarchical Convolutional MLPs for Vision
Preprint link: ConvMLP: Hierarchical Convolutional MLPs for Vision
By Jiachen Li[1,2], Ali Hassani[1]*, Steven Walton[1]*, and Humphrey Shi[1,2,3]
In association with SHI Lab @ University of Oregon[1] and University of Illinois Urbana-Champaign[2], and Picsart AI Research (PAIR)[3]
Abstract
MLP-based architectures, which consist of a sequence of consecutive multi-layer perceptron blocks, have recently been found to reach comparable results to convolutional and transformer-based methods. However, most adopt spatial MLPs which take fixed dimension inputs, therefore making it difficult to apply them to downstream tasks, such as object detection and semantic segmentation. Moreover, single-stage designs further limit performance in other computer vision tasks and fully connected layers bear heavy computation. To tackle these problems, we propose ConvMLP: a hierarchical Convolutional MLP for visual recognition, which is a light-weight, stage-wise, co-design of convolution layers, and MLPs. In particular, ConvMLP-S achieves 76.8% top-1 accuracy on ImageNet-1k with 9M parameters and 2.4 GMACs (15% and 19% of MLP-Mixer-B/16, respectively). Experiments on object detection and semantic segmentation further show that visual representation learned by ConvMLP can be seamlessly transferred and achieve competitive results with fewer parameters.
How to run
Getting Started
Our base model is in pure PyTorch and Torchvision. No extra packages are required. Please refer to PyTorch's Getting Started page for detailed instructions.
You can start off with src.convmlp
, which contains the three variants: convmlp_s
, convmlp_m
, convmlp_l
:
from src.convmlp import convmlp_l, convmlp_s
model = convmlp_l(pretrained=True, progress=True)
model_sm = convmlp_s(num_classes=10)
Image Classification
timm is recommended for image classification training and required for the training script provided in this repository:
./dist_classification.sh $NUM_GPUS -c $CONFIG_FILE /path/to/dataset
You can use our training configurations provided in configs/classification
:
./dist_classification.sh 8 -c configs/classification/convmlp_s_imagenet.yml /path/to/ImageNet
./dist_classification.sh 8 -c configs/classification/convmlp_m_imagenet.yml /path/to/ImageNet
./dist_classification.sh 8 -c configs/classification/convmlp_l_imagenet.yml /path/to/ImageNet
Object Detection
mmdetection is recommended for object detection training and required for the training script provided in this repository:
./dist_detection.sh $CONFIG_FILE $NUM_GPUS /path/to/dataset
You can use our training configurations provided in configs/detection
:
./dist_detection.sh configs/detection/retinanet_convmlp_s_fpn_1x_coco.py 8 /path/to/COCO
./dist_detection.sh configs/detection/retinanet_convmlp_m_fpn_1x_coco.py 8 /path/to/COCO
./dist_detection.sh configs/detection/retinanet_convmlp_l_fpn_1x_coco.py 8 /path/to/COCO
Object Detection & Instance Segmentation
mmdetection is recommended for training Mask R-CNN and required for the training script provided in this repository (same as above).
You can use our training configurations provided in configs/detection
:
./dist_detection.sh configs/detection/maskrcnn_convmlp_s_fpn_1x_coco.py 8 /path/to/COCO
./dist_detection.sh configs/detection/maskrcnn_convmlp_m_fpn_1x_coco.py 8 /path/to/COCO
./dist_detection.sh configs/detection/maskrcnn_convmlp_l_fpn_1x_coco.py 8 /path/to/COCO
Semantic Segmentation
mmsegmentation is recommended for semantic segmentation training and required for the training script provided in this repository:
./dist_segmentation.sh $CONFIG_FILE $NUM_GPUS /path/to/dataset
You can use our training configurations provided in configs/segmentation
:
./dist_segmentation.sh configs/segmentation/fpn_convmlp_s_512x512_40k_ade20k.py 8 /path/to/ADE20k
./dist_segmentation.sh configs/segmentation/fpn_convmlp_m_512x512_40k_ade20k.py 8 /path/to/ADE20k
./dist_segmentation.sh configs/segmentation/fpn_convmlp_l_512x512_40k_ade20k.py 8 /path/to/ADE20k
Results
Image Classification
Feature maps from ResNet50, MLP-Mixer-B/16, our Pure-MLP Baseline and ConvMLP-M are presented in the image below. It can be observed that representations learned by ConvMLP involve more low-level features like edges or textures compared to the rest.
Dataset | Model | Top-1 Accuracy | # Params | MACs |
ImageNet | ConvMLP-S | 76.8% | 9.0M | 2.4G |
ConvMLP-M | 79.0% | 17.4M | 3.9G | |
ConvMLP-L | 80.2% | 42.7M | 9.9G |
If importing the classification models, you can pass pretrained=True
to download and set these checkpoints. The same holds for the training script (classification.py
and dist_classification.sh
): pass --pretrained
. The segmentation/detection training scripts also download the pretrained backbone if you pass the correct config files.
Downstream tasks
You can observe the summarized results from applying our model to object detection, instance and semantic segmentation, compared to ResNet, in the image below.
Object Detection
Dataset | Model | Backbone | # Params | APb | APb50 | APb75 | Checkpoint |
MS COCO | Mask R-CNN | ConvMLP-S | 28.7M | 38.4 | 59.8 | 41.8 | Download |
ConvMLP-M | 37.1M | 40.6 | 61.7 | 44.5 | Download | ||
ConvMLP-L | 62.2M | 41.7 | 62.8 | 45.5 | Download | ||
RetinaNet | ConvMLP-S | 18.7M | 37.2 | 56.4 | 39.8 | Download | |
ConvMLP-M | 27.1M | 39.4 | 58.7 | 42.0 | Download | ||
ConvMLP-L | 52.9M | 40.2 | 59.3 | 43.3 | Download |
Instance Segmentation
Dataset | Model | Backbone | # Params | APm | APm50 | APm75 | Checkpoint |
MS COCO | Mask R-CNN | ConvMLP-S | 28.7M | 35.7 | 56.7 | 38.2 | Download |
ConvMLP-M | 37.1M | 37.2 | 58.8 | 39.8 | Download | ||
ConvMLP-L | 62.2M | 38.2 | 59.9 | 41.1 | Download |
Semantic Segmentation
Dataset | Model | Backbone | # Params | mIoU | Checkpoint |
ADE20k | Semantic FPN | ConvMLP-S | 12.8M | 35.8 | Download |
ConvMLP-M | 21.1M | 38.6 | Download | ||
ConvMLP-L | 46.3M | 40.0 | Download |
Transfer
Dataset | Model | Top-1 Accuracy | # Params |
CIFAR-10 | ConvMLP-S | 98.0% | 8.51M |
ConvMLP-M | 98.6% | 16.90M | |
ConvMLP-L | 98.6% | 41.97M | |
CIFAR-100 | ConvMLP-S | 87.4% | 8.56M |
ConvMLP-M | 89.1% | 16.95M | |
ConvMLP-L | 88.6% | 42.04M | |
Flowers-102 | ConvMLP-S | 99.5% | 8.56M |
ConvMLP-M | 99.5% | 16.95M | |
ConvMLP-L | 99.5% | 42.04M |
Citation
@article{li2021convmlp,
title={ConvMLP: Hierarchical Convolutional MLPs for Vision},
author={Jiachen Li and Ali Hassani and Steven Walton and Humphrey Shi},
year={2021},
eprint={2109.04454},
archivePrefix={arXiv},
primaryClass={cs.CV}
}