Reformer, the efficient Transformer, in Pytorch

Overview

Reformer, the Efficient Transformer, in Pytorch

PyPI version

This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB

It includes LSH attention, reversible network, and chunking. It has been validated with an auto-regressive task (enwik8).

Open In Colab 32k tokens

Open In Colab 81k tokens with half precision

Install

$ pip install reformer_pytorch

Usage

A simple Reformer language model

# should fit in ~ 5gb - 8k tokens

import torch
from reformer_pytorch import ReformerLM

model = ReformerLM(
    num_tokens= 20000,
    dim = 1024,
    depth = 12,
    max_seq_len = 8192,
    heads = 8,
    lsh_dropout = 0.1,
    ff_dropout = 0.1,
    post_attn_dropout = 0.1,
    layer_dropout = 0.1,  # layer dropout from 'Reducing Transformer Depth on Demand' paper
    causal = True,        # auto-regressive or not
    bucket_size = 64,     # average size of qk per bucket, 64 was recommended in paper
    n_hashes = 4,         # 4 is permissible per author, 8 is the best but slower
    emb_dim = 128,        # embedding factorization for further memory savings
    dim_head = 64,        # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_chunks = 200,      # number of chunks for feedforward layer, make higher if there are memory issues
    attn_chunks = 8,      # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
    num_mem_kv = 128,       # persistent learned memory key values, from all-attention paper
    full_attn_thres = 1024, # use full attention if context length is less than set value
    reverse_thres = 1024,   # turn off reversibility for 2x speed for sequence lengths shorter or equal to the designated value
    use_scale_norm = False,  # use scale norm from 'Transformers without tears' paper
    use_rezero = False,      # remove normalization and use rezero from 'ReZero is All You Need'
    one_value_head = False,  # use one set of values for all heads from 'One Write-Head Is All You Need'
    weight_tie = False,           # tie parameters of each layer for no memory per additional depth
    weight_tie_embedding = False, # use token embedding for projection of output, some papers report better results
    n_local_attn_heads = 2,       # many papers suggest mixing local attention heads aids specialization and improves on certain tasks
    pkm_layers = (4,7),           # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,           # defaults to 128, but can be increased to 256 or 512 as memory allows
    use_full_attn = False    # only turn on this flag to override and turn on full attention for all sequence lengths. for comparison with LSH to show that it is working
).cuda()

x = torch.randint(0, 20000, (1, 8192)).long().cuda()
y = model(x) # (1, 8192, 20000)

The Reformer (just a stack of reversible LSH attention)

# should fit in ~ 5gb - 8k embeddings

import torch
from reformer_pytorch import Reformer

model = Reformer(
    dim = 512,
    depth = 12,
    max_seq_len = 8192,
    heads = 8,
    lsh_dropout = 0.1,
    causal = True
).cuda()

x = torch.randn(1, 8192, 512).cuda()
y = model(x) # (1, 8192, 512)

Self Attention with LSH

import torch
from reformer_pytorch import LSHSelfAttention

attn = LSHSelfAttention(
    dim = 128,
    heads = 8,
    bucket_size = 64,
    n_hashes = 8,
    causal = False
)

x = torch.randn(10, 1024, 128)
y = attn(x) # (10, 1024, 128)

LSH (locality sensitive hashing) Attention

import torch
from reformer_pytorch import LSHAttention

attn = LSHAttention(
    bucket_size = 64,
    n_hashes = 16,
    causal = True
)

qk = torch.randn(10, 1024, 128)
v = torch.randn(10, 1024, 128)

out, attn, buckets = attn(qk, v) # (10, 1024, 128)
# attn contains the unsorted attention weights, provided return_attn is set to True (costly otherwise)
# buckets will contain the bucket number (post-argmax) of each token of each batch

Masking

This repository supports masks on the input sequence input_mask (b x i_seq), the context sequence context_mask (b x c_seq), as well as the rarely used full attention matrix itself input_attn_mask (b x i_seq x i_seq), all made compatible with LSH attention. Masks are made of booleans where False denotes masking out prior to the softmax.

The causal triangular mask is all taken care of for you if you set causal = True.

import torch
from reformer_pytorch import ReformerLM

CONTEXT_LEN = 512
SEQ_LEN = 8192

model = ReformerLM(
    num_tokens= 20000,
    dim = 1024,
    depth = 1,
    max_seq_len = SEQ_LEN,
    ff_chunks = 8,
    causal = True
)

c = torch.randn(1, CONTEXT_LEN, 1024)
x = torch.randint(0, 20000, (1, SEQ_LEN)).long()

i_mask = torch.ones(1, SEQ_LEN).bool()
c_mask = torch.ones(1, CONTEXT_LEN).bool()

y = model(x, keys = c, input_mask = i_mask, context_mask = c_mask)
# masking done correctly in LSH attention

Positional Embeddings

Aran has informed me that the Reformer team used axial position embeddings with great results on longer sequences. I tested it out and indeed it works very well! So well in fact that I have decided to make this the default. You can adjust the shape and dimension of the axial embeddings by following the instructions below.

import torch
from reformer_pytorch import ReformerLM

model = ReformerLM(
    num_tokens= 20000,
    dim = 1024,
    depth = 12,
    max_seq_len = 8192,
    ff_chunks = 8,
    attn_chunks = 2,
    causal = True,
    axial_position_shape = (128, 64),  # the shape must multiply up to the max_seq_len (128 x 64 = 8192)
    axial_position_dims = (512, 512)   # the dims must sum up to the model dimensions (512 + 512 = 1024)
)

x = torch.randint(0, 20000, (1, 8192)).long()
y = model(x) # (1, 8192, 20000)

If you would rather use absolute positional embeddings, you can turn it on with absolute_position_emb = True flag on initialization.

Training

Since version 0.17.0, and some corrections to the reversible network, Reformer Pytorch is compatible with Microsoft's Deepspeed! If you have multiple local GPUs, you can follow the instructions / example here.

Examples

A full Reformer sequence → sequence, say translation

import torch
from reformer_pytorch import ReformerLM

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

encoder = ReformerLM(
    num_tokens = 20000,
    emb_dim = 128,
    dim = 1024,
    depth = 12,
    heads = 8,
    max_seq_len = DE_SEQ_LEN,
    fixed_position_emb = True,
    return_embeddings = True # return output of last attention layer
).cuda()

decoder = ReformerLM(
    num_tokens = 20000,
    emb_dim = 128,
    dim = 1024,
    depth = 12,
    heads = 8,
    max_seq_len = EN_SEQ_LEN,
    fixed_position_emb = True,
    causal = True
).cuda()

x  = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
yi = torch.randint(0, 20000, (1, EN_SEQ_LEN)).long().cuda()

enc_keys = encoder(x)               # (1, 4096, 1024)
yo = decoder(yi, keys = enc_keys)   # (1, 4096, 20000)

A full Reformer image → caption

import torch
from torch.nn import Sequential
from torchvision import models
from reformer_pytorch import Reformer, ReformerLM

resnet = models.resnet50(pretrained=True)
resnet = Sequential(*list(resnet.children())[:-4])

SEQ_LEN = 4096

encoder = Reformer(
    dim = 512,
    depth = 6,
    heads = 8,
    max_seq_len = 4096
)

decoder = ReformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    max_seq_len = SEQ_LEN,
    causal = True
)

x  = torch.randn(1, 3, 512, 512)
yi = torch.randint(0, 20000, (1, SEQ_LEN)).long()

visual_emb = resnet(x)
b, c, h, w = visual_emb.shape
visual_emb = visual_emb.view(1, c, h * w).transpose(1, 2) # nchw to nte

enc_keys = encoder(visual_emb)
yo = decoder(yi, keys = enc_keys) # (1, 4096, 20000)

Reformer Encoder Decoder Architecture

There is a bug in versions < 0.21.0. Please upgrade to at least the version specified for the working encoder / decoder Reformer.

By popular demand, I have coded up a wrapper that removes a lot of the manual work in writing up a generic Reformer encoder / decoder architecture. To use, you would import the ReformerEncDec class. Encoder keyword arguments would be passed with a enc_ prefix and decoder keyword arguments with dec_. The model dimension (dim) must be prefix free and will be shared between encoder and decoder. The framework will also take care of passing the encoder input mask to the decoder context mask, unless explicitly overridden.

import torch
from reformer_pytorch import ReformerEncDec

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc_dec = ReformerEncDec(
    dim = 512,
    enc_num_tokens = 20000,
    enc_depth = 6,
    enc_max_seq_len = DE_SEQ_LEN,
    dec_num_tokens = 20000,
    dec_depth = 6,
    dec_max_seq_len = EN_SEQ_LEN
).cuda()

train_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
train_seq_out = torch.randint(0, 20000, (1, EN_SEQ_LEN)).long().cuda()
input_mask = torch.ones(1, DE_SEQ_LEN).bool().cuda()

loss = enc_dec(train_seq_in, train_seq_out, return_loss = True, enc_input_mask = input_mask)
loss.backward()
# learn

# evaluate with the following
eval_seq_in = torch.randint(0, 20000, (1, DE_SEQ_LEN)).long().cuda()
eval_seq_out_start = torch.tensor([[0.]]).long().cuda() # assume 0 is id of start token
samples = enc_dec.generate(eval_seq_in, eval_seq_out_start, seq_len = EN_SEQ_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= 1024) decode the tokens

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Customizing Feedforward

By default, the activation function is GELU. If you would like an alternative activation function, you can pass in the class to the keyword ff_activation.

import torch
from reformer_pytorch import ReformerLM
from torch import nn

model = ReformerLM(
    num_tokens= 20000,
    dim = 512,
    depth = 6,
    max_seq_len = 8192,
    ff_chunks = 8,
    ff_dropout = 0.1,
    ff_mult = 6,
    ff_activation = nn.LeakyReLU,
    ff_glu = True # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
)

x = torch.randint(0, 20000, (1, 8192)).long()
y = model(x) # (1, 8192, 20000)

Research

To access the attention weights and bucket distribution, simply wrap the instantiated model with the Recorder wrapper class.

import torch
from reformer_pytorch import Reformer, Recorder

model = Reformer(
    dim = 512,
    depth = 12,
    max_seq_len = 8192,
    heads = 8,
    lsh_dropout = 0.1,
    causal = True
).cuda()

model = Recorder(model)

x = torch.randn(1, 8192, 512).cuda()
y = model(x)

model.recordings[0] # a list of attention weights and buckets for the first forward pass

model.turn_off() # stop recording
model.turn_on() # start recording
model.clear() # clear the recordings

model = model.eject() # recover the original model and remove all listeners

Additional Helpers

Reformer comes with a slight drawback that the sequence must be neatly divisible by the bucket size * 2. I have provided a small helper tool that can help you auto-round the sequence length to the next best multiple.

import torch
from reformer_pytorch import ReformerLM, Autopadder

model = ReformerLM(
    num_tokens= 20000,
    dim = 1024,
    depth = 12,
    max_seq_len = 8192,
    heads = 8,
    lsh_dropout = 0.1,
    causal = True,
    bucket_size = 63,   # odd bucket size
    num_mem_kv = 77     # odd memory key length
).cuda()

model = Autopadder(model)

SEQ_LEN = 7777 # odd sequence length
keys = torch.randn(1, 137, 1024) # odd keys length

x = torch.randint(0, 20000, (1, SEQ_LEN)).long().cuda()
y = model(x, keys = keys) # (1, 7777, 20000)

Helpers for training auto-regressive models

A lot of users are only interested in an auto-regressive language model (like GPT-2). Here is a training wrapper to make it easy to both train and evaluate on arbitrarily lengthed sequences of encoded tokens. You will have to take care of the encoding and decoding yourself.

import torch
from torch import randint

from reformer_pytorch import ReformerLM
from reformer_pytorch.generative_tools import TrainingWrapper

model = ReformerLM(
    num_tokens= 20000,
    dim = 1024,
    depth = 12,
    max_seq_len = 4096,
    lsh_dropout = 0.1,
    causal = True,
    full_attn_thres = 1024
)

# 0 is used for padding and no loss to be calculated on it
model = TrainingWrapper(model, ignore_index = 0, pad_value = 0)

# the wrapper can handle evenly packed sequences
x_train = randint(0, 20000, (3, 357))

# or if you have a list of uneven sequences, it will be padded for you
x_train = [
    randint(0, 20000, (120,)),
    randint(0, 20000, (253,)),
    randint(0, 20000, (846,))
]

# when training, set return_loss equal to True
model.train()
loss = model(x_train, return_loss = True)
loss.backward()

# when evaluating, just use the generate function, which will default to top_k sampling with temperature of 1.
initial = torch.tensor([[0]]).long() # assume 0 is start token
sample = model.generate(initial, 100, temperature=1., filter_thres = 0.9, eos_token = 1) # assume end token is 1, or omit and it will sample up to 100
print(sample.shape) # (1, <=100) token ids

Issues

Andrea has uncovered that using O2 optimization level when training with mixed precision can lead to instability. Please use O1 instead, which can be set with the amp_level in Pytorch Lightning, or opt_level in Nvidia's Apex library.

Alternatives

  1. Routing Transformer - https://github.com/lucidrains/routing-transformer
  2. Sinkhorn Transformer - https://github.com/lucidrains/sinkhorn-transformer
  3. Performer - https://github.com/lucidrains/performer-pytorch
  4. Linear Transformer - https://github.com/lucidrains/linear-attention-transformer/
  5. Compressive Transformer - https://github.com/lucidrains/compressive-transformer-pytorch

Citations

@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@article{DBLP:journals/corr/abs-1907-01470,
    author    = {Sainbayar Sukhbaatar and
               Edouard Grave and
               Guillaume Lample and
               Herv{\'{e}} J{\'{e}}gou and
               Armand Joulin},
    title     = {Augmenting Self-attention with Persistent Memory},
    journal   = {CoRR},
    volume    = {abs/1907.01470},
    year      = {2019},
    url       = {http://arxiv.org/abs/1907.01470}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@inproceedings{fan2020reducing,
    title     = {Reducing Transformer Depth on Demand with Structured Dropout},
    author    = {Angela Fan and Edouard Grave and Armand Joulin},
    booktitle = {International Conference on Learning Representations},
    year      = {2020},
    url       = {https://openreview.net/forum?id=SylO2yStDr}
}
@article{Shazeer2019FastTD,
    title   = {Fast Transformer Decoding: One Write-Head is All You Need},
    author  = {Noam Shazeer},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1911.02150}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}
@misc{dong2021attention,
    title   = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth}, 
    author  = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas},
    year    = {2021},
    eprint  = {2103.03404}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

Comments
  • why do we have negative self-attention score?

    why do we have negative self-attention score?

    hey lucid ! sorry to ask this dumb question ..

    so what I understand about self attention is it gives the relative information among all words(representing relevance of word?).

    So I expected to see non negative value but apparently there were many negative self-attention score inside since we initialized the Q,K from uniform -0.01~0.01 in the beginning.

    What is the math or reasoning behind that we initialize all learnable params with uniform -.01~0.01? do you intend to assume that we are living in sphere(geometrical reasoninng)?

    how would you interpret the negative val of self attention score?

    I think this will depend on your thought/advice for my Q2, but I am thinking of assuming that if I have negative val, it means it is MORE Negatively important as for the relevance info. which means, like a "BAD TO HAVE" ? then how would you interprete self-attention score of 0? I would guess it is not important at all?

    Thanks again a lot...really...

    opened by muiPomeranian 31
  • Runtime error when attempting to use data distributed parallel

    Runtime error when attempting to use data distributed parallel

    Thank you for putting in the time to do this. I have a bunch of ideas for it.

    I crudely ported your example training script to use the pytorch-lightning library and when I attempted to use data distributed ran into a crash, The problem may be down in the revtorch library, but I want to hand the script off to you so you can play with it while reporting it so you can take a look and decide where the issue is.

    you can get the crash by supplying the --distributed flag to the script with any number of gpus

    Epoch 1:   0%|                                                                                                                                                                         | 0/1451 [00:00<?, ?batch/s]Traceback (most recent call last):
      File "example/train_lightning.py", line 166, in <module>
        main()
      File "example/train_lightning.py", line 161, in main
        trainer.fit(model)
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 687, in fit
        mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
      File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
        while not spawn_context.join():
      File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 118, in join
        raise Exception(msg)
    Exception: 
    
    -- Process 0 terminated with the following error:
    Traceback (most recent call last):
      File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
        fn(i, *args)
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 331, in ddp_train
        self.run_pretrain_routine(model)
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 829, in run_pretrain_routine
        self.train()
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 332, in train
        self.run_training_epoch()
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 386, in run_training_epoch
        output = self.run_training_batch(batch, batch_idx)
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 506, in run_training_batch
        loss = optimizer_closure()
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 489, in optimizer_closure
        model_ref.backward(self.use_amp, closure_loss, optimizer)
      File "/opt/conda/lib/python3.6/site-packages/pytorch_lightning/core/hooks.py", line 154, in backward
        loss.backward()
      File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
        allow_unreachable=True)  # allow_unreachable flag
      File "/opt/conda/lib/python3.6/site-packages/torch/autograd/function.py", line 77, in apply
        return self._forward_cls.backward(self, *args)
      File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 161, in backward
        y, dy = ctx.reversible_blocks[i].backward_pass(y, dy)
      File "/opt/conda/lib/python3.6/site-packages/revtorch/revtorch.py", line 89, in backward_pass
        gy1.backward(dy2)
      File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
        allow_unreachable=True)  # allow_unreachable flag
    RuntimeError: Expected to mark a variable ready only once. This error is caused by use of a module parameter outside the `forward` function. The return value of the `forward` function is inspected by the distributed data parallel wrapper to figure out if any of the module's parameters went unused. If this is the case, it knows they won't receive gradients in a backward pass. If any of those parameters are then used outside `forward`, this error condition is triggered. You can disable unused parameter detection by passing the keyword argument `find_unused_parameters=False` to `torch.nn.parallel.DistributedDataParallel`.
    

    script:

    from reformer_pytorch import ReformerLM
    
    import tqdm
    import gzip
    import numpy as np
    import torch.optim as optim
    from torch.nn import functional as F
    from torch.utils.data import DataLoader, Dataset
    from pytorch_lightning import Trainer
    
    import os
    
    import torch
    from torch import nn
    from torchvision import transforms
    
    import argparse
    
    import pytorch_lightning as pl
    
    # constants
    
    NUM_BATCHES = int(1e5)
    BATCH_SIZE = 4
    GRADIENT_ACCUMULATE_EVERY = 4
    LEARNING_RATE = 1e-4
    VALIDATE_EVERY = 100
    
    SEQ_LEN = 4096
    
    # helpers
    
    def cycle(loader):
        while True:
            for data in loader:
                yield data
    
    with gzip.open('./data/enwik8.gz') as file:
        X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
        trX, vaX = np.split(X, [int(90e6)])
        data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
    
    class TextSamplerDataset(Dataset):
        def __init__(self, data, seq_len):
            super().__init__()
            self.data = data
            self.seq_len = seq_len
    
        def __getitem__(self, index):
            rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
            full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
            return full_seq[0:-1], full_seq[1:]
    
        def __len__(self):
            return self.data.size(0) // self.seq_len
    
    class ReformerTrainer(pl.LightningModule):
    
        def __init__(self, batch_size=4, distributed_mode=False):
            super(ReformerTrainer, self).__init__()
            self.batch_size = batch_size
            self.distributed_mode = distributed_mode
            # instantiate model
            self.model = ReformerLM(
                emb = 512,
                depth = 6,
                max_seq_len = SEQ_LEN,
                num_tokens = 256,
                heads = 8,
                bucket_size = 64,
                n_hashes = 4,
                ff_chunks = 10,
                lsh_dropout = 0.1,
                weight_tie = True,
                causal = True,
                use_full_attn = False # set this to true for comparison with full attention
            )
    
        def forward(self, x):
            pred = self.model(x).transpose(1, 2)
            return pred
    
        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.forward(x)
            loss = F.cross_entropy(y_hat, y, reduction='mean')
            tensorboard_logs = {'train_loss': loss}
            return {'loss': loss, 'log': tensorboard_logs}
    
        def validation_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.forward(x)
            return {'val_loss': F.cross_entropy(y_hat, y)}
        
        def validation_end(self, outputs):
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
            tensorboard_logs = {'val_loss': avg_loss}
            return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
            
        def test_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.forward(x)
            return {'test_loss': F.cross_entropy(y_hat, y)}
        
        def test_end(self, outputs):
            avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
            tensorboard_logs = {'test_loss': avg_loss}
            return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
    
        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)
    
        @pl.data_loader
        def train_dataloader(self):
            # REQUIRED
            dataset = TextSamplerDataset(data_train, SEQ_LEN)
            if self.distributed_mode:
                dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
                dataloader = DataLoader(dataset, sampler=dist_sampler, batch_size=self.batch_size)
            else:
                dataloader = DataLoader(dataset, batch_size=self.batch_size)
            return dataloader
    
        @pl.data_loader
        def val_dataloader(self):
            # OPTIONAL
            dataset = TextSamplerDataset(data_val, SEQ_LEN)
            if self.distributed_mode:
                dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
                dataloader = DataLoader(dataset, sampler=dist_sampler, batch_size=self.batch_size)
            else:
                dataloader = DataLoader(dataset, batch_size=self.batch_size)
            return dataloader
    
        @pl.data_loader
        def test_dataloader(self):
            dataset = TextSamplerDataset(data_val, SEQ_LEN)
            if self.distributed_mode:
                dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
                dataloader = DataLoader(dataset, sampler=dist_sampler, batch_size=self.batch_size)
            else:
                dataloader = DataLoader(dataset, batch_size=self.batch_size)
            return dataloader
    
    def main():
        
        parser = argparse.ArgumentParser("reformer-lightning example")
        parser.add_argument("--gpus", default=1, help="gpus to use")
        parser.add_argument("-d", "--distributed", default=False, action="store_true",
                            help="activates distributed using data distributed parallel")
        parser.add_argument("-b", "--batch_size", type=int, default=4, help="batch_size")
        args = parser.parse_args()
    
        model = ReformerTrainer(args.batch_size, args.distributed)
    
        # most basic trainer, uses good defaults
        if args.distributed:
            trainer = Trainer(gpus=args.gpus, distributed_backend='ddp', accumulate_grad_batches=GRADIENT_ACCUMULATE_EVERY)
        else:
            trainer = Trainer(gpus=args.gpus, distributed_backend='dp', accumulate_grad_batches=GRADIENT_ACCUMULATE_EVERY)
        trainer.fit(model)
        trainer.test()
    
    
    if __name__ == "__main__":
        main()
    
    solved - pending response 
    opened by Phirefly9 28
  • Request for help for LSHSelfAttention()

    Request for help for LSHSelfAttention()

    Hi @lucidrains thank you for your excellent work (I star it).

    I am trying to use the LSHSelfAttention() layer in my network instead of my transformer encoder layer.

    A pseudocode of what I am doing is that:

    word_embeddings = word_embeddings(input)  # batch, seq_len, emb_dim
    lsh_encoded = self.lsh_self_attention(word_embeddings)
    

    I continuously get a vector of NaN values, to avoid it I decrease my learning rate from 1e-3 to 1e-5, but nothing is changed.

    1. Am I using the correct layer?
    2. Should I use Reformer() instead of LSHSelfAttention()? I tried to use Reformer() but I also get an error there, which tells me that my sequence is divisible by the number of buckets (I'm still working on it).
    opened by andreabac3 22
  • Script to easily train text generation models a la gpt-2-simple repo

    Script to easily train text generation models a la gpt-2-simple repo

    Greetings! Your repository is a very welcomed contribution. I tried to follow the examples in this repo but faced some problems. Trying to modify the enwik8_simple I didn't understand how to:

    1. Load my custom data into examples (I have a poetry dataset).
    2. Generate output from a start prefix and until an end token.

    Thanks a lot!

    enhancement 
    opened by timsoraro 22
  • Using a pre-trained REFORMER for fine-tuning takes soooo looong

    Using a pre-trained REFORMER for fine-tuning takes soooo looong

    Hi there, I've pre-trained a REFORMER for 4 days with 500MB of text data, just to try how it works. Now I'm trying to use it for fine-tuning and it's taking huge time for each epoch... I'm using a nice GPU (the one you were jealous about :P ) but it's still taking too long, as you can see below. When compared to a normal BERT, for example, there's no point of comparison, as the latter needs only a couple of secs for fine-tuning while this one is taking hours.

    EPOCH: 0%| | 0/40 [00:00<?, ?it/s] Training epoch 0: 0%| | 0/1041 [00:00<?, ?it/s] Training epoch 0: 0%| | 1/1041 [00:13<3:46:44, 13.08s/it] Training epoch 0: 0%| | 2/1041 [00:24<3:39:14, 12.66s/it] Training epoch 0: 0%| | 3/1041 [00:36<3:33:28, 12.34s/it] Training epoch 0: 0%| | 4/1041 [00:48<3:31:05, 12.21s/it] Training epoch 0: 0%| | 5/1041 [01:00<3:29:03, 12.11s/it] Training epoch 0: 1%| | 6/1041 [01:11<3:26:42, 11.98s/it] Training epoch 0: 1%| | 7/1041 [01:23<3:24:39, 11.88s/it] Training epoch 0: 1%| | 8/1041 [01:35<3:25:09, 11.92s/it] Training epoch 0: 1%| | 9/1041 [01:46<3:22:59, 11.80s/it] Training epoch 0: 1%| | 10/1041 [01:58<3:23:07, 11.82s/it] Training epoch 0: 1%| | 11/1041 [02:11<3:25:52, 11.99s/it] Training epoch 0: 1%| | 12/1041 [02:23<3:25:39, 11.99s/it] Training epoch 0: 1%| | 13/1041 [02:34<3:21:48, 11.78s/it] Training epoch 0: 1%|▏ | 14/1041 [02:46<3:23:27, 11.89s/it] Training epoch 0: 1%|▏ | 15/1041 [02:57<3:19:09, 11.65s/it] Training epoch 0: 2%|▏ | 16/1041 [03:10<3:22:35, 11.86s/it] Training epoch 0: 2%|▏ | 17/1041 [03:22<3:22:47, 11.88s/it] Training epoch 0: 2%|▏ | 18/1041 [03:33<3:22:16, 11.86s/it] Training epoch 0: 2%|▏ | 19/1041 [03:45<3:23:15, 11.93s/it] Training epoch 0: 2%|▏ | 20/1041 [03:57<3:20:54, 11.81s/it] Training epoch 0: 2%|▏ | 21/1041 [04:09<3:19:35, 11.74s/it] Training epoch 0: 2%|▏ | 22/1041 [04:21<3:22:12, 11.91s/it] Training epoch 0: 2%|▏ | 23/1041 [04:32<3:20:29, 11.82s/it] Training epoch 0: 2%|▏ | 24/1041 [04:44<3:16:36, 11.60s/it] Training epoch 0: 2%|▏ | 25/1041 [04:56<3:18:51, 11.74s/it] Training epoch 0: 2%|▏ | 26/1041 [05:07<3:17:10, 11.66s/it] Training epoch 0: 3%|▎ | 27/1041 [05:18<3:15:37, 11.58s/it] Training epoch 0: 3%|▎ | 28/1041 [05:30<3:15:43, 11.59s/it] Training epoch 0: 3%|▎ | 29/1041 [05:42<3:16:18, 11.64s/it] Training epoch 0: 3%|▎ | 30/1041 [05:54<3:16:54, 11.69s/it] Training epoch 0: 3%|▎ | 31/1041 [06:05<3:12:38, 11.44s/it] Training epoch 0: 3%|▎ | 32/1041 [06:16<3:11:49, 11.41s/it] Training epoch 0: 3%|▎ | 33/1041 [06:27<3:11:52, 11.42s/it] Training epoch 0: 3%|▎ | 34/1041 [06:39<3:13:15, 11.51s/it] Training epoch 0: 3%|▎ | 35/1041 [06:50<3:10:34, 11.37s/it] Training epoch 0: 3%|▎ | 36/1041 [07:02<3:12:29, 11.49s/it] Training epoch 0: 4%|▎ | 37/1041 [07:13<3:11:37, 11.45s/it] Training epoch 0: 4%|▎ | 38/1041 [07:24<3:09:23, 11.33s/it] Training epoch 0: 4%|▎ | 39/1041 [07:36<3:09:00, 11.32s/it] Training epoch 0: 4%|▍ | 40/1041 [07:47<3:09:20, 11.35s/it] Training epoch 0: 4%|▍ | 41/1041 [07:58<3:08:17, 11.30s/it]

    Do you know which may be the problem? I've created this class for NER: class ReformerForTokenClassification(nn.Module):

    def __init__(self, num_labels, model_dim, depth, 
                 n_tokens, maxlen, heads, weights_file, n_hashes, dropout=0.2):
        super(ReformerForTokenClassification, self).__init__()
        self.num_labels = num_labels
        self.model_dim = model_dim
        self.reformer = ReformerLM(n_tokens, model_dim, depth, maxlen, heads,
                                  n_hashes, return_embeddings=True)
        model_dict = self.reformer.state_dict()
        pretrained_dict = torch.load(weights_file)
        weights_dict = {k:v for k, v in pretrained_dict.items() if 'to_logits' not in k}
        self.reformer.load_state_dict(weights_dict)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.model_dim, self.num_labels)
    
    def forward(self, input_ids=None, labels=None):
    
        outputs = self.reformer(input_ids)
        sequence_output = self.dropout(outputs)
        logits = self.classifier(sequence_output)
        outputs = (logits, outputs[2:])
    
        if labels is not None:
    
            loss_fct = nn.CrossEntropyLoss()
    
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
            outputs = (loss, outputs[0], outputs[1])
    
        return outputs
    

    model = ReformerForTokenClassification(num_labels=9, model_dim=768, depth=12, maxlen=512, n_tokens=tokenizer.vocab_size, heads=8, n_hashes=4, weights_file='ckpts_pequeño_oscar/model_state_dict.pt')

    question solved - pending response 
    opened by avacaondata 20
  • Possible bug in end-dec attention?

    Possible bug in end-dec attention?

    In the encoder-decoder architecture, encoder output is passed to decoder as keys to be used in attention. Here (https://github.com/lucidrains/reformer-pytorch/blob/5f5bbf4fd5806f45d2cb3b7373021786b3b34e5b/reformer_pytorch/reformer_pytorch.py#L598) you are concating keys with x (where x is the decoder input) and then apply self-attention. Does it make sense to do self attention on decoder-input and encoder outputs? Because even in the trax codes these two are handled separately: (https://github.com/google/trax/blob/c7c47a14ef8ea5b260ac78c22cbadd6dc1fb605b/trax/models/reformer/reformer.py#L968) at first self attention is applied on the decoder input, and then a seperate encoder-decoder attention is applied between the new representation for decoder and the keys.

    I don't if this is the reason or not but I have this simple copy-reverse task where the loss stops at 2.08. However in the trax code the loss becomes close to 0 after a few steps.

    def cycle():
        while True:
            source = torch.randint(2, 10, (32, 768)).long().cuda()
            target_np = np.flip(source.cpu().numpy(),axis=1).copy()   #Reverse of copy of numpy array of given tensor
            target = torch.from_numpy(target_np).long().cuda()
    
            mask = torch.ones(32, 768).bool().cuda()
    
            yield (source, target, mask)
    
    # First example: Copy Reverse: 768 tokens - vocab size: 256
    
    model = ReformerEncDec(
        dim = 512,
        enc_num_tokens = 256,
        enc_depth = 1,
        enc_max_seq_len = 768,
        enc_heads=1,
        dec_num_tokens = 256,
        dec_depth = 1,
        dec_max_seq_len = 768,
        dec_heads=1,
    ).cuda()
    
    #model = TrainingWrapper(model)
    model.cuda()
    
    
    # optimizer
    
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # training
    
    for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
        model.train()
    
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            source, target, mask = next(cycle())
            loss = model(seq_in=source, seq_out=target, return_loss = True, enc_input_mask=mask)
            loss.backward()
    
        print(f'training loss: {loss.item()}')
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
    
    opened by py4 18
  • DeepSpeed and Generate Method

    DeepSpeed and Generate Method

    Hi @lucidrains

    I'm currently testing the generate function of the TrainingWrapper class. When I use DeepSpeed and I try to generate a sequence it gives me the following error: AttributeError: 'DeepSpeedLight' object has no attribute 'generate'

    Is it because Generation can only be done outside DeepSpeed Engine?

    Thank you very much, once again! :)

    opened by CalogeroZarbo 14
  • definition of input_attn_mask  and context_mask

    definition of input_attn_mask and context_mask

    Hi @lucidrains, in a encoder-decoder setting, consider input to decoder as target, denote encoder input length S and decoder length input T, the size of input_mask and input_attn_mask should be NxT and TxT. It is unclear context_mask should have NxS (padding) or TxS (memory)?

    solved - pending response 
    opened by opsuisppn 14
  • Generation doesn't seem right

    Generation doesn't seem right

    First of all, the deepspeed implementation is awesome! I trained on 4 V100 and got a 8.5X boost and 20X with fp16 turned on compared to just one GPU.

    I trained a model on 300MB dialogue dataset for 2 epochs but the generated samples weren't good. I'm quite sure I messed up with the code somehow since I come from a programming background and not ML.

    Here's my code: https://pastebin.com/V1t5Ctg7 lr = 0.0004, bs=32, vocab_size=2000

    Here are some samples: https://pastebin.com/yCL0vVdv

    From my experiments with other architectures (GPT-2 from scratch, LSTM), it should generate decent samples after feeding this data so something must be wrong somewhere.

    opened by timsoraro 13
  • torch.nn.DataParallel causes strange GPU memory overflow

    torch.nn.DataParallel causes strange GPU memory overflow

    Thanks for your great job! When i am testing this model with code as

    import torch
    from reformer_pytorch import ReformerLM
    from torch.nn import functional as F
    
    model = ReformerLM(
        num_tokens=20000,
        dim=1024,
        depth=24,
        max_seq_len=1024,
        heads=16,
        lsh_dropout=0.1,
        emb_dim=1024,  # embedding factorization for further memory savings
        causal=True,  # auto-regressive or not
        bucket_size=64,  # average size of qk per bucket, 64 was recommended in paper
        n_hashes=8,  # 4 is permissible per author, 8 is the best but slower
        ff_chunks=200,  # number of chunks for feedforward layer, make higher if there are memory issues
        weight_tie=False,  # tie parameters of each layer for no memory per additional depth
        attn_chunks=8,  # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
        num_mem_kv=0,  # persistent learned memory key values, from all-attention paper
        twin_attention=False,  # both branches of the reversible network will be attention
        use_full_attn=True,  # use full self attention, for comparison
        full_attn_thres=128,  # use full attention if context length is less than set value
        use_scale_norm=False  # use scale norm from 'Transformers without tears' paper
    ).cuda()
    
    model = torch.nn.DataParallel(model)
    model.train()
    x = torch.randint(0, 20000, (8, 1024)).long().cuda()
    y = torch.randint(0, 20000, (8, 1024)).long().cuda()
    pred = model(x)
    loss = F.cross_entropy(pred.transpose(1, 2), y, reduction='mean')
    loss.backward()
    import ipdb
    ipdb.set_trace()
    

    When without model = torch.nn.DataParallel(model), 7616M memory is used. But after I add model = torch.nn.DataParallel(model), it causes OOV while 8 gpus has 16GB memory for each. I think maybe it is the problem of revtorch?

    opened by ewrfcas 13
  • Visualizing attention weights?

    Visualizing attention weights?

    Thanks for the cool library!

    I'm working on a seq2seq demo using it, and I'd like to visualize the attention weights, but it isn't clear how to get them out of the ReformerLM class. Can you point me in the right direction?

    opened by justindujardin 13
  • FLOPs calculation for LSHSelfAttention in LSH mode and Full attention mode

    FLOPs calculation for LSHSelfAttention in LSH mode and Full attention mode

    As per my understanding, FLOPs calculation is usually done on complete model, but, I am trying to test computational cost comparison of only LSH attention module of Reformer by providing it random input vectors. This LSH attention module switches between LSH hashing and full dot product based attention using setting flag use_full_attn=False and use_full_attn=True.

    But the problem is that whatever size of input vectors I set for qk and v, the number of FLOPs appear to be same for both calculations.

    By setting use_full_attn=False and use_full_attn=True the attention model is switched between LSH based attention and Full attention. I have verified this in debug mode of Spyder IDE.

    Am I missing something?

    How can I verify this? I would be grateful if someone can help me.

    Code: (From Reformer Github website)

    
    import torch
    
    from reformer_pytorch import LSHAttention
    
    model = LSHSelfAttention(
       
    
    >  dim = 128,
    
        heads = 8,
        bucket_size = 64,
        n_hashes = 16,
        causal = True,
        use_full_attn=**False**,
        return_attn = False
    ).to(device)
    
    qk = torch.randn(10, 1024, 128)
    
    v = torch.randn(10, 1024, 128)
    
    x = torch.randn(1, 1024, 128).to(device)
    
    y = model(x) # (10, 1024, 128)
    
    Code for FLOPs calculation: (https://github.com/cszn/KAIR/blob/master/utils/utils_modelsummary.py)
    
    with torch.no_grad():
        input_dim = (1, 16384, 128)  # set the input dimension
    
        flops = get_model_flops(model, input_dim, False)
        
        print('{:>16s} : {:<.4f} [G]'.format('FLOPs', flops/10**9))
    
    ```Result in both cases:
    
    FLOPs : 0.8053 [G]
    opened by zaidilyas89 0
  • error in eval loss & ppl

    error in eval loss & ppl

    Hey! https://github.com/lucidrains/reformer-pytorch/blob/2cbc36bb280c0a6de46d838baad11a533a015fc3/pretraining/self-supervised.py#L327

    you are dividing the eval_loss & perplexity each time you pass through the loop

    Cheers, Guy :)

    opened by dar-tau 0
  • Reversible layers increase memory usage

    Reversible layers increase memory usage

    I'm checking memory usage using nvidia-smi. When I turn on reversibility (setting reverse_thres to two times the input length) it's using 8.8 GB memory. When I turn it off (setting reverse_thres to half of the input length), it's using 3.1 GB memory, and it is (naturally) faster. But the memory part doesn't make sense. What can be the problem here?

    opened by serkansulun 1
  • Class token implementation

    Class token implementation

    Hi,

    I was wondering how the class token is supposed to be handled in the reversible design? Since, replicating the token across the two residual paths is perhaps not optimal.

    Any thoughts/pointers to code is appreciated.

    opened by karttikeya 1
  • Image as input

    Image as input

    i'm watching at image-> captioning example The variable yi what is supposed to contain? In image captioning we have only the image as input of the model.. watching at how yi is generated it seems to me it must contains the labels, but in that way we are giving the labels as input to the model, and that is not right. i'm confused, what is yi and why? thanks in advance

    opened by fedeloper 0
Releases(1.4.4)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
PyTorch Extension Library of Optimized Scatter Operations

PyTorch Scatter Documentation This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations fo

Matthias Fey 1.2k Jan 07, 2023
S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

S3-plugin is a high performance PyTorch dataset library to efficiently access datasets stored in S3 buckets.

Amazon Web Services 138 Jan 03, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

169 Dec 23, 2022
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation.

PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation. It aims to accelerate research by providing a modular design that all

Preferred Networks, Inc. 96 Nov 28, 2022
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
Differentiable SDE solvers with GPU support and efficient sensitivity analysis.

PyTorch Implementation of Differentiable SDE Solvers This library provides stochastic differential equation (SDE) solvers with GPU support and efficie

Google Research 1.2k Jan 04, 2023
3D-RETR: End-to-End Single and Multi-View3D Reconstruction with Transformers

3D-RETR: End-to-End Single and Multi-View 3D Reconstruction with Transformers (BMVC 2021) Zai Shi*, Zhao Meng*, Yiran Xing, Yunpu Ma, Roger Wattenhofe

Zai Shi 36 Dec 21, 2022
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 03, 2023
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 04, 2023
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 06, 2023
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022