Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"

Overview

The Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more" Arxiv preprint

Louay Hazami   ·   Rayhane Mama   ·   Ragavan Thurairatnam


MIT license PWC PWC PWC PWC PWC PWC PWC PWC

Efficient-VDVAE is a memory and compute efficient very deep hierarchical VAE. It converges faster and is more stable than current hierarchical VAE models. It also achieves SOTA likelihood-based performance on several image datasets.

Pre-trained model checkpoints

We provide checkpoints of pre-trained models on MNIST, CIFAR-10, Imagenet 32x32, Imagenet 64x64, CelebA 64x64, CelebAHQ 256x256 (5-bits and 8-bits), FFHQ 256x256 (5-bits and 8bits), CelebAHQ 1024x1024 and FFHQ 1024x1024 in the links in the table below. All provided models are the ones trained for table 4 of the paper.

Dataset Pytorch JAX Negative ELBO
Logs Checkpoints Logs Checkpoints
MNIST link link link link 79.09 nats
CIFAR-10 Queued Queued link link 2.87 bits/dim
Imagenet 32x32 link link link link 3.58 bits/dim
Imagenet 64x64 link link link link 3.30 bits/dim
CelebA 64x64 link link link link 1.83 bits/dim
CelebAHQ 256x256 (5-bits) link link link link 0.51 bits/dim
CelebAHQ 256x256 (8-bits) link link link link 1.35 bits/dim
FFHQ 256x256 (5-bits) link link link link 0.53 bits/dim
FFHQ 256x256 (8-bits) link link link link 2.17 bits/dim
CelebAHQ 1024x1024 link link link link 1.01 bits/dim
FFHQ 1024x1024 link link link link 2.30 bits/dim

Notes:

  • Downloading from the "Checkpoints" link will download the minimal required files to resume training/do inference. The minimal files are the model checkpoint file and the saved hyper-parameters of the run (explained further below).
  • Downloading from the "Logs" link will download additional pre-training logs such as tensorboard files or saved images from training. "Logs" also holds the saved hyper-parameters of the run.
  • Downloaded "Logs" and/or "Checkpoints" should be always unzipped in their implementation folder (efficient_vdvae_torch for Pytorch checkpoints and efficient_vdvae_jax for JAX checkpoints).
  • Some of the model checkpoints are missing in either Pytorch or JAX for the moment. We will update them soon.

Pre-requisites

To run this codebase, you need:

  • Machine that runs a linux based OS (tested on Ubuntu 20.04 (LTS))
  • GPUs (preferably more than 16GB)
  • Docker
  • Python 3.7 or higher
  • CUDA 11.1 or higher (can be installed from here)

We recommend running all the code below inside a Linux screen or any other terminal multiplexer, since some commands can take hours/days to finish and you don't want them to die when you close your terminal.

Note:

  • If you're planning on running the JAX implementation, the installed JAX must use exactly the same CUDA and Cudnn versions installed. Our default Dockerfile assumes the code will run with CUDA 11.4 or newer and should be changed otherwise. For more details, refer to JAX installation.

Installation

To create the docker image used in both the Pytorch and JAX implementations:

cd build  
docker build -t efficient_vdvae_image .  

Note:

  • If using JAX library on ampere architecture GPUs, it's possible to face a random GPU hanging problem when training on multiple GPUs (issue). In that case, we provide an alternative docker image with an older version of JAX to bypass the issue until a solution is found.

All code executions should be done within a docker container. To start the docker container, we provide a utility script:

sh docker_run.sh  # Starts the container and attaches terminal
cd /workspace/Efficient-VDVAE  # Inside docker container

Setup datasets

All datasets can be automatically downloaded and pre-processed from the convenience script we provide:

cd data_scripts
sh download_and_preprocess.sh <dataset_name>

Notes:

  • <dataset_name> can be one of (imagenet32, imagenet64, celeba, celebahq, ffhq). MNIST and CIFAR-10 datasets will get automatically downloaded later when training the model, and they do no require any dataset setup.
  • For the celeba dataset, a manual download of img_align_celeba.zip and list_eval_partition.txt files is necessary. Both files should be placed under <project_path>/dataset_dumps/.
  • img_align_celeba.zip download link.
  • list_eval_partition.txt download link.

Setting the hyper-parameters

In this repository, we use hparams library (already included in the Dockerfile) for hyper-parameter management:

  • Specify all run parameters (number of GPUs, model parameters, etc) in one .cfg file
  • Hparams evaluates any expression used as "value" in the .cfg file. "value" can be any basic python object (floats, strings, lists, etc) or any python basic expression (1/2, max(3, 7), etc.) as long as the evaluation does not require any library importations or does not rely on other values from the .cfg.
  • Hparams saves the configuration of previous runs for reproducibility, resuming training, etc.
  • All hparams are saved by name, and re-using the same name will recall the old run instead of making a new one.
  • The .cfg file is split into sections for readability, and all parameters in the file are accessible as class attributes in the codebase for convenience.
  • The HParams object keeps a global state throughout all the scripts in the code.

We highly recommend having a deeper look into how this library works by reading the hparams library documentation, the parameters description and figures 4 and 5 in the paper before trying to run Efficient-VDVAE.

We have heavily tested the robustness and stability of our approach, so changing the model/optimization hyper-parameters for memory load reduction should not introduce any drastic instabilities as to make the model untrainable. That is of course as long as the changes don't negate the important stability points we describe in the paper.

Training the Efficient-VDVAE

To run Efficient-VDVAE in Torch:

cd efficient_vdvae_torch  
# Set the hyper-parameters in "hparams.cfg" file  
# Set "NUM_GPUS_PER_NODE" in "train.sh" file  
sh train.sh  

To run Efficient-VDVAE in JAX:

cd efficient_vdvae_jax  
# Set the hyper-parameters in "hparams.cfg" file  
python train.py  

If you want to run the model with less GPUs than available on the hardware, for example 2 GPUs out of 8:

CUDA_VISIBLE_DEVICES=0,1 sh train.sh  # For torch  
CUDA_VISIBLE_DEVICES=0,1 python train.py  # For JAX  

Models automatically create checkpoints during training. To resume a model from its last checkpoint, set its <run.name> in hparams.cfg file and re-run the same training commands.

Since training commands will save the hparams of the defined run in the .cfg file. If trying to restart a pre-existing run (by re-using its name in hparams.cfg), we provide a convenience script for resetting saved runs:

cd efficient_vdvae_torch  # or cd efficient_vdvae_jax  
sh reset.sh <run.name>  # <run.name> is the first field in hparams.cfg  

Note:

  • To make things easier for new users, we provide example hparams.cfg files that can be used under the egs folder. Detailed description of the role of each parameter is also inside hparams.cfg.
  • Hparams in egs are to be viewed only as guiding examples, they are not meant to be exactly similar to pre -trained checkpoints or experiments done in the paper.
  • While the example hparams under the naming convention ..._baseline.cfg are not exactly the hparams of C2 models in the paper (pre-trained checkpoints), they are easier to design models that achieve the same performance and can be treated as equivalents to C2 models.

Monitoring the training process

While writing this codebase, we put extra emphasis on verbosity and logging. Aside from the printed logs on terminal (during training), you can monitor the training progress and keep track of useful metrics using Tensorboard:

# While outside efficient_vdvae_torch or efficient_vdvae_jax  
# Run outside the docker container
tensorboard --logdir . --port <port_id> --reload_multifile True  

In the browser, navigate to localhost:<port_id> to visualize all saved metrics.

If Tensorboard is not installed (outside the docker container):

pip install --upgrade tensorboard

Inference with the Efficient-VDVAE

Efficient-VDVAE support multiple inference modes:

  • "reconstruction": Encodes then decodes the test set images and computes test NLL and SSIM.
  • "generation": Generates random images from the prior distribution. Randomness is controlled by the run.seed parameter.
  • "div_stats": Pre-computes the average KL divergence stats used to determine turned-off variates (refer to section 7 of the paper). Note: This mode needs to be run before "encoding" mode and before trying to do masked "reconstruction" (Refer to hparams.cfg for a detailed description).
  • "encoding": Extracts the latent distribution from the inference model, pruned to the quantile defined by synthesis.variates_masks_quantile parameter. This latent distribution is usable in downstream tasks.

To run the inference:

cd efficient_vdvae_torch  # or cd efficient_vdvae_jax  
# Set the inference mode in "logs-<run.name>/hparams-<run.name>.cfg"  
# Set the same <run.name> in "hparams.cfg"  
python synthesize.py  

Notes:

  • Since training a model with a name <run.name> will save that configuration under logs-<run.name>/hparams-<run.name>.cfg for reproducibility and error reduction. Any changes that one wants to make during inference time need to be applied on the saved hparams file (logs-<run.name>/hparams-<run.name>.cfg) instead of the main file hparams.cfg.
  • The torch implementation currently doesn't support multi-GPU inference. The JAX implementation does.

Potential TODOs

  • Make data loaders Out-Of-Core (OOC) in Pytorch
  • Make data loaders Out-Of-Core (OOC) in JAX
  • Update pre-trained model checkpoints
  • Add Fréchet-Inception Distance (FID) and Inception Score (IS) as measures for sample quality performance.
  • Improve the format of the encoded dataset used in downstream tasks (output of encoding mode, if there is a need)
  • Write a decoding mode API (if needed).

Bibtex

If you happen to use this codebase, please cite our paper:

@article{hazami2022efficient,
  title={Efficient-VDVAE: Less is more},
  author={Hazami, Louay and Mama, Rayhane and Thurairatnam, Ragavan},
  journal={arXiv preprint arXiv:2203.13751},
  year={2022}
}
Owner
Rayhane Mama
- If it seems impossible, then it's worth doing.
Rayhane Mama
Using NumPy to solve the equations of fluid mechanics together with Finite Differences, explicit time stepping and Chorin's Projection methods

Computational Fluid Dynamics in Python Using NumPy to solve the equations of fluid mechanics 🌊 🌊 🌊 together with Finite Differences, explicit time

Felix Köhler 4 Nov 12, 2022
[cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation

PS-MT [cvpr22] Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation by Yuyuan Liu, Yu Tian, Yuanhong Chen, Fengbei Liu, Vasile

Yuyuan Liu 132 Jan 03, 2023
Modifications of the official PyTorch implementation of StyleGAN3. Let's easily generate images and videos with StyleGAN2/2-ADA/3!

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation of the NeurIPS 2021 paper Alias-Free Generative Adversarial Net

Diego Porres 185 Dec 24, 2022
SynNet - synthetic tree generation using neural networks

SynNet This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can s

Wenhao Gao 60 Dec 29, 2022
A Genetic Programming platform for Python with TensorFlow for wicked-fast CPU and GPU support.

Karoo GP Karoo GP is an evolutionary algorithm, a genetic programming application suite written in Python which supports both symbolic regression and

Kai Staats 149 Jan 09, 2023
Multispectral Object Detection with Yolov5

Multispectral-Object-Detection Intro Official Code for Cross-Modality Fusion Transformer for Multispectral Object Detection. Multispectral Object Dete

Richard Fang 121 Jan 01, 2023
CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Detection in Remote Sensing Images

CFC-Net This project hosts the official implementation for the paper: CFC-Net: A Critical Feature Capturing Network for Arbitrary-Oriented Object Dete

ming71 55 Dec 12, 2022
People log into different sites every day to get information and browse through these sites one by one

HyperLink People log into different sites every day to get information and browse through these sites one by one. And they are exposed to advertisemen

0 Feb 17, 2022
OCR Streamlit App is used to extract text from images using python's easyocr, pytorch and streamlit packages

OCR-Streamlit-App OCR Streamlit App is used to extract text from images using python's easyocr, pytorch and streamlit packages OCR app gets an image a

Siva Prakash 5 Apr 05, 2022
Code for ICML 2021 paper: How could Neural Networks understand Programs?

OSCAR This repository contains the source code of our ICML 2021 paper How could Neural Networks understand Programs?. Environment Run following comman

Dinglan Peng 115 Dec 17, 2022
Tree LSTM implementation in PyTorch

Tree-Structured Long Short-Term Memory Networks This is a PyTorch implementation of Tree-LSTM as described in the paper Improved Semantic Representati

Riddhiman Dasgupta 529 Dec 10, 2022
JORLDY an open-source Reinforcement Learning (RL) framework provided by KakaoEnterprise

Repository for Open Source Reinforcement Learning Framework JORLDY

Kakao Enterprise Corp. 330 Dec 30, 2022
Minimalist Error collection Service compatible with Rollbar clients. Sentry or Rollbar alternative.

Minimalist Error collection Service Features Compatible with any Rollbar client(see https://docs.rollbar.com/docs). Just change the endpoint URL to yo

Haukur Rósinkranz 381 Nov 11, 2022
💛 Code and Dataset for our EMNLP 2021 paper: "Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes"

Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes Official PyTorch implementation and EmoCause evaluatio

Hyunwoo Kim 51 Jan 06, 2023
Tutorial materials for Part of NSU Intro to Deep Learning with PyTorch.

Intro to Deep Learning Materials are part of North South University (NSU) Intro to Deep Learning with PyTorch workshop series. (Slides) Related materi

Hasib Zunair 9 Jun 08, 2022
3rd Place Solution for ICCV 2021 Workshop SSLAD Track 3A - Continual Learning Classification Challenge

Online Continual Learning via Multiple Deep Metric Learning and Uncertainty-guided Episodic Memory Replay 3rd Place Solution for ICCV 2021 Workshop SS

Rifki Kurniawan 6 Nov 10, 2022
Reproducing code of hair style replacement method from Barbershorp.

Barbershorp Reproducing code of hair style replacement method from Barbershorp. Also reproduces II2S, an improved version of Image2StyleGAN. Requireme

1 Dec 24, 2021
DLFlow is a deep learning framework.

DLFlow是一套深度学习pipeline,它结合了Spark的大规模特征处理能力和Tensorflow模型构建能力。利用DLFlow可以快速处理原始特征、训练模型并进行大规模分布式预测,十分适合离线环境下的生产任务。利用DLFlow,用户只需专注于模型开发,而无需关心原始特征处理、pipeline构建、生产部署等工作。

DiDi 152 Oct 27, 2022
This repository contains the code for the paper Neural RGB-D Surface Reconstruction

Neural RGB-D Surface Reconstruction Paper | Project Page | Video Neural RGB-D Surface Reconstruction Dejan Azinović, Ricardo Martin-Brualla, Dan B Gol

Dejan 406 Jan 04, 2023
A PyTorch Implementation of PGL-SUM from "Combining Global and Local Attention with Positional Encoding for Video Summarization", Proc. IEEE ISM 2021

PGL-SUM: Combining Global and Local Attention with Positional Encoding for Video Summarization PyTorch Implementation of PGL-SUM From "PGL-SUM: Combin

Evlampios Apostolidis 35 Dec 22, 2022