Official implementation of paper Gradient Matching for Domain Generalization

Related tags

Deep Learningfish
Overview

Gradient Matching for Domain Generalisation

This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper, we propose an inter-domain gradient matching (IDGM) objective that targets domain generalization by maximizing the inner product between gradients from different domains. To avoid computing the expensive second-order derivative of the IDGM objective, we derive a simpler first-order algorithm named Fish that approximates its optimization.

This repository contains code to reproduce the main results of our paper.

Dependencies

(Recommended) You can setup up conda environment with all required dependencies using environment.yml:

conda env create -f environment.yml
conda activate fish

Otherwise you can also install the following packages manually:

python=3.7.10
numpy=1.20.2
pytorch=1.8.1
torchaudio=0.8.1
torchvision=0.9.1
torch-cluster=1.5.9
torch-geometric=1.7.0
torch-scatter=2.0.6
torch-sparse=0.6.9
wilds=1.1.0
scikit-learn=0.24.2
scipy=1.6.3
seaborn=0.11.1
tqdm=4.61.0

Running Experiments

We offer options to train using our proposed method Fish or by using Empirical Risk Minimisation baseline. This can be specified by the --algorithm flag (either fish or erm).

CdSprites-N

We propose this simple shape-color dataset based on the dSprites dataset, which contains a collection of white 2D sprites of different shapes, scales, rotations and positions. The dataset contains N domains, where N can be specified. The goal is to classify the shape of the sprites, and there is a shape-color deterministic matching that is specific per domain. This way we have shape as the invariant feature and color as the spurious feature. On the test set, however, this correlation between color and shape is removed. See the image below for an illustration.

cdsprites

The CdSprites-N dataset can be downloaded here. After downloading, please extract the zip file to your preferred data dir (e.g. <your_data_dir>/cdsprites). The following command runs an experiment using Fish with number of domains N=15:

python main.py --dataset cdsprites --algorithm fish --data-dir <your_data_dir> --num-domains 15

The number of domains you can choose from are: N = 5, 10, 15, 20, 25, 30, 35, 40, 45, 50.

WILDS

We include the following 6 datasets from the WILDS benchmark: amazon, camelyon, civil, fmow, iwildcam, poverty. The datasets can be downloaded automatically to a specified data folder. For instance, to train with Fish on Amazon dataset, simply run:

python main.py --dataset amazon --algorithm fish --data-dir <your_data_dir>

This should automatically download the Amazon dataset to <your_data_dir>/wilds. Experiments on other datasets can be ran by the following commands:

python main.py --dataset camelyon --algorithm fish --data-dir <your_data_dir>
python main.py --dataset civil --algorithm fish --data-dir <your_data_dir>
python main.py --dataset fmow --algorithm fish --data-dir <your_data_dir>
python main.py --dataset iwildcam --algorithm fish --data-dir <your_data_dir>
python main.py --dataset poverty --algorithm fish --data-dir <your_data_dir>

Alternatively, you can also download the datasets to <your_data_dir>/wilds manually by following the instructions here. See current results on WILDS here: image

DomainBed

For experiments on datasets including CMNIST, RMNIST, VLCS, PACS, OfficeHome, TerraInc and DomainNet, we implemented Fish on the DomainBed benchmark (see here) and you can compare our algorithm against up to 20 SOTA baselines. See current results on DomainBed here:

image

Citation

If you make use of this code in your research, we would appreciate if you considered citing the paper that is most relevant to your work:

@article{shi2021gradient,
	title="Gradient Matching for Domain Generalization.",
	author="Yuge {Shi} and Jeffrey {Seely} and Philip H. S. {Torr} and N. {Siddharth} and Awni {Hannun} and Nicolas {Usunier} and Gabriel {Synnaeve}",
	journal="arXiv preprint arXiv:2104.09937",
	year="2021"}

Contributions

We welcome contributions via pull requests. Please email [email protected] or [email protected] for any question/request.

Official PyTorch code of Holistic 3D Scene Understanding from a Single Image with Implicit Representation (CVPR 2021)

Implicit3DUnderstanding (Im3D) [Project Page] Holistic 3D Scene Understanding from a Single Image with Implicit Representation Cheng Zhang, Zhaopeng C

Cheng Zhang 149 Jan 08, 2023
NeurIPS-2021: Neural Auto-Curricula in Two-Player Zero-Sum Games.

NAC Official PyTorch implementation of NAC from the paper: Neural Auto-Curricula in Two-Player Zero-Sum Games. We release code for: Gradient based ora

Xidong Feng 19 Nov 11, 2022
Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research

Megaverse Megaverse is a new 3D simulation platform for reinforcement learning and embodied AI research. The efficient design of the engine enables ph

Aleksei Petrenko 191 Dec 23, 2022
"Neural Turing Machine" in Tensorflow

Neural Turing Machine in Tensorflow Tensorflow implementation of Neural Turing Machine. This implementation uses an LSTM controller. NTM models with m

Taehoon Kim 1k Dec 06, 2022
[NeurIPS-2021] Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data

MosaicKD Code for NeurIPS-21 paper "Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data" 1. Motivation Natural images share common l

ZJU-VIPA 37 Nov 10, 2022
This repository implements variational graph auto encoder by Thomas Kipf.

Variational Graph Auto-encoder in Pytorch This repository implements variational graph auto-encoder by Thomas Kipf. For details of the model, refer to

DaehanKim 215 Jan 02, 2023
WORD: Revisiting Organs Segmentation in the Whole Abdominal Region

WORD: Revisiting Organs Segmentation in the Whole Abdominal Region (Paper and DataSet). [New] Note that all the emails about the download permission o

Healthcare Intelligence Laboratory 71 Dec 22, 2022
SCAAML is a deep learning framwork dedicated to side-channel attacks run on top of TensorFlow 2.x.

SCAAML (Side Channel Attacks Assisted with Machine Learning) is a deep learning framwork dedicated to side-channel attacks. It is written in python and run on top of TensorFlow 2.x.

Google 69 Dec 21, 2022
This library is a location of the LegacyLogger for PyTorch Lightning.

neptune-contrib Documentation See neptune-contrib documentation site Installation Get prerequisites python versions 3.5.6/3.6 are supported Install li

neptune.ai 26 Oct 07, 2021
Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships.

feature-set-comp Compares various time-series feature sets on computational performance, within-set structure, and between-set relationships. Reposito

Trent Henderson 7 May 25, 2022
Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

33 Dec 01, 2022
Source Code for DialogBERT: Discourse-Aware Response Generation via Learning to Recover and Rank Utterances (https://arxiv.org/pdf/2012.01775.pdf)

DialogBERT This is a PyTorch implementation of the DialogBERT model described in DialogBERT: Neural Response Generation via Hierarchical BERT with Dis

Xiaodong Gu 67 Jan 06, 2023
🔥RandLA-Net in Tensorflow (CVPR 2020, Oral & IEEE TPAMI 2021)

RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds (CVPR 2020) This is the official implementation of RandLA-Net (CVPR2020, Oral

Qingyong 1k Dec 30, 2022
[CVPR 2021] Pytorch implementation of Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs

Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs In this work, we propose a framework HijackGAN, which enables non-linear latent space travers

Hui-Po Wang 46 Sep 05, 2022
Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021)

Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces(ICML 2021) This repository contains the code

149 Dec 15, 2022
Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

STARS Laboratory 8 Sep 14, 2022
A GOOD REPRESENTATION DETECTS NOISY LABELS

A GOOD REPRESENTATION DETECTS NOISY LABELS This code is a PyTorch implementation of the paper: Prerequisites Python 3.6.9 PyTorch 1.7.1 Torchvision 0.

<a href=[email protected]"> 64 Jan 04, 2023
Code for Towards Unifying Behavioral and Response Diversity for Open-ended Learning in Zero-sum Games

Unifying Behavioral and Response Diversity for Open-ended Learning in Zero-sum Games How to run our algorithm? Create the new environment using: conda

MARL @ SJTU 8 Dec 27, 2022
Individual Tree Crown classification on WorldView-2 Images using Autoencoder -- Group 9 Weak learners - Final Project (Machine Learning 2020 Course)

Created by Olga Sutyrina, Sarah Elemili, Abduragim Shtanchaev and Artur Bille Individual Tree Crown classification on WorldView-2 Images using Autoenc

2 Dec 08, 2022
Experiments for distributed optimization algorithms

Network-Distributed Algorithm Experiments -- This repository contains a set of optimization algorithms and objective functions, and all code needed to

Boyue Li 40 Dec 04, 2022