Reproduction of Vision Transformer in Tensorflow2. Train from scratch and Finetune.

Overview

Vision Transformer(ViT) in Tensorflow2

Tensorflow2 implementation of the Vision Transformer(ViT).

This repository is for An image is worth 16x16 words: Transformers for image recognition at scale and How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers.

Limitations.

  • Due to memory limitations, only the ti/16, s/16, and b/16 models were tested.
  • Due to memory limitations, batch_size 2048 in s16 and 1024 in b/16 (in paper, 4096).
  • Due to computational resource limitations, only reproduce using imagenet1k.

All experimental results and graphs are opend in Wandb.

Model weights

Since this is personal project, it is hard to train with large datasets like imagenet21k. For a pretrain model with good performance, see the official repo. But if you really need it, contact me.

Install dependencies

pip install -r requirements

All experiments were done on tpu_v3-8 with the support of TRC. But you can experiment on GPU. Check conf/config.yaml and conf/downstream.yaml

  # TPU options
  env:
    mode: tpu
    gcp_project: {your_project}
    tpu_name: node-1
    tpu_zone: europe-west4-a
    mixed_precision: True
  # GPU options
  # env:
  #   mode: gpu
  #   mixed_precision: True

Train from scratch

python run.py experiment=vit-s16-aug_light1-bs_2048-wd_0.1-do_0.1-dp_0.1-lr_1e-3 base.project_name=vit-s16-aug_light1-bs_2048-wd_0.1-do_0.1-dp_0.1-lr_1e-3 base.save_dir={your_save_dir} base.env.gcp_project={your_gcp_project} base.env.tpu_name={your_tpu_name} base.debug=False

Downstream

python run.py --config-name=downstream experiment=downstream-imagenet-ti16_384 base.pretrained={your_checkpoint} base.project_name={your_project_name} base.save_dir={your_save_dir} base.env.gcp_project={your_gcp_project} base.env.tpu_name={your_tpu_name} base.debug=False

Board

To track metics, you can use wandb or tensorboard (default: wandb). You can change in conf/callbacks/{filename.yaml}.

modules:
  - type: MonitorCallback
  - type: TerminateOnNaN
  - type: ProgbarLogger
    params:
      count_mode: steps
  - type: ModelCheckpoint
    params:
      filepath: ???
      save_weights_only: True
  - type: Wandb
    project: vit
    nested_dict: False
    hide_config: True
    params: 
      monitor: val_loss
      save_model: False
  # - type: TensorBoard
  #   params:
  #     log_dir: ???
  #     histogram_freq: 1

TFC

This open source was assisted by TPU Research Cloud (TRC) program

Thank you for providing the TPU.

Citations

@article{dosovitskiy2020image,
  title={An image is worth 16x16 words: Transformers for image recognition at scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
@article{steiner2021train,
  title={How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers},
  author={Steiner, Andreas and Kolesnikov, Alexander and Zhai, Xiaohua and Wightman, Ross and Uszkoreit, Jakob and Beyer, Lucas},
  journal={arXiv preprint arXiv:2106.10270},
  year={2021}
}
Owner
sungjun lee
AI Researcher
sungjun lee
This implementation contains the application of GPlearn's symbolic transformer on a commodity futures sector of the financial market.

GPlearn_finiance_stock_futures_extension This implementation contains the application of GPlearn's symbolic transformer on a commodity futures sector

Chengwei <a href=[email protected]"> 189 Dec 25, 2022
PyTorch implementation of SwAV (Swapping Assignments between Views)

Unsupervised Learning of Visual Features by Contrasting Cluster Assignments This code provides a PyTorch implementation and pretrained models for SwAV

Meta Research 1.7k Jan 04, 2023
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
Official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification

CrossViT This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. ArXiv If

International Business Machines 168 Dec 29, 2022
Towards Multi-Camera 3D Human Pose Estimation in Wild Environment

PanopticStudio Toolbox This repository has a toolbox to download, process, and visualize the Panoptic Studio (Panoptic) data. Note: Sep-21-2020: Curre

335 Jan 09, 2023
Use .csv files to record, play and evaluate motion capture data.

Purpose These scripts allow you to record mocap data to, and play from .csv files. This approach facilitates parsing of body movement data in statisti

21 Dec 12, 2022
Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image

Defocus Map Estimation and Deblurring from a Single Dual-Pixel Image This repository is an implementation of the method described in the following pap

21 Dec 15, 2022
Little tool in python to watch anime from the terminal (the better way to watch anime)

ani-cli Script working again :), thanks to the fork by Dink4n for the alternative approach to by pass the captcha on gogoanime A cli to browse and wat

Harshith 4.5k Dec 31, 2022
Sharpened cosine similarity torch - A Sharpened Cosine Similarity layer for PyTorch

Sharpened Cosine Similarity A layer implementation for PyTorch Install At your c

Brandon Rohrer 203 Nov 30, 2022
PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis This is a PyTorch implementation of the Deep Streaming Linear Discriminant

Tyler Hayes 41 Dec 25, 2022
Face-Recognition-based-Attendance-System - An implementation of Attendance System in python.

Face-Recognition-based-Attendance-System A real time implementation of Attendance System in python. Pre-requisites To understand the implentation of F

Muhammad Zain Ul Haque 1 Dec 31, 2021
Code for the USENIX 2017 paper: kAFL: Hardware-Assisted Feedback Fuzzing for OS Kernels

kAFL: Hardware-Assisted Feedback Fuzzing for OS Kernels Blazing fast x86-64 VM kernel fuzzing framework with performant VM reloads for Linux, MacOS an

Chair for Sys­tems Se­cu­ri­ty 541 Nov 27, 2022
Complete* list of autonomous driving related datasets

AD Datasets Complete* and curated list of autonomous driving related datasets Contributing Contributions are very welcome! To add or update a dataset:

Daniel Bogdoll 13 Dec 19, 2022
A diff tool for language models

LMdiff Qualitative comparison of large language models. Demo & Paper: http://lmdiff.net LMdiff is a MIT-IBM Watson AI Lab collaboration between: Hendr

Hendrik Strobelt 27 Dec 29, 2022
A Tensorflow based library for Time Series Modelling with Gaussian Processes

Markovflow Documentation | Tutorials | API reference | Slack What does Markovflow do? Markovflow is a Python library for time-series analysis via prob

Secondmind Labs 24 Dec 12, 2022
Pytorch code for semantic segmentation using ERFNet

ERFNet (PyTorch version) This code is a toolbox that uses PyTorch for training and evaluating the ERFNet architecture for semantic segmentation. For t

Edu 394 Jan 01, 2023
NFNets and Adaptive Gradient Clipping for SGD implemented in PyTorch

PyTorch implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping Paper: https://arxiv.org/abs/2102.06171.pdf Original code: htt

Vaibhav Balloli 320 Jan 02, 2023
Implementation of Convolutional LSTM in PyTorch.

ConvLSTM_pytorch This file contains the implementation of Convolutional LSTM in PyTorch made by me and DavideA. We started from this implementation an

Andrea Palazzi 1.3k Dec 29, 2022
PyTorch implementation for "HyperSPNs: Compact and Expressive Probabilistic Circuits", NeurIPS 2021

HyperSPN This repository contains code for the paper: HyperSPNs: Compact and Expressive Probabilistic Circuits "HyperSPNs: Compact and Expressive Prob

8 Nov 08, 2022