Composable transformations of Python+NumPy programsComposable transformations of Python+NumPy programs

Related tags

Deep Learningchex
Overview

Chex

CI status

Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions)
  • Debug (e.g. transforming pmaps in vmaps within a context manager).
  • Test JAX code across many variants (e.g. jitted vs non-jitted).

Installation

Chex can be installed with pip directly from github, with the following command:

pip install git+git://github.com/deepmind/chex.git

or from PyPI:

pip install chex

Modules Overview

Dataclass (dataclass.py)

Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.

In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.

Chex implementation of dataclass registers dataclasses as internal PyTree nodes to ensure compatibility with JAX data structures.

In addition, we provide a class wrapper that exposes dataclasses as collections.Mapping descendants which allows to process them (e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries. See @mappable_dataclass docstring for more details.

Example:

@chex.dataclass
class Parameters:
  x: chex.ArrayDevice
  y: chex.ArrayDevice

parameters = Parameters(
    x=jnp.ones((2, 2)),
    y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

NOTE: Unlike standard Python 3.7 dataclasses, Chex dataclasses cannot be constructed using positional arguments. They support construction arguments provided in the same format as the Python dict constructor. Dataclasses can be converted to tuples with the from_tuple and to_tuple methods if necessary.

parameters = Parameters(
    jnp.ones((2, 2)),
    jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

Assertions (asserts.py)

One limitation of PyType annotations for JAX is that they do not support the specification of DeviceArray ranks, shapes or dtypes. Chex includes a number of functions that allow flexible and concise specification of these properties.

E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

More examples:

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3))                # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)])      # x is scalar and y has shape (2, 3)

assert_rank(x, 0)                      # x is scalar
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar OR rank-2 arrays

assert_type(x, int)                    # x has type `int` (x can be an array)
assert_type([x, y], [int, float])      # x has type `int` and y has type `float`

assert_equal_shape([x, y, z])          # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite

assert_devices_available(2, 'gpu')     # 2 GPUs available
assert_tpu_available()                 # at least 1 TPU available

assert_numerical_grads(f, (x, y), j)   # f^{(j)}(x, y) matches numerical grads

All chex assertions support the following optional kwargs for manipulating the emitted exception messages:

  • custom_message: A string to include into the emitted exception messages.
  • include_default_message: Whether to include the default Chex message into the emitted exception messages.
  • exception_type: An exception type to use. AssertionError by default.

For example, the following code:

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
  params = update_params(params, dataset.sample())
  chex.assert_tree_all_finite(params,
                              custom_message=f'Failed at iteration {i}.',
                              exception_type=ValueError)

will raise a ValueError that includes a step number when params get polluted with NaNs or Nones.

JAX re-traces JIT'ted function every time the structure of passed arguments changes. Often this behavior is inadvertent and leads to a significant performance drop which is hard to debug. @chex.assert_max_traces decorator asserts that the function is not re-traced more than n times during program execution.

Global trace counter can be cleared by calling chex.clear_trace_counter(). This function be used to isolate unittests relying on @chex.assert_max_traces.

Examples:

  @jax.jit
  @chex.assert_max_traces(n=1)
  def fn_sum_jitted(x, y):
    return x + y

  z = fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))
  t = fn_sum_jitted(jnp.zeros(6, 7), jnp.zeros(6, 7))  # AssertionError!

Can be used with jax.pmap() as well:

  def fn_sub(x, y):
    return x - y

  fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))

More about tracing

See documentation of asserts.py for details on all supported assertions.

Test variants (variants.py)

JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.

Variants make it easy to ensure that unit tests cover different ‘variations’ of a function, by providing a simple decorator that can be used to repeat any test under all (or a subset) of the relevant code transformations.

E.g. suppose you want to test the output of a function fn with or without jit. You can use chex.variants to run the test with both the jitted and non-jitted version of the function by simply decorating a test method with @chex.variants, and then using self.variant(fn) in place of fn in the body of the test.

def fn(x, y):
  return x + y
...

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    var_fn = self.variant(fn)
    self.assertEqual(fn(1, 2), 3)
    self.assertEqual(var_fn(1, 2), fn(1, 2))

If you define the function in the test method, you may also use self.variant as a decorator in the function definition. For example:

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(1, 2), 3)

Example of parameterized test:

from absl.testing import parameterized

# Could also be:
#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
#  `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  @parameterized.named_parameters(
      ('case_positive', 1, 2, 3),
      ('case_negative', -1, -2, -3),
  )
  def test(self, arg_1, arg_2, expected):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(arg_1, arg_2), expected)

Chex currently supports the following variants:

  • with_jit -- applies jax.jit() transformation to the function.
  • without_jit -- uses the function as is, i.e. identity transformation.
  • with_device -- places all arguments (except specified in ignore_argnums argument) into device memory before applying the function.
  • without_device -- places all arguments in RAM before applying the function.
  • with_pmap -- applies jax.pmap() transformation to the function (see notes below).

See documentation in variants.py for more details on the supported variants. More examples can be found in variants_test.py.

Variants notes

  • Test classes that use @chex.variants must inherit from chex.TestCase (or any other base class that unrolls tests generators within TestCase, e.g. absl.testing.parameterized.TestCase).

  • [jax.vmap] All variants can be applied to a vmapped function; please see an example in variants_test.py (test_vmapped_fn_named_params and test_pmap_vmapped_fn).

  • [@chex.all_variants] You can get all supported variants by using the decorator @chex.all_variants.

  • [with_pmap variant] jax.pmap(fn) (doc) performs parallel map of fn onto multiple devices. Since most tests run in a single-device environment (i.e. having access to a single CPU or GPU), in which case jax.pmap is a functional equivalent to jax.jit, with_pmap variant is skipped by default (although it works fine with a single device). Below we describe a way to properly test fn if it is supposed to be used in multi-device environments (TPUs or multiple CPUs/GPUs). To disable skipping with_pmap variants in case of a single device, add --chex_skip_pmap_variant_if_single_device=false to your test command.

Fakes (fake.py)

Debugging in JAX is made more difficult by code transformations such as jit and pmap, which introduce optimizations that make code hard to inspect and trace. It can also be difficult to disable those transformations during debugging as they can be called at several places in the underlying code. Chex provides tools to globally replace jax.jit with a no-op transformation and jax.pmap with a (non-parallel) jax.vmap, in order to more easily debug code in a single-device context.

For example, you can use Chex to fake pmap and have it replaced with a vmap. This can be achieved by wrapping your code with a context manager:

with chex.fake_pmap():
  @jax.pmap
  def fn(inputs):
    ...

  # Function will be vmapped over inputs
  fn(inputs)

The same functionality can also be invoked with start and stop:

fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()

In addition, you can fake a real multi-device test environment with a multi-threaded CPU. See section Faking multi-device test environments for more details.

See documentation in fake.py and examples in fake_test.py for more details.

Faking multi-device test environments

In situations where you do not have easy access to multiple devices, you can still test parallel computation using single-device multi-threading.

In particular, one can force XLA to use a single CPU's threads as separate devices, i.e. to fake a real multi-device environment with a multi-threaded one. These two options are theoretically equivalent from XLA perspective because they expose the same interface and use identical abstractions.

Chex has a flag chex_n_cpu_devices that specifies a number of CPU threads to use as XLA devices.

To set up a multi-threaded XLA environment for absl tests, define setUpModule function in your test module:

def setUpModule():
  chex.set_n_cpu_devices()

Now you can launch your test with python test.py --chex_n_cpu_devices=N to run it in multi-device regime. Note that all tests within a module will have an access to N devices.

More examples can be found in variants_test.py, fake_test.py and fake_set_n_cpu_devices_test.py.

Citing Chex

To cite this repository:

@software{chex2020github,
  author = {David Budden and Matteo Hessel and Iurii Kemaev and Stephen Spencer
            and Fabio Viola},
  title = {Chex: Testing made fun, in JAX!},
  url = {http://github.com/deepmind/chex},
  version = {0.0.1},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from chex/__init__.py, and the year corresponds to the project's open-source release.

Comments
  • [chex] Allow an ellipsis in the expected shape passed to `assert_shape`.

    [chex] Allow an ellipsis in the expected shape passed to `assert_shape`.

    [chex] Allow an ellipsis in the expected shape passed to assert_shape.

    This allows things like:

    chex.assert_shape(a, [..., seq_len, features])
    

    This is particularly useful for situations like variable numbers of batch dimensions.

    cla: yes 
    opened by copybara-service[bot] 8
  • CpuDevice no longer in jax

    CpuDevice no longer in jax

    Hello,

    Seems like the newest version of jax (0.3.7) removed some classes that are used here in chex. Should chex upper bound the jax version? I see this conflicting code is not currently on the main branch -- alternatively, maybe a new release can be made?

    https://github.com/google/jax/pull/10326

    opened by adamgayoso 4
  • `AssertsChexifyTest.test_uninspected_checks` test failure

    `AssertsChexifyTest.test_uninspected_checks` test failure

    I'm seeing the following test failure when running the test suite:

    ============================= test session starts ==============================
    platform linux -- Python 3.10.7, pytest-7.1.3, pluggy-1.0.0
    rootdir: /build/source
    collected 548 items                                                            
    
    chex/chex_test.py .                                                      [  0%]
    chex/_src/asserts_chexify_test.py ......F.....                           [  2%]
    chex/_src/asserts_internal_test.py .s.s.........                         [  4%]
    chex/_src/asserts_test.py ..s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s. [ 13%]
    s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s..................... [ 26%]
    ........................................................................ [ 39%]
    ........................................................................ [ 52%]
    .................................                                        [ 58%]
    chex/_src/dataclass_test.py ...........................................  [ 66%]
    chex/_src/dimensions_test.py .................                           [ 69%]
    chex/_src/fake_set_n_cpu_devices_test.py s                               [ 69%]
    chex/_src/fake_test.py ................................                  [ 75%]
    chex/_src/restrict_backends_test.py ssssssssss                           [ 77%]
    chex/_src/variants_test.py .....................s....s............s....s [ 85%]
    ..........................................................ssssssssssssss [ 98%]
    sssssss                                                                  [100%]
    
    =================================== FAILURES ===================================
    __________________ AssertsChexifyTest.test_uninspected_checks __________________
    
    self = <chex._src.asserts_chexify_test.AssertsChexifyTest testMethod=test_uninspected_checks>
    
        def test_uninspected_checks(self):
        
          @jax.jit
          def _pos_sum(x):
            chex_value_assert_positive(x, custom_message='err_label')
            return x.sum()
        
          invalid_x = -jnp.ones(3)
          chexify_async(_pos_sum)(invalid_x)  # async error
        
    >     with self.assertRaisesRegex(AssertionError, 'err_label'):
    E     AssertionError: AssertionError not raised
    
    chex/_src/asserts_chexify_test.py:179: AssertionError
    ------------------------------ Captured log call -------------------------------
    WARNING  absl:asserts_chexify.py:57 [Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.
    =============================== warnings summary ===============================
    chex/_src/asserts_chexify_test.py: 12 warnings
      /build/source/chex/_src/asserts_chexify_test.py:58: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
        return jnp.all(jnp.array([(x > 0).all() for x in jax.tree_leaves(tree)]))
    
    chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__with_jit
    chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__without_jit
      /build/source/chex/_src/asserts_chexify_test.py:86: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
        return sum(x.sum() for x in jax.tree_leaves(tree))
    
    chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
      /nix/store/4y9j6xdkgqwkdx5ki508l175smcjgs9l-python3.10-pytest-7.1.3/lib/python3.10/site-packages/_pytest/unraisableexception.py:78: PytestUnraisableExceptionWarning: Exception ignored in atexit callback: <function _check_if_hanging_assertions at 0x7ffddfe66d40>
      
      Traceback (most recent call last):
        File "/build/source/chex/_src/asserts_chexify.py", line 32, in _check_error
          checkify.check_error(err)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 476, in check_error
          return assert_p.bind(err, code, payload, msgs=error.msgs)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 328, in bind
          return self.bind_with_trace(find_top_trace(args), args, params)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 331, in bind_with_trace
          out = trace.process_primitive(self, map(trace.full_raise, args), params)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 698, in process_primitive
          return primitive.impl(*tracers, **params)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 483, in assert_impl
          raise_error(Error(err, code, msgs, payload))
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 123, in raise_error
          raise ValueError(err)
      ValueError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] (check failed at /build/source/chex/_src/asserts_internal.py:229 (_chex_assert_fn))
      
      During handling of the above exception, another exception occurred:
      
      Traceback (most recent call last):
        File "/build/source/chex/_src/asserts_chexify.py", line 62, in _check_if_hanging_assertions
          block_until_chexify_assertions_complete()
        File "/build/source/chex/_src/asserts_chexify.py", line 51, in block_until_chexify_assertions_complete
          wait_fn()
        File "/build/source/chex/_src/asserts_chexify.py", line 180, in _wait_checks
          _check_error(async_check_futures.popleft().result(async_timeout))
        File "/build/source/chex/_src/asserts_chexify.py", line 40, in _check_error
          raise AssertionError(msg)  # pylint:disable=raise-missing-from
      AssertionError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] 
      
        warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
    
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
      /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
        if not all((x > 0).all() for x in jax.tree_leaves(tree)):
    
    -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
    =========================== short test summary info ============================
    FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
    ====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
    error: builder for '/nix/store/f9icjsb9pbz4p8qpsyhp9gq1fvjvwwhz-python3.10-chex-0.1.5.drv' failed with exit code 1;
           last 10 log lines:
           > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
           > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
           > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
           >   /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
           >     if not all((x > 0).all() for x in jax.tree_leaves(tree)):
           >
           > -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
           > =========================== short test summary info ============================
           > FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
           > ====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
    

    I'm using

    • jax v0.3.23
    • jaxlib v0.3.22
    • absl-py v1.2.0
    • dm-tree from commit https://github.com/deepmind/tree/commit/b452e5c2743e7489b4ba7f16ecd51c516d7cd8e3
    • numpy 1.23.3
    • toolz 0.12.0
    opened by samuela 3
  • AttributeError: module 'jax' has no attribute '_src'

    AttributeError: module 'jax' has no attribute '_src'

    trying to import optax and getting an error AttributeError: module 'jax' has no attribute '_src' for jax versions > 0.3.17

    optax version == 0.1.3 chex version == 0.1.3

    In [1]: import optax
    /home/penn/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
      PyTreeDef = type(jax.tree_structure(None))
    ---------------------------------------------------------------------------
    AttributeError                            Traceback (most recent call last)
    Input In [1], in <cell line: 1>()
    ----> 1 import optax
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/__init__.py:17, in <module>
          1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
          2 #
          3 # Licensed under the Apache License, Version 2.0 (the "License");
       (...)
         13 # limitations under the License.
         14 # ==============================================================================
         15 """Optax: composable gradient processing and optimization, in JAX."""
    ---> 17 from optax import experimental
         18 from optax._src.alias import adabelief
         19 from optax._src.alias import adafactor
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/experimental/__init__.py:20, in <module>
          1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
          2 #
          3 # Licensed under the Apache License, Version 2.0 (the "License");
       (...)
         13 # limitations under the License.
         14 # ==============================================================================
         15 """Experimental features in Optax.
         16 
         17 Features may be removed or modified at any time.
         18 """
    ---> 20 from optax._src.experimental.complex_valued import split_real_and_imaginary
         21 from optax._src.experimental.complex_valued import SplitRealAndImaginaryState
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/_src/experimental/complex_valued.py:32, in <module>
         15 """Complex-valued optimization.
         16 
         17 When using `split_real_and_imaginary` to wrap an optimizer, we split the complex
       (...)
         27 See details at https://github.com/deepmind/optax/issues/196
         28 """
         30 from typing import NamedTuple, Union
    ---> 32 import chex
         33 import jax
         34 import jax.numpy as jnp
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/__init__.py:17, in <module>
          1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
          2 #
          3 # Licensed under the Apache License, Version 2.0 (the "License");
       (...)
         13 # limitations under the License.
         14 # ==============================================================================
         15 """Chex: Testing made fun, in JAX!"""
    ---> 17 from chex._src.asserts import assert_axis_dimension
         18 from chex._src.asserts import assert_axis_dimension_comparator
         19 from chex._src.asserts import assert_axis_dimension_gt
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts.py:26, in <module>
         23 import unittest
         24 from unittest import mock
    ---> 26 from chex._src import asserts_internal as _ai
         27 from chex._src import pytypes
         28 import jax
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts_internal.py:32, in <module>
         29 from typing import Any, Sequence, Union, Callable, Optional, Set, Tuple, Type
         31 from absl import logging
    ---> 32 from chex._src import pytypes
         33 import jax
         34 import jax.numpy as jnp
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:44, in <module>
         40 Device = jax.lib.xla_extension.Device
         42 ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
    ---> 44 ArrayDType = jax._src.numpy.lax_numpy._ScalarMeta
    
    AttributeError: module 'jax' has no attribute '_src'
    
    opened by jenkspt 3
  • Use dataclass_transform to help type checkers with @chex.dataclass

    Use dataclass_transform to help type checkers with @chex.dataclass

    closes #155

    I basically copied this example: https://peps.python.org/pep-0681/#id1

    Tested with pyright/pylance.

    I had to specify a return type for chex.dataclass because otherwise pyright/pylance is ignoring it completely if arguments are passed to it (like chex.dataclass(eq=False)), but if it's used bare (just chex.dataclass without any parentheses) then it also works without the return type annotation.

    There is a (expected) test failure from pytype:

    FAILED: /home/tmk/dev/python/chex/.pytype/pyi/chex/_src/dataclass.pyi 
    /tmp/chex-env/bin/python3 -m pytype.single --imports_info /home/tmk/dev/python/chex/.pytype/imports/chex._src.dataclass.imports --module-name chex._src.dataclass --platform linux -V 3.9 -o /home/tmk/dev/python/chex/.pytype/pyi/chex/_src/dataclass.pyi --analyze-annotated --nofail --quick /home/tmk/dev/python/chex/chex/_src/dataclass.py
    File "/home/tmk/dev/python/chex/chex/_src/dataclass.py", line 90, in <module>: typing_extensions.dataclass_transform not supported yet [not-supported-yet]
    

    I'm not sure how to deal with that.

    There is also not really a way to write tests for this...

    cc @hbq1

    opened by thomkeh 3
  • Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    cla: no 
    opened by copybara-service[bot] 3
  • Using variants with pytest

    Using variants with pytest

    Hi,

    First of all thank you for this very useful library !

    I have a project in Jax in which I already implemented my tests using pytest. However the possibility that chex.variants offers are too nice to ignore. Simultaneously I would like not to rewrite all my test.

    Is there a way to reconcile pytest and chex ?

    Thank you again for all the work! Best,

    opened by pablo2909 2
  • [chex] Allow `set`s of alternatives in expected shape for `assert_shape`.

    [chex] Allow `set`s of alternatives in expected shape for `assert_shape`.

    [chex] Allow sets of alternatives in expected shape for assert_shape.

    This extends the behavior allowed by assert_rank to assert_shape, enabling things like:

    chex.assert_shape(mask, (batch_size, {num_heads, 1}, q_len, kv_len))
    

    In this example, axis 1 can either be num_heads or 1, which is helpful, for example, in situations where you want to allow only particular dimensions to be broadcastable.

    cla: yes 
    opened by copybara-service[bot] 2
  • [REQ] Conda recipe

    [REQ] Conda recipe

    Hi, I'm the lead developer of NetKet, an established machine learning / quantum physics package.

    We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version. Since many physicists seem to use anaconda, we would also like to update our conda recipe. However, since we depend on optax (and therefore on Chex), we would need Chex to have a Conda recipe.

    Is that something you'd consider? I am willing to volunteer some work to help you.

    I tried creating a recipe starting from your pypi source distribution, but that is problematic because you don't bundle your requirements.txt file, which is required to run setup.py. I could create a recipe from the tag tarballs on GitHub, but that sometimes prevent the conda packages from auto-updating the recipe for later releases.

    opened by PhilipVinc 2
  • Chex dataclass defaulting mappable_dataclass=True

    Chex dataclass defaulting mappable_dataclass=True

    To start with, thanks for open sourcing your work on Chex, it's a great tooling library for building robust Jax applications!

    As I was upgrading to the latest release 0.0.3, I noticed quite a few of my tests breaking. It happens that the default option mappable_dataclass=True in chex.dataclass is breaking the usual interface of dataclasses (which is clearly expected reading the code documentation!)

    I guess probably from the perspective of Deepmind usage, it makes sense to default this option. But from an external user point of view, it is rather surprising to have a dataclass decorator not behaving like a dataclass. I think it would be great to make it clear in the library readme that this option needs to be turned off to get the full dataclass behaviour (or turned it off by default).

    opened by balancap 2
  • Add ability to check shapes with wildcards

    Add ability to check shapes with wildcards

    I often find myself writing the following sort of thing:

    chex.assert_rank(x, 2)
    x.shape[1] == num_actions, "some custom message ..."
    

    It would be nice to be able to simply check the shape with a wildcard, i.e.

    chex.assert_shape(x, (None, num_actions))
    

    What do you think?

    cla: yes 
    opened by KristianHolsheimer 2
  • [chex] Add `assert_trees_all_equal_shapes_and_dtypes`

    [chex] Add `assert_trees_all_equal_shapes_and_dtypes`

    [chex] Add assert_trees_all_equal_shapes_and_dtypes

    This is purely a convenience function, asserting both assert_trees_all_equal_shapes and assert_trees_all_equal_dtypes.

    opened by copybara-service[bot] 0
  • Improve support for custom `__init__` methods in dataclasses.

    Improve support for custom `__init__` methods in dataclasses.

    Improve support for custom __init__ methods in dataclasses.

    Chex dataclasses assume the dataclass has a default constructor, which is necessary for flatten/unflatten. This change allows custom initializers by keeping an internal reference to a default initializer for use with flatten/unflatten.

    opened by copybara-service[bot] 0
  • post_init error in inherited dataclass

    post_init error in inherited dataclass

    When inheriting one dataclass from another, Chex's dataclass does not allow a super() call to be made. This is something you can do in Python's base dataclass module.

    A minimum working example is

    from chex import dataclass as dataclass
    
    @dataclass
    class ChexBase:
        a : int 
    
        def __post_init__(self):
            self.b = self.a + 1
    
    @dataclass
    class ChexSub(ChexBase):
        a: int 
    
        def __post_init__(self):
            super().__post_init__()
            self.c = self.a + 2
    
    temp = ChexSub(a = 1)
    temp.b
    

    Importing dataclass from dataclasses runs without error and returns 2, as expected.

    Environment

    • Chex version 0.1.5
    • Ubuntu 20.04
    • Python 3.9
    opened by thomaspinder 1
Releases(v0.1.5)
  • v0.1.5(Sep 13, 2022)

    What's Changed

    • Add support for value assertions in jitted functions. by @copybara-service in https://github.com/deepmind/chex/pull/178
    • [JAX] Avoid private implementation detail _ScalarMeta. by @copybara-service in https://github.com/deepmind/chex/pull/180
    • [JAX] Avoid implicit references to jax._src. by @copybara-service in https://github.com/deepmind/chex/pull/181
    • Release v0.1.15 by @copybara-service in https://github.com/deepmind/chex/pull/184

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.4...v0.1.5

    Source code(tar.gz)
    Source code(zip)
  • v0.1.4(Aug 4, 2022)

    What's Changed

    • Add an InitVar field in the dataclass tests. by @copybara-service in https://github.com/deepmind/chex/pull/161
    • Download latest .pylintrc version in tests. by @copybara-service in https://github.com/deepmind/chex/pull/167
    • Fix assert_axis_dimension_comparator usages. by @copybara-service in https://github.com/deepmind/chex/pull/168
    • Update "jax.tree_util" functions by @copybara-service in https://github.com/deepmind/chex/pull/171
    • Use jax.tree_util.tree_map in place of deprecated tree_multimap. by @copybara-service in https://github.com/deepmind/chex/pull/175
    • Silence some pytype errors. by @copybara-service in https://github.com/deepmind/chex/pull/174
    • Add chex.Dimensions utility for readable shape asserts. by @copybara-service in https://github.com/deepmind/chex/pull/169

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.3...v0.1.4

    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Apr 19, 2022)

    What's Changed

    • Slight helping clarification to clear_trace_counter. by @lucasb-eyer in https://github.com/deepmind/chex/pull/148
    • Add new JAX-specific pytypes to chex pytypes. by @copybara-service in https://github.com/deepmind/chex/pull/153
    • Remove chex.{C,G,T}puDevice in favour of chex.Device. by @copybara-service in https://github.com/deepmind/chex/pull/154

    New Contributors

    • @lucasb-eyer made their first contribution in https://github.com/deepmind/chex/pull/148

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.2...v0.1.3

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Mar 31, 2022)

    What's Changed

    • Support JAX parallel operations in chex.fake_pmap contexts by @copybara-service in https://github.com/deepmind/chex/pull/142
    • Remove references to jax.numpy.lax_numpy. by @copybara-service in https://github.com/deepmind/chex/pull/150

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.1...v0.1.2

    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(Feb 25, 2022)

    What's Changed

    • Move dataclass registration to init so that it's invoked after deserialization. by @copybara-service in https://github.com/deepmind/chex/pull/111
    • Add pytype for jax array's dtypes. by @copybara-service in https://github.com/deepmind/chex/pull/112
    • Fix dataclass registration on deserialization. by @copybara-service in https://github.com/deepmind/chex/pull/114
    • Fix restrict_backends after jax.xla.backend_compile was moved by @copybara-service in https://github.com/deepmind/chex/pull/116
    • Refactor asserts.py and warn users not to rely on asserts_internal's functionality. by @copybara-service in https://github.com/deepmind/chex/pull/117
    • Set up ReadTheDoc pages and add a few examples. by @copybara-service in https://github.com/deepmind/chex/pull/118
    • Include Sphinx builds into CI tests. by @copybara-service in https://github.com/deepmind/chex/pull/119
    • Adds internal functionality by @copybara-service in https://github.com/deepmind/chex/pull/122
    • Update Chex citation. by @copybara-service in https://github.com/deepmind/chex/pull/125
    • Refactor assertions in preparation for including them into the RTD docs. by @copybara-service in https://github.com/deepmind/chex/pull/126
    • Add asserts, variants, and pytypes modules to the RTD docs. by @copybara-service in https://github.com/deepmind/chex/pull/127
    • Fix references to collections.abc.Mappable -> collections.abc.Mapping in docs and comments. collections.abc.Mappable does not exist. by @copybara-service in https://github.com/deepmind/chex/pull/129
    • Document the rational behing the mappability of chex.dataclasses. by @copybara-service in https://github.com/deepmind/chex/pull/130
    • Add 3 new tree assertions: by @copybara-service in https://github.com/deepmind/chex/pull/131
    • Add assert_tree_is_sharded for asserting that a tree is sharded across the specified devices. by @copybara-service in https://github.com/deepmind/chex/pull/132
    • Add PyTreeDef to pytypes. by @copybara-service in https://github.com/deepmind/chex/pull/134
    • Disallow ShardedDeviceArrays in assert_tree_is_on_host and assert_tree_is_on_device. by @copybara-service in https://github.com/deepmind/chex/pull/133
    • Bump ipython from 7.16.1 to 7.16.3 in /requirements by @dependabot in https://github.com/deepmind/chex/pull/135
    • Remove the old venv directory before testing the package. by @copybara-service in https://github.com/deepmind/chex/pull/138
    • Refactor asserts.py in preparation for experimental device assertions. by @copybara-service in https://github.com/deepmind/chex/pull/137
    • Fix minor typo in docs. by @copybara-service in https://github.com/deepmind/chex/pull/139
    • Improve exception message for assert_tree_shape_prefix. by @copybara-service in https://github.com/deepmind/chex/pull/143
    • Release v0.1.1 by @copybara-service in https://github.com/deepmind/chex/pull/146

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.0...v0.1.1

    Source code(tar.gz)
    Source code(zip)
  • v0.1.0(Nov 18, 2021)

  • v0.0.9(Nov 16, 2021)

    It is the latest version compatible with Python 3.6. See https://github.com/deepmind/optax/issues/222 for more details.

    Changes since 0.0.8:

    • Use rtol=1e-6 in asserts.assert_tree_close;
    • Added asserts.assert_trees_all_equal;
    • Removed restricted_inheritance option from Chex dataclasses;
    • Added dims= option to assert_equal_shape, to check a subset of dims;
    • Added test.sh for launching CI tests on a local machine;
    • Added support for default exception messages and types to assertions;
    • Added support for jnp.bfloat16 to asserts.assert_trees_all_close();
    • Added support for static_argnames to variants.with_jit;
    • Added a restrict_backends module for constraining the set of backends that a region of code can use;
    • Added asserts.assert_trees_all_equal_dtypes assertion;
    • Exposed asserts.assert_tree_shape_suffix to the public API;
    • Added asserts.assert_tree_shape_suffix to check whether arrays share the same suffix.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.8(Jul 2, 2021)

    Changes:

    • Add support for static_broadcasted_argnums to fake_pmap;
    • Allows sets of alternatives and ellipsis in assert_shape;
    • Format @variant test names to use only underscores and lowercase letters;
    • Fix incorrect type annotation in asserts.py;
    • Fix dataclass (un-)flatten functions;
    • Add more tests for dataclasses;
    • Raise ValueError when no variants are selected;
    • Exclude chex' internal frames from AssertionError tracebacks;
    • Add '[Chex] ' prefix to AssertionError messages;
    • Include path to leaves that failed the equality check in assert_tree_all_close;
    • Clean up asserts.py;
    • Asserts which only make sense on >1 tree now demand this (can result in breakages in the existing code).
    Source code(tar.gz)
    Source code(zip)
  • v0.0.7(May 4, 2021)

    Changelog

    Full Changelog

    Closed issues:

    • [REQ] Conda recipe #37

    Merged pull requests:

    * This Changelog was automatically generated by github_changelog_generator

    Source code(tar.gz)
    Source code(zip)
  • v0.0.6(Mar 25, 2021)

  • v0.0.5(Mar 22, 2021)

    Changelog

    Note: this is a first GitHub release of Chex. It includes all changes since the repo was created.

    Full Changelog

    Closed issues:

    • Chex dataclass throws an exception in Python 3.9 #10
    • 'jax.interpreters.xla' has no attribute '_DeviceArray' for jax <= 0.2.5 #9
    • Chex dataclass defaulting mappable_dataclass=True #8
    • DeprecationWarning for importing toolz #4
    • Fake contexts by calling .start() not working #3

    Merged pull requests:

    * This Changelog was automatically generated by github_changelog_generator

    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
This repo contains code to reproduce all experiments in Equivariant Neural Rendering

Equivariant Neural Rendering This repo contains code to reproduce all experiments in Equivariant Neural Rendering by E. Dupont, M. A. Bautista, A. Col

Apple 83 Nov 16, 2022
[NeurIPS'21] "AugMax: Adversarial Composition of Random Augmentations for Robust Training" by Haotao Wang, Chaowei Xiao, Jean Kossaifi, Zhiding Yu, Animashree Anandkumar, and Zhangyang Wang.

AugMax: Adversarial Composition of Random Augmentations for Robust Training Haotao Wang, Chaowei Xiao, Jean Kossaifi, Zhiding Yu, Anima Anandkumar, an

VITA 112 Nov 07, 2022
Code release for Local Light Field Fusion at SIGGRAPH 2019

Local Light Field Fusion Project | Video | Paper Tensorflow implementation for novel view synthesis from sparse input images. Local Light Field Fusion

1.1k Dec 27, 2022
PartImageNet is a large, high-quality dataset with part segmentation annotations

PartImageNet: A Large, High-Quality Dataset of Parts We will release our dataset and scripts soon after cleaning and approval. Introduction PartImageN

Ju He 77 Nov 30, 2022
Implement some metaheuristics and cost functions

Metaheuristics This repot implement some metaheuristics and cost functions. Metaheuristics JAYA Implement Jaya optimizer without constraints. Cost fun

Adri1G 1 Mar 23, 2022
Residual Dense Net De-Interlace Filter (RDNDIF)

Residual Dense Net De-Interlace Filter (RDNDIF) Work in progress deep de-interlacer filter. It is based on the architecture proposed by Bernasconi et

Louis 7 Feb 15, 2022
Gradient representations in ReLU networks as similarity functions

Gradient representations in ReLU networks as similarity functions by Dániel Rácz and Bálint Daróczy. This repo contains the python code related to our

1 Oct 08, 2021
Asymmetric Bilateral Motion Estimation for Video Frame Interpolation, ICCV2021

ABME (ICCV2021) Junheum Park, Chul Lee, and Chang-Su Kim Official PyTorch Code for "Asymmetric Bilateral Motion Estimation for Video Frame Interpolati

Junheum Park 86 Dec 28, 2022
Code for BMVC2021 "MOS: A Low Latency and Lightweight Framework for Face Detection, Landmark Localization, and Head Pose Estimation"

MOS-Multi-Task-Face-Detect Introduction This repo is the official implementation of "MOS: A Low Latency and Lightweight Framework for Face Detection,

104 Dec 08, 2022
Unofficial PyTorch implementation of Google AI's VoiceFilter system

VoiceFilter Note from Seung-won (2020.10.25) Hi everyone! It's Seung-won from MINDs Lab, Inc. It's been a long time since I've released this open-sour

MINDs Lab 883 Jan 07, 2023
Pytorch-Swin-Unet-V2 - a modified version of Swin Unet based on Swin Transfomer V2

Swin Unet V2 Swin Unet V2 is a modified version of Swin Unet arxiv based on Swin

Chenxu Peng 26 Dec 03, 2022
Official implementation of Protected Attribute Suppression System, ICCV 2021

Official implementation of Protected Attribute Suppression System, ICCV 2021

Prithviraj Dhar 6 Jan 01, 2023
Turning pixels into virtual points for multimodal 3D object detection.

Multimodal Virtual Point 3D Detection Turning pixels into virtual points for multimodal 3D object detection. Multimodal Virtual Point 3D Detection, Ti

Tianwei Yin 204 Jan 08, 2023
Visualizing lattice vibration information from phonon dispersion to atoms (For GPUMD)

Phonon-Vibration-Viewer (For GPUMD) Visualizing lattice vibration information from phonon dispersion for primitive atoms. In this tutorial, we will in

Liangting 6 Dec 10, 2022
Codebase for the paper titled "Continual learning with local module selection"

This repository contains the codebase for the paper Continual Learning via Local Module Composition. Setting up the environemnt Create a new conda env

Oleksiy Ostapenko 20 Dec 10, 2022
An atmospheric growth and evolution model based on the EVo degassing model and FastChem 2.0

EVolve Linking planetary mantles to atmospheric chemistry through volcanism using EVo and FastChem. Overview EVolve is a linked mantle degassing and a

Pip Liggins 2 Jan 17, 2022
Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)

Maximum Likelihood Training of Score-Based Diffusion Models This repo contains the official implementation for the paper Maximum Likelihood Training o

Yang Song 84 Dec 12, 2022
Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral)

Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral) Tianyu Wang*, Xiaowei Hu*, Chi-Wing Fu, and Pheng-Ann Hen

Steve Wong 51 Oct 20, 2022
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 03, 2023
PaddleRobotics is an open-source algorithm library for robots based on Paddle, including open-source parts such as human-robot interaction, complex motion control, environment perception, SLAM positioning, and navigation.

简体中文 | English PaddleRobotics paddleRobotics是基于paddle的机器人开源算法库集,包括人机交互、复杂运动控制、环境感知、slam定位导航等开源算法部分。 人机交互 主动多模交互技术TFVT-HRI 主动多模交互技术是通过视觉、语音、触摸传感器等输入机器人

185 Dec 26, 2022