Few-shot Image Generation via Cross-domain Correspondence
Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zhang
Adobe Research, UC Davis, UC Berkeley
PyTorch implementation of adapting a source GAN (trained on a large dataset) to a target domain using very few images.
Project page | Paper
Overview
Our method helps adapt the source GAN where one-to-one correspondence is preserved between the source Gs(z) and target Gt(z) images.
Requirements
Note The base model is taken from StyleGAN2's implementation by @rosinality.
- Linux
- NVIDIA GPU + CUDA CuDNN 10.2
- PyTorch 1.7.0
- Python 3.6.9
- Install all the other libraries through
pip install -r requirements.txt
Testing
Currently, we are providing different sets of images, using which the quantitative results in Table 1 and 2 are presented.
Evaluating FID
There are three sets of images which are used to get the results in Table 1:
- A set of real images from a target domain -- Rtest
- 10 images from the above set (Rtest) used to train the algorithm -- Rtrain
- 5000 generated images using the GAN-based method -- F
The following table provides a link to each of these images:
Rtrain | Rtest | F | |
---|---|---|---|
Babies | link | link | link |
Sunglasses | link | link | link |
Sketches | link | link | link |
Rtrain is given just to illustate what the algorithm sees, and won't be used for computing the FID score.
Download, and unzip the set of images into your desired directory, and compute the FID score (taken from pytorch-fid) between the real (Rtest) and fake (F) images, by running the following command
python -m pytorch_fid /path/to/real/images /path/to/fake/images
Evaluating intra-cluster distance
Download the entire set of images from here (1.1 GB), which are used for the results in Table 2. The organization of this collection is as follows:
cluster_centers
└── amedeo # target domain -- will be from [amedeo, sketches]
└── ours # method -- will be from [tgan, tgan_ada, freezeD, ewc, ours]
└── c0 # center id -- there will be 10 clusters [c0, c1 ... c9]
├── center.png # cluster center -- this is one of the 10 training images used. Each cluster will have its own center
│── img0.png # generated images which matched with this cluster's center, according to LPIPS distance.
│── img1.png
│ .
│ .
Unzip the file, and then run the following command to compute the results for a baseline on a dataset:
CUDA_VISIBLE_DEVICES=0 python3 feat_cluster.py --baseline <baseline> --dataset <target_domain> --mode intra_cluster_dist
CUDA_VISIBLE_DEVICES=0 python3 feat_cluster.py --baseline tgan --dataset sketches --mode intra_cluster_dist
We also provide the utility to visualize the closest and farthest members of a cluster, as shown in Figure 14 (shown below), using the following command:
CUDA_VISIBLE_DEVICES=0 python3 feat_cluster.py --baseline tgan --dataset sketches --mode visualize_members
The command will save the generated image which is closest/farthest to/from a center as closest.png
/farthest.png
respectively.
Note We cannot share the images for the caricature domain due to license issues.
More results coming soon..
Bibtex
@inproceedings{ojha2021few-shot-gan,
title={Few-shot Image Generation via Cross-domain Correspondence},
author={Ojha, Utkarsh and Li, Yijun and Lu, Cynthia and Efros, Alexei A. and Lee, Yong Jae and Shechtman, Eli and Zhang, Richard},
booktitle={CVPR},
year={2021}
}
Acknowledgment
As mentioned before, the StyleGAN2 model is borrowed from this wonderful pytorch implementation by @rosinality. We are also thankful to @mseitzer and @richzhang for their user friendly implementations of computing FID score and LPIPS metric.