Very deep VAEs in JAX/Flax

Overview

Very Deep VAEs in JAX/Flax

Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images using JAX and Flax, ported from the official OpenAI PyTorch implementation.

I have tried to keep this implementation as close as possible to the original. I was able to re-use a large proportion of the code, including the data input pipeline, which still uses PyTorch. I recommend installing a CPU-only version of PyTorch for this.

Tested with JAX 0.2.10, Flax 0.3.0, PyTorch 1.7.1, NumPy 1.19.2. I also ran training to convergence on cifar10 and reproduced the test ELBO value of 2.87 from the paper, using --conv_precision=highest, see below. If anyone asks for trained checkpoints for cifar I will be happy to upload them.

From the paper, some model samples and a visualization of how it generates them:

image

Setup

As well as JAX, Flax, NumPy and PyTorch, this implementation depends on Pillow and scikit-learn:

pip install pillow
pip install sklearn

Also, you'll have to download the data, depending on which one you want to run:

./setup_cifar10.sh
./setup_imagenet.sh imagenet32
./setup_imagenet.sh imagenet64
./setup_ffhq256.sh
./setup_ffhq1024.sh  /path/to/images1024x1024  # this one depends on you first downloading the subfolder `images_1024x1024` from https://github.com/NVlabs/ffhq-dataset on your own & running `pip install torchvision`

Training models

Hyperparameters all reside in hps.py.

python train.py --hps cifar10
python train.py --hps imagenet32
python train.py --hps imagenet64
python train.py --hps ffhq256
python train.py --hps ffhq1024

TODOs

  • Implement support for 5 bit images which was used in the paper's FFHQ-256 experiments.

Known differences from the orignal

  • Instead of using the PyTorch default layer initializers we use the Flax defaults.
  • Renamed rate/distortion to kl/loglikelihood.
  • In multihost configurations, checkpoints are saved to disk on all hosts.
  • Slight changes to DMOL loss.

Things to watch out for

We tried to keep this implementation as close as possible to the author's original Pytorch implementation. There are two potentially confusing things which we chose to preserve. Firstly, the --n_batch command line argument specifies the per device batch size; on configurations with multiple GPUs/TPUs and multiple hosts this needs to be taken into account when comparing runs on different configurations. Secondly, some of the default hyperparameter settings in hps.py do not match the settings used for the paper's experiments, which are specified on page 15 of the paper.

In order to reproduce results from the paper on TPU, it may be necessary to set --conv_precision=highest, which simulates GPU-like float32 precision on the TPU. Note that this can result in slower runtime. In my experiments on cifar10 I've found that this setting has about a 1% effect on the final ELBO value and was necessary to reproduce the value 2.87 reported in the paper.

Acknowledgements

This code is very closely based on Rewon Child's implementation, thanks to him for writing that. Thanks to Julius Kunze for tidying the code and fixing some bugs.

Owner
Jamie Townsend
Jamie Townsend
Fine-tune pretrained Convolutional Neural Networks with PyTorch

Fine-tune pretrained Convolutional Neural Networks with PyTorch. Features Gives access to the most popular CNN architectures pretrained on ImageNet. A

Alex Parinov 694 Nov 23, 2022
Understanding and Overcoming the Challenges of Efficient Transformer Quantization

Transformer Quantization This repository contains the implementation and experiments for the paper presented in Yelysei Bondarenko1, Markus Nagel1, Ti

83 Dec 30, 2022
Implementations of orthogonal and semi-orthogonal convolutions in the Fourier domain with applications to adversarial robustness

Orthogonalizing Convolutional Layers with the Cayley Transform This repository contains implementations and source code to reproduce experiments for t

CMU Locus Lab 36 Dec 30, 2022
Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user who joins your server.

Discord-Protect Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user wh

Tir Omar 2 Oct 28, 2021
Scheme for training and applying a label propagation framework

Factorisation-based Image Labelling Overview This is a scheme for training and applying the factorisation-based image labelling (FIL) framework. Some

Wellcome Centre for Human Neuroimaging 2 Dec 17, 2021
Machine Learning Toolkit for Kubernetes

Kubeflow the cloud-native platform for machine learning operations - pipelines, training and deployment. Documentation Please refer to the official do

Kubeflow 12.1k Jan 03, 2023
Official Pytorch Implementation for Splicing ViT Features for Semantic Appearance Transfer presenting Splice

Splicing ViT Features for Semantic Appearance Transfer [Project Page] Splice is a method for semantic appearance transfer, as described in Splicing Vi

Omer Bar Tal 253 Jan 06, 2023
一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。

captcha_server 一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。 使用方法 python = 3.8 以上环境 pip install -r requirements.txt -i https://pypi.douban.com/simple gun

Sml2h3 189 Dec 02, 2022
Resources complimenting the Machine Learning Course led in the Faculty of mathematics and informatics part of Sofia University.

Machine Learning and Data Mining, Summer 2021-2022 How to learn data science and machine learning? Programming. Learn Python. Basic Statistics. Take a

Simeon Hristov 8 Oct 04, 2022
Official repository for "Action-Based Conversations Dataset: A Corpus for Building More In-Depth Task-Oriented Dialogue Systems"

Action-Based Conversations Dataset (ABCD) This respository contains the code and data for ABCD (Chen et al., 2021) Introduction Whereas existing goal-

ASAPP Research 49 Oct 09, 2022
Supporting code for short YouTube series Neural Networks Demystified.

Neural Networks Demystified Supporting iPython notebooks for the YouTube Series Neural Networks Demystified. I've included formulas, code, and the tex

Stephen 1.3k Dec 23, 2022
Rethinking of Pedestrian Attribute Recognition: A Reliable Evaluation under Zero-Shot Pedestrian Identity Setting

Pytorch Pedestrian Attribute Recognition: A strong PyTorch baseline of pedestrian attribute recognition and multi-label classification.

Jian 79 Dec 18, 2022
SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021)

SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021) PyTorch implementation of SnapMix | paper Method Overview Cite

DavidHuang 126 Dec 30, 2022
Official implementation of CVPR2020 paper "Deep Generative Model for Robust Imbalance Classification"

Deep Generative Model for Robust Imbalance Classification Deep Generative Model for Robust Imbalance Classification Xinyue Wang, Yilin Lyu, Liping Jin

9 Nov 01, 2022
Quantify the difference between two arbitrary curves in space

similaritymeasures Quantify the difference between two arbitrary curves Curves in this case are: discretized by inidviudal data points ordered from a

Charles Jekel 175 Jan 08, 2023
PyTorch implementation of the ExORL: Exploratory Data for Offline Reinforcement Learning

ExORL: Exploratory Data for Offline Reinforcement Learning This is an original PyTorch implementation of the ExORL framework from Don't Change the Alg

Denis Yarats 52 Jan 01, 2023
A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron.

The GatedTabTransformer. A deep learning tabular classification architecture inspired by TabTransformer with integrated gated multilayer perceptron. C

Radi Cho 60 Dec 15, 2022
Angora is a mutation-based fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without symbolic execution.

Angora Angora is a mutation-based coverage guided fuzzer. The main goal of Angora is to increase branch coverage by solving path constraints without s

833 Jan 07, 2023
Face Mask Detection system based on computer vision and deep learning using OpenCV and Tensorflow/Keras

Face Mask Detection Face Mask Detection System built with OpenCV, Keras/TensorFlow using Deep Learning and Computer Vision concepts in order to detect

Chandrika Deb 1.4k Jan 03, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 01, 2023