Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)

Overview

Vision Transformer

Pytorch reimplementation of Google's repository for the ViT model that was released with the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.

This paper show that Transformers applied directly to image patches and pre-trained on large datasets work really well on image recognition task.

fig1

Vision Transformer achieve State-of-the-Art in image recognition task with standard Transformer encoder and fixed-size patches. In order to perform classification, author use the standard approach of adding an extra learnable "classification token" to the sequence.

fig2

Usage

1. Download Pre-trained model (Google's Official Checkpoint)

  • Available models: ViT-B_16(85.8M), R50+ViT-B_16(97.96M), ViT-B_32(87.5M), ViT-L_16(303.4M), ViT-L_32(305.5M), ViT-H_14(630.8M)
    • imagenet21k pre-train models
      • ViT-B_16, ViT-B_32, ViT-L_16, ViT-L_32, ViT-H_14
    • imagenet21k pre-train + imagenet2012 fine-tuned models
      • ViT-B_16-224, ViT-B_16, ViT-B_32, ViT-L_16-224, ViT-L_16, ViT-L_32
    • Hybrid Model(Resnet50 + Transformer)
      • R50-ViT-B_16
# imagenet21k pre-train
wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz

# imagenet21k pre-train + imagenet2012 fine-tuning
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{MODEL_NAME}.npz

2. Train Model

python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz

CIFAR-10 and CIFAR-100 are automatically download and train. In order to use a different dataset you need to customize data_utils.py.

The default batch size is 512. When GPU memory is insufficient, you can proceed with training by adjusting the value of --gradient_accumulation_steps.

Also can use Automatic Mixed Precision(Amp) to reduce memory usage and train faster

python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz --fp16 --fp16_opt_level O2

Results

To verify that the converted model weight is correct, we simply compare it with the author's experimental results. We trained using mixed precision, and --fp16_opt_level was set to O2.

imagenet-21k

model dataset resolution acc(official) acc(this repo) time
ViT-B_16 CIFAR-10 224x224 - 0.9908 3h 13m
ViT-B_16 CIFAR-10 384x384 0.9903 0.9906 12h 25m
ViT_B_16 CIFAR-100 224x224 - 0.923 3h 9m
ViT_B_16 CIFAR-100 384x384 0.9264 0.9228 12h 31m
R50-ViT-B_16 CIFAR-10 224x224 - 0.9892 4h 23m
R50-ViT-B_16 CIFAR-10 384x384 0.99 0.9904 15h 40m
R50-ViT-B_16 CIFAR-100 224x224 - 0.9231 4h 18m
R50-ViT-B_16 CIFAR-100 384x384 0.9231 0.9197 15h 53m
ViT_L_32 CIFAR-10 224x224 - 0.9903 2h 11m
ViT_L_32 CIFAR-100 224x224 - 0.9276 2h 9m
ViT_H_14 CIFAR-100 224x224 - WIP

imagenet-21k + imagenet2012

model dataset resolution acc
ViT-B_16-224 CIFAR-10 224x224 0.99
ViT_B_16-224 CIFAR-100 224x224 0.9245
ViT-L_32 CIFAR-10 224x224 0.9903
ViT-L_32 CIFAR-100 224x224 0.9285

shorter train

  • In the experiment below, we used a resolution size (224x224).
  • tensorboard
upstream model dataset total_steps /warmup_steps acc(official) acc(this repo)
imagenet21k ViT-B_16 CIFAR-10 500/100 0.9859 0.9859
imagenet21k ViT-B_16 CIFAR-10 1000/100 0.9886 0.9878
imagenet21k ViT-B_16 CIFAR-100 500/100 0.8917 0.9072
imagenet21k ViT-B_16 CIFAR-100 1000/100 0.9115 0.9216

Visualization

The ViT consists of a Standard Transformer Encoder, and the encoder consists of Self-Attention and MLP module. The attention map for the input image can be visualized through the attention score of self-attention.

Visualization code can be found at visualize_attention_map.

fig3

Reference

Citations

@article{dosovitskiy2020,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and  Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
Owner
Eunkwang Jeon
Eunkwang Jeon
A curated list of the latest breakthroughs in AI (in 2021) by release date with a clear video explanation, link to a more in-depth article, and code.

2021: A Year Full of Amazing AI papers- A Review 📌 A curated list of the latest breakthroughs in AI by release date with a clear video explanation, l

Louis-François Bouchard 2.9k Dec 31, 2022
Dynamic View Synthesis from Dynamic Monocular Video

Dynamic View Synthesis from Dynamic Monocular Video Project Website | Video | Paper Dynamic View Synthesis from Dynamic Monocular Video Chen Gao, Ayus

Chen Gao 139 Dec 28, 2022
The code for the NeurIPS 2021 paper "A Unified View of cGANs with and without Classifiers".

Energy-based Conditional Generative Adversarial Network (ECGAN) This is the code for the NeurIPS 2021 paper "A Unified View of cGANs with and without

sianchen 22 May 28, 2022
The official code repo of "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection"

Hierarchical Token Semantic Audio Transformer Introduction The Code Repository for "HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound

Knut(Ke) Chen 134 Jan 01, 2023
Code for Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works

GDAP Code for Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works Environment Python (verified: v3.8) CUDA

45 Oct 29, 2022
Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).

Deep Text Search - AI Based Text Search & Recommendation System Deep Text Search is an AI-powered multilingual text search and recommendation engine w

19 Sep 29, 2022
Multi-Modal Machine Learning toolkit based on PaddlePaddle.

简体中文 | English PaddleMM 简介 飞桨多模态学习工具包 PaddleMM 旨在于提供模态联合学习和跨模态学习算法模型库,为处理图片文本等多模态数据提供高效的解决方案,助力多模态学习应用落地。 近期更新 2022.1.5 发布 PaddleMM 初始版本 v1.0 特性 丰富的任务

njustkmg 520 Dec 28, 2022
SustainBench: Benchmarks for Monitoring the Sustainable Development Goals with Machine Learning

Datasets | Website | Raw Data | OpenReview SustainBench: Benchmarks for Monitoring the Sustainable Development Goals with Machine Learning Christopher

67 Dec 17, 2022
A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering.

DeepFilterNet A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering. libDF contains Rust code used for dat

Hendrik Schröter 292 Dec 25, 2022
Steer OpenAI's Jukebox with Music Taggers

TagBox Steer OpenAI's Jukebox with Music Taggers! The closest thing we have to VQGAN+CLIP for music! Unsupervised Source Separation By Steering Pretra

Ethan Manilow 34 Nov 02, 2022
Seg-Torch for Image Segmentation with Torch

Seg-Torch for Image Segmentation with Torch This work was sparked by my personal research on simple segmentation methods based on deep learning. It is

Eren Gölge 37 Dec 12, 2022
E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

E2EDNA2 - An automated pipeline for simulation of DNA aptamers complexed with small molecules and short peptides

11 Nov 08, 2022
Code for "Multi-Time Attention Networks for Irregularly Sampled Time Series", ICLR 2021.

Multi-Time Attention Networks (mTANs) This repository contains the PyTorch implementation for the paper Multi-Time Attention Networks for Irregularly

The Laboratory for Robust and Efficient Machine Learning 68 Dec 17, 2022
Implementation of ICLR 2020 paper "Revisiting Self-Training for Neural Sequence Generation"

Self-Training for Neural Sequence Generation This repo includes instructions for running noisy self-training algorithms from the following paper: Revi

Junxian He 45 Dec 31, 2022
Happywhale - Whale and Dolphin Identification Silver🥈 Solution (26/1588)

Kaggle-Happywhale Happywhale - Whale and Dolphin Identification Silver 🥈 Solution (26/1588) 竞赛方案思路 图像数据预处理-标志性特征图片裁剪:首先根据开源的标注数据训练YOLOv5x6目标检测模型,将训练集

Franxx 20 Nov 14, 2022
Finetune alexnet with tensorflow - Code for finetuning AlexNet in TensorFlow >= 1.2rc0

Finetune AlexNet with Tensorflow Update 15.06.2016 I revised the entire code base to work with the new input pipeline coming with TensorFlow = versio

Frederik Kratzert 766 Jan 04, 2023
SafePicking: Learning Safe Object Extraction via Object-Level Mapping, ICRA 2022

SafePicking Learning Safe Object Extraction via Object-Level Mapping Kentaro Wad

Kentaro Wada 49 Oct 24, 2022
This repo is for segmentation of T2 hyp regions in gliomas.

T2-Hyp-Segmentor This repo is for segmentation of T2 hyp regions in gliomas. By downloading the model from here you can use it to segment your T2w ima

1 Jan 18, 2022
Repository for "Exploring Sparsity in Image Super-Resolution for Efficient Inference", CVPR 2021

SMSR Reposity for "Exploring Sparsity in Image Super-Resolution for Efficient Inference" [arXiv] Highlights Locate and skip redundant computation in S

Longguang Wang 225 Dec 26, 2022
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha

Winnie Xu 95 Nov 26, 2021