Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention

Overview

Sinkhorn Transformer

PyPI version


This is a reproduction of the work outlined in Sparse Sinkhorn Attention, with additional enhancements.

It includes a parameterized sorting network, using sinkhorn normalization to sample a permutation matrix that matches the most relevant buckets of keys to the buckets of queries.

This work also brings in reversible networks and feed forward chunking (concepts introduced from Reformer) to bring about further memory savings.

Open In Colab 204k tokens (demonstration purposes)

Install

$ pip install sinkhorn_transformer

Use

A Sinkhorn Transformer based language model

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    bucket_size = 128,        # size of the buckets
    causal = False,           # auto-regressive or not
    n_sortcut = 2,            # use sortcut to reduce memory complexity to linear
    n_top_buckets = 2,        # sort specified number of key/value buckets to one query bucket. paper is at 1, defaults to 2
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    emb_dropout = 0.1,        # embedding dropout
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    dim_head = 64,            # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 2,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
    pkm_layers = (4,7),       # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,       # defaults to 128, but can be increased to 256 or 512 as memory allows
)

x = torch.randint(0, 20000, (1, 2048))
model(x) # (1, 2048, 20000)

A plain Sinkhorn Transformer, layers of sinkhorn attention

import torch
from sinkhorn_transformer import SinkhornTransformer

model = SinkhornTransformer(
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128
)

x = torch.randn(1, 2048, 1024)
model(x) # (1, 2048, 1024)

Sinkhorn Encoder / Decoder Transformer

import torch
from sinkhorn_transformer import SinkhornTransformerLM

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    bucket_size = 128,
    max_seq_len = DE_SEQ_LEN,
    reversible = True,
    return_embeddings = True
).cuda()

dec = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    causal = True,
    bucket_size = 128,
    max_seq_len = EN_SEQ_LEN,
    receives_context = True,
    context_bucket_size = 128,  # context key / values can be bucketed differently
    reversible = True
).cuda()

x = torch.randint(0, 20000, (1, DE_SEQ_LEN)).cuda()
y = torch.randint(0, 20000, (1, EN_SEQ_LEN)).cuda()

x_mask = torch.ones_like(x).bool().cuda()
y_mask = torch.ones_like(y).bool().cuda()

context = enc(x, input_mask=x_mask)
dec(y, context=context, input_mask=y_mask, context_mask=x_mask) # (1, 4096, 20000)

Autopadder

By default the model will complain if given an input that is not a multiple of the bucket size. To avoid having to make the same padding calculations each time, you can use the helper Autopadder class. It will take care of the input_mask for you as well, if given. Contextual key/values and mask are supported as well.

import torch
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer import Autopadder

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 2048,
    bucket_size = 128,
    causal = True
)

model = Autopadder(model, pad_left=True) # autopadder will fetch the bucket size and autopad input

x = torch.randint(0, 20000, (1, 1117)) # odd sequence length
model(x) # (1, 1117, 20000)

Sinkhorn

This repository has diverged from the paper and is now using attention in place of the original sorting net + gumbel sinkhorn sampling. I have not found a noticeable difference in performance yet, and the new scheme allows me to generalize the network to flexible sequence lengths. If you would like to try Sinkhorn, please use the following settings, which only works for non-causal networks.

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128,
    max_seq_len = 8192,
    use_simple_sort_net = True, # turn off attention sort net
    sinkhorn_iter = 7,          # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,              # use sortcut to reduce complexity to linear time
    temperature = 0.75,         # gumbel temperature - default is set at reported best in paper
    non_permutative = False,    # allow buckets of keys to be sorted to queries more than once
)

x = torch.randint(0, 20000, (1, 8192))
model(x) # (1, 8192, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Issues

Decoding and sequence lengths

Sinkhorn, when trained on fixed length sequences, seems to have trouble decoding sequences from scratch, mainly due to the fact that the sorting net has trouble generalizing when the buckets are partially filled with padding tokens.

Fortunately, I think I have found a simple solution. During training, for causal networks, randomly truncate the sequences and force the sorting net to generalize. I have provided a flag (randomly_truncate_sequence) for the AutoregressiveWrapper instance to make this easy.

import torch
from sinkhorn_transformer import SinkhornTransformerLM, AutoregressiveWrapper

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 75,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

I am open to suggestions if someone has found a better solution.

Causal sorting net

There is a potential problem with the causal sorting network, where the decision of which key/value buckets of the past sorts to a bucket is dependent only on the first token and not the rest (due to the bucketing scheme and preventing leakage of future to past).

I have attempted to alleviate this problem by rotating half the heads to the left by bucket size - 1, thereby promoting the last token to be first. This is also the reason why the AutoregressiveWrapper defaults to left padding during training, to always make sure that the last token in the sequence have a say in what to retrieve.

If anyone has found a cleaner solution, please let me know in the issues.

Alternatives

  1. Routing Transformer - https://github.com/lucidrains/routing-transformer
  2. Reformer - https://github.com/lucidrains/reformer-pytorch

Citations

@misc{tay2020sparse,
    title   = {Sparse Sinkhorn Attention},
    author  = {Yi Tay and Dara Bahri and Liu Yang and Donald Metzler and Da-Cheng Juan},
    year    = {2020},
    url.    = {https://arxiv.org/abs/2002.11296}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}

Comments
  • Training falling on version 0.0.14 and 0.0.15

    Training falling on version 0.0.14 and 0.0.15

    Hi, I testing training model on new versions of repo, and I have some troubles with 0.0.14 and 0.0.15. On 0.0.14, model always return nan on forward pass, version 0.0.15 lead to CUDA error:

    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Full error listing:

    ipython-input-7-1329da5363de> in forward(self, inputs, labels)
          7   def forward(self, inputs, labels=None):
          8     loss_mx = labels != -100
    ----> 9     output = self.model(inputs)
         10     output = output[loss_mx].view(-1, tokenizer.vocab_size)
         11     labels = labels[loss_mx].view(-1)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        376         x = self.to_token_emb(x)
        377         x = self.pos_emb(torch.arange(t, device=device)) + x
    --> 378         x = self.sinkhorn_transformer(x)
        379         return self.to_logits(x)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        359 
        360     def forward(self, x, input_mask = None):
    --> 361         return self.layers(x)
        362 
        363 class SinkhornTransformerLM(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        330     def forward(self, x, **kwargs):
        331         x = torch.cat([x, x], dim=-1)
    --> 332         x = self.layers(x, **kwargs)
        333         return torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
        334 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, arg_route, **kwargs)
        128         block_kwargs = {'f_args': f_args, 'g_args': g_args}
        129 
    --> 130         return _ReversibleFunction.apply(x, blocks, block_kwargs)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(ctx, x, blocks, kwargs)
         98         ctx.kwargs = kwargs
         99         for block in blocks:
    --> 100             x = block(x, **kwargs)
        101         ctx.y = x.detach()
        102         ctx.blocks = blocks
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, f_args, g_args)
         51         with torch.no_grad():
         52             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
    ---> 53             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
         54 
         55         return torch.cat([y1, y2], dim=2)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
         25 
         26         if not set_rng:
    ---> 27             return self.net(*args, **kwargs)
         28 
         29         rng_devices = []
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in <listcomp>(.0)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        112     def forward(self, x, **kwargs):
        113         x = self.norm(x)
    --> 114         return self.fn(x, **kwargs)
        115 
        116 class SortNet(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
        103 
        104     def forward(self, x):
    --> 105         return self.net(x)
        106 
        107 class PreNorm(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
         98     def forward(self, input):
         99         for module in self:
    --> 100             input = module(input)
        101         return input
        102 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
         85 
         86     def forward(self, input):
    ---> 87         return F.linear(input, self.weight, self.bias)
         88 
         89     def extra_repr(self):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
       1591         ret = torch.addmm(bias, input, weight.t())
       1592     else:
    -> 1593         output = input.matmul(weight.t())
       1594         if bias is not None:
       1595             output += bias
    
    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Also, version 0.0.11(and all other version from 0.0.8) work stable.

    opened by blizda 30
  • generation problem in a toy task

    generation problem in a toy task

    Here is the full script for my toy task (x -> xx like "abc" to "abcabc")

    from sinkhorn_transformer import SinkhornTransformerLM
    from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
    
    import random
    import tqdm
    import gzip
    import numpy as np
    import torch
    import torch.optim as optim
    from torch import nn
    from torch.nn import functional as F
    from torch.utils.data import DataLoader, Dataset
    
    # constants
    
    NUM_BATCHES = int(1e5)
    BATCH_SIZE = 4
    GRADIENT_ACCUMULATE_EVERY = 4
    LEARNING_RATE = 1e-4
    VALIDATE_EVERY  = 100
    GENERATE_EVERY  = 100
    ENC_SEQ_LEN=16
    DEC_SEQ_LEN=40
    NUM_TOKENS = 256 + 2
    BUCKET_SIZE = 8
    
    # helpers
    
    def top_k(logits, thres = 0.9):
        k = int((1 - thres) * logits.shape[-1])
        val, ind = torch.topk(logits, k)
        probs = torch.full_like(logits, float('-inf'))
        probs.scatter_(1, ind, val)
        return probs
    
    
    def cycle():
        while True:
            source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
    
            target = torch.cat((source, source), 1)
            prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
            target = torch.cat((prefix, target), axis=1)
    
            x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
            y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()
    
    
            yield (source, target, x_mask, y_mask)
    
    # instantiate model
    
    class MySinkhornTransformer(nn.Module):
        def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
            super().__init__()
            
            self.pad_token = 0
            self.sos_token = 1
    
            self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
                                             reversible=True, return_embeddings=True)
            self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len, 
                                             receives_context=True, context_bucket_size=bucket_size, reversible=True)
            self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
        
        @torch.no_grad()
        def generate(self, x, x_mask):
            context = self.enc(x, input_mask=x_mask)
            start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()
    
            return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)
    
        def forward(self, x, y, x_mask, y_mask, return_loss):
            context = self.enc(x, input_mask=x_mask)
            return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)
    
    
    model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
    model.cuda()
    # optimizer
    
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # training
    
    for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
        model.train()
    
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            source, target, x_mask, y_mask = next(cycle())
            loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
            loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
    
        if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                source, target, x_mask, y_mask = next(cycle())
                loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
                print(f'validation loss: {loss.item()}')
    
        if i % GENERATE_EVERY == 0:
            model.eval()
            
            source, target, x_mask, y_mask = next(cycle())
            
            sample = model.generate(x=source, x_mask=x_mask)
            print("input:  ", source[0])
            print("model output:  ", sample[0])
    

    After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: "x,x,x,x,x,y,y,y,y,y" like "aaaabbbb" instead of "abcdabcd". I was wondering what might be the underlying issue. Do you got any idea?

    opened by py4 18
  • Sortcut variant and bucket size

    Sortcut variant and bucket size

    From the paper it looks to me that for the sortcut variant, the queries are allowed to attend to all the key buckets post truncation for the non-causal case. If this is correct, won't the output be the same no matter what the bucket size is for the query.

    Based on my understanding of the paper, the authors only report results selecting 2 top key/value blocks with bucket sizes 8,16,32. They do not mention having different bucket sizes for query and key/value which I found in this repo.

    opened by jsmith1915 16
  • A noobish question about training...

    A noobish question about training...

    @lucidrains

    First of all, everything seems to be working now in Google Colab so thank you very much for fixing it.

    I have a quick question about training if you do not mind...

    I get very high loss results. Here is the example:

    training: 1%| | 509/100000 [1:55:22<387:23:57, 14.02s/it]training loss: 2.495532512664795

    Is this normal and I simply need to train more? Or does it mean that there is a problem somewhere?

    2.49 in 2 hours is way too much IMHO. I am pretty sure I am not doing it right so your advice would be really apppreciated.

    Thank you.

    P.S. I am running your wiki8 example in Google Colab Pro.

    opened by asigalov61 9
  • several question about implementation

    several question about implementation

    I am currently trying to implement Sparse Sinkhorn attention (non-causal, self-attention, pretrain for MLM task) with tensorflow and I would appreciate it if you could answer several questions about your code.

    1. I did not quite understand from the paper what is happening with queries and keys in SortCut, when we cut off most of the blocks. In your code it seems like you broadcast keys/queries like this [number_blocks, block_len, size_per_head] -> [:n_sortcut, block_len, size_per_head] -> [1, n_sortcut * block_len, size_per_head] -> [number_blocks, n_sortcut * block_len, size_per_head] I understand that pytorch handling expand_dim automatically and memory is preserved, but I do not understand how does the memory consumption stay linear: the resulting attention scores (dots) are still [number_blocks, block_len, num_sortcut * block_len] which is good but quadratic. Do I miss something?

    2. In the case of SortCut what is the meaning of the softmax(A)*Q? For example, for full attention outputs are vectors which are weighted average of other vectors; For local block attention outputs are vectors that are weighted average of other vectors in the block. How would you interpret the output of SortCut attention layer?

    3. In your code you concatenate regular keys and values with permuted/sorted keys and values, but in the paper (3.2. part) it seems like simple addition. Why is it different from the paper? Actually I can understand concatenation more than summation (because I am confused which attention masks to use in this case: from the regular block of permuted).

    4. Seems like you are using concatenated queries and keys as an input to make a SortNet. Maybe I completely missed the point, but I did not find anything about this in the article: the input sequence is used for SortNet instead of its projections. In fact, for some reason, I cannot make it converge if I use query and keys (maybe for a completely unrelated reasons). Have you tried to compare input sequence vs concat(q,k)?

    opened by w4-magnes 9
  • TypeError exception in AxialPositionalEncoding when using DataParallel

    TypeError exception in AxialPositionalEncoding when using DataParallel

    Hello,

    I want to run SinkhornTransformerLM using multiple GPUs, so I'm wrapping the model into torch.nn.DataParallel. However, when I do this, I get an exception:

    Traceback (most recent call last):
      File "script.py", line 27, in <module>
        model(x)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
        output.reraise()
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
        raise self.exc_type(msg)
    TypeError: Caught TypeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
        output = module(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 792, in forward
        x = self.axial_pos_emb(x) + x
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 243, in forward
        return pos_emb[:, :t]
    TypeError: 'int' object is not subscriptable
    

    Looking at the code, it would seem that self.weights does not get populated. To reproduce this error, I took the first example in README.md and changed

    model(x) # (1, 2048, 20000)
    

    to

    model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))).to('cuda')
    model(x)
    
    opened by kl0211 8
  • Do I need to pad?

    Do I need to pad?

    Excuse me for the noob question, I have sequences of different lengths and I use:

    model = AutoregressiveWrapper(model, ignore_index=PAD_ID, pad_value=PAD_ID)
    

    Do I need to pad as in this example or not?

    def __getitem__(self, index):
            seq_tokens = self.examples[index]
            input_ids = torch.tensor(seq_tokens, dtype=torch.long)
            input_ids = F.pad(input_ids, (seq_len - len(input_ids), 0), value=self.pad_value)
            return input_ids
    
    opened by timsoraro 8
  • Performance of the different attention variants in the repo

    Performance of the different attention variants in the repo

    Really great work!!

    I have a few questions.

    -Does attention sort net work better or at par with the simple sort net. -Does attention sort net work well for the sortcut variant also. -How does the sortcut variant perform when compared to vanilla sparse sinkhorn attention.

    It would be very helpful if you could share some plots/numbers from your experiments comparing the performance of the different variants such as attention sort net, routing based attention etc.

    opened by jsmith1915 7
  • Some questions about dropout

    Some questions about dropout

    Hi again @lucidrains, I just had some quick questions about dropout with the Sinkhorn Transformer, as I was just using my Linformer implementation (which as you know is based off of this repo), but it was overfitting my dataset. Therefore, I just had some quick questions about some dropout and your implementation, and I wanted to ask whether some design choices here were intentional or not:

    1. In the original Transformer, dropout was performed after each sublayer, before the residual connection. I noticed that you only have this after the SinkhornSelfAttention class, but not after the FeedForward class. Is this intentional?
    2. Speaking of the FeedForward class, you insert dropout after the first linear layer. I couldn't find this anywhere in any literature, were you able to find a reference of why this was effective? I put it into my implementation, and it seems to help, but i just don't know where this idea came from.
    3. On a similar note, do you know why the dots tensor in the self attention classes are dropped out? Again, I put it in my linformer and it seems to work, but I can't find a reference to this in the literature.
    4. Finally, the original transformer also dropped out the input tokens, like so (From the SinkhornTransformerLM class):
        def forward(self, x, **kwargs):
            _, t, device = *x.shape, x.device
            assert t <= self.max_seq_len, f'sequence length {t} is greater than maximum sequence length {self.max_seq_len}'
    
            x = self.to_token_emb(x)
            x = self.axial_pos_emb(x) + x
            """ Dropout would go here"""
            x = self.sinkhorn_transformer(x, **kwargs)
            return self.to_logits(x)
    

    Should they also be dropped out here as well?

    I now updated my repo such that all 4 of these dropout possibilities exist. I'll let you know if this helps overfitting.

    Thank you for your time!

    opened by tatp22 6
  • Training crashes due to inplace operation

    Training crashes due to inplace operation

    Hi there, I am trying to train this model on my custom dataset and training starts but after many iterations, training crashes due to this error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
    

    After enabling anomaly detection, here is the error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
    

    Traceback is not useful as it only shows error at line loss.backward() and torch.autograd module. Can you help me where the issue might be?

    Thanks

    opened by NaxAlpha 6
  • Positional Embedding

    Positional Embedding

    Hi! Found out that SinkhornTransformerLM uses two positional encodings simultaneously:

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L778-L779

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L792-L793

    I guess pos_emb can be removed as it introduces memory overhead and makes useless the utilization of Axial Positional Encoding that is designed to reduce the number of positional encoding parameters.

    Is there a special reasoning behind that?

    opened by ilya16 4
  • A wrapper of SinkhornTransformerEncDec

    A wrapper of SinkhornTransformerEncDec

    Could your please coded up a wrapper that removes a lot of the manual work in writing up a generic SinkhornTransformer encoder / decoder architecture.

    Thanks a lot! halexan

    opened by halexan 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Trankit is a Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing

Trankit: A Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing Trankit is a light-weight Transformer-based Pyth

652 Jan 06, 2023
Machine Learning Course Project, IMDB movie review sentiment analysis by lstm, cnn, and transformer

IMDB Sentiment Analysis This is the final project of Machine Learning Courses in Huazhong University of Science and Technology, School of Artificial I

Daniel 0 Dec 27, 2021
Code for ACL 2022 main conference paper "STEMM: Self-learning with Speech-text Manifold Mixup for Speech Translation".

STEMM: Self-learning with Speech-Text Manifold Mixup for Speech Translation This is a PyTorch implementation for the ACL 2022 main conference paper ST

ICTNLP 29 Oct 16, 2022
MiCECo - Misskey Custom Emoji Counter

MiCECo Misskey Custom Emoji Counter Introduction This little script counts custo

7 Dec 25, 2022
내부 작업용 django + vue(vuetify) boilerplate. 짠 하면 돌아감.

Pocket Galaxy 아주 간단한 개인용, 혹은 내부용 툴을 만들어야하는데 이왕이면 웹이 편하죠? 그럴때를 위해 만들어둔 django와 vue(vuetify)로 이뤄진 boilerplate 입니다. 각 폴더에 있는 설명서대로 실행을 시키면 일단 당장 뭔가가 돌아갑니

Jamie J. Seol 16 Dec 03, 2021
Crowd sourced training data for Rasa NLU models

NLU Training Data Crowd-sourced training data for the development and testing of Rasa NLU models. If you're interested in grabbing some data feel free

Rasa 169 Dec 26, 2022
NLPretext packages in a unique library all the text preprocessing functions you need to ease your NLP project.

NLPretext packages in a unique library all the text preprocessing functions you need to ease your NLP project.

Artefact 114 Dec 15, 2022
A highly sophisticated sequence-to-sequence model for code generation

CoderX A proof-of-concept AI system by Graham Neubig (June 30, 2021). About CoderX CoderX is a retrieval-based code generation AI system reminiscent o

Graham Neubig 39 Aug 03, 2021
Shared, streaming Python dict

UltraDict Sychronized, streaming Python dictionary that uses shared memory as a backend Warning: This is an early hack. There are only few unit tests

Ronny Rentner 192 Dec 23, 2022
FireFlyer Record file format, writer and reader for DL training samples.

FFRecord The FFRecord format is a simple format for storing a sequence of binary records developed by HFAiLab, which supports random access and Linux

77 Jan 04, 2023
Club chatbot

Chatbot Club chatbot Instructions to get the Chatterbot working Step 1. First make sure you are using a version of Python 3 or newer. To check your ve

5 Mar 07, 2022
An algorithm that can solve the word puzzle Wordle with an optimal number of guesses on HARD mode.

WordleSolver An algorithm that can solve the word puzzle Wordle with an optimal number of guesses on HARD mode. How to use the program Copy this proje

Akil Selvan Rajendra Janarthanan 3 Mar 02, 2022
BERT-based Financial Question Answering System

BERT-based Financial Question Answering System In this example, we use Jina, PyTorch, and Hugging Face transformers to build a production-ready BERT-b

Bithiah Yuan 61 Sep 18, 2022
Natural Language Processing for Adverse Drug Reaction (ADR) Detection

Natural Language Processing for Adverse Drug Reaction (ADR) Detection This repo contains code from a project to identify ADRs in discharge summaries a

Medicines Optimisation Service - Austin Health 21 Aug 05, 2022
A python wrapper around the ZPar parser for English.

NOTE This project is no longer under active development since there are now really nice pure Python parsers such as Stanza and Spacy. The repository w

ETS 49 Sep 12, 2022
Code for EMNLP 2021 main conference paper "Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification"

Code for EMNLP 2021 main conference paper "Text AutoAugment: Learning Compositional Augmentation Policy for Text Classification"

LancoPKU 105 Jan 03, 2023
NLP techniques such as named entity recognition, sentiment analysis, topic modeling, text classification with Python to predict sentiment and rating of drug from user reviews.

This file contains the following documents sumbited for Baruch CIS9665 group 9 fall 2021. 1. Dataset: drug_reviews.csv 2. python codes for text classi

Aarif Munwar Jahan 2 Jan 04, 2023
A natural language modeling framework based on PyTorch

Overview PyText is a deep-learning based NLP modeling framework built on PyTorch. PyText addresses the often-conflicting requirements of enabling rapi

Meta Research 6.4k Jan 08, 2023
Indonesia spellchecker with python

indonesia-spellchecker Ganti kata yang terdapat pada file teks.txt untuk diperiksa kebenaran kata. Run on local machine python3 main.py

Rahmat Agung Julians 1 Sep 14, 2022
Spert NLP Relation Extraction API deployed with torchserve for inference

URLMask Python program for Linux users to change a URL to ANY domain. A program than can take any url and mask it to any domain name you like. E.g. ne

Zichu Chen 1 Nov 24, 2021