JAX-based neural network library

Overview

Haiku: Sonnet for JAX

Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku

pytest

What is Haiku?

Haiku is a tool
For building neural networks
Think: "Sonnet for JAX"

Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.

Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.

Disambiguation: if you are looking for Haiku the operating system then please see https://haiku-os.org/.

Overview

JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.

Haiku provides two core tools: a module abstraction, hk.Module, and a simple function transformation, hk.transform.

hk.Modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs.

hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit, jax.grad, jax.pmap, etc.

Why Haiku?

There are a number of neural network libraries for JAX. Why should you choose Haiku?

Haiku has been tested by researchers at DeepMind at scale.

  • DeepMind has reproduced a number of experiments in Haiku and JAX with relative ease. These include large-scale results in image and language processing, generative models, and reinforcement learning.

Haiku is a library, not a framework.

  • Haiku is designed to make specific things simpler: managing model parameters and other model state.
  • Haiku can be expected to compose with other libraries and work well with the rest of JAX.
  • Haiku otherwise is designed to get out of your way - it does not define custom optimizers, checkpointing formats, or replication APIs.

Haiku does not reinvent the wheel.

  • Haiku builds on the programming model and APIs of Sonnet, a neural network library with near universal adoption at DeepMind. It preserves Sonnet's Module-based programming model for state management while retaining access to JAX's function transformations.
  • Haiku APIs and abstractions are as close as reasonable to Sonnet. Many users have found Sonnet to be a productive programming model in TensorFlow; Haiku enables the same experience in JAX.

Transitioning to Haiku is easy.

  • By design, transitioning from TensorFlow and Sonnet to JAX and Haiku is easy.
  • Outside of new features (e.g. hk.transform), Haiku aims to match the API of Sonnet 2. Modules, methods, argument names, defaults, and initialization schemes should match.

Haiku makes other aspects of JAX simpler.

  • Haiku offers a trivial model for working with random numbers. Within a transformed function, hk.next_rng_key() returns a unique rng key.
  • These unique keys are deterministically derived from an initial random key passed into the top-level transformed function, and are thus safe to use with JAX program transformations.

Quickstart

Let's take a look at an example neural network and loss function.

import haiku as hk
import jax.numpy as jnp

def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

# There are two transforms in Haiku, hk.transform and hk.transform_with_state.
# If our network updated state during the forward pass (e.g. like the moving
# averages in hk.BatchNorm) we would need hk.transform_with_state, but for our
# simple MLP we can just use hk.transform.
loss_fn_t = hk.transform(loss_fn)

# MLP is deterministic once we have our parameters, as such we will not need to
# pass an RNG key to apply. without_apply_rng is a convenience wrapper that will
# make the rng argument to `loss_fn_t.apply` default to `None`.
loss_fn_t = hk.without_apply_rng(loss_fn_t)

hk.transform allows us to turn this function into a pair of pure functions: init and apply. All JAX transformations (e.g. jax.grad) require you to pass in a pure function for correct behaviour. Haiku makes it easy to write them.

The init function returned by hk.transform allows you to collect the initial value of any parameters in the network. Haiku does this by running your function, keeping track of any parameters requested through hk.get_parameter and returning them to you:

# Initial parameter values are typically random. In JAX you need a key in order
# to generate random numbers and so Haiku requires you to pass one in.
rng = jax.random.PRNGKey(42)

# `init` runs your function, as such we need an example input. Typically you can
# pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization
# is not usually data dependent.
images, labels = next(input_dataset)

# The result of `init` is a nested data structure of all the parameters in your
# network. You can pass this into `apply`.
params = loss_fn_t.init(rng, images, labels)

The params object is designed for you to inspect and manipulate. It is a mapping of module name to module parameters, where a module parameter is a mapping of parameter name to parameter value. For example:

{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
            'w': ndarray(..., shape=(28, 300), dtype=float32)},
 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
              'w': ndarray(..., shape=(1000, 100), dtype=float32)},
 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
              'w': ndarray(..., shape=(100, 10), dtype=float32)}}

The apply function allows you to inject parameter values into your function. Whenever hk.get_parameter is called the value returned will come from the params you provide as input to apply:

loss = loss_fn_t.apply(params, images, labels)

Since apply is a pure function we can pass it to jax.grad (or any of JAX's other transforms):

grads = jax.grad(loss_fn_t.apply)(params, images, labels)

Finally, we put this all together into a simple training loop:

def sgd(param, update):
  return param - 0.01 * update

for images, labels in input_dataset:
  grads = jax.grad(loss_fn_t.apply)(params, images, labels)
  params = jax.tree_multimap(sgd, params, grads)

Here we used jax.tree_multimap to apply the sgd function across all matching entries in params and grads. The result has the same structure as the previous params and can again be used with apply.

For more, see our examples directory. The MNIST example is a good place to start.

Installation

Haiku is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install Haiku using pip:

$ pip install git+https://github.com/deepmind/dm-haiku

Our examples rely on additional libraries (e.g. bsuite). You can install the full set of additional requirements using pip:

$ pip install -r examples/requirements.txt

User manual

Writing your own modules

In Haiku, all modules are a subclass of hk.Module. You can implement any method you like (nothing is special-cased), but typically modules implement __init__ and __call__.

Let's work through implementing a linear layer:

class MyLinear(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b

All modules have a name. When no name argument is passed to the module, its name is inferred from the name of the Python class (for example MyLinear becomes my_linear). Modules can have named parameters that are accessed using hk.get_parameter(param_name, ...). We use this API (rather than just using object properties) so that we can convert your code into a pure function using hk.transform.

When using modules you need to define functions and transform them into a pair of pure functions using hk.transform. See our quickstart for more details about the functions returned from transform:

def forward_fn(x):
  model = MyLinear(10)
  return model(x)

# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])

# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` is this is undesirable.
y = forward.apply(params, None, x)

Working with stochastic models

Some models may require random sampling as part of the computation. For example, in variational autoencoders with the reparametrization trick, a random sample from the standard normal distribution is needed. For dropout we need a random mask to drop units from the input. The main hurdle in making this work with JAX is in management of PRNG keys.

In Haiku we provide a simple API for maintaining a PRNG key sequence associated with modules: hk.next_rng_key() (or next_rng_keys() for multiple keys):

class MyDropout(hk.Module):

  def __init__(self, rate=0.5, name=None):
    super().__init__(name=name)
    self.rate = rate

  def __call__(self, x):
    key = hk.next_rng_key()
    p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
    return x * p / (1.0 - self.rate)

forward = hk.transform(lambda x: MyDropout()(x))

key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)

For a more complete look at working with stochastic models, please see our VAE example.

Note: hk.next_rng_key() is not functionally pure which means you should avoid using it alongside JAX transformations which are inside hk.transform. For more information and possible workarounds, please consult the docs on Haiku transforms and available wrappers for JAX transforms inside Haiku networks.

Working with non-trainable state

Some models may want to maintain some internal, mutable state. For example, in batch normalization a moving average of values encountered during training is maintained.

In Haiku we provide a simple API for maintaining mutable state that is associated with modules: hk.set_state and hk.get_state. When using these functions you need to transform your function using hk.transform_with_state since the signature of the returned pair of functions is different:

def forward(x, is_training):
  net = hk.nets.ResNet50(1000)
  return net(x, is_training)

forward = hk.transform_with_state(forward)

# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)

# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)

If you forget to use hk.transform_with_state don't worry, we will print a clear error pointing you to hk.transform_with_state rather than silently dropping your state.

Distributed training with jax.pmap

The pure functions returned from hk.transform (or hk.transform_with_state) are fully compatible with jax.pmap. For more details on SPMD programming with jax.pmap, look here.

One common use of jax.pmap with Haiku is for data-parallel training on many accelerators, potentially across multiple hosts. With Haiku, that might look like this:

def loss_fn(inputs, labels):
  logits = hk.nets.MLP([8, 4, 2])(x)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)

# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch():
  """Constructs a superbatch, i.e. one batch of data per device."""
  # Get N batches, then split into list-of-images and list-of-labels.
  superbatch = [next(input_dataset) for _ in range(num_devices)]
  superbatch_images, superbatch_labels = zip(*superbatch)
  # Stack the superbatches to be one array with a leading dimension, rather than
  # a python list. This is what `jax.pmap` expects as input.
  superbatch_images = np.stack(superbatch_images)
  superbatch_labels = np.stack(superbatch_labels)
  return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'):
  """Updates params based on performance on inputs and labels."""
  grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
  # Take the mean of the gradients across all data-parallel replicas.
  grads = jax.lax.pmean(grads, axis_name)
  # Update parameters using SGD or Adam or ...
  new_params = my_update_rule(params, grads)
  return new_params

# Run several training updates.
for _ in range(10):
  superbatch_images, superbatch_labels = make_superbatch()
  params = jax.pmap(update, axis_name='i')(params, superbatch_images,
                                           superbatch_labels)

For a more complete look at distributed Haiku training, take a look at our ResNet-50 on ImageNet example.

Citing Haiku

To cite this repository:

@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.3},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from haiku/__init__.py, and the year corresponds to the project's open-source release.

Comments
  • Is there a good way to save/load & compress/decompress model weights?

    Is there a good way to save/load & compress/decompress model weights?

    Hey- This is Chris. I'm using this open-source for my project.

    https://github.com/chris-chris/haiku-scalable-example

    Since I'm new to JAX and haiku, I have some questions.

    Is there a good way to save/load & compress/decompress & serialize model weights?

    • save/load model (network only or weight only)
    • compress/decompress weights
    • serialize

    I think serialization is an important issue on scalability. Can you give me some keywords or hints about this issue?

    Thanks!

    opened by chris-chris 10
  • Is there a way to share parameters between methods?

    Is there a way to share parameters between methods?

    You can write and transform multiple methods on the same module, but it doesn't seem possible to share parameters between them without manually merging the two parameter FlatMappings. It's particularly cumbersome if the shared parameters are used several submodules deep. Is there any more convenient way to accomplish something like this?

    opened by davisyoshida 7
  • Adds Identity initializer

    Adds Identity initializer

    This PR adds the initializers that sonnets has but are missing in Haiku. I haven't written tests for them yet since I don't know if there is interest in this PR, as soon as the Haiku team gives me the green light I will add the tests

    cla: yes 
    opened by joaogui1 7
  • He initialization

    He initialization

    The default initialization for linear and convolutional modules seems to be Glorot initialization, but for the commonly used ReLU activation function He initialization is superior, while only requiring a quick change to the stddev definition, should we implement better defaults? I know that there are many initialization schemes, I only suggest it as it would't be computationally expensive and would also be only a minor code change.

    enhancement 
    opened by joaogui1 7
  • Feeding in dictionary of data?

    Feeding in dictionary of data?

    Hey all!

    One thing I really enjoyed about Tensorflow was the feeddict option where I could then access the data by the keys to easily access and process the data in chunks. E.g

    {
        "data_to_be_embedded": ... #Some (batch_size, N , M) matrix
        "timeseries data": ... # (batch_size, timeseries_window)
    }
    

    I suppose one option would be to define multiple models and wrap them all into a single function which applies them key-value by key-value. Is there a more "idiomatic" way of doing this in Haiku?

    opened by IanQS 6
  • Correct way to transform and init a `hk.Module` with non-default parameter?

    Correct way to transform and init a `hk.Module` with non-default parameter?

    Hey all!

    I'm trying to run a linear regression example and I've got the following

    import jax.numpy as jnp
    from sklearn.datasets import load_boston
    import haiku as hk
    import optax
    import jax
    
    
    X, y = load_boston(return_X_y=True)
    train_X = jnp.asarray(X.tolist())
    train_y = jnp.asarray(y.tolist())
        
    class Model(hk.Module):
        def __init__(self, input_dims):
            super().__init__()
            self.input_dims = input_dims
        
        def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
            l1 = hk.Linear(self.input_dims)
            return l1(X)
        
    model = hk.transform(lambda x: Model()(x))  # <-- where I would specify the model shape if at all? 
    

    So I'm running into an issue where I'm not able to specify the model shape. If I do not specify it as in the above, I get the error of

    __init__() missing 1 required positional argument: 'input_dims'

    but if I do specify the shape via

    model = hk.transform(lambda x: Model(train_X.shape[1])(x))
    

    I get Argument '<function without_state.<locals>.init_fn at 0x7f1e5c616430>' of type <class 'function'> is not a valid JAX type.


    What is the recommended way of addressing this? I'm reading through hk.transform but I'm not sure. Looking at the code examples, there are __init__ functions without default args so I know it's possible.

    opened by IanQS 6
  • Efficient Ways for Saving and Loading weights

    Efficient Ways for Saving and Loading weights

    I'm sorry if it's not the right place as I could not find the discussions or forum page.

    I was wondering what are some of the most efficient ways to save and load models (also verify it's properly loaded into GPU)?

    1, In the docs, its given as save_the_model using Tensorflow 2. I also understand that the weights of haiku network are stored in a dictionary, as an example

    {'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
                'w': ndarray(..., shape=(28, 300), dtype=float32)},
     'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
                  'w': ndarray(..., shape=(1000, 100), dtype=float32)}}
    

    Does haiku have some inbuilt function to save and load models? It becomes crucial in transfer learning tasks. Thanks in advance,

    opened by VIGNESHinZONE 6
  • "NCHW" data_format in Conv not working with latest CUDA

    I'm not able to use the NCHW data format in conv layers:

    import os
    import numpy as np
    import jax
    import haiku as hk
    
    # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/opt/cuda"
    
    def net(x):
        model = hk.Sequential([hk.Conv2D(2, 5, padding="VALID", data_format="NCHW")])
        return model(x)
    
    key = jax.random.PRNGKey(42)
    net_transformed = hk.without_apply_rng(hk.transform(net))
    params = net_transformed.init(key, np.zeros((1, 1, 28, 28)))
    

    The snippet above works fine on the CPU but on the GPU gives tensorflow-style spew of errors below. The problem goes away if I change data_format to NHWC. I'm running pretty recent versions of nvidia driver and cuda and the same snippet seems to run on older versions (according to a few people I sent it to) so pretty sure it's related to those. My versions are:

    cuda 11.1.0-2
    nvidia driver: 455.38
    jax 0.2.5
    jaxlib 0.1.57+cuda111 
    dm-haiku 0.0.2
    

    Error:

    2020-11-17 12:18:03.717098: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.718623: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.719796: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.719969: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:772] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call = (f32[1,20,24,24]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %parameter.1, f32[5,5,1,20]{1,0,2,3} %copy.1), window={size=5x5}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n                      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(3, 2, 0, 1), out_spec=(0, 1, 2, 3))\n                      feature_group_count=1\n                      lhs_dilation=(1, 1)\n                      lhs_shape=(1, 1, 28, 28)\n                      padding=((0, 0), (0, 0))\n                      precision=None\n                      rhs_dilation=(1, 1)\n                      rhs_shape=(5, 5, 1, 20)\n                      window_strides=(1, 1) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 
    
    Convolution performance may be suboptimal.
    2020-11-17 12:18:03.800681: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.800721: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_client.cc:1809] Execution of replica 0 failed: Unimplemented: DNN library is not found.
    Traceback (most recent call last):
      File "scratch.py", line 17, in <module>
        params = net_transformed.init(key, np.zeros((1, 1, 28, 28)))
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/transform.py", line 111, in init_fn
        params, state = f.init(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/transform.py", line 277, in init_fn
        f(*args, **kwargs)
      File "scratch.py", line 12, in net
        return model(x)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 406, in wrapped
        out = f(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 263, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/basic.py", line 124, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 406, in wrapped
        out = f(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 263, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/conv.py", line 195, in __call__
        out = lax.conv_general_dilated(inputs,
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 571, in conv_general_dilated
        return conv_general_dilated_p.bind(
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/core.py", line 266, in bind
        out = top_trace.process_primitive(self, tracers, params)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/core.py", line 576, in process_primitive
        return primitive.impl(*tracers, **params)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/interpreters/xla.py", line 234, in apply_primitive
        return compiled_fun(*args)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/interpreters/xla.py", line 349, in _execute_compiled_primitive
        out_bufs = compiled.execute(input_bufs)
    RuntimeError: Unimplemented: DNN library is not found.
    
    opened by mil-ad 6
  • Iterating through hk modules

    Iterating through hk modules

    Let's say I want to iterate through all modules inside an hk model and replace all hn.Linears with my own custom Module or monkey-patch some of their properties. Does haiku currently support something along these lines?

    opened by mil-ad 6
  • Dealing with conditionally constant state

    Dealing with conditionally constant state

    How could I add a constant state to my haiku module? Specifically I would want something like this:

    class MyModule(hk.Module):
      def __init__(output_size, const, name):
        if const = True:
          self.b = hk.conts(jnp.ones(output_size)) //won't be updated when adding gradient
        else:
          self.b = jnp.zeros(output_size) //will get updated when adding gradient
    
    opened by joaogui1 6
  • FutureWarning: jax.tree_util.tree_multimap() is deprecated

    FutureWarning: jax.tree_util.tree_multimap() is deprecated

    Looks like dm-haiku is still using tree_multimap() which is now deprecated (resulting in annoying "future warning" messages with the latest jax)

    /usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py:189: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
      'instead as a drop-in replacement.', FutureWarning)
    
    opened by sokrypton 5
  • Argument `init` in `get_parameter` is not optional

    Argument `init` in `get_parameter` is not optional

    Hi,

    According to the docs, the init argument of get_parameter is optional, while in reality it raises the error: ValueError: Initializer must be specified. (See line 496 in base.py.)

    Example An example where the init=None might occur is if the initialisation is done outside Haiku. For example:

    import haiku as hk
    import jax.numpy as jnp
    
    @hk.without_apply_rng
    @hk.transform
    def foo(x):
        w = hk.get_parameter("w", [1], init=None)
        return x + w
    
    # Initialise params outside haiku, without `foo.init`.
    params = {'~': {'w': jnp.array([2.], dtype=jnp.float32)}}
    
    x = jnp.array([1])
    foo.apply(params, x)
    

    Kind regards,

    Hylke

    opened by hylkedonker 0
  • How to reinitialize the hidden states of RNNs?

    How to reinitialize the hidden states of RNNs?

    I want use initial_state in this way but get an error: AttributeError: 'Transformed' object has no attribute 'init_hidden_state' What is the best way to to this?

    import haiku as hk
    
    class RNN(hk.Module):
      def __init__(self, hidden_size=4, name=None):
        super().__init__(name=name)
        self.rnn = hk.LSTM(hidden_size)
    
      def __call__(self, h, x):
        out, h = self.rnn(x, h)
        return h, out
    
      def init_hidden_state(self, batch_size=1):
        return self.rnn.initial_state(batch_size)
    
    model = hk.without_apply_rng(hk.transform(lambda h, x: RNN(4)(h, x)))
    h = model.init_hidden_state(1)
    
    opened by qlan3 0
  • Correct way to integrate tf2jax output with a hk.Module

    Correct way to integrate tf2jax output with a hk.Module

    I'm looking at the tf2jax project, and the ability to take TensorFlow pretrained modules and convert them to haiku would be a really useful functionality, since there aren't a lot of available Haiku checkpoints. A typical application is something like

    import tf2jax
    import tensorflow as tf
    import jax.numpy as jnp
    jax_func, jax_params = tf2jax.convert(tf.function(tf.keras.applications.resnet50.ResNet50()), jnp.zeros((1, 224, 224, 3)))
    

    So now I have a function and parameters to do what I want, but I need to insert them into a Haiku module. How should I do this? I'm hoping for some way to eventually be able to

    class MyModule(hk.Module):
        def __call__(self, x):
            x = ResNet50Jax()(x)
            x = # some other module specific stuff
            return x
    

    that I can then proceed with hk.transform as usual. I wasn't able to find an obvious way to do this. Any thoughts?

    More broadly, is it a bad idea to rely on tf2jax for checkpoints, versus perhaps making the model directly in Haiku and manually copying over weights from PyTorch/tensorflow?

    opened by rdilip 1
  • Bump certifi from 2021.10.8 to 2022.12.7 in /docs

    Bump certifi from 2021.10.8 to 2022.12.7 in /docs

    Bumps certifi from 2021.10.8 to 2022.12.7.

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    • @dependabot use these labels will set the current labels as the default for future PRs for this repo and language
    • @dependabot use these reviewers will set the current reviewers as the default for future PRs for this repo and language
    • @dependabot use these assignees will set the current assignees as the default for future PRs for this repo and language
    • @dependabot use this milestone will set the current milestone as the default for future PRs for this repo and language

    You can disable automated security fix PRs for this repo from the Security Alerts page.

    dependencies 
    opened by dependabot[bot] 0
  • Mypy error from `next_rng_key` type inconsistency with jax `PRNGKeyArray`

    Mypy error from `next_rng_key` type inconsistency with jax `PRNGKeyArray`

    Hi,

    It seems that my mypy (version 0.942) is complaining that Haiku's random key generated by hk.next_rng_key() is not compatible with Jax's PRNGKeyArray type. The latter are the types of the key argument in various jax.random samplers.

    Example

    import jax
    import haiku as hk
    
    def sample_phi(alpha: float):
        phi = jax.random.gamma(hk.next_rng_key(), a=alpha)
        return phi
    

    Error

    example.py:5: error: Argument 1 to "gamma" has incompatible type "ndarray"; expected "Union[Array, PRNGKeyArray]"
    

    Apart from explicitly silencing these errors in mypy, are there any other suggestions to fix these errors?

    Thanks in advance,

    Hylke

    Environment

    dm-haiku==0.0.9
    jax==0.3.25
    jaxlib==0.3.25
    mypy==0.942
    
    opened by hylkedonker 0
Releases(v0.0.9)
  • v0.0.9(Nov 16, 2022)

    What's Changed

    • Support vmap where in_axes is a list rather than a tuple in https://github.com/deepmind/dm-haiku/commit/307cf7dbda64d637ca423cacc9978f0ca19dc8a6
    • Pass pmap axis specs optionally to make_model_info in https://github.com/deepmind/dm-haiku/commit/d0ba451c96a6ac4f44fb9457e252b1d675a5416a
    • Remove use of jax_experimental_name_stack flag in https://github.com/deepmind/dm-haiku/commit/dbc0b1f2ffee9b348a3cb67460f28f9cc4667f08
    • Add param_axis argument to RMSNorm to allow setting scale param shape in https://github.com/deepmind/dm-haiku/commit/a4998a02bc4e8303f9897e5c32ded90cc38fa84f
    • Add documentation and error messages for w_init and w_init_scale to avoid confusion in https://github.com/deepmind/dm-haiku/pull/541
    • Fix hk.while_loop carrying state when reserving variable sizes of rng keys. by @copybara-service in https://github.com/deepmind/dm-haiku/pull/551
    • Add ensemble example to hk.lift documentation. by @copybara-service in https://github.com/deepmind/dm-haiku/pull/556

    Full Changelog: https://github.com/deepmind/dm-haiku/compare/v0.0.8...v0.0.9

    Source code(tar.gz)
    Source code(zip)
  • v0.0.8(Sep 21, 2022)

    • Added experimental.force_name.
    • Added ability to simulate a method name in experimental.name_scope.
    • Added a config option for PRNG key block size.
    • Added unroll parameter to dynamic_unroll.
    • Remove use of deprecated jax.tree_* functions.
    • Many improvements to our examples.
    • Improve error messages in vmap.
    • Support jax_experimental_name_stack in jaxpr_info.
    • transform_and_run now supports a map on PRNG keys.
    • remat now uses the new JAX remat implementation.
    • Scale parameter is now optional in RMSNorm.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.7(Jul 4, 2022)

    Source code(tar.gz)
    Source code(zip)
  • v0.0.6(Feb 14, 2022)

    Source code(tar.gz)
    Source code(zip)
  • v0.0.5(Nov 1, 2021)

    • Added support for mixed precision training (dba1fd9) via jmp
    • Added hk.with_empty_state(..).
    • Added hk.multi_transform(..) (#137), supporting transforming multiple functions that share parameters.
    • Added hk.data_structures.is_subset(..) to test whether parameters are a subset of another.
    • Minimum Python version is now 3.7.
    • Multiple changes in preparation for a future version of Haiku changing to plain dicts.
    • hk.next_rng_keys(..) now returns a stacked array rather than a collection.
    • hk.MultiHeadAttention now supports distinct sequence lengths in query and key/value.
    • hk.LayerNorm now optionally supports faster (but less stable) variance computation.
    • hk.nets.MLP now has an output_shape property.
    • hk.nets.ResNet now supports changing strides.
    • UnexpectedTracerError inside a Haiku transform now has a more useful error message.
    • hk.{lift,custom_creator,custom_getter} are no longer experimental.
    • Haiku now supports JAX's pluggable RNGs.
    • We have made multiple improvements to our docs an error messages.

    Any many other small fixes and improvements.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.4(Apr 12, 2021)

    Changelog:

    • (Important Fix) Fixed strides in basic block (300e6a40be3).
    • Added map, partition_n and traverse to data_structures.
    • Added "build your own Haiku" to the docs.
    • Added summarise utility to Haiku.
    • Added visualisation section to docs.
    • Added precision arg to Linear, Conv and ConvTranspose.
    • Added RMSNorm.
    • Added module_name and name to GetterContext.
    • Added hk.eval_shape.
    • Improved performance of non cross-replica BN variance.
    • Haiku branch functions are only traced once (mirroring JAX).
    • Attention logits are rescaled before the softmax now.
    • ModuleMetaclass now inherits from Protocol.
    • Removed "dot access" to FlatMapping.
    • Removed query_size from MultiHeadAttention constructor.

    Any many other small fixes and improvements.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.3(Nov 24, 2020)

    Changelog:

    • Added hk.experimental.intercept_methods.
    • Added hk.running_init.
    • Added hk.experimental.name_scope.
    • Added optional support for state in custom_creator and custom_getter.
    • Added index groups to BatchNorm.
    • Added interactive notebooks to documentation, including basics guide.
    • Added support for batch major unrolls in static_unroll and dynamic_unroll.
    • Added hk.experimental.abstract_to_dot.
    • Added step markers in imagenet example.
    • Added hk.MultiHeadAttention.
    • Added option to remove double bias from VanillaRNN.
    • Added support for feature_group_count in ConvND.
    • Added logits config to resnet models.
    • Added various control flow primitives (fori_loop, switch, while_loop).
    • Added cross_replica_axis to VectorQuantizerEMA.
    • Added original_shape to ParamContext.
    • Added hk.SeparableDepthwiseConv2D.
    • Added support for unroll kwarg to hk.scan.
    • Added output_shape argument to ConvTranspose modules.
    • Replaced frozendict with FlatMapping, significantly reduces overheads calling jitted computations.
    • Misc changes to ensure parameter dtype follows input dtype.
    • Multiple changes to support JAX omnistaging.
    • ExponentialMovingAverage.initialize now takes shape/dtype not value.
    • Replaced optix with optax in examples.
    • hk.Embed embeddings now created lazily.
    • Re-indexed documentation for easier navigation.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Jul 29, 2020)

    Changelog:

    • Changed the default value of apply_rng to True in hk.transform to simplify the apply_fn signature.
    • Made ConvND, ConvNDTranspose, ResetCore and pooling modules optionally batched.
    • Added hk.GroupNorm.
    • Added hk.scan.
    • Changed hk.BatchNorm to always create state for moving averages.
    • Changed use_projection in hk.nets.ResNet to take a sequence of bools.
    • Exposed hk.net.ResNet.{BlockGroup, BlockV1, BlockV2}.
    • Added original_dtype to ParamContext to expose the original parameter dtype to custom_getters.
    • Added GAN example notebook.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Jun 4, 2020)

    Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.

    Changelog:

    Features:

    • Exposed hk.nets.ResNet and addeed hk.nets.ResNet{18,34,101,152,200}
    • Added IdentityCore.
    • Added custom_getter API for advanced parameter manipulation.
    • Added ConvND and lifted N<=3 restriction.
    • Added tree_size and tree_bytes to easily compute parameter counts.
    • hk.remat now only threads changed values (faster compilation).
    • Added support for @dataclass to define modules.
    • Added support for splitting >1 key at a time k1, k2 = hk.next_rng_keys(2).
    • Experimental: Added profiler_name_scopes API to add Haiku names to XProf.
    • Experimental: Added optimize_rng_use to improve compilation time for models with lots of RNG keys.

    Examples:

    • Added language model example.
    • Added VQVAE example.

    Bug fixes:

    • LayerNorm now correctly handles bf16 inputs.
    • TruncatedNormal initializer now respects dtype.

    Usability:

    • Improved error messages for get_parameter, to_module and others.
    • Reimplemented core modules with "public" API (easier to read and fork).
    • Added tests that ensure all public symbols are included in documentation.
    • Added type annotations to more internal code.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.1-beta(Mar 26, 2020)

    Changes

    Examples

    • Added VAE example.
    • Added pruning example (https://arxiv.org/abs/1710.01878).
    • MNIST example uses 300-100-10 MLP.
    • Updated imagenet dataset to return correctly scaled examples.

    Breaking changes

    • State arg to hk.transform dropped in favor of transform_with_state.
    • Decay argument is now required in BatchNorm.

    Features

    • Added hk.maybe_next_rng_key().
    • BatchNorm and LayerNorm speed improvements.
    • Added support for partition/filter/merge params.
    • Haiku now allows running with jax_numpy_rank_promotion.

    Experimental features

    • hk.experimental.to_dot - experimental visualisation support.
    • hk.experimental.lift - experimental purification support.

    Usability

    • Improved error message when RNG arg is not and RNG.
    • Improved documentation.
    • Improved test coverage.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.1-alpha(Feb 20, 2020)

Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics

Dataset Cartography Code for the paper Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics at EMNLP 2020. This repository cont

AI2 125 Dec 22, 2022
Repo for "Event-Stream Representation for Human Gaits Identification Using Deep Neural Networks"

Summary This is the code for the paper Event-Stream Representation for Human Gaits Identification Using Deep Neural Networks by Yanxiang Wang, Xian Zh

zhangxian 54 Jan 03, 2023
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
CAMoE + Dual SoftMax Loss (DSL): Improving Video-Text Retrieval by Multi-Stream Corpus Alignment and Dual Softmax Loss

CAMoE + Dual SoftMax Loss (DSL): Improving Video-Text Retrieval by Multi-Stream Corpus Alignment and Dual Softmax Loss This is official implement of "

程星 87 Dec 24, 2022
Code for "PVNet: Pixel-wise Voting Network for 6DoF Pose Estimation" CVPR 2019 oral

Good news! We release a clean version of PVNet: clean-pvnet, including how to train the PVNet on the custom dataset. Use PVNet with a detector. The tr

ZJU3DV 722 Dec 27, 2022
Fluency ENhanced Sentence-bert Evaluation (FENSE), metric for audio caption evaluation. And Benchmark dataset AudioCaps-Eval, Clotho-Eval.

FENSE The metric, Fluency ENhanced Sentence-bert Evaluation (FENSE), for audio caption evaluation, proposed in the paper "Can Audio Captions Be Evalua

Zhiling Zhang 13 Dec 23, 2022
YOLOv5 + ROS2 object detection package

YOLOv5-ROS YOLOv5 + ROS2 object detection package This program changes the input of detect.py (ultralytics/yolov5) to sensor_msgs/Image of ROS2. Requi

Ar-Ray 23 Dec 19, 2022
fcn by tensorflow

Update An example on how to integrate this code into your own semantic segmentation pipeline can be found in my KittiSeg project repository. tensorflo

9 May 22, 2022
HiFi++: a Unified Framework for Neural Vocoding, Bandwidth Extension and Speech Enhancement

HiFi++ : a Unified Framework for Neural Vocoding, Bandwidth Extension and Speech Enhancement This is the unofficial implementation of Vocoder part of

Rishikesh (ऋषिकेश) 118 Dec 29, 2022
A minimalist tool to display a network graph.

A tool to get a minimalist view of any architecture This tool has only be tested with the models included in this repo. Therefore, I can't guarantee t

Thibault Castells 1 Feb 11, 2022
This repository contains the code for the CVPR 2021 paper "GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields"

GIRAFFE: Representing Scenes as Compositional Generative Neural Feature Fields Project Page | Paper | Supplementary | Video | Slides | Blog | Talk If

1.1k Dec 30, 2022
Using this you can control your PC/Laptop volume by Hand Gestures (pinch-in, pinch-out) created with Python.

Hand Gesture Volume Controller Using this you can control your PC/Laptop volume by Hand Gestures (pinch-in, pinch-out). Code Firstly I have created a

Tejas Prajapati 16 Sep 11, 2021
Hierarchical Memory Matching Network for Video Object Segmentation (ICCV 2021)

Hierarchical Memory Matching Network for Video Object Segmentation Hongje Seong, Seoung Wug Oh, Joon-Young Lee, Seongwon Lee, Suhyeon Lee, Euntai Kim

Hongje Seong 72 Dec 14, 2022
GPU implementation of $k$-Nearest Neighbors and Shared-Nearest Neighbors

GPU implementation of kNN and SNN GPU implementation of $k$-Nearest Neighbors and Shared-Nearest Neighbors Supported by numba cuda and faiss library E

Hyeon Jeon 7 Nov 23, 2022
ICLR 2021, Fair Mixup: Fairness via Interpolation

Fair Mixup: Fairness via Interpolation Training classifiers under fairness constraints such as group fairness, regularizes the disparities of predicti

Ching-Yao Chuang 49 Nov 22, 2022
Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding (AAAI 2020) - PyTorch Implementation

Scalable Attentive Sentence-Pair Modeling via Distilled Sentence Embedding PyTorch implementation for the Scalable Attentive Sentence-Pair Modeling vi

Microsoft 25 Dec 02, 2022
Pytorch implementation of the paper "Optimization as a Model for Few-Shot Learning"

Optimization as a Model for Few-Shot Learning This repo provides a Pytorch implementation for the Optimization as a Model for Few-Shot Learning paper.

Albert Berenguel Centeno 238 Jan 04, 2023
Pytorch implementation of Deep Recursive Residual Network for Super Resolution (DRRN)

DRRN-pytorch This is an unofficial implementation of "Deep Recursive Residual Network for Super Resolution (DRRN)", CVPR 2017 in Pytorch. [Paper] You

yun_yang 192 Dec 12, 2022
ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information

ChineseBERT: Chinese Pretraining Enhanced by Glyph and Pinyin Information This repository contains code, model, dataset for ChineseBERT at ACL2021. Ch

413 Dec 01, 2022
NBEATSx: Neural basis expansion analysis with exogenous variables

NBEATSx: Neural basis expansion analysis with exogenous variables We extend the NBEATS model to incorporate exogenous factors. The resulting method, c

Cristian Challu 100 Dec 31, 2022