Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Overview

Finding an Unsupervised Image Segmenter in each of your Deep Generative Models

Paper

Description

Recent research has shown that numerous human-interpretable directions exist in the latent space of GANs. In this paper, we develop an automatic procedure for finding directions that lead to foreground-background image separation, and we use these directions to train an image segmentation model without human supervision. Our method is generator-agnostic, producing strong segmentation results with a wide range of different GAN architectures. Furthermore, by leveraging GANs pretrained on large datasets such as ImageNet, we are able to segment images from a range of domains without further training or finetuning. Evaluating our method on image segmentation benchmarks, we compare favorably to prior work while using neither human supervision nor access to the training data. Broadly, our results demonstrate that automatically extracting foreground-background structure from pretrained deep generative models can serve as a remarkably effective substitute for human supervision.

How to run

Dependencies

This code depends on pytorch-pretrained-gans, a repository I developed that exposes a standard interface for a variety of pretrained GANs. Install it with:

pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

The pretrained weights for most GANs are downloaded automatically. For those that are not, I have provided scripts in that repository.

There are also some standard dependencies:

Install them with:

pip install hydra-core==1.1.0dev5 pytorch_lightning albumentations tqdm retry kornia

General Approach

Our unsupervised segmentation approach has two steps: (1) finding a good direction in latent space, and (2) training a segmentation model from data and masks that are generated using this direction.

In detail, this means:

  1. We use optimization/main.py finds a salient direction (or two salient directions) in the latent space of a given pretrained GAN that leads to foreground-background image separation.
  2. We use segmentation/main.py to train a standard segmentation network (a UNet) on generated data. The data can be generated in two ways: (1) you can generate the images on-the-fly during training, or (2) you can generate the images before training the segmentation model using segmentation/generate_and_save.py and then train the segmentation network afterward. The second approach is faster, but requires more disk space (~10GB for 1 million images). We will also provide a pre-generated dataset (coming soon).

Configuration and Logging

We use Hydra for configuration and Weights and Biases for logging. With Hydra, you can specify a config file (found in configs/) with --config-name=myconfig.yaml. You can also override the config from the command line by specifying the overriding arguments (without --). For example, you can enable Weights and Biases with wandb=True and you can name the run with name=myname.

The structure of the configs is as follows:

config
├── data_gen
│   ├── generated.yaml  # <- for generating data with 1 latent direction
│   ├── generated-dual.yaml   # <- for generating data with 2 latent directions
│   ├── generator  # <- different types of GANs for generating data
│   │   ├── bigbigan.yaml
│   │   ├── pretrainedbiggan.yaml
│   │   ├── selfconditionedgan.yaml
│   │   ├── studiogan.yaml
│   │   └── stylegan2.yaml 
│   └── saved.yaml  # <- for using pre-generated data
├── optimize.yaml  # <- for optimization
└── segment.yaml   # <- for segmentation

Code Structure

The code is structured as follows:

src
├── models  # <- segmentation model
│   ├── __init__.py
│   ├── latent_shift_model.py  # <- shifts direction in latent space
│   ├── unet_model.py  # <- segmentation model
│   └── unet_parts.py
├── config  # <- configuration, explained above
│   ├── ... 
├── datasets  # <- classes for loading datasets during segmentation/generation
│   ├── __init__.py
│   ├── gan_dataset.py  # <- for generating dataset
│   ├── saved_gan_dataset.py  # <- for pre-generated dataset
│   └── real_dataset.py  # <- for evaluation datasets (i.e. real images)
├── optimization
│   ├── main.py  # <- main script
│   └── utils.py  # <- helper functions
└── segmentation
    ├── generate_and_save.py  # <- for generating a dataset and saving it to disk
    ├── main.py  # <- main script, uses PyTorch Lightning 
    ├── metrics.py  # <- for mIoU/F-score calculations
    └── utils.py  # <- helper functions

Datasets

The datasets should have the following structure. You can easily add you own datasets or use only a subset of these datasets by modifying config/segment.yaml. You should specify your directory by modifying root in that file on line 19, or by passing data_seg.root=MY_DIR using the command line whenever you call python segmentation/main.py.

├── DUT_OMRON
│   ├── DUT-OMRON-image
│   │   └── ...
│   └── pixelwiseGT-new-PNG
│       └── ...
├── DUTS
│   ├── DUTS-TE
│   │   ├── DUTS-TE-Image
│   │   │   └── ...
│   │   └── DUTS-TE-Mask
│   │       └── ...
│   └── DUTS-TR
│       ├── DUTS-TR-Image
│       │   └── ...
│       └── DUTS-TR-Mask
│           └── ...
├── ECSSD
│   ├── ground_truth_mask
│   │   └── ...
│   └── images
│       └── ...
├── CUB_200_2011
│   ├── train_images
│   │   └── ...
│   ├── train_segmentations
│   │   └── ...
│   ├── test_images
│   │   └── ...
│   └── test_segmentations
│       └── ...
└── Flowers
    ├── train_images
    │   └── ...
    ├── train_segmentations
    │   └── ...
    ├── test_images
    │   └── ...
    └── test_segmentations
        └── ...

The datasets can be downloaded from:

Training

Before training, make sure you understand the general approach (explained above).

Note: All commands are called from within the src directory.

In the example commands below, we use BigBiGAN. You can easily switch out BigBiGAN for another model if you would like to.

Optimization

PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME

This should take less than 5 minutes to run. The output will be saved in outputs/optimization/fixed-BigBiGAN-NAME/DATE/, with the final checkpoint in latest.pth.

Segmentation with precomputed generations

The recommended way of training is to generate the data first and train afterward. An example generation script would be:

PYTHONPATH=. python segmentation/generate_and_save.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.save_dir="YOUR_OUTPUT_DIR" \
data_gen.save_size=1000000 \
data_gen.kwargs.batch_size=1 \
data_gen.kwargs.generation_batch_size=128

This will generate 1 million image-label pairs and save them to YOUR_OUTPUT_DIR/images. Note that YOUR_OUTPUT_DIR should be an absolute path, not a relative one, because Hydra changes the working directory. You may also want to tune the generation_batch_size to maximize GPU utilization on your machine. It takes around 3-4 hours to generate 1 million images on a single V100 GPU.

Once you have generated data, you can train a segmentation model:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=saved \
data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE"

It takes around 3 hours on 1 GPU to complete 18000 iterations, by which point the model has converged (in fact you can probably get away with fewer steps, I would guess around ~5000).

Segmentation with on-the-fly generations

Alternatively, you can generate data while training the segmentation model. An example script would be:

PYTHONPATH=. python segmentation/main.py \
name=NAME \
data_gen=generated \
data_gen/generator=bigbigan \
data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \
data_gen.kwargs.generation_batch_size=128

Evaluation

To evaluate, set the train argument to False. For example:

python train.py \
name="eval" \
train=False \
eval_checkpoint=${checkpoint} \
data_seg.root=${DATASETS_DIR} 

Pretrained models

  • ... are coming soon!

Available GANs

It should be possible to use any GAN from pytorch-pretrained-gans, including:

Citation

@inproceedings{melaskyriazi2021finding,
  author    = {Melas-Kyriazi, Luke and Rupprecht, Christian and Laina, Iro and Vedaldi, Andrea},
  title     = {Finding an Unsupervised Image Segmenter in each of your Deep Generative Models},
  booktitle = arxiv,
  year      = {2021}
}
You might also like...
pytorch implementation of
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. The pytorch implementation of  DG-Font: Deformable Generative Networks for Unsupervised Font Generation
The pytorch implementation of DG-Font: Deformable Generative Networks for Unsupervised Font Generation

DG-Font: Deformable Generative Networks for Unsupervised Font Generation The source code for 'DG-Font: Deformable Generative Networks for Unsupervised

Minimal PyTorch implementation of Generative Latent Optimization from the paper
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

DeepCAD: A Deep Generative Network for Computer-Aided Design Models
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

TAug :: Time Series Data Augmentation using Deep Generative Models

TAug :: Time Series Data Augmentation using Deep Generative Models Note!!! The package is under development so be careful for using in production! Fea

Comments
  • pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    pip install git+https://github.com/lukemelas/pytorch-pretrained-gans

    Hi, is the repo in the pytorch-pretrained-gans step public or is that the right URL for it? I got prompted for username and password when I tried the pip install git+ and don't see the repo at that URL: https://github.com/lukemelas/pytorch-pretrained-gans (Get 404)

    Thanks.

    opened by ModMorph 2
  • Help producing results with the StyleGAN models

    Help producing results with the StyleGAN models

    Hi there!

    I'm having trouble producing meaningful results on StyleGAN2 on AFHQ. I've been using the default setup and hyperparameters. After 50 iterations (with the default batch size of 32) I get visualisations that look initially promising: (https://i.imgur.com/eR79Wyd.png). But as training progresses, and indeed when it reaches 300 iterations, these are the visualisation results: https://i.imgur.com/36zhBzT.png.

    I've tried playing with the learning rate, and the number of iterations with no success yet. Did you have tips here or ideas as to what might be going wrong here?

    Thanks! James.

    opened by james-oldfield 1
  • bug

    bug

    Firstly, I ran PYTHONPATH=. python optimization/main.py data_gen/generator=bigbigan name=NAME. And then, I ran PYTHONPATH=. python segmentation/generate_and_save.py \ name=NAME \ data_gen=generated \ data_gen/generator=bigbigan \ data_gen.checkpoint="YOUR_OPTIMIZATION_DIR_FROM_ABOVE/latest.pth" \ data_gen.save_dir="YOUR_OUTPUT_DIR" \ data_gen.save_size=1000000 \ data_gen.kwargs.batch_size=1 \ data_gen.kwargs.generation_batch_size=128 When I ran PYTHONPATH=. python segmentation/main.py \ name=NAME \ data_gen=saved \ data_gen.data.root="YOUR_OUTPUT_DIR_FROM_ABOVE" An error occurred. The error is: Traceback (most recent call last): File "segmentation/main.py", line 98, in main kwargs = dict(images_dir=_cfg.images_dir, labels_dir=_cfg.labels_dir, omegaconf.errors.InterpolationResolutionError: KeyError raised while resolving interpolation: "Environment variable '/raid/name/gaochengli/segmentation/src/images' not found" full_key: data_seg.data[0].images_dir object_type=dict According to what you wrote, I modified the root (config/segment.yaml on line 19). Just like this "/raid/name/gaochengli/segmentation/src/images". And the folder contains all data sets,whose name is images. I wonder why such a mistake happened.

    opened by Lee-Gao 1
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
Code and Data for the paper: Molecular Contrastive Learning with Chemical Element Knowledge Graph [AAAI 2022]

Knowledge-enhanced Contrastive Learning (KCL) Molecular Contrastive Learning with Chemical Element Knowledge Graph [ AAAI 2022 ]. We construct a Chemi

Fangyin 58 Dec 26, 2022
[PAMI 2020] Show, Match and Segment: Joint Weakly Supervised Learning of Semantic Matching and Object Co-segmentation

Show, Match and Segment: Joint Weakly Supervised Learning of Semantic Matching and Object Co-segmentation This repository contains the source code for

Yun-Chun Chen 60 Nov 25, 2022
AI Flow is an open source framework that bridges big data and artificial intelligence.

Flink AI Flow Introduction Flink AI Flow is an open source framework that bridges big data and artificial intelligence. It manages the entire machine

144 Dec 30, 2022
[CVPR22] Official codebase of Semantic Segmentation by Early Region Proxy.

RegionProxy Figure 2. Performance vs. GFLOPs on ADE20K val split. Semantic Segmentation by Early Region Proxy Yifan Zhang, Bo Pang, Cewu Lu CVPR 2022

Yifan 54 Nov 29, 2022
The official TensorFlow implementation of the paper Action Transformer: A Self-Attention Model for Short-Time Pose-Based Human Action Recognition

Action Transformer A Self-Attention Model for Short-Time Human Action Recognition This repository contains the official TensorFlow implementation of t

PIC4SeRCentre 20 Jan 03, 2023
PyMove is a Python library to simplify queries and visualization of trajectories and other spatial-temporal data

Use PyMove and go much further Information Package Status License Python Version Platforms Build Status PyPi version PyPi Downloads Conda version Cond

Insight Data Science Lab 64 Nov 15, 2022
NL-Augmenter 🦎 → 🐍 A Collaborative Repository of Natural Language Transformations

NL-Augmenter 🦎 → 🐍 The NL-Augmenter is a collaborative effort intended to add transformations of datasets dealing with natural language. Transformat

684 Jan 09, 2023
CCNet: Criss-Cross Attention for Semantic Segmentation (TPAMI 2020 & ICCV 2019).

CCNet: Criss-Cross Attention for Semantic Segmentation Paper Links: Our most recent TPAMI version with improvements and extensions (Earlier ICCV versi

Zilong Huang 1.3k Dec 27, 2022
Rendering color and depth images for ShapeNet models.

Color & Depth Renderer for ShapeNet This library includes the tools for rendering multi-view color and depth images of ShapeNet models. Physically bas

Yinyu Nie 41 Dec 19, 2022
FinEAS: Financial Embedding Analysis of Sentiment 📈

FinEAS: Financial Embedding Analysis of Sentiment 📈 (SentenceBERT for Financial News Sentiment Regression) This repository contains the code for gene

LHF Labs 31 Dec 13, 2022
small collection of functions for neural networks

neurobiba other languages: RU small collection of functions for neural networks. very easy to use! Installation: pip install neurobiba See examples h

4 Aug 23, 2021
PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

Berivan Isik 8 Dec 08, 2022
Development kit for MIT Scene Parsing Benchmark

Development Kit for MIT Scene Parsing Benchmark [NEW!] Our PyTorch implementation is released in the following repository: https://github.com/hangzhao

MIT CSAIL Computer Vision 424 Dec 01, 2022
Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, and finding their unique parameters (e.g. death rate).

DINN We introduce Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, a

19 Dec 10, 2022
Pytorch implementation for "Open Compound Domain Adaptation" (CVPR 2020 ORAL)

Open Compound Domain Adaptation [Project] [Paper] [Demo] [Blog] Overview Open Compound Domain Adaptation (OCDA) is the author's re-implementation of t

Zhongqi Miao 137 Dec 15, 2022
novel deep learning research works with PaddlePaddle

Research 发布基于飞桨的前沿研究工作,包括CV、NLP、KG、STDM等领域的顶会论文和比赛冠军模型。 目录 计算机视觉(Computer Vision) 自然语言处理(Natrual Language Processing) 知识图谱(Knowledge Graph) 时空数据挖掘(Spa

1.5k Dec 29, 2022
Code for "Localization with Sampling-Argmax", NeurIPS 2021

Localization with Sampling-Argmax [Paper] [arXiv] [Project Page] Localization with Sampling-Argmax Jiefeng Li, Tong Chen, Ruiqi Shi, Yujing Lou, Yong-

JeffLi 71 Dec 17, 2022
LeafSnap replicated using deep neural networks to test accuracy compared to traditional computer vision methods.

Deep-Leafsnap Convolutional Neural Networks have become largely popular in image tasks such as image classification recently largely due to to Krizhev

Sujith Vishwajith 48 Nov 27, 2022
A hand tracking demo made with mediapipe where you can control lights with pinching your fingers and moving your hand up/down.

HandTrackingBrightnessControl A hand tracking demo made with mediapipe where you can control lights with pinching your fingers and moving your hand up

Teemu Laurila 19 Feb 12, 2022
This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch

This computer program provides a reference implementation of Lagrangian Monte Carlo in metric induced by the Monge patch. The code was prepared to the final version of the accepted manuscript in AIST

Marcelo Hartmann 2 May 06, 2022