EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit

Related tags

Deep Learningevojax
Overview

EvoJAX: Hardware-Accelerated Neuroevolution

EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs. EvoJAX achieves very high performance by implementing the evolution algorithm, neural network and task all in NumPy, which is compiled just-in-time to run on accelerators.

This repo also includes several extensible examples of EvoJAX for a wide range of tasks, including supervised learning, reinforcement learning and generative art, demonstrating how EvoJAX can run your evolution experiments within minutes on a single accelerator, compared to hours or days when using CPUs.

EvoJAX paper: https://arxiv.org/abs/2202.05008

Installation

EvoJAX is implemented in JAX which needs to be installed first.

Install JAX: Please first follow JAX's installation instruction with optional GPU/TPU backend support. In case JAX is not set up, EvoJAX installation will still try pulling a CPU-only version of JAX. Note that Colab runtimes come with JAX pre-installed.

Install EvoJAX:

# Install from PyPI.
pip install evojax

# Or, install from our GitHub repo.
pip install git+https://github.com/google/[email protected]

Code Overview

EvoJAX is a framework with three major components, which we expect the users to extend.

  1. Neuroevolution Algorithms All neuroevolution algorithms should implement the evojax.algo.base.NEAlgorithm interface and reside in evojax/algo/. We currently provide PGPE, with more coming soon.
  2. Policy Networks All neural networks should implement the evojax.policy.base.PolicyNetwork interface and be saved in evojax/policy/. In this repo, we give example implementations of the MLP, ConvNet, Seq2Seq and PermutationInvariant models.
  3. Tasks All tasks should implement evojax.task.base.VectorizedTask and be in evojax/task/.

These components can be used either independently, or orchestrated by evojax.trainer and evojax.sim_mgr that manage the training pipeline. While they should be sufficient for the currently provided policies and tasks, we plan to extend their functionality in the future as the need arises.

Examples

As a quickstart, we provide non-trivial examples (scripts in examples/ and notebooks in examples/notebooks) to illustrate the usage of EvoJAX. We provide example commands to start the training process at the top of each script. These scripts and notebooks are run with TPUs and/or NVIDIA V100 GPU(s):

Supervised Learning Tasks

While one would obviously use gradient-descent for such tasks in practice, the point is to show that neuroevolution can also solve them to some degree of accuracy within a short amount of time, which will be useful when these models are adapted within a more complicated task where gradient-based approaches may not work.

  • MNIST Classification - We show that EvoJAX trains a ConvNet policy to achieve >98% test accuracy within 5 min on a single GPU.
  • Seq2Seq Learning - We demonstrate that EvoJAX is capable of learning a large network with hundreds of thousands parameters to accomplish a seq2seq task.

Classic Control Tasks

The purpose of including control tasks are two-fold: 1) Unlike supervised learning tasks, control tasks in EvoJAX have undetermined number of steps, we thus use these examples to demonstrate the efficiency of our task roll-out loops. 2) We wish to show the speed-up benefit of implementing tasks in JAX and illustrate how to implement one from scratch.

  • Locomotion - Brax is a differentiable physics engine implemented in JAX. We wrap it as a task and train with EvoJAX on GPUs/TPUs. It takes EvoJAX tens of minutes to solve a locomotion task in Brax.
  • Cart-Pole Swing Up - We illustrate how the classic control task can be implemented in JAX and be integrated into EvoJAX's pipeline for significant speed up training.

Novel Tasks

In this last category, we go beyond simple illustrations and show examples of novel tasks that are more practical and attractive to researchers in the genetic and evolutionary computation area, with the goal of helping them try out ideas in EvoJAX.

Multi-agent WaterWorld ES-CLIP: “A drawing of a cat”
  • WaterWorld - In this task, an agent tries to get as much food as possible while avoiding poisons. EvoJAX is able to learn the agent in tens of minutes on a single GPU. Moreover, we demonstrate that multi-agents training in EvoJAX is possible, which is beneficial for learning policies that can deal with environmental complexity and uncertainties.
  • Abstract Paintings (notebook 1 and notebook 2) - We reproduce the results from this computational creativity work and show how the original work, whose implementation requires multiple CPUs and GPUs, could be accelerated on a single GPU efficiently using EvoJAX, which was not possible before. Moreover, with multiple GPUs/TPUs, EvoJAX can further speed up the mentioned work almost linearly. We also show that the modular design of EvoJAX allows its components to be used independently -- in this case it is possible to use only the ES algorithms from EvoJAX while leveraging one's own training loops and environment implantation.

Disclaimer

This is not an official Google product.

Comments
  • Some proposals about the `Trainer` logic

    Some proposals about the `Trainer` logic

    Currently I see two ways of using the Trainer.test_task:

    1. The test_task of the trainer is used for validation. The actual test set is being holdout and not seen during training or validation. In this case, how do I run the actual test? I can't pass just the test_task to the trainer, because the train_task is non-optional. Looks like there should be a way to do this with evojax.
    2. The test_task of the trainer is used for the actual test, no validation is used at all. In this case, why does the trainer.run return the best model score and not the last model score?

    I propose the following (high level) logic:

    best_val_reward = trainer.fit(train_task: VectorizedTask, val_task: Optional[VectorizedTask] = None)  # maybe the user doesn't want validation (e.g. train on latest data without early stopping)
    test_reward = trainer.test(test_task: VectorizedTask, checkpoint="best|last|path")  # specify which checkpoint to use for testing
    

    Probably early stopping would be pretty necessary for the trainer.fit method. Currently there is no way to determine when to do it and even which model iteration has the best result.

    I'm willing to implement this logic in a PR.

    opened by danielgafni 7
  • high dimensional parametric search

    high dimensional parametric search

    I'm trying to use evojax to evolve my model parameters. I found that the algorithm only accepts the parameter num_dims as the dimension, whether it can only be int type here? If I want to evolve multidimensional parameters, such as [1000x1000] data, how can I do it? Thanks!

    opened by Agnes233 5
  • add CR-FM-NES algorithm

    add CR-FM-NES algorithm

    Adds a wrapper to CR-FM-NES, see "Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES)" pdf .

    It wraps the fcmaes Eigen/C++ version of CR-FM-NES which is derived from https://github.com/nomuramasahir0/crfmnes.

    Since there are numpy and Eigen based implementations (and soon a JAX based one) of CR-FM-NES available, it will be possible to compare the performance of these tree "backends" for the same algorithm. This commit wraps only the C++/Eigen based implementation crfmnes.cpp .

    Tested on NVIDIA 3090 + AMD 5950x Linux Mint 20 (Ubuntu based). Performance (wall time) is similar to PGPE outperforming CMA_ES_JAX. Benchmark results for waterworld are above all other algorithms. Do "pip install fcmaes --upgrade" before testing.

    opened by dietmarwo 5
  • Evaluating brax environments other than brax-ant. Terminates with error.

    Evaluating brax environments other than brax-ant. Terminates with error.

    Information

    Issue is with running brax environments other brax-ant. The included humanoid, half cheetah and fetch environments are affected.

    Couldn't find any references to this issue in the repo. I could have missed something.

    Expected Behavior

    /home/<USER>/anaconda3/envs/evojax/bin/python /home/<USER>/evojax/scripts/benchmarks/train.py -config configs/PGPE/brax_halfcheetah.yaml
    brax: 2022-06-16 20:41:01,954 [INFO] EvoJAX brax
    brax: 2022-06-16 20:41:01,954 [INFO] ==============================
    absl: 2022-06-16 20:41:02,137 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
    absl: 2022-06-16 20:41:02,221 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
    MLPPolicy: 2022-06-16 20:41:03,747 [INFO] MLPPolicy.num_params = 3974
    brax: 2022-06-16 20:41:03,787 [INFO] use_for_loop=False
    brax: 2022-06-16 20:41:03,825 [INFO] Start to train for 1 iterations.
    brax: 2022-06-16 20:41:56,024 [INFO] [TEST] Iter=1, #tests=1, max=-9.7476, avg=-9.7476, min=-9.7476, std=0.0000
    brax: 2022-06-16 20:41:56,087 [INFO] Training done, best_score=-9.7476
    brax: 2022-06-16 20:41:56,093 [INFO] Loaded model parameters from ./log/PGPE/brax/default.
    brax: 2022-06-16 20:41:56,093 [INFO] Start to test the parameters.
    brax: 2022-06-16 20:42:03,478 [INFO] [TEST] #tests=1, max=-9.9009, avg=-9.9009, min=-9.9009, std=0.0000
    

    Current Behavior

    brax: 2022-06-16 20:26:04,657 [INFO] EvoJAX brax
    brax: 2022-06-16 20:26:04,657 [INFO] ==============================
    absl: 2022-06-16 20:26:04,833 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
    absl: 2022-06-16 20:26:04,920 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
    MLPPolicy: 2022-06-16 20:26:06,465 [INFO] MLPPolicy.num_params = 3974
    brax: 2022-06-16 20:26:06,504 [INFO] use_for_loop=False
    brax: 2022-06-16 20:26:06,541 [INFO] Start to train for 10 iterations.
    Traceback (most recent call last):
      File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 88, in <module>
        main(config)
      File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 64, in main
        trainer.run(demo_mode=False)
      File "/home/<USER>/evojax/evojax/trainer.py", line 152, in run
        scores, bds = self.sim_mgr.eval_params(
      File "/home/<USER>/evojax/evojax/sim_mgr.py", line 258, in eval_params
        return self._scan_loop_eval(params, test)
      File "/home/<USER>/evojax/evojax/sim_mgr.py", line 355, in _scan_loop_eval
        scores, all_obs, masks, final_states = rollout_func(
      File "/home/<USER>/evojax/evojax/sim_mgr.py", line 202, in rollout
        (obs_set, obs_mask)) = jax.lax.scan(
      File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1630, in scan
        _check_tree_and_avals("scan carry output and input",
      File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2316, in _check_tree_and_avals
        raise TypeError(f"{what} must have identical types, got\n{diff}.")
    jax._src.traceback_util.UnfilteredStackTrace: TypeError: scan carry output and input must have identical types, got
    (State(state=State(qp=QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), obs='ShapedArray(float32[16384,18])', reward='ShapedArray(float32[16384])', done='ShapedArray(float32[16384])', metrics={'reward_ctrl_cost': 'ShapedArray(float32[16384])', 'reward_forward': 'ShapedArray(float32[16384])'}, info={'first_obs': 'ShapedArray(float32[16384,18])', 'first_qp': QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), 'steps': 'ShapedArray(float32[16384])', 'truncation': 'ShapedArray(float32[16384])'}), obs='ShapedArray(float32[16384,18])', feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])'), PolicyState(keys='ShapedArray(uint32[16384,2])'), 'ShapedArray(float32[16384,3974])', 'ShapedArray(float32[37])', 'ShapedArray(float32[16384])', 'ShapedArray(float32[16384])').
    

    Exact Error:

    feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])')
    

    Failure Information

    Context

    Based on commit history, this appears to be due to the changes introduced in #33 . Manually altering variable feet_contact variable from method reset_fn in file evojax/evojax/task/brax_task.py allows for the other environments to be run.

    Setup details related to the hardware are irrelevant since error occurs on the hosted colab notebook as well.

    brax                         0.0.13
    evojax                       0.2.11               
    flax                         0.4.0
    jax                          0.3.1
    jaxlib                       0.3.0+cuda11.cudnn82
    

    Steps to Reproduce

    Please provide detailed steps for reproducing the issue.

    1. Run evojax/scripts/benchmarks/train.py using a modified evojax/scripts/benchmarks/configs/<ES> file using non-ant brax environment.
    2. Modify feet_contact array size and test.
    opened by Surya-77 5
  • AssertionError for OpenES

    AssertionError for OpenES

    When I try to instantiate OpenES from open_es.py, I get the following error message: Schermata 2022-12-15 alle 20 23 59 I traced back the problem to line 110 in open_es.py, where both centered_rank and z_score arguments are set to True: Schermata 2022-12-15 alle 20 26 01 But line 26 of FitnessShaper class from evosax/utils/reshape_fitness.py says that Schermata 2022-12-15 alle 20 26 49 How to get around this issue?

    opened by pigozzif 4
  • Native implementation in JAX of Augmented Random Search

    Native implementation in JAX of Augmented Random Search

    Test results

    Note for MNIST I halved the batch size and doubled the iterations due to memory issues. | | Benchmark | Params | Results (avg.) | | ----------------|-----------------------|----------|------------- | | CartPole (easy) | 900 (max_iter=1000) | Link| 910 | | CartPole (hard) | 600 (max_iter=2000) | Link | 558.02 | | MNIST | 0.90 (max_iter=4000) | Link | 0.92 | | Brax Ant | 3000 (max_iter=700) | Link | 4129.83 | | Waterworld | 6 (max_iter=2000) |Link | 7.29 | | Waterworld (MA) | 2 (max_iter=2000) | Link | 1.68 |

    opened by EdoardoPona 4
  • AbstractPainting02.ipynb. doesn't work on colab

    AbstractPainting02.ipynb. doesn't work on colab

    Hello, this is a really great code.

    I was able to run "Abstract Painting 01" very well at Google coab. However, when I ran "AbstractPainting02", an error occurred.

    Exception                                 Traceback (most recent call last)
    [<ipython-input-20-b16203d22159>](https://localhost:8080/#) in <module>()
          2 devices = jax.local_devices()
          3 
    ----> 4 image_fn, text_fn, jax_params, jax_preprocess = clip_jax.load('ViT-B/32', "cpu")
          5 
          6 target_text_ids = jnp.array(clip_jax.tokenize([prompt])) # already with batch dim
    
    3 frames
    [/content/CLIP_JAX/clip_jax/clip.py](https://localhost:8080/#) in process_node(value, name)
        117             new_tensor = jnp.array(pytorch_tensor)
        118         else:
    --> 119             raise Exception("not implemented")
        120 
        121         assert new_tensor.shape == value.shape
    
    Exception: not implemented
    

    Which version of clip_jax when you made?

    Best

    opened by shi3z 4
  • Evosax - Sep-CMA-ES

    Evosax - Sep-CMA-ES

    • Reference: Ros & Hansen (2008)
    • evosax Source Code: https://github.com/RobertTLange/evosax/blob/main/evosax/strategies/sep_cma_es.py
    • This PR adds a CMA-ES version which imposes a diagonal structure for the estimated covariance matrix. Thereby it is a lot more memory efficient as compared to pure CMA-ES, which has do store a (d x d) matrix.
    • Benchmarks and hyperparameters:

    | | Benchmarks | Parameters | Results (Avg) | |---|---|---|---| CartPole (easy) | 900 (max_iter=1000)|Link| 924.3028 | CartPole (hard) | 600 (max_iter=1000)|Link| 626.9728 | MNIST | 90.0 (max_iter=2000) | Link| 0.9545 | Brax Ant | 3000 (max_iter=300) |Link| 3980.9194 | Waterworld | 6 (max_iter=500) | Link| 9.9000 | Waterworld (MA) | 2 (max_iter=2000) | Link | 1.1875 |

    Note: Linting doesn't pass due to import error for Open_ES - see PR #19. This has to be merged first.

    opened by RobertTLange 3
  • Adding a Linear Policy

    Adding a Linear Policy

    This is a simple linear policy (1 layer neural network). This policy is especially useful for tasks related to control, with for example augmented random search. In fact, in the original ARS paper, one of the algorithm's key advantages is the ability to find high performing linear policies.

    I created a new policy rather than editing MLP for simplicity, and since they would most likely be used in different contexts (eg. tasks, algorithms)

    opened by EdoardoPona 2
  • Add a Python/JAX port of CR-FM-NES

    Add a Python/JAX port of CR-FM-NES

    This PR adds a Python/JAX port of Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES), see https://arxiv.org/abs/2201.11422 . Derived from https://github.com/nomuramasahir0/crfmnes.

    This variant is slightly faster than FCRFMC (the C++ port) on fast GPUs/TPUs, but slower on CPUs and for smaller dimensions. It uses 32 bit accuracy (FCRFMC uses 64 bit) which mostly doesn't harm the convergence (with Waterworld MA being the exception for very high iteration numbers).

    Wall time and convergence is mostly comparable with PGPE (as FCRFMC) for the benchmarks. Slower in the beginning, but improving at higher iterations.

    Since there are no for-loops I found no beneficial applications of 'jax.jit', just converted most 'np.arrays' into 'jnp.arrays' deployed on the GPUs/TPUs.

    def sort_indices_by(evals: np.ndarray, z: jnp.ndarray) -> jnp.ndarray:

    uses not evals: jnp.ndarray because this slowed things down on my NVIDIA 3090.

    Since this is Python code, no missing shared libraries on Ubuntu 18 this time.

    Added test results for CRFMNES (this Python implementation) at EvoJax.adoc.

    opened by dietmarwo 2
  • Reproducing benchmark scores

    Reproducing benchmark scores

    Hello everyone.

    I am currently currently trying to reproduce scores from the benchmarks, specifically for ARS, as I am implementing my own version native in jax, and wanted to compare with the wrapper already implemented.

    For example, I cannot achieve the score posted in the benchmark table (902.107) for ARS on cartpole_easy.

    running python train.py -config configs/ARS/cartpole_easy.yaml yields the following training logs

    cartpole_easy: 2022-09-25 22:45:55,777 [INFO] EvoJAX cartpole_easy
    cartpole_easy: 2022-09-25 22:45:55,777 [INFO] ==============================
    absl: 2022-09-25 22:45:55,791 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
    absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
    absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
    MLPPolicy: 2022-09-25 22:45:59,165 [INFO] MLPPolicy.num_params = 4609
    cartpole_easy: 2022-09-25 22:45:59,429 [INFO] use_for_loop=False
    cartpole_easy: 2022-09-25 22:45:59,496 [INFO] Start to train for 1000 iterations.
    cartpole_easy: 2022-09-25 22:46:10,527 [INFO] Iter=50, size=100, max=399.5886, avg=207.9111, min=0.5843, std=99.0207
    cartpole_easy: 2022-09-25 22:46:19,916 [INFO] Iter=100, size=100, max=543.8907, avg=364.9780, min=28.8478, std=141.8982
    cartpole_easy: 2022-09-25 22:46:21,143 [INFO] [TEST] Iter=100, #tests=100, max=553.4018 avg=510.5583, min=462.4243, std=15.6930
    cartpole_easy: 2022-09-25 22:46:30,627 [INFO] Iter=150, size=100, max=558.2020, avg=314.9279, min=89.8001, std=153.6488
    cartpole_easy: 2022-09-25 22:46:40,068 [INFO] Iter=200, size=100, max=562.4118, avg=354.9529, min=47.0048, std=154.1567
    cartpole_easy: 2022-09-25 22:46:40,114 [INFO] [TEST] Iter=200, #tests=100, max=570.1135 avg=547.5375, min=508.5795, std=10.0840
    cartpole_easy: 2022-09-25 22:46:49,579 [INFO] Iter=250, size=100, max=562.1505, avg=325.3990, min=73.3733, std=161.9460
    cartpole_easy: 2022-09-25 22:46:59,073 [INFO] Iter=300, size=100, max=569.5461, avg=370.2641, min=83.7473, std=166.8020
    cartpole_easy: 2022-09-25 22:46:59,129 [INFO] [TEST] Iter=300, #tests=100, max=573.5941 avg=545.0388, min=505.8637, std=11.3853
    cartpole_easy: 2022-09-25 22:47:08,623 [INFO] Iter=350, size=100, max=579.3894, avg=425.6462, min=82.4907, std=126.6614
    cartpole_easy: 2022-09-25 22:47:18,109 [INFO] Iter=400, size=100, max=627.6509, avg=530.2781, min=156.4797, std=76.0956
    cartpole_easy: 2022-09-25 22:47:18,160 [INFO] [TEST] Iter=400, #tests=100, max=639.7323 avg=600.9105, min=573.7767, std=10.7564
    cartpole_easy: 2022-09-25 22:47:27,653 [INFO] Iter=450, size=100, max=668.2064, avg=546.0261, min=418.5385, std=60.5854
    cartpole_easy: 2022-09-25 22:47:37,149 [INFO] Iter=500, size=100, max=684.4142, avg=574.4891, min=446.3126, std=62.5338
    cartpole_easy: 2022-09-25 22:47:37,202 [INFO] [TEST] Iter=500, #tests=100, max=693.1522 avg=682.7945, min=638.0387, std=12.1575
    cartpole_easy: 2022-09-25 22:47:46,708 [INFO] Iter=550, size=100, max=708.9561, avg=591.0547, min=295.5651, std=73.6026
    cartpole_easy: 2022-09-25 22:47:56,212 [INFO] Iter=600, size=100, max=706.8138, avg=599.4783, min=348.7581, std=55.6310
    cartpole_easy: 2022-09-25 22:47:56,263 [INFO] [TEST] Iter=600, #tests=100, max=691.0123 avg=680.4677, min=630.2983, std=6.1448
    cartpole_easy: 2022-09-25 22:48:05,770 [INFO] Iter=650, size=100, max=707.0887, avg=581.3851, min=418.2251, std=75.9066
    cartpole_easy: 2022-09-25 22:48:15,275 [INFO] Iter=700, size=100, max=712.7586, avg=586.4597, min=362.7628, std=71.5669
    cartpole_easy: 2022-09-25 22:48:15,326 [INFO] [TEST] Iter=700, #tests=100, max=725.2336 avg=714.1309, min=635.7863, std=9.3471
    cartpole_easy: 2022-09-25 22:48:24,849 [INFO] Iter=750, size=100, max=716.1056, avg=602.7747, min=458.0401, std=63.1697
    cartpole_easy: 2022-09-25 22:48:34,365 [INFO] Iter=800, size=100, max=709.3475, avg=587.9896, min=393.0367, std=69.2385
    cartpole_easy: 2022-09-25 22:48:34,418 [INFO] [TEST] Iter=800, #tests=100, max=732.5553 avg=720.5952, min=648.5032, std=8.3936
    cartpole_easy: 2022-09-25 22:48:43,945 [INFO] Iter=850, size=100, max=706.8488, avg=598.3582, min=321.8640, std=75.2542
    cartpole_easy: 2022-09-25 22:48:53,482 [INFO] Iter=900, size=100, max=720.0320, avg=596.1929, min=370.6555, std=77.2801
    cartpole_easy: 2022-09-25 22:48:53,536 [INFO] [TEST] Iter=900, #tests=100, max=703.5345 avg=692.9500, min=677.6909, std=5.9381
    cartpole_easy: 2022-09-25 22:49:03,068 [INFO] Iter=950, size=100, max=716.2341, avg=598.3802, min=422.7760, std=71.7756
    cartpole_easy: 2022-09-25 22:49:12,455 [INFO] [TEST] Iter=1000, #tests=100, max=726.0114, avg=719.0803, min=698.4325, std=4.7247
    cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952
    cartpole_easy: 2022-09-25 22:49:12,458 [INFO] Loaded model parameters from ./log/ARS/cartpole_easy/default.
    cartpole_easy: 2022-09-25 22:49:12,459 [INFO] Start to test the parameters.
    cartpole_easy: 2022-09-25 22:49:12,509 [INFO] [TEST] #tests=100, max=728.9848, avg=720.6152, min=698.9832, std=5.0566
    

    I am not entirely sure if the result on the benchmark table is intended to be 720.5952 from cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952

    or the max score from the final test. Regardless, neither of these match the one posted on the benchmark table.

    Am I doing something wrong to reproduce these scores? This makes me unable to compare my own implementation of the algorithm.

    Thank you

    opened by EdoardoPona 2
  • Add Diversifier QD Meta Algorithm - JAX backend

    Add Diversifier QD Meta Algorithm - JAX backend

    This PR adds a new JAX-based QD meta algorithm called Diversifier. It is a generalization of CMA-ME.

    It uses a MAP-Elites archive not for solution candidate generation, but only to modify the fitness values told (via tell) to the wrapped algorithm. This modification changes the fitness ranking of the population to favor exploration over exploitation. Tested with CR-FM-NES and CMA-ES, but other wrapped algorithms may work as well. Based on fcmaes diversifier.py (see MapElites.adoc).

    The generalization over CMA-ME is necessary in the EvoJAX context, because CMA-ES struggles with a very high number of decision variables. Therefore CR-FM-NES-ME is superior here - as possibly are other not yet tested alternatives.

    https://doi.org/10.1145/2739480.2754664 proposes the QD score (sum of fitness values of all elites in the map) as metric for comparison.

    For Brax-Ant CR-FM-NES-ME (Diversifier applied to CR-FM-NES), compared with MAP-Elites, reaches a higher QD-score for high iteration numbers (see details below). So MAP-Elites should only be preferred for a low evaluation budget or if you want to maximize the number of occupied niches.

    On a NVIDIA 3090 + AMD 5950, Linux Mint with optimized configurations we measured:

    • MAP-Elites has a the same optimizer overhead (evaluation/sec rate for the same popsize).
    • MAP-Elites has a higher number of occupied niches.

    but

    • CR-FM-NES-ME has a much higher QD score and found a better global optimum for a high evaluation budget.

    Detailed measurements for the Brax-Ant example (NVIDIA 3090 + AMD 5950, Linux Mint):

    After 20 minutes MAP-Elites is in the lead, but slows down from there. CR_FM_NES-ME continues to improve until 500 minutes / 8 million evaluations. CR_FM_NES-ME can even produce a good global optimum - 4107 - thereby still occupying 6138 niches with a mean score of 1208. After 500 minutes MAP-Elites continues to improve where CR_FM_NES-ME does not, but at that time CR_FM_NES-ME has a >70% lead in score.

    CR_FM_NES-ME with init-std = 0.159, popsize = 512, fitness_weight 0.0

    20 min QD score: 1692282 occupied: 4936 max score: 558 mean score: 342 evaluations: 263680 50 min QD score: 2724260 occupied: 5628 max score: 918 mean score: 484 evaluations: 704512 100 min QD score: 4289807 occupied: 6087 max score: 1442 mean score: 704 evaluations: 1496576 200 min QD score: 5928753 occupied: 6138 max score: 2363 mean score: 965 evaluations: 3072000 300 min QD score: 6524518 occupied: 6138 max score: 2862 mean score: 1063 evaluations: 4710400 400 min QD score: 7353257 occupied: 6138 max score: 3889 mean score: 1198 evaluations: 6348800 500 min QD score: 7418018 occupied: 6138 max score: 4107 mean score: 1208 evaluations: 7884800 600 min QD score: 7444092 occupied: 6138 max score: 4211 mean score: 1212 evaluations: 9523200

    MAP-Elites iso-sigma = 0.05, line-sigma = 0.2, popsize = 1024: (line-sigma = 0.3 is worse)

    20 min QD score: 2509773 occupied: 5621 max score: 643 mean score: 446 evaluations: 346112 50 min QD score: 3022521 occupied: 6375 max score: 724 mean score: 474 evaluations: 915456 100 min QD score: 3383041 occupied: 6786 max score: 769 mean score: 498 evaluations: 1941504 200 min QD score: 3713977 occupied: 7107 max score: 825 mean score: 522 evaluations: 3936256 300 min QD score: 3915492 occupied: 7265 max score: 927 mean score: 538 evaluations: 5922816 400 min QD score: 4065677 occupied: 7400 max score: 927 mean score: 549 evaluations: 7941120 500 min QD score: 4179020 occupied: 7498 max score: 927 mean score: 557 evaluations: 9958400 600 min QD score: 4272665 occupied: 7566 max score: 927 mean score: 564 evaluations: 12083200 700 min QD score: 4351397 occupied: 7632 max score: 941 mean score: 570 evaluations: 14094336 800 min QD score: 4415351 occupied: 7675 max score: 1003 mean score: 575 evaluations: 16040960

    These results indicate that it should be possible to apply MAP-Elites to the resulting CR_FM_NES-ME archive to further improve occupancy and score. As algorithm wrapped by Diversifier,py CRFMNES can be replaced by FCRFMC (same algorithm but implemented in C++). We got the same results, but this may reduce the GPU load for smaller GPUs/TPUs and is definively advantageous for CPU alone executions. On the Nvidia 3090 CRFMNES is slightly faster.

    Note that 'fitness_weight' is a concept neither used in CMA-ME nor in fcmaes fcmaes diversifier. All these use implicitely fitness_weight=0. For fcmaes the reason is that there are other means to improve the elites of a given map, so the focus is on exploration here. We use as default fitness_weight=0, because for Brax Ant the final QD score is higher - but the final global optimum found is lower.

    fcmaes even supports sequences of wrapped algorithms, something probably not relevant for EvoJAX.

    Increasing popsize to 1024 closes the evaluations / sec gap to MAP-Elites, the rate is 34% higher than with popsize = 512. But popsize = 1024 seems to produce lower occupancy - which is quite suprising:

    CR_FM_NES-ME with init-std = 0.159, popsize = 1024, fitness_weight 0.0

    20 min QD score: 1864258 occupied: 4856 max score: 538 mean score: 383 evaluations: 350208 50 min QD score: 2702674 occupied: 5402 max score: 848 mean score: 500 evaluations: 905216 100 min QD score: 3853807 occupied: 5873 max score: 1288 mean score: 656 evaluations: 1879040 200 min QD score: 5005292 occupied: 5947 max score: 1781 mean score: 841 evaluations: 3891200 300 min QD score: 6425120 occupied: 5963 max score: 2936 mean score: 1077 evaluations: 5963776 400 min QD score: 7103424 occupied: 5976 max score: 3783 mean score: 1188 evaluations: 8192000 500 min QD score: 7282457 occupied: 5980 max score: 4111 mean score: 1217 evaluations: 10240000 600 min QD score: 7371868 occupied: 5982 max score: 4227 mean score: 1232 evaluations: 12288000 700 min QD score: 7405531 occupied: 5982 max score: 4276 mean score: 1237 evaluations: 14336000 800 min QD score: 7464371 occupied: 5983 max score: 4307 mean score: 1247 evaluations: 16384000 900 min QD score: 7509290 occupied: 5989 max score: 4325 mean score: 1253 evaluations: 18432000 1000min QD score: 7514184 occupied: 5989 max score: 4342 mean score: 1254 evaluations: 20480000

    But why can't we have our cake and eat it too?

    This is not part of the PR but discusses what could be done in the future:

    Both Diversifier and MAP-Elites share the same archive management. They differ only in population generation. In the future both could be unified into a single MD solver - still called MAP-Elites. This new implementation could randomly chose the way "ask" works. We define a probability, a wrapped solver is used instead of the standard mechanism. If this probability is 0, we have the old MAP-Elites. If it is 1.0, we have Diversifier. The interesting question is: What happens for values in between? Lets try 0.5. This can easily be implemented as:

        def ask(self) -> jnp.ndarray:
            self.key, key = jax.random.split(self.key)
            if jax.random.uniform(key) > 0.5: # a parameter to play with
                self.population = self.solver.ask() # population from wrapped solver
                self.solver_asked = True
            else: # population from MA-Elites generator
                self.key, mutate_key, parents = self._sample_parents(
                                    key=self.key,
                                    occupancy=self.occupancy_lattice,
                                    params=self.params_lattice)      
                self.population = self._gen_pop(parents, mutate_key)
                self.solver_asked = False
            return self.population
    
        def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:      
            if self.solver_asked:   
                lattice_fitness = self.fitness_lattice[self.bin_idx]
                to_tell = self._get_to_tell(fitness, lattice_fitness, self.fitness_weight)
                self.solver.tell(to_tell)
            # update lattice 
    

    MAP-Elites + CR_FM_NES-ME with iso-sigma = 0.05, line-sigma = 0.2, init-std = 0.159, popsize = 1024, fitness_weight 0.0

    20 min QD score: 2201387 occupied: 5266 max score: 672 mean score: 418 evaluations: 344064 50 min QD score: 2738691 occupied: 6105 max score: 672 mean score: 448 evaluations: 892928 100 min QD score: 3348423 occupied: 6656 max score: 857 mean score: 503 evaluations: 1859584 200 min QD score: 4233393 occupied: 7135 max score: 1103 mean score: 593 evaluations: 3851264 300 min QD score: 5139277 occupied: 7334 max score: 1586 mean score: 700 evaluations: 5893120 400 min QD score: 5776929 occupied: 7457 max score: 1884 mean score: 774 evaluations: 7943168 500 min QD score: 6098261 occupied: 7537 max score: 2104 mean score: 809 evaluations: 9947136 600 min QD score: 7240351 occupied: 7603 max score: 2890 mean score: 952 evaluations: 11999232 700 min QD score: 7800357 occupied: 7660 max score: 3421 mean score: 1018 evaluations: 14023680 800 min QD score: 8004109 occupied: 7699 max score: 3735 mean score: 1039 evaluations: 16000000

    This is a 81% QD score increase compared to MAP-Elites alone thereby also improving occupancy.

    900 min QD score: 8115904 occupied: 7744 max score: 3917 mean score: 1048 evaluations: 18140160 1000min QD score: 8195842 occupied: 7772 max score: 4019 mean score: 1054 evaluations: 20133888 1100min QD score: 8259249 occupied: 7799 max score: 4090 mean score: 1059 evaluations: 22155264 1200min QD score: 8319027 occupied: 7826 max score: 4130 mean score: 1062 evaluations: 24177664 1300min QD score: 8362034 occupied: 7847 max score: 4156 mean score: 1065 evaluations: 26213376

    QD-score 8362034 probably is a challenge for each algorithm independent from the evaluation budget.

    opened by dietmarwo 1
  • Bug of center_lr_decay_steps when use adam with PGPE

    Bug of center_lr_decay_steps when use adam with PGPE

    Bug

    When use adam with PGPE this code

    self._opt_state = self._opt_update(
                self._t // self._lr_decay_steps, -grad_center, self._opt_state
            )
    

    means adam t will increase after every self._lr_decay_steps. And it means mhat and vhat will not work as moving average because (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) will be very small always. (bellow is adam update code)

    def update(i, g, state):
        x, m, v = state
        m = (1 - b1) * g + b1 * m  # First  moment estimate.
        v = (1 - b2) * jnp.square(g) + b2 * v  # Second moment estimate.
        mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
        vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
        x = x - step_size(i) * mhat / (jnp.sqrt(vhat) + eps)
        return x, m, v
    

    Suggestion

    I think it is better to change this code to

    step_size=lambda x: self._center_lr * jnp.power(decay_coef, x // self._lr_decay_steps),
    

    and to remove self._lr_decay_steps at

    self._opt_state = self._opt_update(
                self._t, -grad_center, self._opt_state
            )
    
    opened by garam-kim1 1
  • Save top n models per checkpoint

    Save top n models per checkpoint

    As I understand, currently only the best model from the population is being saved in the end of the iteration. This may lead to inconsistent train/test results (due to overfitting) in some setups. Blending the top n models could potentially reduce this effect.

    Would you be interested in this feature for evojax? I can work on a PR. Seems like not all solvers can have this feature.

    opened by danielgafni 0
  • fix support for multi-dim observations

    fix support for multi-dim observations

    Hey! I found a bug in the observations normalization code. The bug occurs when the observations are not a flat array, but a multi-dim array. This happens because the obs_normalizer params are stored as a flat array. The code fails in this case. Here is the fix for this bug.

    opened by danielgafni 5
  • JAX implementation of CMAES

    JAX implementation of CMAES

    Hi, I'm really amazed by this library.

    Currently, CMAES is just a wrapper. I implemented a JAX CMAES based on https://github.com/CyberAgentAILab/cmaes/.

    opened by moskomule 7
Releases(v0.2.15)
Owner
Google
Google ❤️ Open Source
Google
A module for solving and visualizing Schrödinger equation.

qmsolve This is an attempt at making a solid, easy to use solver, capable of solving and visualize the Schrödinger equation for multiple particles, an

506 Dec 28, 2022
Syntax-Aware Action Targeting for Video Captioning

Syntax-Aware Action Targeting for Video Captioning Code for SAAT from "Syntax-Aware Action Targeting for Video Captioning" (Accepted to CVPR 2020). Th

59 Oct 13, 2022
Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis"

Beyond the Spectrum Implementation for the IJCAI2021 work "Beyond the Spectrum: Detecting Deepfakes via Re-synthesis" by Yang He, Ning Yu, Margret Keu

Yang He 27 Jan 07, 2023
Object detection and instance segmentation toolkit based on PaddlePaddle.

Object detection and instance segmentation toolkit based on PaddlePaddle.

9.3k Jan 02, 2023
PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Self-Supervised Vision Transformers with DINO PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supe

Facebook Research 4.2k Jan 03, 2023
Some pvbatch (paraview) scripts for postprocessing OpenFOAM data

pvbatchForFoam Some pvbatch (paraview) scripts for postprocessing OpenFOAM data For every script there is a help message available: pvbatch pv_state_s

Morev Ilya 2 Oct 26, 2022
Le dataset des images du projet d'IA de 2021

face-mask-dataset-ilc-2021 Le dataset des images du projet d'IA de 2021, Indiquez vos id git dans la issue pour les droits TL;DR: Choisir 200 images J

7 Nov 15, 2021
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
Source code for ZePHyR: Zero-shot Pose Hypothesis Rating @ ICRA 2021

ZePHyR: Zero-shot Pose Hypothesis Rating ZePHyR is a zero-shot 6D object pose estimation pipeline. The core is a learned scoring function that compare

R-Pad - Robots Perceiving and Doing 18 Aug 22, 2022
Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision. ICCV 2021.

Towers of Babel: Combining Images, Language, and 3D Geometry for Learning Multimodal Vision Download links and PyTorch implementation of "Towers of Ba

Blakey Wu 40 Dec 14, 2022
Notepy is a full-featured Notepad Python app

Notepy A full featured python text-editor Notable features Autocompletion for parenthesis and quote Auto identation Syntax highlighting Compile and ru

Mirko Rovere 11 Sep 28, 2022
CMT: Convolutional Neural Networks Meet Vision Transformers

CMT: Convolutional Neural Networks Meet Vision Transformers [arxiv] 1. Introduction This repo is the CMT model which impelement with pytorch, no refer

FlyEgle 83 Dec 30, 2022
PCAM: Product of Cross-Attention Matrices for Rigid Registration of Point Clouds

PCAM: Product of Cross-Attention Matrices for Rigid Registration of Point Clouds PCAM: Product of Cross-Attention Matrices for Rigid Registration of P

valeo.ai 24 May 31, 2022
Image Recognition using Pytorch

PyTorch Project Template A simple and well designed structure is essential for any Deep Learning project, so after a lot practice and contributing in

Sarat Chinni 1 Nov 02, 2021
PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, EfficientNetV2, NFNet, Vision Transformer, MixNet, MobileNet-V3/V2, RegNet, DPN, CSPNet, and more

PyTorch Image Models Sponsors What's New Introduction Models Features Results Getting Started (Documentation) Train, Validation, Inference Scripts Awe

Ross Wightman 22.9k Jan 09, 2023
✨风纪委员会自动投票脚本,利用Github Action帮你进行裁决操作(为了让其他风纪委员有案件可判,本程序从中午12点才开始运行,有需要请自己修改运行时间)

风纪委员会自动投票 本脚本通过使用Github Action来实现B站风纪委员的自动投票功能,喜欢请给我点个STAR吧! 如果你不是风纪委员,在符合风纪委员申请条件的情况下,本脚本会自动帮你申请 投票时间是早上八点,如果有需要请自行修改.github/workflows/Judge.yml中的时间,

Pesy Wu 25 Feb 17, 2021
An automated algorithm to extract the linear blend skinning (LBS) from a set of example poses

Dem Bones This repository contains an implementation of Smooth Skinning Decomposition with Rigid Bones, an automated algorithm to extract the Linear B

Electronic Arts 684 Dec 26, 2022
MINIROCKET: A Very Fast (Almost) Deterministic Transform for Time Series Classification

MINIROCKET: A Very Fast (Almost) Deterministic Transform for Time Series Classification

187 Dec 26, 2022
PyTorch deep learning projects made easy.

PyTorch Template Project PyTorch deep learning project made easy. PyTorch Template Project Requirements Features Folder Structure Usage Config file fo

Victor Huang 3.8k Jan 01, 2023
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022