Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Overview

Dataset Distillation by Matching Training Trajectories

Project Page | Paper


Teaser image

This repo contains code for training expert trajectories and distilling synthetic data from our Dataset Distillation by Matching Training Trajectories paper (CVPR 2022). Please see our project page for more results.

Dataset Distillation by Matching Training Trajectories
George Cazenavette, Tongzhou Wang, Antonio Torralba, Alexei A. Efros, Jun-Yan Zhu
CMU, MIT, UC Berkeley
CVPR 2022

The task of "Dataset Distillation" is to learn a small number of synthetic images such that a model trained on this set alone will have similar test performance as a model trained on the full real dataset.

Our method distills the synthetic dataset by directly optimizing the fake images to induce similar network training dynamics as the full, real dataset. We train "student" networks for many iterations on the synthetic data, measure the error in parameter space between the "student" and "expert" networks trained on real data, and back-propagate through all the student network updates to optimize the synthetic pixels.

Wearable ImageNet: Synthesizing Tileable Textures

Teaser image

Instead of treating our synthetic data as individual images, we can instead encourage every random crop (with circular padding) on a larger canvas of pixels to induce a good training trajectory. This results in class-based textures that are continuous around their edges.

Given these tileable textures, we can apply them to areas that require such properties, such as clothing patterns.

Visualizations made using FAB3D

Getting Started

First, download our repo:

git clone https://github.com/GeorgeCazenavette/mtt-distillation.git
cd mtt-distillation

For an express instillation, we include .yaml files.

If you have an RTX 30XX GPU (or newer), run

conda env create -f requirements_11_3.yaml

If you have an RTX 20XX GPU (or older), run

conda env create -f requirements_10_2.yaml

You can then activate your conda environment with

conda activate distillation
Quadro Users Take Note:

torch.nn.DataParallel seems to not work on Quadro A5000 GPUs, and this may extend to other Quadro cards.

If you experience indefinite hanging during training, try running the process with only 1 GPU by prepending CUDA_VISIBLE_DEVICES=0 to the command.

Generating Expert Trajectories

Before doing any distillation, you'll need to generate some expert trajectories using buffer.py

The following command will train 100 ConvNet models on CIFAR-100 with ZCA whitening for 50 epochs each:

python buffer.py --dataset=CIFAR100 --model=ConvNet --train_epochs=50 --num_experts=100 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

We used 50 epochs with the default learning rate for all of our experts. Worse (but still interesting) results can be obtained faster through training fewer experts by changing --num_experts. Note that experts need only be trained once and can be re-used for multiple distillation experiments.

Distillation by Matching Training Trajectories

The following command will then use the buffers we just generated to distill CIFAR-100 down to just 1 image per class:

python distill.py --dataset=CIFAR100 --ipc=1 --syn_steps=20 --expert_epochs=3 --max_start_epoch=20 --zca --lr_img=1000 --lr_lr=1e-05 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

ImageNet

Our method can also distill subsets of ImageNet into low-support synthetic sets.

When generating expert trajectories with buffer.py or distilling the dataset with distill.py, you must designate a named subset of ImageNet with the --subset flag.

For example,

python distill.py --dataset=ImageNet --subset=imagefruit --model=ConvNetD5 --ipc=1 --res=128 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagefruit subset (at 128x128 resolution) into the following 10 images

To register your own ImageNet subset, you can add it to the Config class at the top of utils.py.

Simply create a list with the desired class ID's and add it to the dictionary.

This gist contains a list of all 1k ImageNet classes and their corresponding numbers.

Texture Distillation

You can also use the same set of expert trajectories (except those using ZCA) to distill classes into toroidal textures by simply adding the --texture flag.

For example,

python distill.py --texture --dataset=ImageNet --subset=imagesquawk --model=ConvNetD5 --ipc=1 --res=256 --syn_steps=20 --expert_epochs=2 --max_start_epoch=10 --lr_img=1000 --lr_lr=1e-06 --lr_teacher=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

will distill the imagesquawk subset (at 256x256 resolution) into the following 10 textures

Acknowledgments

We would like to thank Alexander Li, Assaf Shocher, Gokul Swamy, Kangle Deng, Ruihan Gao, Nupur Kumari, Muyang Li, Gaurav Parmar, Chonghyuk Song, Sheng-Yu Wang, and Bingliang Zhang as well as Simon Lucey's Vision Group at the University of Adelaide for their valuable feedback. This work is supported, in part, by the NSF Graduate Research Fellowship under Grant No. DGE1745016 and grants from J.P. Morgan Chase, IBM, and SAP. Our code is adapted from https://github.com/VICO-UoE/DatasetCondensation

Related Work

  1. Tongzhou Wang et al. "Dataset Distillation", in arXiv preprint 2018
  2. Bo Zhao et al. "Dataset Condensation with Gradient Matching", in ICLR 2020
  3. Bo Zhao and Hakan Bilen. "Dataset Condensation with Differentiable Siamese Augmentation", in ICML 2021
  4. Timothy Nguyen et al. "Dataset Meta-Learning from Kernel Ridge-Regression", in ICLR 2021
  5. Timothy Nguyen et al. "Dataset Distillation with Infinitely Wide Convolutional Networks", in NeurIPS 2021
  6. Bo Zhao and Hakan Bilen. "Dataset Condensation with Distribution Matching", in arXiv preprint 2021
  7. Kai Wang et al. "CAFE: Learning to Condense Dataset by Aligning Features", in CVPR 2022

Reference

If you find our code useful for your research, please cite our paper.

@inproceedings{
cazenavette2022distillation,
title={Dataset Distillation by Matching Training Trajectories},
author={George Cazenavette and Tongzhou Wang and Antonio Torralba and Alexei A. Efros and Jun-Yan Zhu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
Owner
George Cazenavette
Carnegie Mellon University
George Cazenavette
Data and Code for paper Outlining and Filling: Hierarchical Query Graph Generation for Answering Complex Questions over Knowledge Graph is available for research purposes.

Data and Code for paper Outlining and Filling: Hierarchical Query Graph Generation for Answering Complex Questions over Knowledge Graph is available f

Yongrui Chen 5 Nov 10, 2022
PyTorch implementation of DirectCLR from paper Understanding Dimensional Collapse in Contrastive Self-supervised Learning

DirectCLR DirectCLR is a simple contrastive learning model for visual representation learning. It does not require a trainable projector as SimCLR. It

Meta Research 49 Dec 21, 2022
Official Pytorch implementation for "End2End Occluded Face Recognition by Masking Corrupted Features, TPAMI 2021"

End2End Occluded Face Recognition by Masking Corrupted Features This is the Pytorch implementation of our TPAMI 2021 paper End2End Occluded Face Recog

Haibo Qiu 25 Oct 31, 2022
YOLOX-CondInst - Implement CondInst which is a instances segmentation method on YOLOX

YOLOX CondInst -- YOLOX 实例分割 前言 本项目是自己学习实例分割时,复现的代码. 通过自己编程,让自己对实例分割有更进一步的了解。 若想

DDGRCF 16 Nov 18, 2022
Non-stationary GP package written from scratch in PyTorch

NSGP-Torch Examples gpytorch model with skgpytorch # Import packages import torch from regdata import NonStat2D from gpytorch.kernels import RBFKernel

Zeel B Patel 1 Mar 06, 2022
Fast, differentiable sorting and ranking in PyTorch

Torchsort Fast, differentiable sorting and ranking in PyTorch. Pure PyTorch implementation of Fast Differentiable Sorting and Ranking (Blondel et al.)

Teddy Koker 655 Jan 04, 2023
PyTorch implementation of D2C: Diffuison-Decoding Models for Few-shot Conditional Generation.

D2C: Diffuison-Decoding Models for Few-shot Conditional Generation Project | Paper PyTorch implementation of D2C: Diffuison-Decoding Models for Few-sh

Jiaming Song 90 Dec 27, 2022
Fuzzer for Linux Kernel Drivers

difuze: Fuzzer for Linux Kernel Drivers This repo contains all the sources (including setup scripts), you need to get difuze up and running. Tested on

seclab 344 Dec 27, 2022
Official code for Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset

Official code for our Interspeech 2021 - Spoken ObjectNet: A Bias-Controlled Spoken Caption Dataset [1]*. Visually-grounded spoken language datasets c

Ian Palmer 3 Jan 26, 2022
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
MetaShift: A Dataset of Datasets for Evaluating Contextual Distribution Shifts and Training Conflicts (ICLR 2022)

MetaShift: A Dataset of Datasets for Evaluating Distribution Shifts and Training Conflicts This repo provides the PyTorch source code of our paper: Me

88 Jan 04, 2023
Controlling Hill Climb Racing with Hand Tacking

Controlling Hill Climb Racing with Hand Tacking Opened Palm for Gas Closed Palm for Brake

Rohit Ingole 3 Jan 18, 2022
Özlem Taşkın 0 Feb 23, 2022
Fast and exact ILP-based solvers for the Minimum Flow Decomposition (MFD) problem, and variants of it.

MFD-ILP Fast and exact ILP-based solvers for the Minimum Flow Decomposition (MFD) problem, and variants of it. The solvers are implemented using Pytho

Algorithmic Bioinformatics Group @ University of Helsinki 4 Oct 23, 2022
MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc.

MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc. ⭐⭐⭐⭐⭐

568 Jan 04, 2023
A 3D sparse LBM solver implemented using Taichi

taichi_LBM3D Background Taichi_LBM3D is a 3D lattice Boltzmann solver with Multi-Relaxation-Time collision scheme and sparse storage structure impleme

Jianhui Yang 121 Jan 06, 2023
General neural ODE and DAE modules for power system dynamic modeling.

Py_PSNODE General neural ODE and DAE modules for power system dynamic modeling. The PyTorch-based ODE solver is developed based on torchdiffeq. Sample

14 Dec 31, 2022
Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation

Uncertainty Estimation via Response Scaling for Pseudo-mask Noise Mitigation in Weakly-supervised Semantic Segmentation Introduction This is a PyTorch

XMed-Lab 30 Sep 23, 2022
Inferred Model-based Fuzzer

IMF: Inferred Model-based Fuzzer IMF is a kernel API fuzzer that leverages an automated API model inferrence techinque proposed in our paper at CCS. I

SoftSec Lab 104 Sep 28, 2022