Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

Owner
Jamie Townsend
Jamie Townsend
Randomizes the warps in a stock pokeemerald repo.

pokeemerald warp randomizer Randomizes the warps in a stock pokeemerald repo. Usage Instructions Install networkx and matplotlib via pip3 or similar.

Max Thomas 6 Mar 17, 2022
LieTransformer: Equivariant Self-Attention for Lie Groups

LieTransformer This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant Self-At

OxCSML (Oxford Computational Statistics and Machine Learning) 50 Dec 28, 2022
Unified tracking framework with a single appearance model

Paper: Do different tracking tasks require different appearance model? [ArXiv] (comming soon) [Project Page] (comming soon) UniTrack is a simple and U

ZhongdaoWang 300 Dec 24, 2022
Towhee is a flexible machine learning framework currently focused on computing deep learning embeddings over unstructured data.

Towhee is a flexible machine learning framework currently focused on computing deep learning embeddings over unstructured data.

1.7k Jan 08, 2023
Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Database

Python cx_Oracle Notebooks, 2022 The repository contains Jupyter notebooks showing best practices for using cx_Oracle, the Python DB API for Oracle Da

Christopher Jones 13 Dec 15, 2022
Open source implementation of "A Self-Supervised Descriptor for Image Copy Detection" (SSCD).

A Self-Supervised Descriptor for Image Copy Detection (SSCD) This is the open-source codebase for "A Self-Supervised Descriptor for Image Copy Detecti

Meta Research 68 Jan 04, 2023
A small library for creating and manipulating custom JAX Pytree classes

Treeo A small library for creating and manipulating custom JAX Pytree classes Light-weight: has no dependencies other than jax. Compatible: Treeo Tree

Cristian Garcia 58 Nov 23, 2022
FlingBot: The Unreasonable Effectiveness of Dynamic Manipulations for Cloth Unfolding

This repository contains code for training and evaluating FlingBot in both simulation and real-world settings on a dual-UR5 robot arm setup for Ubuntu 18.04

Columbia Artificial Intelligence and Robotics Lab 70 Dec 06, 2022
Unofficial PyTorch Implementation of AHDRNet (CVPR 2019)

AHDRNet-PyTorch This is the PyTorch implementation of Attention-guided Network for Ghost-free High Dynamic Range Imaging (CVPR 2019). The official cod

Yutong Zhang 4 Sep 08, 2022
The pytorch implementation of SOKD (BMVC2021).

Semi-Online Knowledge Distillation Implementations of SOKD. Requirements This repo was tested with Python 3.8, PyTorch 1.5.1, torchvision 0.6.1, CUDA

4 Dec 19, 2021
Source code for NAACL 2021 paper "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference"

TR-BERT Source code and dataset for "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference". The code is based on huggaface's transformers.

THUNLP 37 Oct 30, 2022
Official Implement of CVPR 2021 paper “Cross-Modal Collaborative Representation Learning and a Large-Scale RGBT Benchmark for Crowd Counting”

RGBT Crowd Counting Lingbo Liu, Jiaqi Chen, Hefeng Wu, Guanbin Li, Chenglong Li, Liang Lin. "Cross-Modal Collaborative Representation Learning and a L

37 Dec 08, 2022
Breast Cancer Classification Model is applied on a different dataset

Breast Cancer Classification Model is applied on a different dataset

1 Feb 04, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
TreeSubstitutionCipher - Encryption system based on trees and substitution

Tree Substitution Cipher Generation Algorithm: Generate random tree. Tree nodes

stepa 1 Jan 08, 2022
use tensorflow 2.0 to tell a dog and cat from a specified picture

dog_or_cat use tensorflow 2.0 to tell a dog and cat from a specified picture This is one of the classic experiments for the introduction of deep learn

你这个代码我看不懂 1 Oct 22, 2021
Swin-Transformer is basically a hierarchical Transformer whose representation is computed with shifted windows.

Swin-Transformer Swin-Transformer is basically a hierarchical Transformer whose representation is computed with shifted windows. For more details, ple

旷视天元 MegEngine 9 Mar 14, 2022
CIFAR-10 Photo Classification

Image-Classification CIFAR-10 Photo Classification CIFAR-10_Dataset_Classfication CIFAR-10 Photo Classification Dataset CIFAR is an acronym that stand

ADITYA SHAH 1 Jan 05, 2022
TAPEX: Table Pre-training via Learning a Neural SQL Executor

TAPEX: Table Pre-training via Learning a Neural SQL Executor The official repository which contains the code and pre-trained models for our paper TAPE

Microsoft 157 Dec 28, 2022
A lightweight face-recognition toolbox and pipeline based on tensorflow-lite

FaceIDLight 📘 Description A lightweight face-recognition toolbox and pipeline based on tensorflow-lite with MTCNN-Face-Detection and ArcFace-Face-Rec

Martin Knoche 16 Dec 07, 2022