Trax — Deep Learning with Clear Code and Speed

Overview

Trax — Deep Learning with Clear Code and Speed

train tracks PyPI version GitHub Issues GitHub Build Contributions welcome License Gitter

Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.

  1. Run a pre-trained Transformer: create a translator in a few lines of code
  2. Features and resources: API docs, where to talk to us, how to open an issue and more
  3. Walkthrough: how Trax works, how to make new models and train on your own data

We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!

Here are a few example notebooks:-

General Setup

Execute the following cell (once) before running any of the code samples.

import os
import numpy as np

!pip install -q -U trax
import trax

1. Run a pre-trained Transformer

Here is how you create an English-German translator in a few lines of code:

# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!

2. Features and resources

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.

3. Walkthrough

You can learn here how Trax works, how to create new models and how to train them on your own data.

Tensors and Fast Math

The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.

In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath package thanks to its backends -- JAX and TensorFlow numpy.

from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.

matrix  = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix = 
[[1 2 3]
 [4 5 6]
 [7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]

Gradients can be calculated using trax.fastmath.grad.

def f(x):
  return 2.0 * x * x

grad_f = trax.fastmath.grad(f)

print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0

Layers

Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding:

class Embedding(base.Layer):
  """Trainable layer that maps discrete tokens/IDs to vectors."""

  def __init__(self,
               vocab_size,
               d_feature,
               kernel_initializer=init.RandomNormalInitializer(1.0)):
    """Returns an embedding layer with given vocabulary size and vector size.

    Args:
      vocab_size: Size of the input vocabulary. The layer will assign a unique
          vector to each ID in `range(vocab_size)`.
      d_feature: Dimensionality/depth of the output vectors.
      kernel_initializer: Function that creates (random) initial vectors for
          the embedding.
    """
    super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
    self._d_feature = d_feature  # feature dimensionality
    self._vocab_size = vocab_size
    self._kernel_initializer = kernel_initializer

  def forward(self, x):
    """Returns embedding vectors corresponding to input token IDs.

    Args:
      x: Tensor of token IDs.

    Returns:
      Tensor of embedding vectors.
    """
    return jnp.take(self.weights, x, axis=0, mode='clip')

  def init_weights_and_state(self, input_signature):
    """Returns tensor of newly initialized embedding vectors."""
    del input_signature
    shape_w = (self._vocab_size, self._d_feature)
    w = self._kernel_initializer(shape_w, self.rng)
    self.weights = w

Layers with trainable weights like Embedding need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.

from trax import layers as tl

# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')

# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))

# Run the layer -- y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14]
shape of y = (15, 32)

Models

Models in Trax are built from layers most often using the Serial and Branch combinators. You can read more about those combinators in the layers intro and see the code for many models in trax/models/, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
    tl.LogSoftmax()   # Produce log-probabilities.
)

# You can print model structure.
print(model)
Serial[
  Embedding_8192_256
  Mean
  Dense_2
  LogSoftmax
]

Data

To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream) and get a tuple, e.g., (inputs, targets). Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt').

train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream))  # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)

Using the trax.data module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial and they are functions that you apply to streams to create processed streams.

data_pipeline = trax.data.Serial(
    trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
    trax.data.Shuffle(),
    trax.data.FilterByLength(max_length=2048, length_keys=[0]),
    trax.data.BucketByLength(boundaries=[  32, 128, 512, 2048],
                             batch_sizes=[256,  64,  16,    4, 1],
                             length_keys=[0]),
    trax.data.AddLossWeights()
  )
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}')  # Check the shapes.
shapes = [(4, 1024), (4,), (4,)]

Supervised training

When you have the model and the data, use trax.supervised.training to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.

from trax.supervised import training

# Training task.
train_task = training.TrainTask(
    labeled_data=train_batches_stream,
    loss_layer=tl.WeightedCategoryCrossEntropy(),
    optimizer=trax.optimizers.Adam(0.01),
    n_steps_per_checkpoint=500,
)

# Evaluaton task.
eval_task = training.EvalTask(
    labeled_data=eval_batches_stream,
    metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
    n_eval_batches=20  # For less variance in eval numbers.
)

# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
                              train_task,
                              eval_tasks=[eval_task],
                              output_dir=output_dir)

# Run 2000 steps (batches).
training_loop.run(2000)
Step      1: Ran 1 train steps in 0.78 secs
Step      1: train WeightedCategoryCrossEntropy |  1.33800304
Step      1: eval  WeightedCategoryCrossEntropy |  0.71843582
Step      1: eval      WeightedCategoryAccuracy |  0.56562500

Step    500: Ran 499 train steps in 5.77 secs
Step    500: train WeightedCategoryCrossEntropy |  0.62914723
Step    500: eval  WeightedCategoryCrossEntropy |  0.49253047
Step    500: eval      WeightedCategoryAccuracy |  0.74062500

Step   1000: Ran 500 train steps in 5.03 secs
Step   1000: train WeightedCategoryCrossEntropy |  0.42949259
Step   1000: eval  WeightedCategoryCrossEntropy |  0.35451687
Step   1000: eval      WeightedCategoryAccuracy |  0.83750000

Step   1500: Ran 500 train steps in 4.80 secs
Step   1500: train WeightedCategoryCrossEntropy |  0.41843575
Step   1500: eval  WeightedCategoryCrossEntropy |  0.35207348
Step   1500: eval      WeightedCategoryAccuracy |  0.82109375

Step   2000: Ran 500 train steps in 5.35 secs
Step   2000: train WeightedCategoryCrossEntropy |  0.38129005
Step   2000: eval  WeightedCategoryCrossEntropy |  0.33760912
Step   2000: eval      WeightedCategoryAccuracy |  0.85312500

After training the model, run it like any layer to get results.

example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :])  # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: I first saw this when I was a teen in my last year of Junior High. I was riveted to it! I loved the special effects, the fantastic places and the trial-aspect and flashback method of telling the story.<br /><br />Several years later I read the book and while it was interesting and I could definitely see what Swift was trying to say, I think that while it's not as perfect as the book for social commentary, as a story the movie is better. It makes more sense to have it be one long adventure than having Gulliver return after each voyage and making a profit by selling the tiny Lilliput sheep or whatever.<br /><br />It's much more arresting when everyone thinks he's crazy and the sheep DO make a cameo anyway. As a side note, when I saw Laputa I was stunned. It looks very much like the Kingdom of Zeal from the Chrono Trigger video game (1995) that also made me like this mini-series even more.<br /><br />I saw it again about 4 years ago, and realized that I still enjoyed it just as much. Really high quality stuff and began an excellent run of Sweeps mini-series for NBC who followed it up with the solid Merlin and interesting Alice in Wonderland.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[3.984500e-04 9.996014e-01]]
Comments
  • Example code to create a machine translation using trax transformer model.

    Example code to create a machine translation using trax transformer model.

    Description

    Hi, I am little new to this area. Please excuse me if I am asking any stupid question. I wanted to have machine translation model using Transformers and the model is trained on my own data. Can I get an example code where I can pass my own data as input (both input and target) to the trax transofrmers so that it can predict the translated value.

    I will be using my own laptop for preparing and running the model. It does not have GPU.

    Also if possible can I get an example code for speech recognition as well using the transformer model.

    Thanks Nagaraju

    ...

    Environment information

    OS: <win10>
    
    $ pip freeze | grep trax
    # your output here
    
    $ pip freeze | grep tensor
    # your output here
    
    $ pip freeze | grep jax
    # your output here
    
    $ python -V
    # your output here
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    
    # Error logs:
    ...
    
    opened by nag0811 12
  • TFDS From Master Branch Raises NoneType Not Subscriptable Error

    TFDS From Master Branch Raises NoneType Not Subscriptable Error

    Description

    This isn't from the current release on pip, but on February 11 a change was made to master that causes TFDS to crash with a "NoneType is not subscriptable" error on my computer.

    In trax.trax.data.tf_inputs.TFDS there are these lines:

      host_id = jax.host_id() if host_id is None else host_id
      n_hosts = n_hosts or jax.host_count()
      if n_hosts > 1:
        subsplit = (host_id / n_hosts, (host_id + 1) / n_hosts)
      else:
        subsplit = None
    

    On my computer n_hosts = 1, so subsplit is None which gets passed to the _train_and_eval_dataset function and inside that function are these lines:

      if eval_holdout_examples > 0 or subsplit is not None:
        n_train = train_examples - eval_holdout_examples
        train_start = int(n_train * subsplit[0])
        train_end = int(n_train * subsplit[1])
    

    because the conditional has an or and the eval_holdout_examples is greater than 0, the conditional gets past even though subsplit is None, so the attempt to subscript it subsplit[0] raises an exception.

    I don't know if now is the time to report this, since I'm pulling from master (reverting to the last February 10 commit it fixes it for me) but I thought it might be helpful to know if it's not already.

    Environment information

    OS: Ubuntu 20.04 (using the nvidia docker container)
    
    $ pip freeze | grep trax
    -e trax==1.3.7
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.18
    tensorboard==2.4.1
    tensorboard-plugin-wit==1.8.0
    tensorflow==2.4.1
    tensorflow-datasets==4.2.0
    tensorflow-estimator==2.4.0
    tensorflow-hub==0.11.0
    tensorflow-metadata==0.28.0
    tensorflow-text==2.4.3
    
    $ pip freeze | grep jax
    jax==0.2.10
    jaxlib==0.1.61+cuda111
    
    $ python -V
    Python 3.8.5
    

    For bugs: reproduction and error logs

    Steps to reproduce:

    import trax
    path = "data"
    data_set = "opus/medical"
    train_stream_fn = trax.data.TFDS(data_set,
                                     data_dir=path,
                                     keys=('en', 'de'),
                                     eval_holdout_size=0.01,
                                     train=True)
    

    Error logs:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-6-fb62d04026f5> in <module>
          4 # data_set = "para_crawl/ende"
          5 
    ----> 6 train_stream_fn = trax.data.TFDS(data_set,
          7                                  data_dir=path,
          8                                  keys=('en', 'de'),
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1067       scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
       1068       err_str = err_str.format(name, fn_or_cls, scope_info)
    -> 1069       utils.augment_exception_message_and_reraise(e, err_str)
       1070 
       1071   return gin_wrapper
    
    /usr/local/lib/python3.8/dist-packages/gin/utils.py in augment_exception_message_and_reraise(exception, message)
         39   proxy = ExceptionProxy()
         40   ExceptionProxy.__qualname__ = type(exception).__qualname__
    ---> 41   raise proxy.with_traceback(exception.__traceback__) from None
         42 
         43 
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1044 
       1045     try:
    -> 1046       return fn(*new_args, **new_kwargs)
       1047     except Exception as e:  # pylint: disable=broad-except
       1048       err_str = ''
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1067       scope_info = " in scope '{}'".format(scope_str) if scope_str else ''
       1068       err_str = err_str.format(name, fn_or_cls, scope_info)
    -> 1069       utils.augment_exception_message_and_reraise(e, err_str)
       1070 
       1071   return gin_wrapper
    
    /usr/local/lib/python3.8/dist-packages/gin/utils.py in augment_exception_message_and_reraise(exception, message)
         39   proxy = ExceptionProxy()
         40   ExceptionProxy.__qualname__ = type(exception).__qualname__
    ---> 41   raise proxy.with_traceback(exception.__traceback__) from None
         42 
         43 
    
    /usr/local/lib/python3.8/dist-packages/gin/config.py in gin_wrapper(*args, **kwargs)
       1044 
       1045     try:
    -> 1046       return fn(*new_args, **new_kwargs)
       1047     except Exception as e:  # pylint: disable=broad-except
       1048       err_str = ''
    
    ~/trax/trax/data/tf_inputs.py in TFDS(dataset_name, data_dir, tfds_preprocess_fn, keys, train, shuffle_train, host_id, n_hosts, eval_holdout_size)
        279   else:
        280     subsplit = None
    --> 281   (train_data, eval_data, _) = _train_and_eval_dataset(
        282       dataset_name, data_dir, eval_holdout_size,
        283       train_shuffle_files=shuffle_train, subsplit=subsplit)
    
    ~/trax/trax/data/tf_inputs.py in _train_and_eval_dataset(dataset_name, data_dir, eval_holdout_size, train_shuffle_files, eval_shuffle_files, subsplit)
        224   if eval_holdout_examples > 0 or subsplit is not None:
        225     n_train = train_examples - eval_holdout_examples
    --> 226     train_start = int(n_train * subsplit[0])
        227     train_end = int(n_train * subsplit[1])
        228     if train_end - train_start < 1:
    
    TypeError: 'NoneType' object is not subscriptable
      In call to configurable 'TFDS' (<function TFDS at 0x7f960c527280>)
      In call to configurable 'TFDS' (<function TFDS at 0x7f960c526f70>)
    
    opened by necromuralist 10
  • Add the general convolution operation to extensions

    Add the general convolution operation to extensions

    TF XLA version: https://www.tensorflow.org/xla/operation_semantics?hl=en#conv_convolution; JAX version: https://jax.readthedocs.io/en/stable/_autosummary/jax.lax.conv_general_dilated.html

    cla: yes ready to pull 
    opened by DarrenZhang01 9
  • Add the general dot operation to Trax NumPy extensions

    Add the general dot operation to Trax NumPy extensions

    opened by DarrenZhang01 8
  • Not able to install trax on local computer (only cpu)

    Not able to install trax on local computer (only cpu)

    Description

    Hi, Not able to install trax on my laptop. I am getting following error. ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none) ERROR: No matching distribution found for jaxlib (from trax)

    I tried to install jaxlib but it is failing

    (base) C:\Users\nb185041>pip install --upgrade jaxlib --user ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none) ERROR: No matching distribution found for jaxlib ...

    Environment information

    OS: <Windows 10>
    
    $ pip freeze | grep trax
    # your output here
    ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none)
    ERROR: No matching distribution found for jaxlib (from trax)
    $ pip freeze | grep tensor
    # your output here
    
    $ pip freeze | grep jax
    # your output here
    
    $ python -V
    # your output here
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    
    # Error logs:
    ...
    (base) C:\Users\nb185041>pip install trax --upgrade --user
    Collecting trax
      Using cached trax-1.3.4-py2.py3-none-any.whl (366 kB)
    Requirement already satisfied, skipping upgrade: gym in c:\programdata\anaconda3\lib\site-packages (from trax) (0.17.2)
    ERROR: Could not find a version that satisfies the requirement jaxlib (from trax) (from versions: none)
    ERROR: No matching distribution found for jaxlib (from trax)
    
    opened by nag0811 8
  • Tensor2Tensor Transformer is not Trax Transormer

    Tensor2Tensor Transformer is not Trax Transormer

    Description

    Hello. I've been playing around with both T2T and Trax libraries for a while. Since Trax has several bugs during inference, I've decided to switch to T2T. However, it seems to me that Transformer in Tensor2Tensor is not the same as in Trax.

    In Tensor2Tensor I create my Transformer model this way:

    hparams_my = {
        'batch_size': 128,
        'batch_shuffle_size': 128,
        'use_fixed_batch_size': True,
        'num_hidden_layers': 1,
        'max_input_seq_length': 252,
        'max_target_seq_length': 252,
        'max_length': 252,
        'symbol_modality_num_shards': 1,
        'filter_size': 2048,
        'dropout': 0.1
    }
    

    In Trax:

    Transformer(input_vocab_size=127,
                    output_vocab_size=127,
                    d_model=512,
                    d_ff=2048,
                    n_encoder_layers=1,
                    n_decoder_layers=1,
                    n_heads=8,
                    dropout=0.1,
                    max_len=2048,
                    mode='train',
                    ff_activation=tl.Relu):
    

    After I run training with T2T, I get this message: (btw, 2 times)

    INFO:tensorflow:Trainable Variables Total size: 7433728
    INFO:tensorflow:Trainable Variables Total size: 7433728
    

    Whereas in Trax I after I call trainer.print_n_weights() I get

    Step      0: Total number of trainable weights: 7614591
    

    I would like to notice, that when I train my Transformer model in Trax, I reach convergence almost immediately (considering the nature of the task - just simple sequence copying with little changes), while with T2T I reach some loss values like 3-4 and no convergence at all.

    Could anybody tell me what do I have to do? It seems like a common problem with T2T Transformer convergence, but I want to emphasise that in Trax it is another Transformer...

    opened by DevKretov 7
  • beam_search.Search() in a single-GPU environment

    beam_search.Search() in a single-GPU environment

    Description

    Training a Transformer converges.

    Then beam_search fails though. When n_devices == 1 some reshapes crash in decode().

    Environment information

    OS: 
    ubuntu 18.04 
    
    CUDA 10.1
    1 GPU environment
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.11
    tensor2tensor==1.15.4
    tensorboard==2.1.0
    tensorflow==2.1.0
    tensorflow-datasets==2.1.0
    tensorflow-estimator==2.1.0
    tensorflow-gan==2.0.0
    tensorflow-hub==0.7.0
    tensorflow-metadata==0.21.1
    tensorflow-probability==0.7.0
    
    
    $ pip freeze | grep jax
    jax==0.1.59
    jaxlib==0.1.39
    
    
    $ python -V
    Python 3.6.9
    

    For bugs: reproduction and error logs

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 442, in pure_fn x, weights=weights, state=state, rng=rng) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 220, in forward_with_state return self.forward(inputs, weights), state File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access File "/root/.local/lib/python3.6/site-packages/trax/layers/attention.py", line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1])) File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 959, in _reshape_method return _reshape(a, newshape, order=order) File "/root/.local/lib/python3.6/site-packages/jax/numpy/lax_numpy.py", line 938, in _reshape return lax.reshape(a, computed_newshape, None) File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 640, in reshape old_sizes=onp.shape(operand)) File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params) File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)) File "/root/.local/lib/python3.6/site-packages/jax/lax/lax.py", line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand))) TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 480, in _forward_abstract input_signature, weight_signature, self.state, rng) File "/root/.local/lib/python3.6/site-packages/trax/math/jax.py", line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs) File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat)) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 273, in abstract_eval_fun instantiate=True) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng) File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 451, in pure_fn self._caller, signature(x), trace) trax.layers.base.LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init weights, state = self.new_weights_and_state(input_signature) File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 92, in new_weights_and_state outputs, _ = sublayer._forward_abstract(inputs) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 485, in _forward_abstract trace) trax.layers.base.LayerError: Exception passing through layer Parallel (in _forward_abstract): layer created in file [...]/trax/layers/combinators.py, line 468 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

    File [...]/trax/math/jax.py, line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs)

    File [...]/site-packages/jax/api.py, line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

    File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun instantiate=True)

    File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/trax/layers/base.py, line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng)

    File [...]/trax/layers/combinators.py, line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

    LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 310, in init weights, state = self.new_weights_and_state(input_signature) File "/root/.local/lib/python3.6/site-packages/trax/layers/combinators.py", line 91, in new_weights_and_state weights_or_empty, state = sublayer.init(inputs) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init input_signature, trace) trax.layers.base.LayerError: Exception passing through layer Serial (in init): layer created in file [...]/trax/layers/combinators.py, line 470 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state outputs, _ = sublayer._forward_abstract(inputs)

    LayerError: Exception passing through layer Parallel (in _forward_abstract): layer created in file [...]/trax/layers/combinators.py, line 468 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

    File [...]/trax/math/jax.py, line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs)

    File [...]/site-packages/jax/api.py, line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

    File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun instantiate=True)

    File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/trax/layers/base.py, line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng)

    File [...]/trax/layers/combinators.py, line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

    LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "math_trax.py", line 565, in seqs, scores = beam_decoder.decode(inputs=batch, batch_size=iBatch_size)#, )targets_prefix=prefix_for_bs, File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 602, in decode dummy=np.zeros(n_devices)) File "/root/.local/lib/python3.6/site-packages/jax/api.py", line 146, in f_jitted name=flat_fun.name) File "/root/.local/lib/python3.6/site-packages/jax/core.py", line 642, in call_bind outs = primitive.impl(f, *args, **params) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 448, in _xla_call_impl compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args)) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 220, in memoized_fun ans = call(fun, *args) File "/root/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 465, in _xla_callable jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals) File "/root/.local/lib/python3.6/site-packages/jax/linear_util.py", line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 535, in _unreplicated_beam_search self._get_initial_state(inputs, targets_prefix, batch_size), File "/root/.local/lib/python3.6/site-packages/trax/models/beam_search.py", line 490, in _get_initial_state _, initial_state = self.model(mode='predict').init(signature) File "/root/.local/lib/python3.6/site-packages/trax/layers/base.py", line 321, in init input_signature, trace) trax.layers.base.LayerError: Exception passing through layer Serial (in init): layer created in file [...]/trax/models/transformer.py, line 301 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(2, 1), dtype:int32})

    File [...]/trax/layers/combinators.py, line 91, in new_weights_and_state weights_or_empty, state = sublayer.init(inputs)

    LayerError: Exception passing through layer Serial (in init): layer created in file [...]/trax/layers/combinators.py, line 470 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/combinators.py, line 92, in new_weights_and_state outputs, _ = sublayer._forward_abstract(inputs)

    LayerError: Exception passing through layer Parallel (in _forward_abstract): layer created in file [...]/trax/layers/combinators.py, line 468 layer input shapes: (ShapeDtype{shape:(1, 2, 30), dtype:int32}, ShapeDtype{shape:(1, 2, 30), dtype:int32})

    File [...]/trax/math/jax.py, line 175, in shape_fun jax_shapes = jax.eval_shape(f, *args, **kwargs)

    File [...]/site-packages/jax/api.py, line 2042, in eval_shape out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))

    File [...]/jax/interpreters/partial_eval.py, line 273, in abstract_eval_fun instantiate=True)

    File [...]/jax/interpreters/partial_eval.py, line 354, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/site-packages/jax/linear_util.py, line 149, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs))

    File [...]/trax/layers/base.py, line 477, in call_on_input return self.forward_with_state(x, weights=weights, state=state, rng=rng)

    File [...]/trax/layers/combinators.py, line 238, in forward_with_state sub_outputs, sub_state = layer.pure_fn(x, w, s, r)

    LayerError: Exception passing through layer PaddingMask (in pure_fn): layer created in file [...]/trax/models/transformer.py, line 286 layer input shapes: ShapeDtype{shape:(1, 2, 30), dtype:int32}

    File [...]/trax/layers/base.py, line 220, in forward_with_state return self.forward(inputs, weights), state

    File [...]/trax/layers/base.py, line 580, in _forward raw_output = raw_fn(x, weights=weights, **self._kwargs) # pylint: disable=protected-access

    File [...]/trax/layers/attention.py, line 51, in PaddingMask return np.reshape(x != pad, (x.shape[0], 1, 1, x.shape[-1]))

    File [...]/jax/numpy/lax_numpy.py, line 921, in reshape return a.reshape(newshape, order=order) # forward to method for ndarrays

    File [...]/jax/numpy/lax_numpy.py, line 959, in _reshape_method return _reshape(a, newshape, order=order)

    File [...]/jax/numpy/lax_numpy.py, line 938, in _reshape return lax.reshape(a, computed_newshape, None)

    File [...]/jax/lax/lax.py, line 640, in reshape old_sizes=onp.shape(operand))

    File [...]/site-packages/jax/core.py, line 182, in bind out_tracer = top_trace.process_primitive(self, tracers, kwargs)

    File [...]/jax/interpreters/partial_eval.py, line 98, in process_primitive return self.default_process_primitive(primitive, tracers, params)

    File [...]/jax/interpreters/partial_eval.py, line 106, in default_process_primitive out_aval = primitive.abstract_eval(*avals, **params)

    File [...]/jax/lax/lax.py, line 1523, in standard_abstract_eval return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))

    File [...]/jax/lax/lax.py, line 2582, in _reshape_shape_rule raise TypeError(msg.format(new_sizes, onp.shape(operand)))

    TypeError: reshape total size must be unchanged, got new_sizes (1, 1, 1, 30) for shape (1, 2, 30).

    REMARK: 1 is n_devices, 2 is batch size, 30 is max_len

    # Steps to reproduce:
    I tried to force your machine_translation.ipynb in colab to use the GPU but didnt succeed. But maybe for you it's the fastest to check what happens if only 1 GPU as the colab in itsef runs smoothly (on a TPU).
    
    # Error logs:
    ...
    
    opened by friesel 7
  • [Feature request] Examples for advanced machine learning tasks

    [Feature request] Examples for advanced machine learning tasks

    Since Trax is a successor of tensor2tensor (according to the release notes of tensor2tensor v1.15.0), it would be helpful if you could provide examples for more advanced machine learning tasks. An outstanding feature of tensor2tensor are the numerous (and useful) examples which Trax is currently lacking. Such examples would especially be helpful for machine learning tasks with complex input transformations like speech recognition or translation with subword encodings.

    documentation 
    opened by cantwbr 7
  • Update core.py -

    Update core.py -

    Added this note to the LogSoftmax function: Note that the implementation actually computes x - LogSumExp(x), which is mathematically equal to LogSoftmax(x).

    cla: yes ready to pull 
    opened by MichalRyszardWojcik 6
  • Error while retrieving transformer weights in intro notebook

    Error while retrieving transformer weights in intro notebook

    Description

    Initiating weights from the gcloud source leads to a NotFoundError ...

    Environment information

    OS: Pop!_OS 20.04 (Based on Ubuntu 20.04)
    
    $ pip freeze | grep trax
    trax==1.3.4
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.16
    tensor2tensor==1.15.7
    tensorboard==2.3.0
    tensorboard-plugin-wit==1.6.0.post3
    tensorflow==2.3.0
    tensorflow-addons==0.10.0
    tensorflow-datasets==3.2.1
    tensorflow-estimator==2.3.0
    tensorflow-gan==2.0.0
    tensorflow-gpu==2.3.0
    tensorflow-hub==0.8.0
    tensorflow-metadata==0.22.2
    tensorflow-probability==0.7.0
    tensorflow-text==2.3.0
    
    $ pip freeze | grep jax
    jax==0.1.74
    jaxlib==0.1.52
    
    $ python -V
    Python 3.8.2
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    model = trax.models.Transformer(
        input_vocab_size=33300, 
        d_model=512, d_ff=2048, 
        n_heads=8, n_encoder_layers=6, n_decoder_layers=6, 
        max_len=2048, mode='predict') 
    # Initialize using pre-trained weights. 
    model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',  weights_only=True)
    
    # Error logs:
    ...
    NotFoundError: Error executing an HTTP request: HTTP response code 404 with body '<?xml version='1.0' encoding='UTF-8'?><Error><Code>NoSuchKey</Code><Message>The specified key does not exist.</Message><Details>No such object: trax-ml/models/translation/end_wmt32k.pkl.gz</Details></Error>'
    	 when reading gs://trax-ml/models/translation/end_wmt32k.pkl.gz
    
    
    opened by Cquential 6
  • [Bug] loss_fn argument for Trainer must not be a function since 1.2.4

    [Bug] loss_fn argument for Trainer must not be a function since 1.2.4

    Description

    As in previous examples shown, loss_fn should be callable like this:

    trainer = trax.supervised.Trainer(
        model=eval(train_model.selector),
        loss_fn=trax.layers.CrossEntropyLoss,
        optimizer=trax.optimizers.Adam,
        lr_schedule=trax.lr.MultifactorSchedule,
        inputs=trax.supervised.inputs.Inputs(train_stream),
        output_dir=output_dir,
    )
    

    However, since the latest upgrade to 1.2.4 this cannot not work anymore.

    In the trainer_lib the loss_fn gets passed to a Serial constructor:

    https://github.com/google/trax/blob/93f2bd47f5f17aacafe3f312ae56ce6f98d93ee7/trax/supervised/trainer_lib.py#L130

    Which in turn runs _ensure_flat in it's constructor

    https://github.com/google/trax/blob/5b1565910a53d0d1175f647cc67db48e334d8f90/trax/layers/combinators.py#L47

    However, all objects in layers have to be of type base.Laser:

    def _ensure_flat(layers):
      """Ensures that layers is a single flat list of Layer instances."""
      if len(layers) == 1 and layers[0] is None:
        layers = ()
      else:
        layers = _deep_flatten(layers)
      for obj in layers:
        if not isinstance(obj, base.Layer):
          raise ValueError(
              f'Found nonlayer object ({obj}) in layers: {layers}')
      return layers
    

    See

    https://github.com/google/trax/blob/5b1565910a53d0d1175f647cc67db48e334d8f90/trax/layers/combinators.py#L775

    Thus we'll see an exception:

    ValueError: Found nonlayer object (<function CrossEntropyLoss at 0x7fc5be59a9e0>) in layers:
    
    opened by stefan-falk 6
  • ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'

    ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'

    Description

    ImportError thrown after importing libraries ...

    Environment information

    trax 1.4.1

    OS: Ubuntu 
    
    $ pip freeze | grep trax
    trax                         1.4.1
    
    $ pip freeze | grep tensor
    mesh-tensorflow==0.1.21
    tensor2tensor==1.15.7
    tensorboard==2.11.0
    tensorboard-data-server==0.6.1
    tensorboard-plugin-wit==1.8.1
    tensorflow==2.11.0
    tensorflow-addons==0.18.0
    tensorflow-datasets==4.7.0
    tensorflow-estimator==2.11.0
    tensorflow-gan==2.1.0
    tensorflow-hub==0.12.0
    tensorflow-io-gcs-filesystem==0.28.0
    tensorflow-metadata==1.12.0
    tensorflow-probability==0.7.0
    tensorflow-text==2.11.0
    tensorstore==0.1.28
    
    $ pip freeze | grep jax
    jax==0.3.25
    jaxlib==0.3.25
    
    $ python -V
    Python 3.8.10
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    !pip install -q -U trax
    
    import numpy as np  # regular ol' numpy
    
    from trax import fastmath
    from trax import layers as tl
    from trax import shapes
    from trax.fastmath import numpy as jnp  # For use in defining new layer types.
    from trax.shapes import ShapeDtype
    from trax.shapes import signature
    # Error logs:
    ---------------------------------------------------------------------------
    ImportError                               Traceback (most recent call last)
    Cell In[22], line 3
          1 import numpy as np  # regular ol' numpy
    ----> 3 from trax import fastmath
          4 from trax import layers as tl
          5 from trax import shapes
    
    File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/__init__.py:18
          1 # coding=utf-8
          2 # Copyright 2021 The Trax Authors.
          3 #
       (...)
         13 # See the License for the specific language governing permissions and
         14 # limitations under the License.
         16 """Trax top level import."""
    ---> 18 from trax import data
         19 from trax import fastmath
         20 from trax import layers
    
    File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/data/__init__.py:70
         67 from trax.data.inputs import UnBatch
         68 from trax.data.inputs import UniformlySeek
    ---> 70 from trax.data.tf_inputs import add_eos_to_output_features
         71 from trax.data.tf_inputs import BertGlueEvalStream
    ...
         35 from trax.layers.attention import SplitIntoHeads
         38 # Layers are always CamelCase, but functions in general are snake_case
         39 # pylint: disable=invalid-name
    
    ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'
    
    opened by jpontalba 0
  • Machine Translation Refromer model.pkl for trax 1.4.1?

    Machine Translation Refromer model.pkl for trax 1.4.1?

    Description

    I am trying to translate the Reformer machine_translation code https://github.com/google/trax/blob/master/trax/models/reformer/machine_translation.ipynb that was written for trax 1.2.9 to make it work with trax 1.4.1 However, my program crashes when I try to load model.pkl with this error: "ModuleNotFoundError: No module named 'trax.history'" My understanding is that trax.history in 1.2.9 has been moved to trax.supervised.history in 1.4.1. I believe I need a new model.pkl for 1.4.1 to make it work. Where can I download the new model.pkl? It would also be great if I can download the new config.gin as well. Thanks a lot in advance. ...

    Environment information

    OS: <your answer here>
    Ubuntu 20.04
    
    $ pip freeze | grep trax
    # your output here
    trax==1.4.1
    
    $ pip freeze | grep tensor
    # your output here
    tensorboard @ file:///home/conda/feedstock_root/build_artifacts/tensorboard_1664238338171/work/tensorboard-2.10.1-py3-none-any.whl
    tensorboard-data-server @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-data-server_1649932776625/work/tensorboard_data_server-0.6.0-py3-none-manylinux2010_x86_64.whl
    tensorboard-plugin-wit @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-plugin-wit_1641458951060/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
    tensorflow==2.10.0
    tensorflow-datasets==4.7.0
    tensorflow-estimator @ file:///home/conda/feedstock_root/build_artifacts/tensorflow-split_1663957899180/work/tensorflow-estimator/wheel_dir/tensorflow_estimator-2.10.0-py2.py3-none-any.whl
    tensorflow-hub==0.12.0
    tensorflow-io-gcs-filesystem==0.27.0
    tensorflow-metadata==1.10.0
    tensorflow-text==2.10.0
    
    $ pip freeze | grep jax
    # your output here
    jax @ file:///home/conda/feedstock_root/build_artifacts/jax_1665610009116/work
    jaxlib==0.3.22
    
    $ python -V
    # your output here
    Python 3.8.10
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    ...
    

    import sys import gin import os import pickle import jax import trax import numpy as np import jax.numpy as jnp import sacrebleu from trax.data.text_encoder import SubwordTextEncoder from tensorflow.io.gfile import GFile

    Load the source text and reference translations into Python

    refs = [] for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1): if line.endswith('\n'): line = line[:-1] refs.append(line) srcs = [] for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1): if line.endswith('\n'): line = line[:-1] srcs.append(line)

    Set up our sub-word tokenizer

    tokenizer = SubwordTextEncoder( 'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')

    Encode source sentences using the tokenizer

    input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64) for i, x in enumerate(srcs): x = tokenizer.encode(x) assert len(x) <= 127 input_ids[i, :len(x)] = x input_ids[i, len(x)] = 1

    We'll be using a pre-trained reversible transformer-base model.

    First, load the config (which sets all needed hyperparameters).

    !gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin

    gin.parse_config_file('./config.gin')

    Now we load the pre-trained model weights.

    with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f: model_weights = pickle.load(f)['weights']

    # Error logs:
    ...
    

    Traceback (most recent call last): File "reformer.py", line 65, in model_weights = pickle.load(f)['weights'] ModuleNotFoundError: No module named 'trax.history'

    opened by ymcki 0
  • TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

    TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

    Description

    Hi, I am trying to follow this tutorial: https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb Setting the runtime to TPU on Colab used to work couple of days ago. But now it crashes with error:

    TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'

    This happens at this step: training_loop = training.Loop(model,.....

    Environment information

    OS: 
    NAME="Ubuntu"
    VERSION="18.04.6 LTS (Bionic Beaver)"
    ID=ubuntu
    ID_LIKE=debian
    PRETTY_NAME="Ubuntu 18.04.6 LTS"
    VERSION_ID="18.04"
    HOME_URL="https://www.ubuntu.com/"
    SUPPORT_URL="https://help.ubuntu.com/"
    BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
    PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
    VERSION_CODENAME=bionic
    UBUNTU_CODENAME=bionic
    
    $ pip freeze | grep trax
    # trax==1.4.1
    
    
    $ pip freeze | grep tensor
    # tensorboard==2.10.0
    tensorboard-data-server==0.6.1
    tensorboard-plugin-wit==1.8.1
    tensorflow==2.10.0
    tensorflow-datasets==4.6.0
    tensorflow-estimator==2.10.0
    tensorflow-gcs-config==2.8.0
    tensorflow-hub==0.12.0
    tensorflow-io-gcs-filesystem==0.26.0
    tensorflow-metadata==1.10.0
    tensorflow-probability==0.16.0
    tensorflow-text==2.10.0
    
    $ pip freeze | grep jax
    # jax==0.3.17
    jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.15+cuda11.cudnn805-cp37-none-manylinux2014_x86_64.whl
    
    $ python -V
    # Python 3.7.13
    
    

    For bugs: reproduction and error logs

    # Steps to reproduce:
    https://github.com/google/trax/blob/master/trax/examples/NMT_with_Transformers_Reformers_using_Trax.ipynb
    
    ...
    
    # Error logs:
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    [<ipython-input-8-2021642a85f0>](https://localhost:8080/#) in <module>
          9                               train_task,
         10                               eval_tasks=[eval_task],
    ---> 11                               output_dir=output_dir)
    
    16 frames
    [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in __init__(self, model, tasks, eval_model, eval_tasks, output_dir, checkpoint_at, checkpoint_low_metric, checkpoint_high_metric, permanent_checkpoint_at, eval_at, which_task, n_devices, random_seed, loss_chunk_size, use_memory_efficient_trainer, adasum, callbacks)
        278 
        279     # Create the optimizer for the training loss function.
    --> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
        281 
        282     # Sync layers weights/state in memory effcient trainer layers.
    
    [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in <genexpr>(.0)
        278 
        279     # Create the optimizer for the training loss function.
    --> 280     self._trainer_per_task = tuple(self._init_trainer(task) for task in tasks)
        281 
        282     # Sync layers weights/state in memory effcient trainer layers.
    
    [/usr/local/lib/python3.7/dist-packages/trax/supervised/training.py](https://localhost:8080/#) in _init_trainer(self, task)
        348         task.optimizer.tree_init(model_in_training.weights)
        349       return optimizers.Trainer(
    --> 350           model_in_training, task.optimizer, adasum=self._adasum)
        351     # In the memory-efficient path, we initialize the model here.
        352     blocks, loss_layer = optimizers.trainer.extract_reversible_blocks(
    
    [/usr/local/lib/python3.7/dist-packages/trax/optimizers/trainer.py](https://localhost:8080/#) in __init__(self, model_with_loss, optimizer, n_devices, adasum)
         57     # optimizer slots and opt_params may need to be replicated
         58     self._slots, self._opt_params = tl.on_cpu(tl.for_n_devices(
    ---> 59         (self._optimizer.slots, self._optimizer.opt_params), self._n_devices))
         60 
         61     # accelerated version of model+loss to replicate weights and state
    
    [/usr/local/lib/python3.7/dist-packages/trax/layers/acceleration.py](https://localhost:8080/#) in on_cpu(x)
        250   """Puts ``x`` in CPU memory in JAX."""
        251   if fastmath.is_backend(fastmath.Backend.JAX):
    --> 252     return jax.device_put(x, jax.devices('cpu')[0])
        253   else:
        254     return x
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in device_put(x, device)
       2722   """
       2723   with config_explicit_device_put_scope():
    -> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
       2725 
       2726 
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map(f, tree, is_leaf, *rest)
        203   leaves, treedef = tree_flatten(tree, is_leaf)
        204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    --> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
        206 
        207 def build_tree(treedef, xs):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
        203   leaves, treedef = tree_flatten(tree, is_leaf)
        204   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    --> 205   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
        206 
        207 def build_tree(treedef, xs):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in <lambda>(y)
       2722   """
       2723   with config_explicit_device_put_scope():
    -> 2724     return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
       2725 
       2726 
    
    [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params)
        323     assert (not config.jax_enable_checks or
        324             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
    --> 325     return self.bind_with_trace(find_top_trace(args), args, params)
        326 
        327   def bind_with_trace(self, trace, args, params):
    
    [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
        326 
        327   def bind_with_trace(self, trace, args, params):
    --> 328     out = trace.process_primitive(self, map(trace.full_raise, args), params)
        329     return map(full_lower, out) if self.multiple_results else full_lower(out)
        330 
    
    [/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
        684 
        685   def process_primitive(self, primitive, tracers, params):
    --> 686     return primitive.impl(*tracers, **params)
        687 
        688   def process_call(self, primitive, f, tracers, params):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_impl(x, device)
       1219     raise TypeError(
       1220         f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
    -> 1221   return aval_to_result_handler(device, a)(None, *device_put(x, device))
       1222 
       1223 
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in device_put(x, device)
       1113   x = xla.canonicalize_dtype(x)
       1114   try:
    -> 1115     return device_put_handlers[type(x)](x, device)
       1116   except KeyError as err:
       1117     raise TypeError(f"No device_put handler for type: {type(x)}") from err
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _device_put_array(x, device)
       1124   if x.dtype == dtypes.float0:
       1125     x = np.zeros(x.shape, dtype=np.dtype(bool))
    -> 1126   return (backend.buffer_from_pyval(x, device),)
       1127 
       1128 def _device_put_scalar(x, device):
    
    [/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in __array__(self, dtype, context)
        264 
        265   def __array__(self, dtype=None, context=None):
    --> 266     return np.asarray(self._value, dtype=dtype)
        267 
        268   setattr(device_array, "__array__", __array__)
    
    [/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py](https://localhost:8080/#) in _sda_value(self)
        803     npy_value = np.empty(self.aval.shape, self.aval.dtype)
        804     for i in self.one_replica_buffer_indices:
    --> 805       npy_value[self.indices[i]] = np.asarray(self.device_buffers[i])
        806     self._npy_value = npy_value
        807   return self._npy_value
    
    TypeError: float() argument must be a string or a number, not 'jaxlib.tpu_client_extension.PyTpuBuffer'
    ...
    
    opened by agoliaei 0
  • The colab button on Knowledge_Tracing_Transformer.ipynb is not open

    The colab button on Knowledge_Tracing_Transformer.ipynb is not open

    Description

    Small issue here: The Open in Colab button in the notebook Knowledge_Tracing_Transformer.ipynb in the directory "trax/trax/examples" leads to a private file on Google Drive.

    opened by haytamdon 0
Releases(v1.4.1)
  • v1.4.1(Oct 26, 2021)

  • v1.4.0(Oct 26, 2021)

  • v1.3.9(May 21, 2021)

  • v1.3.8(Apr 26, 2021)

  • v1.3.7(Dec 18, 2020)

    Lots of documentation, bugs squashed and misc changes.

    Features

    • Load checkpoint for fine-tuning in https://github.com/google/trax/commit/0349c3f2c32ed83d0ba54a0f6652766ea5467cca by @henrykmichalewski !
    • Added generic input pipeline for GLUE tasks in https://github.com/google/trax/commit/2f490e83a9be1b802dccf866c52c8908645e10d6 by @henrykmichalewski - thanks a lot!
    • Some nascent support for bfloat16s!

    Models

    • Performer's Favor and CausalFavor - https://github.com/google/trax/commit/77db199392ff967e49156ccb24c916b270b47eca thanks @lukaszkaiser !
    • Funnel-Transformer in https://github.com/google/trax/pull/1156 thanks a lot @mvxxx !
    • BERT for Trax in https://github.com/google/trax/pull/1254 , https://github.com/google/trax/pull/1223, etc by @piotrekp1 - thanks a lot!
    • Residual Exchange Network by @kkanska in https://github.com/google/trax/commit/3a8f402b722a0d9ae934ed399e6532dba8401f9e ! Thanks!

    PRs Merged

    • Fixing broken example links in https://github.com/google/trax/pull/1263 thanks @amtagrwl !
    • Added WideResnet, Deconv etc Example Notebook in https://github.com/google/trax/pull/1259 , https://github.com/google/trax/pull/1202, https://github.com/google/trax/pull/1232 thanks a lot! @SauravMaheshkar !
    • Remove implicit object from the base class in https://github.com/google/trax/pull/1228 thanks @HarshCasper !
    • Fashion MNIST example in https://github.com/google/trax/pull/1199 thanks @Jimexist !
    • Fix PretrainedBERT init in https://github.com/google/trax/pull/1135 by @hepaajan !
    • Typo in the TransformerDecoder input parameters description in https://github.com/google/trax/pull/1100 by @kujaomega !
    Source code(tar.gz)
    Source code(zip)
  • v1.3.6(Oct 21, 2020)

  • v1.3.5(Sep 19, 2020)

    PRs: Thanks @NathanHowell for fixing a bunch of typos in #962 Thanks @DarrenZhang01 for contributing to the TF-Numpy extensions code in #954

    ReversibleSerialTrainer for memory efficient, layer by layer training.

    Miscellaneous other issues.

    Source code(tar.gz)
    Source code(zip)
  • v1.3.3(Jul 26, 2020)

    Minor Fixes:

    • Rename gumbel_sample to logsoftmax_sample in a2497cbd8477a11f2b96e87a4d53ce46b845ffa8
    • A fix to storing checkpoints in Loop in 88b033c804a4da9021505f2ce8fbb5c7d6500574
    Source code(tar.gz)
    Source code(zip)
  • v1.3.2(Jul 24, 2020)

    Framework:

    • Data pipeline combinators in 17d44710bee74e8bb6e3b34baa4114b0c78160af
    • Multi-host training in the Loop api 524321e0f77afa58a9fba470e5bccb19ae6bdb92
    • Auto-regressive sampling in fe0fa7886ee33bfac831a969a82b351ccb02941c
    • T2T Tokenizers 6306231ace01ab55c5df0c401b1f3cf4c7b6a32f
    • Early work on multi-task training in 423d664b470612ebbc0e5041e8288eb1a47c30ef thanks to @koz4k

    Tasks / Gin files:

    • IMDB dataset in 08bdb50db41a089b6e12c61a8a29981225bf3566
    • Machine Translation en-pl in 5fb8aa8c5cb86dabb2338938c745996d5d87d996
    Source code(tar.gz)
    Source code(zip)
  • v1.3.1(Jul 2, 2020)

    Miscellaneous fixes.

    • tl.Embedding now has the same signature as pytorch/tf
    • train.lr_schedule (function object) -> train.lr_schedule_fn (function)
    • Report loss back to training.Loop
    Source code(tar.gz)
    Source code(zip)
  • v1.3.0(Jun 30, 2020)

    Trax now has docs on - https://trax-ml.readthedocs.io/en/latest/trax.html thanks to @j2i2 !

    Many usability changes, especially in trax.supervised.training.TrainTask/EvalTask/Loop, docs, comments etc.

    • flat saved model/checkpoint representation
    • lr schedule simplified, now they just take step number.
    • configs are now in supervised/configs and rl/configs.
    • RL obsolete code cleanup.

    Also rapid development of the tf-numpy codebase !

    Source code(tar.gz)
    Source code(zip)
  • v1.2.4(Apr 18, 2020)

    Merged PRs:

    • #459 by @w4-sjcho - adding names to layers, aiding debuggability thanks a lot!
    • #256 and #300 by @stephenjfox and @satyarohith refining the README.md language, thanks a lot folks!
    • #313 #312 #436 #396 from @pkol with lots of bugfixes, thanks a lot @pkol !
    • #409 by @pkol -- a special shoutout to this PR, this fixes a long standing issue that prevented termination of the process by tracking the zombie threads -- thanks a lot for this @pkol specially !
    • #386 another shoutout to @pkol for an amazing speedup in the RL code -- thanks a lot again !
    • #344 a psum bugfix with tf backend from @fsx950223 - thanks a lot !
    • #335 a bugfix from @friesel - thanks a lot Phillip !
    • #315 better exception handling by @cool-RR - thanks a lot !

    Reformer:

    • BERT initialization and finetuning by Nikita!
    • Many changes including ReformerLM on C4 dataset.

    RL:

    • New 'light' RL code in the Trax supervised style, check it out!
    • AWR in the old code working with MuJoCo tasks.

    And many more changes the Trax framework !

    Source code(tar.gz)
    Source code(zip)
  • v1.2.3(Feb 25, 2020)

    Reformer

    • Reversible Transformer model for machine translation and other encoder-decoder tasks
    • Add code for beam search, sampling, and greedy decoding (see trax.models.beam_search.Search)
    • Memory-efficient attention classes have been re-written to use even less memory and to support faster decoding (see the new SelfAttention, LSHSelfAttention and EncDecAttention classes)

    RL

    • Implemented the Advantage-Weighted Regression algorithm, a simple off-policy reinforcement learning algorithm.
    • Extracted out a PolicyBasedTrainer, so ppo_trainer.PPO and awr_trainer.AwrTrainer now both inherit from it.
    • Refactoring of the serialization code in the RL part, thanks to @koz4k !

    Framework

    • A lot of code cleanup and refactoring of the core abstractions by Jonni, thanks Jonni!

    TF Numpy

    • More ops added by @wangpengmit !
    Source code(tar.gz)
    Source code(zip)
  • v1.2.2(Jan 17, 2020)

  • v1.2.1(Jan 16, 2020)

  • v1.2.0(Jan 15, 2020)

    New Models

    • Reformer Implementation - https://arxiv.org/abs/2001.04451 Thanks Nikita Kitaev, @lukaszkaiser and @levskaya !

    Colabs

    • Reformer Colabs https://github.com/google/trax/commit/9dbf8636a34cffbd421bedb2a1e3d7fe006346c0
    • Transformer Colab https://github.com/google/trax/commit/d2c5b84b8db53888013c56980ccbafba077ee8f2

    Framework Changes

    • Ongoing cleanups and API simplifications.
    • Optimization by @jekbradbury - thanks James!

    PRs

    • Consistent logging absl.logging and setup.py fixes thanks to @lkhphuc in #198
    • Code cleanups by @cclauss in #196
    • Code cleanup by @pkol in #134 - thanks!
    • Bug fix by @pzielinski-nyc in #151 - thanks!
    Source code(tar.gz)
    Source code(zip)
Owner
Google
Google ❤️ Open Source
Google
Fully Convolutional Refined Auto Encoding Generative Adversarial Networks for 3D Multi Object Scenes

Fully Convolutional Refined Auto-Encoding Generative Adversarial Networks for 3D Multi Object Scenes This repository contains the source code for Full

Yu Nishimura 106 Nov 21, 2022
PECOS - Prediction for Enormous and Correlated Spaces

PECOS - Predictions for Enormous and Correlated Output Spaces PECOS is a versatile and modular machine learning (ML) framework for fast learning and i

Amazon 387 Jan 04, 2023
Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor.

Qcover is an open source effort to help exploring combinatorial optimization problems in Noisy Intermediate-scale Quantum(NISQ) processor. It is devel

33 Nov 11, 2022
Neural-PIL: Neural Pre-Integrated Lighting for Reflectance Decomposition - NeurIPS2021

Neural-PIL: Neural Pre-Integrated Lighting for Reflectance Decomposition Project Page | Video | Paper Implementation for Neural-PIL. A novel method wh

Computergraphics (University of Tübingen) 64 Dec 29, 2022
Understanding Convolutional Neural Networks from Theoretical Perspective via Volterra Convolution

nnvolterra Run Code Compile first: make compile Run all codes: make all Test xconv: make npxconv_test MNIST dataset needs to be downloaded, converted

1 May 24, 2022
The code succinctly shows how our ensemble learning based on deep learning CNN is used for LAM-avulsion-diagnosis.

deep-learning-LAM-avulsion-diagnosis The code succinctly shows how our ensemble learning based on deep learning CNN is used for LAM-avulsion-diagnosis

1 Jan 12, 2022
Einshape: DSL-based reshaping library for JAX and other frameworks.

Einshape: DSL-based reshaping library for JAX and other frameworks. The jnp.einsum op provides a DSL-based unified interface to matmul and tensordot o

DeepMind 62 Nov 30, 2022
StrongSORT: Make DeepSORT Great Again

StrongSORT StrongSORT: Make DeepSORT Great Again StrongSORT: Make DeepSORT Great Again Yunhao Du, Yang Song, Bo Yang, Yanyun Zhao arxiv 2202.13514 Abs

369 Jan 04, 2023
PyTorch Implementation of Region Similarity Representation Learning (ReSim)

ReSim This repository provides the PyTorch implementation of Region Similarity Representation Learning (ReSim) described in this paper: @Article{xiao2

Tete Xiao 74 Jan 03, 2023
This repository contains the code for the CVPR 2021 paper "GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields"

GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields Project Page | Paper | Supplementary | Video | Slides | Blog | Talk If

1.1k Dec 30, 2022
PyTorch implementation of SimSiam: Exploring Simple Siamese Representation Learning

SimSiam: Exploring Simple Siamese Representation Learning This is a PyTorch implementation of the SimSiam paper: @Article{chen2020simsiam, author =

Facebook Research 834 Dec 30, 2022
Fast and scalable uncertainty quantification for neural molecular property prediction, accelerated optimization, and guided virtual screening.

Evidential Deep Learning for Guided Molecular Property Prediction and Discovery Ava Soleimany*, Alexander Amini*, Samuel Goldman*, Daniela Rus, Sangee

Alexander Amini 75 Dec 15, 2022
face2comics by Sxela (Alex Spirin) - face2comics datasets

This is a paired face to comics dataset, which can be used to train pix2pix or similar networks.

Alex 164 Nov 13, 2022
An open source python library for automated feature engineering

"One of the holy grails of machine learning is to automate more and more of the feature engineering process." ― Pedro Domingos, A Few Useful Things to

alteryx 6.4k Jan 03, 2023
Asymmetric metric learning for knowledge transfer

Asymmetric metric learning This is the official code that enables the reproduction of the results from our paper: Asymmetric metric learning for knowl

20 Dec 06, 2022
Reading list for research topics in Masked Image Modeling

awesome-MIM Reading list for research topics in Masked Image Modeling(MIM). We list the most popular methods for MIM, if I missed something, please su

ligang 231 Dec 07, 2022
PaddleRobotics is an open-source algorithm library for robots based on Paddle, including open-source parts such as human-robot interaction, complex motion control, environment perception, SLAM positioning, and navigation.

简体中文 | English PaddleRobotics paddleRobotics是基于paddle的机器人开源算法库集,包括人机交互、复杂运动控制、环境感知、slam定位导航等开源算法部分。 人机交互 主动多模交互技术TFVT-HRI 主动多模交互技术是通过视觉、语音、触摸传感器等输入机器人

185 Dec 26, 2022
Official implementation of the ICCV 2021 paper "Joint Inductive and Transductive Learning for Video Object Segmentation"

JOINT This is the official implementation of Joint Inductive and Transductive learning for Video Object Segmentation, to appear in ICCV 2021. @inproce

Yunyao 35 Oct 16, 2022
Code for "Neural 3D Scene Reconstruction with the Manhattan-world Assumption" CVPR 2022 Oral

News 05/10/2022 To make the comparison on ScanNet easier, we provide all quantitative and qualitative results of baselines here, including COLMAP, COL

ZJU3DV 365 Dec 30, 2022
The code of Zero-shot learning for low-light image enhancement based on dual iteration

Zero-shot-dual-iter-LLE The code of Zero-shot learning for low-light image enhancement based on dual iteration. You can get the real night image tests

1 Mar 18, 2022