Implementation of the state-of-the-art vision transformers with tensorflow

Overview

ViT Tensorflow

This repository contains the tensorflow implementation of the state-of-the-art vision transformers (a category of computer vision models first introduced in An Image is worth 16 x 16 words). This repository is inspired from the work of lucidrains which is vit-pytorch. I hope you enjoy these implementations :)

Models

Requirements

pip install tensorflow

Vision Transformer

Vision transformer was introduced in An Image is worth 16 x 16 words. This model uses a Transformer encoder to classify images with pure attention and no convolution.

Usage

Defining the Model

from vit import ViT
import tensorflow as tf

vitClassifier = ViT(
                    num_classes=1000,
                    patch_size=16,
                    num_of_patches=(224//16)**2,
                    d_model=128,
                    heads=2,
                    num_layers=4,
                    mlp_rate=2,
                    dropout_rate=0.1,
                    prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = vitClassifier(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

vitClassifier.compile(
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              metrics=[
                       tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
              ])

vitClassifier.fit(
              trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
              validation_data=valData, #The same as training
              epochs=100,)

Convolutional Vision Transformer

Convolutional Vision Transformer was introduced in here. This model uses a hierarchical (multi-stage) architecture with convolutional embeddings in the begining of each stage. it also uses Convolutional Transformer Blocks to improve the orginal vision transformer by adding CNNs inductive bias into the architecture.

Usage

Defining the Model

from cvt import CvT , CvTStage
import tensorflow as tf

cvtModel = CvT(
num_of_classes=1000, 
stages=[
        CvTStage(projectionDim=64, 
                 heads=1, 
                 embeddingWindowSize=(7 , 7), 
                 embeddingStrides=(4 , 4), 
                 layers=1,
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=192,
                 heads=3,
                 embeddingWindowSize=(3 , 3), 
                 embeddingStrides=(2 , 2),
                 layers=1, 
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=384,
                 heads=6,
                 embeddingWindowSize=(3 , 3),
                 embeddingStrides=(2 , 2),
                 layers=1,
                 projectionWindowSize=(3 , 3),
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1)
],
dropout=0.5)
CvT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of CvTStage
    list of cvt stages
  • dropout: float
    dropout rate used for the prediction head
CvTStage Params
  • projectionDim: int
    dimension used for the multi-head attention mechanism and the convolutional embedding
  • heads: int
    number of heads in the multi-head attention mechanism
  • embeddingWindowSize: tuple(int , int)
    window size used for the convolutional emebdding
  • embeddingStrides: tuple(int , int)
    strides used for the convolutional embedding
  • layers: int
    number of convolutional transformer blocks
  • projectionWindowSize: tuple(int , int)
    window size used for the convolutional projection in each convolutional transformer block
  • projectionStrides: tuple(int , int)
    strides used for the convolutional projection in each convolutional transformer block
  • ffnRate: int
    expansion rate of the mlp block in each convolutional transformer block
  • dropoutRate: float
    dropout rate used in each convolutional transformer block

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = cvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

cvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

cvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V1

Pyramid Vision Transformer V1 was introduced in here. This model stacks multiple Transformer Encoders to form the first convolution-free multi-scale backbone for various visual tasks including Image Segmentation , Object Detection and etc. In addition to this a new attention mechanism called Spatial Reduction Attention (SRA) is also introduced in this paper to reduce the quadratic complexity of the multi-head attention mechansim.

Usage

Defining the Model

from pvt_v1 import PVT , PVTStage
import tensorflow as tf

pvtModel = PVT(
num_of_classes=1000, 
stages=[
        PVTStage(d_model=64,
                 patch_size=(2 , 2),
                 heads=1,
                 reductionFactor=2,
                 mlp_rate=2,
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=128,
                 patch_size=(2 , 2),
                 heads=2, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=320,
                 patch_size=(2 , 2),
                 heads=5, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTStage
    list of pvt stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the SRA mechanism and the patch embedding
  • patch_size: tuple(int , int)
    window size used for the patch emebdding
  • heads: int
    number of heads in the SRA mechanism
  • reductionFactor: int
    reduction factor used for the down sampling of the K and V in the SRA mechanism
  • mlp_rate: int
    expansion rate used in the feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

pvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V2

Pyramid Vision Transformer V2 was introduced in here. This model is an improved version of the PVT V1. The improvements of this version are as follows:

  1. It uses overlapping patch embedding by using padded convolutions
  2. It uses convolutional feed-forward blocks which have a depth-wise convolution after the first fully-connected layer
  3. It uses a fixed pooling instead of convolutions for down sampling the K and V in the SRA attention mechanism (The new attention mechanism is called Linear SRA)

Usage

Defining the Model

from pvt_v2 import PVTV2 , PVTV2Stage
import tensorflow as tf

pvtV2Model = PVTV2(
num_of_classes=1000, 
stages=[
        PVTV2Stage(d_model=64,
                   windowSize=(2 , 2), 
                   heads=1,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
        PVTV2Stage(d_model=128, 
                   windowSize=(2 , 2),
                   heads=2,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2,
                   dropout_rate=0.1),
        PVTV2Stage(d_model=320,
                   windowSize=(2 , 2), 
                   heads=5, 
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTV2Stage
    list of pvt v2 stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the Linear SRA mechanism and the convolutional patch embedding
  • windowSize: tuple(int , int)
    window size used for the convolutional patch emebdding
  • heads: int
    number of heads in the Linear SRA mechanism
  • poolingSize: tuple(int , int)
    size of the K and V after the fixed pooling
  • mlp_rate: int
    expansion rate used in the convolutional feed-forward block
  • mlp_windowSize: tuple(int , int)
    the window size used for the depth-wise convolution in the convolutional feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtV2Model(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtV2Model.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
          metrics=[
                   tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
          ])

pvtV2Model.fit(
          trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
          validation_data=valData, #The same as training
          epochs=100,)

DeiT

DeiT was introduced in Training Data-Efficient Image Transformers & Distillation Through Attention. Since original vision transformer is data hungry due to the lack of existance of any inductive bias (unlike CNNs) a lot of data is required to train original vision transformer in order to surpass the state-of-the-art CNNs such as Resnet. Therefore, in this paper authors used a pre-trained CNN such as resent during training and used a sepcial loss function to perform distillation through attention.

Usage

Defining the Model

from deit import DeiT
import tensorflow as tf

teacherModel = tf.keras.applications.ResNet50(include_top=True, 
                                              weights="imagenet", 
                                              input_shape=(224 , 224 , 3))

deitModel = DeiT(
                 num_classes=1000,
                 patch_size=16,
                 num_of_patches=(224//16)**2,
                 d_model=128,
                 heads=2,
                 num_layers=4,
                 mlp_rate=2,
                 teacherModel=teacherModel,
                 temperature=1.0, 
                 alpha=0.5,
                 hard=False, 
                 dropout_rate=0.1,
                 prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • teacherModel: Tensorflow Model
    the teacherModel used for the distillation during training, This model is a pre-trained CNN model with the same input_shape and output_shape as the Transformer
  • temperature: float
    the temperature parameter in the loss
  • alpha: float
    the coefficient balancing the Kullback–Leibler divergence loss (KL) and the cross-entropy loss
  • hard: bool
    indicates using Hard-label distillation or Soft distillation
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = deitModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

#Note that the loss is defined inside the model and no loss should be passed here
deitModel.compile(
         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
         metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                  tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
         ])

deitModel.fit(
         trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b , num_classes))
         validation_data=valData, #The same as training
         epochs=100,)
Owner
Mohammadmahdi NouriBorji
Mohammadmahdi NouriBorji
Chainer implementation of recent GAN variants

Chainer-GAN-lib This repository collects chainer implementation of state-of-the-art GAN algorithms. These codes are evaluated with the inception score

399 Oct 23, 2022
Implementation for "Manga Filling Style Conversion with Screentone Variational Autoencoder" (SIGGRAPH ASIA 2020 issue)

Manga Filling with ScreenVAE SIGGRAPH ASIA 2020 | Project Website | BibTex This repository is for ScreenVAE introduced in the following paper "Manga F

30 Dec 24, 2022
Repository for "Space-Time Correspondence as a Contrastive Random Walk" (NeurIPS 2020)

Space-Time Correspondence as a Contrastive Random Walk This is the repository for Space-Time Correspondence as a Contrastive Random Walk, published at

A. Jabri 239 Dec 27, 2022
Minimal fastai code needed for working with pytorch

fastai_minima A mimal version of fastai with the barebones needed to work with Pytorch #all_slow Install pip install fastai_minima How to use This lib

Zachary Mueller 14 Oct 21, 2022
Code associated with the paper "Towards Understanding the Data Dependency of Mixup-style Training".

Mixup-Data-Dependency Code associated with the paper "Towards Understanding the Data Dependency of Mixup-style Training". Running Alternating Line Exp

Muthu Chidambaram 0 Nov 11, 2021
Learn the Deep Learning for Computer Vision in three steps: theory from base to SotA, code in PyTorch, and space-repetition with Anki

DeepCourse: Deep Learning for Computer Vision arthurdouillard.com/deepcourse/ This is a course I'm giving to the French engineering school EPITA each

Arthur Douillard 113 Nov 29, 2022
Pytorch code for ICRA'21 paper: "Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation"

Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation This repository is the pytorch implementation of our paper: Hierarchical Cr

43 Nov 21, 2022
CTRL-C: Camera calibration TRansformer with Line-Classification

CTRL-C: Camera calibration TRansformer with Line-Classification This repository contains the official code and pretrained models for CTRL-C (Camera ca

57 Nov 14, 2022
Panoptic SegFormer: Delving Deeper into Panoptic Segmentation with Transformers

Panoptic SegFormer: Delving Deeper into Panoptic Segmentation with Transformers Results results on COCO val Backbone Method Lr Schd PQ Config Download

155 Dec 20, 2022
Unpaired Caricature Generation with Multiple Exaggerations

CariMe-pytorch The official pytorch implementation of the paper "CariMe: Unpaired Caricature Generation with Multiple Exaggerations" CariMe: Unpaired

Gu Zheng 37 Dec 30, 2022
AutoVideo: An Automated Video Action Recognition System

AutoVideo is a system for automated video analysis. It is developed based on D3M infrastructure, which describes machine learning with generic pipeline languages. Currently, it focuses on video actio

Data Analytics Lab at Texas A&M University 267 Dec 17, 2022
Parameterized Explainer for Graph Neural Network

PGExplainer This is a Tensorflow implementation of the paper: Parameterized Explainer for Graph Neural Network https://arxiv.org/abs/2011.04573 NeurIP

Dongsheng Luo 89 Dec 12, 2022
[CVPR 2021 Oral] Variational Relational Point Completion Network

VRCNet: Variational Relational Point Completion Network This repository contains the PyTorch implementation of the paper: Variational Relational Point

PL 121 Dec 12, 2022
A Broad Study on the Transferability of Visual Representations with Contrastive Learning

A Broad Study on the Transferability of Visual Representations with Contrastive Learning This repository contains code for the paper: A Broad Study on

Ashraful Islam 29 Nov 09, 2022
This repository contains a toolkit for collecting, labeling and tracking object keypoints

This repository contains a toolkit for collecting, labeling and tracking object keypoints. Object keypoints are semantic points in an object's coordinate frame.

ETHZ ASL 13 Dec 12, 2022
Metrics to evaluate quality and efficacy of synthetic datasets.

An Open Source Project from the Data to AI Lab, at MIT Metrics for Synthetic Data Generation Projects Website: https://sdv.dev Documentation: https://

The Synthetic Data Vault Project 129 Jan 03, 2023
TensorFlow Tutorials with YouTube Videos

TensorFlow Tutorials Original repository on GitHub Original author is Magnus Erik Hvass Pedersen Introduction These tutorials are intended for beginne

9.1k Jan 02, 2023
BC3407-Group-5-Project - BC3407 Group Project With Python

BC3407-Group-5-Project As the world struggles to contain the ever-changing varia

1 Jan 26, 2022
A "gym" style toolkit for building lightweight Neural Architecture Search systems

A "gym" style toolkit for building lightweight Neural Architecture Search systems

Jack Turner 12 Nov 05, 2022
AnimationKit: AI Upscaling & Interpolation using Real-ESRGAN+RIFE

ALPHA 2.5: Frostbite Revival (Released 12/23/21) Changelog: [ UI ] Chained design. All steps link to one another! Use the master override toggles to s

87 Nov 16, 2022