PyTorch reimplementation of Diffusion Models

Overview

PyTorch pretrained Diffusion Models

A PyTorch reimplementation of Denoising Diffusion Probabilistic Models with checkpoints converted from the author's TensorFlow implementation.

Quickstart

Running

pip install -e git+https://github.com/pesser/pytorch_diffusion.git#egg=pytorch_diffusion
pytorch_diffusion_demo

will start a Streamlit demo. It is recommended to run the demo with a GPU available.

demo

Usage

Diffusion models with pretrained weights for cifar10, lsun-bedroom, lsun_cat or lsun_church can be loaded as follows:

from pytorch_diffusion import Diffusion

diffusion = Diffusion.from_pretrained("lsun_church")
samples = diffusion.denoise(4)
diffusion.save(samples, "lsun_church_sample_{:02}.png")

Prefix the name with ema_ to load the averaged weights that produce better results. The U-Net model used for denoising is available via diffusion.model and can also be instantiated on its own:

from pytorch_diffusion import Model

model = Model(resolution=32,
              in_channels=3,
              out_ch=3,
              ch=128,
              ch_mult=(1,2,2,2),
              num_res_blocks=2,
              attn_resolutions=(16,),
              dropout=0.1)

This configuration example corresponds to the model used on CIFAR-10.

Producing samples

If you installed directly from github, you can find the cloned repository in <venv path>/src/pytorch_diffusion for virtual environments, and <cwd>/src/pytorch_diffusion for global installs. There, you can run

python pytorch_diffusion/diffusion.py <name> <bs> <nb>

where <name> is one of cifar10, lsun-bedroom, lsun_cat, lsun_church, or one of these names prefixed with ema_, <bs> is the batch size and <nb> the number of batches. This will produce samples from the PyTorch models and save them to results/<name>/.

Results

Evaluating 50k samples with torch-fidelity gives

Dataset EMA Framework Model FID
CIFAR10 Train no PyTorch cifar10 12.13775
TensorFlow tf_cifar10 12.30003
yes PyTorch ema_cifar10 3.21213
TensorFlow tf_ema_cifar10 3.245872
CIFAR10 Validation no PyTorch cifar10 14.30163
TensorFlow tf_cifar10 14.44705
yes PyTorch ema_cifar10 5.274105
TensorFlow tf_ema_cifar10 5.325035

To reproduce, generate 50k samples from the converted PyTorch models provided in this repo with

`python pytorch_diffusion/diffusion.py <Model> 500 100`

and with

python -c "import convert as m; m.sample_tf(500, 100, which=['cifar10', 'ema_cifar10'])"

for the original TensorFlow models.

Running conversions

The converted pytorch checkpoints are provided for download. If you want to convert them on your own, you can follow the steps described here.

Setup

This section assumes your working directory is the root of this repository. Download the pretrained TensorFlow checkpoints. It should follow the original structure,

diffusion_models_release/
  diffusion_cifar10_model/
    model.ckpt-790000.data-00000-of-00001
    model.ckpt-790000.index
    model.ckpt-790000.meta
  diffusion_lsun_bedroom_model/
    ...
  ...

Set the environment variable TFROOT to the directory where you want to store the author's repository, e.g.

export TFROOT=".."

Clone the diffusion repository,

git clone https://github.com/hojonathanho/diffusion.git ${TFROOT}/diffusion

and install their required dependencies (pip install ${TFROOT}/requirements.txt). Then add the following to your PYTHONPATH:

export PYTHONPATH=".:./scripts:${TFROOT}/diffusion:${TFROOT}/diffusion/scripts:${PYTHONPATH}"

Testing operations

To test the pytorch implementations of the required operations against their TensorFlow counterparts under random initialization and random inputs, run

python -c "import convert as m; m.test_ops()"

Converting checkpoints

To load the pretrained TensorFlow models, copy the weights into the pytorch models, check for equality on random inputs and finally save the corresponding pytorch checkpoints, run

python -c "import convert as m; m.transplant_cifar10()"
python -c "import convert as m; m.transplant_cifar10(ema=True)"
python -c "import convert as m; m.transplant_lsun_bedroom()"
python -c "import convert as m; m.transplant_lsun_bedroom(ema=True)"
python -c "import convert as m; m.transplant_lsun_cat()"
python -c "import convert as m; m.transplant_lsun_cat(ema=True)"
python -c "import convert as m; m.transplant_lsun_church()"
python -c "import convert as m; m.transplant_lsun_church(ema=True)"

Pytorch checkpoints will be saved in

diffusion_models_converted/
  diffusion_cifar10_model/
    model-790000.ckpt
  ema_diffusion_cifar10_model/
    model-790000.ckpt
  diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  ema_diffusion_lsun_bedroom_model/
    model-2388000.ckpt
  diffusion_lsun_cat_model/
    model-1761000.ckpt
  ema_diffusion_lsun_cat_model/
    model-1761000.ckpt
  diffusion_lsun_church_model/
    model-4432000.ckpt
  ema_diffusion_lsun_church_model/
    model-4432000.ckpt

Sample TensorFlow models

To produce N samples from each of the pretrained TensorFlow models, run

python -c "import convert as m; m.sample_tf(N)"

Pass a list of model names as keyword argument which to specify which models to sample from. Samples will be saved in results/.

Owner
Patrick Esser
Patrick Esser
Gesture-Volume-Control - This Python program can adjust the system's volume by using hand gestures

Gesture-Volume-Control This Python program can adjust the system's volume by usi

VatsalAryanBhatanagar 1 Dec 30, 2021
HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep.

HODEmu HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep. and emulates satellite abundance as a function of co

Antonio Ragagnin 1 Oct 13, 2021
How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

Bogdan Kulynych 49 Nov 05, 2022
Only a Matter of Style: Age Transformation Using a Style-Based Regression Model

Only a Matter of Style: Age Transformation Using a Style-Based Regression Model The task of age transformation illustrates the change of an individual

444 Dec 30, 2022
Official Repository for "Robust On-Policy Data Collection for Data Efficient Policy Evaluation" (NeurIPS 2021 Workshop on OfflineRL).

Robust On-Policy Data Collection for Data-Efficient Policy Evaluation Source code of Robust On-Policy Data Collection for Data-Efficient Policy Evalua

Autonomous Agents Research Group (University of Edinburgh) 2 Oct 09, 2022
PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020).

Scaffold-Federated-Learning PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020). Environment numpy=

KI 30 Dec 29, 2022
Learning To Have An Ear For Face Super-Resolution

Learning To Have An Ear For Face Super-Resolution [Project Page] This repository contains demo code of our CVPR2020 paper. Training and evaluation on

50 Nov 16, 2022
Incomplete easy-to-use math solver and PDF generator.

Math Expert Let me do your work Preview preview.mp4 Introduction Math Expert is our (@salastro, @younis-tarek, @marawn-mogeb) math high school graduat

SalahDin Ahmed 22 Jul 11, 2022
Code for the published paper : Learning to recognize rare traffic sign

Improving traffic sign recognition by active search This repo contains code for the paper : "Learning to recognise rare traffic signs" How to use this

samsja 4 Jan 05, 2023
Learning to Reconstruct 3D Non-Cuboid Room Layout from a Single RGB Image

NonCuboidRoom Paper Learning to Reconstruct 3D Non-Cuboid Room Layout from a Single RGB Image Cheng Yang*, Jia Zheng*, Xili Dai, Rui Tang, Yi Ma, Xiao

67 Dec 15, 2022
FairMOT for Multi-Class MOT using YOLOX as Detector

FairMOT-X Project Overview FairMOT-X is a multi-class multi object tracker, which has been tailored for training on the BDD100K MOT Dataset. It makes

Jonathan Tan 33 Dec 28, 2022
Implementation of the "PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences" paper.

PSTNet: Point Spatio-Temporal Convolution on Point Cloud Sequences Introduction Point cloud sequences are irregular and unordered in the spatial dimen

Hehe Fan 63 Dec 09, 2022
Semantic Segmentation in Pytorch. Network include: FCN、FCN_ResNet、SegNet、UNet、BiSeNet、BiSeNetV2、PSPNet、DeepLabv3_plus、 HRNet、DDRNet

🚀 If it helps you, click a star! ⭐ Update log 2020.12.10 Project structure adjustment, the previous code has been deleted, the adjustment will be re-

Deeachain 269 Jan 04, 2023
Python Implementation of Chess Playing AI with variable difficulty

Chess AI with variable difficulty level implemented using the MiniMax AB-Pruning Algorithm

Ali Imran 7 Feb 20, 2022
A 3D Dense mapping backend library of SLAM based on taichi-Lang designed for the aerial swarm.

TaichiSLAM This project is a 3D Dense mapping backend library of SLAM based Taichi-Lang, designed for the aerial swarm. Intro Taichi is an efficient d

XuHao 230 Dec 19, 2022
PyTorch implementation of the ideas presented in the paper Interaction Grounded Learning (IGL)

Interaction Grounded Learning This repository contains a simple PyTorch implementation of the ideas presented in the paper Interaction Grounded Learni

Arthur Juliani 4 Aug 31, 2022
Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph".

multilingual-mrc-isdg Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph". This r

Liyan 5 Dec 07, 2022
Code for Towards Streaming Perception (ECCV 2020) :car:

sAP — Code for Towards Streaming Perception ECCV Best Paper Honorable Mention Award Feb 2021: Announcing the Streaming Perception Challenge (CVPR 2021

Martin Li 85 Dec 22, 2022
NeuralWOZ: Learning to Collect Task-Oriented Dialogue via Model-based Simulation (ACL-IJCNLP 2021)

NeuralWOZ This code is official implementation of "NeuralWOZ: Learning to Collect Task-Oriented Dialogue via Model-based Simulation". Sungdong Kim, Mi

NAVER AI 31 Oct 25, 2022
Domain Generalization for Mammography Detection via Multi-style and Multi-view Contrastive Learning

MSVCL_MICCAI2021 Installation Please follow the instruction in pytorch-CycleGAN-and-pix2pix to install. Example Usage An example of vendor-styles tran

Jaron Lee 11 Oct 19, 2022