PyTorch Implementation of DSB for Score Based Generative Modeling. Experiments managed using Hydra.

Overview

Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling

This repository contains the implementation for the paper Diffusion Schrödinger Bridge with Applications to Score-Based Generative Modeling.

If using this code, please cite the paper:

    @article{de2021diffusion,
              title={Diffusion Schr$\backslash$" odinger Bridge with Applications to Score-Based Generative Modeling},
              author={De Bortoli, Valentin and Thornton, James and Heng, Jeremy and Doucet, Arnaud},
              journal={arXiv preprint arXiv:2106.01357},
              year={2021}
            }

Contributors

  • Valentin De Bortoli
  • James Thornton
  • Jeremy Heng
  • Arnaud Doucet

What is a Schrödinger bridge?

The Schrödinger Bridge (SB) problem is a classical problem appearing in applied mathematics, optimal control and probability; see [1, 2, 3]. In the discrete-time setting, it takes the following (dynamic) form. Consider as reference density p(x0:N) describing the process adding noise to the data. We aim to find p*(x0:N) such that p*(x0) = pdata(x0) and p*(xN) = pprior(xN) and minimize the Kullback-Leibler divergence between p* and p. In this work we introduce Diffusion Schrodinger Bridge (DSB), a new algorithm which uses score-matching approaches [4] to approximate the Iterative Proportional Fitting algorithm, an iterative method to find the solutions of the SB problem. DSB can be seen as a refinement of existing score-based generative modeling methods [5, 6].

Schrodinger bridge

Installation

This project can be installed from its git repository.

  1. Obtain the sources by:

    git clone https://github.com/anon284/schrodinger_bridge.git

or, if git is unavailable, download as a ZIP from GitHub https://github.com/.

  1. Install:

    conda env create -f conda.yaml

    conda activate bridge

  2. Download data examples:

    • CelebA: python data.py --data celeba --data_dir './data/'
    • MNIST: python data.py --data mnist --data_dir './data/'

How to use this code?

  1. Train Networks:
  • 2d: python main.py dataset=2d model=Basic num_steps=20 num_iter=5000
  • mnist python main.py dataset=stackedmnist num_steps=30 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>
  • celeba python main.py dataset=celeba num_steps=50 model=UNET num_iter=5000 data_dir=<insert filepath of data dir <local paths/data/>

Checkpoints and sampled images will be saved to a newly created directory. If GPU has insufficient memory, then reduce cache size. 2D dataset should train on CPU. MNIST and CelebA was ran on 2 high-memory V100 GPUs.

References

.. [1] Hans Föllmer Random fields and diffusion processes In: École d'été de Probabilités de Saint-Flour 1985-1987

.. [2] Christian Léonard A survey of the Schrödinger problem and some of its connections with optimal transport In: Discrete & Continuous Dynamical Systems-A 2014

.. [3] Yongxin Chen, Tryphon Georgiou and Michele Pavon Optimal Transport in Systems and Control In: Annual Review of Control, Robotics, and Autonomous Systems 2020

.. [4] Aapo Hyvärinen and Peter Dayan Estimation of non-normalized statistical models by score matching In: Journal of Machine Learning Research 2005

.. [5] Yang Song and Stefano Ermon Generative modeling by estimating gradients of the data distribution In: Advances in Neural Information Processing Systems 2019

.. [6] Jonathan Ho, Ajay Jain and Pieter Abbeel Denoising diffusion probabilistic models In: Advances in Neural Information Processing Systems 2020

Owner
James Thornton
James Thornton
Extreme Lightwegith Portrait Segmentation

Extreme Lightwegith Portrait Segmentation Please go to this link to download code Requirements python 3 pytorch = 0.4.1 torchvision==0.2.1 opencv-pyt

HYOJINPARK 59 Dec 16, 2022
A simple python library for fast image generation of people who do not exist.

Random Face A simple python library for fast image generation of people who do not exist. For more details, please refer to the [paper](https://arxiv.

Sergei Belousov 170 Dec 15, 2022
Self-driving car env with PPO algorithm from stable baseline3

Self-driving car with RL stable baseline3 Most of the project develop from https://github.com/GerardMaggiolino/Gym-Medium-Post Please check it out! Th

Sornsiri.P 7 Dec 22, 2022
Digan - Official PyTorch implementation of Generating Videos with Dynamics-aware Implicit Generative Adversarial Networks

DIGAN (ICLR 2022) Official PyTorch implementation of "Generating Videos with Dyn

Sihyun Yu 147 Dec 31, 2022
A simple baseline for the 2022 IEEE GRSS Data Fusion Contest (DFC2022)

DFC2022 Baseline A simple baseline for the 2022 IEEE GRSS Data Fusion Contest (DFC2022) This repository uses TorchGeo, PyTorch Lightning, and Segmenta

isaac 24 Nov 28, 2022
This is an implementation of Googles Yogi-Optimizer in Keras (tf.keras)

Yogi-Optimizer_Keras This is an implementation of Googles Yogi-Optimizer in Keras (tf.keras) The NeurIPS-Paper can be found here: http://papers.nips.c

14 Sep 13, 2022
TensorFlow CNN for fast style transfer

Fast Style Transfer in TensorFlow Add styles from famous paintings to any photo in a fraction of a second! It takes 100ms on a 2015 Titan X to style t

1 Dec 14, 2021
AoT is a system for automatically generating off-target test harness by using build information.

AoT: Auto off-Target Automatically generating off-target test harness by using build information. Brought to you by the Mobile Security Team at Samsun

Samsung 10 Oct 19, 2022
Benchmarking Pipeline for Prediction of Protein-Protein Interactions

B4PPI Benchmarking Pipeline for the Prediction of Protein-Protein Interactions How this benchmarking pipeline has been built, and how to use it, is de

Loïc Lannelongue 4 Jun 27, 2022
KIDA: Knowledge Inheritance in Data Aggregation

KIDA: Knowledge Inheritance in Data Aggregation This project releases our 1st place solution on NeurIPS2021 ML4CO Dual Task. Slide and model weights a

24 Sep 08, 2022
Intrinsic Image Harmonization

Intrinsic Image Harmonization [Paper] Zonghui Guo, Haiyong Zheng, Yufeng Jiang, Zhaorui Gu, Bing Zheng Here we provide PyTorch implementation and the

VISION @ OUC 44 Dec 21, 2022
This project provides a stock market environment using OpenGym with Deep Q-learning and Policy Gradient.

Stock Trading Market OpenAI Gym Environment with Deep Reinforcement Learning using Keras Overview This project provides a general environment for stoc

Kim, Ki Hyun 769 Dec 25, 2022
PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

StarEnhancer StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral) Abstract: Image enhancement is a subjective process w

IDKiro 133 Dec 28, 2022
ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021)

ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021) Project Page | Video | Paper | Data We present a novel metho

65 Nov 28, 2022
a simple, efficient, and intuitive text editor

Oxygen beta a simple, efficient, and intuitive text editor Overview oxygen is a simple, efficient, and intuitive text editor designed as more featured

Aarush Gupta 1 Feb 23, 2022
A basic implementation of Layer-wise Relevance Propagation (LRP) in PyTorch.

Layer-wise Relevance Propagation (LRP) in PyTorch Basic unsupervised implementation of Layer-wise Relevance Propagation (Bach et al., Montavon et al.)

Kai Fabi 28 Dec 26, 2022
The code for Bi-Mix: Bidirectional Mixing for Domain Adaptive Nighttime Semantic Segmentation

BiMix The code for Bi-Mix: Bidirectional Mixing for Domain Adaptive Nighttime Semantic Segmentation arxiv Framework: visualization results: Requiremen

stanley 18 Sep 18, 2022
LabelImg is a graphical image annotation tool.

LabelImgPlus LabelImg is a graphical image annotation tool. This project is not updated with new functions now. More functions are supported with Labe

lzx1413 200 Dec 20, 2022
Pixel-wise segmentation on VOC2012 dataset using pytorch.

PiWiSe Pixel-wise segmentation on the VOC2012 dataset using pytorch. FCN SegNet PSPNet UNet RefineNet For a more complete implementation of segmentati

Bodo Kaiser 378 Dec 30, 2022
Python inverse kinematics for your robot model based on Pinocchio.

Python inverse kinematics for your robot model based on Pinocchio.

Stéphane Caron 50 Dec 22, 2022