Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators

Overview

BRAX

Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. It's also a suite of learning algorithms to train agents to operate in these environments (PPO, SAC, evolutionary strategy, and direct trajectory optimization are implemented).

Brax is written in JAX and is designed for use on acceleration hardware. It is both efficient for single-core training, and scalable to massively parallel simulation, without the need for pesky datacenters.

Some policies trained via Brax. Brax simulates these environments at millions of physics steps per second on TPU.

Colab Notebooks

Explore Brax easily and quickly through a series of colab notebooks:

  • Brax Basics introduces the Brax API, and shows how to simulate basic physics primitives.
  • Brax Training introduces Brax environments and training algorithms, and lets you train your own policies directly within the colab.

Using Brax locally

To install Brax from source, clone this repo, cd to it, and then:

python3 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install -e .

To train a model:

learn

Training on NVidia GPU is supported, but you must first install CUDA, CuDNN, and JAX with GPU support.

Citing Brax

If you would like to reference Brax in a publication, please use:

@software{brax2021github,
  author = {C. Daniel Freeman and Erik Frey and Anton Raichuk and Sertan Girgin and Igor Mordatch and Olivier Bachem},
  title = {Brax - A Differentiable Physics Engine for Large Scale Rigid Body Simulation},
  url = {http://github.com/google/brax},
  version = {0.1.0},
  year = {2021},
}
Comments
  • Question about JS development

    Question about JS development

    This is probably a really silly question, but I have no experience whatsoever with JS and anything web related.

    Im trying to add a torus primitive to the set of colliders. I made a simple mostly empty env to try it out in, and im using the HTML render functionality showcased in other notebook examples to try and debug my progress. Ive added colliders, changed the protobuf definition and compiled it, etc. Not much of a clue what im doing yet, but all parts of the brax repo that mention spheres now have a torus equivalent. The plane I added in my env shows up nicely; but my torus does not show up. Could be many reasons of course, but ive found at least one. Deleting createPlane in system.js in my repo, also does not stop my plane render code from working.

    Further digging reveals that that html.py has a line import {Viewer} from 'https://cdn.jsdelivr.net/gh/google/[email protected]/js/viewer.js';

    Unless im missing something, thats not referring to viewer.js inside my repo; so it makes sense im not seeing changes made there reflected in my notebook.

    Now im not sure if this is just some WIP development code that got merged, or if im missing something fundamental about the zen-of-webdev here; but if you were me and looking to make a change that would allow me to do local development on this JS, and also would stand a chance of getting merged, what would you do? Do I need to locally host my viewer.js and link the HTML to that? But then how would I merge a non-broken PR if its supposed to refer to this CDN of a past brax release?

    In general, is there something im missing about doing local development? Is debugging using the 3JS viz in a notebook the way to go?

    opened by EelcoHoogendoorn 15
  • Add Brax to conda

    Add Brax to conda

    As a number of machine learning projects use conda and it is better for conda is all dependencies are also on conda

    Could Brax to added to conda-forge?

    opened by pseudo-rnd-thoughts 14
  • Bug: capsule/sphere-plane collision working unreliably

    Bug: capsule/sphere-plane collision working unreliably

    I'm trying to simulate a pool billiards tabletop that I've modeled with a plane as floor and 4 planes surrounding it, facing inward. If I now spawn a ball and bounce it off the walls, sometimes the collisions work well:

    (left is top-down view, right is side view)

    brax-collision-right1

    ...but sometimes it doesn't work and the ball gets stuck in the wall:

    brax-collision-weird3

    brax-collision-weird4

    I've made a Colab to reproduce the issue: https://colab.research.google.com/drive/1flnseQcjarIYM4G_rECTEAaps-kXPoey?usp=sharing

    If anybody has any pointers, that'd be greatly appreciated.

    Best, Flo

    bug 
    opened by fgolemo 14
  • Enable per-collider friction specification

    Enable per-collider friction specification

    • Add optional float field "friction" to colliders; If unspecified, the global friction coefficient is applied
    • Re-compile protobuf file (config_pb2.py)
    • Introduce Collider dataclass to pass into _collide() and _collide_pair()
    • Replace references to config.friction with per-body friction (Collider.friction)

    Fixes #55

    Please note:

    1. I did some primitive tests within physics_test.py but observed that the simulation is currently unstable with friction due to lateral friction, no matter how large I set the number of substeps. Hence I ended up rolling back, in case someone else can advise. I'd appreciate any help writing tests for these changes in near future.

    2. I am preserving the plane-body (and mesh-body?) collision behaviour in _collide(), i.e. only the body's coefficient matters and not the plane's. However, a plane's collision coefficient is pretty important for rolling and sliding motions. I'm not a simulation expert yet, but I'd love to see how e.g. Bullet or MuJoCo implements lateral friction.

    cla: yes 
    opened by namheegordonkim 12
  • Support for drag force to implement gym swimmer environment

    Support for drag force to implement gym swimmer environment

    I was trying to get drag force into my simulation in order to get the swimmer environment running (and possibly simple underwater simulations!) because ther rest of the swimmer env is just simple to port over to brax.

    The drag force can be described as follows:

    $F_d = - 1/2 \rho ||v||^2 A C_d unit_vector(v)$

    where:

    • \rho is the density of the fluid
    • ||v||^2 is the magnitude of velocity squared
    • A is the surface area in the direction of the velocity
    • C_d is the Coeficient of drag, (friction of interaction)

    \rho and C_d as well as the 1/2 could easily be merged into one constant as they initially play a minor role of determining the properties of the interaction.

    This can in principle added easily to the location in code where the forces are applied to the different bodies. I got stuck with two problems:

    • Where is the best location in the code to add this formula or how do you want to add it? I think you guys know a lot better than I do @cdfreeman-google
    • Whereas the velocity of each body is certainly available, the surface area in the direction of velocity is certainly not (right?). I am pretty sure as I could not think of a part of brax that would require it. I think I will have to implemet for each collider shape a function that projects the collider onto a plane whose normal looks in the direction of the velocity. I am happy if there is an easier way to do this in brax.
    enhancement 
    opened by benelot 11
  • Performance nitpick

    Performance nitpick

    https://github.com/google/brax/blob/8e58feb923ce86b7b8c7036a05429793bbc3fa65/brax/physics/math.py#L278

    Little nitpick but

    S = jnp.array([1., -1., -1., -1.])
    def inv_quat(q):
    	return q * S
    

    Benchmarks as 10% faster on my laptop cpu at least; and I suspect the same would be more true of architectures more aggressively tuned for vectorization. Dont have any experience with TPUs and their compilers, but this formulation would also make it easier for a GPU compiler to get to the GPU-optimal compiled code I imagine.

    opened by EelcoHoogendoorn 10
  • External Torque

    External Torque

    For discussion: Beginning to implement external torque inputs. API might make sense to change, though, and is currently incomplete.

    The idea is that the frozen field on the Thruster could indicate whether to freeze either force or torque axes, allowing anywhere from 0/1 to 6 DOF for a Thruster.

    Currently this ignores the frozen field and provides all possible DOF. In addition, it might be worth adding a different strength field for torque versus force. However, I'm not going crazy in case a preferred approach is an entirely separate class for external torques than forces.

    This builds on #94 and addresses #61

    cla: yes 
    opened by peabody124 9
  • Multi-Agent Environments

    Multi-Agent Environments

    Hello,

    Are you planning to create any multi-agent environment such as crowd simulation?

    Is there also possibility to have a non-uniform terrain, walls etc in each environment?

    so that each agent can be initialized in a random location for varying its experience.

    (without that, I don't see a major advantage of parallel simulation capability of engine)

    Sincerely, Kamer

    question 
    opened by kayuksel 9
  • Support of height maps and collision between box corner and height map

    Support of height maps and collision between box corner and height map

    Hello maintainers of Brax,

    I really like your repo and would like to add support for height maps in order to train locomotion policies in uneven terrain.

    This PR implements both :

    • The visualization of height maps with the THREE interface.
    • The collision handling between height maps and box corners.

    I hope this might be useful to you and wish the best for your project.

    cla: yes 
    opened by o-Oscar 9
  • 'jaxlib.xla_extension' has no attribute 'CpuDevice'

    'jaxlib.xla_extension' has no attribute 'CpuDevice'

    Thanks for your great work. I just finish the installation and the verison of the libs are:

    brax              0.0.12 
    jax                0.3.7
    jaxlib             0.3.7+cuda11.cudnn805
    

    However, when I run the "learn" from the README, the log shows:

    Traceback (most recent call last):
      File "/home/yangwang/brax-0.0.12/env/bin/learn", line 7, in <module>
        exec(compile(f.read(), __file__, 'exec'))
      File "/home/yangwang/brax-0.0.12/bin/learn", line 4, in <module>
        from brax.training import learner
      File "/home/yangwang/brax-0.0.12/brax/training/learner.py", line 26, in <module>
        from brax.training import apg
      File "/home/yangwang/brax-0.0.12/brax/training/apg.py", line 32, in <module>
        import optax
      File "/home/yangwang/brax-0.0.12/env/lib/python3.9/site-packages/optax/__init__.py", line 17, in <module>
        from optax import experimental
      File "/home/yangwang/brax-0.0.12/env/lib/python3.9/site-packages/optax/experimental/__init__.py", line 20, in <module>
        from optax._src.experimental.complex_valued import split_real_and_imaginary
      File "/home/yangwang/brax-0.0.12/env/lib/python3.9/site-packages/optax/_src/experimental/complex_valued.py", line 32, in <module>
        import chex
      File "/home/yangwang/brax-0.0.12/env/lib/python3.9/site-packages/chex/__init__.py", line 17, in <module>
        from chex._src.asserts import assert_axis_dimension
      File "/home/yangwang/brax-0.0.12/env/lib/python3.9/site-packages/chex/_src/asserts.py", line 26, in <module>
        from chex._src import asserts_internal as _ai
      File "/home/yangwang/brax-0.0.12/env/lib/python3.9/site-packages/chex/_src/asserts_internal.py", line 32, in <module>
        from chex._src import pytypes
      File "/home/yangwang/brax-0.0.12/env/lib/python3.9/site-packages/chex/_src/pytypes.py", line 40, in <module>
        CpuDevice = jax.lib.xla_extension.CpuDevice
    AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
    

    I do not have a clue about this problem and I already use the latest version of both brax and jax (with GPU option).

    Hope you can give me some suggestion.

    opened by jzhzhang 8
  • Add Acrobot Environment, Remove Redudant Inverted Double Pendulum Observations

    Add Acrobot Environment, Remove Redudant Inverted Double Pendulum Observations

    Hello,

    I've been experimenting with some swing-up pendulum environments, I thought Brax might find some of them useful / interesting.

    I am proposing adding an Acrobot environment to Brax. The one I am submitting here is a bit different and more difficult than the one in I.E. gym, since the action state in continuous, and to "solve" the environment an agent must both swing up and balance the system. I like this environment because it is deceptively difficult for most model free RL with it (see for example this paper I wrote a couple of years ago).

    I've got my own twist on APG that works well for this acrobot, but haven't yet been able to get good performance on it from any of the brax algorithms (including APG). I would be very interested to hear if anyone gets the brax RL working well with the environment!

    In this PR I also remove some redundant / useless observations from the double inverted pendulum.

    Let me know what you all think!

    Edit: Also, here is a basic smoke test training set up in collab.

    opened by sgillen 8
  • Added loading mass and joint limits in URDF model importer.

    Added loading mass and joint limits in URDF model importer.

    I have added mass property in URDF and upper/lower revolute joint to Brax model importer.

    Test updated accordingly.

    I have not added the properties for inertia since the internal model config differs from the official one, being the official one a 3x3 matrix.

    Inertia property in Body is as a Vector3, I guess changing it to a matrix 3x3 can break other parts of the engine, if this change should be made I would gladly take it.

    https://github.com/google/brax/blob/7eaa16b4bf446b117b538dbe9c9401f97cf4afa2/brax/physics/config.proto#L24-L35


    How would the other limit properties fit in the engine computations? I tried to look for effort and velocity.

    I saw limit_strength but without being sure of the units it used I didn't want to break anything.

    For the case of velocity I did not find any similar parameter in the back end, are joints speed limited?
    Regarding joint types "universal" and "spherical" it was not clear to me how are they written in URDF, they are not official. I saw similar ones in Gazebo, SDF and MuJoCo. Due to this I left them as they were.

    opened by ManuCorrea 0
  • JaxToTorchWrapper error with jax 0.4.1

    JaxToTorchWrapper error with jax 0.4.1

    Hello, I am trying to run some experiments using pytorch with the JaxToTorchWrapper. I'm running the default Training in Brax with PyTorch on GPUs on a local jupyter instance, but there are errors.

    image

    The error occurs when I use jax==0.4.1 and goes away when I use a lower version. It seems like it has to do with the new jax.Array type introduced in 0.4.1.

    Environment:

    • Python 3.10.7
    • Cuda 11.8
    • jax[cuda]
    • brax==0.0.16

    Thanks!

    opened by jypark0 0
  • `mujoco_convert.py` does not use the default values of MJCF

    `mujoco_convert.py` does not use the default values of MJCF

    I tried running mujoco_convert.py on a sample MJCF file provided on the dm_control repository (this one, but others have similar issues). I get errors such as "unsupported geom type : None", because the type attribute of the elements is not always defined in the MJCF, since it relies on the defaults defined by MuJoCo in their XML reference. For instance, the documentation seems to indicate that the default value for the attribute "type" is "sphere".

    I asked over at dm_control to know whether the parse_xml function worked as intended or if they planned on adding the default values to it, and they said it was the expected behaviour. So I think brax should take those into account after the parse_xml function. What do you think ?

    opened by Theo-Cheynel 2
  • TracedConfig does not influence rendering

    TracedConfig does not influence rendering

    I was trying to extend the domain randomization examples to include changing collider shapes using the code below. However, the renderer doesn't pick up the changes. Specifically this occurs because the mesh information comes from json_format.MessageToDict which does not pick up any of the changes in TracedConfig.

    I don't know enough about the calls json_format.MessageToDict is making into the internal structures to know what to override, but if anyone knew and wanted to point me in that direction, I would be happy to make a PR.

    def scale_bodies(config, body_scale_dict: dict):
      """Constructs tree and in_axes objects for a joint socket randomizer.
      Adds an offset to any joints that match a key appearing in joint_key.  If
      no joint_key, then does nothing.
      Args:
        env: Environment to randomize over
        body_scale_dict
      Returns:
        Tuple of a pytree containing the randomized values packed into a tree
        structure parseable by the TracedConfig class
      """
    
      custom_tree = {'bodies': []}
    
      for b in config.bodies:
    
        def scale_body(b, x):
          colliders = []
          for c in b.colliders:
            collider = {
              'position': {
                'x': c.position.x * x,
                'y': c.position.y * x,
                'z': c.position.z * x
              },
              'capsule': {
                'length': c.capsule.length * x,
                'radius': c.capsule.radius * x
              }
            }
    
            colliders.append(collider)
          return {'colliders': colliders}
    
        if any([key in b.name for key in body_scale_dict.keys()]):
          custom_tree['bodies'].append(scale_body(b, body_scale_dict[b.name]))
    
        else:
          custom_tree['bodies'].append(scale_body(b, 1.0))
    
      return custom_tree
    
    opened by peabody124 1
  • Agents: Short-Horizon Actor Critic

    Agents: Short-Horizon Actor Critic

    This is an implementation of https://arxiv.org/pdf/2204.07137.pdf

    Not sure if there is interest merging this into the main branch. This might be an algorithm worth supporting as it leverages the differentiable simulator to outperform PPO according to the paper.

    Note that many of the environments don't actually have rewards that are differentiable w.r.t. the actions, in which case this algorithm performs poorly. For example, the fast environment used for testing APG and SHAC isn't. I added a fast_differentiable env and also made APG use this by default, after which the performance is much better.

    Still could do with tuning for environments and replicating the performance benefits seen in the original manuscript.

    Addressed #247

    opened by peabody124 2
  • feat(composer): agent specific observations

    feat(composer): agent specific observations

    Hey 👋

    I am working in my fork on supporting agent specific observations in the composer sub-package. Additionally I need this feature to support observation masking that is dependent on the current state of the environment.

    Best approach I can see:

    1. Each agent component already accepts observers a) Similar to how reward_fns is done b) This could support dynamic masking
    2. Introduce two methods to flatten and unflatten the observation array so that each agent can be given it's specific observation.
    3. Pass this sort of information inside the edge parameter to Composer: a) Could be part of its responsibility"automatically create necessary edge information among 2+ components" b) But maybe edge information is only meant to care about pairs of components

    Let me know if this sounds cool with you guys, but otherwise I am going to make a start on this feature 😀

    opened by cemlyn007 4
Releases(v0.1.0)
  • v0.1.0(Dec 21, 2022)

    Brax v0.1.0 Release Notes

    This minor release adds a preview of a major overhaul to Brax's API and functionality. This overhaul (found in the v2/ folder) will eventually become Brax's first stable (1.0) release.

    The new features of Brax v2 include:

    • Generalized physics backend.
    • Continued support for the Spring physics backends. PBD will soon follow.
    • Direct support for Mujoco XML format, and URDF by association.
    • Fully traceable System object.
    • Env API that better supports custom physics backends.
    • Open sourced visualizer server.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.16(Dec 15, 2022)

    Brax v0.0.16 Release Notes

    This release adds a new module: brax.experimental.tracing that allows for domain randomization during training. This release also adds support for placing replay buffers on device using pjit which allows for more configurable parallelism across many devices. Finally this release includes a number of small bug fixes.

    This will be the final release before we release a preview of a significant API change, so users may want to pin to this version if API stability is important.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.15(Sep 9, 2022)

    Brax v0.0.14 Release Notes

    This release includes a refactor of the training code to make it more modular and hackable, with each algorithm now as a separate submodule under brax.training.agents.

    This release also updates references to the deprecated jax.tree* functions to their new home in jax.tree_util, fixes a few bugs in physics/collision code, and adds an initial implementation of box-box collisions.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.14(Sep 9, 2022)

    Brax v0.0.14 Release Notes

    This release includes a refactor of the training code to make it more modular and hackable, with each algorithm now as a separate submodule under brax.training.agents.

    This release also updates references to the deprecated jax.tree* functions to their new home in jax.tree_util, fixes a few bugs in physics/collision code, and adds an initial implementation of box-box collisions.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.13(May 4, 2022)

    Brax v0.0.13 Release Notes

    This release fixes a few bugs in the collision handling in PBD, and adds support for specifying collider visibility, color, and contact participation.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.12(Mar 16, 2022)

  • v0.0.11(Mar 16, 2022)

    Brax Version 0.0.11 Release Notes

    This version introduces a significant overhaul to the physics algorithms. We now support position based dynamics for resolving joint and collision constraints. See this paper for details about PBD.

    The most noticeable difference to prior versions of Brax is that joints are now modeled as infinitely stiff, whereas before they were stiff damped spring systems. This new physics is now default, and all environments use PBD-based joints and collisions by default.

    If you would like to preserve the behavior used in previous versions of brax, you can either:

    1. Version pin to 0.0.10 – the version right before this upgrade. While you will not get the latest and greatest improvements to Brax, you will have unambiguously consistent behavior.

    2. Add dynamics_mode: "legacy_spring" to your brax configuration file. This causes brax to navigate the old codepath.

    3. Supply legacy_spring=True, as a kwarg to env creation (without `s). This causes Brax to load the older config for all the environments currently defined in Brax (see the logic in the init functions of each env for details).

    Thank you for using Brax, and feel free to open an Issue if you have any questions!

    Source code(tar.gz)
    Source code(zip)
  • v0.0.10(Dec 13, 2021)

  • v0.0.9(Dec 10, 2021)

  • v0.0.8(Nov 29, 2021)

  • v0.0.7(Nov 8, 2021)

  • v0.0.6(Oct 4, 2021)

  • v0.0.5(Sep 10, 2021)

  • v0.0.4(Aug 17, 2021)

Owner
Google
Google ❤️ Open Source
Google
3D-Reconstruction 基于深度学习方法的单目多视图三维重建

基于深度学习方法的单目多视图三维重建 Part I 三维重建 代码:Part1 技术文档:[Markdown] [PDF] 原始图像:Original Images 点云结果:Point Cloud Results-1

HMT_Curo 19 Dec 26, 2022
YOLOX + ROS(1, 2) object detection package

YOLOX + ROS(1, 2) object detection package

Ar-Ray 158 Dec 21, 2022
Robotic Process Automation in Windows and Linux by using Driagrams.net BPMN diagrams.

BPMN_RPA Robotic Process Automation in Windows and Linux by using BPMN diagrams. With this Framework you can draw Business Process Model Notation base

23 Dec 14, 2022
For visualizing the dair-v2x-i dataset

3D Detection & Tracking Viewer The project is based on hailanyi/3D-Detection-Tracking-Viewer and is modified, you can find the original version of the

34 Dec 29, 2022
Analyses of the individual electric field magnitudes with Roast.

Aloi Davide - PhD Student (UoB) Analysis of electric field magnitudes (wp2a dataset only at the moment) and correlation analysis with Dynamic Causal M

Davide Aloi 7 Dec 15, 2022
Object tracking using YOLO and a tracker(KCF, MOSSE, CSRT) in openCV

Object tracking using YOLO and a tracker(KCF, MOSSE, CSRT) in openCV File YOLOv3 weight can be downloaded

Ngoc Quyen Ngo 2 Mar 27, 2022
PyTorch implementation of Tacotron speech synthesis model.

tacotron_pytorch PyTorch implementation of Tacotron speech synthesis model. Inspired from keithito/tacotron. Currently not as much good speech quality

Ryuichi Yamamoto 279 Dec 09, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

24 Dec 31, 2022
magiCARP: Contrastive Authoring+Reviewing Pretraining

magiCARP: Contrastive Authoring+Reviewing Pretraining Welcome to the magiCARP API, the test bed used by EleutherAI for performing text/text bi-encoder

EleutherAI 43 Dec 29, 2022
Source code for CVPR2022 paper "Abandoning the Bayer-Filter to See in the Dark"

Abandoning the Bayer-Filter to See in the Dark (CVPR 2022) Paper: https://arxiv.org/abs/2203.04042 (Arxiv version) This code includes the training and

74 Dec 15, 2022
Code for the paper: "On the Bottleneck of Graph Neural Networks and Its Practical Implications"

On the Bottleneck of Graph Neural Networks and its Practical Implications This is the official implementation of the paper: On the Bottleneck of Graph

75 Dec 22, 2022
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

223 Dec 17, 2022
Text to Image Generation with Semantic-Spatial Aware GAN

text2image This repository includes the implementation for Text to Image Generation with Semantic-Spatial Aware GAN This repo is not completely. Netwo

CVDDL 124 Dec 30, 2022
Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

MRSAIL (Mini Robotics, Software & AI Lab) 6 Nov 26, 2022
Imbalanced Gradients: A Subtle Cause of Overestimated Adversarial Robustness

Imbalanced Gradients: A Subtle Cause of Overestimated Adversarial Robustness Code for Paper "Imbalanced Gradients: A Subtle Cause of Overestimated Adv

Hanxun Huang 11 Nov 30, 2022
QI-Q RoboMaster2022 CV Algorithm

QI-Q RoboMaster2022 CV Algorithm

2 Jan 10, 2022
PyTorch version implementation of DORN

DORN_PyTorch This is a PyTorch version implementation of DORN Reference H. Fu, M. Gong, C. Wang, K. Batmanghelich and D. Tao: Deep Ordinal Regression

Zilin.Zhang 3 Apr 27, 2022
An implementation of shampoo

shampoo.pytorch An implementation of shampoo, proposed in Shampoo : Preconditioned Stochastic Tensor Optimization by Vineet Gupta, Tomer Koren and Yor

Ryuichiro Hataya 69 Sep 10, 2022
Code for our paper "Graph Pre-training for AMR Parsing and Generation" in ACL2022

AMRBART An implementation for ACL2022 paper "Graph Pre-training for AMR Parsing and Generation". You may find our paper here (Arxiv). Requirements pyt

xfbai 60 Jan 03, 2023
Pytorch Lightning Distributed Accelerators using Ray

Distributed PyTorch Lightning Training on Ray This library adds new PyTorch Lightning plugins for distributed training using the Ray distributed compu

167 Jan 02, 2023