RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

Overview

RETRO - Pytorch (wip)

Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch. This will deviate from the paper slightly, using rotary embeddings for relative positional encoding, as well as Faiss library instead of Scann.

If you are interested, please join this Discord for discussions

Install

$ pip install retro-pytorch

Usage

import torch
from retro_pytorch import RETRO

retro = RETRO(
    num_tokens = 20000,                      # number of tokens
    chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
    max_seq_len = 2048,                      # max sequence length
    enc_dim = 896,                           # encoder model dim
    enc_depth = 2,                           # encoder depth
    dec_dim = 796,                           # decoder model dim
    dec_depth = 12,                          # decoder depth
    dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
    heads = 8,                               # attention heads
    dim_head = 64,                           # dimension per head
    dec_attn_dropout = 0.25,                 # decoder attention dropout
    dec_ff_dropout = 0.25                    # decoder feedforward dropout
)

seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)

loss = retro(seq, retrieved, return_loss = True)
loss.backward()

# do above for many steps

Todo

  • handle indexing of corpus of text with faiss
  • handle reindexing of all nearest neighbors
  • function for getting frozen BERT embeddings for batch of chunks
  • handle partially filled chunks with mask and null tokens as a safeguard
  • inference code, autoretrieving at chunk boundaries
  • autohandle retrieved chunks for last chunk in sequence, whether it is given or not

Citations

@misc{borgeaud2022improving,
    title   = {Improving language models by retrieving from trillions of tokens}, 
    author  = {Sebastian Borgeaud and Arthur Mensch and Jordan Hoffmann and Trevor Cai and Eliza Rutherford and Katie Millican and George van den Driessche and Jean-Baptiste Lespiau and Bogdan Damoc and Aidan Clark and Diego de Las Casas and Aurelia Guy and Jacob Menick and Roman Ring and Tom Hennigan and Saffron Huang and Loren Maggiore and Chris Jones and Albin Cassirer and Andy Brock and Michela Paganini and Geoffrey Irving and Oriol Vinyals and Simon Osindero and Karen Simonyan and Jack W. Rae and Erich Elsen and Laurent Sifre},
    year  = {2022},
    eprint = {2112.04426},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

I consider always the adult life to be the continuous retrieval of childhood. - Umberto Eco

Comments
  • Error Reconstructing FAISS Index

    Error Reconstructing FAISS Index

    Hiya! Thanks for making this library out in the open!

    I've been trying to get your training wrapper working, but when it tries to generate the index, I get the following error:

    RuntimeError: Error in virtual void faiss::Index::reconstruct(faiss::Index::idx_t, float*) const at /project/faiss/faiss/Index.cpp:48: reconstruct not implemented for this type of index
    

    To reproduce, you can use this google colab: https://colab.research.google.com/drive/1BcOtBpWBGmXX_tOC7WKcHOa9SukWEKpf?usp=sharing

    Any help with this would be greatly appreciated!

    opened by ncoop57 18
  • Why are there so many position embeddings?

    Why are there so many position embeddings?

    Hi! Thanks for your great work, it's very helpful for my project! I was just curious why there are so many position embeddings. Essentially it looks like the sequence is also being added a (1 to n) pos emb initially in the RETRO class, and then in each attention module rotary embeddings are added again. I thought just two in the Attention and CCA would be quite enough. Thanks in advance!

    opened by jasperhyp 5
  • `doc_ids_memmap` shape

    `doc_ids_memmap` shape

    https://github.com/lucidrains/RETRO-pytorch/blob/7d305379b72232c54262742d3f80326ed5fdac9e/retro_pytorch/retrieval.py#L138

    Is there a reason doc_ids_memmap is shape (max_docs, )? Shouldn't it be (max_chunks, ) since it's mapping chunks to doc ids?

    opened by josephcappadona 5
  • rotary embedding question

    rotary embedding question

    I have a two questions about the rotary embedding implementation.

    1. Divide the d-dimension space in to d/2 sub-spaces

    In rotary embedding, head_dim is divided by 2 to utilize the conjugate space with sin and cos.

    from rotary_embedding_torch import RotaryEmbedding
    
    head_dim = 64
    rotary_emb = RotaryEmbedding(dim=head_dim)
    
    class RotaryEmbedding(nn.Module):
        def __init__(
            self,
            dim,
            custom_freqs = None,
            freqs_for = 'lang',
            theta = 10000,
            max_freq = 10,
            num_freqs = 1,
            learned_freq = False
        ):
            super().__init__()
            if exists(custom_freqs):
                freqs = custom_freqs
            elif freqs_for == 'lang':
                # freqs.shape == (head_dim // 2)
                freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
            ...
    

    But the freqs of the rotary in RETRO is kind of weird. Rotary embedding in RETRO's Encoder and Decoder divides head_dim by 2 in advance and puts it as an input.

    https://github.com/lucidrains/RETRO-pytorch/blob/4f99e316458fb13a5e4f881586f8436458cf4ead/retro_pytorch/retro_pytorch.py#L380-L381

    And divide freq by 2 once again in the initializer as shown below.

    https://github.com/lucidrains/RETRO-pytorch/blob/4f99e316458fb13a5e4f881586f8436458cf4ead/retro_pytorch/retro_pytorch.py#L92-L96

    In this way, when head_dim=48, the shape of freqs is obtained as follows.

    Because the apply_rotary_emb function concats the tensor that exceeds rot_dim, the shape of the resulting tensor is the same, but the rotary pos does not seem to be fully applied.

    Hence, I think you need to modify the two lines of code as below.

    • https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py#L76
      • The resulting tensor has the same shape.
    >>> ASIS
                freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
    <<< TOBE
                freqs = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
    
    • https://github.com/lucidrains/RETRO-pytorch/blob/main/retro_pytorch/retro_pytorch.py#L95
      • As shown in the confirmation code below, the above modification is the same as the existing rotary embedding implementation.
      import torch
      dim1 = hid_dim // n_heads
      dim2 = (hid_dim // n_heads) // 2
      freqs1 = 1. / (10000 ** (torch.arange(0, dim1, 2).float() / dim1))
      freqs2 = 1. / (10000 ** (torch.arange(0, dim2, 1).float() / dim2))
      assert torch.equal(freqs1, freqs2)
      
    >>> ASIS
            inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
    <<< TOBE
            inv_freq = 1. / (10000 ** (torch.arange(0, dim, 1).float() / dim))
    

    2. rotate_half function

    The rotary_half implementations of RETRO-pytorch and rotary-embedding-torch are slightly different.

    # In rotary-embedding-torch
    # https://github.com/lucidrains/rotary-embedding-torch/blob/517ee2cfeb10602032ef9d282c19851e19dd8943/rotary_embedding_torch/rotary_embedding_torch.py#L34
    def rotate_half(x):
        x = rearrange(x, '... (d r) -> ... d r', r = 2)
        x1, x2 = x.unbind(dim = -1)
        x = torch.stack((-x2, x1), dim = -1)
        return rearrange(x, '... d r -> ... (d r)')
    
    # In RETRO-pytorch
    # https://github.com/lucidrains/RETRO-pytorch/blob/4f99e316458fb13a5e4f881586f8436458cf4ead/retro_pytorch/retro_pytorch.py#L104
    def rotate_half(x):
        x = rearrange(x, '... (j d) -> ... j d', j = 2)
        x1, x2 = x.unbind(dim = -2)
        return torch.cat((-x2, x1), dim = -1)
    

    In rotary, concat is stacked with [0 1 0 1 0 1 0 1], and retro is stacked with [0 0 0 0 1 1 1 1].

    • [0 0 0 0] is pre-half
    • [1 1 1 1] is post-half

    I wonder why it was implemented with this change! (just curious)

    Looking at your implementation, I am studying and matching the thesis. Thank you always :)

    opened by jinmang2 3
  • Autoregressivity

    Autoregressivity

    I had a question about Figure 2 and equation 3 from the paper. How does the last token of each chunk C_u being able to attend to the retrieved content E_u not break autoregressivity?

    opened by sdpmas 3
  • Extra layer encoder_output_to_decoder_dim cause issue with distributed training

    Extra layer encoder_output_to_decoder_dim cause issue with distributed training

    Hiya, hope Ice Cream is doing well, as well as you of course!

    I've been trying to get distributed training working with your library and I discovered this additional Linear layer encoder_output_to_decoder_dim not being used any where:

    https://github.com/lucidrains/RETRO-pytorch/blob/main/retro_pytorch/retro_pytorch.py#L491

    It seems to be a copy of the layer defined right above it to_decoder_model_dim, which does get used. Having this extra layer that is not part of the loss calculation causes the following error with data parallelism:

    [RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.](https://github.com/pytorch/pytorch/issues/43259#)
    

    Not sure if this layer is supposed to be there and it just didn't get used or if it is there by accident, so wanted to ask 🤓

    opened by ncoop57 2
  • Question about the right position to encode `retrieved`

    Question about the right position to encode `retrieved`

    Hi, I am currently reading through the code and got confused when I reached this line:

    https://github.com/lucidrains/RETRO-pytorch/blob/92ff28755df53352547b1868fb03feae9931c1dd/retro_pytorch/retro_pytorch.py#L598

    image According to Algorithm 1 in the paper (the screenshot above), doesn't this line need to go inside the decoder, under this line? https://github.com/lucidrains/RETRO-pytorch/blob/92ff28755df53352547b1868fb03feae9931c1dd/retro_pytorch/retro_pytorch.py#L406

    This is an example of how I think the code of decoder.forward should be.

    def forward(self, x, *, context_mask = None, retrieved = None):
      encoded = False  # flag to know if p = min(P) (in the algorithm)
      ...
        if exists(cross_attn) and exists(retrieved):
          if not encoded:
            ...
            # use x (H at layer p where p = min(P)), not embed (Emb(X))
            x_as_context = repeat(x[:, :seq_index], 'b (k n) d -> (b k r) n d', n = self.chunk_size, r = num_neighbors)
            retrieved = self.encoder(retrieved, mask = encoder_retrieved_mask, chunked_seq = x_as_context)
            encoded = True
    
    opened by soheeyang 2
  • Confusions about cross attentions in encoder

    Confusions about cross attentions in encoder

    In your code https://github.com/lucidrains/RETRO-pytorch/blob/5260d70fae085ed0cc5cbe3e2d1b35947fee475d/retro_pytorch/retro_pytorch.py#L115-L119

    When this class is called by Encoder, the x means retrieved chunks. In attentional mechanisms it produces q matrix, but i think it should produce k,v matrix. In encoder input sequence just lead us to make attention in retrieved chunks word.

    https://github.com/lucidrains/RETRO-pytorch/blob/5260d70fae085ed0cc5cbe3e2d1b35947fee475d/retro_pytorch/retro_pytorch.py#L288-L294

    opened by Hi-archers 2
  • 'NoneType' object is not callable

    'NoneType' object is not callable

    when I run the example of "RETRO Datasets", there is a wrong aboubt TypeError:

    Traceback (most recent call last): File "/home/fgq/all/RETRO/fuxian_2.py", line 58, in retro = RETRO( File "/home/fgq/all/RETRO/retro_pytorch/retro_pytorch.py", line 507, in init self.encoder = Encoder( File "/home/fgq/all/RETRO/retro_pytorch/retro_pytorch.py", line 337, in init wrapper(Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, causal = causal)), File "/home/fgq/all/RETRO/retro_pytorch/retro_pytorch.py", line 73, in init self.norm = norm_klass(dim) TypeError: 'NoneType' object is not callable

    code

    save_memmap( './train.chunks.dat', np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1))) )

    • generate nearest neighbors for each chunk

    save_memmap( './train.chunks.knn.dat', np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS))) )

    • generate seq data

    save_memmap( './train.seq.dat', np.int32(np.random.randint(0, 128, size = (NUM_SEQS,))) )

    • instantiate dataset class train_ds = RETRODataset( num_sequences = NUM_SEQS, num_chunks = NUM_CHUNKS, num_neighbors = NUM_NEIGHBORS, chunk_size = CHUNK_SIZE, seq_len = 2048, chunk_memmap_path = './train.chunks.dat', chunk_nn_memmap_path = './train.chunks.knn.dat', seq_memmap_path = './train.seq.dat' )
    opened by f-guoqiang 1
  • Fix reconstruction error discussed in #15

    Fix reconstruction error discussed in #15

    This PR fixes the issue with reconstruction of the faiss index. One caveat is that we can no longer do memmapping to reduce RAM overhead. Maybe this will be fixed in faiss soon, but for now memory will be an issue for extremely large indices.

    opened by ncoop57 1
  • Update retrieval.py

    Update retrieval.py

    The build_index command

    In the autofaiss document "–embeddings" Description : "Source path of the directory containing your .npy embedding files. If there are several files, they are read in the lexicographical order. This can be a local path or a path in another Filesystem e.g. hdfs://root/… or s3://…"

    The build_index function read embedding folders in lexicographical order, but now saves embedding files in order of "0.npy, 1.npy, 2.npy,..., n.npy", then build_index read embeddings in order of "0.npy, 1.npy, 10.npy......., 2.npy,..., n.npy", So I fill in some zeros in front of the embedding file name to make the build_index work normal.

    opened by Hi-archers 1
  • Causal mask in Chunked Cross Attention

    Causal mask in Chunked Cross Attention

    When computing the chunked cross-attention (line 259 here https://github.com/lucidrains/RETRO-pytorch/blob/main/retro_pytorch/retro_pytorch.py), a causal mask is used. If I'm not mistaken, in the code we apply the causal mask to the last dimension of x (last words). However, my understanding was that the mask should be applied to the first dimensions as in the figure from the repo: image

    Is it normal?

    opened by Jonor127-OP 0
  • How to give Prompt to trained RETRO Model?

    How to give Prompt to trained RETRO Model?

    I am following the instructions on the RETRO-pytorch GItHub repo. After training my model, how do I go about using it to generate responses?

    retro = RETRO(
        chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)
        max_seq_len = 2048,                      # max sequence length
        enc_dim = 896,                           # encoder model dim
        enc_depth = 2,                           # encoder depth
        dec_dim = 796,                           # decoder model dim
        dec_depth = 12,                          # decoder depth
        dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)
        heads = 8,                               # attention heads
        dim_head = 64,                           # dimension per head
        dec_attn_dropout = 0.25,                 # decoder attention dropout
        dec_ff_dropout = 0.25,                   # decoder feedforward dropout
        use_deepnet = True                       # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers
    )
    
    seq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training
    retrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)
    
    loss = retro(seq, retrieved, return_loss = True)
    loss.backward()
    
    wrapper = TrainingWrapper(
        retro = retro,                                 # path to retro instance
        knn = 2,                                       # knn (2 in paper was sufficient)
        chunk_size = 64,                               # chunk size (64 in paper)
        documents_path = './retro_training_set/',              # path to folder of text
        glob = '**/*.txt',                             # text glob
        chunks_memmap_path = './train.chunks.dat',     # path to chunks
        seqs_memmap_path = './train.seq.dat',          # path to sequence data
        doc_ids_memmap_path = './train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)
        max_chunks = 1_000_000,                        # maximum cap to chunks
        max_seqs = 100_000,                            # maximum seqs
        knn_extra_neighbors = 100,                     # num extra neighbors to fetch
        max_index_memory_usage = '100m',
        current_memory_available = '1G'    
    )
    

    Now when I want to give this model a text input (any prompt), how would I go about doing that? Which method or function would I use? Which model/tokenizer should I use to encode the input prompt and then decode the model output tensor? Is there a method for that?

    Example Prompt: "The movie Dune was released in"

    opened by shahmeer99 1
  • Scann vs faiss

    Scann vs faiss

    Could you elaborate on the decision to use faiss instead of scann? In theory scann is open source too, but I'm wondering if you found easier to get the performance needed from faiss instead.

    opened by afcruzs 5
  • Clarification on Architecture

    Clarification on Architecture

    Reading the original paper, I took it that RETRO was a standard transformer (ie.. 12 layer encoder, 12 layer decoder) augmented with a DB retrieval system that included a second smaller (2 layer) encoder for the Frozen Bart encoded neighbors, where the 2 layer encoder was sort of a translator between the Bart model and the main transformer.

    Looking at the model here, it looks like there is only the 2 layer retrieval encoder and not a full-size main encoder. Is that correct?

    Going back and re-reading the paper it doesn't seem to explicitly say one way or the other. It seems odd to me that the model would only have the 2 layer retrieval encoder. Not only would this mean that the encoder is only 2 layers but it also means that most decoder layers have no standard cross attention to the encoder, only layers 6, 9, 12 with the new CCA setup.

    Has anyone trained the model from this repo and demonstrated that it can produce the results from the original paper?

    opened by bjascob 0
  • Retro-fitting a pretrained model

    Retro-fitting a pretrained model

    Hey,

    Thank you for your implementation! Is it possible to use your library to "retro-fit" a pretrained model?

    I guess it would mean freezing the model during training, only fine-tuning the retrieval and cross-attention? How would you recommend doing that?

    Thanks!

    opened by dean-sh 6
Releases(v0.3.8a)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Code for BMVC2021 "MOS: A Low Latency and Lightweight Framework for Face Detection, Landmark Localization, and Head Pose Estimation"

MOS-Multi-Task-Face-Detect Introduction This repo is the official implementation of "MOS: A Low Latency and Lightweight Framework for Face Detection,

104 Dec 08, 2022
A framework for using LSTMs to detect anomalies in multivariate time series data. Includes spacecraft anomaly data and experiments from the Mars Science Laboratory and SMAP missions.

Telemanom (v2.0) v2.0 updates: Vectorized operations via numpy Object-oriented restructure, improved organization Merge branches into single branch fo

Kyle Hundman 844 Dec 28, 2022
Self-Supervised Learning with Kernel Dependence Maximization

Self-Supervised Learning with Kernel Dependence Maximization This is the code for SSL-HSIC, a self-supervised learning loss proposed in the paper Self

DeepMind 29 Dec 29, 2022
A TensorFlow implementation of Neural Program Synthesis from Diverse Demonstration Videos

ViZDoom http://vizdoom.cs.put.edu.pl ViZDoom allows developing AI bots that play Doom using only the visual information (the screen buffer). It is pri

Hyeonwoo Noh 1 Aug 19, 2020
DeepProbLog is an extension of ProbLog that integrates Probabilistic Logic Programming with deep learning by introducing the neural predicate.

DeepProbLog DeepProbLog is an extension of ProbLog that integrates Probabilistic Logic Programming with deep learning by introducing the neural predic

KU Leuven Machine Learning Research Group 94 Dec 18, 2022
pytorch implementation of openpose including Hand and Body Pose Estimation.

pytorch-openpose pytorch implementation of openpose including Body and Hand Pose Estimation, and the pytorch model is directly converted from openpose

Hzzone 1.4k Jan 07, 2023
A Python package for causal inference using Synthetic Controls

Synthetic Control Methods A Python package for causal inference using synthetic controls This Python package implements a class of approaches to estim

Oscar Engelbrektson 107 Dec 28, 2022
Code for "Learning the Best Pooling Strategy for Visual Semantic Embedding", CVPR 2021

Learning the Best Pooling Strategy for Visual Semantic Embedding Official PyTorch implementation of the paper Learning the Best Pooling Strategy for V

Jiacheng Chen 106 Jan 06, 2023
Cockpit is a visual and statistical debugger specifically designed for deep learning.

Cockpit: A Practical Debugging Tool for Training Deep Neural Networks

Felix Dangel 421 Dec 29, 2022
Implementation of Sequence Generative Adversarial Nets with Policy Gradient

SeqGAN Requirements: Tensorflow r1.0.1 Python 2.7 CUDA 7.5+ (For GPU) Introduction Apply Generative Adversarial Nets to generating sequences of discre

Lantao Yu 2k Dec 29, 2022
A PyTorch implementation of "SelfGNN: Self-supervised Graph Neural Networks without explicit negative sampling"

SelfGNN A PyTorch implementation of "SelfGNN: Self-supervised Graph Neural Networks without explicit negative sampling" paper, which will appear in Th

Zekarias Tilahun 24 Jun 21, 2022
Example how to deploy deep learning model with aiohttp.

aiohttp-demos Demos for aiohttp project. Contents Imagetagger Deep Learning Image Classifier URL shortener Toxic Comments Classifier Moderator Slack B

aio-libs 661 Jan 04, 2023
TensorFlow implementation of AlexNet and its training and testing on ImageNet ILSVRC 2012 dataset

AlexNet training on ImageNet LSVRC 2012 This repository contains an implementation of AlexNet convolutional neural network and its training and testin

Matteo Dunnhofer 161 Nov 25, 2022
Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection

Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection abstract:Unlike 2D object detection where all RoI featur

DK. Zhang 2 Oct 07, 2022
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
Training data extraction on GPT-2

Training data extraction from GPT-2 This repository contains code for extracting training data from GPT-2, following the approach outlined in the foll

Florian Tramer 62 Dec 07, 2022
Self-Adaptable Point Processes with Nonparametric Time Decays

NPPDecay This is our implementation for the paper Self-Adaptable Point Processes with Nonparametric Time Decays, by Zhimeng Pan, Zheng Wang, Jeff M. P

zpan 2 Sep 24, 2022
Certified Patch Robustness via Smoothed Vision Transformers

Certified Patch Robustness via Smoothed Vision Transformers This repository contains the code for replicating the results of our paper: Certified Patc

Madry Lab 35 Dec 14, 2022
A curated list of awesome Active Learning

Awesome Active Learning 🤩 A curated list of awesome Active Learning ! 🤩 Background (image source: Settles, Burr) What is Active Learning? Active lea

BAI Fan 431 Jan 03, 2023
Deep Markov Factor Analysis (NeurIPS2021)

Deep Markov Factor Analysis (DMFA) Codes and experiments for deep Markov factor analysis (DMFA) model accepted for publication at NeurIPS2021: A. Farn

Sarah Ostadabbas 2 Dec 16, 2022