Implementation of LambdaNetworks, a new approach to image recognition that reaches SOTA with less compute

Overview

Lambda Networks - Pytorch

Implementation of λ Networks, a new approach to image recognition that reaches SOTA on ImageNet. The new method utilizes λ layer, which captures interactions by transforming contexts into linear functions, termed lambdas, and applying these linear functions to each input separately.

Yannic Kilcher's paper review

Install

$ pip install lambda-networks

Usage

Global context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

Localized context

import torch
from lambda_networks import LambdaLayer

layer = LambdaLayer(
    dim = 32,
    dim_out = 32,
    r = 23,         # the receptive field for relative positional encoding (23 x 23)
    dim_k = 16,
    heads = 4,
    dim_u = 4
)

x = torch.randn(1, 32, 64, 64)
layer(x) # (1, 32, 64, 64)

For fun, you can also import this as follows

from lambda_networks import λLayer

Tensorflow / Keras version

Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under ./lambda_networks/tfkeras.py or make sure to install tensorflow and keras before running the following.

import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
    dim_out = 32,
    r = 23,
    dim_k = 16,
    heads = 4,
    dim_u = 1
)

x = tf.random.normal((1, 64, 64, 16)) # channel last format
layer(x) # (1, 64, 64, 32)

Citations

@inproceedings{
    anonymous2021lambdanetworks,
    title={LambdaNetworks: Modeling long-range Interactions without Attention},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=xTJEN-ggl1b},
    note={under review}
}
Comments
  • Contiguity problem:

    Contiguity problem: "RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input."

    It seems that LambaLayer breaks contiguity when I try it.

    layer(x).is_contiguous()
    >> False
    

    I have to use .contiguous() where I train with it, is it normal?

    opened by Whiax 5
  • Warning: Mixed memory format inputs detected while calling the operator.

    Warning: Mixed memory format inputs detected while calling the operator.

    I have added lambda layers to every block of Resnet, but the following warning will appear. Will it affect the result?

    Warning: Mixed memory format inputs detected while calling the operator. The operator will output channels_last tensor even if some of the inputs are not in channels_last format. (function operator())

    opened by Pluto1314 4
  • Implementation of Lambda convolution

    Implementation of Lambda convolution

    Thanks for the great work in the implementations!

    I would like to ask whether there is a difference in using 'Conv2d' as suggested in Eq. 3 in the paper and your implementation of 'Conv3d'. These two convs treat the (h x w)-dimension as a 1-d sequence and a 2-d image, respectively. I believe they are quite different in concept.

    Point out if I misunderstood.

    Thx a lot.

    opened by romulus0914 3
  • How lambda layer handle the downsample in LambdaResNet?

    How lambda layer handle the downsample in LambdaResNet?

    Hi, Thanks for your clear code, i try to implement the LambdaResNet. Does lambda layer replace all conv2d layer? If so, how does lambda layer handle the downsample in conv2d, like stride=2? Or just keep the conv2d if stride =2, replace only the conv2d layers in stride =1?

    opened by qiaoran-dawnlight 2
  • question about hybrid lambdaResnet

    question about hybrid lambdaResnet

    Hi,

    In the paper, there is this paragraph:

    When working with hybrid LambdaNetworks, we use a single lambda layer in c4 for LambdaResNet50, 3 lambda layers for LambdaResNet101, 6 lambda layers for LambdaResNet-152/200/270/350 and 8 lambda layers for LambdaResNet-420.

    I have several questions about constructing the hybrid lambdaResnet:

    1. Do we only need to replace the 3x3conv with lambda layer in the C4 stage rather than C4 and C5(as in the ablation study)?
    2. When there is more than 1 lambda layers, such as the case of LambdaResNet101, are we replacing the 3x3conv with 3 lambda layers? And in the resnet50 case, we replace the 3x3conv with 1 lambda layers ?
    opened by CoinCheung 2
  • Question: Is there an easy way to visualise lambdas?

    Question: Is there an easy way to visualise lambdas?

    I want to train classifier and tell what regions it pays the most.. well.. attention to :) And make this simultaneously with an inference without using gradcam etc Can I do this?

    opened by lebionick 1
  • Fix relative positional attention (position lambda)

    Fix relative positional attention (position lambda)

    Hi lucidrains, Thanks to your nice work.

    I found an error in the position lambda (relative positional attention) implementation. Relative positional attention, λp, should be translation equivariance, as written in the paper Sec. 3.2. It means that the positional embedding has a constraint, E[n, m] = E[t(n), t(m)], but it is missed in current implementation. This PR fixes it by adding the translation equivariance constraint. I checked that this PR improves the result in my experiment.

    NOTE that this PR modify the function parameter n, from total area (n=w*h) to length of each side (n=w=h).

    opened by khanrc 1
  • Use of Keras Lambda

    Use of Keras Lambda

    Hey! Thank for the awesome implementations :D

    I was wondering why the use of tf.keras.layers.Lambda? Seems unnecessary, regular calls to TF operations works and is more readable.

    https://github.com/lucidrains/lambda-networks/blob/06a48f2a5b41f3cd278aee67838c32051a0a9bed/lambda_networks/tfkeras.py#L73

    You can also call the functional version of the softmax instead.

    opened by cgarciae 1
  • Lambda for a sequence of images

    Lambda for a sequence of images

    Thanks for the quick implementation!

    I have a problem where I have a sequence of images rather than 1. (a video) So instead of having a dimension batch, channels, height, width, I also have after batch a length dimension to determine sequence length.

    Given a known max_length (for the positional embedding), in forward, should conv4d be used instead of 3d to allow interaction between frames?

    In the paper they do mention this could serve as a general framework for sequences of images, so I wonder if you explored that in implementation (where obviously a single image is just a case where length=1)

    opened by AmitMY 1
  • why flops so high

    why flops so high

    I used resnet50, and change C4 layer into LambaBottleNeck; but why flops so high about 20G and input size is 224*244; is that right, or something wrong about my inplementation.

    opened by zisuina 0
  • How to load_model correctly

    How to load_model correctly

    Hi everyone, I am struggling when loading this model saved in .h5 file.

    How is the correct way to load this network? If I use custom_objects I get init() got an unexpected keyword argument 'name'

    opened by JNaranjo-Alcazar 0
  • Please add clarity to code

    Please add clarity to code

    so Phil - I love your work - I wish you could go extra few steps to help out users. I found this class by François-Guillaume @frgfm - which adds in clear math coments. I want to merge it but there's a bit of code drift don't want to introduce any bugs. I beseech you to go extra step to help users bridge from papers to code.

    https://github.com/frgfm/Holocron/blob/bcc3ea19a477e4b28dc5973cdbe92a9b05c690bb/holocron/nn/modules/lambda_layer.py

    eg. please articulate return types def forward(self, x: torch.Tensor) -> torch.Tensor:

    Please give any clarity in arguments. # Project input and context to get queries, keys & values

    Throw in some maths as a comment / this is great as it bridges the paper to the code.

    B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)

    import torch
    from torch import nn, einsum
    import torch.nn.functional as F
    from typing import Optional
    
    __all__ = ['LambdaLayer']
    
    
    class LambdaLayer(nn.Module):
        """Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention"
        <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains'
        <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`.
        Args:
            in_channels (int): input channels
            out_channels (int, optional): output channels
            dim_k (int): key dimension
            n (int, optional): number of input pixels
            r (int, optional): receptive field for relative positional encoding
            num_heads (int, optional): number of attention heads
            dim_u (int, optional): intra-depth dimension
        """
        def __init__(
            self,
            in_channels: int,
            out_channels: int,
            dim_k: int,
            n: Optional[int] = None,
            r: Optional[int] = None,
            num_heads: int = 4,
            dim_u: int = 1
        ) -> None:
            super().__init__()
            self.u = dim_u
            self.num_heads = num_heads
    
            if out_channels % num_heads != 0:
                raise AssertionError('values dimension must be divisible by number of heads for multi-head query')
            dim_v = out_channels // num_heads
    
            # Project input and context to get queries, keys & values
            self.to_q = nn.Conv2d(in_channels, dim_k * num_heads, 1, bias=False)
            self.to_k = nn.Conv2d(in_channels, dim_k * dim_u, 1, bias=False)
            self.to_v = nn.Conv2d(in_channels, dim_v * dim_u, 1, bias=False)
    
            self.norm_q = nn.BatchNorm2d(dim_k * num_heads)
            self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
    
            self.local_contexts = r is not None
            if r is not None:
                if r % 2 != 1:
                    raise AssertionError('Receptive kernel size should be odd')
                self.padding = r // 2
                self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
            else:
                if n is None:
                    raise AssertionError('You must specify the total sequence length (h x w)')
                self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            b, c, h, w = x.shape
    
            # Project inputs & context to retrieve queries, keys and values
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
    
            # Normalize queries & values
            q = self.norm_q(q)
            v = self.norm_v(v)
    
            # B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
            q = q.reshape(b, self.num_heads, -1, h * w)
            # B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)
            k = k.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
            # B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)
            v = v.reshape(b, -1, self.u, h * w).permute(0, 2, 1, 3)
    
            # Normalized keys
            k = k.softmax(dim=-1)
    
            # Content function
            λc = einsum('b u k m, b u v m -> b k v', k, v)
            Yc = einsum('b h k n, b k v -> b n h v', q, λc)
    
            # Position function
            if self.local_contexts:
                # B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x W
                v = v.reshape(b, self.u, v.shape[2], h, w)
                λp = F.conv3d(v, self.R, padding=(0, self.padding, self.padding))
                Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
            else:
                λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
                Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
    
            Y = Yc + Yp
            # B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x W
            out = Y.permute(0, 2, 3, 1).reshape(b, self.num_heads * v.shape[2], h, w)
            return out
    
    opened by johndpope 1
  • Image Size

    Image Size

    Are non-square image blocks allowed for context? Using global context and a non-square dimensions (96, 128), I get an error on this line about dimension size.

    λp = einsum('n m k u, b u v m -> b n k v', rel_pos_emb, v)

    opened by anklebreaker 0
  • LambdaResNet Implementation?

    LambdaResNet Implementation?

    I have been looking around and found one implementation of LambdaResNets, although there seem to be some metric performance problems and I've found wall-clock performance problems (runs ~7x slower than normal resnets).

    Do you plan on putting out a lambdaresnet model in this repository?

    opened by nollied 4
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and

Kaiyu Yue 275 Nov 22, 2022
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 08, 2023
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
The goal of this library is to generate more helpful exception messages for numpy/pytorch matrix algebra expressions.

Tensor Sensor See article Clarifying exceptions and visualizing tensor operations in deep learning code. One of the biggest challenges when writing co

Terence Parr 704 Dec 14, 2022
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 05, 2023
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
PyTorch to TensorFlow Lite converter

PyTorch to TensorFlow Lite converter

Omer Ferhat Sarioglu 140 Dec 13, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
ocaml-torch provides some ocaml bindings for the PyTorch tensor library.

ocaml-torch provides some ocaml bindings for the PyTorch tensor library. This brings to OCaml NumPy-like tensor computations with GPU acceleration and tape-based automatic differentiation.

Laurent Mazare 369 Jan 03, 2023
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 07, 2023
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022