Code accompanying the paper on "An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers" published at NeurIPS, 2021

Overview

Code for "An Empirical Investigation of Domian Generalization with Empirical Risk Minimizers" (NeurIPS 2021)

Motivation and Introduction

Domain Generalization is a task in machine learning where given a shift in the input data distribution, one is expected to perform well on a test task with a different input data distribution. For example, one might train a digit classifier on MNIST data and ask the model to generalize to predict digits that are rotated by say 30 degrees.

While many approaches have been proposed for this problem, we were intrigued by the results on the DomainBed benchmark which suggested that using the simple, empirical risk minimization (ERM) with a proper hyperparameter sweep leads to performance close to state of the art on Domain Generalization Problems.

What governs the generalization of a trained deep learning model using ERM to a given data distribution? This is the question we seek to answer in our NeurIPS 2021 paper:

An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers. Rama Vedantam, David Lopez-Paz*, David Schwab*.

NeurIPS 2021 (*=Equal Contribution)

This repository contains code used for producing the results in our paper.

Initial Setup

  1. Run source init.sh to install all the dependencies for the project. This will also initialize DomainBed as a submodule for the project

  2. Set requisite paths in setup.sh, and run source setup.sh

Computing Generalization Measures

  • Get set up with the DomainBed codebase and launch a sweep for an initial set of trained models (illustrated below for rotated MNIST dataset):
cd DomainBed/

python -m domainbed.scripts.sweep launch\
       --data_dir=${DOMAINBED_DATA} \
       --output_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
       --algorithms=ERM \
       --holdout_fraction=0.5\
       --datasets=RotatedMNIST \
       --n_hparams=1\
       --command_launcher submitit

After this step, we have a set of trained models that we can now look to evaluate and measure. Note that unlike the original domainbed paper we holdout a larger fraction (50%) of the data for evaluation of the measures.

  • Once the sweep finishes, aggregate the different files for use by the domianbed_measures codebase:
python domainbed_measures/write_job_status_file.py \
                --sweep_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
                --output_txt="domainbed_measures/scratch/sweep_release.txt"
  • Once this step is complete, we can compute various generalization measures and store them to disk for future analysis using:
SLURM_PARTITION="TO_BE_SET"
python domainbed_measures/compute_gen_correlations.py \
	--algorithm=ERM \
    --job_done_file="domainbed_measures/scratch/sweep_release.txt" \
    --run_dir=${MEASURE_RUN_DIR} \
    --all_measures_one_job \
	--slurm_partition=${SLURM_PARTITION}

Where we utilize slurm on a compute cluster to scale the experiments to thousands of models. If you do not have access to such a cluster with multiple GPUs to parallelize the computation, use --slurm_partition="" above and the code will run on a single GPU (although the results might take a long time to compute!).

  • Finally, once the above code is done, use the following code snippet to aggregate the values of the different generalization measures:
python domainbed_measures/extract_generalization_features.py \
    --run_dir=${MEASURE_RUN_DIR} \
    --sweep_name="_out_ERM_RotatedMNIST"

This step yeilds .csv files where each row corresponds to a given trained model. Each row overall has the following format:

dataset | test_envs | measure 1 | measure 2 | measure 3 | target_err

where:

  • test_envs specifies which environments the model is tested on or equivalently trained on, since the remaining environments are used for training
  • target_err specifies the target error value for regression
  • measure 1 specifies the which measure is being computed, e.g. sharpness or fisher eigen value based measures

In case of the file named, for example, sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv, the validation error within domain wd_out_domain_err is also used as one of the measures and target_err is the out of domain generalization error, and all measures are computed on a held-out set of image inputs from the target domain (for more details see the paper).

Alternatively, in case of the file named, sweeps__out_ERM_RotatedMNIST_canon_False_wd.csv, the target_err is the validation accuracy in domain, and all the measures are computed on the in-distribution held-out images.

  • Using this file one can do a number of interesting regression analyses as reported in the paper for measuring generalization.

For example, to generate the kind of results in Table. 1 of the paper in the joint setting, run the following command options:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="joint"\
    --num_features=2 \
    --fix_one_feature_to_wd

Alternatively, to generate results in the stratified setting, run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="stratified"\
    --num_features=2 \
    --fix_one_feature_to_wd

Finally, to generate results using a single feature (Alone setting in Table. 1), run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --num_features=1

Translation of measures from the code to the paper

The following table illustrates all the measures in the paper (Appendix Table. 2) and how they are referred to in the codebase:

Measure Name Code Reference
H-divergence c2st
H-divergence + Source Error c2st_perr
H-divergence MS c2st_per_env
H-divergence MS + Source Error c2st_per_env_perr
H-divergence (train) c2st_train
H-divergence (train) + Source Error c2st_train_perr
H-divergence (train) MS c2st_train_per_env
Entropy-Source or Entropy entropy
Entropy-Target entropy_held_out
Fisher-Eigval-Diff fisher_eigval_sum_diff_ex_75
Fisher-Eigval fisher_eigval_sum_ex_75
Fisher-Align or Fisher (main paper) fisher_eigvec_align_ex_75
HΔH-divergence SS hdh
HΔH-divergence SS + Source Error hdh_perr
HΔH-divergence MS hdh_per_env
HΔH-divergence MS + Source Error hdh_per_env_perr
HΔH-divergence (train) SS hdh_train
HΔH-divergence (train) SS + Source Error hdh_train_perr
Jacobian jacobian_norm
Jacobian Ratio jacobian_norm_relative
Jacobian Diff jacobian_norm_relative_diff
Jacobian Log Ratio jacobian_norm_relative_log_diff
Mixup mixup
Mixup Ratio mixup_relative
Mixup Diff mixup_relative_diff
Mixup Log Ratio mixup_relative_log_diff
MMD-Gaussian mmd_gaussian
MMD-Mean-Cov mmd_mean_cov
L2-Path-Norm. path_norm
Sharpness sharp_mag
H+-divergence SS v_plus_c2st
H+-divergence SS + Source Error v_plus_c2st_perr
H+-divergence MS v_plus_c2st_per_env
H+-divergence MS + Source Error v_plus_c2st_per_env_perr
H+ΔH+-divergence SS v_plus_hdh
H+ΔH+-divergence SS + Source Error v_plus_hdh_perr
H+ΔH+-divergence MS v_plus_hdh_per_env
H+ΔH+-divergence MS + Source Error v_plus_hdh_per_env_perr
Source Error wd_out_domain_err

Acknowledgments

We thank the developers of Decodable Information Bottleneck, Domain Bed and Jonathan Frankle for code we found useful for this project.

License

This source code is released under the Creative Commons Attribution-NonCommercial 4.0 International license, included here.

Owner
Meta Research
Meta Research
Linear image-to-image translation

Linear (Un)supervised Image-to-Image Translation Examples for linear orthogonal transformations in PCA domain, learned without pairing supervision. Tr

Eitan Richardson 40 Aug 31, 2022
Training Structured Neural Networks Through Manifold Identification and Variance Reduction

Training Structured Neural Networks Through Manifold Identification and Variance Reduction This repository is a pytorch implementation of the Regulari

0 Dec 23, 2021
MANO hand model porting for the GraspIt simulator

Learning Joint Reconstruction of Hands and Manipulated Objects - ManoGrasp Porting the MANO hand model to GraspIt! simulator Yana Hasson, Gül Varol, D

Lucas Wohlhart 10 Feb 08, 2022
Continuous Diffusion Graph Neural Network

We present Graph Neural Diffusion (GRAND) that approaches deep learning on graphs as a continuous diffusion process and treats Graph Neural Networks (GNNs) as discretisations of an underlying PDE.

Twitter Research 227 Jan 05, 2023
A repository for storing njxzc final exam review material

文档地址,请戳我 👈 👈 👈 ☀️ 1.Reason 大三上期末复习软件工程的时候,发现其他高校在GitHub上开源了他们学校的期末试题,我很受触动。期末

GuJiakai 2 Jan 18, 2022
Code for paper "Vocabulary Learning via Optimal Transport for Neural Machine Translation"

**Codebase and data are uploaded in progress. ** VOLT(-py) is a vocabulary learning codebase that allows researchers and developers to automaticaly ge

416 Jan 09, 2023
GEA - Code for Guided Evolution for Neural Architecture Search

Efficient Guided Evolution for Neural Architecture Search Usage Create a conda e

6 Jan 03, 2023
BRepNet: A topological message passing system for solid models

BRepNet: A topological message passing system for solid models This repository contains the an implementation of BRepNet: A topological message passin

Autodesk AI Lab 42 Dec 30, 2022
Marine debris detection with commercial satellite imagery and deep learning.

Marine debris detection with commercial satellite imagery and deep learning. Floating marine debris is a global pollution problem which threatens mari

Inter Agency Implementation and Advanced Concepts 56 Dec 16, 2022
Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Kento Nishi 22 Jul 07, 2022
This repo holds codes of the ICCV21 paper: Visual Alignment Constraint for Continuous Sign Language Recognition.

VAC_CSLR This repo holds codes of the paper: Visual Alignment Constraint for Continuous Sign Language Recognition.(ICCV 2021) [paper] Prerequisites Th

Yuecong Min 64 Dec 19, 2022
Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Few-shot Image Generation via Cross-domain Correspondence Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zh

Utkarsh Ojha 251 Dec 11, 2022
Contrastive Feature Loss for Image Prediction

Contrastive Feature Loss for Image Prediction We provide a PyTorch implementation of our contrastive feature loss presented in: Contrastive Feature Lo

Alex Andonian 44 Oct 05, 2022
Bringing Characters to Life with Computer Brains in Unity

AI4Animation: Deep Learning for Character Control This project explores the opportunities of deep learning for character animation and control as part

Sebastian Starke 5.5k Jan 04, 2023
The implementation code for "DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruction"

DAGAN This is the official implementation code for DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruct

TensorLayer Community 159 Nov 22, 2022
Implementation of Hire-MLP: Vision MLP via Hierarchical Rearrangement and An Image Patch is a Wave: Phase-Aware Vision MLP.

Hire-Wave-MLP.pytorch Implementation of Hire-MLP: Vision MLP via Hierarchical Rearrangement and An Image Patch is a Wave: Phase-Aware Vision MLP Resul

Nevermore 29 Oct 28, 2022
Command-line tool for downloading and extending the RedCaps dataset.

RedCaps Downloader This repository provides the official command-line tool for downloading and extending the RedCaps dataset. Users can seamlessly dow

RedCaps dataset 33 Dec 14, 2022
K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding and Generation in E-Commerce (EMNLP Founding 2021)

Introduction K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding and Generation in E-Commerce. Installation PyTor

Xu Song 21 Nov 16, 2022
Code and data for ImageCoDe, a contextual vison-and-language benchmark

ImageCoDe This repository contains code and data for ImageCoDe: Image Retrieval from Contextual Descriptions. Data All collected descriptions for the

McGill NLP 27 Dec 02, 2022