[CVPR 2022] Unsupervised Image-to-Image Translation with Generative Prior

Overview

GP-UNIT - Official PyTorch Implementation

This repository provides the official PyTorch implementation for the following paper:

Unsupervised Image-to-Image Translation with Generative Prior
Shuai Yang, Liming Jiang, Ziwei Liu and Chen Change Loy
In CVPR 2022.
Project Page | Paper | Supplementary Video

Abstract: Unsupervised image-to-image translation aims to learn the translation between two visual domains without paired data. Despite the recent progress in image translation models, it remains challenging to build mappings between complex domains with drastic visual discrepancies. In this work, we present a novel framework, Generative Prior-guided UNsupervised Image-to-image Translation (GP-UNIT), to improve the overall quality and applicability of the translation algorithm. Our key insight is to leverage the generative prior from pre-trained class-conditional GANs (e.g., BigGAN) to learn rich content correspondences across various domains. We propose a novel coarse-to-fine scheme: we first distill the generative prior to capture a robust coarse-level content representation that can link objects at an abstract semantic level, based on which fine-level content features are adaptively learned for more accurate multi-level content correspondences. Extensive experiments demonstrate the superiority of our versatile framework over state-of-the-art methods in robust, high-quality and diversified translations, even for challenging and distant domains.

Updates

  • [03/2022] Paper and supplementary video are released.
  • [04/2022] Code and dataset are released.
  • [03/2022] This website is created.

Installation

Clone this repo:

git clone https://github.com/williamyang1991/GP-UNIT.git
cd GP-UNIT

Dependencies:

We have tested on:

  • CUDA 10.1
  • PyTorch 1.7.0
  • Pillow 8.0.1; Matplotlib 3.3.3; opencv-python 4.4.0; Faiss 1.7.0; tqdm 4.54.0

All dependencies for defining the environment are provided in environment/gpunit_env.yaml. We recommend running this repository using Anaconda:

conda env create -f ./environment/gpunit_env.yaml

We use CUDA 10.1 so it will install PyTorch 1.7.0 (corresponding to Line 16, Line 113, Line 120, Line 121 of gpunit_env.yaml). Please install PyTorch that matches your own CUDA version following https://pytorch.org/.


(1) Dataset Preparation

Human face dataset, animal face dataset and aristic human face dataset can be downloaded from their official pages. Bird, dog and car datasets can be built from ImageNet with our provided script.

Task Used Dataset
Male←→Female CelebA-HQ: divided into male and female subsets by StarGANv2
Dog←→Cat←→Wild AFHQ provided by StarGANv2
Face←→Cat or Dog CelebA-HQ and AFHQ
Bird←→Dog 4 classes of birds and 4 classes of dogs in ImageNet291. Please refer to dataset preparation for building ImageNet291 from ImageNet
Bird←→Car 4 classes of birds and 4 classes of cars in ImageNet291. Please refer to dataset preparation for building ImageNet291 from ImageNet
Face→MetFace CelebA-HQ and MetFaces

(2) Inference for Latent-Guided and Exemplar-Guided Translation

Inference Notebook


To help users get started, we provide a Jupyter notebook at ./notebooks/inference_playground.ipynb that allows one to visualize the performance of GP-UNIT. The notebook will download the necessary pretrained models and run inference on the images in ./data/.

Web Demo

Try Replicate web demo here Replicate

Pretrained Models

Pretrained models can be downloaded from Google Drive or Baidu Cloud (access code: cvpr):

Task Pretrained Models
Prior Distillation content encoder
Male←→Female generators for male2female and female2male
Dog←→Cat←→Wild generators for dog2cat, cat2dog, dog2wild, wild2dog, cat2wild and wild2cat
Face←→Cat or Dog generators for face2cat, cat2face, dog2face and face2dog
Bird←→Dog generators for bird2dog and dog2bird
Bird←→Car generators for bird2car and car2bird
Face→MetFace generator for face2metface

The saved checkpoints are under the following folder structure:

checkpoint
|--content_encoder.pt     % Content encoder
|--bird2car.pt            % Bird-to-Car translation model
|--bird2dog.pt            % Bird-to-Dog translation model
...

Latent-Guided Translation

Translate a content image to the target domain with randomly sampled latent styles:

python inference.py --generator_path PRETRAINED_GENERATOR_PATH --content_encoder_path PRETRAINED_ENCODER_PATH \ 
                    --content CONTENT_IMAGE_PATH --batch STYLE_NUMBER --device DEVICE

By default, the script will use .\checkpoint\dog2cat.pt as PRETRAINED_GENERATOR_PATH, .\checkpoint\content_encoder.pt as PRETRAINED_ENCODER_PATH, and cuda as DEVICE for using GPU. For running on CPUs, use --device cpu.

Take Dog→Cat as an example, run:

python inference.py --content ./data/afhq/images512x512/test/dog/flickr_dog_000572.jpg --batch 6

Six results translation_flickr_dog_000572_N.jpg (N=0~5) are saved in the folder .\output\. An corresponding overview image translation_flickr_dog_000572_overview.jpg is additionally saved to illustrate the input content image and the six results:

Evaluation Metrics: We use the code of StarGANv2 to calculate FID and Diversity with LPIPS in our paper.

Exemplar-Guided Translation

Translate a content image to the target domain in the style of a style image by additionally specifying --style:

python inference.py --generator_path PRETRAINED_GENERATOR_PATH --content_encoder_path PRETRAINED_ENCODER_PATH \ 
                    --content CONTENT_IMAGE_PATH --style STYLE_IMAGE_PATH --device DEVICE

Take Dog→Cat as an example, run:

python inference.py --content ./data/afhq/images512x512/test/dog/flickr_dog_000572.jpg --style ./data/afhq/images512x512/test/cat/flickr_cat_000418.jpg

The result translation_flickr_dog_000572_to_flickr_cat_000418.jpg is saved in the folder .\output\. An corresponding overview image translation_flickr_dog_000572_to_flickr_cat_000418_overview.jpg is additionally saved to illustrate the input content image, the style image, and the result:

Another example of Cat→Wild, run:

python inference.py --generator_path ./checkpoint/cat2wild.pt --content ./data/afhq/images512x512/test/cat/flickr_cat_000418.jpg --style ./data/afhq/images512x512/test/wild/flickr_wild_001112.jpg

The overview image is as follows:


(3) Training GP-UNIT

Download the supporting models to the ./checkpoint/ folder:

Model Description
content_encoder.pt Our pretrained content encoder which distills BigGAN prior from the synImageNet291 dataset.
model_ir_se50.pth Pretrained IR-SE50 model taken from TreB1eN for ID loss.

Train Image-to-Image Transaltion Network

python train.py --task TASK --batch BATCH_SIZE --iter ITERATIONS \
                --source_paths SPATH1 SPATH2 ... SPATHS --source_num SNUM1 SNUM2 ... SNUMS \
                --target_paths TPATH1 TPATH2 ... TPATHT --target_num TNUM1 TNUM2 ... TNUMT

where SPATH1~SPATHS are paths to S folders containing images from the source domain (e.g., S classes of ImageNet birds), SNUMi is the number of images in SPATHi used for training. TPATHi, TNUMi are similarily defined but for the target domain. By default, BATCH_SIZE=16 and ITERATIONS=75000. If --source_num/--target_num is not specified, all images in the folders are used.

The trained model is saved as ./checkpoint/TASK-ITERATIONS.pt. Intermediate results are saved in ./log/TASK/.

This training does not necessarily lead to the optimal results, which can be further customized with additional command line options:

  • --style_layer (default: 4): the discriminator layer to compute the feature matching loss. We found setting style_layer=5 gives better performance on human faces.
  • --use_allskip (default: False): whether using dynamic skip connections to compute the reconstruction loss. For tasks involving close domains like gender translation, season transfer and face stylization, using use_allskip gives better results.
  • --use_idloss (default: False): whether using the identity loss. For Cat/Dog→Face and Face→MetFace tasks, we use this loss.
  • --not_flip_style (default: False): whether not randomly flipping the style image when extracting the style feature. Random flipping prevents the network to learn position information from the style image.
  • --mitigate_style_bias(default: False): whether resampling style features when training the sampling network. For imbalanced dataset that has minor groups, mitigate_style_bias oversamples those style features that are far from the mean style feature of the whole dataset. This leads to more diversified latent-guided translation at the cost of slight image quality degradation. We use it on CelebA-HQ and AFHQ-related tasks.

Here are some examples:
(Parts of our tasks require the ImageNet291 dataset. Please refer to data preparation)

Male→Female

python train.py --task male2female --source_paths ./data/celeba_hq/train/male --target_paths ./data/celeba_hq/train/female --style_layer 5 --mitigate_style_bias --use_allskip --not_flip_style

Cat→Dog

python train.py --task cat2dog --source_paths ./data/afhq/images512x512/train/cat --source_num 4000 --target_paths ./data/afhq/images512x512/train/dog --target_num 4000 --mitigate_style_bias

Cat→Face

python train.py --task cat2face --source_paths ./data/afhq/images512x512/train/cat --source_num 4000 --target_paths ./data/ImageNet291/train/1001_face/ --style_layer 5 --mitigate_style_bias --not_flip_style --use_idloss

Bird→Car (translating 4 classes of birds to 4 classes of cars)

python train.py --task bird2car --source_paths ./data/ImageNet291/train/10_bird/ ./data/ImageNet291/train/11_bird/ ./data/ImageNet291/train/12_bird/ ./data/ImageNet291/train/13_bird/ --source_num 600 600 600 600 --target_paths ./data/ImageNet291/train/436_vehicle/ ./data/ImageNet291/train/511_vehicle/ ./data/ImageNet291/train/627_vehicle/ ./data/ImageNet291/train/656_vehicle/ --target_num 600 600 600 600

Train Content Encoder of Prior Distillation

We provide our pretrained model content_encoder.pt at Google Drive or Baidu Cloud (access code: cvpr). This model is obtained by:

python prior_distillation.py --unpaired_data_root ./data/ImageNet291/train/ --paired_data_root ./data/synImageNet291/train/ --unpaired_mask_root ./data/ImageNet291_mask/train/ --paired_mask_root ./data/synImageNet291_mask/train/

The training requires ImageNet291 and synImageNet291 datasets. Please refer to data preparation.


Results

Male-to-Female: close domains

male2female

Cat-to-Dog: related domains

cat2dog

Dog-to-Human and Bird-to-Dog: distant domains

dog2human

bird2dog

Bird-to-Car: extremely distant domains for stress testing

bird2car

Citation

If you find this work useful for your research, please consider citing our paper:

@inproceedings{yang2022Unsupervised,
  title={Unsupervised Image-to-Image Translation with Generative Prior},
  author={Yang, Shuai and Jiang, Liming and Liu, Ziwei and Loy, Chen Change},
  booktitle={CVPR},
  year={2022}
}

Acknowledgments

The code is developed based on StarGAN v2, SPADE and Imaginaire.

Owner
Code for the ECIR'22 paper "Evaluating the Robustness of Retrieval Pipelines with Query Variation Generators"

Query Variation Generators This repository contains the code and annotation data for the ECIR'22 paper "Evaluating the Robustness of Retrieval Pipelin

Gustavo Penha 12 Nov 20, 2022
The source code of "SIDE: Center-based Stereo 3D Detector with Structure-aware Instance Depth Estimation", accepted to WACV 2022.

SIDE: Center-based Stereo 3D Detector with Structure-aware Instance Depth Estimation The source code of our work "SIDE: Center-based Stereo 3D Detecto

10 Dec 18, 2022
Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection

Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection Main requirements torch = 1.0 torchvision = 0.2.0 Python 3 Environm

15 Apr 04, 2022
Film review classification

Film review classification Решение задачи классификации отзывов на фильмы на положительные и отрицательные с помощью рекуррентных нейронных сетей 1. З

Nikita Dukin 3 Jan 21, 2022
A python library to artfully visualize Factorio Blueprints and an interactive web demo for using it.

Factorio Blueprint Visualizer I love the game Factorio and I really like the look of factories after growing for many hours or blueprints after tweaki

Piet Brömmel 124 Jan 07, 2023
Official implementation of "SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers"

SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers Figure 1: Performance of SegFormer-B0 to SegFormer-B5. Project page

NVIDIA Research Projects 1.4k Dec 31, 2022
Code for our paper "MG-GAN: A Multi-Generator Model Preventing Out-of-Distribution Samples in Pedestrian Trajectory Prediction" published at ICCV 2021.

MG-GAN: A Multi-Generator Model Preventing Out-of-Distribution Samples in Pedestrian Trajectory Prediction This repository contains the code for the p

Sven 30 Jan 05, 2023
Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks

Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks Work accepted at NeurIPS'21 [paper, video]. If you use this code in

TU Delft 43 Dec 07, 2022
FG-transformer-TTS Fine-grained style control in transformer-based text-to-speech synthesis

LST-TTS Official implementation for the paper Fine-grained style control in transformer-based text-to-speech synthesis. Submitted to ICASSP 2022. Audi

Li-Wei Chen 64 Dec 30, 2022
A Fast Knowledge Distillation Framework for Visual Recognition

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Zhiqiang Shen 129 Dec 24, 2022
STMTrack: Template-free Visual Tracking with Space-time Memory Networks

STMTrack This is the official implementation of the paper: STMTrack: Template-free Visual Tracking with Space-time Memory Networks. Setup Prepare Anac

Zhihong Fu 62 Dec 21, 2022
Record radiologists' eye gaze when they are labeling images.

Record radiologists' eye gaze when they are labeling images. Read for installation, usage, and deep learning examples. Why use MicEye Versatile As a l

24 Nov 03, 2022
A deep learning CNN model to identify and classify and check if a person is wearing a mask or not.

Face Mask Detection The Model is designed to check if any human is wearing a mask or not. Dataset Description The Dataset contains a total of 11,792 i

1 Mar 01, 2022
Pytorch implementation for DFN: Distributed Feedback Network for Single-Image Deraining.

DFN:Distributed Feedback Network for Single-Image Deraining Abstract Recently, deep convolutional neural networks have achieved great success for sing

6 Nov 05, 2022
Xview3 solution - XView3 challenge, 2nd place solution

Xview3, 2nd place solution https://iuu.xview.us/ test split aggregate score publ

Selim Seferbekov 24 Nov 23, 2022
Contextual Attention Localization for Offline Handwritten Text Recognition

CALText This repository contains the source code for CALText model introduced in "CALText: Contextual Attention Localization for Offline Handwritten T

0 Feb 17, 2022
Code for paper: Towards Tokenized Human Dynamics Representation

Video Tokneization Codebase for video tokenization, based on our paper Towards Tokenized Human Dynamics Representation. Prerequisites (tested under Py

Kenneth Li 20 May 31, 2022
PySLM Python Library for Selective Laser Melting and Additive Manufacturing

PySLM Python Library for Selective Laser Melting and Additive Manufacturing PySLM is a Python library for supporting development of input files used i

Dr Luke Parry 35 Dec 27, 2022
Ankou: Guiding Grey-box Fuzzing towards Combinatorial Difference

Ankou Ankou is a source-based grey-box fuzzer. It intends to use a more rich fitness function by going beyond simple branch coverage and considering t

SoftSec Lab 54 Dec 24, 2022
Code repository for paper `Skeleton Merger: an Unsupervised Aligned Keypoint Detector`.

Skeleton Merger Skeleton Merger, an Unsupervised Aligned Keypoint Detector. The paper is available at https://arxiv.org/abs/2103.10814. A map of the r

北海若 48 Nov 14, 2022