An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Overview

Transformer-in-Transformer Twitter

PyPI Open In Colab Upload Python Package Lint Code Base Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local patches. Transformer in Transformer uses pixel level attention paired with patch level attention for image classification, in TensorFlow.

PyTorch Implementation

Installation

Run the following to install:

pip install tnt-tensorflow

Developing tnt-tensorflow

To install tnt-tensorflow, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Transformer-in-Transformer.git
# or clone your own fork

cd tnt
pip install -e .[dev]

Usage

import tensorflow as tf
from tnt import TNT

tnt = TNT(
    image_size=256,  # size of image
    patch_dim=512,  # dimension of patch token
    pixel_dim=24,  # dimension of pixel token
    patch_size=16,  # patch size
    pixel_size=4,  # pixel size
    depth=5,  # depth
    num_classes=1000,  # output number of classes
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1,  # feedforward dropout
)

img = tf.random.uniform(shape=[5, 3, 256, 256])
logits = tnt(img) # (5, 1000)

Want to Contribute 🙋‍♂️ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citation

@misc{han2021transformer,
      title={Transformer in Transformer}, 
      author={Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
      year={2021},
      eprint={2103.00112},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Comments
  • Add Unit Tests

    Add Unit Tests

    The tests should check for the rank and shape of the output tensors, the test should override tf.test.TestCase base class.

    • [x] #15
    • [x] #16
    • [x] #18
    • [x] #17

    Feel free to take inspiration from:

    • https://github.com/Rishit-dagli/Fast-Transformer/blob/main/fast_transformer/test_fast_transformer.py
    • For parametrization feel free to follow https://stackoverflow.com/a/34094/11878567, can be used in the exact same way with subTest in TensorFlow
    enhancement good first issue 
    opened by Rishit-dagli 3
  • Update Workflows to run tests

    Update Workflows to run tests

    This issue follows #11

    Update GitHub Workflows to:

    • [ ] Run Tests before uploading to PyPI
    • [ ] Create a workflow to run tests on commits

    Feel free to take inspiration from https://github.com/Rishit-dagli/Fast-Transformer/tree/main/.github/workflows

    enhancement good first issue 
    opened by Rishit-dagli 0
  • Creates an Attention layer

    Creates an Attention layer

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    

    Closes #3

    opened by Rishit-dagli 0
  • Put together a TNT class

    Put together a TNT class

    Verify shapes:

    tnt = TNT(
        image_size=256,  # size of image
        patch_dim=512,  # dimension of patch token
        pixel_dim=24,  # dimension of pixel token
        patch_size=16,  # patch size
        pixel_size=4,  # pixel size
        depth=5,  # depth
        num_classes=1000,  # output number of classes
        attn_dropout=0.1,  # attention dropout
        ff_dropout=0.1,  # feedforward dropout
    )
    
    img = tf.random.uniform(shape=[1, 3, 256, 256])
    print(tnt(img).shape)
    
    # (1, 1000)
    ```
    opened by Rishit-dagli 0
  • Create an Attention layerr

    Create an Attention layerr

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    
    opened by Rishit-dagli 0
  • Create a PreNorm layer

    Create a PreNorm layer

    Verify output shapes from this layer:

    import tensorflow as tf
    PreNorm(dim=1, fn=tf.keras.layers.Dense(5))(tf.random.normal([10, 1]))
    
    # <tf.Tensor: shape=(10, 1), dtype=float32,
    
    opened by Rishit-dagli 0
Releases(v0.2.0)
  • v0.2.0(Feb 2, 2022)

    This is an interesting release for the project, including a pre-trained model on ImageNet, reproducibility of paper results, tests, and end-to-end training.

    ✅ Bug Fixes / Improvements

    • Create an end-to-end training example demonstrating how to train a TNT model for image classification through a custom training loop on the TF Flowers dataset (#14)
    • Pre-trained model to reproduce the paper results have been made available (in this release as well as on TensorFlow Hub)
    • Create an off-the-shelf inference example, that highlights how you can directly use the pre-trained model made available
    • Unit Tests for the Attention class (#19)
    • Unit Tests for the main TNT class (#20)

    Full Changelog: https://github.com/Rishit-dagli/Transformer-in-Transformer/compare/v0.1.0...v0.2.0

    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
  • v0.1.0(Dec 3, 2021)

    This is the initial release of TNT TensorFlow and implements Transformers in Transformers as a subclassed TensorFlow model.

    Classes

    • Attention: Implements attention as a TensorFlow Keras Layer making some modifications.
    • PreNorm: Normalize the activations of the previous layer for each given example in a batch independently and apply some function to it, implemented as a TensorFlow Keras Layer.
    • FeedForward: Create a FeedForward neural net with two Dense layers and GELU activation, implemented as a TensorFlow Keras Layer.
    • TNT: Implements the Transformers in Transformers model using all the other classes, and converts to logits. Implemented as a TensorFlow Keras Model.
    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order of magnitude using coresets and data selection.

COResets and Data Subset selection Reduce end to end training time from days to hours (or hours to minutes), and energy requirements/costs by an order

decile-team 244 Jan 09, 2023
Pytorch implementation of our paper accepted by NeurIPS 2021 -- Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme

Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme (NeurIPS2021) (Link) Overview Prerequisites Linu

Shaojie Li 34 Mar 31, 2022
Attentive Implicit Representation Networks (AIR-Nets)

Attentive Implicit Representation Networks (AIR-Nets) Preprint | Supplementary | Accepted at the International Conference on 3D Vision (3DV) teaser.mo

29 Dec 07, 2022
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

Laura Smith 70 Dec 07, 2022
Контрольная работа по математическим методам машинного обучения

ML-MathMethods-Test Контрольная работа по математическим методам машинного обучения. Вычисление основных статистик, диаграмм и графиков, проверка разл

Stas Ivanovskii 1 Jan 06, 2022
Code of PVTv2 is released! PVTv2 largely improves PVTv1 and works better than Swin Transformer with ImageNet-1K pre-training.

Updates (2020/06/21) Code of PVTv2 is released! PVTv2 largely improves PVTv1 and works better than Swin Transformer with ImageNet-1K pre-training. Pyr

1.3k Jan 04, 2023
PyTorch implementation of ARM-Net: Adaptive Relation Modeling Network for Structured Data.

A ready-to-use framework of latest models for structured (tabular) data learning with PyTorch. Applications include recommendation, CRT prediction, healthcare analytics, and etc.

48 Nov 30, 2022
The Turing Change Point Detection Benchmark: An Extensive Benchmark Evaluation of Change Point Detection Algorithms on real-world data

Turing Change Point Detection Benchmark Welcome to the repository for the Turing Change Point Detection Benchmark, a benchmark evaluation of change po

The Alan Turing Institute 85 Dec 28, 2022
Image-based Navigation in Real-World Environments via Multiple Mid-level Representations: Fusion Models Benchmark and Efficient Evaluation

Image-based Navigation in Real-World Environments via Multiple Mid-level Representations: Fusion Models Benchmark and Efficient Evaluation This reposi

First Person Vision @ Image Processing Laboratory - University of Catania 1 Aug 21, 2022
Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation

Info This is the code repository of the work Tackling the Class Imbalance Problem of Deep Learning Based Head and Neck Organ Segmentation from Elias T

2 Apr 20, 2022
Official repository for the ISBI 2021 paper Transformer Assisted Convolutional Neural Network for Cell Instance Segmentation

SegPC-2021 This is the official repository for the ISBI 2021 paper Transformer Assisted Convolutional Neural Network for Cell Instance Segmentation by

Datascience IIT-ISM 13 Dec 14, 2022
Repository for the AugmentedPCA Python package.

Overview This Python package provides implementations of Augmented Principal Component Analysis (AugmentedPCA) - a family of linear factor models that

Billy Carson 6 Dec 07, 2022
Unified file system operation experience for different backend

megfile - Megvii FILE library Docs: http://megvii-research.github.io/megfile megfile provides a silky operation experience with different backends (cu

MEGVII Research 76 Dec 14, 2022
A pytorch-version implementation codes of paper: "BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation"

BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation A pytorch-version implementation

11 Oct 08, 2022
My implementation of transformers related papers for computer vision in pytorch

vision_transformers This is my personnal repo to implement new transofrmers based and other computer vision DL models I am currenlty working without a

samsja 1 Nov 10, 2021
A PyTorch implementation: "LASAFT-Net-v2: Listen, Attend and Separate by Attentively aggregating Frequency Transformation"

LASAFT-Net-v2 Listen, Attend and Separate by Attentively aggregating Frequency Transformation Woosung Choi, Yeong-Seok Jeong, Jinsung Kim, Jaehwa Chun

Woosung Choi 29 Jun 04, 2022
Trax — Deep Learning with Clear Code and Speed

Trax — Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us

Google 7.3k Dec 26, 2022
OCR Streamlit App is used to extract text from images using python's easyocr, pytorch and streamlit packages

OCR-Streamlit-App OCR Streamlit App is used to extract text from images using python's easyocr, pytorch and streamlit packages OCR app gets an image a

Siva Prakash 5 Apr 05, 2022
Official implementation of Neural Bellman-Ford Networks (NeurIPS 2021)

NBFNet: Neural Bellman-Ford Networks This is the official codebase of the paper Neural Bellman-Ford Networks: A General Graph Neural Network Framework

MilaGraph 136 Dec 21, 2022
Kaggle | 9th place (part of) solution for the Bristol-Myers Squibb – Molecular Translation challenge

Part of the 9th place solution for the Bristol-Myers Squibb – Molecular Translation challenge translating images containing chemical structures into I

Erdene-Ochir Tuguldur 22 Nov 30, 2022