Official Implementation of SWAD (NeurIPS 2021)

Related tags

Deep Learningswad
Overview

SWAD: Domain Generalization by Seeking Flat Minima (NeurIPS'21)

Official PyTorch implementation of SWAD: Domain Generalization by Seeking Flat Minima.

Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, Sungrae Park.

Note that this project is built upon [email protected].

Preparation

Dependencies

pip install -r requirements.txt

Datasets

python -m domainbed.scripts.download --data_dir=/my/datasets/path

Environments

Environment details used for our study.

Python: 3.8.6
PyTorch: 1.7.0+cu92
Torchvision: 0.8.1+cu92
CUDA: 9.2
CUDNN: 7603
NumPy: 1.19.4
PIL: 8.0.1

How to Run

train_all.py script conducts multiple leave-one-out cross-validations for all target domain.

python train_all.py exp_name --dataset PACS --data_dir /my/datasets/path

Experiment results are reported as a table. In the table, the row SWAD indicates out-of-domain accuracy from SWAD. The row SWAD (inD) indicates in-domain validation accuracy.

Example results:

+------------+--------------+---------+---------+---------+---------+
| Selection  | art_painting | cartoon |  photo  |  sketch |   Avg.  |
+------------+--------------+---------+---------+---------+---------+
|   oracle   |   82.245%    | 85.661% | 97.530% | 83.461% | 87.224% |
|    iid     |   87.919%    | 78.891% | 96.482% | 78.435% | 85.432% |
|    last    |   82.306%    | 81.823% | 95.135% | 82.061% | 85.331% |
| last (inD) |   95.807%    | 95.291% | 96.306% | 95.477% | 95.720% |
| iid (inD)  |   97.275%    | 96.619% | 96.696% | 97.253% | 96.961% |
|    SWAD    |   89.750%    | 82.942% | 97.979% | 81.870% | 88.135% |
| SWAD (inD) |   97.713%    | 97.649% | 97.316% | 98.074% | 97.688% |
+------------+--------------+---------+---------+---------+---------+

In this example, the DG performance of SWAD for PACS dataset is 88.135%.

If you set indomain_test option to True, the validation set is splitted to validation and test sets, and the (inD) keys become to indicate in-domain test accuracy.

Reproduce the results of the paper

We provide the instructions to reproduce the main results of the paper, Table 1 and 2. Note that the difference in a detailed environment or uncontrolled randomness may bring a little different result from the paper.

  • PACS
python train_all.py PACS0 --dataset PACS --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS1 --dataset PACS --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS2 --dataset PACS --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • VLCS
python train_all.py VLCS0 --dataset VLCS --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS1 --dataset VLCS --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS2 --dataset VLCS --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
  • OfficeHome
python train_all.py OH0 --dataset OfficeHome --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH1 --dataset OfficeHome --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH2 --dataset OfficeHome --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • TerraIncognita
python train_all.py TR0 --dataset TerraIncognita --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR1 --dataset TerraIncognita --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR2 --dataset TerraIncognita --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • DomainNet
python train_all.py DN0 --dataset DomainNet --deterministic --trial_seed 0 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN1 --dataset DomainNet --deterministic --trial_seed 1 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN2 --dataset DomainNet --deterministic --trial_seed 2 --checkpoint_freq 500 --data_dir /my/datasets/path

Main Results

Citation

The paper will be published at NeurIPS 2021.

@inproceedings{cha2021swad,
  title={SWAD: Domain Generalization by Seeking Flat Minima},
  author={Cha, Junbum and Chun, Sanghyuk and Lee, Kyungjae and Cho, Han-Cheol and Park, Seunghyun and Lee, Yunsung and Park, Sungrae},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

License

This source code is released under the MIT license, included here.

This project includes some code from DomainBed, also MIT licensed.

Specificity-preserving RGB-D Saliency Detection

Specificity-preserving RGB-D Saliency Detection Authors: Tao Zhou, Huazhu Fu, Geng Chen, Yi Zhou, Deng-Ping Fan, and Ling Shao. 1. Preface This reposi

Tao Zhou 35 Jan 08, 2023
Code of our paper "Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning"

CCOP Code of our paper Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning Requirement Install OpenSelfSup Install Detectron2

Chenhongyi Yang 21 Dec 13, 2022
Supporting code for the paper "Dangers of Bayesian Model Averaging under Covariate Shift"

Dangers of Bayesian Model Averaging under Covariate Shift This repository contains the code to reproduce the experiments in the paper Dangers of Bayes

Pavel Izmailov 25 Sep 21, 2022
Gradient Step Denoiser for convergent Plug-and-Play

Source code for the paper "Gradient Step Denoiser for convergent Plug-and-Play"

Samuel Hurault 11 Sep 17, 2022
Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting

Sequence to Sequence (seq2seq) Recurrent Neural Network (RNN) for Time Series Forecasting Note: You can find here the accompanying seq2seq RNN forecas

Guillaume Chevalier 1k Dec 25, 2022
Used to record WKU's utility bills on a regular basis.

WKU水电费小助手 一个用于定期记录WKU水电费的脚本 Looking for English Readme? 背景 由于WKU校园内的水电账单系统时常存在扣费延迟的现象,而补扣的费用缺乏令人信服的证明。不少学生为费用摸不着头脑,但也没有申诉的依据。为了更好地掌握水电费使用情况,留下一手证据,我开源

2 Jul 21, 2022
Implementation for paper "Towards the Generalization of Contrastive Self-Supervised Learning"

Contrastive Self-Supervised Learning on CIFAR-10 Paper "Towards the Generalization of Contrastive Self-Supervised Learning", Weiran Huang, Mingyang Yi

Weiran Huang 13 Nov 30, 2022
Cluttered MNIST Dataset

Cluttered MNIST Dataset A setup script will download MNIST and produce mnist/*.t7 files: luajit download_mnist.lua Example usage: local mnist_clutter

DeepMind 50 Jul 12, 2022
Face Mask Detection System built with OpenCV, TensorFlow using Computer Vision concepts

Face mask detection Face Mask Detection System built with OpenCV, TensorFlow using Computer Vision concepts in order to detect face masks in static im

Vaibhav Shukla 1 Oct 27, 2021
This repository is for EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpretation Data

InterpretationData This repository is for our EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpr

4 Apr 21, 2022
Anti-UAV base on PaddleDetection

Paddle-Anti-UAV Anti-UAV base on PaddleDetection Background UAVs are very popular and we can see them in many public spaces, such as parks and playgro

Qingzhong Wang 2 Apr 20, 2022
Continuous Augmented Positional Embeddings (CAPE) implementation for PyTorch

PyTorch implementation of Continuous Augmented Positional Embeddings (CAPE), by Likhomanenko et al. Enhance your Transformer positional embeddings with easy-to-use augmentations!

Guillermo Cámbara 26 Dec 13, 2022
This is an early in-development version of training CLIP models with hivemind.

A transformer that does not hog your GPU memory This is an early in-development codebase: if you want a stable and documented hivemind codebase, look

<a href=[email protected]"> 4 Nov 06, 2022
[ICCV 2021] Group-aware Contrastive Regression for Action Quality Assessment

CoRe Created by Xumin Yu*, Yongming Rao*, Wenliang Zhao, Jiwen Lu, Jie Zhou This is the PyTorch implementation for ICCV paper Group-aware Contrastive

Xumin Yu 31 Dec 24, 2022
ALL Snow Removed: Single Image Desnowing Algorithm Using Hierarchical Dual-tree Complex Wavelet Representation and Contradict Channel Loss (HDCWNet)

ALL Snow Removed: Single Image Desnowing Algorithm Using Hierarchical Dual-tree Complex Wavelet Representation and Contradict Channel Loss (HDCWNet) (

Wei-Ting Chen 49 Dec 27, 2022
OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis

OpenABC-D: A Large-Scale Dataset For Machine Learning Guided Integrated Circuit Synthesis Overview OpenABC-D is a large-scale labeled dataset generate

NYU Machine-Learning guided Design Automation (MLDA) 31 Nov 22, 2022
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
Source code for "Roto-translated Local Coordinate Framesfor Interacting Dynamical Systems"

Roto-translated Local Coordinate Frames for Interacting Dynamical Systems Source code for Roto-translated Local Coordinate Frames for Interacting Dyna

Miltiadis Kofinas 19 Nov 27, 2022
This repo is duplication of jwyang/faster-rcnn.pytorch

Faster RCNN Pytorch This repo is duplication of jwyang/faster-rcnn.pytorch C/C++ code are removed and easier to study. Python 3.8.5 Ubuntu 20.04.1 LTS

Kim Jihwan 1 Jan 14, 2022
Boundary-aware Transformers for Skin Lesion Segmentation

Boundary-aware Transformers for Skin Lesion Segmentation Introduction This is an official release of the paper Boundary-aware Transformers for Skin Le

Jiacheng Wang 79 Dec 16, 2022