Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

Overview

FLASH - Pytorch

Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time

Install

$ pip install FLASH-pytorch

Usage

The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.

It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.

import torch
from flash_pytorch import GAU

gau = GAU(
    dim = 512,
    query_key_dim = 128,     # query / key dimension
    causal = True,           # autoregressive or not
    expansion_factor = 2,    # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)

The authors then combine GAU with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.

This combination of the quadratic gated attention unit with grouped linear attention they named FLASH

You can also use this quite easily

import torch
from flash_pytorch import FLASH

flash = FLASH(
    dim = 512,
    group_size = 256,             # group size
    causal = True,                # autoregressive or not
    query_key_dim = 128,          # query / key dimension
    expansion_factor = 2.         # hidden dimension = dim * expansion_factor
)

x = torch.randn(1, 1111, 512)     # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)

Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).

import torch
from flash_pytorch import FLASHTransformer

model = FLASHTransformer(
    num_tokens = 20000,          # number of tokens
    dim = 512,                   # model dimension
    depth = 12,                  # depth
    causal = True,               # autoregressive or not
    group_size = 256,            # size of the groups
    query_key_dim = 128,         # dimension of queries / keys
    expansion_factor = 2.,       # hidden dimension = dim * expansion_factor
    norm_type = 'scalenorm',     # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
    shift_tokens = True          # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)

Test on Autoregressive Enwik8

$ python train.py

Citations

@article{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.10447}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
Comments
  • einsum operation in Linear Attention Part

    einsum operation in Linear Attention Part

    Hi, Thanks a lot for your FLASH_pytorch, which helps a lot. I found that there are some differences from the paper in the Linear Attention Part: https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343

    lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
    lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
    

    the lin_kv is three-dim (bde) And the code in the paper is

    lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask) 
    linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)
    

    the lin_kv is four-dim (bgke) It seems that the two ways are not equivalent.

    Looking forward to your reply. Best,

    opened by ShomyLiu 5
  • mask error

    mask error

    x = torch.randint(0, 20000, (1, 1024))
    mask = x.ne(0)
    logits = model(x, mask=mask)
    

    RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 2

    opened by keyunluo 1
  • Speed on TPU

    Speed on TPU

    Hi, Thanks for the code! I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.

    opened by magicknight 0
  • About the

    About the "shift_tokens"

    Thank you for your amazing code.

    In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following: if self.shift_tokens: x_shift, x_pass = normed_x.chunk(2, dim = -1) x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.) normed_x = torch.cat((x_shift, x_pass), dim = -1)

    Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.

    In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.

    May I know why it works?

    Kang

    opened by kangzhao2 1
  • Cross-Attention?

    Cross-Attention?

    Hi, @lucidrains. Thank you for sharing this excellent implementation with us all! Do you have any thoughts as to what changes would need to be made to make cross-attention possible with your FLASH model?

    opened by amorehead 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Prefix-Tuning: Optimizing Continuous Prompts for Generation

Prefix Tuning Files: . ├── gpt2 # Code for GPT2 style autoregressive LM │ ├── train_e2e.py # high-level script

530 Jan 04, 2023
A fast poisson image editing implementation that can utilize multi-core CPU or GPU to handle a high-resolution image input.

Poisson Image Editing - A Parallel Implementation Jiayi Weng (jiayiwen), Zixu Chen (zixuc) Poisson Image Editing is a technique that can fuse two imag

Jiayi Weng 110 Dec 27, 2022
这是一个mobilenet-yolov4-lite的库,把yolov4主干网络修改成了mobilenet,修改了Panet的卷积组成,使参数量大幅度缩小。

YOLOV4:You Only Look Once目标检测模型-修改mobilenet系列主干网络-在Keras当中的实现 2021年2月8日更新: 加入letterbox_image的选项,关闭letterbox_image后网络的map一般可以得到提升。

Bubbliiiing 65 Dec 01, 2022
A library for preparing, training, and evaluating scalable deep learning hybrid recommender systems using PyTorch.

collie_recs Collie is a library for preparing, training, and evaluating implicit deep learning hybrid recommender systems, named after the Border Coll

ShopRunner 97 Jan 03, 2023
PyTorch implementation for our AAAI 2022 Paper "Graph-wise Common Latent Factor Extraction for Unsupervised Graph Representation Learning"

deepGCFX PyTorch implementation for our AAAI 2022 Paper "Graph-wise Common Latent Factor Extraction for Unsupervised Graph Representation Learning" Pr

Thilini Cooray 4 Aug 11, 2022
Self-training for Few-shot Transfer Across Extreme Task Differences

Self-training for Few-shot Transfer Across Extreme Task Differences (STARTUP) Introduction This repo contains the official implementation of the follo

Cheng Perng Phoo 33 Oct 31, 2022
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

20 May 28, 2022
Imbalanced Gradients: A Subtle Cause of Overestimated Adversarial Robustness

Imbalanced Gradients: A Subtle Cause of Overestimated Adversarial Robustness Code for Paper "Imbalanced Gradients: A Subtle Cause of Overestimated Adv

Hanxun Huang 11 Nov 30, 2022
Fully convolutional deep neural network to remove transparent overlays from images

Fully convolutional deep neural network to remove transparent overlays from images

Marc Belmont 1.1k Jan 06, 2023
A Factor Model for Persistence in Investment Manager Performance

Factor-Model-Manager-Performance A Factor Model for Persistence in Investment Manager Performance I apply methods and processes similar to those used

Omid Arhami 1 Dec 01, 2021
The implementation for paper Joint t-SNE for Comparable Projections of Multiple High-Dimensional Datasets.

Joint t-sne This is the implementation for paper Joint t-SNE for Comparable Projections of Multiple High-Dimensional Datasets. abstract: We present Jo

IDEAS Lab 7 Dec 18, 2022
Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR 2022)

Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR2022)[paper] Authors: Chenhang He, Ruihuang Li, Shuai Li, L

Billy HE 141 Dec 30, 2022
Benchmarks for the Optimal Power Flow Problem

Power Grid Lib - Optimal Power Flow This benchmark library is curated and maintained by the IEEE PES Task Force on Benchmarks for Validation of Emergi

A Library of IEEE PES Power Grid Benchmarks 207 Dec 08, 2022
A simple pytorch pipeline for semantic segmentation.

SegmentationPipeline -- Pytorch A simple pytorch pipeline for semantic segmentation. Requirements : torch=1.9.0 tqdm albumentations=1.0.3 opencv-pyt

petite7 4 Feb 22, 2022
[ICCV'21] NEAT: Neural Attention Fields for End-to-End Autonomous Driving

NEAT: Neural Attention Fields for End-to-End Autonomous Driving Paper | Supplementary | Video | Poster | Blog This repository is for the ICCV 2021 pap

254 Jan 02, 2023
Code for Multiple Instance Active Learning for Object Detection, CVPR 2021

Language: 简体中文 | English Introduction This is the code for Multiple Instance Active Learning for Object Detection, CVPR 2021. Installation A Linux pla

Tianning Yuan 269 Dec 21, 2022
Automatically replace ONNX's RandomNormal node with Constant node.

onnx-remove-random-normal This is a script to replace RandomNormal node with Constant node. Example Imagine that we have something ONNX model like the

Masashi Shibata 1 Dec 11, 2021
CAUSE: Causality from AttribUtions on Sequence of Events

CAUSE: Causality from AttribUtions on Sequence of Events

Wei Zhang 21 Dec 01, 2022
Some pre-commit hooks for OpenMMLab projects

pre-commit-hooks Some pre-commit hooks for OpenMMLab projects. Using pre-commit-hooks with pre-commit Add this to your .pre-commit-config.yaml - rep

OpenMMLab 16 Nov 29, 2022