Differentiable Factor Graph Optimization for Learning Smoothers
Overview
Code release for our IROS 2021 conference paper:
Brent Yi1, Michelle A. Lee1, Alina Kloss2, Roberto Martín-Martín1, and Jeannette Bohg1. Differentiable Factor Graph Optimization for Learning Smoothers. Proceedings of the IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), October 2021. |
1Stanford University, {brentyi,michellelee,robertom,bohg}@cs.stanford.edu
2Max Planck Institute for Intelligent Systems, [email protected]
This repository contains models, training scripts, and experimental results, and can be used to either reproduce our results or as a reference for implementation details.
Significant chunks of the code written for this paper have been factored out of this repository and released as standalone libraries, which may be useful for building on our work. You can find each of them linked here:
- jaxfg is our core factor graph optimization library.
- jaxlie is our Lie theory library for working with rigid body transformations.
- jax_dataclasses is our library for building JAX pytrees as dataclasses. It's similar to
flax.struct
, but has workflow improvements for static analysis and nested structures. - jax-ekf contains our EKF implementation.
Status
Included in this repo for the disk task:
- Smoother training & results
- Training:
python train_disk_fg.py --help
- Evaluation:
python cross_validate.py --experiment-paths ./experiments/disk/fg/**/
- Training:
- Filter baseline training & results
- Training:
python train_disk_ekf.py --help
- Evaluation:
python cross_validate.py --experiment-paths ./experiments/disk/ekf/**/
- Training:
- LSTM baseline training & results
- Training:
python train_disk_lstm.py --help
- Evaluation:
python cross_validate.py --experiment-paths ./experiments/disk/lstm/**/
- Training:
And, for the visual odometry task:
- Smoother training & results (including ablations)
- Training:
python train_kitti_fg.py --help
- Evaluation:
python cross_validate.py --experiment-paths ./experiments/kitti/fg/**/
- Training:
- EKF baseline training & results
- Training:
python train_kitti_ekf.py --help
- Evaluation:
python cross_validate.py --experiment-paths ./experiments/kitti/ekf/**/
- Training:
- LSTM baseline training & results
- Training:
python train_kitti_lstm.py --help
- Evaluation:
python cross_validate.py --experiment-paths ./experiments/kitti/lstm/**/
- Training:
Note that **/
indicates a recursive glob in zsh. This can be emulated in bash>4 via the globstar option (shopt -q globstar
).
We've done our best to make our research code easy to parse, but it's still being iterated on! If you have questions, suggestions, or any general comments, please reach out or file an issue.
Setup
We use Python 3.8 and miniconda for development.
-
Any calls to CHOLMOD (via
scikit-sparse
, sometimes used for eval but never for training itself) will require SuiteSparse:# Mac brew install suite-sparse # Debian sudo apt-get install -y libsuitesparse-dev
-
Dependencies can be installed via pip:
pip install -r requirements.txt
In addition to JAX and the first-party dependencies listed above, note that this also includes various other helpers:
The requirements.txt
provided will install the CPU version of JAX by default. For CUDA support, please see instructions from the JAX team.
Datasets
Datasets synced from Google Drive and loaded via h5py automatically as needed. If you're interested in downloading them manually, see lib/kitti/data_loading.py
and lib/disk/data_loading.py
.
Training
The naming convention for training scripts is as follows: train_{task}_{model type}.py
.
All of the training scripts provide a command-line interface for configuring experiment details and hyperparameters. The --help
flag will summarize these settings and their default values. For example, to run the training script for factor graphs on the disk task, try:
> python train_disk_fg.py --help
Factor graph training script for disk task.
optional arguments:
-h, --help show this help message and exit
--experiment-identifier EXPERIMENT_IDENTIFIER
(default: disk/fg/default_experiment/fold_{dataset_fold})
--random-seed RANDOM_SEED
(default: 94305)
--dataset-fold {0,1,2,3,4,5,6,7,8,9}
(default: 0)
--batch-size BATCH_SIZE
(default: 32)
--train-sequence-length TRAIN_SEQUENCE_LENGTH
(default: 20)
--num-epochs NUM_EPOCHS
(default: 30)
--learning-rate LEARNING_RATE
(default: 0.0001)
--warmup-steps WARMUP_STEPS
(default: 50)
--max-gradient-norm MAX_GRADIENT_NORM
(default: 10.0)
--noise-model {CONSTANT,HETEROSCEDASTIC}
(default: CONSTANT)
--loss {JOINT_NLL,SURROGATE_LOSS}
(default: SURROGATE_LOSS)
--pretrained-virtual-sensor-identifier PRETRAINED_VIRTUAL_SENSOR_IDENTIFIER
(default: disk/pretrain_virtual_sensor/fold_{dataset_fold})
When run, train scripts serialize experiment configurations to an experiment_config.yaml
file. You can find hyperparameters in the experiments/
directory for all results presented in our paper.
Evaluation
All evaluation metrics are recorded at train time. The cross_validate.py
script can be used to compute metrics across folds:
# Summarize all experiments with means and standard errors of recorded metrics.
python cross_validate.py
# Include statistics for every fold -- this is much more data!
python cross_validate.py --disaggregate
# We can also glob for a partial set of experiments; for example, all of the
# disk experiments.
# Note that the ** wildcard may fail in bash; see above for a fix.
python cross_validate.py --experiment-paths ./experiments/disk/**/
Acknowledgements
We'd like to thank Rika Antonova, Kevin Zakka, Nick Heppert, Angelina Wang, and Philipp Wu for discussions and feedback on both our paper and codebase. Our software design also benefits from ideas from several open-source projects, including Sophus, GTSAM, Ceres Solver, minisam, and SwiftFusion.
This work is partially supported by the Toyota Research Institute (TRI) and Google. This article solely reflects the opinions and conclusions of its authors and not TRI, Google, or any entity associated with TRI or Google.