Toolkit for building machine learning models that generalize to unseen domains and are robust to privacy and other attacks.

Overview

Toolkit for Building Robust ML models that generalize to unseen domains (RobustDG)

Divyat Mahajan, Shruti Tople, Amit Sharma

Privacy & Causal Learning (ICML 2020) | MatchDG: Causal View of DG (ICML 2021) | Privacy & DG Connection paper

For machine learning models to be reliable, they need to generalize to data beyond the train distribution. In addition, ML models should be robust to privacy attacks like membership inference and domain knowledge-based attacks like adversarial attacks.

To advance research in building robust and generalizable models, we are releasing a toolkit for building and evaluating ML models, RobustDG. RobustDG contains implementations of domain generalization algorithms and includes evaluation benchmarks based on out-of-distribution accuracy and robustness to membership privacy attacks. We will be adding evaluation for adversarial attacks and more privacy attacks soon.

It is easily extendable. Add your own DG algorithms and evaluate them on different benchmarks.

Installation

To use the command-line interface of RobustDG, clone this repo and add the folder to your system's PATH (or alternatively, run the commands from the RobustDG root directory).

Load dataset

Let's first load the rotatedMNIST dataset in a suitable format for the resnet18 architecture.

python data/data_gen_mnist.py --dataset rot_mnist --model resnet18 --img_h 224 --img_w 224 --subset_size 2000

Train and evaluate ML model

The following commands would train and evalute the MatchDG method on the Rotated MNIST dataset.

python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --epochs 50 --batch_size 64 --pos_metric cos --match_func_aug_case 1

python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25

python test.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.0 --ctr_match_flag 1 --ctr_match_interrupt 5 --ctr_model_name resnet18 --epochs 25 --test_metric acc

python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.0 --match_flag 1 --pos_metric cos --test_metric match_score

Demo

A quick introduction on how to use our repository can be accessed here in the Getting Started notebook.

If you are interested in reproducing results from the MatchDG paper, check out the Reproducing results notebook.

Roadmap

  • Support for more domain generalization algorithms like CSD and IRM. If you are an author of a DG algorithm and would like to contribute, please raise a pull request here or get in touch.
  • More evaluation metrics based on adversarial attacks, privacy attacks like model inversion. If you'd like to see an evaluation metric implemented, please raise an issue here.

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
Iris species predictor app is used to classify iris species created using python's scikit-learn, fastapi, numpy and joblib packages.

Iris Species Predictor Iris species predictor app is used to classify iris species using their sepal length, sepal width, petal length and petal width

Siva Prakash 5 Apr 05, 2022
Land Cover Classification Random Forest

You can perform Land Cover Classification on Satellite Images using Random Forest and visualize the result using Earthpy package. Make sure to install the required packages and such as

Dr. Sander Ali Khowaja 1 Jan 21, 2022
Uses WiFi signals :signal_strength: and machine learning to predict where you are

Uses WiFi signals and machine learning (sklearn's RandomForest) to predict where you are. Even works for small distances like 2-10 meters.

Pascal van Kooten 5k Jan 09, 2023
A Software Framework for Neuromorphic Computing

A Software Framework for Neuromorphic Computing

Lava 338 Dec 26, 2022
STUMPY is a powerful and scalable Python library for computing a Matrix Profile, which can be used for a variety of time series data mining tasks

STUMPY STUMPY is a powerful and scalable library that efficiently computes something called the matrix profile, which can be used for a variety of tim

TD Ameritrade 2.5k Jan 06, 2023
Multiple Linear Regression using the LinearRegression class from sklearn.linear_model library

Multiple-Linear-Regression-master - A python program to implement Multiple Linear Regression using the LinearRegression class from sklearn.linear model library

Kushal Shingote 1 Feb 06, 2022
Uber Open Source 1.6k Dec 31, 2022
Python package for concise, transparent, and accurate predictive modeling

Python package for concise, transparent, and accurate predictive modeling. All sklearn-compatible and easy to use. 📚 docs • 📖 demo notebooks Modern

Chandan Singh 983 Jan 01, 2023
healthy and lesion models for learning based on the joint estimation of stochasticity and volatility

health-lesion-stovol healthy and lesion models for learning based on the joint estimation of stochasticity and volatility Reference please cite this p

5 Nov 01, 2022
Data science, Data manipulation and Machine learning package.

duality Data science, Data manipulation and Machine learning package. Use permitted according to the terms of use and conditions set by the attached l

David Kundih 3 Oct 19, 2022
30 Days Of Machine Learning Using Pytorch

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

Mayur 119 Nov 24, 2022
A model to predict steering torque fully end-to-end

torque_model The torque model is a spiritual successor to op-smart-torque, which was a project to train a neural network to control a car's steering f

Shane Smiskol 4 Jun 03, 2022
Predict the output which should give a fair idea about the chances of admission for a student for a particular university

Predict the output which should give a fair idea about the chances of admission for a student for a particular university.

ArvindSandhu 1 Jan 11, 2022
Library for machine learning stacking generalization.

stacked_generalization Implemented machine learning *stacking technic[1]* as handy library in Python. Feature weighted linear stacking is also availab

114 Jul 19, 2022
PennyLane is a cross-platform Python library for differentiable programming of quantum computers

PennyLane is a cross-platform Python library for differentiable programming of quantum computers. Train a quantum computer the same way as a neural ne

PennyLaneAI 1.6k Jan 01, 2023
Machine learning that just works, for effortless production applications

Machine learning that just works, for effortless production applications

Elisha Yadgaran 16 Sep 02, 2022
CinnaMon is a Python library which offers a number of tools to detect, explain, and correct data drift in a machine learning system

CinnaMon is a Python library which offers a number of tools to detect, explain, and correct data drift in a machine learning system

Zelros 67 Dec 28, 2022
Iris-Heroku - Putting a Machine Learning Model into Production with Flask and Heroku

Puesta en Producción de un modelo de aprendizaje automático con Flask y Heroku L

Jesùs Guillen 1 Jun 03, 2022
MasTrade is a trading bot in baselines3,pytorch,gym

mastrade MasTrade is a trading bot in baselines3,pytorch,gym idea we have for example 1 btc and we buy a crypto with it with market option to trade in

Masoud Azizi 18 May 24, 2022
Simple linear model implementations from scratch.

Hand Crafted Models Simple linear model implementations from scratch. Table of contents Overview Project Structure Getting started Citing this project

Jonathan Sadighian 2 Sep 13, 2021