Release for Improved Denoising Diffusion Probabilistic Models

Overview

improved-diffusion

This is the codebase for Improved Denoising Diffusion Probabilistic Models.

Usage

This section of the README walks through how to train and sample from a model.

Installation

Clone this repository and navigate to it in your terminal. Then run:

pip install -e .

This should install the improved_diffusion python package that the scripts depend on.

Preparing Data

The training code reads images from a directory of image files. In the datasets folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10.

For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names).

The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass --data_dir path/to/images to the training script, and it will take care of the rest.

Training

To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline:

MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3"
DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128"

Here are some changes we experiment with, and how to set them in the flags:

  • Learned sigmas: add --learn_sigma True to MODEL_FLAGS
  • Cosine schedule: change --noise_schedule linear to --noise_schedule cosine
  • Reweighted VLB: add --use_kl True to DIFFUSION_FLAGS and add --schedule_sampler loss-second-moment to TRAIN_FLAGS.
  • Class-conditional: add --class_cond True to MODEL_FLAGS.

Once you have setup your hyper-parameters, you can run an experiment like so:

python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

You may also want to train in a distributed manner. In this case, run the same command with mpiexec:

mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

When training in a distributed manner, you must manually divide the --batch_size argument by the number of ranks. In lieu of distributed training, you may use --microbatch 16 (or --microbatch 1 in extreme memory-limited cases) to reduce memory usage.

The logs and saved models will be written to a logging directory determined by the OPENAI_LOGDIR environment variable. If it is not set, then a temporary directory will be created in /tmp.

Sampling

The above training script saves checkpoints to .pt files in the logging directory. These checkpoints will have names like ema_0.9999_200000.pt and model200000.pt. You will likely want to sample from the EMA models, since those produce much better samples.

Once you have a path to your model, you can generate a large batch of samples like so:

python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS

Again, this will save results to a logging directory. Samples are saved as a large npz file, where arr_0 in the file is a large batch of samples.

Just like for training, you can run image_sample.py through MPI to use multiple GPUs and machines.

You can change the number of sampling steps using the --timestep_respacing argument. For example, --timestep_respacing 250 uses 250 steps to sample. Passing --timestep_respacing ddim250 is similar, but uses the uniform stride from the DDIM paper rather than our stride.

To sample using DDIM, pass --use_ddim True.

Owner
OpenAI
OpenAI
Vita Specific Patches and Application for Doki Doki Literature Club (Steam Version) using Ren'Py PSVita

Doki-Doki-Literature-Club-Vita Vita Specific Patches and Application for Doki Doki Literature Club (Steam Version) using Ren'Py PSVita Contains: Modif

Jaylon Gowie 25 Dec 30, 2022
JSEngine is a simple wrapper of Javascript engines.

JSEngine This is a simple wrapper of Javascript engines, it wraps the Javascript interpreter for Python use. There are two ways to call interpreters,

11 Dec 18, 2022
Izy - Python functions and classes that make python even easier than it is

izy Python functions and classes that make it even easier! You will wonder why t

5 Jul 04, 2022
Huggingface package for the discrete VAE used for DALL-E.

DALL-E-Tokenizer Huggingface package for the discrete VAE used for DALL-E.

MyungHoon Jin 5 Sep 01, 2021
Python script for diving image data to train test and val

dataset-division-to-train-val-test-python python script for dividing image data to train test and val If you have an image dataset in the following st

Muhammad Zeeshan 1 Nov 14, 2022
Users can read others' travel journeys in addition to being able to upload and delete posts detailing their own experiences

Users can read others' travel journeys in addition to being able to upload and delete posts detailing their own experiences! Posts are organized by country and destination within that country.

Christopher Zeas 1 Feb 03, 2022
A Python program that generates a maze that solves itself using DFS

Maze Generator And Solver Program Purpose: Generates a maze that then solves itself Language: Python and Pygame Algorithm: Randomized DFS / Floodfill

Joshua Liu 1 Jul 25, 2022
Data on COVID-19 (coronavirus) cases, deaths, hospitalizations, tests • All countries • Updated daily by Our World in Data

COVID-19 Dataset by Our World in Data Find our data on COVID-19 and its documentation in public/data. Documentation Data: complete COVID-19 dataset Da

Our World in Data 5.5k Jan 03, 2023
NUM Alert - A work focus aid created for the Hack the Job hackathon

Contributors: Uladzislau Kaparykha, Amanda Hahn, Nicholas Waller Hackathon Team Name: N.U.M General Purpose: The general purpose of this program is to

Amanda Hahn 1 Jan 10, 2022
flake8 plugin which forbids match statements (PEP 634)

flake8-match flake8 plugin which forbids match statements (PEP 634)

Anthony Sottile 25 Nov 01, 2022
This repository contains each day of Advent of Code 2021 that I've done.

Advent of Code - 2021 I will use this repository as my Advent of Code1 (AoC) repo for the 2021 challenge. I'm changing how I am tackling the problems

Brett Chapin 2 Jan 12, 2022
Convert-Decimal-to-Binary-Octal-and-Hexadecimal

Convert-Decimal-to-Binary-Octal-and-Hexadecimal We have a number in a decimal number, and we have to convert it into a binary, octal, and hexadecimal

Maanyu M 2 Oct 08, 2021
A simple watcher for the XTZ/kUSD pool on Quipuswap

Kolibri Quipuswap Watcher This repo holds the source code for the QuipuBot bot deployed to the #quipuswap-updates channel in the Kolibri Discord Setup

Hover Labs 1 Nov 18, 2021
Cylinder volume calculator features the calculations of the volume of a Right /oblique full cylinder

Cylinder-Volume-Calculator Cylinder volume calculator features the calculations of the volume of a Right /oblique full cylinder. Size : 10.5 mb compat

Abhijeet 4 Nov 07, 2022
HSPyLib is a Python library that will elevate your experience to another level.

HomeSetup Python Library - HSPyLib Your mature python application HSPyLib is a Python library that will elevate your experience to another level. It r

Hugo Saporetti Junior 4 Dec 14, 2022
0xFalcon - 0xFalcon Tool For Python

0xFalcone Installation Install 0xFalcone Tool: apt install git git clone https:/

Alharb7 6 Sep 24, 2022
Ballistic calculator for Airsoft

Ballistic-calculator-for-Airsoft 用于Airsoft的弹道计算器 This is a ballistic calculator for airsoft gun. To calculate your airsoft gun's ballistic, you should

3 Jan 20, 2022
Creates infinite amount of guilded accounts in seconds.

Guilded Cookie Creator [fuck guilded i quit working on this, they patch like every fucking method after 2/3 days i release shit] Optimizations Asynchr

scripted 7 Feb 28, 2022
The Python agent for Apache SkyWalking

SkyWalking Python Agent SkyWalking-Python: The Python Agent for Apache SkyWalking, which provides the native tracing abilities for Python project. Sky

The Apache Software Foundation 149 Dec 12, 2022
Access Modbus RTU via API call to Sungrow WiNet-S

SungrowModbusWebClient Access Modbus RTU via API call to Sungrow WiNet-S Class based on pymodbus.ModbusTcpClient, completely interchangeable, just rep

8 Oct 30, 2022