Implementation of Invariant Point Attention, used for coordinate refinement in the structure module of Alphafold2, as a standalone Pytorch module

Overview

Invariant Point Attention - Pytorch

Implementation of Invariant Point Attention as a standalone module, which was used in the structure module of Alphafold2 for coordinate refinement.

  • write up a test for invariance under rotation
  • enforce float32 for certain operations

Install

$ pip install invariant-point-attention

Usage

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,                  # single (and pairwise) representation dimension
    heads = 8,                 # number of attention heads
    scalar_key_dim = 16,       # scalar query-key dimension
    scalar_value_dim = 16,     # scalar value dimension
    point_key_dim = 4,         # point query-key dimension
    point_value_dim = 4        # point value dimension
)

single_repr   = torch.randn(1, 256, 64)      # (batch x seq x dim)
pairwise_repr = torch.randn(1, 256, 256, 64) # (batch x seq x seq x dim)
mask          = torch.ones(1, 256).bool()    # (batch x seq)

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)  # (batch x seq x rot1 x rot2) - example is identity
translations  = torch.zeros(1, 256, 3) # translation, also identity for example

attn_out = attn(
    single_repr,
    pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use this module without the pairwise representations, which is very specific to the Alphafold2 architecture.

import torch
from einops import repeat
from invariant_point_attention import InvariantPointAttention

attn = InvariantPointAttention(
    dim = 64,
    heads = 8,
    require_pairwise_repr = False   # set this to False to use the module without pairwise representations
)

seq           = torch.randn(1, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), '... -> b n ...', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

attn_out = attn(
    seq,
    rotations = rotations,
    translations = translations,
    mask = mask
)

attn_out.shape # (1, 256, 64)

You can also use one IPA-based transformer block, which is an IPA followed by a feedforward. By default it will use post-layernorm as done in the official code, but you can also try pre-layernorm by setting post_norm = False

import torch
from torch import nn
from einops import repeat
from invariant_point_attention import IPABlock

block = IPABlock(
    dim = 64,
    heads = 8,
    scalar_key_dim = 16,
    scalar_value_dim = 16,
    point_key_dim = 4,
    point_value_dim = 4
)

seq           = torch.randn(1, 256, 64)
pairwise_repr = torch.randn(1, 256, 256, 64)
mask          = torch.ones(1, 256).bool()

rotations     = repeat(torch.eye(3), 'r1 r2 -> b n r1 r2', b = 1, n = 256)
translations  = torch.randn(1, 256, 3)

block_out = block(
    seq,
    pairwise_repr = pairwise_repr,
    rotations = rotations,
    translations = translations,
    mask = mask
)

updates = nn.Linear(64, 6)(block_out)
quaternion_update, translation_update = updates.chunk(2, dim = -1) # (1, 256, 3), (1, 256, 3)

# apply updates to rotations and translations for the next iteration

Citations

@Article{AlphaFold2021,
    author  = {Jumper, John and Evans, Richard and Pritzel, Alexander and Green, Tim and Figurnov, Michael and Ronneberger, Olaf and Tunyasuvunakool, Kathryn and Bates, Russ and {\v{Z}}{\'\i}dek, Augustin and Potapenko, Anna and Bridgland, Alex and Meyer, Clemens and Kohl, Simon A A and Ballard, Andrew J and Cowie, Andrew and Romera-Paredes, Bernardino and Nikolov, Stanislav and Jain, Rishub and Adler, Jonas and Back, Trevor and Petersen, Stig and Reiman, David and Clancy, Ellen and Zielinski, Michal and Steinegger, Martin and Pacholska, Michalina and Berghammer, Tamas and Bodenstein, Sebastian and Silver, David and Vinyals, Oriol and Senior, Andrew W and Kavukcuoglu, Koray and Kohli, Pushmeet and Hassabis, Demis},
    journal = {Nature},
    title   = {Highly accurate protein structure prediction with {AlphaFold}},
    year    = {2021},
    doi     = {10.1038/s41586-021-03819-2},
    note    = {(Accelerated article preview)},
}
Comments
  • Computing point dist - use cartesian dimension instead of hidden dimension

    Computing point dist - use cartesian dimension instead of hidden dimension

    https://github.com/lucidrains/invariant-point-attention/blob/2f1fb7ca003d9c94d4144d1f281f8cbc914c01c2/invariant_point_attention/invariant_point_attention.py#L130

    I think it should be dim=-1, thus using the cartesian (xyz) axis, rather than dim=-2, which uses the hidden dimension.

    opened by aced125 3
  • In-place rotation detach not allowed

    In-place rotation detach not allowed

    Hi, this is probably highly version-dependent (I have pytorch=1.11.0, pytorch3d=0.7.0 nightly), but I thought I'd report it. Torch doesn't like the in-place detach of the rotation tensor. Full stack trace (from denoise.py):

    Traceback (most recent call last):
      File "denoise.py", line 56, in <module>
        denoised_coords = net(
      File "/home/pi-user/miniconda3/envs/piai/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/pi-user/invariant-point-attention/invariant_point_attention/invariant_point_attention.py", line 336, in forward
        rotations.detach_()
    RuntimeError: Can't detach views in-place. Use detach() instead. If you are using DistributedDataParallel (DDP) for training, and gradient_as_bucket_view is set as True, gradients are views of DDP buckets, and hence detach_() cannot be called on these gradients. To fix this error, please refer to the Optimizer.zero_grad() function in torch/optim/optimizer.py as the solution.
    

    Switching to rotations = rotations.detach() seems to behave correctly (tested in denoise.py and my own code). I'm not totally sure if this allocates a separate tensor, or just creates a new node pointing to the same data.

    opened by sidnarayanan 1
  • Report a bug that causes instability in training

    Report a bug that causes instability in training

    Hi, I would like to report a bug in the rotation, that causes instability in training. https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L322

    The IPA Transformer is similar to the structure module in AF2, where the recycling is used. Note that we usually detach the gradient of rotation, which may causes instability during training. The reason is that the gradient of rotation would update the rotation during back propagation, which results in the instability based on experiments. Therefore we usually detach the rotation to dispel the updating effect of gradient descent. I have seen you do this in your alphafold2 repo (https://github.com/lucidrains/alphafold2).

    If you think this is a problem, please let me know. I am happy to submit a pr to fix that.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • Subtle mistake in the implementation

    Subtle mistake in the implementation

    Hi. Thanks for your implementation. It is very helpful. However, I find that you miss the dropout in the IPAModule.

    https://github.com/lucidrains/invariant-point-attention/blob/de337568959eb7611ba56eace2f642ca41e26216/invariant_point_attention/invariant_point_attention.py#L239

    In the alphafold2 supplementary, the dropout is nested in the layer norm, which also holds true in the layer norm at transition layer (line 9 in the figure below). image

    If you think this is a problem, please let me know. I will submit a pr to fix it. Thanks again for sharing such an amazing repo.

    Best, Zhangzhi Peng

    opened by pengzhangzhi 1
  • change quaternions update as original alphafold2

    change quaternions update as original alphafold2

    In the original alphafold2 IPA module, pure-quaternion (without real part) description is used for quaternion update. This can be broken down to the residual-update-like formulation. But in this code you use (1, a, b, c) style quaternion so I believe the quaternion update should be done as a simple multiply update. As far as I have tested, the loss seems to go down more efficiently with the modification.

    opened by ShintaroMinami 1
  • #126 maybe omit the 'self.point_attn_logits_scale'?

    #126 maybe omit the 'self.point_attn_logits_scale'?

    Hi luci:

    I read the original paper and compare it to your implement, found one place might be some mistake:

    #126. attn_logits_points = -0.5 * (point_dist * point_weights).sum(dim = -1),

    I thought it should be attn_logits_points = -0.5 * (point_dist * point_weights * self.point_attn_logits_scale).sum(dim = -1)

    Thanks for your sharing!

    opened by CiaoHe 1
  • Application of Invariant point attention : preserver part of structure.

    Application of Invariant point attention : preserver part of structure.

    Hi, lucidrian. First of all really thanks for your work!

    I have a question, how can I change(denoise) the structure only in the region I want, how do I do it? (denoise.py)

    opened by hw-protein 0
  • Equivariance test for IPA Transformer

    Equivariance test for IPA Transformer

    @lucidrains I would like to ask about the equivariance of the transformer (not IPA blocks). I wonder if you checked for the equivariance of the output when you allow the transformation of local points to global points using the updated quaternions and translations. I am not sure why this test fails in my case.

    opened by amrhamedp 1
Owner
Phil Wang
Working with Attention
Phil Wang
Transformer Huffman coding - Complete Huffman coding through transformer

Transformer_Huffman_coding Complete Huffman coding through transformer 2022/2/19

3 May 19, 2022
Codebase for Image Classification Research, written in PyTorch.

pycls pycls is an image classification codebase, written in PyTorch. It was originally developed for the On Network Design Spaces for Visual Recogniti

Facebook Research 2k Jan 01, 2023
Source code of the paper "Deep Learning of Latent Variable Models for Industrial Process Monitoring".

Source code of the paper "Deep Learning of Latent Variable Models for Industrial Process Monitoring".

Xiangyin Kong 7 Nov 08, 2022
Implementation of a Transformer that Ponders, using the scheme from the PonderNet paper

Ponder(ing) Transformer Implementation of a Transformer that learns to adapt the number of computational steps it takes depending on the difficulty of

Phil Wang 65 Oct 04, 2022
Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech

Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech This repository is the official implementation of "Meta-TTS: Meta-Learning for Few

Sung-Feng Huang 128 Dec 25, 2022
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

PyKEEN 1.1k Jan 09, 2023
MIM: MIM Installs OpenMMLab Packages

MIM provides a unified API for launching and installing OpenMMLab projects and their extensions, and managing the OpenMMLab model zoo.

OpenMMLab 254 Jan 04, 2023
Implémentation en pyhton de l'article Depixelizing pixel art de Johannes Kopf et Dani Lischinski

Implémentation en pyhton de l'article Depixelizing pixel art de Johannes Kopf et Dani Lischinski

TableauBits 3 May 29, 2022
Optical Character Recognition + Instance Segmentation for russian and english languages

Распознавание рукописного текста в школьных тетрадях Соревнование, проводимое в рамках олимпиады НТО, разработанное Сбером. Платформа ODS. Результаты

Gerasimov Maxim 21 Dec 19, 2022
A pytorch implementation of Paper "Improved Training of Wasserstein GANs"

WGAN-GP An pytorch implementation of Paper "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, SciPy, Matplotlib A recent NVIDIA GPU

Marvin Cao 1.4k Dec 14, 2022
An experimentation and research platform to investigate the interaction of automated agents in an abstract simulated network environments.

CyberBattleSim April 8th, 2021: See the announcement on the Microsoft Security Blog. CyberBattleSim is an experimentation research platform to investi

Microsoft 1.5k Dec 25, 2022
Codes for Causal Semantic Generative model (CSG), the model proposed in "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS-21)

Learning Causal Semantic Representation for Out-of-Distribution Prediction This repository is the official implementation of "Learning Causal Semantic

Chang Liu 54 Dec 01, 2022
Rotary Transformer

[中文|English] Rotary Transformer Rotary Transformer is an MLM pre-trained language model with rotary position embedding (RoPE). The RoPE is a relative

325 Jan 03, 2023
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
Multi-task head pose estimation in-the-wild

Multi-task head pose estimation in-the-wild We provide C++ code in order to replicate the head-pose experiments in our paper https://ieeexplore.ieee.o

Roberto Valle 26 Oct 06, 2022
i-RevNet Pytorch Code

i-RevNet: Deep Invertible Networks Pytorch implementation of i-RevNets. i-RevNets define a family of fully invertible deep networks, built from a succ

Jörn Jacobsen 378 Dec 06, 2022
A PyTorch implementation of a Factorization Machine module in cython.

fmpytorch A library for factorization machines in pytorch. A factorization machine is like a linear model, except multiplicative interaction terms bet

Jack Hessel 167 Jul 06, 2022
This is the official implementation for the paper "(Almost) Free Incentivized Exploration from Decentralized Learning Agents" in NeurIPS 2021.

Observe then Incentivize Experiments This is the code used for the paper "(Almost) Free Incentivized Exploration from Decentralized Learning Agents",

Cong Shen Research Group 0 Mar 08, 2022
A TensorFlow implementation of the Mnemonic Descent Method.

MDM A Tensorflow implementation of the Mnemonic Descent Method. Mnemonic Descent Method: A recurrent process applied for end-to-end face alignment G.

123 Oct 07, 2022
An easier way to build neural search on the cloud

An easier way to build neural search on the cloud Jina is a deep learning-powered search framework for building cross-/multi-modal search systems (e.g

Jina AI 17k Jan 02, 2023