PuppetGAN - Cross-Domain Feature Disentanglement and Manipulation just got way better! 🚀

Overview

Better Cross-Domain Feature Disentanglement and Manipulation with Improved PuppetGAN

Quite cool... Right?

Introduction

This repo contains a TensorFlow implementation of PuppetGAN as well as an improved version of it, capable of manipulating features up to 100% better and up to 300% faster! 😎

PuppetGAN is model that extends the CycleGAN idea and is capable of extracting and manipulating features from a domain using examples from a different domain. On top of that, one amazing aspect of PuppetGAN is that it does not require a great amount of data; the biggest dataset I used contained 5000 sets of examples while the smallest one just slightly over 1000 sets of examples!

The Model(s)

Overview

PuppetGAN consists of 4 different components; one that is responsible for learning to reconstruct the input images, one that is responsible for learning to disentangle the the Attribute of Interest, a CycleGAN component and an Attribute CycleGAN. The Attribute CycleGAN acts in a similar manner to CycleGAN with the exception that it deals with cross-domain inputs.



The full architecture of the baseline PuppetGAN (the image is copied from the original paper)

With this repo I add a few more components, which I call Roids, that greatly improve the performance of the Baseline PuppetGAN. One Roid is applied in the disentanglement part and the rest in the attribute cycle part while the objective of all of them is pretty much the same; to guarantee better disentanglement!

  • The original architecture performs the disentanglement only in the synthetic domain and this ability is passed to the real domain through implicitly. The disentanglement Roid takes advantage of the CycleGAN model and performs the disentanglement in the translations of the synthetic images passing the ability explicitly to the real domain.

  • The attribute cycle Roids act in a similar way, but they instead force the attributes, other that the Attribute of Interest, of the cross-domain translations to be as precise as possible. This can be seen as a more strict version of the disentanglement Roid as well.



The Disentanglement Roid



The Attribute Cycle Roids

Implementation

The only difference between my Baseline and the model from the paper is that my generators and discriminators are modified versions of the ones used in TensorFlow's CycleGAN tutorial. The fact that the creators of PuppetGAN used ResNet blocks may be partially responsible for the memorization effect that seems to be present in some of the results of the paper since the skip connections can allow information to be passed unchanged between different layers.

Other than that, all my implementations use exactly the same parameters as the ones in the original model. Also, neither my architectures nor the parameters have been modified at all between different datasets.

Performance

Both my Baseline implementation and my proposed architecture(s) significantly outperform the original PuppetGAN!

Rotation of MNIST digits

By the Numbers

Just like in the original paper, all the reported scores are for the MNIST dataset. Due to the fact that I didn't have access to the size dataset, I was able to measure the performance of my models only in the rotation dataset.

PuppetGAN Accuracy Epoch
Original (paper) 0.97 0.40 0.01 -
My Baseline 0.96 0.59 0.01 300
Roids in Attribute Cycle Component 0.97 0.82 0.02 100
Roids in Disentanglement Component 0.91 0.73 0.01 250
Roids in Both Components 0.97 0.79 0.01 300
  • Accuracy (The closer to 1 the better)

The accuracy measures, using a LeNet-5 network, how well the original class is preserved. In other words, this metric is indicative of how well the model manages to disentangle without affecting the rest of the attributes. As we'll see later it is possible though to get very high accuracy while having suboptimal disentanglement performance...

  • (The closer to 1 the better)

This score is the correlation coefficient between the Attribute of Interest between the known and the generated images and it captures how well the model manipulates the Attribute of Interest.

  • (The closer to 0 the better)

This score captures how similar are the results between images that have identical the Attribute of Interest and different the rest of the attributes. For this metric I report the standard deviation instead of the variance, that it is mentioned in the paper, due to the fact that the variance of my models was magnitudes smaller than the one reported on the paper. This makes me believe that the standard deviation was used in the paper as well.

Discussion about the Results

Mouth manipulation after 440 epochs, using the Baseline.

Mouth manipulation after 190 epochs with Roids in the Attribute Cycle component. The model learns to both open and close the mouth more accurately, disentangle in a better way, produce more clear images and all that way faster!

The most well balanced model seems to be one that uses both kinds of Roids, since it achieves the same accuracy and score as the original model while increasing the manipulation score by more than 30% compared to my Baseline implementation and almost 100% compared to the original paper. Nevertheless, although it is intuitive that a combination of all the Roids would yield better results, I believe that more experiments are required to determine if its benefits are sufficient to outweigh the great speed up of the model that uses Roids only in the Attribute Cycle component.

MNIST rotation after adding Roids on the Attribute Cycle component

For now, I would personally favor the model that uses only the Roids of the Attribute Cycle component due to the fact that it manages to outperform every other model in the AoI manipulation score at 1/3 of the time, while having an insignificant difference in the value of . As an extra trick, I found that not updating the discriminator in the Attribute Cycle Roids could improve the performance slightly, but that's just an additional hack.

Each Roid implicitly affects the weight of its respective loss due to the fact that extra terms are added to it. In order to ensure that the performance boost is not caused by the increased loss weight, I am providing a comparison between the performance of the model with the Roids in the Attribute Cycle component and the Baseline model with twice the weights of the Attribute Cycle Component.

PuppetGAN Accuracy Epoch
Original (paper) 0.97 0.40 0.01 -
My Baseline 0.96 0.59 0.01 300
Weighted Baseline 0.84 0.85 0.01 100
Weighted Baseline 0.93 0.72 0.01 150
Weighted Baseline 0.92 0.68 0.01 200
Weighted Baseline 0.95 0.63 0.01 300
Roids in Attribute Cycle Component 0.97 0.82 0.02 100

The above results show that increasing the weights of the Attribute Cycle losses can slightly increase the performance of PuppetGAN, but such a model would be comparable to the Baseline and not to the model that utilizes the Roids.

Comparison to the original results

A significant drawback of the original model is that seems to memorizes seen images instead of editing the given ones. This can be observed in the rotation results reported in the paper where the representation of a real digit may change during the rotation or different representations of a real digit may have the same rotated representations. This doesn't stop it though from having a very high accuracy, which highlights why this metric is not necessarily ideal for calculating the quality of the disentanglement.

The rotation results of the paper

Another issue with both the model of the paper and my models can be observed in the mouth dataset, where PuppetGAN confuses the microphone with the opening of the mouth; when the synthetic image dictates a wider opening, PuppetGAN moves the microphone closer to the mouth. This effect is slightly bigger in my Baseline but I believe that it is due to the fact that I haven't done any hyper-parameter tuning; some experimentation with the magnitude of the noise or with the weights of the different components could eliminate it. Also, the model with Roids in the Attribute of Interest seems to deal with issue better than the Baseline.

Running the Code

You can manage all the dependencies with Pipenv using the provided Pipfile. This allows for easier reproducibility of the code due to the fact that Pipenv creates a virtual environment containing all the necessary libraries. Just run pipenv shell in the base directory of the project and you're ready to go!

On the other hand, if for any reason you don't want to use Pipenv you can install all the required libraries using the provided requirements.txt file. Neither this file nor Pipenv take care of cuda though; in all my experiments I used cuda 7.5.18.

In order to download the datasets, you can use the fetch_data.sh script which downloads and extracts them in the correct directory, running:

. fetch_data.sh

Unfortunately, I am not allowed to publish any dataset other than MNIST, but feel free to ask the authors of the original PuppetGAN for them, following the instructions on their website 🙂 .

Training a Model

To start a new training, simply run:

python3 main.py

This will automatically look first for any existing checkpoints and will restore the latest one. If you want to continue the training from a specific checkpoint just run:

python3 main.py -c [checkpoint number]

or

python3 main.py --ckpt=[checkpoint number]

To help you keep better track of your work, every time you start a new training, a configuration report is created in ./PuppetGAN/results/config.txt which stores a detailed report of your current configuration. This report contains all your hyper-parameters and their respective values as well as the whole architecture of the model you are using, including every single layer, its parameters and how it is connected to the rest of the model.

Also, to help you keep better track of your process, every a certain number of epochs my model creates in ./PuppetGAN/results a sample of evaluation rows of generated images along with gif animations for these rows to visualize better the performance of your model.

On top of that, in ./PuppetGAN/results are also stored plots of both the supervised and the adversarial losses as well as the images that are produced during the training. This allows you to have in a single folder everything you need to evaluate an experiment, keep track of its progress and reproduce its results!

Unless you want to experiment with different architectures, PuppetGAN/config.json is the only file you'll need. This file allows you to control all the hyper-parameters of the model without having to look at any of code! More specifically, the parameters you can control are:

  • dataset : The dataset to use. You can choose between "mnist", "mouth" and "light".

  • epochs : The number of epochs that the model will be trained for.

  • noise std : The standard deviation of the noise that will be applied to the translated images. The mean of the noise is 0.

  • bottleneck noise : The standard deviation of the noise that will be applied to the bottleneck. The mean of the noise is 0.

  • on roids : Whether or not to use the proposed Roids.

  • learning rates-real generator : The learning rate of the real generator.

  • learning rates-real discriminator : The learning rate of the real discriminator

  • learning rates-synthetic generator : The learning rate of the synthetic generator.

  • learning rate-synthetic discriminator : The learning rate of the synthetic discriminator.

  • losses weights-reconstruction : The weight of the reconstruction loss.

  • losses weights-disentanglement : The weight of the disentanglement loss.

  • losses weights-cycle : The weight of the cycle loss.

  • losses weights-attribute cycle b3 : The weight of part of the attribute cycle loss that is a function of the synthetic image that has both the Attribute of Interest and all the rest of the attributes.

  • losses weights-attribute cycle a : The weight of part of the attribute cycle loss that is a function of the real image.

  • batch size : The batch size. Depending on the kind of the dataset different values can be given.

  • image size : At what size to resize the images of the dataset.

  • save images every : Every how many epochs to save the training images and the sample of the evaluation images.

  • save model every : Every how many epochs to create a checkpoint. Keep in mind that the 5 latest checkpoints are always kept during training.

Evaluation of a Model

You can start an evaluation just by running:

python3 main.py -t

or

python3 main.py --test

Just like with training, this will look for the latest checkpoint; if you want to evaluate the performance of a different checkpoint you can simply use the -c and --ckpt options the same way as before.

During the evaluation process, the model creates all the rows of the generated images, where each cell corresponds to the generated image for the respective synthetic and a real input. Additionally, for each of the evaluation images, their corresponding gif file is also created to help you get a better idea of your results!

If you want to calculate the scores of your model in the MNIST dataset you can use my ./PuppetGAN/eval_rotation.py script, by running:

python3 eval_rotation.py -p [path to the directory of your evaluation images]

or

python3 eval_rotation.py -path=[path to the directory of your evaluation images]

You can also specify a path to save the evaluation report file using the option -t or --target-path. For example, let's say you have just trained and produced the evaluation images for a model and you want to get the evaluation scores for epoch 100 and save the report in the folder of this epoch. Then you should just run:

# make sure you are in ./PuppetGAN
python3 eval_rotation.py -p results/test/100/images -t results/test/100

For a fair comparison I am also providing the checkpoint of my LeNet-5 network in ./PuppetGAN/checkpoints/lenet5. If the eval_rotation.py script doesn't detect the checkpoint it will train one from scratch and in this case there may be a small difference in the accuracy of your model.

Owner
Giorgos Karantonis
Passionate about AI, ML, DL and other abbreviations.
Giorgos Karantonis
Minimal But Practical Image Classifier Pipline Using Pytorch, Finetune on ResNet18, Got 99% Accuracy on Own Small Datasets.

PyTorch Image Classifier Updates As for many users request, I released a new version of standared pytorch immage classification example at here: http:

JinTian 106 Nov 06, 2022
[WWW 2021] Source code for "Graph Contrastive Learning with Adaptive Augmentation"

GCA Source code for Graph Contrastive Learning with Adaptive Augmentation (WWW 2021) For example, to run GCA-Degree under WikiCS, execute: python trai

Big Data and Multi-modal Computing Group, CRIPAC 97 Jan 07, 2023
SGoLAM - Simultaneous Goal Localization and Mapping

SGoLAM - Simultaneous Goal Localization and Mapping PyTorch implementation of the MultiON runner-up entry, SGoLAM: Simultaneous Goal Localization and

10 Jan 05, 2023
SciFive: a text-text transformer model for biomedical literature

SciFive SciFive provided a Text-Text framework for biomedical language and natural language in NLP. Under the T5's framework and desrbibed in the pape

Long Phan 54 Dec 24, 2022
Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language (NeurIPS 2021)

VRDP (NeurIPS 2021) Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language Mingyu Ding, Zhenfang Chen, Tao Du, Pin

Mingyu Ding 36 Sep 20, 2022
Uncertain natural language inference

Uncertain Natural Language Inference This repository hosts the code for the following paper: Tongfei Chen*, Zhengping Jiang*, Adam Poliak, Keisuke Sak

Tongfei Chen 14 Sep 01, 2022
⚾🤖⚾ Automatic baseball pitching overlay in realtime

âš¾ Automatically overlaying pitch motion and trajectory with machine learning! This project takes your baseball pitching clips and automatically genera

Tony Chou 240 Dec 05, 2022
Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Roxbili 5 Nov 19, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)

Protein GLM (wip) Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capabil

Phil Wang 17 May 06, 2022
Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral)

Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral) Tianyu Wang*, Xiaowei Hu*, Chi-Wing Fu, and Pheng-Ann Hen

Steve Wong 51 Oct 20, 2022
An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge.

Bottom-Up and Top-Down Attention for Visual Question Answering An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge. The

Hengyuan Hu 731 Jan 03, 2023
Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers.

Less is More: Pay Less Attention in Vision Transformers Official PyTorch implementation of Less is More: Pay Less Attention in Vision Transformers. By

73 Jan 01, 2023
MEDS: Enhancing Memory Error Detection for Large-Scale Applications

MEDS: Enhancing Memory Error Detection for Large-Scale Applications Prerequisites cmake and clang Build MEDS supporting compiler $ make Build Using Do

Secomp Lab at Purdue University 34 Dec 14, 2022
LBBA-boosted WSOD

LBBA-boosted WSOD Summary Our code is based on ruotianluo/pytorch-faster-rcnn and WSCDN Sincerely thanks for your resources. Newer version of our code

Martin Dong 20 Sep 19, 2022
Cross-modal Deep Face Normals with Deactivable Skip Connections

Cross-modal Deep Face Normals with Deactivable Skip Connections Victoria Fernández Abrevaya*, Adnane Boukhayma*, Philip H. S. Torr, Edmond Boyer (*Equ

72 Nov 27, 2022
This repo contains the code required to train the multivariate time-series Transformer.

Multi-Variate Time-Series Transformer This repo contains the code required to train the multivariate time-series Transformer. Download the data The No

Gregory Duthé 4 Nov 24, 2022
This is RFA-Toolbox, a simple and easy-to-use library that allows you to optimize your neural network architectures using receptive field analysis (RFA) and create graph visualizations of your architecture.

ReceptiveFieldAnalysisToolbox This is RFA-Toolbox, a simple and easy-to-use library that allows you to optimize your neural network architectures usin

84 Nov 23, 2022
This is a collection of our NAS and Vision Transformer work.

AutoML - Neural Architecture Search This is a collection of our AutoML-NAS work iRPE (NEW): Rethinking and Improving Relative Position Encoding for Vi

Microsoft 832 Jan 08, 2023
Dist2Dec: A Simplicial Neural Network for Homology Localization

Dist2Dec: A Simplicial Neural Network for Homology Localization

Alexandros Keros 6 Jun 12, 2022