Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Overview

Vision Transformer - Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.

For a Pytorch implementation with pretrained models, please see Ross Wightman's repository here.

The official Jax repository is here.

Install

$ pip install vit-pytorch

Usage

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Parameters

  • image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
  • patch_size: int.
    Number of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  • num_classes: int.
    Number of classes to classify.
  • dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(..., dim).
  • depth: int.
    Number of Transformer blocks.
  • heads: int.
    Number of heads in Multi-head Attention layer.
  • mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  • channels: int, default 3.
    Number of image's channels.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  • pool: string, either cls token pooling or mean pooling

Distillation

A recent paper has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.

ex. distilling from Resnet50 (or any teacher) to a vision transformer

import torch
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper

teacher = resnet50(pretrained = True)

v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)

img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))

loss = distiller(img, labels)
loss.backward()

# after lots of training above ...

pred = v(img) # (2, 1000)

The DistillableViT class is identical to ViT except for how the forward pass is handled, so you should be able to load the parameters back to ViT after you have completed distillation training.

You can also use the handy .to_vit method on the DistillableViT instance to get back a ViT instance.

v = v.to_vit()
type(v) # 

Deep ViT

This paper notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the Talking Heads paper from NLP.

You can use it as follows

import torch
from vit_pytorch.deepvit import DeepViT

v = DeepViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

CaiT

This paper also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.

They also add Talking Heads, noting improvements

You can use this scheme as follows

import torch
from vit_pytorch.cait import CaiT

v = CaiT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 12,             # depth of transformer for patch to patch attention only
    cls_depth = 2,          # depth of cross attention of CLS tokens to patch
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05    # randomly dropout 5% of the layers
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Token-to-Token ViT

This paper proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the ViT as follows.

import torch
from vit_pytorch.t2t import T2TViT

v = T2TViT(
    dim = 512,
    image_size = 224,
    depth = 5,
    heads = 8,
    mlp_dim = 512,
    num_classes = 1000,
    t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

Cross ViT

This paper proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.

import torch
from vit_pytorch.cross_vit import CrossViT

v = CrossViT(
    image_size = 256,
    num_classes = 1000,
    depth = 4,               # number of multi-scale encoding blocks
    sm_dim = 192,            # high res dimension
    sm_patch_size = 16,      # high res patch size (should be smaller than lg_patch_size)
    sm_enc_depth = 2,        # high res depth
    sm_enc_heads = 8,        # high res heads
    sm_enc_mlp_dim = 2048,   # high res feedforward dimension
    lg_dim = 384,            # low res dimension
    lg_patch_size = 64,      # low res patch size
    lg_enc_depth = 3,        # low res depth
    lg_enc_heads = 8,        # low res heads
    lg_enc_mlp_dim = 2048,   # low res feedforward dimensions
    cross_attn_depth = 2,    # cross attention rounds
    cross_attn_heads = 8,    # cross attention heads
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

pred = v(img) # (1, 1000)

PiT

This paper proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.

import torch
from vit_pytorch.pit import PiT

v = PiT(
    image_size = 224,
    patch_size = 14,
    dim = 256,
    num_classes = 1000,
    depth = (3, 3, 3),     # list of depths, indicating the number of rounds of each stage before a downsample
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

LeViT

This paper proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.

Official repository

import torch
from vit_pytorch.levit import LeViT

levit = LeViT(
    image_size = 224,
    num_classes = 1000,
    stages = 3,             # number of stages
    dim = (256, 384, 512),  # dimensions at each stage
    depth = 4,              # transformer of depth 4 at each stage
    heads = (4, 6, 8),      # heads at each stage
    mlp_mult = 2,
    dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

levit(img) # (1, 1000)

CvT

This paper proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.

import torch
from vit_pytorch.cvt import CvT

v = CvT(
    num_classes = 1000,
    s1_emb_dim = 64,        # stage 1 - dimension
    s1_emb_kernel = 7,      # stage 1 - conv kernel
    s1_emb_stride = 4,      # stage 1 - conv stride
    s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size
    s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride
    s1_heads = 1,           # stage 1 - heads
    s1_depth = 1,           # stage 1 - depth
    s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor
    s2_emb_dim = 192,       # stage 2 - (same as above)
    s2_emb_kernel = 3,
    s2_emb_stride = 2,
    s2_proj_kernel = 3,
    s2_kv_proj_stride = 2,
    s2_heads = 3,
    s2_depth = 2,
    s2_mlp_mult = 4,
    s3_emb_dim = 384,       # stage 3 - (same as above)
    s3_emb_kernel = 3,
    s3_emb_stride = 2,
    s3_proj_kernel = 3,
    s3_kv_proj_stride = 2,
    s3_heads = 4,
    s3_depth = 10,
    s3_mlp_mult = 4,
    dropout = 0.
)

img = torch.randn(1, 3, 224, 224)

pred = v(img) # (1, 1000)

Twins SVT

This paper proposes mixing local and global attention, along with position encoding generator (proposed in CPVT) and global average pooling, to achieve the same results as Swin, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.

import torch
from vit_pytorch.twins_svt import TwinsSVT

model = TwinsSVT(
    num_classes = 1000,       # number of output classes
    s1_emb_dim = 64,          # stage 1 - patch embedding projected dimension
    s1_patch_size = 4,        # stage 1 - patch size for patch embedding
    s1_local_patch_size = 7,  # stage 1 - patch size for local attention
    s1_global_k = 7,          # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
    s1_depth = 1,             # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
    s2_emb_dim = 128,         # stage 2 (same as above)
    s2_patch_size = 2,
    s2_local_patch_size = 7,
    s2_global_k = 7,
    s2_depth = 1,
    s3_emb_dim = 256,         # stage 3 (same as above)
    s3_patch_size = 2,
    s3_local_patch_size = 7,
    s3_global_k = 7,
    s3_depth = 5,
    s4_emb_dim = 512,         # stage 4 (same as above)
    s4_patch_size = 2,
    s4_local_patch_size = 7,
    s4_global_k = 7,
    s4_depth = 4,
    peg_kernel_size = 3,      # positional encoding generator kernel size
    dropout = 0.              # dropout
)

img = torch.randn(1, 3, 224, 224)

pred = model(img) # (1, 1000)

NesT

This paper decided to process the image in heirarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the heirarchy. The aggregation is done in the image plane, and contains a convolution to allow it to pass information across the boundary.

You can use it with the following code (ex. NesT-T)

import torch
from vit_pytorch.nest import NesT

nest = NesT(
    image_size = 224,
    patch_size = 4,
    dim = 96,
    heads = 3,
    num_heirarchies = 3,        # number of heirarchies
    block_repeats = (8, 4, 1),  # the number of transformer blocks at each heirarchy, starting from the bottom
    num_classes = 1000
)

img = torch.randn(1, 3, 224, 224)
pred = nest(img) # (1, 1000)

Masked Patch Prediction

Thanks to Zach, you can train using the original masked patch prediction task presented in the paper, with the following code.

import torch
from vit_pytorch import ViT
from vit_pytorch.mpp import MPP

model = ViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)

mpp_trainer = MPP(
    transformer=model,
    patch_size=32,
    dim=1024,
    mask_prob=0.15,          # probability of using token in masked prediction task
    random_patch_prob=0.30,  # probability of randomly replacing a token being used for mpp
    replace_prob=0.50,       # probability of replacing a token being used for mpp with the mask token
)

opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = mpp_trainer(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

Dino

You can train ViT with the recent SOTA self-supervised learning technique, Dino, with the following code.

Yannic Kilcher video

import torch
from vit_pytorch import ViT, Dino

model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

learner = Dino(
    model,
    image_size = 256,
    hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

Accessing Attention

If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)

# there is one extra patch due to the CLS token

attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)

to cleanup the class and the hooks once you have collected enough data

v = v.eject()  # wrapper is discarded and original ViT instance is returned

Research Ideas

Efficient Attention

There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.

An example with Nystromformer

$ pip install nystrom-attention
import torch
from vit_pytorch.efficient import ViT
from nystrom_attention import Nystromformer

efficient_transformer = Nystromformer(
    dim = 512,
    depth = 12,
    heads = 8,
    num_landmarks = 256
)

v = ViT(
    dim = 512,
    image_size = 2048,
    patch_size = 32,
    num_classes = 1000,
    transformer = efficient_transformer
)

img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)

Other sparse attention frameworks I would highly recommend is Routing Transformer or Sinkhorn Transformer

Combining with other Transformer improvements

This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the Encoder from this repository.

ex.

$ pip install x-transformers
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder

v = ViT(
    dim = 512,
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    transformer = Encoder(
        dim = 512,                  # set to be the same as the wrapper
        depth = 12,
        heads = 8,
        ff_glu = True,              # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
        residual_attn = True        # ex. residual attention https://arxiv.org/abs/2012.11747
    )
)

img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)

FAQ

  • How do I pass in non-square images?

You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the image_size, and both divisible by the patch_size

ex.

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128) # <-- not a square

preds = v(img) # (1, 1000)
  • How do I pass in non-square patches?
import torch
from vit_pytorch import ViT

v = ViT(
    num_classes = 1000,
    image_size = (256, 128),  # image size is a tuple of (height, width)
    patch_size = (32, 16),    # patch size is a tuple of (height, width)
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128)

preds = v(img)

Resources

Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.

  1. Illustrated Transformer - Jay Alammar

  2. Transformers from Scratch - Peter Bloem

  3. The Annotated Transformer - Harvard NLP

Citations

@misc{dosovitskiy2020image,
    title   = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
    author  = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
    year    = {2020},
    eprint  = {2010.11929},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{touvron2020training,
    title   = {Training data-efficient image transformers & distillation through attention}, 
    author  = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
    year    = {2020},
    eprint  = {2012.12877},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yuan2021tokenstotoken,
    title     = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
    author    = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
    year      = {2021},
    eprint    = {2101.11986},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{zhou2021deepvit,
    title   = {DeepViT: Towards Deeper Vision Transformer},
    author  = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng},
    year    = {2021},
    eprint  = {2103.11886},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{touvron2021going,
    title   = {Going deeper with Image Transformers}, 
    author  = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
    year    = {2021},
    eprint  = {2103.17239},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chen2021crossvit,
    title   = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
    author  = {Chun-Fu Chen and Quanfu Fan and Rameswar Panda},
    year    = {2021},
    eprint  = {2103.14899},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{wu2021cvt,
    title   = {CvT: Introducing Convolutions to Vision Transformers},
    author  = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang},
    year    = {2021},
    eprint  = {2103.15808},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{heo2021rethinking,
    title   = {Rethinking Spatial Dimensions of Vision Transformers}, 
    author  = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
    year    = {2021},
    eprint  = {2103.16302},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{graham2021levit,
    title   = {LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
    author  = {Ben Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Hervé Jégou and Matthijs Douze},
    year    = {2021},
    eprint  = {2104.01136},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{li2021localvit,
    title   = {LocalViT: Bringing Locality to Vision Transformers},
    author  = {Yawei Li and Kai Zhang and Jiezhang Cao and Radu Timofte and Luc Van Gool},
    year    = {2021},
    eprint  = {2104.05707},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chu2021twins,
    title   = {Twins: Revisiting Spatial Attention Design in Vision Transformers},
    author  = {Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
    year    = {2021},
    eprint  = {2104.13840},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{zhang2021aggregating,
    title   = {Aggregating Nested Transformers},
    author  = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
    year    = {2021},
    eprint  = {2105.12723},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{caron2021emerging,
    title   = {Emerging Properties in Self-Supervised Vision Transformers},
    author  = {Mathilde Caron and Hugo Touvron and Ishan Misra and Hervé Jégou and Julien Mairal and Piotr Bojanowski and Armand Joulin},
    year    = {2021},
    eprint  = {2104.14294},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines. — Claude Shannon

Comments
  • Masked tokens

    Masked tokens

    Just a quick question. I was wondering if the masks which get passed to the model at inference are for the purposes of masked tokens for self-supervision or if they are a different mask? Thanks

    opened by zankner 15
  • `AttributeError: 'JpegImageFile' object has no attribute 'shape'`

    `AttributeError: 'JpegImageFile' object has no attribute 'shape'`

    I was trying the code with the random image but I need to print the image I am passing nottouch.rand

    my_image = load_img("127074.jpg")
    v = Vit()
    img = torch.randn(1, 3, 256, 256)
    preds = v(img)
    print(preds)
    p=v(my_image)
    print(p)
    

    the code worked with using random but when i pass my_image i got AttributeError: 'JpegImageFile' object has no attribute 'shape'

    opened by ersamo 14
  • The model doesn't converge

    The model doesn't converge

    Hi,

    Thank you for your work.

    I have tested to train your implementation for action classification on Kinetics400, but find the training not convergent. image Note that the learning rate is calculated based on the paper: 4096 ~ 6e-4 by Linear Scaling Rule. I also applied warmup but the loss plateau while warming up.

    I have also tested the pretrained model in timm. Also not convergent. image

    Do you have any suggestion for better training such stand-alone transformers? Thanks.

    opened by SuX97 14
  • handle non square patch sizes

    handle non square patch sizes

    Hello, would it be difficult to allow for nonsquare patches to be used, I'm examining the use of Vit on some non-image data where it cannot be represented in a square format. Preferably by having a tuple as an input to select width and height. Is there anything currently that would block this from being done?

    opened by FilipAndersson245 11
  • Tensors must have same number of dimensions : got 5 and 3

    Tensors must have same number of dimensions : got 5 and 3

    Hello @lucidrains @stevenwalton I have been trying to implement the standard ViT in 3d space and I have worked on some part of code ViT changed the Rearrange in patch embedding to as follows Rearrange('b e (h p1) (w p2) (d p3) -> b (e p1 p2 p3) h w d',p1=patch_size,p2=patch_size,p3=patch_size) and this patch embbeddings are passed to map with cls_tokens cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b) which throws an error due to dimensionality mismatch so how can i change the shape of cls_tokens to match the dimensionality of the patch_embeddings.

    can you help me for getting solution to this problem Thanks & Regards Satwik Sunnam

    opened by satwiksunnam19 8
  • Accessing last layer hidden states or embeddings for models like CrossViT, RegionViT (Extractor doesn't seem to work)

    Accessing last layer hidden states or embeddings for models like CrossViT, RegionViT (Extractor doesn't seem to work)

    How can I access the last layer hidden states aka embeddings of an image from models like CrossViT and RegionViT? The extractor option works only on vanilla ViT.

    Please advice

    opened by PrithivirajDamodaran 8
  • Hilach/dropout

    Hilach/dropout

    Thanks for the useful resource @lucidrains! I saw this part from the paper: Dropout, when used, is applied after every dense layer except for the the qkv-projections and directly after adding positional- to patch embeddings. So 3 things regarding the dropouts you added:

    1. I think there's a dropout missing right after adding the positional to the embedding.
    2. I don't see why there should be a dropout after the dot product in the attention- line 57 (it's not a dense layer and it's not mentioned in the above description).
    3. considering they refer to the qkv-projections as dense layers, I think they mean that all linear layers are dense layers, so I added dropouts after each linear layer.

    Kindly let me know if you have any comments/ disagreements.

    opened by hila-chefer 8
  • Using masks as preprocessing for classification [FR]

    Using masks as preprocessing for classification [FR]

    Maybe it's a little bit too early to ask for this but could it be possible to specify regions within an image for ViT to perfom the prediction? I was thinking on a binary mask, for example, which could be used for the tiling step in order to obtain different images sequences.

    I am thinking on a pipeline where, in order to increase resolution, you could specify the regions to perform the training based on whatever reason you find it suitable (previous attention maps for example :smile:).

    opened by Tato14 8
  • Loss cannot drop

    Loss cannot drop

    Thank you so much for sharing your codes. I try to employ Vit as the encoder and follow a common decoder to build a segmentation network. I train it from scratch but found the loss can't drop since the beginning of training, and the results keep near 0. Is there any trick for training Vit correctly? Is it very important to load the pre-train model to fine-tune? Here is my configuration: patch_size=16 hidden_size=16*16*3 mlp_dim = 3072 dropout_rate = 0.1 num_heads = 12 num_layers = 12 lr=3e-4 opt=Adam weight_decay=0.0

    opened by QiushiYang 7
  • Regression

    Regression

    Hi, I am doing a regression task on images, predicting 6 numbers form images. Does it makes sense to use the CLS token or can I just pool the last layer of the transformer and connect it to the mlp head?

    opened by aslarsen 7
  • How to set appropriate learning rate ?

    How to set appropriate learning rate ?

    vit = ViT( image_size=448, patch_size=32, num_classes=180, dim=1024, depth=8, heads=8, mlp_dim=2048, dropout=0.5, emb_dropout=0.5 ).cuda() optimizer = torch.optim.Adam(vit.parameters(), lr=5e-3, weight_decay=0.1) I tried to train ViT on a 180-class dataset and used the shown config but loss doesn't descend during training. Any suggestion to solve ?

    opened by sleeplessai 7
  • add an interpolate_embeddings helper function

    add an interpolate_embeddings helper function

    This library is awesome. However I want to pretrain on small patches and then transfer learn on larger images. torchvision has a really great helper function for their vision transformer class : interpolate_embeddings.

    """This function helps interpolating positional embeddings during checkpoint loading,
        especially when you want to apply a pre-trained model on images with different resolution.
    

    it would be awesome if your library also had something like that :)

    opened by DanTaranis 0
  • Quesiton about attention's qkv matrix

    Quesiton about attention's qkv matrix

    Hello thanks for this great repo!

    I am confused about the details in the vit.py. In the attention's section, when compute the q, k, v matrix, you project x from ( b, n, d ) to ( b, n, 3d) and then split x to 3 parts using chunk() like:

    qkv = self.to_qkv(x).chunk(3, dim = -1)

    However, I think in this way q, k and v matrix only contains part of the information of the original matrix x, which is not the exact meaning of the transformer paper. In the original paper, q, k, v contains all the information of the input matrix, and then perform dot production to compute attentions. Please check xD.

    PS: I am a beginner in this topic, if I have any misunderstanding, please figure it out and sorry for any possible inconveniece.

    opened by JearSYY 0
  • Issues loading RegionVIT pre-trained checkpoints

    Issues loading RegionVIT pre-trained checkpoints

    Hello there - I am trying to load pre-trained checkpoints from original author's page. Are there changes in the implementation of the model here? None of the checkpoints are loading citing differences in layers.

    Screenshot 2022-09-09 at 10 14 21 PM
    opened by PrithivirajDamodaran 0
  • Visualize the attention weights

    Visualize the attention weights

    Hello,

    I am wondering how can we visualize the attention maps as in the original ViT paper. I tried several strategies but I failed unfortunately.

    Thanks in advance.

    opened by code-pan94 1
  • Distillation RuntimeError

    Distillation RuntimeError

    Hi,

    I am using distiller where teacher network is resnet34.

    I am getting this error while training the model:

    distiller(data, target) Traceback (most recent call last): File "/home/neel/miniconda3/envs/process/lib/python3.7/code.py", line 90, in runcode exec(code, self.locals) File "", line 1, in File "/home/neel/miniconda3/envs/process/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/neel/miniconda3/envs/process/lib/python3.7/site-packages/vit_pytorch/distill.py", line 146, in forward reduction = 'batchmean') File "/home/neel/miniconda3/envs/process/lib/python3.7/site-packages/torch/nn/functional.py", line 2916, in kl_div reduced = torch.kl_div(input, target, reduction_enum, log_target=log_target) RuntimeError: The size of tensor a (1000) must match the size of tensor b (2) at non-singleton dimension 1

    the shape of my data is: torch.Size([24, 3, 224, 224]) the shape of the target is : torch.Size([24])

    Is there something wrong? I am using the same instructions from the GitHub homepage. Has anyone else experienced this?

    Neel

    opened by NeelKanwal 1
Releases(0.40.2)
Owner
Phil Wang
Working with Attention
Phil Wang
pix2pix in tensorflow.js

pix2pix in tensorflow.js This repo is moved to https://github.com/yining1023/pix2pix_tensorflowjs_lite See a live demo here: https://yining1023.github

Yining Shi 47 Oct 04, 2022
Official PyTorch implementation of the NeurIPS 2021 paper StyleGAN3

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation of the NeurIPS 2021 paper Alias-Free Generative Adversarial Net

Eugenio Herrera 92 Nov 18, 2022
How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

Bogdan Kulynych 49 Nov 05, 2022
On Size-Oriented Long-Tailed Graph Classification of Graph Neural Networks

On Size-Oriented Long-Tailed Graph Classification of Graph Neural Networks We provide the code (in PyTorch) and datasets for our paper "On Size-Orient

Zemin Liu 4 Jun 18, 2022
fcn by tensorflow

Update An example on how to integrate this code into your own semantic segmentation pipeline can be found in my KittiSeg project repository. tensorflo

9 May 22, 2022
Official implementation for TTT++: When Does Self-supervised Test-time Training Fail or Thrive

TTT++ This is an official implementation for TTT++: When Does Self-supervised Test-time Training Fail or Thrive? TL;DR: Online Feature Alignment + Str

VITA lab at EPFL 39 Dec 25, 2022
Qimera: Data-free Quantization with Synthetic Boundary Supporting Samples

Qimera: Data-free Quantization with Synthetic Boundary Supporting Samples This repository is the official implementation of paper [Qimera: Data-free Q

Kanghyun Choi 21 Nov 03, 2022
PolyGlot, a fuzzing framework for language processors

PolyGlot, a fuzzing framework for language processors Build We tested PolyGlot on Ubuntu 18.04. Get the source code: git clone https://github.com/s3te

Software Systems Security Team at Penn State University 79 Dec 27, 2022
Extension to fastai for volumetric medical data

FAIMED 3D use fastai to quickly train fully three-dimensional models on radiological data Classification from faimed3d.all import * Load data in vari

Keno 26 Aug 22, 2022
A toolkit for document-level event extraction, containing some SOTA model implementations

❤️ A Toolkit for Document-level Event Extraction with & without Triggers Hi, there 👋 . Thanks for your stay in this repo. This project aims at buildi

Tong Zhu(朱桐) 159 Dec 22, 2022
The dynamics of representation learning in shallow, non-linear autoencoders

The dynamics of representation learning in shallow, non-linear autoencoders The package is written in python and uses the pytorch implementation to ML

Maria Refinetti 4 Jun 08, 2022
Graph Neural Networks with Keras and Tensorflow 2.

Welcome to Spektral Spektral is a Python library for graph deep learning, based on the Keras API and TensorFlow 2. The main goal of this project is to

Daniele Grattarola 2.2k Jan 08, 2023
TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffic Environments for IV 2022.

TorchGRL TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffi

XXQQ 42 Dec 09, 2022
Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Networks

CyGNet This repository reproduces the AAAI'21 paper “Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Network

CunchaoZ 89 Jan 03, 2023
BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond

BasicVSR BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond Ported from https://github.com/xinntao/BasicSR Dependencie

Holy Wu 8 Jun 07, 2022
QQ Browser 2021 AI Algorithm Competition Track 1 1st Place Program

QQ Browser 2021 AI Algorithm Competition Track 1 1st Place Program

249 Jan 03, 2023
Motion Reconstruction Code and Data for Skills from Videos (SFV)

Motion Reconstruction Code and Data for Skills from Videos (SFV) This repo contains the data and the code for motion reconstruction component of the S

268 Dec 01, 2022
Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch

MeMOT - Pytorch (wip) Implementation of MeMOT - Multi-Object Tracking with Memory - in Pytorch. This paper is just one in a line of work, but importan

Phil Wang 15 May 09, 2022
This repository contains demos I made with the Transformers library by HuggingFace.

Transformers-Tutorials Hi there! This repository contains demos I made with the Transformers library by 🤗 HuggingFace. Currently, all of them are imp

3.5k Jan 01, 2023
Deepfake Scanner by Deepware.

Deepware Scanner (CLI) This repository contains the command-line deepfake scanner tool with the pre-trained models that are currently used at deepware

deepware 110 Jan 02, 2023