The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

Related tags

Deep LearningELSA
Overview

ELSA: Enhanced Local Self-Attention for Vision Transformer

By Jingkai Zhou, Pichao Wang*, Fan Wang, Qiong Liu, Hao Li, Rong Jin

This repo is the official implementation of "ELSA: Enhanced Local Self-Attention for Vision Transformer".

Introduction

Self-attention is powerful in modeling long-range dependencies, but it is weak in local finer-level feature learning. As shown in Figure 1, the performance of local self-attention (LSA) is just on par with convolution and inferior to dynamic filters, which puzzles researchers on whether to use LSA or its counterparts, which one is better, and what makes LSA mediocre. In this work, we comprehensively investigate LSA and its counterparts. We find that the devil lies in the generation and application of spatial attention.

Based on these findings, we propose the enhanced local self-attention (ELSA) with Hadamard attention and the ghost head, as illustrated in Figure 2. Experiments demonstrate the effectiveness of ELSA. Without architecture / hyperparameter modification, The use of ELSA in drop-in replacement boosts baseline methods consistently in both upstream and downstream tasks.

Please refer to our paper for more details.

Model zoo

ImageNet Classification

Model #Params Pretrain Resolution Top1 Acc Download
ELSA-Swin-T 28M ImageNet 1K 224 82.7 google / baidu
ELSA-Swin-S 53M ImageNet 1K 224 83.5 google / baidu
ELSA-Swin-B 93M ImageNet 1K 224 84.0 google / baidu

COCO Object Detection

Backbone Method Pretrain Lr Schd Box mAP Mask mAP #Params Download
ELSA-Swin-T Mask R-CNN ImageNet-1K 1x 45.7 41.1 49M google / baidu
ELSA-Swin-T Mask R-CNN ImageNet-1K 3x 47.5 42.7 49M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 1x 48.3 43.0 72M google / baidu
ELSA-Swin-S Mask R-CNN ImageNet-1K 3x 49.2 43.6 72M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 1x 49.8 43.0 86M google / baidu
ELSA-Swin-T Cascade Mask R-CNN ImageNet-1K 3x 51.0 44.2 86M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 1x 51.6 44.4 110M google / baidu
ELSA-Swin-S Cascade Mask R-CNN ImageNet-1K 3x 52.3 45.2 110M google / baidu

ADE20K Semantic Segmentation

Backbone Method Pretrain Crop Size Lr Schd mIoU (ms+flip) #Params Download
ELSA-Swin-T UPerNet ImageNet-1K 512x512 160K 47.9 61M google / baidu
ELSA-Swin-S UperNet ImageNet-1K 512x512 160K 50.4 85M google / baidu

Install

  • Clone this repo:
git clone https://github.com/damo-cv/ELSA.git elsa
cd elsa
  • Create a conda virtual environment and activate it:
conda create -n elsa python=3.7 -y
conda activate elsa
  • Install PyTorch==1.8.0 and torchvision==0.9.0 with CUDA==10.1:
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.1 -c pytorch
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ../
  • Install mmcv-full==1.3.0
pip install mmcv-full==1.3.0 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.8.0/index.html
  • Install other requirements:
pip install -r requirements.txt
  • Install mmdet and mmseg:
cd ./det
pip install -v -e .
cd ../seg
pip install -v -e .
cd ../
  • Build the elsa operation:
cd ./cls/models/elsa
python setup.py install
mv build/lib*/* .
cp *.so ../../../det/mmdet/models/backbones/elsa/
cp *.so ../../../seg/mmseg/models/backbones/elsa/
cd ../../../

Data preparation

We use standard ImageNet dataset, you can download it from http://image-net.org/. Please prepare it under the following file structure:

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

Also, please prepare the COCO and ADE20K datasets following their links. Then, please link them to det/data and seg/data.

Evaluation

ImageNet Classification

Run following scripts to evaluate pre-trained models on the ImageNet-1K:

cd cls

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_tiny --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_small --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128

python validate.py <PATH_TO_IMAGENET> --model elsa_swin_base --checkpoint <CHECKPOINT_FILE> \
  --no-test-pool --apex-amp --img-size 224 -b 128 --use-ema

COCO Detection

Run following scripts to evaluate a detector on the COCO:

cd det

# single-gpu testing
python tools/test.py <CONFIG_FILE> <DET_CHECKPOINT_FILE> --eval bbox segm

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <DET_CHECKPOINT_FILE> <GPU_NUM> --eval bbox segm

ADE20K Semantic Segmentation

Run following scripts to evaluate a model on the ADE20K:

cd seg

# single-gpu testing
python tools/test.py <CONFIG_FILE> <SEG_CHECKPOINT_FILE> --aug-test --eval mIoU

# multi-gpu testing
tools/dist_test.sh <CONFIG_FILE> <SEG_CHECKPOINT_FILE> <GPU_NUM> --aug-test --eval mIoU

Training from scratch

Due to randomness, the re-training results may have a gap of about 0.1~0.2% with the numbers in the paper.

ImageNet Classification

Run following scripts to train classifiers on the ImageNet-1K:

cd cls

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_tiny \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.1 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_small \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.3 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp

bash ./distributed_train.sh 8 <PATH_TO_IMAGENET> --model elsa_swin_base \
  --epochs 300 -b 128 -j 8 --opt adamw --lr 1e-3 --sched cosine --weight-decay 5e-2 \
  --warmup-epochs 20 --warmup-lr 1e-6 --min-lr 1e-5 --drop-path 0.5 --aa rand-m9-mstd0.5-inc1 \
  --mixup 0.8 --cutmix 1. --remode pixel --reprob 0.25 --clip-grad 5. --amp --model-ema

If GPU memory is not enough when training elsa_swin_base, you can use two nodes (2 * 8 GPUs), each with a batch size of 64 images/GPU.

COCO Detection / ADE20K Semantic Segmentation

Run following scripts to train models on the COCO / ADE20K:

cd det 
# (or cd seg)

# multi-gpu training
tools/dist_train.sh <CONFIG_FILE> <GPU_NUM> --cfg-options model.pretrained=<PRETRAIN_MODEL> [model.backbone.use_checkpoint=True] [other optional arguments] 

Acknowledgement

This work was supported by Alibaba Group through Alibaba Research Intern Program and the National Natural Science Foundation of China (No.61976094).

Codebase from pytorch-image-models, ddfnet, VOLO, Swin-Transformer, Swin-Transformer-Detection, and Swin-Transformer-Semantic-Segmentation

Citing ELSA

@article{zhou2021ELSA,
  title={ELSA: Enhanced Local Self-Attention for Vision Transformer},
  author={Zhou, Jingkai and Wang, Pichao and Wang, Fan and Liu, Qiong and Li, Hao and Jin, Rong},
  journal={arXiv preprint arXiv:2112.12786},
  year={2021}
}
Owner
DamoCV
CV team of DAMO academy
DamoCV
⚖️🔁🔮🕵️‍♂️🦹🖼️ Code for *Measuring the Contribution of Multiple Model Representations in Detecting Adversarial Instances* paper.

Measuring the Contribution of Multiple Model Representations in Detecting Adversarial Instances This repository contains the code for Measuring the Co

Daniel Steinberg 0 Nov 06, 2022
Real-Time-Student-Attendence-System - Real Time Student Attendence System

Real-Time-Student-Attendence-System The Student Attendance Management System Pro

Rounak Das 1 Feb 15, 2022
Official implementation of the ICCV 2021 paper "Joint Inductive and Transductive Learning for Video Object Segmentation"

JOINT This is the official implementation of Joint Inductive and Transductive learning for Video Object Segmentation, to appear in ICCV 2021. @inproce

Yunyao 35 Oct 16, 2022
Source code for 2021 ICCV paper "In-the-Wild Single Camera 3D Reconstruction Through Moving Water Surfaces"

In-the-Wild Single Camera 3D Reconstruction Through Moving Water Surfaces This is the PyTorch implementation for 2021 ICCV paper "In-the-Wild Single C

27 Dec 06, 2022
Pythonic particle-based (super-droplet) warm-rain/aqueous-chemistry cloud microphysics package with box, parcel & 1D/2D prescribed-flow examples in Python, Julia and Matlab

PySDM PySDM is a package for simulating the dynamics of population of particles. It is intended to serve as a building block for simulation systems mo

Atmospheric Cloud Simulation Group @ Jagiellonian University 32 Oct 18, 2022
This is a JAX implementation of Neural Radiance Fields for learning purposes.

learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi

Alex Nichol 62 Dec 20, 2022
Sky Computing: Accelerating Geo-distributed Computing in Federated Learning

Sky Computing Introduction Sky Computing is a load-balanced framework for federated learning model parallelism. It adaptively allocate model layers to

HPC-AI Tech 72 Dec 27, 2022
A Keras implementation of YOLOv3 (Tensorflow backend)

keras-yolo3 Introduction A Keras implementation of YOLOv3 (Tensorflow backend) inspired by allanzelener/YAD2K. Quick Start Download YOLOv3 weights fro

7.1k Jan 03, 2023
This repository contains code, network definitions and pre-trained models for working on remote sensing images using deep learning

Deep learning for Earth Observation This repository contains code, network definitions and pre-trained models for working on remote sensing images usi

Nicolas Audebert 447 Jan 05, 2023
Simple is not Easy: A Simple Strong Baseline for TextVQA and TextCaps[AAAI2021]

Simple is not Easy: A Simple Strong Baseline for TextVQA and TextCaps Here is the code for ssbassline model. We also provide OCR results/features/mode

ZephyrZhuQi 51 Nov 18, 2022
Evaluation suite for large-scale language models.

This repo contains code for running the evaluations and reproducing the results from the Jurassic-1 Technical Paper (see blog post), with current support for running the tasks through both the AI21 S

71 Dec 17, 2022
A Vision Transformer approach that uses concatenated query and reference images to learn the relationship between query and reference images directly.

A Vision Transformer approach that uses concatenated query and reference images to learn the relationship between query and reference images directly.

24 Dec 13, 2022
Optimising chemical reactions using machine learning

Summit Summit is a set of tools for optimising chemical processes. We’ve started by targeting reactions. What is Summit? Currently, reaction optimisat

Sustainable Reaction Engineering Group 75 Dec 14, 2022
Code for the paper "PortraitNet: Real-time portrait segmentation network for mobile device" @ CAD&Graphics2019

PortraitNet Code for the paper "PortraitNet: Real-time portrait segmentation network for mobile device". @ CAD&Graphics 2019 Introduction We propose a

265 Dec 01, 2022
This program automatically runs Python code copied in clipboard

CopyRun This program runs Python code which is copied in clipboard WARNING!! USE AT YOUR OWN RISK! NO GUARANTIES IF ANYTHING GETS BROKEN. DO NOT COPY

vertinski 4 Sep 10, 2021
DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

DeepSpeed is a deep learning optimization library that makes distributed training easy, efficient, and effective.

Microsoft 8.4k Jan 01, 2023
Top #1 Submission code for the first https://alphamev.ai MEV competition with best AUC (0.9893) and MSE (0.0982).

alphamev-winning-submission Top #1 Submission code for the first alphamev MEV competition with best AUC (0.9893) and MSE (0.0982). The code won't run

70 Oct 29, 2022
RM Operation can equivalently convert ResNet to VGG, which is better for pruning; and can help RepVGG perform better when the depth is large.

RMNet: Equivalently Removing Residual Connection from Networks This repository is the official implementation of "RMNet: Equivalently Removing Residua

184 Jan 04, 2023
Official Repository for "Robust On-Policy Data Collection for Data Efficient Policy Evaluation" (NeurIPS 2021 Workshop on OfflineRL).

Robust On-Policy Data Collection for Data-Efficient Policy Evaluation Source code of Robust On-Policy Data Collection for Data-Efficient Policy Evalua

Autonomous Agents Research Group (University of Edinburgh) 2 Oct 09, 2022
Multi Agent Path Finding Algorithms

MATP-solver Simulator collision check path step random initial states or given states Traditional method Seperate A* algorithem Confict-based Search S

30 Dec 12, 2022