An implementation of Performer, a linear attention-based transformer, in Pytorch

Overview

Performer - Pytorch

PyPI version

An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random features approach (FAVOR+).

Install

$ pip install performer-pytorch

Usage

Performer Language Model

import torch
from performer_pytorch import PerformerLM

model = PerformerLM(
    num_tokens = 20000,
    max_seq_len = 2048,             # max sequence length
    dim = 512,                      # dimension
    depth = 12,                     # layers
    heads = 8,                      # heads
    causal = False,                 # auto-regressive or not
    nb_features = 256,              # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head
    feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
    generalized_attention = False,  # defaults to softmax approximation, but can be set to True for generalized attention
    kernel_fn = nn.ReLU(),          # the kernel function to be used, if generalized attention is turned on, defaults to Relu
    reversible = True,              # reversible layers, from Reformer paper
    ff_chunks = 10,                 # chunk feedforward layer, from Reformer paper
    use_scalenorm = False,          # use scale norm, from 'Transformers without Tears' paper
    use_rezero = False,             # use rezero, from 'Rezero is all you need' paper
    tie_embedding = False,          # multiply final embeddings with token weights for logits, like gpt decoder
    ff_glu = True,                  # use GLU variant for feedforward
    emb_dropout = 0.1,              # embedding dropout
    ff_dropout = 0.1,               # feedforward dropout
    attn_dropout = 0.1,             # post-attn dropout
    local_attn_heads = 4,           # 4 heads are local attention, 4 others are global performers
    local_window_size = 256,        # window size of local attention
    rotary_position_emb = True      # use rotary positional embedding, which endows linear attention with relative positional encoding with no learned parameters. should always be turned on unless if you want to go back to old absolute positional encoding
)

x = torch.randint(0, 20000, (1, 2048))
mask = torch.ones_like(x).bool()

model(x, mask = mask) # (1, 2048, 20000)

Plain Performer, if you are working with say images or other modalities

import torch
from performer_pytorch import Performer

model = Performer(
    dim = 512,
    depth = 1,
    heads = 8,
    causal = True
)

x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)

Encoder / Decoder - Made possible by Thomas Melistas

import torch
from performer_pytorch import PerformerEncDec

SRC_SEQ_LEN = 4096
TGT_SEQ_LEN = 4096
GENERATE_LEN = 512

enc_dec = PerformerEncDec(
    dim = 512,
    tie_token_embed = True,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_heads = 8,
    enc_max_seq_len = SRC_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_heads = 8,
    dec_max_seq_len = TGT_SEQ_LEN,
)

src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

# train
enc_dec.train()
loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask)
loss.backward()

# generate
generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with <bos> token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens

Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers.

import torch
from performer_pytorch import SelfAttention

attn = SelfAttention(
    dim = 512,
    heads = 8,
    causal = False,
).cuda()

x = torch.randn(1, 1024, 512).cuda()
attn(x) # (1, 1024, 512)

To minimize model surgery, you could also simply rewrite the code, so that the attention step is done by the FastAttention module, as follows.

import torch
from performer_pytorch import FastAttention

# queries / keys / values with heads already split and transposed to first dimension
# 8 heads, dimension of head is 64, sequence length of 512
q = torch.randn(1, 8, 512, 64)
k = torch.randn(1, 8, 512, 64)
v = torch.randn(1, 8, 512, 64)

attn_fn = FastAttention(
    dim_heads = 64,
    nb_features = 256,
    causal = False
)

out = attn_fn(q, k, v) # (1, 8, 512, 64)
# now merge heads and combine outputs with Wo

Advanced

At the end of training, if you wish to fix the projection matrices to get the model to output deterministically, you can invoke the following

model.fix_projection_matrices_()

Now your model will have fixed projection matrices across all layers

Citations

@misc{choromanski2020rethinking,
    title   = {Rethinking Attention with Performers},
    author  = {Krzysztof Choromanski and Valerii Likhosherstov and David Dohan and Xingyou Song and Andreea Gane and Tamas Sarlos and Peter Hawkins and Jared Davis and Afroz Mohiuddin and Lukasz Kaiser and David Belanger and Lucy Colwell and Adrian Weller},
    year    = {2020},
    eprint  = {2009.14794},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@inproceedings{katharopoulos_et_al_2020,
    author  = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
    title   = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
    booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
    year    = {2020}
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@techreport{zhuiyiroformer,
    title   = {RoFormer: Transformer with Rotary Position Embeddings - ZhuiyiAI},
    author  = {Jianlin Su},
    year    = {2021},
    url     = "https://github.com/ZhuiyiTechnology/roformer",
}
Comments
  • [Feature] EncoderDecoder framework, similar to ReformerEncDec

    [Feature] EncoderDecoder framework, similar to ReformerEncDec

    Hello Phil,

    Nice job on this great architecture. I want to use it as an Encoder Decoder within Deepspeed, so I am thinking of writing a wrapper similar to the one you did for Reformer. Do you have any tips on what to pay attention (no pun intended) and if I need to use padding as in Autopadder?

    Thanks

    opened by gulnazaki 22
  • Causal linear attention benchmark

    Causal linear attention benchmark

    First, thanks for this awesome repo!!

    Based on T5 model classes from Huggingface's transformers, I was trying to use performer attention instead of original T5 attention. We finetuned t5-large with summarization model, and tried to profile both time and memory usage, and compare the performer attention with the original attention. I have only benchmarked with input size of 1024.

    The result clearly showed that performer attention use lot less memory compared to the original transformer. I know from the paper that performer outperforms the original transformer when input size is bigger than 1024. However, finetuning and generation with the performer actually took longer, so I profiled the forward call of both the original T5 attention and the performer attention. The forward of T5 performer took twice longer and the main bottleneck was causal_dot_product_kernel from fast-transformers.

    Is this a normal performace of the performer or causal attention calculation? or Will the performer attention be faster with the bigger input size?

    opened by ice-americano 13
  • Regarding DDP and reversible networks

    Regarding DDP and reversible networks

    Hi, I'm trying to figure out how to combine DDP with setting the network to be reversible.

    My code basically looks like this:

    import pytorch_lightning as pl
    from performer_pytorch import Performer
    ...
    model = nn.Sequential([...,Performer(...,reversible=True)])
    trainer = pl.Trainer(...
                        distributed_backend='ddp',
                        ...)
    trainer.fit(model,train_loader,val_loader)
    

    Now all combinations work for me (ddp/not reversible, not ddp/reversible, not ddp/not reversible) except for ddp and reversible.

    The error I get is:

    RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons:

    1. Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes
    2. Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.

    I've seen multiple people have similar issues: https://github.com/huggingface/transformers/issues/7160 ,https://github.com/pytorch/pytorch/issues/46166 , https://github.com/tatp22/linformer-pytorch/issues/23

    Do you have any suggestion for how to deal with this issue? Im not really familiar with the inner workings of DDP and the autograd engine, so I'm not sure how to fix this myself.

    opened by Parskatt 11
  • Triangular matrices ?

    Triangular matrices ?

    Does the current implementation provide triangular matrices (to constrain the attention always on the "left" of the sequence, both for input and encoded values) as described in the last section of the original paper?

    opened by jeremycochoy 10
  • wrong implementation for autoregressive self-attention

    wrong implementation for autoregressive self-attention

    Hi, I found that you used fast_transfomers's CUDA Kernel, but it does not contain normalization part, which needs a cumsum outside the CausalDotProduct (in causal_linear_attention). If I didn't miss something, the result of your code should be wrong... But I am not 100% sure.

    opened by Sleepychord 10
  • There are no tests in this project, use_rezero=True is non-functional

    There are no tests in this project, use_rezero=True is non-functional

    Tests are needed to validate that models can train in various configurations. I built and ran simple tests (trying to get authorization to contribute as a PR) and found that use_rezero=True kills the gradient and results in a performer model that cannot learn. The fix consists in initializing the rezero parameter with a small value, but not zero (e.g., 1E-3 works in my tests). Zero prevents any signal to pass through the module so that the parameter will never change from zero.

    opened by fcampagne 10
  • Show what is the performance on enwiki8 is across your projects

    Show what is the performance on enwiki8 is across your projects

    Hello @lucidrains , I´m a very big fan of your work. It is of such as high quality, that every new project you release I get sleepless to try it.

    You do have many different versions of transformers, such as reformer, memory-xl, performer... And apparently you already test it with enwiki8.

    Would be possible to post on Read-me a table with the enwiki runtime, memory and some performance metric? That would be awesome to compare the different implementations.

    Thanks again for your work!!

    opened by bratao 10
  • Issue with biased estimates from QR decomposition

    Issue with biased estimates from QR decomposition

    Hi again :)

    See issue: https://github.com/google-research/google-research/issues/436 that I posted on the main repository. Using the QR incorrectly produces results with significantly higher variance. There is quite an easy fix by simply doing

     q, r = torch.qr(flattened)
     # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
     d = torch.diag(r, 0)
     ph = d.sign()
     q *= ph
    
    opened by Parskatt 9
  • Collaborate on Implementation?

    Collaborate on Implementation?

    I was planning on implementing this on Pytorch as well and started a repo https://github.com/calclavia/Performer-Pytorch Implemented the kernel so far. If the author(s) of this repo wants to collaborate, would be happy to contribute.

    opened by calclavia 9
  • Extra FF when using cross attention

    Extra FF when using cross attention

    Hello Phil,

    I have noticed that when using cross attention a new block (with attention and a FeedForward layer is added), while only an attention layer should be added between the self attention and the FF layer.

    Is there any reason for this?

    opened by gulnazaki 8
  • Add feature_redraw_interval option

    Add feature_redraw_interval option

    This fork allows the user to select a number of forward passes after which the random features will be redrawn. This allows us to avoid doing QR decomposition on the CPU every forward pass. By default it is set to redraw every 1000 passes.

    opened by norabelrose 8
  • Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Performer Pytorch Slower than Expected and Please Help with Understanding Parameter Count

    Hi,

    First of all, this is a great package from lucidrains and I find it very helpful in my research.

    A quick question is that I noticed ViT-performer is slower than the regular ViT from lucidrains. For example running on mnist from pytorch will take 15 sec/epoch for regular ViT with the configuration below while ViT performer takes 23 sec/epoch.

    Checking the parameter count also shows ViT-performer has double the size of regular ViT.

    Screen Shot 2022-12-12 at 11 32 41 PM Screen Shot 2022-12-12 at 11 28 50 PM

    I am hoping that someone has intuition about the speed of ViT performer vs regular ViT and their parameter counts.

    Thank you very much in advance!

    opened by weihaosong 1
  • Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    Using replicating nn.MultiHeadAttention with multiple performer SelfAttention modules

    As the title says, has anyone tried replacing multi head attention in a typical transformer with the self attention as described in this library.

    my thought was that I can essentially concat the multiple self attention elements together to replicate this per the attached image from the torch website. image

    I'm relatively new to transformers as a whole so hopefully this question makes some sense.

    for reference, considering the interest in a previous post, I've been attempting to explore performer effectiveness with DETR (https://github.com/facebookresearch/detr)

    thanks!

    opened by JGittles 0
  • Question about masking

    Question about masking

    Hi, thanks for the wonderful repo, I am new in BERT, so I 'd like to make sure in your example:

    model = PerformerLM() x = torch.randint(0, 20000, (1, 2048)) mask = torch.ones_like(x).bool() model(x, mask = mask) # (1, 2048, 20000)

    is this 'mask' is attention_mask? i.e., TRUE (1) for normal tokens and FALSE (0) for padding tokens? Or set 1 to indicate padding token? Thanks a lot!

    opened by Microbiods 1
  • Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Question: Is Performer order equivariant? (can it transform an unordered set of tensors)

    Hi,

    Thanks for the amazing implementation. I'm wondering if Performer can be used like a set-operator (i.e. whether it is order equivariant)

    For instance, say I have a point cloud and I want to apply self-attention across all the point features. Can Performer be used here (note equivariance: points can be arbitrarily shuffled, but we expect the corresponding transformed features to be identical regardless of the shuffling)?

    Thanks!

    opened by nmakes 0
  • Using Performer with GNNs

    Using Performer with GNNs

    My understanding of "Rethinking Attention with Performers" is that FAVOR+ is used to approximate the attention matrix and avoids the use of the softmax function. In the README.md file, you note that the Plain Performer can be used if we are using images or other modalities, just as the authors elude to Performer's use in other areas.

    I am interested in using Perfomer to approximate attention between nodes in a graph neural network. The graph neural network contains vectors characterizing the node's features and boolean edge indices indicating a connection between two nodes.

    Do you have any recommendations how this is feasible with the current Performer model? I see that Attention.forward() contains input for a mask.

    opened by jah377 0
Releases(1.1.4)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
270 Dec 24, 2022
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and

Kaiyu Yue 275 Nov 22, 2022
Use Jax functions in Pytorch with DLPack

Use Jax functions in Pytorch with DLPack

Phil Wang 106 Dec 17, 2022
A few Windows specific scripts for PyTorch

It is a repo that contains scripts that makes using PyTorch on Windows easier. Easy Installation Update: Starting from 0.4.0, you can go to the offici

408 Dec 15, 2022
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
High-fidelity performance metrics for generative models in PyTorch

High-fidelity performance metrics for generative models in PyTorch

Vikram Voleti 5 Oct 24, 2021
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
PyTorch to TensorFlow Lite converter

PyTorch to TensorFlow Lite converter

Omer Ferhat Sarioglu 140 Dec 13, 2022
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 03, 2023
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021