Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Overview

Infinitely Deep Bayesian Neural Networks with SDEs

This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stochastic variational inference. A rudimentary JAX implementation of differentiable SDE solvers is also provided, refer to torchsde [2] for a full set of differentiable SDE solvers in Pytorch and similarly to torchdiffeq [3] for differentiable ODE solvers.

Continuous-depth hidden unit trajectories in Neural ODE vs uncertain posterior dynamics SDE-BNN.

Installation

This library runs on jax==0.1.77 and torch==1.6.0. To install all other requirements:

pip install -r requirements.txt

Note: Package versions may change, refer to official JAX installation instructions here.

JaxSDE: Differentiable SDE Solvers in JAX

The jaxsde library contains SDE solvers in the Ito and Stratonovich form. Solvers of different orders can be specified with the following method={euler_maruyama|milstein|euler_heun} (strong orders 0.5|1|0.5 and orders 1|1|1 in the case of an additive noise SDE). Stochastic adjoint (sdeint_ito) training mode does not work efficiently yet, use sdeint_ito_fixed_grid for now. Tradeoff solver speed for precision during training or inference by adjusting --nsteps <# steps>.

Usage

Default solver: Backpropagation through the solver.

from jaxsde.jaxsde.sdeint import sdeint_ito_fixed_grid

y1 = sdeint_ito_fixed_grid(f, g, y0, ts, rng, fw_params, method="euler_maruyama")

Stochastic adjoint: Using O(1) memory instead of solving an adjoint SDE in the backward pass.

from jaxsde.jaxsde.sdeint import sdeint_ito

y1 = sdeint_ito(f, g, y0, ts, rng, fw_params, method="milstein")

Brax: Bayesian SDE Framework in JAX

Implementation of composable Bayesian layers in the stax API. Our SDE Bayesian layers can be used with the SDEBNN block composed with multiple parameterizations of time-dependent layers in diffeq_layers. Sticking-the-landing (STL) trick can be enabled during training with --stl for improving convergence rate. Augment the inputs by a custom amount --aug <integer>, set the number of samples averaged over with --nsamples <integer>. If memory constraints pose a problem, train in gradient accumulation mode: --acc_grad and gradient checkpointing: --remat.

Samples from SDEBNN-learned predictive prior and posterior density distributions.

Usage

All examples can be swapped in with different vision datasets. For better readability, tensorboard logging has been excluded (see torchbnn instead).

Toy 1D regression to learn complex posteriors:

python examples/jax/sdebnn_toy1d.py --ds cos --activn swish --loss laplace --kl_scale 1. --diff_const 0.2 --driftw_scale 0.1 --aug_dim 2 --stl --prior_dw ou

Image Classification:

To train an SDEBNN model:

python examples/jax/sdebnn_classification.py --output <output directory> --model sdenet --aug 2 --nblocks 2-2-2 --diff_coef 0.2 --fx_dim 64 --fw_dims 2-64-2 --nsteps 20 --nsamples 1

To train a ResNet baseline, specify --model resnet and for a Bayesian ResNet baseline, specify --meanfield_sdebnn.

TorchBNN: SDE-BNN in Pytorch

A PyTorch implementation of the Brax framework powered by the torchsde backend.

Usage

All examples can be swapped in with different vision datasets and includes tensorboard logging for critical metrics.

Toy 1D regression to learn multi-modal posterior:

python examples/torch/sdebnn_toy1d.py --output_dir <dst_path>

Arbitrarily expression approximate posteriors from learning non-Gaussian marginals.

Image Classification:

All hyperparameters can be found in the training script. Train with adjoint for memory efficient backpropagation and adaptive mode for adaptive computation (and ensure --adjoint_adaptive True if training with adjoint and adaptive modes).

python examples/torch/sdebnn_classification.py --train-dir <output directory> --data cifar10 --dt 0.05 --method midpoint --adjoint True --adaptive True --adjoint_adaptive True --inhomogeneous True

References

[1] Winnie Xu, Ricky T. Q. Chen, Xuechen Li, David Duvenaud. "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations." Preprint 2021. [arxiv]

[2] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations." AISTATS 2020. [arxiv]

[3] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." NeurIPS. 2018. [arxiv]


If you found this library useful in your research, please consider citing

@article{xu2021sdebnn,
  title={Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations},
  author={Xu, Winnie and Chen, Ricky T. Q. and Li, Xuechen and Duvenaud, David},
  archivePrefix = {arXiv},
  year={2021}
}
Owner
Winnie Xu
Undergrad in CS/Stats/Math '22 @ UToronto. Working on something secret @cohere-ai. Deep neural networks @for-ai @VectorInstitute. Prev. @google-research @NVIDIA
Winnie Xu
Perception-aware multi-sensor fusion for 3D LiDAR semantic segmentation (ICCV 2021)

Perception-Aware Multi-Sensor Fusion for 3D LiDAR Semantic Segmentation (ICCV 2021) [中文|EN] 概述 本工作主要探索一种高效的多传感器(激光雷达和摄像头)融合点云语义分割方法。现有的多传感器融合方法主要将点云投影

ICE 126 Dec 30, 2022
Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Bootstrap Your Own Latent (BYOL), in Pytorch Practical implementation of an astoundingly simple method for self-supervised learning that achieves a ne

Phil Wang 1.4k Dec 29, 2022
Code repo for "Cross-Scale Internal Graph Neural Network for Image Super-Resolution" (NeurIPS'20)

IGNN Code repo for "Cross-Scale Internal Graph Neural Network for Image Super-Resolution" [paper] [supp] Prepare datasets 1 Download training dataset

Shangchen Zhou 278 Jan 03, 2023
This is the official pytorch implementation of the BoxEL for the description logic EL++

BoxEL: Box EL++ Embedding This is the official pytorch implementation of the BoxEL for the description logic EL++. BoxEL++ is a geometric approach bas

1 Nov 03, 2022
eXPeditious Data Transfer

xpdt: eXPeditious Data Transfer About xpdt is (yet another) language for defining data-types and generating code for serializing and deserializing the

Gianni Tedesco 3 Jan 06, 2022
Text-to-Image generation

Generate vivid Images for Any (Chinese) text CogView is a pretrained (4B-param) transformer for text-to-image generation in general domain. Read our p

THUDM 1.3k Dec 29, 2022
harmonic-percussive-residual separation algorithm wrapped as a VST3 plugin (iPlug2)

Harmonic-percussive-residual separation plug-in This work is a study on the plausibility of a sines-transients-noise decomposition inspired algorithm

Derp Learning 9 Sep 01, 2022
Paper Title: Heterogeneous Knowledge Distillation for Simultaneous Infrared-Visible Image Fusion and Super-Resolution

HKDnet Paper Title: "Heterogeneous Knowledge Distillation for Simultaneous Infrared-Visible Image Fusion and Super-Resolution" Email:

wasteland 11 Nov 12, 2022
Yolo algorithm for detection + centroid tracker to track vehicles

Vehicle Tracking using Centroid tracker Algorithm used : Yolo algorithm for detection + centroid tracker to track vehicles Backend : opencv and python

6 Dec 21, 2022
Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.

Translated in 🇰🇷 Korean/ Ludwig is a toolbox that allows users to train and test deep learning models without the need to write code. It is built on

Ludwig 8.7k Jan 05, 2023
Amazing-Python-Scripts - 🚀 Curated collection of Amazing Python scripts from Basics to Advance with automation task scripts.

📑 Introduction A curated collection of Amazing Python scripts from Basics to Advance with automation task scripts. This is your Personal space to fin

Avinash Ranjan 1.1k Dec 29, 2022
🏃‍♀️ A curated list about human motion capture, analysis and synthesis.

Awesome Human Motion 🏃‍♀️ A curated list about human motion capture, analysis and synthesis. Contents Introduction Human Models Datasets Data Process

Dennis Wittchen 274 Dec 14, 2022
Entity-Based Knowledge Conflicts in Question Answering.

Entity-Based Knowledge Conflicts in Question Answering Run Instructions | Paper | Citation | License This repository provides the Substitution Framewo

Apple 35 Oct 19, 2022
some academic posters as references. May we have in-person poster session soon!

some academic posters as references. May we have in-person poster session soon!

Bolei Zhou 472 Jan 06, 2023
We are More than Our JOints: Predicting How 3D Bodies Move

We are More than Our JOints: Predicting How 3D Bodies Move Citation This repo contains the official implementation of our paper MOJO: @inproceedings{Z

72 Oct 20, 2022
Collision risk estimation using stochastic motion models

collision_risk_estimation Collision risk estimation using stochastic motion models. This is a new approach, based on stochastic models, to predict the

Unmesh 7 Jun 26, 2022
CZU-MHAD: A multimodal dataset for human action recognition utilizing a depth camera and 10 wearable inertial sensors

CZU-MHAD: A multimodal dataset for human action recognition utilizing a depth camera and 10 wearable inertial sensors   In order to facilitate the res

yujmo 11 Dec 12, 2022
For AILAB: Cross Lingual Retrieval on Yelp Search Engine

Cross-lingual Information Retrieval Model for Document Search Train Phase CUDA_VISIBLE_DEVICES="0,1,2,3" \ python -m torch.distributed.launch --nproc_

Chilia Waterhouse 104 Nov 12, 2022
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Fudan Zhang Vision Group 897 Jan 05, 2023