Ensembling Off-the-shelf Models for GAN Training

Overview

Data-Efficient GANs with DiffAugment

project | paper | datasets | video | slides

Generated using only 100 images of Obama, grumpy cats, pandas, the Bridge of Sighs, the Medici Fountain, the Temple of Heaven, without pre-training.

[NEW!] PyTorch training with DiffAugment-stylegan2-pytorch is now available!

[NEW!] Our Colab tutorial is released!

[NEW!] FFHQ training is supported! See the DiffAugment-stylegan2 README.

[NEW!] Time to generate 100-shot interpolation videos with generate_gif.py!

[NEW!] Our DiffAugment-biggan-imagenet repo (for TPU training) is released!

[NEW!] Our DiffAugment-biggan-cifar PyTorch repo is released!

This repository contains our implementation of Differentiable Augmentation (DiffAugment) in both PyTorch and TensorFlow. It can be used to significantly improve the data efficiency for GAN training. We have provided DiffAugment-stylegan2 (TensorFlow) and DiffAugment-stylegan2-pytorch, DiffAugment-biggan-cifar (PyTorch) for GPU training, and DiffAugment-biggan-imagenet (TensorFlow) for TPU training.

Low-shot generation without pre-training. With DiffAugment, our model can generate high-fidelity images using only 100 Obama portraits, grumpy cats, or pandas from our collected 100-shot datasets, 160 cats or 389 dogs from the AnimalFace dataset at 256×256 resolution.

Unconditional generation results on CIFAR-10. StyleGAN2’s performance drastically degrades given less training data. With DiffAugment, we are able to roughly match its FID and outperform its Inception Score (IS) using only 20% training data.

Differentiable Augmentation for Data-Efficient GAN Training
Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
MIT, Tsinghua University, Adobe Research, CMU
arXiv

Overview

Overview of DiffAugment for updating D (left) and G (right). DiffAugment applies the augmentation T to both the real sample x and the generated output G(z). When we update G, gradients need to be back-propagated through T (iii), which requires T to be differentiable w.r.t. the input.

Training and Generation with 100 Images

To generate an interpolation video using our pre-trained models:

cd DiffAugment-stylegan2
python generate_gif.py -r mit-han-lab:DiffAugment-stylegan2-100-shot-obama.pkl -o obama.gif

or to train a new model:

python run_low_shot.py --dataset=100-shot-obama --num-gpus=4

You may also try out 100-shot-grumpy_cat, 100-shot-panda, 100-shot-bridge_of_sighs, 100-shot-medici_fountain, 100-shot-temple_of_heaven, 100-shot-wuzhen, or the folder containing your own training images. Please refer to the DiffAugment-stylegan2 README for the dependencies and details.

[NEW!] PyTorch training is now available:

cd DiffAugment-stylegan2-pytorch
python train.py --outdir=training-runs --data=https://data-efficient-gans.mit.edu/datasets/100-shot-obama.zip --gpus=1

DiffAugment for StyleGAN2

To run StyleGAN2 + DiffAugment for unconditional generation on the 100-shot datasets, CIFAR, FFHQ, or LSUN, please refer to the DiffAugment-stylegan2 README or DiffAugment-stylegan2-pytorch for the PyTorch version.

DiffAugment for BigGAN

Please refer to the DiffAugment-biggan-cifar README to run BigGAN + DiffAugment for conditional generation on CIFAR (using GPUs), and the DiffAugment-biggan-imagenet README to run on ImageNet (using TPUs).

Using DiffAugment for Your Own Training

To help you use DiffAugment in your own codebase, we provide portable DiffAugment operations of both TensorFlow and PyTorch versions in DiffAugment_tf.py and DiffAugment_pytorch.py. Generally, DiffAugment can be easily adopted in any model by substituting every D(x) with D(T(x)), where x can be real images or fake images, D is the discriminator, and T is the DiffAugment operation. For example,

from DiffAugment_pytorch import DiffAugment
# from DiffAugment_tf import DiffAugment
policy = 'color,translation,cutout' # If your dataset is as small as ours (e.g.,
# hundreds of images), we recommend using the strongest Color + Translation + Cutout.
# For large datasets, try using a subset of transformations in ['color', 'translation', 'cutout'].
# Welcome to discover more DiffAugment transformations!

...
# Training loop: update D
reals = sample_real_images() # a batch of real images
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
real_scores = Discriminator(DiffAugment(reals, policy=policy))
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating D's loss based on real_scores and fake_scores...
...

...
# Training loop: update G
z = sample_latent_vectors()
fakes = Generator(z) # a batch of fake images
fake_scores = Discriminator(DiffAugment(fakes, policy=policy))
# Calculating G's loss based on fake_scores...
...

We have implemented Color, Translation, and Cutout DiffAugment as visualized below:

Citation

If you find this code helpful, please cite our paper:

@inproceedings{zhao2020diffaugment,
  title={Differentiable Augmentation for Data-Efficient GAN Training},
  author={Zhao, Shengyu and Liu, Zhijian and Lin, Ji and Zhu, Jun-Yan and Han, Song},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2020}
}

Acknowledgements

We thank NSF Career Award #1943349, MIT-IBM Watson AI Lab, Google, Adobe, and Sony for supporting this research. Research supported with Cloud TPUs from Google's TensorFlow Research Cloud (TFRC). We thank William S. Peebles and Yijun Li for helpful comments.

Owner
MIT HAN Lab
Accelerating Deep Learning Computing
MIT HAN Lab
Libraries, tools and tasks created and used at DeepMind Robotics.

dm_robotics: Libraries, tools, and tasks created and used for Robotics research at DeepMind. Package overview Package Summary Transformations Rigid bo

DeepMind 273 Jan 06, 2023
High performance distributed framework for training deep learning recommendation models based on PyTorch.

PERSIA (Parallel rEcommendation tRaining System with hybrId Acceleration) is developed by AI 340 Dec 30, 2022

Hyper-parameter optimization for sklearn

hyperopt-sklearn Hyperopt-sklearn is Hyperopt-based model selection among machine learning algorithms in scikit-learn. See how to use hyperopt-sklearn

1.4k Jan 01, 2023
Supporting code for "Autoregressive neural-network wavefunctions for ab initio quantum chemistry".

naqs-for-quantum-chemistry This repository contains the codebase developed for the paper Autoregressive neural-network wavefunctions for ab initio qua

Tom Barrett 24 Dec 23, 2022
The official implementation of ELSA: Enhanced Local Self-Attention for Vision Transformer

ELSA: Enhanced Local Self-Attention for Vision Transformer By Jingkai Zhou, Pich

DamoCV 87 Dec 19, 2022
The InterScript dataset contains interactive user feedback on scripts generated by a T5-XXL model.

Interscript The Interscript dataset contains interactive user feedback on a T5-11B model generated scripts. Dataset data.json contains the data in an

AI2 8 Dec 01, 2022
Fast Learning of MNL Model From General Partial Rankings with Application to Network Formation Modeling

Fast-Partial-Ranking-MNL This repo provides a PyTorch implementation for the CopulaGNN models as described in the following paper: Fast Learning of MN

Xingjian Zhang 3 Aug 19, 2022
Image Captioning on google cloud platform based on iot

Image-Captioning-on-google-cloud-platform-based-on-iot - Image Captioning on google cloud platform based on iot

Shweta_kumawat 1 Jan 20, 2022
Implementation of "With a Little Help from my Temporal Context: Multimodal Egocentric Action Recognition, BMVC, 2021" in PyTorch

Multimodal Temporal Context Network (MTCN) This repository implements the model proposed in the paper: Evangelos Kazakos, Jaesung Huh, Arsha Nagrani,

Evangelos Kazakos 13 Nov 24, 2022
Implementation of Neural Style Transfer in Pytorch

PytorchNeuralStyleTransfer Code to run Neural Style Transfer from our paper Image Style Transfer Using Convolutional Neural Networks. Also includes co

Leon Gatys 396 Dec 01, 2022
Code release for Hu et al. Segmentation from Natural Language Expressions. in ECCV, 2016

Segmentation from Natural Language Expressions This repository contains the code for the following paper: R. Hu, M. Rohrbach, T. Darrell, Segmentation

Ronghang Hu 88 May 24, 2022
Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running numpy.

Deploy tensorflow graphs for fast evaluation and export to tensorflow-less environments running numpy. Now with tensorflow 1.0 support. Evaluation usa

Marcel R. 349 Aug 06, 2022
Code for the paper "Adapting Monolingual Models: Data can be Scarce when Language Similarity is High"

Wietse de Vries • Martijn Bartelds • Malvina Nissim • Martijn Wieling Adapting Monolingual Models: Data can be Scarce when Language Similarity is High

Wietse de Vries 5 Aug 02, 2021
🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗

🚗 INGI Dakar 2K21 - Be the first one on the finish line ! 🚗 This year's first semester Club Info challenge will put you at the head of a car racing

ClubINFO INGI (UCLouvain) 6 Dec 10, 2021
a basic code repository for basic task in CV(classification,detection,segmentation)

basic_cv a basic code repository for basic task in CV(classification,detection,segmentation,tracking) classification generate dataset train predict de

1 Oct 15, 2021
Light-SERNet: A lightweight fully convolutional neural network for speech emotion recognition

Light-SERNet This is the Tensorflow 2.x implementation of our paper "Light-SERNet: A lightweight fully convolutional neural network for speech emotion

Arya Aftab 29 Nov 12, 2022
Code for "Learning to Segment Rigid Motions from Two Frames".

rigidmask Code for "Learning to Segment Rigid Motions from Two Frames". ** This is a partial release with inference and evaluation code.

Gengshan Yang 157 Nov 21, 2022
Instant-nerf-pytorch - NeRF trained SUPER FAST in pytorch

instant-nerf-pytorch This is WORK IN PROGRESS, please feel free to contribute vi

94 Nov 22, 2022
Mercer Gaussian Process (MGP) and Fourier Gaussian Process (FGP) Regression

Mercer Gaussian Process (MGP) and Fourier Gaussian Process (FGP) Regression We provide the code used in our paper "How Good are Low-Rank Approximation

Aristeidis (Ares) Panos 0 Dec 13, 2021
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrai

Hugging Face 77.4k Jan 05, 2023