Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Overview

Elegy

PyPI Status Badge Coverage PyPI - Python Version Documentation Code style: black Contributions welcome Status


Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Main Features

  • Easy-to-use: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
  • Flexible: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
  • Agnostic: Elegy supports a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
  • Compatible: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the Documentation.

Installation

Install Elegy using pip:

pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 WSL2 since jax does not support it yet.

Quick Start: High-level API

Elegy's high-level API provides a very simple interface you can use by implementing following steps:

1. Define the architecture inside a Module. We will use Flax Linen for this example:

import flax.linen as nn
import jax

class MLP(nn.Module):
    @nn.compact
    def call(self, x):
        x = nn.Dense(300)(x)
        x = jax.nn.relu(x)
        x = nn.Dense(10)(x)
        return x

2. Create a Model from this module and specify additional things like losses, metrics, and optimizers:

import elegy, optax

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
        elegy.regularizers.GlobalL2(l=1e-5),
    ],
    metrics=elegy.metrics.SparseCategoricalAccuracy(),
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Quick Start: Low-level API

In Elegy's low-level API lets you define exactly what goes on during training, testing, and inference. Lets define the test_step to implement a linear classifier in pure jax:

1. Calculate our loss, logs, and states:

class LinearClassifier(elegy.Model):
    # request parameters by name via depending injection.
    # names: x, y_true, sample_weight, class_weight, states, initializing
    def test_step(
        self,
        x, # inputs
        y_true, # labels
        states: elegy.States, # model state
        initializing: bool, # if True we should initialize our parameters
    ):  
        rng: elegy.RNGSeq = states.rng
        # flatten + scale
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        # initialize or use existing parameters
        if initializing:
            w = jax.random.uniform(
                rng.next(), shape=[np.prod(x.shape[1:]), 10]
            )
            b = jax.random.uniform(rng.next(), shape=[1])
        else:
            w, b = states.net_params
        # model
        logits = jnp.dot(x, w) + b
        # categorical crossentropy loss
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
        # metrics
        logs = dict(
            accuracy=accuracy,
            loss=loss,
        )
        return loss, logs, states.update(net_params=(w, b))

2. Instantiate our LinearClassifier with an optimizer:

model = LinearClassifier(
    optimizer=optax.rmsprop(1e-3),
)

3. Train the model using the fit method:

model.fit(
    x=X_train,
    y=y_train,
    epochs=100,
    steps_per_epoch=200,
    batch_size=64,
    validation_data=(X_test, y_test),
    shuffle=True,
    callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

Using Jax Frameworks

It is straightforward to integrate other functional JAX libraries with this low-level API:

class LinearClassifier(elegy.Model):
    def test_step(
        self, x, y_true, states: elegy.States, initializing: bool
    ):
        rng: elegy.RNGSeq = states.rng
        x = jnp.reshape(x, (x.shape[0], -1)) / 255
        if initializing:
            logits, variables = self.module.init_with_output(
                {"params": rng.next(), "dropout": rng.next()}, x
            )
        else:
            variables = dict(params=states.net_params, **states.net_states)
            logits, variables = self.module.apply(
                variables, x, rngs={"dropout": rng.next()}, mutable=True
            )
        net_states, net_params = variables.pop("params")
        
        labels = jax.nn.one_hot(y_true, 10)
        loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
        accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

        logs = dict(accuracy=accuracy, loss=loss)
        return loss, logs, states.update(net_params=net_params, net_states=net_states)

More Info

Examples

To run the examples first install some required packages:

pip install -r examples/requirements.txt

Now run the example:

python examples/flax_mnist_vae.py 

Contributing

Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our Contributing Guide.

About Us

We are some friends passionate about ML.

License

Apache

Citing Elegy

To cite this project:

BibTeX

@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.5.0},
year = {2020},
}

Where the current version may be retrieved either from the Release tag or the file elegy/__init__.py and the year corresponds to the project's release year.

Comments
  • Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks.

    As noted below, this PR contains the following features:

    • It turns Elegy into a framework agnostic library by removing the dependencies between elegy.Model and elegy.Module, it proposes the GeneralizedModule API and implements it for Flax, Haiku, Elegy Module types, and regular python functions.
    • It introduces a new low-level API similar to Pytorch Lightning that lets users manually override the core parts of the training loop when maximal flexibility is required.
    • General changes that enable the framework-agnostic mindset.
    • Many quality of life changes like standardization of hooks, simplification of the Module system, etc.

    Tasks:

    • [x] Create hooks module
    • [x] Refactor Model with low-level API and remove Module dependencies
    • [x] Refactor Module to use hooks
    • [x] Create GeneralizedModule and GeneralizedOptimizer Inferfaces
    • [x] Implement GeneralizedModule for flax.linen.Module
    • [x] Implement GeneralizedModule for elegy.Module
    • [x] Implement GeneralizedModule for haiku.Module
    • [x] Implement GeneralizedOptimizer for optax.GradientTransformation
    • [x] Implement GeneralizedOptimizer for elegy.Optimizer
    • [x] Fix Model.summary
    • [x] Fix tests
    • [x] Fix examples
    • [ ] Fix README
    • [ ] Fix guides
    • [ ] Fix docstrings
    opened by cgarciae 27
  • WGAN-GP low-level API example

    WGAN-GP low-level API example

    A more extensive example using the new low-level API: Wasserstein-GAN with Gradient Penalty (WGAN-GP) trained on the CelebA dataset.

    Some good generated images: epoch-0079 epoch-0084 epoch-0089

    Some notes:

    • I first tried to train a DCGAN which uses binary crossentropy but I've run into balancing issues. The discriminator quickly becomes too good so that the generator does not learn anything. The same model implemented in PyTorch or TensorFlow works. Most modern GANs don't use the WGAN loss anymore, most use BCE.
    • I'm still in favor of making Module.apply() return init(). It's just too much boilerplate to use an if-else every time. I avoided it by manually calling wgan.states = wgan.init(...) after model instantiation which I think is also not nice.
    • Can we make Module.apply() accept params and states separately instead of collections. It's annoying having to construct a dict {'params':params, 'states':states} every time
    • It would be nice if elegy.States was a dict so that the user can decide by themself what to put into it. With GANs where you have to manage generator and discriminator states separately one has to always split them like (g_states, d_states) = net_states which is again annoying
    • Model.save() fails on this model. Partially due to the extra jitted functions but even when I remove them, cloudpickle chokes on _HooksContext

    @cgarciae I'm not completely sure I've used the low-level API correctly, maybe you can take a closer look?

    opened by alexander-g 11
  • Add learning rate logging

    Add learning rate logging

    Implements the same functionality from #131 using only minor modifications to elegy.Optimizer.

    • [x] Add lr_schedule and steps_per_epoch to Optimizer.
    • [x] Implement Optimizer.get_effective_learning_rate
    • [x] Copy logging code from #131
    • [x] Add documentation

    @alexander-g Here is a proposal that is a bit simpler, closer to what I mentioned in #124. What do you think? @charlielito should we log the learning rate automatically if available or should we create a Callback?

    opened by cgarciae 9
  • Question: how to set the random state when calling model.predict(...)

    Question: how to set the random state when calling model.predict(...)

    Not sure if this is the right place to post this...

    I have built and trained a VAE. When calling model.predict(x=test_set), I would like to make multiple predictions for each item in the test set (because VAE's are probabilistic). That way I can look at the distribution of predictions for each item in the test_set.

    The call() for the VAE includes the line
    intrinsic_latents = mean + stds * jax.random.normal(self.next_key(), mean.shape).

    I haven't been able to find an explanation for how self.next_key() works or how to change the random seed on each call so that I can get different predictions. I could rewrite the code so that random seeds are explicitly passed, but I assume there is some functionality build into elegy to make this easy?

    Could someone explain how this works, or point me to the documentation explaining it?

    Thanks!

    opened by jfcrenshaw 8
  • Examples Cleanup

    Examples Cleanup

    • refactored examples/imagenet/resnet_imagenet.py to accept parameters instead of modifying them inside the script
    • added README.md for examples/imagenet/
    • removed unnecessary Lambda class from examples/mnist.py
    • moved global average pooling in examples/mnist_conv.py before the Linear layer
    opened by alexander-g 7
  • Resnet

    Resnet

    • ResNet model architecture and an example for training on ImageNet
      • code is mostly adapted from the flax library
      • pretrained ResNet50 with 76.5% accuracy
      • pretrained ResNet18 with 68.7% accuracy
    • Experimental support for mixed precision: previously all layers set their parameters' dtype to the input's dtype. This is incorrect, for numerical stability reasons all parameters should be float32 even when performing float16 computations. See more here.
    • Some issues I had during training:
      • There seems to be a memory leak during training, RAM constantly increased
      • I had to use smaller batch sizes than when training with flax or with TensorFlow before maxing out GPU memory (64 instead of 128 for ResNet50 on a RTX2080Ti). This might be of course due to a mistake in my code, but the number of parameters is identical to the flax and PyTorch versions, so I think it might be somewhere else
    opened by alexander-g 7
  • [Bug] Problem with computing metrics

    [Bug] Problem with computing metrics

    Describe the bug Hi, when I am using the fit function I have an error message that the update function is not provided with y_true and y_pred. It seems to be coming from the metrics of the model, because if I comment the metrics line I have no error

    TypeError: update() missing 2 required positional arguments: 'y_true' and 'y_pred'
    

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import jax
    import jax.numpy as jnp
    import ml_collections
    import numpy as np
    import optax
    import elegy as eg
    
    
    class eCNN(eg.Module):
        """A simple CNN model."""
    
        @eg.compact
        def __call__(self, x):
            x=eg.Conv(10,kernel_size=(10,))(x)
            x=jax.nn.relu(x)
            x = eg.Linear(1)(x)
            x=jax.nn.sigmoid(x)
            return x
    
    n=200
    X_train = np.random.rand(n*100).reshape(n,100)
    y_train = np.random.rand(n).reshape(n,1)
    print(X_train.shape)
    print(y_train.shape)
    
    model = eg.Model(
        module=eCNN(),
        loss=[
            eg.losses.MeanSquaredError(),
        ],
        metrics=eg.metrics.MeanSquareError(),  #Line to be commented to get rid of the error
        optimizer=optax.rmsprop(1e-3),
    )
    
    model.fit(X_train,y_train,
        epochs=10,
        batch_size=20,
        #validation_data=0.1,
        shuffle=False,
        callbacks=[eg.callbacks.TensorBoard("summaries")]
        )
    

    Library Info Please provide os info and elegy version.

    import elegy
    print(elegy.__version__) 
    # 0.8.4
    
    bug 
    opened by organic-chemistry 6
  • Multi-gpu with pmap docs

    Multi-gpu with pmap docs

    One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?

    opened by sooheon 6
  • SCCE fix for bug in Jax<0.2.7

    SCCE fix for bug in Jax<0.2.7

    Small fix for a bug in Jax<0.2.7 where jax.lax.take_along_axis gives incorrect results for uint8 indices. Very relevant for semantic segmentation.

    Alternatively consider updating Jax

    opened by alexander-g 6
  • Dataset & DataLoader

    Dataset & DataLoader

    Dataset and parallel DataLoader API similar to PyTorch. Can be used with Model.fit()

    class MyDataset(elegy.data.Dataset):
        def __len__(self):
            return 128
    
        def __getitem__(self, i):
            #dummy data
            return np.random.random([224, 224, 3]),  np.random.randint(10)
    
    ds     = MyDataset()
    loader = elegy.data.DataLoader(ds, batch_size=8, n_workers=8, worker_type='thread', shuffle=True)
    
    batch = next(iter(loader))
    assert batch[0].shape == (8,224,224,3)
    assert batch[1].shape == (8,)
    assert len(loader) == 16
    
    model.fit(loader, epochs=10)
    
    opened by alexander-g 6
  • Implemented BinaryCrossentropy metric

    Implemented BinaryCrossentropy metric

    Updates:

    • Created BinaryCrossentropy metric
    • Created basic tests for BinaryCrossentropy metric (passing)
    • Created docs for BinaryCrossentropy metric
    • Refactored main docs by balancing files and correcting language typos
    documentation 
    opened by sebasarango1180 6
  • use poetry-core

    use poetry-core

    poetry-core is intended to be a light weight, fully compliant, self-contained package allowing PEP 517 compatible build frontends to build Poetry managed projects.

    Using poetry-core allows distribution packages to depend only on the build backend.

    opened by dotlambda 0
  • Documentation/API reference not accessible via project website[Bug]

    Documentation/API reference not accessible via project website[Bug]

    Hi, It looks like I can't really access the API reference for Elegy. The corresponding link on the project's website simply takes me back to the main page (https://poets-ai.github.io/elegy/). Any idea what's up?

    bug 
    opened by geomlyd 0
  • [Bug] elegy does not work with latest haiku version

    [Bug] elegy does not work with latest haiku version

    Describe the bug When I type 'import elegy' I get this error

     File "/home/kpmurphy/mambaforge/lib/python3.10/site-packages/elegy/generalized_module/haiku_module.py", line 4, in <module>
        from haiku._src.base import current_bundle_name
    

    Minimal code to reproduce

    import elegy
    

    Expected behavior A clear and concise description of what you expected to happen.

    Library Info Please provide os info and elegy version.

    >> 
    >>> jax.__version__
    '0.2.28'
    >>> haiku.__version__
    '0.0.9.dev'
    >>> elegy.__version__. #  elegy-0.5.0-py3-none-any.whl 
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    NameError: name 'elegy' is not defined
    >>> 
    

    Screenshots

    Screen Shot 2022-10-03 at 2 33 21 PM

    Additional context Add any other context about the problem here.

    bug 
    opened by murphyk 5
  • CSVLogger iteration over a 0-d array

    CSVLogger iteration over a 0-d array

    Describe the bug When using the CSVLogger callback, elegy crashes at the end of the first epoch.

    Minimal code to reproduce

    import elegy as eg
    import optax
    import numpy as np
    
    x = np.random.randn(64, 1)
    y = np.random.randn(64, 1)
    
    model = eg.Model(
        eg.nn.Linear(1),
        loss=eg.losses.MeanSquaredError(),
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        x,
        y,
        epochs=10,
        callbacks=[
            eg.callbacks.CSVLogger("train.csv"), <-- commenting
        ]
    )
    

    Stack trace:

    Epoch 1/10
    2/2 [==============================] - ETA: 0s - loss: 1.3408 - mean_squared_error_loss: 1.3408
    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/csv.py", line 14, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 465, in fit
        callbacks.on_epoch_end(epoch, epoch_logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/callback_list.py", line 221, in on_epoch_end
        callback.on_epoch_end(epoch, logs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in on_epoch_end
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 93, in <genexpr>
        row_dict.update((key, handle_value(logs[key])) for key in self.keys)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py", line 68, in handle_value
        return '"[%s]"' % (", ".join(map(str, k)))
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py", line 245, in __iter__
        raise TypeError("iteration over a 0-d array")  # same as numpy error
    TypeError: iteration over a 0-d array
    

    Expected behavior Should treat 0-d array as scalar.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context More detailed error information shows the error occurs because the array is a jax DeviceArray and so the test for zero dimensional array uses the line

    is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
    
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/callbacks/csv_logger.py:6 │
    │ 8 in handle_value                                                                                │
    │                                                                                                  │
    │    65 │   │   │   if isinstance(k, six.string_types):                                            │
    │    66 │   │   │   │   return k                                                                   │
    │    67 │   │   │   elif isinstance(k, tp.Iterable) and not is_zero_dim_ndarray:                   │
    │ ❱  68 │   │   │   │   return '"[%s]"' % (", ".join(map(str, k)))                                 │
    │    69 │   │   │   else:                                                                          │
    │    70 │   │   │   │   return k                                                                   │
    │    71                                                                                            │
    │                                                                                                  │
    │ ╭──────────────────────────── locals ─────────────────────────────╮                              │
    │ │ is_zero_dim_ndarray = False                                     │                              │
    │ │                   k = DeviceArray(4.8264385e-05, dtype=float32) │                              │
    │ ╰─────────────────────────────────────────────────────────────────╯                              │
    │                                                                                                  │
    │ /home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/jax/_src/device_array.py:245 in │
    │ __iter__                                                                                         │
    │                                                                                                  │
    │   242                                                                                            │
    │   243   def __iter__(self):                                                                      │
    │   244 │   if self.ndim == 0:                                                                     │
    │ ❱ 245 │     raise TypeError("iteration over a 0-d array")  # same as numpy error                 │
    │   246 │   else:                                                                                  │
    │   247 │     return (sl for chunk in self._chunk_iter(100) for sl in chunk._unstack())            │
    │   248                                                                                            │
    │                                                                                                  │
    │ ╭───────────────────── locals ─────────────────────╮                                             │
    │ │ self = DeviceArray(4.8264385e-05, dtype=float32) │                                             │
    │ ╰──────────────────────────────────────────────────╯                                             │
    ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
    TypeError: iteration over a 0-d array
    
    bug 
    opened by ScottAlexanderCameron 0
  • Metrics ignore

    Metrics ignore "on" keyword arg

    Describe the bug I have an application where I need to output multiple values from a network, which I am doing using a dictionary and using the on keyword argument. This works fine for the loss functions but not for metrics.

    Minimal code to reproduce Small snippet that contains a minimal amount of code.

    import elegy as eg
    import optax
    import numpy as np
    
    
    def data_generator():
        while True:
            yield (
                np.random.randn(10, 1),
                {"target": {"y": np.random.randn(10, 1)}},
            )
    
    
    class MyModule(eg.Module):
        @eg.compact
        def __call__(self, x):
            return {"y": eg.nn.Linear(1)(x)}
    
    
    model = eg.Model(
        MyModule(),
        loss=eg.losses.MeanSquaredError(on="y"),
        metrics=eg.metrics.MeanAbsoluteError(on="y"),  #  <-- works fine without this line
        optimizer=optax.adam(1e-3),
    )
    
    hist = model.fit(
        data_generator(),
        steps_per_epoch=10,
        epochs=10,
    )
    

    Stack trace:

    Traceback (most recent call last):
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/scott/.pyenv/versions/3.8.13/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/scott/Documents/phd/geom/pde/metric.py", line 27, in <module>
        hist = model.fit(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_base.py", line 417, in fit
        tmp_logs = self.train_on_batch(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 617, in train_on_batch
        logs, model = train_step_fn(self, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model_core.py", line 412, in _static_train_step
        return model.train_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 306, in train_step
        grads, (logs, model) = grad_fn(params, model, inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 278, in loss_fn
        loss, logs, model = model.test_step(inputs, labels)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/elegy/model/model.py", line 248, in test_step
        batch_loss_and_logs.update(
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/loss_and_logs.py", line 78, in update
        self.metrics.update(**metrics_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/metrics.py", line 44, in update
        metric.update(**metric_kwargs)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 83, in update
        values = _mean_absolute_error(preds, target)
      File "/home/scott/Documents/phd/geom/.venv/lib/python3.8/site-packages/treex/metrics/mean_absolute_error.py", line 20, in _mean_absolute_error
        target = target.astype(preds.dtype)
    AttributeError: 'dict' object has no attribute 'astype'
    

    Expected behavior Should produce the same result as if the dictionaries are removed and the on arg not specified.

    Library Info Please provide os info and elegy version. python version: 3.8.13 elegy version: 0.8.6 treex version: 0.6.10

    Additional context From my digging the cause seems to be due to the Metric.update() method being called instead of the __call__ method.

    bug 
    opened by ScottAlexanderCameron 0
  • [Bug] Elegy crash on GPU

    [Bug] Elegy crash on GPU

    Describe the bug

    Running mnist_cnn.py in the example dir crash the instance at the end of the first epoch.

    This was previously reported on Colab GPU instance. But I can reproduce this on CLI too.

    Running on CPU does not have this problem.

    Running on eager mode with GPU does not have this problem.

    Minimal code to reproduce

    python mnist_cnn.py
    

    Expected behavior Not stuck.

    Library Info CentOS Linux release 7.6.1810 elegy 0.8.6

    Additional context absl-py==1.2.0 aiohttp==3.8.1 aiosignal==1.2.0 async-timeout==4.0.2 attrs==22.1.0 certifi==2021.10.8 charset-normalizer==2.1.1 chex==0.1.4 click==8.1.3 cloudpickle==1.6.0 colorama==0.4.5 commonmark==0.9.1 cycler==0.11.0 datasets==2.4.0 dill==0.3.5.1 dm-tree==0.1.7 docker-pycreds==0.4.0 einops==0.4.1 elegy==0.8.6 etils==0.7.1 filelock==3.8.0 flax==0.4.2 fonttools==4.36.0 frozenlist==1.3.1 fsspec==2022.7.1 gitdb==4.0.9 GitPython==3.1.27 h5py==3.6.0 huggingface-hub==0.8.1 idna==3.3 importlib-resources==5.9.0 jax==0.3.16 jaxlib==0.3.15+cuda11.cudnn82 kiwisolver==1.4.4 matplotlib==3.5.3 msgpack==1.0.4 multidict==6.0.2 multiprocess==0.70.13 numpy==1.22.3 opt-einsum==3.3.0 optax==0.1.3 packaging==21.3 pandas==1.4.3 pathtools==0.1.2 Pillow==9.2.0 promise==2.3 protobuf==3.20.1 psutil==5.9.1 pyarrow==9.0.0 Pygments==2.13.0 pyparsing==3.0.9 python-dateutil==2.8.2 pytz==2022.2.1 PyYAML==6.0 requests==2.28.1 responses==0.18.0 rich==11.2.0 scipy==1.8.0 sentry-sdk==1.9.5 setproctitle==1.3.2 shortuuid==1.0.9 six==1.16.0 smmap==5.0.0 tensorboardX==2.5.1 toolz==0.12.0 tqdm==4.64.0 treeo==0.0.10 treex==0.6.10 typing_extensions==4.3.0 urllib3==1.26.11 wandb==0.12.21 xxhash==3.0.0 yarl==1.8.1 zipp==3.8.1

    bug 
    opened by jiyuuchc 2
Releases(0.8.6)
  • 0.8.6(Mar 23, 2022)

    🚀 Features

    • Weights and Biases Callback for Elegy
      • PR: #220

    🐛 Fixes

    • Docs typos
      • PR: #222
    • Donate model's memory buffer to jit/pmap functions.
      • PR: #226
    • Lazy load wandb
      • PR: #228
    Source code(tar.gz)
    Source code(zip)
  • 0.8.5(Feb 23, 2022)

  • 0.8.4(Dec 14, 2021)

  • 0.8.3(Dec 13, 2021)

  • 0.8.2(Dec 13, 2021)

  • 0.8.1(Nov 8, 2021)

    Elegy is now based on Treex 🎉

    Changes

    • Remove the module, nn, metrics, and losses from Elegy, instead Elegy reexports these modules from Treex.
    • GeneralizedModule and friends are gone, to use Flax Modules use the elegy.nn.FlaxModule wrapper.
    • Low level API is massively simplified:
      • States is removed, since Model is a pytree all parameters are tracked automatically thanks to Treex / Treeo.
      • All static state arguments (training, initializing) are removed, Modules can simply use self.training to pick their training state and self.initializing() to check whether they are initializing.
      • Signature for pred_step, test_step, and train_step now simply consists of inputs and labels, where labels is a dict that can contain additional keys like sample_weight or class_weight as required by the losses and metrics.
    • Adds the DistributedStrategy class which currently has 3 instances
      • Eager: Runs model in a single device in eager mode (no jit)
      • JIT: Runs model in a single device with jit
      • DataParallel: Run the model in multiple devices using pmap.
    • Adds methods to change the model's distributed strategy:
      • .distributed(strategy = DataParallel): changes the distributed strategy, DataParallel used by default.
      • .local(): changes the distributed strategy to JIT.
      • .eager(): changes the distributed strategy to Eager.
    • Removes the .eager field in favor of the .eager() method.
    Source code(tar.gz)
    Source code(zip)
  • 0.7.4(Jun 1, 2021)

  • 0.7.2(Mar 10, 2021)

  • 0.7.1(Mar 1, 2021)

  • 0.7.0(Feb 22, 2021)

    Features

    • init now only called once internally and required to be called explicitly by the user under certain circumstances
    • summary now uses jax.eval_shape under the hood so its super fast since it doesn't allocate memory or perform any computations on the device.

    Merged pull requests:

    • Fix notebook #166 (cgarciae)
    • Single Initialization: Removes the current progressive initialization in favor of a single underlying call to init_step. #165 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.6.0(Feb 14, 2021)

  • 0.5.0(Feb 8, 2021)

    This version simplifies parts of the low-level API in spirit of what was introduced in 0.4.0 to provide a more homogeneous and simpler experience.

    Merged pull requests:

    • Improve States: uses __dict__ so States works with vars #159 (cgarciae)
    • Simplify API: Cleans-up some API details around Model and Module #158 (cgarciae)
    Source code(tar.gz)
    Source code(zip)
  • 0.4.1(Feb 3, 2021)

  • 0.4.0(Feb 1, 2021)

    Implemented enhancements:

    • [Feature Request] Monitoring learning rates #124

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Dec 17, 2020)

    Implemented enhancements:

    • elegy.nn.Sequential docs not clear #107
    • [Feature Request] Community example repo. #98

    Fixed bugs:

    • [Bug] Accuracy from Model.evaluate() is inconsistent with manually computed accuracy #109
    • Exceptions in "Getting Started" colab notebook #104

    Closed issues:

    • l2_normalize #102
    • Need some help for contributing new losses. #93
    • Document Sum #62
    • Binary Accuracy Metric #58
    • Automate generation of API Reference folder structure #19
    • Implement Model.summary #3

    Merged pull requests:

    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Aug 31, 2020)

  • 0.2.1(Aug 25, 2020)

  • 0.2.0(Aug 17, 2020)

  • 0.1.5(Jul 28, 2020)

    • Mean Absolute Percentage Error Implementation @Ciroye
    • Adds elegy.nn.Linear, elegy.nn.Conv2D, elegy.nn.Flatten, elegy.nn.Sequential @cgarciae
    • Add Elegy hooks @cgarciae
    • Improves Tensorboard support @Davidnet
    • Added coverage metrics to CI @charlielito
    Source code(tar.gz)
    Source code(zip)
  • 0.1.4(Jul 24, 2020)

    • Adds elegy.metrics.BinaryCrossentropy @sebasarango1180
    • Adds elegy.nn.Dropout and elegy.nn.BatchNormalization @cgarciae
    • Improves documentation
    • Fixes bug that cause error when using is_training via dependency injection on Model.predict.
    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jul 23, 2020)

Supplementary code for the experiments described in the 2021 ISMIR submission: Leveraging Hierarchical Structures for Few Shot Musical Instrument Recognition.

Music Trees Supplementary code for the experiments described in the 2021 ISMIR submission: Leveraging Hierarchical Structures for Few Shot Musical Ins

Hugo Flores García 32 Nov 22, 2022
Lazy, a tool for running things in idle time

Lazy, a tool for running things in idle time Mostly used to stop CUDA ML model training from making my desktop unusable. Simply monitors keyboard/mous

N Shepperd 46 Nov 06, 2022
A Protein-RNA Interface Predictor Based on Semantics of Sequences

PRIP PRIP:A Protein-RNA Interface Predictor Based on Semantics of Sequences installation gensim==3.8.3 matplotlib==3.1.3 xgboost==1.3.3 prettytable==2

李优 0 Mar 25, 2022
Generic image compressor for machine learning. Pytorch code for our paper "Lossy compression for lossless prediction".

Lossy Compression for Lossless Prediction Using: Training: This repostiory contains our implementation of the paper: Lossy Compression for Lossless Pr

Yann Dubois 84 Jan 02, 2023
Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust.

Subspace Adversarial Training Single-step adversarial training (AT) has received wide attention as it proved to be both efficient and robust. However,

15 Sep 02, 2022
Annotated notes and summaries of the TensorFlow white paper, along with SVG figures and links to documentation

TensorFlow White Paper Notes Features Notes broken down section by section, as well as subsection by subsection Relevant links to documentation, resou

Sam Abrahams 437 Oct 09, 2022
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 21.3k Jan 01, 2023
Automated Attendance Project Using Face Recognition

dependencies for project: cmake 3.22.1 dlib 19.22.1 face-recognition 1.3.0 openc

Rohail Taha 1 Jan 09, 2022
Jarvis Project is a basic virtual assistant that uses TensorFlow for learning.

Jarvis_proyect Jarvis Project is a basic virtual assistant that uses TensorFlow for learning. Latest version 0.1 Features: Good morning protocol Tell

Anze Kovac 3 Aug 31, 2022
[Open Source]. The improved version of AnimeGAN. Landscape photos/videos to anime

[Open Source]. The improved version of AnimeGAN. Landscape photos/videos to anime

CC 4.4k Dec 27, 2022
HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation Official PyTorch Implementation

: We present a novel, real-time, semantic segmentation network in which the encoder both encodes and generates the parameters (weights) of the decoder. Furthermore, to allow maximal adaptivity, the w

Yuval Nirkin 182 Dec 14, 2022
FairMOT for Multi-Class MOT using YOLOX as Detector

FairMOT-X Project Overview FairMOT-X is a multi-class multi object tracker, which has been tailored for training on the BDD100K MOT Dataset. It makes

Jonathan Tan 33 Dec 28, 2022
This Deep Learning Model Predicts that from which disease you are suffering.

Deep-Learning-Project This Deep Learning Model Predicts that from which disease you are suffering. This Project Covers the Topics of Deep Learning Int

Jai Viral Doshi 0 Jan 20, 2022
Learnable Motion Coherence for Correspondence Pruning

Learnable Motion Coherence for Correspondence Pruning Yuan Liu, Lingjie Liu, Cheng Lin, Zhen Dong, Wenping Wang Project Page Any questions or discussi

liuyuan 41 Nov 30, 2022
Python and C++ implementation of "MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation". Accepted at LXCV @ CVPR 2021.

MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation This is a PyTorch and LibTorch implementation of MarkerPose: a

Jhacson Meza 47 Nov 18, 2022
Collection of TensorFlow2 implementations of Generative Adversarial Network varieties presented in research papers.

TensorFlow2-GAN Collection of tf2.0 implementations of Generative Adversarial Network varieties presented in research papers. Model architectures will

41 Apr 28, 2022
Charsiu: A transformer-based phonetic aligner

Charsiu: A transformer-based phonetic aligner [arXiv] Note. This is a preview version. The aligner is under active development. New functions, new lan

jzhu 166 Dec 09, 2022
NeuroFind - A solution to the to the Task given by the Oberseminar of Messtechnik Institute of TU Dresden in 2021

NeuroFind A solution to the to the Task given by the Oberseminar of Messtechnik

1 Jan 20, 2022
LAVT: Language-Aware Vision Transformer for Referring Image Segmentation

LAVT: Language-Aware Vision Transformer for Referring Image Segmentation Where we are ? 12.27 目前和原论文仍有1%左右得差距,但已经力压很多SOTA了 ckpt__448_epoch_25.pth mIoU

zichengsaber 60 Dec 11, 2022
The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic Neural Rendering"

Website | ArXiv | Get Start | Video PIRenderer The source code of the ICCV2021 paper "PIRenderer: Controllable Portrait Image Generation via Semantic

Ren Yurui 261 Jan 09, 2023