[NeurIPS 2021]: Are Transformers More Robust Than CNNs? (Pytorch implementation & checkpoints)

Overview

Are Transformers More Robust Than CNNs?

Pytorch implementation for NeurIPS 2021 Paper: Are Transformers More Robust Than CNNs?

Our implementation is based on DeiT.

Introduction

Transformer emerges as a powerful tool for visual recognition. In addition to demonstrating competitive performance on a broad range of visual benchmarks, recent works also argue that Transformers are much more robust than Convolutions Neural Networks (CNNs). Nonetheless, surprisingly, we find these conclusions are drawn from unfair experimental settings, where Transformers and CNNs are compared at different scales and are applied with distinct training frameworks. In this paper, we aim to provide the first fair & in-depth comparisons between Transformers and CNNs, focusing on robustness evaluations.

With our unified training setup, we first challenge the previous belief that Transformers outshine CNNs when measuring adversarial robustness. More surprisingly, we find CNNs can easily be as robust as Transformers on defending against adversarial attacks, if they properly adopt Transformers' training recipes. While regarding generalization on out-of-distribution samples, we show pre-training on (external) large-scale datasets is not a fundamental request for enabling Transformers to achieve better performance than CNNs. Moreover, our ablations suggest such stronger generalization is largely benefited by the Transformer's self-attention-like architectures per se, rather than by other training setups. We hope this work can help the community better understand and benchmark the robustness of Transformers and CNNs.

Pretrained models

We provide both pretrained vanilla models and adversarially trained models.

Vanilla Training

Main Results

Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Res50-Ori download link 76.9 3.2 57.9 8.3
Res50-Align download link 76.3 4.5 55.6 8.2
Res50-Best download link 75.7 6.3 52.3 10.8
DeiT-Small download link 76.8 12.2 48.0 13.0

Model Size

ResNets:

  • ResNets fully aligned (with DeiT's training recipe) model, denoted as res*:
Model Size Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Res18* 11.69M download link 67.83 1.92 64.14 7.92
Res50* 25.56M download link 76.28 4.53 55.62 8.17
Res101* 44.55M download link 77.97 8.84 49.19 11.60
  • ResNets best model (for Out-of-Distribution (OOD) generalization), denoted as res-best:
Model Size Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Res18-best 11.69M download link 66.81 2.03 62.65 9.45
Res50-best 25.56M download link 75.74 6.32 52.25 10.77
Res101-best 44.55M download link 77.83 11.49 47.35 13.28

DeiTs:

Model Size Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
DeiT-Mini 9.98M download link 72.89 8.19 54.68 9.88
DeiT-Small 22.05M download link 76.82 12.21 47.99 12.98

Model Distillation

Architecture Pretrained Model ImageNet ImageNet-A ImageNet-C Stylized-ImageNet
Teacher DeiT-Small download link 76.8 12.2 48.0 13.0
Student Res50*-Distill download link 76.7 5.2 54.2 9.8
Teacher Res50* download link 76.3 4.5 55.6 8.2
Student DeiT-S-Distill download link 76.2 10.9 49.3 11.9

Adversarial Training

Pretrained Model Clean Acc PGD-100 Auto Attack
Res50-ReLU download link 66.77 32.26 26.41
Res50-GELU download link 67.38 40.27 35.51
DeiT-Small download link 66.50 40.32 35.50

Vanilla Training

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The directory structure is the standard layout for the torchvision, and the training and validation data is expected to be in the train folder and val folder respectively:

/path/to/imagenet/
  train/
    class1/
      img1.jpeg
    class2/
      img2.jpeg
  val/
    class1/
      img3.jpeg
    class/2
      img4.jpeg

Environment

Install dependencies:

pip3 install -r requirements.txt

Training Scripts

To train a ResNet model on ImageNet run:

bash script/res.sh

To train a DeiT model on ImageNet run:

bash script/deit.sh

Generalization to Out-of-Distribution Sample

Data Preparation

Download and extract ImageNet-A, ImageNet-C, Stylized-ImageNet val images:

/path/to/datasets/
  val/
    class1/
      img1.jpeg
    class/2
      img2.jpeg

Evaluation Scripts

To evaluate pre-trained models, run:

bash script/generation_to_ood.sh

It is worth noting that for ImageNet-C evaluation, the error rate is calculated based on the Noise, Blur, Weather and Digital categories.

Adversarial Training

To perform adversarial training on ResNet run:

bash script/advres.sh

To do adversarial training on DeiT run:

bash scripts/advdeit.sh

Robustness to Adversarial Example

PGD Attack Evaluation

To evaluate the pre-trained models, run:

bash script/eval_advtraining.sh

AutoAttack Evaluation

./autoattack contains the AutoAttack public package, with a little modification to best support ImageNet evaluation.

cd autoattack/
bash autoattack.sh

Patch Attack Evaluation

Please refer to PatchAttack

Citation

If you use our code, models or wish to refer to our results, please use the following BibTex entry:

@inproceedings{bai2021transformers,
  title     = {Are Transformers More Robust Than CNNs?},
  author    = {Bai, Yutong and Mei, Jieru and Yuille, Alan and Xie, Cihang},
  booktitle = {Thirty-Fifth Conference on Neural Information Processing Systems},
  year      = {2021},
}
Owner
Yutong Bai
CS Ph.D student @ JHU, CCVL
Yutong Bai
This repo contains the pytorch implementation for Dynamic Concept Learner (accepted by ICLR 2021).

DCL-PyTorch Pytorch implementation for the Dynamic Concept Learner (DCL). More details can be found at the project page. Framework Grounding Physical

Zhenfang Chen 31 Jan 06, 2023
Official Pytorch implementation of 6DRepNet: 6D Rotation representation for unconstrained head pose estimation.

6D Rotation Representation for Unconstrained Head Pose Estimation (Pytorch) Paper Thorsten Hempel and Ahmed A. Abdelrahman and Ayoub Al-Hamadi, "6D Ro

Thorsten Hempel 284 Dec 23, 2022
Official implementation of "Motif-based Graph Self-Supervised Learning forMolecular Property Prediction"

Motif-based Graph Self-Supervised Learning for Molecular Property Prediction Official Pytorch implementation of NeurIPS'21 paper "Motif-based Graph Se

zaixi 71 Dec 20, 2022
Human Dynamics from Monocular Video with Dynamic Camera Movements

Human Dynamics from Monocular Video with Dynamic Camera Movements Ri Yu, Hwangpil Park and Jehee Lee Seoul National University ACM Transactions on Gra

215 Jan 01, 2023
Neural Motion Learner With Python

Neural Motion Learner Introduction This work is to extract skeletal structure from volumetric observations and to learn motion dynamics from the detec

Jinseok Bae 14 Nov 28, 2022
FCN (Fully Convolutional Network) is deep fully convolutional neural network architecture for semantic pixel-wise segmentation

FCN_via_Keras FCN FCN (Fully Convolutional Network) is deep fully convolutional neural network architecture for semantic pixel-wise segmentation. This

Kento Watanabe 48 Aug 30, 2022
Code of our paper "Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning"

CCOP Code of our paper Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning Requirement Install OpenSelfSup Install Detectron2

Chenhongyi Yang 21 Dec 13, 2022
DL & CV-based indicator toolset for the vehicle drivers via live dash-cam footage.

Vehicle Indicator Toolset Deep Learning and Computer Vision based indicator toolset for vehicle drivers using live dash-cam footages. Tracking of vehi

Alex Xu 12 Dec 28, 2021
Accurate identification of bacteriophages from metagenomic data using Transformer

PhaMer is a python library for identifying bacteriophages from metagenomic data. PhaMer is based on a Transorfer model and rely on protein-based vocab

Kenneth Shang 9 Nov 30, 2022
Urban mobility simulations with Python3, RLlib (Deep Reinforcement Learning) and Mesa (Agent-based modeling)

Deep Reinforcement Learning for Smart Cities Documentation RLlib: https://docs.ray.io/en/master/rllib.html Mesa: https://mesa.readthedocs.io/en/stable

1 May 15, 2022
The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining Concept-Oriented Shared Information".

The HIST framework for stock trend forecasting The implementation of the paper "HIST: A Graph-based Framework for Stock Trend Forecasting via Mining C

Wentao Xu 110 Dec 27, 2022
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

DALL-E in Pytorch Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch. It will also contain CLIP for ranking the ge

Phil Wang 5k Jan 04, 2023
Visualizing Yolov5's layers using GradCam

YOLO-V5 GRADCAM I constantly desired to know to which part of an object the object-detection models pay more attention. So I searched for it, but I di

Pooya Mohammadi Kazaj 200 Jan 01, 2023
CS506-Spring2022 - Code and Slides for Boston University CS 506

CS 506 - Computational Tools for Data Science Code, slides, and notes for Boston

Lance Galletti 17 May 06, 2022
Official repository for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'21, Oral Presentation)

Official PyTorch Implementation for HOTR: End-to-End Human-Object Interaction Detection with Transformers (CVPR'2021, Oral Presentation) HOTR: End-to-

Kakao Brain 114 Nov 28, 2022
A nutritional label for food for thought.

Lexiscore As a first effort in tackling the theme of information overload in content consumption, I've been working on the lexiscore: a nutritional la

Paul Bricman 34 Nov 08, 2022
Predicting Tweet Sentiment Maching Learning and streamlit

Predicting-Tweet-Sentiment-Maching-Learning-and-streamlit (I prefere using Visual Studio Code ) Open the folder in VS Code Run the first cell in requi

1 Nov 20, 2021
Code for "Reconstructing 3D Human Pose by Watching Humans in the Mirror", CVPR 2021 oral

Reconstructing 3D Human Pose by Watching Humans in the Mirror Qi Fang*, Qing Shuai*, Junting Dong, Hujun Bao, Xiaowei Zhou CVPR 2021 Oral The videos a

ZJU3DV 178 Dec 13, 2022
Training code and evaluation benchmarks for the "Self-Supervised Policy Adaptation during Deployment" paper.

Self-Supervised Policy Adaptation during Deployment PyTorch implementation of PAD and evaluation benchmarks from Self-Supervised Policy Adaptation dur

Nicklas Hansen 101 Nov 01, 2022
my graduation project is about live human face augmentation by projection mapping by using CNN

Live-human-face-expression-augmentation-by-projection my graduation project is about live human face augmentation by projection mapping by using CNN o

1 Mar 08, 2022