Trax — Deep Learning with Clear Code and Speed

Overview

Trax — Deep Learning with Clear Code and Speed

train tracks PyPI version GitHub Issues GitHub Build Contributions welcome License Gitter

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

Here are a few example notebooks:-

General Setup

Execute the following cell (once) before running any of the code samples.

import os
import numpy as np

!pip install -q -U trax
import trax

1. Run a pre-trained Transformer

Here is how you create an English-German translator in a few lines of code:

# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.

from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix  = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix = 
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

class Embedding(base.Layer):
  """Trainable layer that maps discrete tokens/IDs to vectors."""

  def __init__(self,
               vocab_size,
               d_feature,
               kernel_initializer=init.RandomNormalInitializer(1.0)):
    """Returns an embedding layer with given vocabulary size and vector size.

    Args:
      vocab_size: Size of the input vocabulary. The layer will assign a unique
          vector to each ID in `range(vocab_size)`.
      d_feature: Dimensionality/depth of the output vectors.
      kernel_initializer: Function that creates (random) initial vectors for
          the embedding.
    """
    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
    self._d_feature = d_feature  # feature dimensionality
    self._vocab_size = vocab_size
    self._kernel_initializer = kernel_initializer

  def forward(self, x):
    """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0, mode='clip')

  def init_weights_and_state(self, input_signature):
    """Returns tensor of newly initialized embedding vectors."""
    del input_signature
    shape_w = (self._vocab_size, self._d_feature)
    w = self._kernel_initializer(shape_w, self.rng)
    self.weights = w

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

from trax import layers as tl

# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')

# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y = (15, 32)

Models

Models in Trax are built from layers most often using the Serial and Branch combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
    tl.LogSoftmax()   # Produce log-probabilities.
)

# You can print model structure.
print(model)
Serial[
  Embedding_8192_256
  Mean
  Dense_2
  LogSoftmax
]

Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream) and get a tuple, e.g., (inputs, targets). Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt').

train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream))  # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)

Using the trax.data module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial and they are functions that you apply to streams to create processed streams.

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[256,  64,  16,    4, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.
shapes = [(4, 1024), (4,), (4,)]

Supervised training

When you have the model and the data, use trax.supervised.training to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)
Step      1: Ran 1 train steps in 0.78 secs
Step      1: train WeightedCategoryCrossEntropy |  1.33800304
Step      1: eval  WeightedCategoryCrossEntropy |  0.71843582
Step      1: eval      WeightedCategoryAccuracy |  0.56562500

Step    500: Ran 499 train steps in 5.77 secs
Step    500: train WeightedCategoryCrossEntropy |  0.62914723
Step    500: eval  WeightedCategoryCrossEntropy |  0.49253047
Step    500: eval      WeightedCategoryAccuracy |  0.74062500

Step   1000: Ran 500 train steps in 5.03 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.42949259
Step   1000: eval  WeightedCategoryCrossEntropy |  0.35451687
Step   1000: eval      WeightedCategoryAccuracy |  0.83750000

Step   1500: Ran 500 train steps in 4.80 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.41843575
Step   1500: eval  WeightedCategoryCrossEntropy |  0.35207348
Step   1500: eval      WeightedCategoryAccuracy |  0.82109375

Step   2000: Ran 500 train steps in 5.35 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.38129005
Step   2000: eval  WeightedCategoryCrossEntropy |  0.33760912
Step   2000: eval      WeightedCategoryAccuracy |  0.85312500

After training the model, run it like any layer to get results.

example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: I first saw this when I was a teen in my last year of Junior High. I was riveted to it! I loved the special effects, the fantastic places and the trial-aspect and flashback method of telling the story.<br /><br />Several years later I read the book and while it was interesting and I could definitely see what Swift was trying to say, I think that while it's not as perfect as the book for social commentary, as a story the movie is better. It makes more sense to have it be one long adventure than having Gulliver return after each voyage and making a profit by selling the tiny Lilliput sheep or whatever.<br /><br />It's much more arresting when everyone thinks he's crazy and the sheep DO make a cameo anyway. As a side note, when I saw Laputa I was stunned. It looks very much like the Kingdom of Zeal from the Chrono Trigger video game (1995) that also made me like this mini-series even more.<br /><br />I saw it again about 4 years ago, and realized that I still enjoyed it just as much. Really high quality stuff and began an excellent run of Sweeps mini-series for NBC who followed it up with the solid Merlin and interesting Alice in Wonderland.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[3.984500e-04 9.996014e-01]]
Comments
  • Example code to create a machine translation using trax transformer model.

    Example code to create a machine translation using trax transformer model.

    Description

    Hi, I am little new to this area. Please excuse me if I am asking any stupid question. I wanted to have machine translation model using Transformers and the model is trained on my own data. Can I get an example code where I can pass my own data as input (both input and target) to the trax transofrmers so that it can predict the translated value.

    I will be using my own laptop for preparing and running the model. It does not have GPU.

    Also if possible can I get an example code for speech recognition as well using the transformer model.

    Thanks Nagaraju

    ...

    Environment information

    OS: <win10>
    
    $ pip freeze | grep trax
    # your output here
    
    $ pip freeze | grep tensor
    # your output here
    
    $ pip freeze | grep jax
    # your output here
    
    $ python -V
    # your output here
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    
    # Error logs:
    ...
    
    opened by nag0811 12
  • TFDS From Master Branch Raises NoneType Not Subscriptable Error

    TFDS From Master Branch Raises NoneType Not Subscriptable Error

    Description

    This isn't from the current release on pip, but on February 11 a change was made to master that causes TFDS to crash with a "NoneType is not subscriptable" error on my computer.

    In trax.trax.data.tf_inputs.TFDS there are these lines:

      host_id = jax.host_id() if host_id is None else host_id
      n_hosts = n_hosts or jax.host_count()
      if n_hosts > 1:
        subsplit = (host_id / n_hosts, (host_id + 1) / n_hosts)
      else:
        subsplit = None
    

    On my computer n_hosts = 1, so subsplit is None which gets passed to the _train_and_eval_dataset function and inside that function are these lines:

      if eval_holdout_examples > 0 or subsplit is not None:
        n_train = train_examples - eval_holdout_examples
        train_start = int(n_train * subsplit[0])
        train_end = int(n_train * subsplit[1])
    

    because the conditional has an or and the eval_holdout_examples is greater than 0, the conditional gets past even though subsplit is None, so the attempt to subscript it subsplit[0] raises an exception.

    I don't know if now is the time to report this, since I'm pulling from master (reverting to the last February 10 commit it fixes it for me) but I thought it might be helpful to know if it's not already.

    Environment information

    OS: Ubuntu 20.04 (using the nvidia docker container)
    
    $ pip freeze | grep trax
    -e trax==1.3.7
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.18
    tensorboard==2.4.1
    tensorboard-plugin-wit==1.8.0
    tensorflow==2.4.1
    tensorflow-datasets==4.2.0
    tensorflow-estimator==2.4.0
    tensorflow-hub==0.11.0
    tensorflow-metadata==0.28.0
    tensorflow-text==2.4.3
    
    $ pip freeze | grep jax
    jax==0.2.10
    jaxlib==0.1.61+cuda111
    
    $ python -V
    Python 3.8.5
    

    For bugs: reproduction and error logs

    Steps to reproduce:

    import trax
    path = "data"
    data_set = "opus/medical"
    train_stream_fn = trax.data.TFDS(data_set,
                                     data_dir=path,
                                     keys=('en', 'de'),
                                     eval_holdout_size=0.01,
                                     train=True)
    

    Error logs:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-6-fb62d04026f5> in <module>
          4 # data_set = "para_crawl/ende"
          5 
    ----> 6 train_stream_fn = trax.data.TFDS(data_set,
          7                                  data_dir=path,
          8                                  keys=('en', 'de'),
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1067       scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
       1068       err_str = err_str.format(name, fn_or_cls, scope_info)
    -> 1069       utils.augment_exception_message_and_reraise(e, err_str)
       1070 
       1071   return gin_wrapper
    
    /usr/local/lib/python3.8/dist-packages/gin/utils.py in augment_exception_message_and_reraise(exception, message)
         39   proxy = ExceptionProxy()
         40   ExceptionProxy.__qualname__ = type(exception).__qualname__
    ---> 41   raise proxy.with_traceback(exception.__traceback__) from None
         42 
         43 
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1044 
       1045     try:
    -> 1046       return fn(*new_args, **new_kwargs)
       1047     except Exception as e:  # pylint: disable=broad-except
       1048       err_str = ''
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1067       scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
       1068       err_str = err_str.format(name, fn_or_cls, scope_info)
    -> 1069       utils.augment_exception_message_and_reraise(e, err_str)
       1070 
       1071   return gin_wrapper
    
    /usr/local/lib/python3.8/dist-packages/gin/utils.py in augment_exception_message_and_reraise(exception, message)
         39   proxy = ExceptionProxy()
         40   ExceptionProxy.__qualname__ = type(exception).__qualname__
    ---> 41   raise proxy.with_traceback(exception.__traceback__) from None
         42 
         43 
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1044 
       1045     try:
    -> 1046       return fn(*new_args, **new_kwargs)
       1047     except Exception as e:  # pylint: disable=broad-except
       1048       err_str = ''
    
    ~/trax/trax/data/tf_inputs.py in TFDS(dataset_name, data_dir, tfds_preprocess_fn, keys, train, shuffle_train, host_id, n_hosts, eval_holdout_size)
        279   else:
        280     subsplit = None
    --> 281   (train_data, eval_data, _) = _train_and_eval_dataset(
        282       dataset_name, data_dir, eval_holdout_size,
        283       train_shuffle_files=shuffle_train, subsplit=subsplit)
    
    ~/trax/trax/data/tf_inputs.py in _train_and_eval_dataset(dataset_name, data_dir, eval_holdout_size, train_shuffle_files, eval_shuffle_files, subsplit)
        224   if eval_holdout_examples > 0 or subsplit is not None:
        225     n_train = train_examples - eval_holdout_examples
    --> 226     train_start = int(n_train * subsplit[0])
        227     train_end = int(n_train * subsplit[1])
        228     if train_end - train_start < 1:
    
    TypeError: 'NoneType' object is not subscriptable
      In call to configurable 'TFDS' (<function TFDS at 0x7f960c527280>)
      In call to configurable 'TFDS' (<function TFDS at 0x7f960c526f70>)
    
    opened by necromuralist 10
  • Add the general convolution operation to extensions

    Add the general convolution operation to extensions

    TF XLA version: https://www.tensorflow.org/xla/operation_semantics?hl=en#conv_convolution; JAX version: https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html

    cla: yes ready to pull 
    opened by DarrenZhang01 9
  • Add the general dot operation to Trax NumPy extensions

    Add the general dot operation to Trax NumPy extensions

    opened by DarrenZhang01 8
  • Not able to install trax on local computer (only cpu)

    Not able to install trax on local computer (only cpu)

    Description

    Hi, Not able to install trax on my laptop. I am getting following error. ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none) ERROR: No matching distribution found for jaxlib (from trax)

    I tried to install jaxlib but it is failing

    (base) C:\Users\nb185041>pip install --upgrade jaxlib --user ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none) ERROR: No matching distribution found for jaxlib ...

    Environment information

    OS: <Windows 10>
    
    $ pip freeze | grep trax
    # your output here
    ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none)
    ERROR: No matching distribution found for jaxlib (from trax)
    $ pip freeze | grep tensor
    # your output here
    
    $ pip freeze | grep jax
    # your output here
    
    $ python -V
    # your output here
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    
    # Error logs:
    ...
    (base) C:\Users\nb185041>pip install trax --upgrade --user
    Collecting trax
      Using cached trax-1.3.4-py2.py3-none-any.whl (366 kB)
    Requirement already satisfied, skipping upgrade: gym in c:\programdata\anaconda3\lib\site-packages (from trax) (0.17.2)
    ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none)
    ERROR: No matching distribution found for jaxlib (from trax)
    
    opened by nag0811 8
  • Tensor2Tensor Transformer is not Trax Transormer

    Tensor2Tensor Transformer is not Trax Transormer

    Description

    Hello. I've been playing around with both T2T and Trax libraries for a while. Since Trax has several bugs during inference, I've decided to switch to T2T. However, it seems to me that Transformer in Tensor2Tensor is not the same as in Trax.

    In Tensor2Tensor I create my Transformer model this way:

    hparams_my = {
        'batch_size': 128,
        'batch_shuffle_size': 128,
        'use_fixed_batch_size': True,
        'num_hidden_layers': 1,
        'max_input_seq_length': 252,
        'max_target_seq_length': 252,
        'max_length': 252,
        'symbol_modality_num_shards': 1,
        'filter_size': 2048,
        'dropout': 0.1
    }
    

    In Trax:

    Transformer(input_vocab_size=127,
                    output_vocab_size=127,
                    d_model=512,
                    d_ff=2048,
                    n_encoder_layers=1,
                    n_decoder_layers=1,
                    n_heads=8,
                    dropout=0.1,
                    max_len=2048,
                    mode='train',
                    ff_activation=tl.Relu):
    

    After I run training with T2T, I get this message: (btw, 2 times)

    INFO:tensorflow:Trainable Variables Total size: 7433728
    INFO:tensorflow:Trainable Variables Total size: 7433728
    

    Whereas in Trax I after I call trainer.print_n_weights() I get

    Step      0: Total number of trainable weights: 7614591
    

    I would like to notice, that when I train my Transformer model in Trax, I reach convergence almost immediately (considering the nature of the task - just simple sequence copying with little changes), while with T2T I reach some loss values like 3-4 and no convergence at all.

    Could anybody tell me what do I have to do? It seems like a common problem with T2T Transformer convergence, but I want to emphasise that in Trax it is another Transformer...

    opened by DevKretov 7
  • beam_search.Search() in a single-GPU environment

    beam_search.Search() in a single-GPU environment

    Description

    Training a Transformer converges.

    Then beam_search fails though. When n_devices == 1 some reshapes crash in decode().

    Environment information

    OS: 
    ubuntu 18.04 
    
    CUDA 10.1
    1 GPU environment
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.11
    tensor2tensor==1.15.4
    tensorboard==2.1.0
    tensorflow==2.1.0
    tensorflow-datasets==2.1.0
    tensorflow-estimator==2.1.0
    tensorflow-gan==2.0.0
    tensorflow-hub==0.7.0
    tensorflow-metadata==0.21.1
    tensorflow-probability==0.7.0
    
    
    $ pip freeze | grep jax
    jax==0.1.59
    jaxlib==0.1.39
    
    
    $ python -V
    Python 3.6.9
    

    For bugs: reproduction and error logs

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 442, in pure_fn x, weights=weights, state=state, rng=rng) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 220, in forward_with_state return self.forward(inputs, weights), state File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access File "/root/.local/lib/python3.6/site-packages/trax/layers/attention.py", line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1])) File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 959, in _reshape_method return _reshape(a, newshape, order=order) File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 938, in _reshape return lax.reshape(a, computed_newshape, None) File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 640, in reshape old_sizes=onp.shape(operand)) File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params) File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)) File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand))) TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 480, in _forward_abstract input_signature, weight_signature, self.state, rng) File "/root/.local/lib/python3.6/site-packages/trax/math/jax.py", line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs) File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat)) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 273, in abstract_eval_fun instantiate=True) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng) File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 451, in pure_fn self._caller, signature(x), trace) trax.layers.base.LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init weights, state = self.new_weights_and_state(input_signature) File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state outputs, _ = sublayer._forward_abstract(inputs) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 485, in _forward_abstract trace) trax.layers.base.LayerError: Exception passing through layer Parallel (in _forward_abstract): layer created in file [...]/trax/layers/combinators.py, line 468 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

    File [...]/trax/math/jax.py, line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs)

    File [...]/site-packages/jax/api.py, line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

    File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun instantiate=True)

    File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/trax/layers/base.py, line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng)

    File [...]/trax/layers/combinators.py, line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

    LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init weights, state = self.new_weights_and_state(input_signature) File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state weights_or_empty, state = sublayer.init(inputs) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init input_signature, trace) trax.layers.base.LayerError: Exception passing through layer Serial (in init): layer created in file [...]/trax/layers/combinators.py, line 470 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state outputs, _ = sublayer._forward_abstract(inputs)

    LayerError: Exception passing through layer Parallel (in _forward_abstract): layer created in file [...]/trax/layers/combinators.py, line 468 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

    File [...]/trax/math/jax.py, line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs)

    File [...]/site-packages/jax/api.py, line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

    File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun instantiate=True)

    File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/trax/layers/base.py, line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng)

    File [...]/trax/layers/combinators.py, line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

    LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "math_trax.py", line 565, in seqs, scores = beam_decoder.decode(inputs=batch, batch_size=iBatch_size)#, )targets_prefix=prefix_for_bs, File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 602, in decode dummy=np.zeros(n_devices)) File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 146, in f_jitted name=flat_fun.name) File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 642, in call_bind outs = primitive.impl(f, *args, **params) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 448, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args)) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 220, in memoized_fun ans = call(fun, *args) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 465, in _xla_callable jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 535, in _unreplicated_beam_search self._get_initial_state(inputs, targets_prefix, batch_size), File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 490, in _get_initial_state _, initial_state = self.model(mode='predict').init(signature) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init input_signature, trace) trax.layers.base.LayerError: Exception passing through layer Serial (in init): layer created in file [...]/trax/models/transformer.py, line 301 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(2, 1), dtype:int32})

    File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state weights_or_empty, state = sublayer.init(inputs)

    LayerError: Exception passing through layer Serial (in init): layer created in file [...]/trax/layers/combinators.py, line 470 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state outputs, _ = sublayer._forward_abstract(inputs)

    LayerError: Exception passing through layer Parallel (in _forward_abstract): layer created in file [...]/trax/layers/combinators.py, line 468 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

    File [...]/trax/math/jax.py, line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs)

    File [...]/site-packages/jax/api.py, line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

    File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun instantiate=True)

    File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/trax/layers/base.py, line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng)

    File [...]/trax/layers/combinators.py, line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

    LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    REMARK: 1 is n_devices, 2 is batch size, 30 is max_len

    # Steps to reproduce:
    I tried to force your machine_translation.ipynb in colab to use the GPU but didnt succeed. But maybe for you it's the fastest to check what happens if only 1 GPU as the colab in itsef runs smoothly (on a TPU).
    
    # Error logs:
    ...
    
    opened by friesel 7
  • [Feature request] Examples for advanced machine learning tasks

    [Feature request] Examples for advanced machine learning tasks

    Since Trax is a successor of tensor2tensor (according to the release notes of tensor2tensor v1.15.0), it would be helpful if you could provide examples for more advanced machine learning tasks. An outstanding feature of tensor2tensor are the numerous (and useful) examples which Trax is currently lacking. Such examples would especially be helpful for machine learning tasks with complex input transformations like speech recognition or translation with subword encodings.

    documentation 
    opened by cantwbr 7
  • Update core.py -

    Update core.py -

    Added this note to the LogSoftmax function: Note that the implementation actually computes x - LogSumExp(x), which is mathematically equal to LogSoftmax(x).

    cla: yes ready to pull 
    opened by MichalRyszardWojcik 6
  • Error while retrieving transformer weights in intro notebook

    Error while retrieving transformer weights in intro notebook

    Description

    Initiating weights from the gcloud source leads to a NotFoundError ...

    Environment information

    OS: Pop!_OS 20.04 (Based on Ubuntu 20.04)
    
    $ pip freeze | grep trax
    trax==1.3.4
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.16
    tensor2tensor==1.15.7
    tensorboard==2.3.0
    tensorboard-plugin-wit==1.6.0.post3
    tensorflow==2.3.0
    tensorflow-addons==0.10.0
    tensorflow-datasets==3.2.1
    tensorflow-estimator==2.3.0
    tensorflow-gan==2.0.0
    tensorflow-gpu==2.3.0
    tensorflow-hub==0.8.0
    tensorflow-metadata==0.22.2
    tensorflow-probability==0.7.0
    tensorflow-text==2.3.0
    
    $ pip freeze | grep jax
    jax==0.1.74
    jaxlib==0.1.52
    
    $ python -V
    Python 3.8.2
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    model = trax.models.Transformer(
        input_vocab_size=33300, 
        d_model=512, d_ff=2048, 
        n_heads=8, n_encoder_layers=6, n_decoder_layers=6, 
        max_len=2048, mode='predict') 
    # Initialize using pre-trained weights. 
    model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',  weights_only=True)
    
    # Error logs:
    ...
    NotFoundError: Error executing an HTTP request: HTTP response code 404 with body '<?xml version='1.0' encoding='UTF-8'?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message><Details>No such object: trax-ml/models/translation/end_wmt32k.pkl.gz</Details></Error>'
    	 when reading gs://trax-ml/models/translation/end_wmt32k.pkl.gz
    
    
    opened by Cquential 6
  • [Bug] loss_fn argument for Trainer must not be a function since 1.2.4

    [Bug] loss_fn argument for Trainer must not be a function since 1.2.4

    Description

    As in previous examples shown, loss_fn should be callable like this:

    trainer = trax.supervised.Trainer(
        model=eval(train_model.selector),
        loss_fn=trax.layers.CrossEntropyLoss,
        optimizer=trax.optimizers.Adam,
        lr_schedule=trax.lr.MultifactorSchedule,
        inputs=trax.supervised.inputs.Inputs(train_stream),
        output_dir=output_dir,
    )
    

    However, since the latest upgrade to 1.2.4 this cannot not work anymore.

    In the trainer_lib the loss_fn gets passed to a Serial constructor:

    https://github.com/google/trax/blob/93f2bd47f5f17aacafe3f312ae56ce6f98d93ee7/trax/supervised/trainer_lib.py#L130

    Which in turn runs _ensure_flat in it's constructor

    https://github.com/google/trax/blob/5b1565910a53d0d1175f647cc67db48e334d8f90/trax/layers/combinators.py#L47

    However, all objects in layers have to be of type base.Laser:

    def _ensure_flat(layers):
      """Ensures that layers is a single flat list of Layer instances."""
      if len(layers) == 1 and layers[0] is None:
        layers = ()
      else:
        layers = _deep_flatten(layers)
      for obj in layers:
        if not isinstance(obj, base.Layer):
          raise ValueError(
              f'Found nonlayer object ({obj}) in layers: {layers}')
      return layers
    

    See

    https://github.com/google/trax/blob/5b1565910a53d0d1175f647cc67db48e334d8f90/trax/layers/combinators.py#L775

    Thus we'll see an exception:

    ValueError: Found nonlayer object (<function CrossEntropyLoss at 0x7fc5be59a9e0>) in layers:
    
    opened by stefan-falk 6
  • ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'

    ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'

    Description

    ImportError thrown after importing libraries ...

    Environment information

    trax 1.4.1

    OS: Ubuntu 
    
    $ pip freeze | grep trax
    trax                         1.4.1
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.21
    tensor2tensor==1.15.7
    tensorboard==2.11.0
    tensorboard-data-server==0.6.1
    tensorboard-plugin-wit==1.8.1
    tensorflow==2.11.0
    tensorflow-addons==0.18.0
    tensorflow-datasets==4.7.0
    tensorflow-estimator==2.11.0
    tensorflow-gan==2.1.0
    tensorflow-hub==0.12.0
    tensorflow-io-gcs-filesystem==0.28.0
    tensorflow-metadata==1.12.0
    tensorflow-probability==0.7.0
    tensorflow-text==2.11.0
    tensorstore==0.1.28
    
    $ pip freeze | grep jax
    jax==0.3.25
    jaxlib==0.3.25
    
    $ python -V
    Python 3.8.10
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    !pip install -q -U trax
    
    import numpy as np  # regular ol' numpy
    
    from trax import fastmath
    from trax import layers as tl
    from trax import shapes
    from trax.fastmath import numpy as jnp  # For use in defining new layer types.
    from trax.shapes import ShapeDtype
    from trax.shapes import signature
    # Error logs:
    ---------------------------------------------------------------------------
    ImportError                               Traceback (most recent call last)
    Cell In[22], line 3
          1 import numpy as np  # regular ol' numpy
    ----> 3 from trax import fastmath
          4 from trax import layers as tl
          5 from trax import shapes
    
    File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/__init__.py:18
          1 # coding=utf-8
          2 # Copyright 2021 The Trax Authors.
          3 #
       (...)
         13 # See the License for the specific language governing permissions and
         14 # limitations under the License.
         16 """Trax top level import."""
    ---> 18 from trax import data
         19 from trax import fastmath
         20 from trax import layers
    
    File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/data/__init__.py:70
         67 from trax.data.inputs import UnBatch
         68 from trax.data.inputs import UniformlySeek
    ---> 70 from trax.data.tf_inputs import add_eos_to_output_features
         71 from trax.data.tf_inputs import BertGlueEvalStream
    ...
         35 from trax.layers.attention import SplitIntoHeads
         38 # Layers are always CamelCase, but functions in general are snake_case
         39 # pylint: disable=invalid-name
    
    ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'
    
    opened by jpontalba 0
  • Machine Translation Refromer model.pkl for trax 1.4.1?

    Machine Translation Refromer model.pkl for trax 1.4.1?

    Description

    I am trying to translate the Reformer machine_translation code https://github.com/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb that was written for trax 1.2.9 to make it work with trax 1.4.1 However, my program crashes when I try to load model.pkl with this error: "ModuleNotFoundError: No module named 'trax.history'" My understanding is that trax.history in 1.2.9 has been moved to trax.supervised.history in 1.4.1. I believe I need a new model.pkl for 1.4.1 to make it work. Where can I download the new model.pkl? It would also be great if I can download the new config.gin as well. Thanks a lot in advance. ...

    Environment information

    OS: <your answer here>
    Ubuntu 20.04
    
    $ pip freeze | grep trax
    # your output here
    trax==1.4.1
    
    $ pip freeze | grep tensor
    # your output here
    tensorboard @ file:///home/conda/feedstock_root/build_artifacts/tensorboard_1664238338171/work/tensorboard-2.10.1-py3-none-any.whl
    tensorboard-data-server @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-data-server_1649932776625/work/tensorboard_data_server-0.6.0-py3-none-manylinux2010_x86_64.whl
    tensorboard-plugin-wit @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-plugin-wit_1641458951060/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
    tensorflow==2.10.0
    tensorflow-datasets==4.7.0
    tensorflow-estimator @ file:///home/conda/feedstock_root/build_artifacts/tensorflow-split_1663957899180/work/tensorflow-estimator/wheel_dir/tensorflow_estimator-2.10.0-py2.py3-none-any.whl
    tensorflow-hub==0.12.0
    tensorflow-io-gcs-filesystem==0.27.0
    tensorflow-metadata==1.10.0
    tensorflow-text==2.10.0
    
    $ pip freeze | grep jax
    # your output here
    jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1665610009116/work
    jaxlib==0.3.22
    
    $ python -V
    # your output here
    Python 3.8.10
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    

    import sys import gin import os import pickle import jax import trax import numpy as np import jax.numpy as jnp import sacrebleu from trax.data.text_encoder import SubwordTextEncoder from tensorflow.io.gfile import GFile

    Load the source text and reference translations into Python

    refs = [] for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1): if line.endswith('\n'): line = line[:-1] refs.append(line) srcs = [] for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1): if line.endswith('\n'): line = line[:-1] srcs.append(line)

    Set up our sub-word tokenizer

    tokenizer = SubwordTextEncoder( 'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')

    Encode source sentences using the tokenizer

    input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64) for i, x in enumerate(srcs): x = tokenizer.encode(x) assert len(x) <= 127 input_ids[i, :len(x)] = x input_ids[i, len(x)] = 1

    We'll be using a pre-trained reversible transformer-base model.

    First, load the config (which sets all needed hyperparameters).

    !gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin

    gin.parse_config_file('./config.gin')

    Now we load the pre-trained model weights.

    with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f: model_weights = pickle.load(f)['weights']

    # Error logs:
    ...
    

    Traceback (most recent call last): File "reformer.py", line 65, in model_weights = pickle.load(f)['weights'] ModuleNotFoundError: No module named 'trax.history'

    opened by ymcki 0
  • TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

    TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

    Description

    Hi, I am trying to follow this tutorial: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:

    TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

    This happens at this step: training_loop = training.Loop(model,.....

    Environment information

    OS: 
    NAME="Ubuntu"
    VERSION="18.04.6 LTS (Bionic Beaver)"
    ID=ubuntu
    ID_LIKE=debian
    PRETTY_NAME="Ubuntu 18.04.6 LTS"
    VERSION_ID="18.04"
    HOME_URL="https://www.ubuntu.com/"
    SUPPORT_URL="https://help.ubuntu.com/"
    BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
    PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
    VERSION_CODENAME=bionic
    UBUNTU_CODENAME=bionic
    
    $ pip freeze | grep trax
    # trax==1.4.1
    
    
    $ pip freeze | grep tensor
    # tensorboard==2.10.0
    tensorboard-data-server==0.6.1
    tensorboard-plugin-wit==1.8.1
    tensorflow==2.10.0
    tensorflow-datasets==4.6.0
    tensorflow-estimator==2.10.0
    tensorflow-gcs-config==2.8.0
    tensorflow-hub==0.12.0
    tensorflow-io-gcs-filesystem==0.26.0
    tensorflow-metadata==1.10.0
    tensorflow-probability==0.16.0
    tensorflow-text==2.10.0
    
    $ pip freeze | grep jax
    # jax==0.3.17
    jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl
    
    $ python -V
    # Python 3.7.13
    
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb
    
    ...
    
    # Error logs:
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    [<ipython-input-8-2021642a85f0>](https://localhost:8080/#) in <module>
          9                               train_task,
         10                               eval_tasks=[eval_task],
    ---> 11                               output_dir=output_dir)
    
    16 frames
    [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
        278 
        279     # Create the optimizer for the training loss function.
    --> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
        281 
        282     # Sync layers weights/state in memory effcient trainer layers.
    
    [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in <genexpr>(.0)
        278 
        279     # Create the optimizer for the training loss function.
    --> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
        281 
        282     # Sync layers weights/state in memory effcient trainer layers.
    
    [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in _init_trainer(self, task)
        348         task.optimizer.tree_init(model_in_training.weights)
        349       return optimizers.Trainer(
    --> 350           model_in_training, task.optimizer, adasum=self._adasum)
        351     # In the memory-efficient path, we initialize the model here.
        352     blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(
    
    [/usr/local/lib/python3.7/dist-packages/trax/optimizers/trainer.py](https://localhost:8080/#) in __init__(self, model_with_loss, optimizer, n_devices, adasum)
         57     # optimizer slots and opt_params may need to be replicated
         58     self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices(
    ---> 59         (self._optimizer.slots, self._optimizer.opt_params), self._n_devices))
         60 
         61     # accelerated version of model+loss to replicate weights and state
    
    [/usr/local/lib/python3.7/dist-packages/trax/layers/acceleration.py](https://localhost:8080/#) in on_cpu(x)
        250   """Puts ``x`` in CPU memory in JAX."""
        251   if fastmath.is_backend(fastmath.Backend.JAX):
    --> 252     return jax.device_put(x, jax.devices('cpu')[0])
        253   else:
        254     return x
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in device_put(x, device)
       2722   """
       2723   with config_explicit_device_put_scope():
    -> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
       2725 
       2726 
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest)
        203   leaves, treedef = tree_flatten(tree, is_leaf)
        204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    --> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
        206 
        207 def build_tree(treedef, xs):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
        203   leaves, treedef = tree_flatten(tree, is_leaf)
        204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    --> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
        206 
        207 def build_tree(treedef, xs):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>(y)
       2722   """
       2723   with config_explicit_device_put_scope():
    -> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
       2725 
       2726 
    
    [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params)
        323     assert (not config.jax_enable_checks or
        324             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
    --> 325     return self.bind_with_trace(find_top_trace(args), args, params)
        326 
        327   def bind_with_trace(self, trace, args, params):
    
    [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
        326 
        327   def bind_with_trace(self, trace, args, params):
    --> 328     out = trace.process_primitive(self, map(trace.full_raise, args), params)
        329     return map(full_lower, out) if self.multiple_results else full_lower(out)
        330 
    
    [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
        684 
        685   def process_primitive(self, primitive, tracers, params):
    --> 686     return primitive.impl(*tracers, **params)
        687 
        688   def process_call(self, primitive, f, tracers, params):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_impl(x, device)
       1219     raise TypeError(
       1220         f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
    -> 1221   return aval_to_result_handler(device, a)(None, *device_put(x, device))
       1222 
       1223 
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in device_put(x, device)
       1113   x = xla.canonicalize_dtype(x)
       1114   try:
    -> 1115     return device_put_handlers[type(x)](x, device)
       1116   except KeyError as err:
       1117     raise TypeError(f"No device_put handler for type: {type(x)}") from err
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_array(x, device)
       1124   if x.dtype == dtypes.float0:
       1125     x = np.zeros(x.shape, dtype=np.dtype(bool))
    -> 1126   return (backend.buffer_from_pyval(x, device),)
       1127 
       1128 def _device_put_scalar(x, device):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in __array__(self, dtype, context)
        264 
        265   def __array__(self, dtype=None, context=None):
    --> 266     return np.asarray(self._value, dtype=dtype)
        267 
        268   setattr(device_array, "__array__", __array__)
    
    [/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py](https://localhost:8080/#) in _sda_value(self)
        803     npy_value = np.empty(self.aval.shape, self.aval.dtype)
        804     for i in self.one_replica_buffer_indices:
    --> 805       npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
        806     self._npy_value = npy_value
        807   return self._npy_value
    
    TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
    ...
    
    opened by agoliaei 0
  • The colab button on Knowledge_Tracing_Transformer.ipynb is not open

    The colab button on Knowledge_Tracing_Transformer.ipynb is not open

    Description

    Small issue here: The Open in Colab button in the notebook Knowledge_Tracing_Transformer.ipynb in the directory "trax/trax/examples" leads to a private file on Google Drive.

    opened by haytamdon 0
Releases(v1.4.1)
  • v1.4.1(Oct 26, 2021)

  • v1.4.0(Oct 26, 2021)

  • v1.3.9(May 21, 2021)

  • v1.3.8(Apr 26, 2021)

  • v1.3.7(Dec 18, 2020)

    Lots of documentation, bugs squashed and misc changes.

    Features

    • Load checkpoint for fine-tuning in https://github.com/google/trax/commit/0349c3f2c32ed83d0ba54a0f6652766ea5467cca by @henrykmichalewski !
    • Added generic input pipeline for GLUE tasks in https://github.com/google/trax/commit/2f490e83a9be1b802dccf866c52c8908645e10d6 by @henrykmichalewski - thanks a lot!
    • Some nascent support for bfloat16s!

    Models

    • Performer's Favor and CausalFavor - https://github.com/google/trax/commit/77db199392ff967e49156ccb24c916b270b47eca thanks @lukaszkaiser !
    • Funnel-Transformer in https://github.com/google/trax/pull/1156 thanks a lot @mvxxx !
    • BERT for Trax in https://github.com/google/trax/pull/1254 , https://github.com/google/trax/pull/1223, etc by @piotrekp1 - thanks a lot!
    • Residual Exchange Network by @kkanska in https://github.com/google/trax/commit/3a8f402b722a0d9ae934ed399e6532dba8401f9e ! Thanks!

    PRs Merged

    • Fixing broken example links in https://github.com/google/trax/pull/1263 thanks @amtagrwl !
    • Added WideResnet, Deconv etc Example Notebook in https://github.com/google/trax/pull/1259 , https://github.com/google/trax/pull/1202, https://github.com/google/trax/pull/1232 thanks a lot! @SauravMaheshkar !
    • Remove implicit object from the base class in https://github.com/google/trax/pull/1228 thanks @HarshCasper !
    • Fashion MNIST example in https://github.com/google/trax/pull/1199 thanks @Jimexist !
    • Fix PretrainedBERT init in https://github.com/google/trax/pull/1135 by @hepaajan !
    • Typo in the TransformerDecoder input parameters description in https://github.com/google/trax/pull/1100 by @kujaomega !
    Source code(tar.gz)
    Source code(zip)
  • v1.3.6(Oct 21, 2020)

  • v1.3.5(Sep 19, 2020)

    PRs: Thanks @NathanHowell for fixing a bunch of typos in #962 Thanks @DarrenZhang01 for contributing to the TF-Numpy extensions code in #954

    ReversibleSerialTrainer for memory efficient, layer by layer training.

    Miscellaneous other issues.

    Source code(tar.gz)
    Source code(zip)
  • v1.3.3(Jul 26, 2020)

    Minor Fixes:

    • Rename gumbel_sample to logsoftmax_sample in a2497cbd8477a11f2b96e87a4d53ce46b845ffa8
    • A fix to storing checkpoints in Loop in 88b033c804a4da9021505f2ce8fbb5c7d6500574
    Source code(tar.gz)
    Source code(zip)
  • v1.3.2(Jul 24, 2020)

    Framework:

    • Data pipeline combinators in 17d44710bee74e8bb6e3b34baa4114b0c78160af
    • Multi-host training in the Loop api 524321e0f77afa58a9fba470e5bccb19ae6bdb92
    • Auto-regressive sampling in fe0fa7886ee33bfac831a969a82b351ccb02941c
    • T2T Tokenizers 6306231ace01ab55c5df0c401b1f3cf4c7b6a32f
    • Early work on multi-task training in 423d664b470612ebbc0e5041e8288eb1a47c30ef thanks to @koz4k

    Tasks / Gin files:

    • IMDB dataset in 08bdb50db41a089b6e12c61a8a29981225bf3566
    • Machine Translation en-pl in 5fb8aa8c5cb86dabb2338938c745996d5d87d996
    Source code(tar.gz)
    Source code(zip)
  • v1.3.1(Jul 2, 2020)

    Miscellaneous fixes.

    • tl.Embedding now has the same signature as pytorch/tf
    • train.lr_schedule (function object) -> train.lr_schedule_fn (function)
    • Report loss back to training.Loop
    Source code(tar.gz)
    Source code(zip)
  • v1.3.0(Jun 30, 2020)

    Trax now has docs on - https://trax-ml.readthedocs.io/en/latest/trax.html thanks to @j2i2 !

    Many usability changes, especially in trax.supervised.training.TrainTask/EvalTask/Loop, docs, comments etc.

    • flat saved model/checkpoint representation
    • lr schedule simplified, now they just take step number.
    • configs are now in supervised/configs and rl/configs.
    • RL obsolete code cleanup.

    Also rapid development of the tf-numpy codebase !

    Source code(tar.gz)
    Source code(zip)
  • v1.2.4(Apr 18, 2020)

    Merged PRs:

    • #459 by @w4-sjcho - adding names to layers, aiding debuggability thanks a lot!
    • #256 and #300 by @stephenjfox and @satyarohith refining the README.md language, thanks a lot folks!
    • #313 #312 #436 #396 from @pkol with lots of bugfixes, thanks a lot @pkol !
    • #409 by @pkol -- a special shoutout to this PR, this fixes a long standing issue that prevented termination of the process by tracking the zombie threads -- thanks a lot for this @pkol specially !
    • #386 another shoutout to @pkol for an amazing speedup in the RL code -- thanks a lot again !
    • #344 a psum bugfix with tf backend from @fsx950223 - thanks a lot !
    • #335 a bugfix from @friesel - thanks a lot Phillip !
    • #315 better exception handling by @cool-RR - thanks a lot !

    Reformer:

    • BERT initialization and finetuning by Nikita!
    • Many changes including ReformerLM on C4 dataset.

    RL:

    • New 'light' RL code in the Trax supervised style, check it out!
    • AWR in the old code working with MuJoCo tasks.

    And many more changes the Trax framework !

    Source code(tar.gz)
    Source code(zip)
  • v1.2.3(Feb 25, 2020)

    Reformer

    • Reversible Transformer model for machine translation and other encoder-decoder tasks
    • Add code for beam search, sampling, and greedy decoding (see trax.models.beam_search.Search)
    • Memory-efficient attention classes have been re-written to use even less memory and to support faster decoding (see the new SelfAttention, LSHSelfAttention and EncDecAttention classes)

    RL

    • Implemented the Advantage-Weighted Regression algorithm, a simple off-policy reinforcement learning algorithm.
    • Extracted out a PolicyBasedTrainer, so ppo_trainer.PPO and awr_trainer.AwrTrainer now both inherit from it.
    • Refactoring of the serialization code in the RL part, thanks to @koz4k !

    Framework

    • A lot of code cleanup and refactoring of the core abstractions by Jonni, thanks Jonni!

    TF Numpy

    • More ops added by @wangpengmit !
    Source code(tar.gz)
    Source code(zip)
  • v1.2.2(Jan 17, 2020)

  • v1.2.1(Jan 16, 2020)

  • v1.2.0(Jan 15, 2020)

    New Models

    • Reformer Implementation - https://arxiv.org/abs/2001.04451 Thanks Nikita Kitaev, @lukaszkaiser and @levskaya !

    Colabs

    • Reformer Colabs https://github.com/google/trax/commit/9dbf8636a34cffbd421bedb2a1e3d7fe006346c0
    • Transformer Colab https://github.com/google/trax/commit/d2c5b84b8db53888013c56980ccbafba077ee8f2

    Framework Changes

    • Ongoing cleanups and API simplifications.
    • Optimization by @jekbradbury - thanks James!

    PRs

    • Consistent logging absl.logging and setup.py fixes thanks to @lkhphuc in #198
    • Code cleanups by @cclauss in #196
    • Code cleanup by @pkol in #134 - thanks!
    • Bug fix by @pzielinski-nyc in #151 - thanks!
    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
CS583: Deep Learning

CS583: Deep Learning

Shusen Wang 2.6k Dec 30, 2022
PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition

PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition The unofficial code of CDistNet. Now, we ha

25 Jul 20, 2022
NAS-Bench-x11 and the Power of Learning Curves

NAS-Bench-x11 NAS-Bench-x11 and the Power of Learning Curves Shen Yan, Colin White, Yash Savani, Frank Hutter. NeurIPS 2021. Surrogate NAS benchmarks

AutoML-Freiburg-Hannover 13 Nov 18, 2022
An Efficient Training Approach for Very Large Scale Face Recognition or F²C for simplicity.

Fast Face Classification (F²C) This is the code of our paper An Efficient Training Approach for Very Large Scale Face Recognition or F²C for simplicit

33 Jun 27, 2021
A clean and robust Pytorch implementation of PPO on continuous action space.

PPO-Continuous-Pytorch I found the current implementation of PPO on continuous action space is whether somewhat complicated or not stable. And this is

XinJingHao 56 Dec 16, 2022
Implementation of Pix2Seq in PyTorch

pix2seq-pytorch Implementation of Pix2Seq paper Different from the paper image input size 1280 bin size 1280 LambdaLR scheduler used instead of Linear

Tony Shin 9 Dec 15, 2022
A whale detector design for the Kaggle whale-detector challenge!

CNN (InceptionV1) + STFT based Whale Detection Algorithm So, this repository is my PyTorch solution for the Kaggle whale-detection challenge. The obje

Tarin Ziyaee 92 Sep 28, 2021
Improving Contrastive Learning by Visualizing Feature Transformation, ICCV 2021 Oral

Improving Contrastive Learning by Visualizing Feature Transformation This project hosts the codes, models and visualization tools for the paper: Impro

Bingchen Zhao 83 Dec 15, 2022
Towards End-to-end Video-based Eye Tracking

Towards End-to-end Video-based Eye Tracking The code accompanying our ECCV 2020 publication and dataset, EVE. Authors: Seonwook Park, Emre Aksan, Xuco

Seonwook Park 76 Dec 12, 2022
Rest API Written In Python To Classify NSFW Images.

Rest API Written In Python To Classify NSFW Images.

Wahyusaputra 2 Dec 23, 2021
Anime Face Detector using mmdet and mmpose

Anime Face Detector This is an anime face detector using mmdetection and mmpose. (To avoid copyright issues, I use generated images by the TADNE model

198 Jan 07, 2023
Rot-Pro: Modeling Transitivity by Projection in Knowledge Graph Embedding

Rot-Pro : Modeling Transitivity by Projection in Knowledge Graph Embedding This repository contains the source code for the Rot-Pro model, presented a

Tewi 9 Sep 28, 2022
Self-Supervised Methods for Noise-Removal

SSMNR | Self-Supervised Methods for Noise Removal Image denoising is the task of removing noise from an image, which can be formulated as the task of

1 Jan 16, 2022
Code for "Unsupervised State Representation Learning in Atari"

Unsupervised State Representation Learning in Atari Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Côté, R Devon Hjelm This

Mila 217 Jan 03, 2023
Generate indoor scenes with Transformers

SceneFormer: Indoor Scene Generation with Transformers Initial code release for the Sceneformer paper, contains models, train and test scripts for the

Chandan Yeshwanth 110 Dec 06, 2022
This is a repository for a Semantic Segmentation inference API using the Gluoncv CV toolkit

BMW Semantic Segmentation GPU/CPU Inference API This is a repository for a Semantic Segmentation inference API using the Gluoncv CV toolkit. The train

BMW TechOffice MUNICH 56 Nov 24, 2022
Face Mask Detector by live camera using tensorflow-keras, openCV and Python

Face Mask Detector 😷 by Live Camera Detecting masked or unmasked faces by live camera with percentange of mask occupation About Project: This an Arti

Karan Shingde 2 Apr 04, 2022
Automatically creates genre collections for your Plex media

Plex Auto Genres Plex Auto Genres is a simple script that will add genre collection tags to your media making it much easier to search for genre speci

Shane Israel 63 Dec 31, 2022
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 01, 2022
Leveraging OpenAI's Codex to solve cornerstone problems in Music

Music-Codex Leveraging OpenAI's Codex to solve cornerstone problems in Music Please NOTE: Presented generated samples were created by OpenAI's Codex P

Alex 2 Mar 11, 2022