Implementation of Nyström Self-attention, from the paper Nyströmformer

Overview

Nyström Attention

Implementation of Nyström Self-attention, from the paper Nyströmformer.

Yannic Kilcher video

Install

$ pip install nystrom-attention

Usage

import torch
from nystrom_attention import NystromAttention

attn = NystromAttention(
    dim = 512,
    dim_head = 64,
    heads = 8,
    num_landmarks = 256,    # number of landmarks
    pinv_iterations = 6,    # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
    residual = True         # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

attn(x, mask = mask) # (1, 16384, 512)

Nyströmformer, layers of Nyström attention

import torch
from nystrom_attention import Nystromformer

model = Nystromformer(
    dim = 512,
    dim_head = 64,
    heads = 8,
    depth = 6,
    num_landmarks = 256,
    pinv_iterations = 6
)

x = torch.randn(1, 16384, 512)
mask = torch.ones(1, 16384).bool()

model(x, mask = mask) # (1, 16384, 512)

You can also import it as Nyströmer if you wish

from nystrom_attention import Nystromer

Citations

@misc{xiong2021nystromformer,
    title   = {Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention},
    author  = {Yunyang Xiong and Zhanpeng Zeng and Rudrasis Chakraborty and Mingxing Tan and Glenn Fung and Yin Li and Vikas Singh},
    year    = {2021},
    eprint  = {2102.03902},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
Comments
  • Clarification on masking

    Clarification on masking

    Given the dimensionality of the mask argument, (N, T), I'm assuming this is a boolean mask for masking out padding tokens. I created the following function to generate such a mask given an input tensor:

    def _create_pad_mask(self, x: torch.LongTensor) -> torch.BoolTensor:
        mask = torch.ones_like(x).to(torch.bool)
        mask[x==0] = False
        return mask
    

    where 0 is the padding token, setting positions to False so not to attend to them.

    However, I am unsure how to apply a causal mask to the attention layers so to prevent my decoder from accessing future elements. I couldn't see an example of this in the full Nystromformer module. How can I achieve this?

    For context, I am trying to apply the causal mask generated by the following function:

    def _create_causal_mask(self, x: torch.LongTensor) -> torch.FloatTensor:
        size = x.shape[1]
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill_(mask == 0, float('-inf')).masked_fill_(mask==1, 0.0)
        return mask
    

    One way I can think of is to set return_attn to True, apply the mask on the returned attention weights then matmul with the value tensor. But this has a few issues:

    • Having to return v
    • Computing the full attention matrix (I think), defeating the entire point of linear attention
    • Needlessly calculating out only to discard it.

    Is this just a limitation of Nystrom attention? Or am I overlooking something obvious?

    Thanks

    opened by vvvm23 3
  • Possible bug with padding

    Possible bug with padding

    Hey there,

    I was going through the code and I noticed the following, which I found curious.

    In Line 75, you pad the input tensor to a multiple of num_landmarks from the front:

    x = F.pad(x, (0, 0, padding, 0), value = 0)
    

    In Line 144 you trim the extra padding elements you inserted in the output tensor from the end.

    out = out[:, :n]
    

    Am I not getting something, or should we be removing the front elements of out?

    out = out[:, out.size(1) - n:]
    
    opened by georgepar 2
  • Nystrom for Image processing

    Nystrom for Image processing

    thank you for sharing the wondeful code. I am working on image processing and wanted to try your code for the same. I have 2 doubts:

    1. How to select residual_conv_kernel? I could not find any details for the same. also, it is enabled by a flag. When should we enable it and when to disable it?
    2. Is there any guideline for deciding num_landmarks for image processing task?

    Thanks

    opened by paragon1234 1
  • Error when mask is of the same size as that of the input X

    Error when mask is of the same size as that of the input X

    Hi,

    First of all, thank you for putting such an easy to use implementation on GitHub. I'm trying to incorporate the nystrom attention into a legacy codebase, it previously used to provide the input X and the mask (off the same dimensions as X) to a Multi headed Attention Layer.

    When I'm trying to integrate nystrom attention with it, it runs alright without the mask. But, when I pass the mask alongside it, it throws einops rearrange error.

    Sorry, if this is a very basic question, but how would you recommend I deal with handling 3D mask (same dimensions as the size of input) in the codebase.

    Best, VB

    opened by Vaibhavs10 1
  • ViewBackward inplace deprecation warning

    ViewBackward inplace deprecation warning

    Hello again,

    The following code results in a UserWarning in PyTorch 1.8.1.

    In [1]: from nystrom_attention.nystrom_attention import NystromAttention
    
    In [2]: import torch
    
    In [3]: attn = NystromAttention(256)
    
    In [4]: x = torch.randn(1, 8192, 256)
    
    In [5]: attn(x)
    /home/alex/.tmp/nystrom-attention/nystrom_attention/nystrom_attention.py:91: UserWarning: Output 0 of ViewBackward is a view and is being modified inplace. This view is an output of a function that returns multiple views. Inplace operators on such views are being deprecated and will be forbidden starting from version 1.8. Consider using `unsafe_` version of the function that produced this view or don't modify this view inplace. (Triggered internally at  ../torch/csrc/autograd/variable.cpp:547.)
      q *= self.scale
    Out[5]:
    tensor([[[-0.0449, -0.1726,  0.1409,  ...,  0.0127,  0.2287, -0.2437],
             [-0.1132,  0.3229, -0.1279,  ...,  0.0084, -0.3307, -0.2351],
             [ 0.0361,  0.1013,  0.0828,  ...,  0.1045, -0.1627,  0.0736],
             ...,
             [ 0.0018,  0.1385, -0.1716,  ..., -0.0366, -0.0682,  0.0241],
             [ 0.1497,  0.0149, -0.0020,  ..., -0.0352, -0.1126,  0.0193],
             [ 0.1341,  0.0077,  0.1627,  ..., -0.0363,  0.1057, -0.2071]]],
           grad_fn=<SliceBackward>)
    

    Not a huge issue, but worth mentioning

    opened by vvvm23 1
  • Relative position encoding

    Relative position encoding

    Similar to the question raised for the performer architecture , is it possible to implement a relative position encoding given the methodology in which attention is calculated?

    opened by jdcla 1
  • How can we implement

    How can we implement "batch_first" in Nystrom attention?

    Hi,

    Thanks a lot for implementing the nystromformer attention algorithm! Very nice job!

    I am wondering whether it is feasible to add the "batch_first" option in the nystrom attention algorithm? This allow the algorithm to be integrated in the existing pytorch transformer encoder architecture.

    opened by mark0935git 0
  • x-transformers

    x-transformers

    Hi @lucidrains - just wondering if we can plug in Nystrom Attention with x-transformers?

    I've been plugging in Vision Transformers with X-transformers but am wondering if its possible to have a Nystrom transformer with x-transformer improvements to plug into a ViT?

    opened by robbohua 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
simple demo codes for Learning to Teach with Dynamic Loss Functions

Learning to Teach with Dynamic Loss Functions This repo contains the simple demo for the NeurIPS-18 paper: Learning to Teach with Dynamic Loss Functio

Lijun Wu 15 Dec 30, 2021
Code and models used in "MUSS Multilingual Unsupervised Sentence Simplification by Mining Paraphrases".

Multilingual Unsupervised Sentence Simplification Code and pretrained models to reproduce experiments in "MUSS: Multilingual Unsupervised Sentence Sim

Facebook Research 81 Dec 29, 2022
Few-shot NLP benchmark for unified, rigorous eval

FLEX FLEX is a benchmark and framework for unified, rigorous few-shot NLP evaluation. FLEX enables: First-class NLP support Support for meta-training

AI2 85 Dec 03, 2022
Code release for The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification (TIP 2020)

The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification Code release for The Devil is in the Channels: Mutual-Channel

PRIS-CV: Computer Vision Group 230 Dec 31, 2022
Learning multiple gaits of quadruped robot using hierarchical reinforcement learning

Learning multiple gaits of quadruped robot using hierarchical reinforcement learning We propose a method to learn multiple gaits of quadruped robot us

Yunho Kim 17 Dec 11, 2022
We will release the code of "ConTNet: Why not use convolution and transformer at the same time?" in this repo

ConTNet Introduction ConTNet (Convlution-Tranformer Network) is proposed mainly in response to the following two issues: (1) ConvNets lack a large rec

93 Nov 08, 2022
Continual reinforcement learning baselines: experiment specifications, implementation of existing methods, and common metrics. Easily extensible to new methods.

Continual Reinforcement Learning This repository provides a simple way to run continual reinforcement learning experiments in PyTorch, including evalu

55 Dec 24, 2022
LLVIP: A Visible-infrared Paired Dataset for Low-light Vision

LLVIP: A Visible-infrared Paired Dataset for Low-light Vision Project | Arxiv | Abstract It is very challenging for various visual tasks such as image

CVSM Group - email: <a href=[email protected]"> 377 Jan 07, 2023
This repository contains a pytorch implementation of "StereoPIFu: Depth Aware Clothed Human Digitization via Stereo Vision".

StereoPIFu: Depth Aware Clothed Human Digitization via Stereo Vision | Project Page | Paper | This repository contains a pytorch implementation of "St

87 Dec 09, 2022
This is the official implement of paper "ActionCLIP: A New Paradigm for Action Recognition"

This is an official pytorch implementation of ActionCLIP: A New Paradigm for Video Action Recognition [arXiv] Overview Content Prerequisites Data Prep

268 Jan 09, 2023
Conditional Generative Adversarial Networks (CGAN) for Mobility Data Fusion

This code implements the paper, Kim et al. (2021). Imputing Qualitative Attributes for Trip Chains Extracted from Smart Card Data Using a Conditional Generative Adversarial Network. Transportation Re

Eui-Jin Kim 2 Feb 03, 2022
This is the official implementation of VaxNeRF (Voxel-Accelearated NeRF).

VaxNeRF Paper | Google Colab This is the official implementation of VaxNeRF (Voxel-Accelearated NeRF). This codebase is implemented using JAX, buildin

naruya 132 Nov 21, 2022
An open source implementation of CLIP.

OpenCLIP Welcome to an open source implementation of OpenAI's CLIP (Contrastive Language-Image Pre-training). The goal of this repository is to enable

2.7k Dec 31, 2022
PowerGridworld: A Framework for Multi-Agent Reinforcement Learning in Power Systems

PowerGridworld provides users with a lightweight, modular, and customizable framework for creating power-systems-focused, multi-agent Gym environments that readily integrate with existing training fr

National Renewable Energy Laboratory 37 Dec 17, 2022
The Surprising Effectiveness of Visual Odometry Techniques for Embodied PointGoal Navigation

PointNav-VO The Surprising Effectiveness of Visual Odometry Techniques for Embodied PointGoal Navigation Project Page | Paper Table of Contents Setup

Xiaoming Zhao 41 Dec 15, 2022
Code for "Single-view robot pose and joint angle estimation via render & compare", CVPR 2021 (Oral).

Single-view robot pose and joint angle estimation via render & compare Yann Labbé, Justin Carpentier, Mathieu Aubry, Josef Sivic CVPR: Conference on C

Yann Labbé 51 Oct 14, 2022
[CVPR'21] DeepSurfels: Learning Online Appearance Fusion

DeepSurfels: Learning Online Appearance Fusion Paper | Video | Project Page This is the official implementation of the CVPR 2021 submission DeepSurfel

Online Reconstruction 52 Nov 14, 2022
Code for Overinterpretation paper Overinterpretation reveals image classification model pathologies

Overinterpretation This repository contains the code for the paper: Overinterpretation reveals image classification model pathologies Authors: Brandon

Gifford Lab, MIT CSAIL 17 Dec 10, 2022
Script utilizando OpenCV e modelo Machine Learning para detectar o uso de máscaras.

Reconhecendo máscaras Este repositório contém um script em Python3 que reconhece se um rosto está ou não portando uma máscara! O código utiliza da bib

Maria Eduarda de Azevedo Silva 168 Oct 20, 2022
Scikit-event-correlation - Event Correlation and Forecasting over High Dimensional Streaming Sensor Data algorithms

scikit-event-correlation Event Correlation and Changing Detection Algorithm Theo

Intellia ICT 5 Oct 30, 2022