Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

Overview

tests badge pypi badge docs badge license badge

coax

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

readthedocs

For the full documentation, including many examples, go to https://coax.readthedocs.io/

Install

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version. To install without CUDA, simply run:

$ pip install jaxlib jax coax --upgrade

If you do require CUDA support, please check out the Installation Guide.

Getting Started

Have a look at the Getting Started page to train your first RL agent.


Comments
  • Quantile Q-Learning Implementation

    Quantile Q-Learning Implementation

    This PR adds a QuantileQ class with function types 3 and 4 that accept a number of quantiles together with the state (and action), as well as a QuantileQLearning class. The QuantileQ function could be merged into the Q class which would simplify the user-facing API. However, some more work needs to be done to incorporate the QuantileQLearning class into the QLearning class. I just wanted to validate that this is the correct approach to take to implement the IQN.

    There still is some documentation for the quantile huber loss missing and the notebooks need to be added and tuned.

    Closes https://github.com/coax-dev/coax/issues/3

    opened by frederikschubert 11
  • Add DeepMind Control Suite Example

    Add DeepMind Control Suite Example

    This PR is a rework of https://github.com/coax-dev/coax/pull/26 and adds an example for using SAC on the Walker.walk task from the DeepMind Control Suite.

    Depends on https://github.com/coax-dev/coax/pull/27 and https://github.com/coax-dev/coax/pull/28

    opened by frederikschubert 6
  • Assertion assert_equal_shape failed for MultiDiscrete action space

    Assertion assert_equal_shape failed for MultiDiscrete action space

    First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

    So my questions are:

    1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
    2. Did I declare my policy function properly?
    3. Any general ideas on how to debug JAX functions?

    I would appreciate any feedback. Thank you!

    Error message

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    Input In [25], in <cell line: 5>()
         13     transition_batch = tracer.pop()
         14     Gn = transition_batch.Rn
    ---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
         16     env.record_metrics(metrics)
         17 if done:
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
        127 def update(self, transition_batch, Adv):
        128     r"""
        129 
        130     Update the model parameters (weights) of the underlying function approximator.
       (...)
        147 
        148     """
    --> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
        150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
        151         raise RuntimeError(f"found nan's in grads: {grads}")
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
        212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
        213     warnings.warn(
        214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
        215         "should be non-zero. Please sample actions with their propensities: "
        216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
        217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
    --> 218 return self._grad_and_metrics_func(
        219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
        220     transition_batch, Adv)
    
    File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
         58 def __call__(self, *args, **kwargs):
    ---> 59     return self._jitted_func(*args, **kwargs)
    
        [... skipping hidden 14 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
         77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
         78     grads_func = jax.grad(loss_func, has_aux=True)
         79     grads, (metrics, state_new) = \
    ---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
         82     # add some diagnostics of the gradients
         83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))
    
        [... skipping hidden 10 frame]
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
         45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
         46     objective, (dist_params, log_pi, state_new) = \
    ---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
         49     # flip sign to turn objective into loss
         50     loss = -objective
    
    File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
         49 W = jnp.clip(transition_batch.W, 0.1, 10.)
         51 # some consistency checks
    ---> 52 chex.assert_equal_shape([W, Adv, log_pi])
         53 chex.assert_rank([W, Adv, log_pi], 1)
         54 objective = W * Adv * log_pi
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
        195 else:
        196   try:
    --> 197     host_assertion(*args, **kwargs)
        198   except jax.errors.ConcretizationTypeError as exc:
        199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
        200            "likely that it tried to access tensors' values during tracing. "
        201            "Make sure that you defined a jittable version of this Chex "
        202            "assertion.")
    
    File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
        154     custom_message = custom_message.format(*custom_message_format_vars)
        155   error_msg = f"{error_msg} [{custom_message}]"
    --> 157 raise exception_type(error_msg)
    
    AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].
    

    Example data

    ExampleData(
      inputs=Inputs(
        args=ArgsType2(
          S={
            'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
          is_training=True)
        static_argnums=(
          1))
      output=(
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
        {
          'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))
    

    Policy function

    def pi(S, is_training):
        module = CustomizedModule()
        res = tuple([{"logits": item} for item in module(S["features"])])
        return res
    
    question 
    opened by xiangyuy 5
  • 'linear/w' does not match shape

    'linear/w' does not match shape

    I've been starting to learn about RL and have been trying to get coax up and running, but have run into an issue that I'm not sure how to resolve. I'm doing Q-learning on a custom gym environment, and I can run the following pieces successfully:

    q = coax.Q(func_q, env)
    pi = coax.Policy(func_pi, env)
    
    qlearning = coax.td_learning.QLearning(q, pi_targ=pi, optimizer=optax.adam(0.001))
    cache = coax.reward_tracing.NStep(n=1, gamma=0.9)
    

    Additionally, my setup passes the simple checks of:

    data = coax.Q.example_data(env) # Looks good
    ...
    s = env.observation_space.sample()
    a = env.action_space.sample()
    print(q(s,a)) # 0.0
    ...
    a = pi(s)
    print(a) # [0, 0, 0, 0, 0] as I have a MultiDiscrete action space
    

    However, once I get to actually running the training loop:

    for ep in range(50):
      pi.epsilon = 0.1
      s = env.reset()
    
      for t in range(env.maxGuesses):
        a = pi(s)
        s_next, r, done, info = env.step(a)
    
        # update
        cache.add(s, a, r, done)
    
        while cache:
          transition_batch = cache.pop()
          metrics = qlearning.update(transition_batch)
          env.record_metrics(metrics)
    
        if done:
          break
    
        s = s_next
    
        # early stopping
        if env.avg_G > env.reward_threshold:
          break
    

    I get a bunch of errors with the most human-readable of them saying:

    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    By adjusting the parameters of the environment, I can adjust what the numbers that are mismatched are. I can't get them to match and either way that seems like the wrong solution as something more fundamental seems to be the issue.

    For reference, here are my functions for q and pi:

    def func_pi(S, is_training):
      logits = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(Wordle.wordLength*len(alphabet), w_init=jnp.zeros) # This many possible actions
      ))
      # First, convert to a vector:
      sVec = state_to_vec(S)
    
      # Now get the output:
      logitVec = logits(sVec)
    
      # Now chunk the output into alphabet-sized pieces (definitionally an integral
      # number of them). There will be Wordle.wordLength chunks of this length
      chunks = jnp.split(logitVec, Wordle.wordLength)
    
      # Now format our output array:
      ret = []
      for chunk in chunks:
        ret.append({'logits': jnp.reshape(chunk,(1,len(alphabet)))})
    
      return tuple(ret)
    
    # and for actual state:
    def func_q(S, A, is_training):
      value = hk.Sequential((
        hk.Linear(30), jax.nn.relu, 
        hk.Linear(30), jax.nn.relu,
        hk.Linear(30), jax.nn.relu,
        hk.Linear(1, w_init=jnp.zeros), jnp.ravel
      ))
    
      sVec = state_to_vec(S)
      aVec = action_to_vec(A)
    
      X = jnp.concatenate((sVec, aVec))
      return value(X)
    

    Note that state_to_vec(S) and action_to_vec(A) just convert from my internal types to jnp.array's for use with Haiku.

    I'm quite new to coax/JAX/Haiku so it's entirely possible I've set something up wrong. For completeness here's the full text of the error:

    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 596, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
        compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
        ans = call(fun, *args)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
        return lower_xla_callable(fun, device, backend, name, donated_invars,
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1623, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 996, in grad_f_aux
        (_, aux), g = value_and_grad_f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 1067, in value_and_grad_f
        ans, vjp_py, aux = _vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 2478, in _vjp
        out_primal, out_vjp, aux = ad.vjp(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 118, in vjp
        out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 103, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
        return func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 520, in trace_to_jaxpr
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
        out_flat = xla.xla_call(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 324, in process_call
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 204, in process_call
        jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 317, in partial_eval
        out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
        outs = top_trace.process_call(primitive, fun, tracers, params)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1364, in process_call
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers_)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
        metrics = qlearning.update(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
        grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
        return self._grads_and_metrics_func(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
        grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
        Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
        return self._jitted_func(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
        out = f(*args, **kwargs)
      File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
        return value(X)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
        out = f(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
        w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
      File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
        raise ValueError(
    ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')
    

    Please let me know if other information would be useful or relevant (or let me know if this isn't actually a coax issue...).

    Thanks for your help and the neat package.

    bug good first issue question 
    opened by bcerjan 5
  • DQN pong example doesn't work off the shelf

    DQN pong example doesn't work off the shelf

    Describe the bug

    Running the DQN example on pong generates the following error when generating a gif:

      File ".../lib/python3.9/site-packages/coax/utils/_misc.py", line 475, in generate_gif
        assert env.render_mode == 'rgb_array', "env.render_mode must be 'rgb_array'"
    

    This is likely due to some recent updates to gym. Currently, on gym==0.26.2 I observe the following:

    import gym
    env = gym.make('PongNoFrameskip-v4', render_mode="rgb_array")
    print(env.render_mode) # prints None
    
    opened by thisiscam 4
  • Add dm_control example for SAC

    Add dm_control example for SAC

    This PR introduces the common squashed normal distribution for the SAC policy on dm_control and provides an example that solves the walker.walk task. Interestingly clipping the actions to the range [-1, 1] diverges. rendering

    @KristianHolsheimer How would you go about changing the installation script for this notebook to add dm_control as a dependency?

    opened by frederikschubert 4
  • Frozen Lake example has an invalid gym signature.

    Frozen Lake example has an invalid gym signature.

    Describe the bug

    The example for Frozen Lake in the main branch of the docs isn't fully updated for the new version of gym's signature.

    ValueError Traceback (most recent call last) in 77 78 a = pi.mode(s) ---> 79 s, r, done, info = env.step(a) 80 81 env.render()

    ValueError: too many values to unpack (expected 4)

    Expected behavior

    Executing the notebook should not result in a ValueError.

    To Reproduce

    Colab notebook to repro the bug:

    - https://colab.research.google.com/...

    Runtime used for this colab notebook: ... (e.g. CPU/GPU/TPU)

    Any.

    Additional context

    Simple fix, happy to contribute a pull request.

    opened by dbleyl 3
  • Incorporating jax.jit into a customer policy

    Incorporating jax.jit into a customer policy

    I'm a bit new to JAX so my question might sound very naive. Suppose we are trying to solve a policy optimization problem through REINFORCE algorithm and suppose we already have our environment at hand (env). We define our customer policy as follows,

    class CustomPolicy(hk.Module):
        def __init__(self, name = None):
            super().__init__(name = name)
        
    
        def __call__(self, x):
            w = hk.get_parameter("w", shape= ... , dtype = x.dtype, init=jnp.zeros)
            # some computation
            return out
    

    Per the documentation, then we define

    def custom_policy(S, is_training=True):
        logits = CustomPolicy()
        return {'logits': logits(S)}
    

    and finally the policy is stated as follows,

    pi = coax.Policy(custom_policy, env)

    I was wondering is there any way to incorporate @jax.jit into this structure to further quicken the performance. Thanks.

    question 
    opened by UweGensheimer 3
  • Multi-Step Entropy Regularization for SAC

    Multi-Step Entropy Regularization for SAC

    • Add record_extra_info flag to the NStep tracer that records the intermediate states in the new extra_info field to TransitionBatch
    • Add support for the NStepEntropyRegularizer in SoftPG

    This PR contains an initial working implementation of the mechanism and sums um the discounted entropy bonuses of the states s_t, s_{t + 1}, ... , s_{t + n - 1} for the soft policy gradient regularization.

    opened by frederikschubert 3
  • Implementation of SAC

    Implementation of SAC

    Since SAC is really similar to TD3, we are able to re-use most of its components. The differences are:

    • The actions to update the q-functions and policy are sampled using the current policy (instead of taking the mode).
    • There is no target policy.
    • The log variance of the policy depends on the state.
    • The policy is entropy regularized.

    The current implementation does not support multi-step td-learning.

    opened by frederikschubert 3
  • AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    AttributeError: module 'jax.api' has no attribute '_jit_is_disabled'

    Hi, unsure if this is a due to coax or jax but I get this error when running the pendulum ppo example, dqn runs fine however.

    A similar error I found online recommended changing the version of jaxlib so I went to using the jaxlib version set out in the coax getting started guide but seemed to have no affect jax version = 0.2.13 jaxlib version = 0.1.65 + cuda111 coax version = 0.1.6

    question 
    opened by mmcaulif 3
  • Recurrent Experience Replay

    Recurrent Experience Replay

    Is your feature request related to a problem? Please describe.

    It seems that the implemented replay buffers only operate over transitions, with no ability to operate over entire sequences. This prevents the use of recurrent policies for tackling POMDPs.

    Describe the solution you'd like

    A SequenceReplayBuffer that returns contiguous episodes instead of shuffled transitions.

    Describe alternatives you've considered

    Additional context

    enhancement 
    opened by smorad 3
  • MiniMax Algorithm?

    MiniMax Algorithm?

    How would you implement a minimax q-learner with coax?

    Hi there! I love the package and how accessible it is to relative newbies. The tutorials are pretty great and the accompanying videos are very helpful!

    I was wondering what the best way to implement a minimax algorithm would be, would you recommend using two policies pi1 and pi2? Or is there something better suited for this?

    I'd like to re-implement something like this old blogpost of mine in coax to get a better feel of the library.

    Any help would be greatly appreciated :)

    question 
    opened by flaport 1
  • Convert Numpy Docstrings to Google Style

    Convert Numpy Docstrings to Google Style

    This issue tracks the progress of converting the numpy style docstrings to the more concise Google style.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    This depends on the type annotations https://github.com/coax-dev/coax/issues/13 for easier automatic conversions.

    enhancement 
    opened by frederikschubert 0
  • Add Type Annotations

    Add Type Annotations

    This issue tracks the progress of adding type annotations to coax.

    • [ ] _core
    • [ ] experience_replay
    • [ ] model_updaters
    • [ ] policy_objectives
    • [ ] proba_dists
    • [ ] reward_tracing
    • [ ] td_learning
    • [ ] utils
    • [ ] value_transforms
    • [ ] wrappers

    The types are added by utilising pyannotate and adding the following snippet to the coax._base.TestCase class:

    ...
    @classmethod
        def setUpClass(cls) -> None:
            collect_types.init_types_collection()
            collect_types.start()
    
        @classmethod
        def tearDownClass(cls) -> None:
            collect_types.stop()
            type_replacements = {
                "jaxlib.xla_extension.DeviceArray": "jax.numpy.ndarray",
                "haiku._src.data_structures.FlatMapping": "typing.Mapping",
                "coax._core.policy_test": "gym.Env"
            }
            types_str = collect_types.dumps_stats()
            for inferred_type, replacement in type_replacements.items():
                types_str = types_str.replace(inferred_type, replacement)
            with open(sys.modules[cls.__module__].__file__.replace(".py", "_types.json"), "w") as f:
                f.write(types_str)
    ...
    

    and the types are added automatically

    for t in coax/**/*_test_types.json
    do
        pyannotate --type-info $t -3 coax/* -w
    done
    
    enhancement 
    opened by frederikschubert 0
  • PPOClip grad update seems to cause inf update

    PPOClip grad update seems to cause inf update

    Describe the bug Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.

    During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.

    Expected behavior

    ...
    adv = np.random.rand(32)
    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
    print("grads", grads)
    print(ppo_clip._pi.params)
    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
    print(ppo_clip._pi.params)
    

    Results in:

    grads FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([ 0.0477 , -0.02505, -0.05048,  0.02798], dtype=float16),
                  'w': DeviceArray([[ 0.01338 , -0.01921 , -0.01038 ,  0.01622 ],
                                    [ 0.02406 , -0.01683 , -0.02039 ,  0.01316 ],
                                    [ 0.0332  , -0.0227  , -0.03108 ,  0.02061 ],
                                    ...,
                                    [ 0.02452 , -0.00956 , -0.01997 ,  0.005024],
                                    [ 0.010025,  0.001724, -0.03467 ,  0.02295 ],
                                    [ 0.01886 , -0.01413 , -0.01494 ,  0.01022 ]], dtype=float16),
                }),
    
    FlatMapping({
      'linear': FlatMapping({
                  'w': DeviceArray([[-1.0124e-02,  3.4389e-03,  2.9316e-03,  6.5498e-03],
                                    [ 3.3302e-03, -1.7233e-03, -3.0422e-03, -1.8060e-04],
                                    [-2.8908e-05, -3.3131e-03, -6.1073e-03,  6.5804e-03],
                                    ...,
                                    [-2.5597e-03,  7.3471e-03, -3.6221e-03, -5.6801e-03],
                                    [-7.3471e-03, -3.7746e-03,  5.8746e-03,  6.1531e-03],
                                    [-1.1940e-03,  6.9733e-03, -5.0507e-03,  3.4218e-03]],            dtype=float16),
                  'b': DeviceArray([0., 0., 0., 0.], dtype=float16),
                }),
    })
    
    FlatMapping({
      'linear': FlatMapping({
                  'b': DeviceArray([-0.001002,  0.000978,  0.001001, -0.001007], dtype=float16),
                  'w': DeviceArray([[-0.01111  ,  0.004448 ,  0.00386  ,  0.00551  ],
                                    [ 0.002354 , -0.0007563, -0.002048 , -0.001162 ],
                                    [-0.001021 , -0.002335 , -0.005104 ,  0.005558 ],
                                    ...,
                                    [-0.003561 ,  0.008224 , -0.002628 ,       -inf],
                                    [-0.00828  ,       -inf,  0.006874 ,  0.00515  ],
                                    [-0.002203 ,  0.00804  , -0.004086 ,  0.002493 ]],            dtype=float16),
                }),
    

    Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:

    import os
    from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
    from luxai2021.game.game import Game
    from luxai2021.game.actions import *
    from luxai2021.game.constants import LuxMatchConfigs_Default
    
    from luxai2021.env.agent import Agent, AgentWithTeamModel
    import numpy as np
    
    from agent import TeamAgent
    
    # set some env vars
    os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu')     # tell JAX to use GPU
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1'  # don't use all gpu mem
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'              # tell XLA to be quiet
    
    import gym
    import jax
    import coax
    import haiku as hk
    import jax.numpy as jnp
    from optax import adam
    
    
    # the name of this script
    name = 'ppo'
    
    configs = LuxMatchConfigs_Default
    
    player = TeamAgent(mode="train")
    opponent = Agent()
    
    env = LuxEnvironment(configs=configs,
                                    learning_agent=player,
                                    opponent_agent=opponent)
    env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
    
    def func_pi(S, is_training):
        n_actions = 4
        out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
        return out
    
    def func_v(S, is_training):
        h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
        return h
    
    '''
    def func_pi(S, is_training):
        #print(env.action_space.shape)
        n_filters = 5
        n_actions = 4
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))
        
        print('h', type(h), h.shape)
        h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
        h_head_actions = hk.Linear(n_actions)(h_head)
        print('h_head_actions', type(h_head_actions), h_head_actions.shape)
        #print(h_head_actions)
    
        out = {'logits': h_head_actions}
        
        return out
    
    def func_v(S, is_training):
        n_filters = 5
        n_layers = 3
    
        h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
        for layer in range(n_layers):
            h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))
    
        h = hk.Flatten()(h)
        h = jax.nn.relu(hk.Linear(64)(h))
        h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))
        
        return h
    '''
    
    
    # function approximators
    pi = coax.Policy(func_pi, env)
    v = coax.V(func_v, env)
    
    # target networks
    pi_behavior = pi.copy()
    v_targ = v.copy()
    
    # policy regularizer (avoid premature exploitation)
    entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)
    
    # updaters
    simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
    ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))
    
    # reward tracer and replay buffer
    tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
    buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)
    
    # run episodes
    max_episode_steps = 400
    while env.T < 3000000:
        s = env.reset()
    
        for t in range(max_episode_steps):
            print(t)
            a, logp = pi_behavior(s, return_logp=True)
            s_next, r, done, info = env.step(a)
    
            # trace rewards and add transition to replay buffer
            tracer.add(s, a, r, done, logp)
            while tracer:
                buffer.add(tracer.pop())
    
            # learn
            if len(buffer) >= buffer.capacity:
                num_batches = int(4 * buffer.capacity / 32)  # 4 epochs per round
                for i in range(num_batches):
                    transition_batch = buffer.sample(32)
                    grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
                    metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
    
                    
                    adv = np.random.rand(32)
                    grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
                    print("grads", grads)
                    print(ppo_clip._pi.params)
                    metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
                    print(ppo_clip._pi.params)
                    exit()
                    env.record_metrics(metrics_pi)
                    env.record_metrics(metrics_v)
                    
    
                buffer.clear()
    
                # sync target networks
                pi_behavior.soft_update(pi, tau=0.1)
                v_targ.soft_update(v, tau=0.1)
    
            if done:
                break
    
            s = s_next
    
        # generate an animated GIF to see what's going on
        if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
            T = env.T - env.T % 10000  # round to 10000s
            coax.utils.generate_gif(
                env=env, policy=pi, resize_to=(320, 420),
                filepath=f"./data/gifs/{name}/T{T:08d}.gif")
    
    
    opened by glmcdona 3
Releases(v0.1.12)
[NeurIPS2021] Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks Code for NeurIPS 2021 Paper "Exploring Architectural Ingredients of A

Hanxun Huang 26 Dec 01, 2022
Robotics with GPU computing

Robotics with GPU computing Cupoch is a library that implements rapid 3D data processing for robotics using CUDA. The goal of this library is to imple

Shirokuma 625 Jan 07, 2023
This is RFA-Toolbox, a simple and easy-to-use library that allows you to optimize your neural network architectures using receptive field analysis (RFA) and create graph visualizations of your architecture.

ReceptiveFieldAnalysisToolbox This is RFA-Toolbox, a simple and easy-to-use library that allows you to optimize your neural network architectures usin

84 Nov 23, 2022
A Player for Kanye West's Stem Player. Sort of an emulator.

Stem Player Player Stem Player Player Usage Download the latest release here Optional: install ffmpeg, instructions here NOTE: DOES NOT ENABLE DOWNLOA

119 Dec 28, 2022
Normalization Matters in Weakly Supervised Object Localization (ICCV 2021)

Normalization Matters in Weakly Supervised Object Localization (ICCV 2021) 99% of the code in this repository originates from this link. ICCV 2021 pap

Jeesoo Kim 10 Feb 01, 2022
Official Repo for Ground-aware Monocular 3D Object Detection for Autonomous Driving

Visual 3D Detection Package: This repo aims to provide flexible and reproducible visual 3D detection on KITTI dataset. We expect scripts starting from

Yuxuan Liu 305 Dec 19, 2022
Task-related Saliency Network For Few-shot learning

Task-related Saliency Network For Few-shot learning This is an official implementation in Tensorflow of TRSN. Abstract An essential cue of human wisdo

1 Nov 18, 2021
IDM: An Intermediate Domain Module for Domain Adaptive Person Re-ID,

Intermediate Domain Module (IDM) This repository is the official implementation for IDM: An Intermediate Domain Module for Domain Adaptive Person Re-I

Yongxing Dai 87 Nov 22, 2022
Learning trajectory representations using self-supervision and programmatic supervision.

Trajectory Embedding for Behavior Analysis (TREBA) Implementation from the paper: Jennifer J. Sun, Ann Kennedy, Eric Zhan, David J. Anderson, Yisong Y

58 Jan 06, 2023
Code base for "On-the-Fly Test-time Adaptation for Medical Image Segmentation"

On-the-Fly Adaptation Official Pytorch Code base for On-the-Fly Test-time Adaptation for Medical Image Segmentation Paper Introduction One major probl

Jeya Maria Jose 17 Nov 10, 2022
Tooling for converting STAC metadata to ODC data model

手语识别 0、使用到的模型 (1). openpose,作者:CMU-Perceptual-Computing-Lab https://github.com/CMU-Perceptual-Computing-Lab/openpose (2). 图像分类classification,作者:Bubbl

Open Data Cube 65 Dec 20, 2022
C3d-pytorch - Pytorch porting of C3D network, with Sports1M weights

C3D for pytorch This is a pytorch porting of the network presented in the paper Learning Spatiotemporal Features with 3D Convolutional Networks How to

Davide Abati 311 Jan 06, 2023
Secure Distributed Training at Scale

Secure Distributed Training at Scale This repository contains the implementation of experiments from the paper "Secure Distributed Training at Scale"

Yandex Research 9 Jul 11, 2022
这是一个deeplabv3-plus-pytorch的源码,可以用于训练自己的模型。

DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在Pytorch当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download 训练步骤

Bubbliiiing 350 Dec 28, 2022
PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending"

Bridging the Visual Gap: Wide-Range Image Blending PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending".

Chia-Ni Lu 69 Dec 20, 2022
Auxiliary Raw Net (ARawNet) is a ASVSpoof detection model taking both raw waveform and handcrafted features as inputs, to balance the trade-off between performance and model complexity.

Overview This repository is an implementation of the Auxiliary Raw Net (ARawNet), which is ASVSpoof detection system taking both raw waveform and hand

6 Jul 08, 2022
ICCV2021 Expert-Goal Trajectory Prediction

ICCV 2021: Where are you heading? Dynamic Trajectory Prediction with Expert Goal Examples This repository contains the code for the paper Where are yo

hz 21 Dec 12, 2022
Multi-Modal Fingerprint Presentation Attack Detection: Evaluation On A New Dataset

PADISI USC Dataset This repository analyzes the PADISI-Finger dataset introduced in Multi-Modal Fingerprint Presentation Attack Detection: Evaluation

USC ISI VISTA Computer Vision 6 Feb 06, 2022
DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021)

DPT This repo is the official implementation of DPT: Deformable Patch-based Transformer for Visual Recognition (ACM MM2021). We provide code and model

CASIA-IVA-Lab 111 Dec 21, 2022
Publication describing 3 ML examples at NSLS-II and interfacing into Bluesky

Machine learning enabling high-throughput and remote operations at large-scale user facilities. Overview This repository contains the source code and

BNL 4 Sep 24, 2022