[CVPR 2021] Teachers Do More Than Teach: Compressing Image-to-Image Models (CAT)

Overview

CAT

arXiv

Pytorch implementation of our method for compressing image-to-image models.
Teachers Do More Than Teach: Compressing Image-to-Image Models
Qing Jin1, Jian Ren2, Oliver J. Woodford, Jiazhuo Wang2, Geng Yuan1, Yanzhi Wang1, Sergey Tulyakov2
1Northeastern University, 2Snap Inc.
In CVPR 2021.

Overview

Compression And Teaching (CAT) framework for compressing image-to-image models: ① Given a pre-trained teacher generator Gt, we determine the architecture of a compressed student generator Gs by eliminating those channels with smallest magnitudes of batch norm scaling factors. ② We then distill knowledge from the pretrained teacher Gt on the student Gs via a novel distillation technique, which maximize the similarity between features of both generators, defined in terms of kernel alignment (KA).

Prerequisites

  • Linux
  • Python 3
  • CPU or NVIDIA GPU + CUDA CuDNN

Getting Started

Installation

  • Clone this repo:

    git clone [email protected]:snap-research/CAT.git
    cd CAT
  • Install PyTorch 1.7 and other dependencies (e.g., torchvision).

    • For pip users, please type the command pip install -r requirements.txt.
    • For Conda users, please create a new Conda environment using conda env create -f environment.yml.

Data Preparation

CycleGAN

Setup

  • Download the CycleGAN dataset (e.g., horse2zebra).

    bash datasets/download_cyclegan_dataset.sh horse2zebra
  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistic information for several datasets on Google Drive Folder.

Pix2pix

Setup

  • Download the pix2pix dataset (e.g., cityscapes).

    bash datasets/download_pix2pix_dataset.sh cityscapes

Cityscapes Dataset

For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script prepare_cityscapes_dataset.py to preprocess it. You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip and unzip them in the same folder. For example, you may put gtFine and leftImg8bit in database/cityscapes-origin. You need to prepare the dataset with the following commands:

python datasets/get_trainIds.py database/cityscapes-origin/gtFine/
python datasets/prepare_cityscapes_dataset.py \
--gtFine_dir database/cityscapes-origin/gtFine \
--leftImg8bit_dir database/cityscapes-origin/leftImg8bit \
--output_dir database/cityscapes \
--table_path datasets/table.txt

You will get a preprocessed dataset in database/cityscapes and a mapping table (used to compute mIoU) in dataset/table.txt.

  • Get the statistical information for the ground-truth images for your dataset to compute FID. We provide pre-prepared real statistics for several datasets. For example,

    bash datasets/download_real_stat.sh cityscapes A

Evaluation Preparation

mIoU Computation

To support mIoU computation, you need to download a pre-trained DRN model drn-d-105_ms_cityscapes.pth from http://go.yf.io/drn-cityscapes-models. By default, we put the drn model in the root directory of our repo. Then you can test our compressed models on cityscapes after you have downloaded our compressed models.

FID/KID Computation

To compute the FID/KID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script get_real_stat.py to extract statistical information. For example, for the map2arial dataset, you could run the following command:

python get_real_stat.py \
--dataroot database/map2arial \
--output_path real_stat/maps_B.npz \
--direction AtoB

For paired image-to-image translation (pix2pix and GauGAN), we calculate the FID between generated test images to real test images. For unpaired image-to-image translation (CycleGAN), we calculate the FID between generated test images to real training+test images. This allows us to use more images for a stable FID evaluation, as done in previous unconditional GANs research. The difference of the two protocols is small. The FID of our compressed CycleGAN model increases by 4 when using real test images instead of real training+test images.

KID is not supported for the cityscapes dataset.

Model Training

Teacher Training

The first step of our framework is to train a teacher model. For this purpose, please run the script train_inception_teacher.sh under the correponding folder named as the dataset, for example, run

bash scripts/cycle_gan/horse2zebra/train_inception_teacher.sh

Student Training

With the pretrained teacher model, we can determine the architecture of student model under prescribed computational budget. For this purpose, please run the script train_inception_student_XXX.sh under the correponding folder named as the dataset, where XXX stands for the computational budget (in terms of FLOPs for this case) and can be different for different datasets and models. For example, for CycleGAN with Horse2Zebra dataset, our computational budget is 2.6B FLOPs, so we run

bash scripts/cycle_gan/horse2zebra/train_inception_student_2p6B.sh

Pre-trained Models

For convenience, we also provide pretrained teacher and student models on Google Drive Folder.

Model Evaluation

With pretrained teacher and student models, we can evaluate them on the dataset. For this purpose, please run the script evaluate_inception_student_XXX.sh under the corresponding folder named as the dataset, where XXX is the computational budget (in terms of FLOPs). For example, for CycleGAN with Horse2Zebra dataset where the computational budget is 2.6B FLOPs, please run

bash scripts/cycle_gan/horse2zebra/evaluate_inception_student_2p6B.sh

Model Export

The final step is to export the trained compressed model as onnx file to run on mobile devices. For this purpose, please run the script onnx_export_inception_student_XXX.sh under the corresponding folder named as the dataset, where XXX is the computational budget (in terms of FLOPs). For example, for CycleGAN with Horse2Zebra dataset where the computational budget is 2.6B FLOPs, please run

bash scripts/cycle_gan/horse2zebra/onnx_export_inception_student_2p6B.sh

This will create one .onnx file in addition to log files.

Citation

If you use this code for your research, please cite our paper.

@inproceedings{jin2021teachers,
  title={Teachers Do More Than Teach: Compressing Image-to-Image Models},
  author={Jin, Qing and Ren, Jian and Woodford, Oliver J and Wang, Jiazhuo and Yuan, Geng and Wang, Yanzhi and Tulyakov, Sergey},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2021}
}

Acknowledgements

Our code is developed based on AtomNAS and gan-compression.

We also thank pytorch-fid for FID computation and drn for mIoU computation.

Owner
Snap Research
Snap Research
Official PyTorch implementation of "Improving Face Recognition with Large AgeGaps by Learning to Distinguish Children" (BMVC 2021)

Inter-Prototype (BMVC 2021): Official Project Webpage This repository provides the official PyTorch implementation of the following paper: Improving F

Jungsoo Lee 16 Jun 30, 2022
dualFace: Two-Stage Drawing Guidance for Freehand Portrait Sketching (CVMJ)

dualFace dualFace: Two-Stage Drawing Guidance for Freehand Portrait Sketching (CVMJ) We provide python implementations for our CVM 2021 paper "dualFac

Haoran XIE 46 Nov 10, 2022
🤗 Paper Style Guide

🤗 Paper Style Guide (Work in progress, send a PR!) Libraries to Know booktabs natbib cleveref Either seaborn, plotly or altair for graphs algorithmic

Hugging Face 66 Dec 12, 2022
LSTM-VAE Implementation and Relevant Evaluations

LSTM-VAE Implementation and Relevant Evaluations Before using any file in this repository, please create two directories under the root directory name

Lan Zhang 5 Oct 08, 2022
Hierarchical User Intent Graph Network for Multimedia Recommendation

Hierarchical User Intent Graph Network for Multimedia Recommendation This is our Pytorch implementation for the paper: Hierarchical User Intent Graph

6 Jan 05, 2023
This project provides the proof of the uniqueness of the equilibrium and the global asymptotic stability.

Delayed-cellular-neural-network This project provides the proof of the uniqueness of the equilibrium and the global asymptotic stability. There is als

4 Apr 28, 2022
A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

2 Jul 25, 2022
Official Pytorch Implementation of Unsupervised Image Denoising with Frequency Domain Knowledge

Unsupervised Image Denoising with Frequency Domain Knowledge (BMVC 2021 Oral) : Official Project Page This repository provides the official PyTorch im

Donggon Jang 12 Sep 26, 2022
Statsmodels: statistical modeling and econometrics in Python

About statsmodels statsmodels is a Python package that provides a complement to scipy for statistical computations including descriptive statistics an

statsmodels 8.1k Jan 02, 2023
Dynamic Multi-scale Filters for Semantic Segmentation (DMNet ICCV'2019)

Dynamic Multi-scale Filters for Semantic Segmentation (DMNet ICCV'2019) Introduction Official implementation of Dynamic Multi-scale Filters for Semant

23 Oct 21, 2022
This repository contains code from the paper "TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network"

TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network This repository contains code from the paper "TTS-GAN: A Transformer-based Tim

Intelligent Multimodal Computing and Sensing Laboratory (IMICS Lab) - Texas State University 108 Dec 29, 2022
No-reference Image Quality Assessment(NIQA) Algorithms (BRISQUE, NIQE, PIQE, RankIQA, MetaIQA)

No-Reference Image Quality Assessment Algorithms No-reference Image Quality Assessment(NIQA) is a task of evaluating an image without a reference imag

Dae-Young Song 26 Jan 04, 2023
Learning multiple gaits of quadruped robot using hierarchical reinforcement learning

Learning multiple gaits of quadruped robot using hierarchical reinforcement learning We propose a method to learn multiple gaits of quadruped robot us

Yunho Kim 17 Dec 11, 2022
Simulate genealogical trees and genomic sequence data using population genetic models

msprime msprime is a population genetics simulator based on tskit. Msprime can simulate random ancestral histories for a sample of individuals (consis

Tskit developers 150 Dec 14, 2022
AISTATS 2019: Confidence-based Graph Convolutional Networks for Semi-Supervised Learning

Confidence-based Graph Convolutional Networks for Semi-Supervised Learning Source code for AISTATS 2019 paper: Confidence-based Graph Convolutional Ne

MALL Lab (IISc) 56 Dec 03, 2022
This is the official github repository of the Met dataset

The Met dataset This is the official github repository of the Met dataset. The official webpage of the dataset can be found here. What is it? This cod

Nikolaos-Antonios Ypsilantis 35 Dec 17, 2022
MPViT:Multi-Path Vision Transformer for Dense Prediction

MPViT : Multi-Path Vision Transformer for Dense Prediction This repository inlcu

Youngwan Lee 272 Dec 20, 2022
Replication attempt for the Protein Folding Model

RGN2-Replica (WIP) To eventually become an unofficial working Pytorch implementation of RGN2, an state of the art model for MSA-less Protein Folding f

Eric Alcaide 36 Nov 29, 2022
Pytorch implementation for Semantic Segmentation/Scene Parsing on MIT ADE20K dataset

Semantic Segmentation on MIT ADE20K dataset in PyTorch This is a PyTorch implementation of semantic segmentation models on MIT ADE20K scene parsing da

MIT CSAIL Computer Vision 4.5k Jan 08, 2023
For IBM Quantum Challenge 2021 (May 20 - 26)

IBM Quantum Challenge 2021 Introduction Commemorating the 40-year anniversary of the Physics of Computation conference, and 5-year anniversary of IBM

Qiskit Community 140 Jan 01, 2023