Benchmarks for semi-supervised domain generalization.

Overview

Semi-Supervised Domain Generalization

This code is the official implementation of the following paper: Semi-Supervised Domain Generalization with Stochastic StyleMatch. The paper addresses a practical and yet under-studied setting for domain generalization: one needs to use limited labeled data along with abundant unlabeled data gathered from multiple distinct domains to learn a generalizable model. This setting greatly challenges existing domain generalization methods, which are not designed to deal with unlabeled data and are thus less scalable in practice. Our approach, StyleMatch, extends the pseudo-labeling-based FixMatch—a state-of-the-art semi-supervised learning framework—in two crucial ways: 1) a stochastic classifier is designed to reduce overfitting and 2) the two-view consistency learning paradigm in FixMatch is upgraded to a multi-view version with style augmentation as the third complementary view. Two benchmarks are constructed for evaluation. Please see the paper at https://arxiv.org/abs/2106.00592 for more details.

How to setup the environment

This code is built on top of Dassl.pytorch. Please follow the instructions provided in https://github.com/KaiyangZhou/Dassl.pytorch to install the dassl environment, as well as to prepare the datasets, PACS and OfficeHome. The five random labeled-unlabeled splits can be downloaded at the following links: pacs, officehome. The splits need to be extracted to the two datasets' folders. Assume you put the datasets under the directory $DATA, the structure should look like

$DATA/
    pacs/
        images/
        splits/
        splits_ssdg/
    office_home_dg/
        art/
        clipart/
        product/
        real_world/
        splits_ssdg/

The style augmentation is based on AdaIN and the implementation is based on this code https://github.com/naoto0804/pytorch-AdaIN. Please download the weights of the decoder and the VGG from https://github.com/naoto0804/pytorch-AdaIN and put them under a new folder ssdg-benchmark/weights.

How to run StyleMatch

The script is provided in ssdg-benchmark/scripts/StyleMatch/run_ssdg.sh. You need to update the DATA variable that points to the directory where you put the datasets. There are three input arguments: DATASET, NLAB (total number of labels), and CFG. See the tables below regarding how to set the values for these variables.

Dataset NLAB
ssdg_pacs 210 or 105
ssdg_officehome 1950 or 975
CFG Description
v1 FixMatch + stochastic classifier + T_style
v2 FixMatch + stochastic classifier + T_style-only (i.e. no T_strong)
v3 FixMatch + stochastic classifier
v4 FixMatch

v1 refers to StyleMatch, which is our final model. See the config files in configs/trainers/StyleMatch for the detailed settings.

Here we give an example. Say you want to run StyleMatch on PACS under the 10-labels-per-class setting (i.e. 210 labels in total), simply run the following commands in your terminal,

conda activate dassl
cd ssdg-benchmark/scripts/StyleMatch
bash run_ssdg.sh ssdg_pacs 210 v1

In this case, the code will run StyleMatch in four different setups (four target domains), each for five times (five random seeds). You can modify the code to run a single experiment instead of all at once if you have multiple GPUs.

At the end of training, you will have

output/
    ssdg_pacs/
        nlab_210/
            StyleMatch/
                resnet18/
                    v1/ # contains results on four target domains
                        art_painting/ # contains five folders: seed1-5
                        cartoon/
                        photo/
                        sketch/

To show the results, simply do

python parse_test_res.py output/ssdg_pacs/nlab_210/StyleMatch/resnet18/v1 --multi-exp

Citation

If you use this code in your research, please cite our paper

@article{zhou2021stylematch,
    title={Semi-Supervised Domain Generalization with Stochastic StyleMatch},
    author={Zhou, Kaiyang and Loy, Chen Change and Liu, Ziwei},
    journal={arXiv preprint arXiv:2106.00592},
    year={2021}
}
Owner
Kaiyang
Researcher in computer vision and machine learning :)
Kaiyang
Research shows Google collects 20x more data from Android than Apple collects from iOS. Block this non-consensual telemetry using pihole blocklists.

pihole-antitelemetry Research shows Google collects 20x more data from Android than Apple collects from iOS. Block both using these pihole lists. Proj

Adrian Edwards 290 Jan 09, 2023
Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

Unofficial PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners This repository is built upon BEiT, thanks very much! Now, we on

Zhiliang Peng 2.3k Jan 04, 2023
Python Classes: Medical Insurance Project using Object Oriented Programming Concepts

Medical-Insurance-Project-OOP Python Classes: Medical Insurance Project using Object Oriented Programming Concepts Classes are an incredibly useful pr

Hugo B. 0 Feb 04, 2022
deep_image_prior_extension

Code for "Is Deep Image Prior in Need of a Good Education?" Project page: https://jleuschn.github.io/docs.educated_deep_image_prior/. Supplementary Ma

riccardo barbano 7 Jan 09, 2022
Build an Amazon SageMaker Pipeline to Transform Raw Texts to A Knowledge Graph

Build an Amazon SageMaker Pipeline to Transform Raw Texts to A Knowledge Graph This repository provides a pipeline to create a knowledge graph from ra

AWS Samples 3 Jan 01, 2022
[CVPR'2020] DeepDeform: Learning Non-rigid RGB-D Reconstruction with Semi-supervised Data

DeepDeform (CVPR'2020) DeepDeform is an RGB-D video dataset containing over 390,000 RGB-D frames in 400 videos, with 5,533 optical and scene flow imag

Aljaz Bozic 165 Jan 09, 2023
Ludwig Benchmarking Toolkit

Ludwig Benchmarking Toolkit The Ludwig Benchmarking Toolkit is a personalized benchmarking toolkit for running end-to-end benchmark studies across an

HazyResearch 17 Nov 18, 2022
Auto HMM: Automatic Discrete and Continous HMM including Model selection

Auto HMM: Automatic Discrete and Continous HMM including Model selection

Chess_champion 29 Dec 07, 2022
Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network

DeepCDR Cancer Drug Response Prediction via a Hybrid Graph Convolutional Network This work has been accepted to ECCB2020 and was also published in the

Qiao Liu 50 Dec 18, 2022
Social Distancing Detector

Computer vision has opened up a lot of opportunities to explore into AI domain that were earlier highly limited. Here is an application of haarcascade classifier and OpenCV to develop a social distan

Ashish Pandey 2 Jul 18, 2022
[ICCV 2021 (oral)] Planar Surface Reconstruction from Sparse Views

Planar Surface Reconstruction From Sparse Views Linyi Jin, Shengyi Qian, Andrew Owens, David F. Fouhey University of Michigan ICCV 2021 (Oral) This re

Linyi Jin 89 Jan 05, 2023
Python Fanduel API (2021) - Lineup Automation

Southpaw is a python package that provides access to the Fanduel API. Optimize your DFS experience by programmatically updating your lineups, analyzin

Brandin Canfield 13 Jan 04, 2023
Implementation for paper: Self-Regulation for Semantic Segmentation

Self-Regulation for Semantic Segmentation This is the PyTorch implementation for paper Self-Regulation for Semantic Segmentation, ICCV 2021. Citing SR

Dong ZHANG 30 Nov 21, 2022
Orchestrating Distributed Materials Acceleration Platform Tutorial

Orchestrating Distributed Materials Acceleration Platform Tutorial This tutorial for orchestrating distributed materials acceleration platform was pre

BIG-MAP 1 Jan 25, 2022
This is an unofficial PyTorch implementation of Meta Pseudo Labels

This is an unofficial PyTorch implementation of Meta Pseudo Labels. The official Tensorflow implementation is here.

Jungdae Kim 320 Jan 08, 2023
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

Karbo 45 Dec 21, 2022
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Yash Sanjay Bhalgat 616 Jan 06, 2023
Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains

Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains This is an accompanying repository to the ICAIL 2021 pap

4 Dec 16, 2021
A PyTorch implementation of "Graph Wavelet Neural Network" (ICLR 2019)

Graph Wavelet Neural Network ⠀⠀ A PyTorch implementation of Graph Wavelet Neural Network (ICLR 2019). Abstract We present graph wavelet neural network

Benedek Rozemberczki 490 Dec 16, 2022
FSL-Mate: A collection of resources for few-shot learning (FSL).

FSL-Mate is a collection of resources for few-shot learning (FSL). In particular, FSL-Mate currently contains FewShotPapers: a paper list which tracks

Yaqing Wang 1.5k Jan 08, 2023