Tilted Empirical Risk Minimization (ICLR '21)

Overview

Tilted Empirical Risk Minimization

This repository contains the implementation for the paper

Tilted Empirical Risk Minimization

ICLR 2021

Empirical risk minimization (ERM) is typically designed to perform well on the average loss, which can result in estimators that are sensitive to outliers, generalize poorly, or treat subgroups unfairly. While many methods aim to address these problems individually, in this work, we explore them through a unified framework---tilted empirical risk minimization (TERM).

This repository contains the data, code, and experiments to reproduce our empirical results. We demonstrate that TERM can be used for a multitude of applications, such as enforcing fairness between subgroups, mitigating the effect of outliers, and handling class imbalance. TERM is not only competitive with existing solutions tailored to these individual problems, but can also enable entirely new applications, such as simultaneously addressing outliers and promoting fairness.

Getting started

Dependencies

As we apply TERM to a diverse set of real-world applications, the dependencies for different applications can be different.

  • if we mention that the code is based on other public codebases, then one needs to follow the same setup of those codebases.
  • otherwise, need the following dependencies (the latest versions will work):
    • python3
    • sklearn
    • numpy
    • matplotlib
    • colorsys
    • seaborn
    • scipy
    • cvxpy (optional)

Properties of TERM

Motivating examples

These figures illustrate TERM as a function of t: (a) finding a point estimate from a set of 2D samples, (b) linear regression with outliers, and (c) logistic regression with imbalanced classes. While positive values of t magnify outliers, negative values suppress them. Setting t=0 recovers the original ERM objective.

(How to generate these figures: cd TERM/toy_example & jupyter notebook , and directly run the three notebooks.)

A toy problem to visualize the solutions to TERM

TERM objectives for a squared loss problem with N=3. As t moves from - to +, t-tilted losses recover min-loss (t-->+), avg-loss (t=0), and max-loss (t-->+), and approximate median-loss (for some t). TERM is smooth for all finite t and convex for positive t.

(How to generate this figure: cd TERM/properties & jupyter notebook , and directly run the notebook.)

How to run the code for different applications

1. Robust regression

cd TERM/robust_regression
python regression.py --obj $OBJ --corrupt 1 --noise $NOISE

where $OBJ is the objective and $NOISE is the noise level (see code for options).

2. Robust classification

cd TERM/robust_classification

3. Mitigating noisy annotators

cd TERM/noisy_annotator/pytorch_resnet_cifar10
python trainer.py --t -2  # TERM

4. Fair PCA

cd TERM/fair_pca
jupyter notebook

and directly run the notebook fair_pca_credit.ipynb.

  • built upon the public fair pca codebase
  • we directly extract the pre-processed Credit data dumped from the original matlab code, which are called data.csv, A.csv, and B.csv saved under TERM/fair_pca/multi-criteria-dimensionality-reduction-master/data/credit/.
  • dependencies: same as the fair pca code

5. Handling class imbalance

cd TERM/class_imbalance
python3 -m mnist.mnist_train_tilting --exp tilting  # TERM, common class=99.5%

6. Variance reduction for generalization

cd TERM/DRO
python variance_reduction.py --obj $OBJ $OTHER_PARAS  

where $OBJ is the objective, and $OTHER_PARAS$ are the hyperparameters associated with the objective (see code for options). We report how we select the hyperparameters along with all hyperparameter values in Appendix E of the paper. For instance, for TERM with t=50, run the following:

python variance_reduction.py --obj tilting --t 50  

7. Fair federated learning

cd TERM/fair_flearn
bash run.sh tilting 0 0 term_t0.1_seed0 > term_t0.1_seed0 2>&1 &

8. Hierarchical multi-objective tilting

cd TERM/hierarchical
python mixed_level1.py --imbalance 1 --corrupt 1 --obj tilting --t_in -2 --t_out 10  # TERM_sc
python mixed_level2.py --imbalance 1 --corrupt 1 --obj tilting --t_in 50 --t_out -2 # TERM_ca
  • mixed_level1.py: TERM_{sc}: (sample level, class level)
  • mixed_level2.py: TERM_{ca}: (class level, annotator level)

References

Please see the paper for more details of TERM as well as a complete list of related work.

Owner
Tian Li
Tian Li
Code for "Share With Thy Neighbors: Single-View Reconstruction by Cross-Instance Consistency" paper

UNICORN 🦄 Webpage | Paper | BibTex PyTorch implementation of "Share With Thy Neighbors: Single-View Reconstruction by Cross-Instance Consistency" pap

118 Jan 06, 2023
SeMask: Semantically Masked Transformers for Semantic Segmentation.

SeMask: Semantically Masked Transformers Jitesh Jain, Anukriti Singh, Nikita Orlov, Zilong Huang, Jiachen Li, Steven Walton, Humphrey Shi This repo co

Picsart AI Research (PAIR) 186 Dec 30, 2022
Reference models and tools for Cloud TPUs.

Cloud TPUs This repository is a collection of reference models and tools used with Cloud TPUs. The fastest way to get started training a model on a Cl

5k Jan 05, 2023
Improving Transferability of Representations via Augmentation-Aware Self-Supervision

Improving Transferability of Representations via Augmentation-Aware Self-Supervision Accepted to NeurIPS 2021 TL;DR: Learning augmentation-aware infor

hankook 38 Sep 16, 2022
ML-Decoder: Scalable and Versatile Classification Head

ML-Decoder: Scalable and Versatile Classification Head Paper Official PyTorch Implementation Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baru

189 Jan 04, 2023
Repository sharing code and the model for the paper "Rescoring Sequence-to-Sequence Models for Text Line Recognition with CTC-Prefixes"

Rescoring Sequence-to-Sequence Models for Text Line Recognition with CTC-Prefixes Setup virtualenv -p python3 venv source venv/bin/activate pip instal

Planet AI GmbH 9 May 20, 2022
Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

SSRL-for-image-classification Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

Feng 2 Nov 19, 2021
Demo project for real time anomaly detection using kafka and python

kafkaml-anomaly-detection Project for real time anomaly detection using kafka and python It's assumed that zookeeper and kafka are running in the loca

Rodrigo Arenas 36 Dec 12, 2022
IDRLnet, a Python toolbox for modeling and solving problems through Physics-Informed Neural Network (PINN) systematically.

IDRLnet IDRLnet is a machine learning library on top of PyTorch. Use IDRLnet if you need a machine learning library that solves both forward and inver

IDRL 105 Dec 17, 2022
Source code of the paper Meta-learning with an Adaptive Task Scheduler.

ATS About Source code of the paper Meta-learning with an Adaptive Task Scheduler. If you find this repository useful in your research, please cite the

Huaxiu Yao 16 Dec 26, 2022
Official code for paper "ISNet: Costless and Implicit Image Segmentation for Deep Classifiers, with Application in COVID-19 Detection"

Official code for paper "ISNet: Costless and Implicit Image Segmentation for Deep Classifiers, with Application in COVID-19 Detection". LRPDenseNet.py

Pedro Ricardo Ariel Salvador Bassi 2 Sep 21, 2022
Uni-Fold: Training your own deep protein-folding models.

Uni-Fold: Training your own deep protein-folding models. This package provides and implementation of a trainable, Transformer-based deep protein foldi

DeepModeling 88 Jan 03, 2023
Reimplementation of Dynamic Multi-scale filters for Semantic Segmentation.

Paddle implementation of Dynamic Multi-scale filters for Semantic Segmentation.

Hongqiang.Wang 2 Nov 01, 2021
Multi-Stage Episodic Control for Strategic Exploration in Text Games

XTX: eXploit - Then - eXplore Requirements First clone this repo using git clone https://github.com/princeton-nlp/XTX.git Please create two conda envi

Princeton Natural Language Processing 9 May 24, 2022
Discover hidden deepweb pages

DeepWeb Scapper Att: Demo version An simple script to scrappe deepweb to find pages. Will return if any of those exists and will save on a file. You s

Héber Júlio 77 Oct 02, 2022
Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021)

EMI-FGSM This repository contains code to reproduce results from the paper: Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021) Xiaosen Wa

John Hopcroft Lab at HUST 10 Sep 26, 2022
Machine learning notebooks in different subjects optimized to run in google collaboratory

Notebooks Name Description Category Link Training pix2pix This notebook shows a simple pipeline for training pix2pix on a simple dataset. Most of the

Zaid Alyafeai 363 Dec 06, 2022
Deep Reinforcement Learning for Keras.

Deep Reinforcement Learning for Keras What is it? keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seaml

Keras-RL 0 Dec 15, 2022
using STGCN to achieve egg classification task

EEG Classification   The task requires us to classify electroencephalography(EEG) into six categories, including human body, human face, animal body,

4 Jun 13, 2022
This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation".

Prompt-Based Multi-Modal Image Segmentation This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation". The sys

Timo Lüddecke 305 Dec 30, 2022