Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Overview

Diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax is a JAX-based library providing numerical differential equation solvers.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
  • lots of different solvers (including Tsit5, Dopri8, symplectic solvers, implicit solvers);
  • vmappable everything (including the region of integration);
  • using a PyTree as the state;
  • dense solutions;
  • multiple adjoint methods for backpropagation;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.

Installation

pip install diffrax

Requires Python >=3.7 and JAX >=0.2.27.

Documentation

Available at https://docs.kidger.site/diffrax.

Quick example

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

Here, Dopri5 refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.

Citation

If you found this library useful in academic research, please cite: (arXiv link)

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

(Also consider starring the project on GitHub.)

Comments
  • [WIP] Delay differential equations

    [WIP] Delay differential equations

    @thibmonsel

    This is a quick WIP draft of how we might add support for delay diffeqs into Diffrax.

    The goal is to make the API follow:

    def vector_field(t, y, args, *, history):
        ...
    
    delays = [lambda t, y, args: 1.0,
              lambda t, y, args: max(y, 1)]
    
    diffeqsolve(ODETerm(vector_field), ..., delays=delays)
    

    There's several pieces that still need doing:

    • The nonlinear solve, with respect to the dense solution over each step. (E.g. as per Section 4.1 of the DelayDiffEq.jl paper)
    • Detecting discontinuities and stepping to them directly. (Section 4.2)
    • Possibly add special support for "nice" delays, that we might be able to handle more efficiently? E.g. as long as our minimal delay is larger than our step size then the nonlinear solve can be skipped.
    • Adding documentation.
    • Adding an example.
    • Probably now would be a good time to figure out how to add support for solving DAEs as well (e.g. see #62). Both involve a nonlinear solve, and both involve passing extra information to the user-provided vector field. It might be that we can make use the same mechanisms for both. (And at the very least we should ensure that any choices we make now don't negatively impact DAE support later.)
    opened by patrick-kidger 24
  • Can't return solution of coupled differential equations

    Can't return solution of coupled differential equations

    I'm trying to solve a mid-sized system of coupled differential equations with diffrax. I'm using version 0.2.0. Here's a short snippet of dummy code that raises the issue I'm having:

    import jax.numpy as jnp
    from diffrax import diffeqsolve, ODETerm, Kvaerno3,PIDController
    
    def Results():
        def Y_prime(t, Y, args):
            dY = jnp.array([Y[6], (Y[5]-Y[6])**2,Y[0]+Y[7], (Y[1])**2, Y[2],Y[3], Y[4]**3, Y[5]**2])
            return dY
            
        t_init = 100
        t_fin = 1e5
    
        Yn_i = 1e-5
        Yp_i = 1e-6
        Yd_i = 1e-12
        Yt_i = 1e-12
        YHe3_i = 1e-12
        Ya_i = 1e-12
        YLi7_i = 1e-12
        YBe7_i = 1e-12
    
        Y0=jnp.array([[Yn_i], [Yp_i], [Yd_i], [Yt_i], [YHe3_i], [Ya_i], [YLi7_i], [YBe7_i]])
        term = ODETerm(Y_prime)
        solver = Kvaerno3()
        stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
        t_eval = jnp.logspace(jnp.log10(t_init),jnp.log10(t_fin),num=100)
        sol_at_MT = diffeqsolve(term, solver, t0=jnp.float64(t_init), t1=jnp.float64(t_fin), dt0=jnp.float64((t_eval[1]-t_eval[0])/10),y0=Y0,stepsize_controller=stepsize_controller,max_steps=None)
        Yn_MT_f, Yp_MT_f, Yd_MT_f, Yt_MT_f, YHe3_MT_f, Ya_MT_f, YLi7_MT_f, YBe7_MT_f = sol_at_MT.ys[-1][0][0],sol_at_MT.ys[-1][1][0],sol_at_MT.ys[-1][2][0],sol_at_MT.ys[-1][3][0],sol_at_MT.ys[-1][4][0],sol_at_MT.ys[-1][5][0],sol_at_MT.ys[-1][6][0],sol_at_MT.ys[-1][7][0]
    
        Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Yn_MT_f, Yp_MT_f, Yd_MT_f,Yt_MT_f,YHe3_MT_f,Ya_MT_f,YLi7_MT_f, YBe7_MT_f
        return jnp.array([Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f])
    Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Results()
    print(Yn_f)
    

    It seems diffrax successfully solves the differential equation, but struggles to return the output, i.e. it seems the code hangs when trying to assign values to the variable sol_at_MT. Tampering a bit with the diffrax source, it looks like there are two things going on.

    One is that, no matter what I try to return (even if I set all of the returns to None), if the lines right before the return in integrate.py

    branched_error_if(
        throw & jnp.invert(is_okay(result)),
        error_index,
        RESULTS.reverse_lookup,
    )
    

    aren't commented out, the code will freeze. I can include a print statement right after these lines (just before the return) that prints out successfully even when they're not commented, but I can't assign anything to sol_at_MT in without the code hanging if these lines are left in.

    Then, if I comment that branched_error_if() call out, the code still hangs if I try to return ts, ys, stats or result from integrate.py. This doesn't seem to be an issue of time or memory; the code just freezes up and can't even be aborted from the command line whether I'm running locally or with extra resources on a cluster.

    question 
    opened by cgiovanetti 12
  • Handling discontinuities in time derivative?

    Handling discontinuities in time derivative?

    Hi, first of all, let me say that this looks like an amazing project. I am looking forward to playing around with this :).

    In a concrete problem I am dealing with, I have a forced system where the external force is piecewise constant. The external force changes at specific time points (t1, ..., tn), causing a discontinuity of the time derivative.
    I would like to use adaptive step-size solvers for increased accuracy, but naively applying adaptive step-size solvers will "waste" a lot of steps to find the point of change.

    Would including the change points in SaveAt avoid this problem? Or is there some other recommended way to handle this?

    opened by jaschau 12
  • Slow `jit` compilation time compared to `jax.experimental.ode.odeint`

    Slow `jit` compilation time compared to `jax.experimental.ode.odeint`

    hi @patrick-kidger, big fan of diffrax!

    I've been playing around with some of the functionality you have in this repository and comparing it with the ode-solver in jax. The one pain point i noticed is that there seems to be a relatively slow jit compilation time, particularly when I try to jit the grad of a simple loss function containing diffeqsolve. I was wondering if this is an error on my part (perhaps I botched the diffrax implementation) or if there is yet to be some optimization. The demonstration is below:

    from jax.config import config
    config.update("jax_enable_x64", True)
    config.update("jax_debug_nans", True) 
    config.parse_flags_with_absl()
    import jax
    import jax.numpy as jnp
    from jax import random
    import numpy as np
    from functools import partial
    import haiku as hk
    
    def exact_kinematic_aug_diff_f(t, y, args_tuple):
        """
        """
        _y, _, _ = y
        _params, _key, diff_f = args_tuple
        aug_diff_fn = lambda __y : diff_f(t, __y, (_params,))
        _f, s, t = aug_diff_fn(_y)
        r = jnp.sum(t)
        return _f, r, 0.
    
    def exact_kinematic_odeint_diff_f(y, t, params, canonical_diff_fn):
        run_y = y[0]
        _f, s, t = canonical_diff_fn(t, run_y, (params,))
        return _f, jnp.sum(s), 0.
    
    class TestMLP(hk.Module):
        def __init__(self, num_particles, name=None):
            super().__init__(name=None)
            self._mlp = hk.nets.MLP([8,8,8,8,num_particles*12])
            self._num_particles=num_particles
        def __call__(self, t, y):
            in_y = (y + t).flatten()
            outter = self._mlp(in_y).reshape((4, self._num_particles, 3))
            return outter[:2], outter[2], outter[3]
    
    def test(num_particles):
        import functools
        from jax.experimental.ode import odeint
        import diffrax
        
        #generate positions/velocities
        small_positions = jax.random.normal(jax.random.PRNGKey(261), shape=(num_particles,3))
        small_velocities = jax.random.normal(jax.random.PRNGKey(235), shape=(num_particles,3))
        small_positions_and_velocities = jnp.vstack([small_positions[jnp.newaxis, ...], small_velocities[jnp.newaxis, ...]])
        
        # make module kwargs
        VectorMLP_kwargs = {'num_particles': num_particles}
        
        # make module function
        def _diff_f_wrapper(t, y):
            diff_f = TestMLP(**VectorMLP_kwargs)
            return diff_f(t, y)
        
        diff_f_init, diff_f_apply = hk.without_apply_rng(hk.transform(_diff_f_wrapper))
        init_params = diff_f_init(jax.random.PRNGKey(36), 0., small_positions_and_velocities)
        canonicalized_diff_f_fn = lambda _t, _y, _args_tuple : diff_f_apply(_args_tuple[0], _t, _y)
        
        # make the augmented functions
        odeint_aug_diff_func = functools.partial(exact_kinematic_odeint_diff_f, canonical_diff_fn=canonicalized_diff_f_fn)
        diffeqsolve_aug_diff_func = exact_kinematic_aug_diff_f
        
        # odeint solver
        def odeint_solver(_parameters, _init_y, _key):
            aug_init_y = (_init_y, 0., 0.)
            outs = odeint(odeint_aug_diff_func, aug_init_y, jnp.array([0., 1.]), _parameters, rtol=1.4e-8, atol=1.4e-8)
            final_outs = (outs[0][-1], outs[1][-1], outs[2][-1])
            return final_outs
        
        def diffrax_ode_solver(_parameters, _init_y, _key):
            term=diffrax.ODETerm(diffeqsolve_aug_diff_func)
            stepsize_controller=diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
            solver = diffrax.Dopri5()
            aug_init_y = (_init_y, 0., 0.)
            sol = diffrax.diffeqsolve(term, 
                                      solver, 
                                      t0=0., 
                                      t1=1., 
                                      dt0=1e-1, 
                                      y0=aug_init_y, 
                                      stepsize_controller=stepsize_controller, 
                                      args=(_parameters, _key, canonicalized_diff_f_fn))
            return sol.ys[0][0], sol.ys[1][0], sol.ys[2][0]
        
        @jax.jit
        def odeint_loss_fn(_params, _init_y, _key):
            ode_solution = odeint_solver(_params, _init_y, _key)
            return jnp.sum(ode_solution[1]**2)
        
        @jax.jit
        def diffrax_loss_fn(_params, _init_y, _key):
            ode_solution = diffrax_ode_solver(_params, _init_y, _key)
            return jnp.sum((ode_solution[1])**2)
        
        # test
        import time
        
        # odeint compilation time
        start_time = time.time()
        _ = jax.grad(odeint_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
        end_time = time.time()
        print(f"odeint comp. time: {end_time - start_time}")
        
        # diffrax compilation time
        start_time = time.time()
        _ = jax.grad(diffrax_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
        end_time = time.time()
        print(f"diffrax comp. time: {end_time - start_time}")
    
    

    running test(8) gives me the following compilation time on CPU:

    odeint comp. time: 2.5580570697784424
    diffrax comp. time: 23.965799570083618
    

    I noticed that if I use diffrax.BacksolveAdjoint, compilation time goes down to ~8 seconds, but I'm keen to avoid that method based on your docs.; also, it looks like the compilation time in diffrax is heavily dependent on the number of hidden layers in TestMLP, perhaps suggesting a non-optimal compilation in diffrax of for loops? Thanks!

    refactor next 
    opened by dominicrufa 11
  • No GPU/TPU found, falling back to CPU

    No GPU/TPU found, falling back to CPU

    Here's the full warning that I get (I do have a GPU):

    >>> import diffrax
    2022-03-24 16:30:19.350737: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:170] XLA service 0x55795c0d4670 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
    2022-03-24 16:30:19.350761: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:178]   StreamExecutor device (0): Interpreter, <undefined>
    2022-03-24 16:30:19.353414: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:169] TfrtCpuClient created.
    2022-03-24 16:30:19.353886: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    

    Edit: I installed diffrax from conda-forge.

    opened by ma-sadeghi 11
  • Logging metrics during an ODE solve

    Logging metrics during an ODE solve

    Hello @patrick-kidger,

    thank you for open-sourcing this nice library! I was going to resume work on my own small ODE lib, but since this is much more elaborate than what I came up with so far, I am inclined to use this instead for a small project in the future.

    One question that came up to me when reading the source code: Is there currently a way to compute step-wise metrics during the solve? (Think logging step sizes, Jacobian eigenvalues, etc.)

    This would presumably happen in the integrate method. Could I e.g. use the solver_state pytree for this in, say, overridden solver classes? Thank you for your consideration.

    opened by nicholasjng 11
  • Brownian motion classes accept pytrees for shape and dtype arguments

    Brownian motion classes accept pytrees for shape and dtype arguments

    This changes the argument shape for classes VirtualBrownianTree and UnsafeBrownianPath, and adds an additional argument dtype as per the dicussion in #180.

    • I decided upon shape: Pytree[Tuple[int, ...] instead of shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]]. It's unclear what to do with named_shape in jax.ShapeDtypeStruct -- I don't know if there is a way to sample Brownian motion via named shapes. But if you feel strongly about this and give me some pointers, I can reimplement.
    • To allow specifying dtypes, dtype argument specifies them as a pytree and has to be a prefix tree of shape.
    • I added __init__ methods to both classes since I was not sure how to have dtype=None without it.
    • Added some helper functions that I use in misc.py, hope that's the right location to place them.
    • Used jtu.tree_map instead of jax.vmap -- was not sure how to supply is_leaf to jax.vmap. Happy to change this as well, with some pointers.
    • Changed the test_brownian.py:test_shape to test pytree shapes and dtypes. Just noticed that formatting made it look pretty bad, not sure if that's a big deal.
    • Tests pass locally.

    Let me know what you think. Thanks!

    opened by ciupakabra 9
  • added new kalman-filter example

    added new kalman-filter example

    I wrote a little additional example that showcases diffrax in a maybe not so obvious way. It also showcases equinox and the ability to freeze parameters. Let me know what you think (and what needs to be changed). Greetings

    opened by SimiPixel 8
  • Performance against `jax.experimental.ode.odeint`

    Performance against `jax.experimental.ode.odeint`

    Hi @patrick-kidger, I was excited to test out Diffrax in our code. However, we found it did not perform as well as expected. This is likely to nuances on our end, but because o https://github.com/google/jax/issues/9654 I thought I would post a MWE.

    import diffrax
    import jax
    import ticktack
    
    PARAMS = (774.86, 0.25, 0.8, 6.44)
    
    STEADY_PROD = 1.8803862513018528
    
    STEADY_STATE = jax.numpy.array(
        [1.34432991e+02, 7.07000000e+02, 1.18701144e+03,
        3.95666872e+00, 4.49574232e+04, 1.55056740e+02,
        6.32017337e+02, 4.22182768e+02, 1.80125397e+03,
        6.63307283e+02, 7.28080320e+03], 
        dtype=jax.numpy.float64)
    
    PROD_COEFFS = jax.numpy.array(
        [0.7, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
        dtype=jax.numpy.float64)
    
    MATRIX = jax.numpy.array([
        [-0.509, 0.009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.508, -0.44, 0.068, 0.0, 0.0, 0.545, 0.0, 0.167, 0.002, 0.002, 0.0],
        [0.0, 0.121, -0.155, 12.0, 0.001, 0.0, 0.0, 0.003, 0.0, 0.0, 0.0],
        [0.0, 0.0, 4.4000e-02, -1.3333e+01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.042, 1.333, -0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.229, 0.0, 0.0, 0.0, -1.046, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.136, -0.033, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.364, 0.033, -0.183, 0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, -0.002, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.0, -0.002, 0.0],
        [0.0, 0.0, 3.333e-04, 0.0, 5.291e-06, 0.0, 0.0, 0.0, 0.0, 4.0e-04, -1.2340e-04]], 
        dtype=jax.numpy.float64)
    
    @jax.jit 
    def driving_term(t, args):
        start_time, duration, phase, area = jax.numpy.array(args)
        middle = start_time + duration / 2.
        height = area / duration
    
        gauss = height * \
            jax.numpy.exp(- ((t - middle) / (0.5 * duration)) ** 16.)
        sine = STEADY_PROD + 0.18 * STEADY_PROD *\
            jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11)
    
        return (sine + gauss) * 3.747
    
    @jax.jit
    def jax_dydt(y, t, args, /, matrix=MATRIX, production=driving_term, 
                       prod_coeffs=PROD_COEFFS):
        ans = jax.numpy.matmul(matrix, y)
        production_rate_constant = production(t, args)
        production_term = prod_coeffs * production_rate_constant
        return ans + production_term
    
    @jax.jit
    def diffrax_dydt(t, y, args, /, matrix=MATRIX, production=driving_term, 
                     prod_coeffs=PROD_COEFFS):
        ans = jax.numpy.matmul(matrix, y)
        production_rate_constant = production(t, args)
        production_term = prod_coeffs * production_rate_constant
        return ans + production_term
    
    time_out = jax.numpy.linspace(750, 800, 1000)
    
    %%timeit
    jax.experimental.ode.odeint(jax_dydt, STEADY_STATE, time_out, PARAMS)
    
    term = diffrax.ODETerm(diffrax_dydt)
    solver = diffrax.Bosh3()
    step_size = diffrax.PIDController(rtol=1e-10, atol=1e-10)
    save_time = diffrax.SaveAt(ts=time_out)
    
    %%timeit
    diffrax.diffeqsolve(args=PARAMS, terms=term, solver=solver, y0=STEADY_STATE,
                        t0=time_out.min(), t1=time_out.max(), dt0=0.01,
                        saveat=save_time, stepsize_controller=step_size, 
                        max_steps=10000)
    

    Sorry that the example is so volumous but I wanted to keep it very similar to our code.

    Thanks in advance.

    Jordan

    opened by Jordan-Dennis 8
  • Weird behaviour due to defaults when using Implicit-Euler

    Weird behaviour due to defaults when using Implicit-Euler

    When using dfx.ImplicitEuler() with everything set to default an error is raised

    missing rtol and atol of NewtonNonlinearSolver

    You are then prompted to set these values in the stepsize-controller, because it is by default supposed to fallback to the values provided in PIDController. But dfx.ImplicitEuler() does not support adaptive step-sizing using a PIDController.

    The solution is to use

    solver=dfx.ImplicitEuler(nonlinear_solver=dfx.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    

    Just something that feels a bit odd.

    refactor 
    opened by SimiPixel 6
  • Transform Feedforward-Network + solver into a Recurrent-Network

    Transform Feedforward-Network + solver into a Recurrent-Network

    Hello Patrick,

    let me first quickly motivate my feature request. As a side-project i am currently working on Model-based optimal control. For e.g. a only partially-observable environment stateful agents are useful. So, suppose the action selection of an agent is given by the following method

    def select_action(params, state, observation, time):
        apply = neural_network.apply
        state, action = apply(params, state, observation, time)
        return state, action
    
    while True:
        action = select_action(..., observation, env.time)
        observation = env.step(action)
    

    Typically, the apply-function is some recurrent neural network. Suppose the environment env is differentiable, because it is just some model of the environment (maybe another network). Now, i would like to replace the recurrent neural network with a feedforward network + solver without changing the API of the agent.

    I was wondering if constructing the following is possible and sensible? I.e. i would like to transform a choice of Feedforward-Network + Solver into a Recurrent-Network.

    def select_action(params, ode_state, observation, time):
        rhs = lambda x,u: neural_network.apply(params, x, u)
        solution, ode_state = odeint(ode_state, rhs, t1=time, u=(observation, time))
        return ode_state, solution.x(time)
    

    I would like to emphasis that this select_action must remain differentiable: The x-output w.r.t the network parameters.

    I would love to hear your input :) Anyways thank you in advance.

    opened by SimiPixel 5
  • ODE solver fails with 'The maximum number of solver steps was reached. Try increasing `max_steps`'

    ODE solver fails with 'The maximum number of solver steps was reached. Try increasing `max_steps`'

    Hi,

    I was playing with this cool package on a chemical reaction ODE problem. This problem solves the time evolution of seven chemical concentrations, which is a stiff problem but can be solved using a Fortran-based solver. However, the diffrax version fails, with an XlaRuntimeError complaining 'The maximum number of solver steps was reached. Try increasing max_steps'. Unfortunately, the error persists no matter how large the max_steps is and which solver is used (e.g., impliciteuler or Kvaerno5). Note that when commenting the error message in diffeqsolve function, I find that the code can solve about the first 100s and output inf (from solution.ys) in a later time.

    Any suggestion would be appreciated!

    Below is the code snippet --

    from diffrax import diffeqsolve, ODETerm, SaveAt
    from diffrax import NewtonNonlinearSolver, Dopri5, Kvaerno3, ImplicitEuler, Euler, Kvaerno5
    from diffrax import PIDController
    
    import jax
    import jax.numpy as jnp
    import jax.random as jrandom
    
    from jax.config import config
    config.update("jax_enable_x64", True)
    
    def funclog2(t, logy, args):
        k1, k2, k3 = args[0], args[1], args[2]
        kd1, kd2, kd3 = args[3], args[4], args[5]
        ka1, ka2, ka3 = args[6], args[7], args[8]
        r4 = args[9]
        
        y = jnp.power(10, logy)
        doc, o2, no3, no2, n2, co2, bm = y
        
        # log transform scale
        scale = 1 / jnp.log(10)
        scale = scale / y
        
        # The stoichiometry matrix
        stoich = jnp.array([
            [-1, -1, -1, 5],
            [0, 0, -1, 0],
            [-2, 0, 0, 0],
            [1, -1, 0, 0],
            [0, 1, 0, 0],
            [1, 1, 1, 0],
            [0, 0, 0, -1]
        ])
        
        # Scale stoich
        stoich = jax.vmap(lambda a, b: a*b, in_axes=0)(scale, stoich)
        
        # Reaction rate
        r1 = k1 * bm * doc/(doc+kd1) * no3/(no3+ka1)
        r2 = k2 * bm * doc/(doc+kd2) * no2/(no2+ka2)
        r3 = k3 * bm * doc/(doc+kd3) * o2/(no2+ka3)
        
        r = jnp.array([r1, r2, r3, r4]).T
        
        return stoich @ r
    
    # Static parameters
    k1, k2, k3 = 3.24e-4, 2.69e-4, 9e-4 # [mol/L/sec/mass [BM]]
    kd1, kd2, kd3 = 2.5e-4, 2.5e-4, 2.5e-4 # [mol/L]
    ka1, ka2, ka3 = 1e-6, 4e-6, 1e-6  # [mol/L]
    r4 = 2.8e-6 # [mol/L/sec]
    args = jnp.array([k1, k2, k3, kd1, kd2, kd3, ka1, ka2, ka3, r4])
    
    # The initial concentrations with the following order [mol/L]:
    # doc, o2, no3, no2, n2, co2, bm
    # y0 = jnp.array([4.16e-05, 0.000266, 0.000396, 1e-10, 1e-10, 0.00248, 0.0003])
    y0 = jnp.array([4.16e-05, 0.000266, 0.000396, 1e-3, 1e-3, 0.00248, 0.0003])
    logy0 = jnp.log10(y0)
    
    term = ODETerm(funclog2)
    # solver = Dopri5()
    # solver = Euler()
    solver = Kvaerno5(NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    # solver = ImplicitEuler(NewtonNonlinearSolver(rtol=1e-3, atol=1e-6))
    # t0, t1, dt0 = 0, 3600*24*30, 1
    t0, t1, dt0 = 0, 200, 0.01
    # t0, t1, dt0 = 0, 3600*24, 3600
    solution = diffeqsolve(term, solver, t0=t0, t1=t1, dt0=dt0, max_steps=400000,
                           stepsize_controller=PIDController(rtol=1e-3, atol=1e-6),
                           saveat = SaveAt(t0=True, ts=jnp.linspace(t0,t1)), 
                           y0=logy0, args=args)
    solution.stats
    
    question 
    opened by PeishiJiang 5
  • Truncated Back Propagation through time

    Truncated Back Propagation through time

    Hi, I was wondering if it possible to integrate truncated back propagation through time (TBPTT) into Diffrax. I couldn't find any options for this in Diffrax or Equinox, nor could I find any implementation of TBPTT in the source code in integrate.py, but maybe I missed it. My best guess would be to write a custom adjoint class that would implement TBPTT, but I am not sure how to do this. My question is: would it be possible to (easily) implement TBPTT to train my NDEs and how should I approach this?

    feature 
    opened by sdevries0 1
  • Fastest way to evaluate a solution

    Fastest way to evaluate a solution

    Hi, suppose I have a simple ODE that I solve with diffrax. What would be the fastest way to use the solution in another piece of code? I need to evaluate the solution on some points not known in advance, and I thought of generating a dense solution sol and then use its method evaluate on the points of interest, i.e. every time I need it, call sol.evaluate() on my points of interest (using vmap when needed). Is this the most efficient way, or shall I interpolate myself a fixed grid solution and create a jitted function that evaluates it on my points of interest?

    question 
    opened by marcofrancis 1
  • Make diffeqsolve convertable to TensorFlow

    Make diffeqsolve convertable to TensorFlow

    Based on a talk on NODE's on youtube I came across this package, and this looks perfect for some project we are planning (thanks for the great talk!) . Now one of the platforms where we want to run our code does not support JAX/XLA/Tensorflow. Just ONNX. I tried converting a simulation function to Tensorflow for later conversion to ONNX, but this fails because the unsupported unvmap_any is used (at compiletime!) to deduce the amount of iterations needed.

    Minimal example:

    import tensorflow as tf
    import jax.numpy as jnp
    import tf2onnx
    
    from diffrax import diffeqsolve, ODETerm, Euler
    from jax.experimental import jax2tf
    
    def simulate(y0):
        solution = diffeqsolve(
                terms=ODETerm(lambda t, y, a: -y), solver=Euler(),
                t0=0, t1=1, dt0=0.1, y0=y0)
        return solution.ys[0]
    
    # This works
    x = simulate(100)
    assert jnp.isclose(x, jnp.exp(-1)*100, atol=.1, rtol=.1)
    
    simulate_tf = tf.function(jax2tf.convert(simulate, enable_xla=False))
    
    # Does not work:
    # simulate_tf(100)
    # => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented
    
    # Also doesn't not work:
    tf2onnx.convert.from_function(
            simulate_tf, input_signature=[tf.TensorSpec((), tf.float32)])
    # simulate_tf(100)
    # => NotImplementedError: TensorFlow interpretation rule for 'unvmap_any' not implemented
    

    For us, it would be really nice to use a GPU/TSP during training with jax, then transfer to this specifc piece of hardware with just ONNX support for inference (at this point I don't need gradient calculation anymore). Of course, solving this might be completely outside the scope of the project and there are other solutions like writing the solvers from scratch or using existing solvers in TF/PyTorch.

    Currently my knowledge of JAX is limited (hopefully this will soon improve!). If this is the only function stopping Diffrax from being tensorflow-convertable maybe a small workaround could be possible. I'm also happy with a answer like 'no we don't do this' or 'send us a PR if you want to have this fixed'

    feature 
    opened by llandsmeer 6
  • Question about BacksolveAdjoint through SemiImplicitEuler solver

    Question about BacksolveAdjoint through SemiImplicitEuler solver

    I am testing the adjoint method to calculate the gradients from a SemiImplicitEuler solver. I met errors when calculate the gradients using BacksolveAdjoint method. Here is a working example. It would be great to have some suggestions.

    Thank you in advance!

    ` from diffrax import diffeqsolve, ODETerm, SemiImplicitEuler, SaveAt, BacksolveAdjoint import jax.numpy as jnp from jax import grad from matplotlib import pyplot as plt

    def drdt(t, v, args): return v

    def dvdt(t, r, args): return -args[0]*(r-args[1])

    terms =(ODETerm(drdt),ODETerm(dvdt)) solver = SemiImplicitEuler() y0 = (jnp.array([1.0]),jnp.array([0.0])) saveat = SaveAt(ts=jnp.arange(0,30,0.1))

    def loss(y0): solution = diffeqsolve(terms, solver, t0=0, t1=30, dt0=0.0001, y0=y0, args=[1.0,0.0], saveat=saveat,max_steps=10000000,adjoint=BacksolveAdjoint()) return jnp.sum(solution.ys[0]) grads = grad(loss)(y0) print(grads) `

    here is the error message:

    Traceback (most recent call last): File "test_harmonic.py", line 23, in <module> grads = grad(loss)(y0) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 482, in fn_bwd_wrapped out = fn_bwd(residuals, grad_diff_array_out, vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 394, in _loop_backsolve_bwd state, _ = _scan_fun(state, val0, first=True) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 332, in _scan_fun _sol = diffeqsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 82, in __call__ return __self._fun_wrapper(False, args, kwargs) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 78, in _fun_wrapper dynamic_out, static_out = self._cached(dynamic, static) File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/jit.py", line 30, in fun_wrapped out = fun(*args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 858, in diffeqsolve final_state, aux_stats = adjoint.loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 499, in loop final_state, aux_stats = _loop_backsolve( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 509, in __call__ out = self.fn_wrapped( File "/home/zwq2834/anaconda3/envs/diffsim_jax/lib/python3.8/site-packages/equinox-0.9.2-py3.8.egg/equinox/grad.py", line 443, in fn_wrapped out = self.fn(vjp_arg, *args, **kwargs) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/adjoint.py", line 250, in _loop_backsolve return self._loop_fn( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 497, in loop final_state = bounded_while_loop( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 125, in bounded_while_loop return lax.while_loop(cond_fun, _body_fun, init_val) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/misc/bounded_while_loop.py", line 118, in _body_fun _new_val = body_fun(_val, inplace) File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/integrate.py", line 137, in body_fun (y, y_error, dense_info, solver_state, solver_result) = solver.step( File "/home/zwq2834/development/DiffSim_Jax/diffrax/diffrax/solver/semi_implicit_euler.py", line 42, in step y0_1, y0_2 = y0 ValueError: too many values to unpack (expected 2)

    bug feature 
    opened by Chenghao-Wu 1
Releases(v0.2.2)
  • v0.2.2(Nov 15, 2022)

    Performance improvements

    • Now make fewer vector field traces in several cases (#172, #174)

    Fixes

    • Many documentation improvements.
    • Fixed several warnings about jax.{tree_map,tree_leaves,...} being moved to jax.tree_util.{tree_map,tree_leaves,...}. (Thanks @jacobusmmsmit!)
    • Fixed the step size controller choking if the error is ever NaN. (#143, #152)
    • Fixed some crashes due to JAX-internal changes (If you've ever seen it throw an error about not knowing how to rewrite closed_call_p, it's this one.)
    • Fixed an obscure edge-case NaN on the backward pass, if you were using an implicit solver with an adaptive step size controller, got a rejected step due to the implicit solve failing to converge, and happened to also be backpropagating wrt the controller_state.

    Other

    • Added a new Kalman filter example (#159) (Thanks @SimiPixel!)
    • Brownian motion classes accept pytrees for shape and dtype arguments (#183) (Thanks @ciupakabra!)
    • The main change is an internal refactor: a lot of functionality has moved diffrax.misc -> equinox.internal.

    New Contributors

    • @jacobusmmsmit made their first contribution in https://github.com/patrick-kidger/diffrax/pull/149
    • @SimiPixel made their first contribution in https://github.com/patrick-kidger/diffrax/pull/159
    • @ciupakabra made their first contribution in https://github.com/patrick-kidger/diffrax/pull/183

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.2.1...v0.2.2

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Aug 3, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Made is_okay,is_successful,is_event public by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/134
    • Fix implicit adjoints assuming array-valued state by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/136
    • Replace jax tree manipulation method that are being deprecated with jax.tree_util equivalents by @mahdi-shafiei in https://github.com/patrick-kidger/diffrax/pull/138
    • bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/141

    New Contributors

    • @mahdi-shafiei made their first contribution in https://github.com/patrick-kidger/diffrax/pull/138

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Jul 20, 2022)

    • Feature: event handling. In particular it is now possible to interrupt a diffeqsolve early. See the events page in the docs and the new steady state example.
    • Compilation time improvements:
      • The compilation speed of NewtonNonlinearSolver (and thus in practice also all implicit solvers like Kvaerno3 etc.) has been improved (~factor 1.5)
      • The compilation speed of all Runge--Kutta solvers can be dramatically reduced (~factor 3) by passing e.g. Dopri5(scan_stages=True). This may increase runtime slightly. At the moment the default is scan_stages=False for all solvers, but this default might change in the future.
    • Various documentation improvements.

    New Contributors

    • @jatentaki made their first contribution in https://github.com/patrick-kidger/diffrax/pull/121

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.2...v0.2.0

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(May 18, 2022)

    Main change here is a minor technical one - Diffrax will no longer initialise the JAX backend as a side effect of being imported.


    Autogenerated release notes as follows:

    What's Changed

    • Removed explicit jaxlib dependency by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/93
    • switch error_if to python if (regarding google/jax/issues/10047) by @amir-saadat in https://github.com/patrick-kidger/diffrax/pull/99
    • Doc fixes by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/100
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/107

    New Contributors

    • @amir-saadat made their first contribution in https://github.com/patrick-kidger/diffrax/pull/99

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.1...v0.1.2

    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(Apr 7, 2022)

    Diffrax uses some JAX-internal functionality that will shortly be deprecated in JAX. This release adds the appropriate support for both older and newer versions of JAX.


    Autogenerated release notes as follows:

    What's Changed

    • [JAX] Add MHLO lowerings in preparation for xla.lower_fun() removal by @hawkinsp in https://github.com/patrick-kidger/diffrax/pull/91
    • Bump version by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/92

    New Contributors

    • @hawkinsp made their first contribution in https://github.com/patrick-kidger/diffrax/pull/91

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.1.0...v0.1.1

    Source code(tar.gz)
    Source code(zip)
  • v0.1.0(Mar 30, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Adjusted PIDController by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/89

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.6...v0.1.0

    Source code(tar.gz)
    Source code(zip)
  • v0.0.6(Mar 29, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Symbolic regression text by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/79
    • Fixed edge case infinite loop on stiff-ish problems (+very bad luck) by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/86

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.5...v0.0.6

    Source code(tar.gz)
    Source code(zip)
  • v0.0.5(Mar 21, 2022)

    Autogenerated release notes as follows:

    What's Changed

    • Doc tweaks by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/72
    • Added JIT wrapper to stiff ODE example by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/75
    • Added autoreleases by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/78
    • Removed overheads from runtime checking when they can be compiled away. by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/77

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.4...v0.0.5

    Source code(tar.gz)
    Source code(zip)
  • v0.0.4(Mar 6, 2022)

    First release using GitHub releases! We'll be using this to serve as a changelog.

    As for what has changed since the v0.0.3 release, we'll let the autogenerated release notes do the talking:

    What's Changed

    • Rewrote RK implementation quite substantially to allow FSAL RK SDE integrators. by @patrick-kidger in https://github.com/patrick-kidger/diffrax/pull/70

    Full Changelog: https://github.com/patrick-kidger/diffrax/compare/v0.0.3...v0.0.4

    Source code(tar.gz)
    Source code(zip)
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
Unofficial implementation of "TTNet: Real-time temporal and spatial video analysis of table tennis" (CVPR 2020)

TTNet-Pytorch The implementation for the paper "TTNet: Real-time temporal and spatial video analysis of table tennis" An introduction of the project c

Nguyen Mau Dung 438 Dec 29, 2022
Optimize Trading Strategies Using Freqtrade

Optimize trading strategy using Freqtrade Short demo on building, testing and optimizing a trading strategy using Freqtrade. The DevBootstrap YouTube

DevBootstrap 139 Jan 01, 2023
A curated list of awesome Model-Based RL resources

Awesome Model-Based Reinforcement Learning This is a collection of research papers for model-based reinforcement learning (mbrl). And the repository w

OpenDILab 427 Jan 03, 2023
Generic Event Boundary Detection: A Benchmark for Event Segmentation

Generic Event Boundary Detection: A Benchmark for Event Segmentation We release our data annotation & baseline codes for detecting generic event bound

47 Nov 22, 2022
An Evaluation of Generative Adversarial Networks for Collaborative Filtering.

An Evaluation of Generative Adversarial Networks for Collaborative Filtering. This repository was developed by Fernando B. Pérez Maurera. Fernando is

Fernando Benjamín PÉREZ MAURERA 0 Jan 19, 2022
PyTorch implementation of DirectCLR from paper Understanding Dimensional Collapse in Contrastive Self-supervised Learning

DirectCLR DirectCLR is a simple contrastive learning model for visual representation learning. It does not require a trainable projector as SimCLR. It

Meta Research 49 Dec 21, 2022
Instance-Dependent Partial Label Learning

Instance-Dependent Partial Label Learning Installation pip install -r requirements.txt Run the Demo benchmark-random mnist python -u main.py --gpu 0 -

17 Dec 29, 2022
Code for Motion Representations for Articulated Animation paper

Motion Representations for Articulated Animation This repository contains the source code for the CVPR'2021 paper Motion Representations for Articulat

Snap Research 851 Jan 09, 2023
How will electric vehicles affect traffic congestion and energy consumption: an integrated modelling approach

EV-charging-impact This repository contains the code that has been used for the Queue modelling for the paper "How will electric vehicles affect traff

7 Nov 30, 2022
f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation

f-BRS: Rethinking Backpropagating Refinement for Interactive Segmentation [Paper] [PyTorch] [MXNet] [Video] This repository provides code for training

Visual Understanding Lab @ Samsung AI Center Moscow 516 Dec 21, 2022
Ground truth data for the Optical Character Recognition of Historical Classical Commentaries.

OCR Ground Truth for Historical Commentaries The dataset OCR ground truth for historical commentaries (GT4HistComment) was created from the public dom

Ajax Multi-Commentary 3 Sep 08, 2022
A GridMixup augmentation, inspired by GridMask and CutMix

GridMixup A GridMixup augmentation, inspired by GridMask and CutMix Easy install pip install git+https://github.com/IlyaDobrynin/GridMixup.git Overvie

IlyaDo 42 Dec 28, 2022
LLVIP: A Visible-infrared Paired Dataset for Low-light Vision

LLVIP: A Visible-infrared Paired Dataset for Low-light Vision Project | Arxiv | Abstract It is very challenging for various visual tasks such as image

CVSM Group - email: <a href=[email protected]"> 377 Jan 07, 2023
SAS: Self-Augmentation Strategy for Language Model Pre-training

SAS: Self-Augmentation Strategy for Language Model Pre-training This repository

Alibaba 5 Nov 02, 2022
Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images

SASSnet Code for paper: Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images(MICCAI 2020) Our code is origin from UA-MT You can fin

klein 125 Jan 03, 2023
Improving Object Detection by Label Assignment Distillation

Improving Object Detection by Label Assignment Distillation This is the official implementation of the WACV 2022 paper Improving Object Detection by L

Cybercore Co. Ltd 51 Dec 08, 2022
nn_builder lets you build neural networks with less boilerplate code

nn_builder lets you build neural networks with less boilerplate code. You specify the type of network you want and it builds it. Install pip install n

Petros Christodoulou 157 Nov 20, 2022
Implementation of SSMF: Shifting Seasonal Matrix Factorization

SSMF Implementation of SSMF: Shifting Seasonal Matrix Factorization, Koki Kawabata, Siddharth Bhatia, Rui Liu, Mohit Wadhwa, Bryan Hooi. NeurIPS, 2021

Koki Kawabata 9 Jun 10, 2022
Data Preparation, Processing, and Visualization for MoVi Data

MoVi-Toolbox Data Preparation, Processing, and Visualization for MoVi Data, https://www.biomotionlab.ca/movi/ MoVi is a large multipurpose dataset of

Saeed Ghorbani 51 Nov 27, 2022
Constraint-based geometry sketcher for blender

Constraint-based sketcher addon for Blender that allows to create precise 2d shapes by defining a set of geometric constraints like tangent, distance,

1.7k Dec 31, 2022