Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs)

Overview

Why Spectral Normalization Stabilizes GANs: Analysis and Improvements

[paper (NeurIPS 2021)] [paper (arXiv)] [code]

Authors: Zinan Lin, Vyas Sekar, Giulia Fanti

Abstract: Spectral normalization (SN) is a widely-used technique for improving the stability and sample quality of Generative Adversarial Networks (GANs). However, there is currently limited understanding of why SN is effective. In this work, we show that SN controls two important failure modes of GAN training: exploding and vanishing gradients. Our proofs illustrate a (perhaps unintentional) connection with the successful LeCun initialization. This connection helps to explain why the most popular implementation of SN for GANs requires no hyper-parameter tuning, whereas stricter implementations of SN have poor empirical performance out-of-the-box. Unlike LeCun initialization which only controls gradient vanishing at the beginning of training, SN preserves this property throughout training. Building on this theoretical understanding, we propose a new spectral normalization technique: Bidirectional Scaled Spectral Normalization (BSSN), which incorporates insights from later improvements to LeCun initialization: Xavier initialization and Kaiming initialization. Theoretically, we show that BSSN gives better gradient control than SN. Empirically, we demonstrate that it outperforms SN in sample quality and training stability on several benchmark datasets.


This repo contains the codes for reproducing the experiments of our BSN and different SN variants in the paper. The codes were tested under Python 2.7.5, TensorFlow 1.14.0.

Preparing datasets

CIFAR10

Download cifar-10-python.tar.gz from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz (or from other sources).

STL10

Download stl10_binary.tar.gz from http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz (or from other sources), and put it in dataset_preprocess/STL10 folder. Then run python preprocess.py. This code will resize the images into 48x48x3 format, and save the images in stl10.npy.

CelebA

Download img_align_celeba.zip from https://www.kaggle.com/jessicali9530/celeba-dataset (or from other sources), and put it in dataset_preprocess/CelebA folder. Then run python preprocess.py. This code will crop and resize the images into 64x64x3 format, and save the images in celeba.npy.

ImageNet

Download ILSVRC2012_img_train.tar from http://www.image-net.org/ (or from other sources), and put it in dataset_preprocess/ImageNet folder. Then run python preprocess.py. This code will crop and resize the images into 128x128x3 format, and save the images in ILSVRC2012folder. Each subfolder in ILSVRC2012 folder corresponds to one class. Each npy file in the subfolders corresponds to an image.

Training BSN and SN variants

Prerequisites

The codes are based on GPUTaskScheduler library, which helps you automatically schedule the jobs among GPU nodes. Please install it first. You may need to change GPU configurations according to the devices you have. The configurations are set in config.py in each directory. Please refer to GPUTaskScheduler's GitHub page for the details of how to make proper configurations.

You can also run these codes without GPUTaskScheduler. Just run python gan.py in gan subfolders.

CIFAR10, STL10, CelebA

Preparation

Copy the preprocessed datasets from the previous steps into the following paths:

  • CIFAR10: /data/CIFAR10/cifar-10-python.tar.gz.
  • STL10: /data/STL10/cifar-10-stl10.npy.
  • CelebA: /data/CelebA/celeba.npy.

Here means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-CNN.
  • SN with the same gammas: same_gamma-CNN.
  • SN with different gammas: diff_gamma-CNN.

Alternatively, you can directly modify the dataset paths in /gan_task.py to the path of the preprocessed dataset folders.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

ImageNet

Preparation

Copy the preprocessed folder ILSVRC2012 from the previous steps to /data/imagenet/ILSVRC2012, where means

  • Vanilla SN and our proposed BSSN/SSN/BSN without gammas: no_gamma-ResNet.

Alternatively, you can directly modify the dataset path in /gan_task.py to the path of the preprocessed folder ILSVRC2012.

Running codes

Now you can directly run python main.py in each to train the models.

All the configurable hyper-parameters can be set in config.py. The hyper-parameters in the file are already set for reproducing the results in the paper. Please refer to GPUTaskScheduler's GitHub page for the details of the grammar of this file.

The code supports multi-GPU training for speed-up, by separating each data batch equally among multiple GPUs. To do that, you only need to make minor modifications in config.py. For example, if you have two GPUs with IDs 0 and 1, then all you need to do is to (1) change "gpu": ["0"] to "gpu": [["0", "1"]], and (2) change "num_gpus": [1] to "num_gpus": [2]. Note that the number of GPUs might influence the results because in this implementation the batch normalization layers on different GPUs are independent. In our experiments, we were using only one GPU.

Results

The code generates the following result files/folders:

  • /results/ /worker.log : Standard output and error from the code.
  • /results/ /metrics.csv : Inception Score and FID during training.
  • /results/ /sample/*.png : Generated images during training.
  • /results/ /checkpoint/* : TensorFlow checkpoints.
  • /results/ /time.txt : Training iteration timestamps.
Owner
Zinan Lin
Ph.D. student at Electrical and Computer Engineering, Carnegie Mellon University
Zinan Lin
FCN (Fully Convolutional Network) is deep fully convolutional neural network architecture for semantic pixel-wise segmentation

FCN_via_Keras FCN FCN (Fully Convolutional Network) is deep fully convolutional neural network architecture for semantic pixel-wise segmentation. This

Kento Watanabe 48 Aug 30, 2022
Some pvbatch (paraview) scripts for postprocessing OpenFOAM data

pvbatchForFoam Some pvbatch (paraview) scripts for postprocessing OpenFOAM data For every script there is a help message available: pvbatch pv_state_s

Morev Ilya 2 Oct 26, 2022
PaddleBoBo是基于PaddlePaddle和PaddleSpeech、PaddleGAN等开发套件的虚拟主播快速生成项目

PaddleBoBo - 元宇宙时代,你也可以动手做一个虚拟主播。 PaddleBoBo是基于飞桨PaddlePaddle深度学习框架和PaddleSpeech、PaddleGAN等开发套件的虚拟主播快速生成项目。PaddleBoBo致力于简单高效、可复用性强,只需要一张带人像的图片和一段文字,就能

502 Jan 08, 2023
GAN-based Matrix Factorization for Recommender Systems

GAN-based Matrix Factorization for Recommender Systems This repository contains the datasets' splits, the source code of the experiments and their res

Ervin Dervishaj 9 Nov 06, 2022
BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalanced Tongue Data

Balanced-Evolutionary-Semi-Stacking Code for the paper ''BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalan

0 Jan 16, 2022
Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions"

Graph Convolution Simulator (GCS) Source code for "Understanding Knowledge Integration in Language Models with Graph Convolutions" Requirements: PyTor

yifan 10 Oct 18, 2022
Deep Hedging Demo - An Example of Using Machine Learning for Derivative Pricing.

Deep Hedging Demo Pricing Derivatives using Machine Learning 1) Jupyter version: Run ./colab/deep_hedging_colab.ipynb on Colab. 2) Gui version: Run py

Yu Man Tam 102 Jan 06, 2023
CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer

CSAW-M This repository contains code for CSAW-M: An Ordinal Classification Dataset for Benchmarking Mammographic Masking of Cancer. Source code for tr

Yue Liu 7 Oct 11, 2022
Code for "Long-tailed Distribution Adaptation"

Long-tailed Distribution Adaptation (Accepted in ACM MM2021) This project is built upon BBN. Installation pip install -r requirements.txt Usage Traini

Zhiliang Peng 10 May 18, 2022
use machine learning to recognize gesture on raspberrypi

Raspberrypi_Gesture-Recognition use machine learning to recognize gesture on raspberrypi 說明 利用 tensorflow lite 訓練手部辨識模型 分辨 "剪刀"、"石頭"、"布" 之手勢 再將訓練模型匯入

1 Dec 10, 2021
CryptoFrog - My First Strategy for freqtrade

cryptofrog-strategies CryptoFrog - My First Strategy for freqtrade NB: (2021-04-20) You'll need the latest freqtrade develop branch otherwise you migh

Robert Davey 137 Jan 01, 2023
This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation

This is a GUI interface which can process forest fire detection, smoke detection and fire segmentation. Yolov5 is used to detect fire and smoke and unet is used to segment fire.

7 Jan 08, 2023
The official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averaging Approach

Graph Optimizer This repo contains the official implementation of our CVPR 2021 paper - Hybrid Rotation Averaging: A Fast and Robust Rotation Averagin

Chenyu 109 Dec 23, 2022
[SIGGRAPH 2022 Journal Track] AvatarCLIP: Zero-Shot Text-Driven Generation and Animation of 3D Avatars

AvatarCLIP: Zero-Shot Text-Driven Generation and Animation of 3D Avatars Fangzhou Hong1*  Mingyuan Zhang1*  Liang Pan1  Zhongang Cai1,2,3  Lei Yang2 

Fangzhou Hong 749 Jan 04, 2023
A ssl analyzer which could analyzer target domain's certificate.

ssl_analyzer A ssl analyzer which could analyzer target domain's certificate. Analyze the domain name ssl certificate information according to the inp

vincent 17 Dec 12, 2022
GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Models

GraphRNN: Generating Realistic Graphs with Deep Auto-regressive Model This repository is the official PyTorch implementation of GraphRNN, a graph gene

Jiaxuan 568 Dec 29, 2022
Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models.

WECHSEL Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models. arXiv: https://arx

Institute of Computational Perception 45 Dec 29, 2022
The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color

The World of an Octopus: How Reporting Bias Influences a Language Model's Perception of Color Overview Code and dataset for The World of an Octopus: H

1 Nov 13, 2021
TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.

TensorFlow-Image-Models Introduction Usage Models Profiling License Introduction TensorfFlow-Image-Models (tfimm) is a collection of image models with

Martins Bruveris 227 Dec 20, 2022
A curated list of awesome Model-Based RL resources

Awesome Model-Based Reinforcement Learning This is a collection of research papers for model-based reinforcement learning (mbrl). And the repository w

OpenDILab 427 Jan 03, 2023