PyTorch code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised DA

Related tags

Deep LearningSENTRY
Overview

PyTorch Code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation

Viraj Prabhu, Shivam Khare, Deeksha Kartik, Judy Hoffman

Many existing approaches for unsupervised domain adaptation (UDA) focus on adapting under only data distribution shift and offer limited success under additional cross-domain label distribution shift. Recent work based on self-training using target pseudolabels has shown promise, but on challenging shifts pseudolabels may be highly unreliable and using them for self-training may cause error accumulation and domain misalignment. We propose Selective Entropy Optimization via Committee Consistency (SENTRY), a UDA algorithm that judges the reliability of a target instance based on its predictive consistency under a committee of random image transformations. Our algorithm then selectively minimizes predictive entropy to increase confidence on highly consistent target instances, while maximizing predictive entropy to reduce confidence on highly inconsistent ones. In combination with pseudolabel-based approximate target class balancing, our approach leads to significant improvements over the state-of-the-art on 27/31 domain shifts from standard UDA benchmarks as well as benchmarks designed to stress-test adaptation under label distribution shift.

method

Table of Contents

Setup and Dependencies

  1. Create an anaconda environment with Python 3.6: conda create -n sentry python=3.6.8 and activate: conda activate sentry
  2. Navigate to the code directory: cd code/
  3. Install dependencies: pip install -r requirements.txt

And you're all set up!

Usage

Download data

Data for SVHN->MNIST is downloaded automatically via PyTorch. Data for other benchmarks can be downloaded from the following links. The splits used for our experiments are already included in the data/ folder):

  1. DomainNet
  2. OfficeHome
  3. VisDA2017 (only train and validation needed)

Pretrained checkpoints

To reproduce numbers reported in the paper, we include a a few pretrained checkpoints. We include checkpoints (source and adapted) for SVHN to MNIST (DIGITS) in the checkpoints directory. Source and adapted checkpoints for Clipart to Sketch adaptation (from DomainNet) and Real_World to Product adaptation (from OfficeHome RS-UT) can be downloaded from this link, and should be saved to the checkpoints/source and checkpoints/SENTRY directory as appropriate.

Train and adapt model

  • Natural label distribution shift: Adapt a model from to for a given (where benchmark may be DomainNet, OfficeHome, VisDA, or DIGITS), as follows:
python train.py --id <experiment_id> \
                --source <source> \
                --target <target> \
                --img_dir <image_directory> \
                --LDS_type <LDS_type> \
                --load_from_cfg True \
                --cfg_file 'config/<benchmark>/<cfg_file>.yml' \
                --use_cuda True

SENTRY hyperparameters are provided via a sentry.yml config file in the corresponding config/<benchmark> folder (On DIGITS, we also provide a config for baseline adaptation via DANN). The list of valid source/target domains per-benchmark are:

  • DomainNet: real, clipart, sketch, painting
  • OfficeHome_RS_UT: Real_World, Clipart, Product
  • OfficeHome: Real_World, Clipart, Product, Art
  • VisDA2017: visda_train, visda_test
  • DIGITS: Only svhn (source) to mnist (target) adaptation is currently supported.

Pass in the path to the parent folder containing dataset images via the --img_dir <name_of_directory> flag (eg. --img_dir '~/data/DomainNet'). Pass in the label distribution shift type via the --LDS_type flag: For DomainNet, OfficeHome (standard), and VisDA2017, pass in --LDS_type 'natural' (default). For OfficeHome RS-UT, pass in --LDS_type 'RS_UT'. For DIGITS, pass in --LDS_type as one of IF1, IF20, IF50, or IF100, to load a manually long-tailed target training split with a given imbalance factor (IF), as described in Table 4 of the paper.

To load a pretrained DA checkpoint instead of training your own, additionally pass --load_da True and --id <benchmark_name> to the script above. Finally, the training script will log performance metrics to the console (average and aggregate accuracy), and additionally plot and save some per-class performance statistics to the results/ folder.

Note: By default this code runs on GPU. To run on CPU pass: --use_cuda False

Reference

If you found this code useful, please consider citing:

@article{prabhu2020sentry
   author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy},
   title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation},
   year = {2020},
   journal = {arXiv preprint: 2012.11460},
}

Acknowledgements

We would like to thank the developers of PyTorch for building an excellent framework, in addition to the numerous contributors to all the open-source packages we use.

License

MIT

Create Data & AI apps in 20 lines of code with Shimoku

Install with: pip install shimoku-api-python Start with: from os import getenv import shimoku_api_python.client as Shimoku

Shimoku 5 Nov 07, 2022
PyTorch implementation of a Real-ESRGAN model trained on custom dataset

Real-ESRGAN PyTorch implementation of a Real-ESRGAN model trained on custom dataset. This model shows better results on faces compared to the original

Sber AI 160 Jan 04, 2023
From a body shape, infer the anatomic skeleton.

OSSO: Obtaining Skeletal Shape from Outside (CVPR 2022) This repository contains the official implementation of the skeleton inference from: OSSO: Obt

Marilyn Keller 166 Dec 28, 2022
ML-Decoder: Scalable and Versatile Classification Head

ML-Decoder: Scalable and Versatile Classification Head Paper Official PyTorch Implementation Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baru

189 Jan 04, 2023
Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models

Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models Abstract Many applications of generative models rely on the marginali

Stanford Intelligent Systems Laboratory 9 Jun 06, 2022
Title: Graduate-Admissions-Predictor

The purpose of this project is create a predictive model capable of identifying the probability of a person securing an admit based on their personal profile parameters. Simplified visualisations hav

Akarsh Singh 1 Jan 26, 2022
Supervised forecasting of sequential data in Python.

Supervised forecasting of sequential data in Python. Intro Supervised forecasting is the machine learning task of making predictions for sequential da

The Alan Turing Institute 54 Nov 15, 2022
Very Deep Convolutional Networks for Large-Scale Image Recognition

pytorch-vgg Some scripts to convert the VGG-16 and VGG-19 models [1] from Caffe to PyTorch. The converted models can be used with the PyTorch model zo

Justin Johnson 217 Dec 05, 2022
The open-source and free to use Python package miseval was developed to establish a standardized medical image segmentation evaluation procedure

miseval: a metric library for Medical Image Segmentation EVALuation The open-source and free to use Python package miseval was developed to establish

59 Dec 10, 2022
PyTorch wrapper for Taichi data-oriented class

Stannum PyTorch wrapper for Taichi data-oriented class PRs are welcomed, please see TODOs. Usage from stannum import Tin import torch data_oriented =

86 Dec 23, 2022
Robotics environments

Robotics environments Details and documentation on these robotics environments are available in OpenAI's blog post and the accompanying technical repo

Farama Foundation 121 Dec 28, 2022
NAS-Bench-x11 and the Power of Learning Curves

NAS-Bench-x11 NAS-Bench-x11 and the Power of Learning Curves Shen Yan, Colin White, Yash Savani, Frank Hutter. NeurIPS 2021. Surrogate NAS benchmarks

AutoML-Freiburg-Hannover 13 Nov 18, 2022
Implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork.

YOLOv4-large This is the implementation of "Scaled-YOLOv4: Scaling Cross Stage Partial Network" using PyTorch framwork. YOLOv4-CSP YOLOv4-tiny YOLOv4-

Kin-Yiu, Wong 2k Jan 02, 2023
DSTC10 Track 2 - Knowledge-grounded Task-oriented Dialogue Modeling on Spoken Conversations

DSTC10 Track 2 - Knowledge-grounded Task-oriented Dialogue Modeling on Spoken Conversations This repository contains the data, scripts and baseline co

Alexa 51 Dec 17, 2022
Framework for joint representation learning, evaluation through multimodal registration and comparison with image translation based approaches

CoMIR: Contrastive Multimodal Image Representation for Registration Framework 🖼 Registration of images in different modalities with Deep Learning 🤖

Methods for Image Data Analysis - MIDA 55 Dec 09, 2022
InferPy: Deep Probabilistic Modeling with Tensorflow Made Easy

InferPy: Deep Probabilistic Modeling Made Easy InferPy is a high-level API for probabilistic modeling written in Python and capable of running on top

PGM-Lab 141 Oct 13, 2022
MASS (Mueen's Algorithm for Similarity Search) - a python 2 and 3 compatible library used for searching time series sub-sequences under z-normalized Euclidean distance for similarity.

Introduction MASS allows you to search a time series for a subquery resulting in an array of distances. These array of distances enable you to identif

Matrix Profile Foundation 79 Dec 31, 2022
Code for our paper "Graph Pre-training for AMR Parsing and Generation" in ACL2022

AMRBART An implementation for ACL2022 paper "Graph Pre-training for AMR Parsing and Generation". You may find our paper here (Arxiv). Requirements pyt

xfbai 60 Jan 03, 2023
SMPLpix: Neural Avatars from 3D Human Models

subject0_validation_poses.mp4 Left: SMPL-X human mesh registered with SMPLify-X, middle: SMPLpix render, right: ground truth video. SMPLpix: Neural Av

Sergey Prokudin 292 Dec 30, 2022
Using multidimensional LSTM neural networks to create a forecast for Bitcoin price

Multidimensional LSTM BitCoin Time Series Using multidimensional LSTM neural networks to create a forecast for Bitcoin price. For notes around this co

Jakob Aungiers 318 Dec 14, 2022