Official implementation of Representer Point Selection via Local Jacobian Expansion for Post-hoc Classifier Explanation of Deep Neural Networks and Ensemble Models at NeurIPS 2021

Related tags

Deep LearningRPS_LJE
Overview

Representer Point Selection via Local Jacobian Expansion for Classifier Explanation of Deep Neural Networks and Ensemble Models

This repository is the official implementation of Representer Point Selection via Local Jacobian Expansion for Classifier Explanation of Deep Neural Networks and Ensemble Models at NeurIPS 2021. (will update the link)

Introduction

We propose a novel sample-based explanation method for classifiers with a novel derivation of representer point with Taylor Expansion on the Jacobian matrix.

If you would like to cite this work, a sample bibtex citation is as following:

@inproceedings{yi2021representer,
 author = {Yi Sui, Ga Wu, Scott Sanner},
 booktitle = {Advances in Neural Information Processing Systems},
 title = {Representer Point Selection via Local Jacobian Expansion for Classifier Explanation of Deep Neural Networks and Ensemble Models},
 year = {2021}
}

Set up

To install requirements:

pip install -r requirements.txt

Change the root path in config.py to the path to the project

project_root = #your path here

Download the pre-trained models and calculated weights here

  • Dowload and unzip the saved_models_MODEL_NAME
  • Put the content into the corresponding folders ("models/ MODEL_NAME /saved_models")

Training

In our paper, we run experiment with three tasks

  • CIFAR image classification with ResNet-20 (CNN)
  • IMDB sentiment classification with Bi-LSTM (RNN)
  • German credit analysis with XGBoost (Xgboost)

The models are implemented in the models directory with pre-trained weights under "models/ MODEL_NAME /saved_models/base" : ResNet (CNN), Bi-LSTM (RNN), and XGBoost.

To train theses model(s) in the paper, run the following commands:

python models/CNN/train.py --lr 0.01 --epochs 10 --saved_path saved_models/base
python models/RNN/train.py --lr 1e-3 --epochs 10 --saved_path saved_models/base --use_pretrained True
python models/Xgboost/train.py

Caculate weights

We implemented three different explainers: RPS-LJE, RPS-l2 (modified from official repository of RPS-l2), and Influence Function. To calculate the importance weights, run the following commands:

python explainer/calculate_ours_weights.py --model CNN --lr 0.01
python explainer/calculate_representer_weights.py --model RNN --lmbd 0.003 --epoch 3000
python explainer/calculate_influence.py --model Xgboost

Experiments

Dataset debugging experiment

To run the dataset debugging experiments, run the following commands:

python dataset_debugging/experiment_dataset_debugging_cnn.py --num_of_run 10 --flip_portion 0.2 --path ../models/CNN/saved_models/experiment_dataset_debugging --lr 1e-5
python dataset_debugging/experiment_dataset_debugging_cnn.py --num_of_run 10 --flip_portion 0.2 --path ../models/CNN/saved_models/experiment_dataset_debugging_fix_random_split --lr 1e-5 --seed 11

python dataset_debugging/experiment_dataset_debugging_rnn.py --num_of_run 10 --flip_portion 0.2 --path ../models/RNN/saved_models/experiment_dataset_debugging --lr 1e-5

python dataset_debugging/experiment_dataset_debugging_Xgboost.py --num_of_run 10 --flip_portion 0.3 --path ../models/Xgboost/saved_models/experiment_dataset_debugging --lr 1e-5

The trained models, intermediate outputs, explainer weights, and accuracies at each checkpoint are stored under the specified paths "models/MODEL_NAME/saved_models/experiment_dataset_debugging". To visualize the results, run the notebooks plot_res_cnn.ipynb, plot_res_cnn_fixed_random_split.ipynb, plot_res_rnn.ipynb, plot_res_xgboost.ipynb. The results are saved under folder dataset_debugging/figs.

Other experiments

All remaining experiments are in Jupyter-notebooks organized under "models/ MODEL_NAME /experiments" : ResNet (CNN), Bi-LSTM (RNN), and XGBoost.

A comparison of explanation provided by Influence Function, RPS-l2, and RPS-LJE. Explanation for Image Classification

Owner
Yi(Amy) Sui
Yi(Amy) Sui
An intuitive library to extract features from time series

Time Series Feature Extraction Library Intuitive time series feature extraction This repository hosts the TSFEL - Time Series Feature Extraction Libra

Associação Fraunhofer Portugal Research 589 Jan 04, 2023
The project is an official implementation of our CVPR2019 paper "Deep High-Resolution Representation Learning for Human Pose Estimation"

Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) News [2020/07/05] A very nice blog from Towards Data Science introd

Leo Xiao 3.9k Jan 05, 2023
EdiBERT is a generative model based on a bi-directional transformer, suited for image manipulation

EdiBERT, a generative model for image editing EdiBERT is a generative model based on a bi-directional transformer, suited for image manipulation. The

16 Dec 07, 2022
Code for ICDM2020 full paper: "Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning"

Subg-Con Sub-graph Contrast for Scalable Self-Supervised Graph Representation Learning (Jiao et al., ICDM 2020): https://arxiv.org/abs/2009.10273 Over

34 Jul 06, 2022
Medical Insurance Cost Prediction using Machine earning

Medical-Insurance-Cost-Prediction-using-Machine-learning - Here in this project, I will use regression analysis to predict medical insurance cost for people in different regions, and based on several

1 Dec 27, 2021
Code for our paper 'Generalized Category Discovery'

Generalized Category Discovery This repo is a placeholder for code for our paper: Generalized Category Discovery Abstract: In this paper, we consider

107 Dec 28, 2022
YOLTv4 builds upon YOLT and SIMRDWN, and updates these frameworks to use the most performant version of YOLO, YOLOv4

YOLTv4 builds upon YOLT and SIMRDWN, and updates these frameworks to use the most performant version of YOLO, YOLOv4. YOLTv4 is designed to detect objects in aerial or satellite imagery in arbitraril

Adam Van Etten 161 Jan 06, 2023
Predicting future trajectories of people in cameras of novel scenarios and views.

Pedestrian Trajectory Prediction Predicting future trajectories of pedestrians in cameras of novel scenarios and views. This repository contains the c

8 Sep 03, 2022
Multi-modal Content Creation Model Training Infrastructure including the FACT model (AI Choreographer) implementation.

AI Choreographer: Music Conditioned 3D Dance Generation with AIST++ [ICCV-2021]. Overview This package contains the model implementation and training

Google Research 365 Dec 30, 2022
Graduation Project

Gesture-Detection-and-Depth-Estimation This is my graduation project. (1) In this project, I use the YOLOv3 object detection model to detect gesture i

ChaosAT 1 Nov 23, 2021
[ICML 2022] The official implementation of Graph Stochastic Attention (GSAT).

Graph Stochastic Attention (GSAT) The official implementation of GSAT for our paper: Interpretable and Generalizable Graph Learning via Stochastic Att

85 Nov 27, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
Laser device for neutralizing - mosquitoes, weeds and pests

Laser device for neutralizing - mosquitoes, weeds and pests (in progress) Here I will post information for creating a laser device. A warning!! How It

Ildaron 1k Jan 02, 2023
Training a Resilient Q-Network against Observational Interference, Causal Inference Q-Networks

Obs-Causal-Q-Network AAAI 2022 - Training a Resilient Q-Network against Observational Interference Preprint | Slides | Colab Demo | Environment Setup

23 Nov 21, 2022
EfficientDet (Scalable and Efficient Object Detection) implementation in Keras and Tensorflow

EfficientDet This is an implementation of EfficientDet for object detection on Keras and Tensorflow. The project is based on the official implementati

1.3k Dec 19, 2022
Code for the paper "Reinforced Active Learning for Image Segmentation"

Reinforced Active Learning for Image Segmentation (RALIS) Code for the paper Reinforced Active Learning for Image Segmentation Dependencies python 3.6

Arantxa Casanova 79 Dec 19, 2022
Deep Learning Specialization by Andrew Ng, deeplearning.ai.

Deep Learning Specialization on Coursera Master Deep Learning, and Break into AI This is my personal projects for the course. The course covers deep l

Engen 1.5k Jan 07, 2023
Improving Query Representations for DenseRetrieval with Pseudo Relevance Feedback:A Reproducibility Study.

APR The repo for the paper Improving Query Representations for DenseRetrieval with Pseudo Relevance Feedback:A Reproducibility Study. Environment setu

ielab 8 Nov 26, 2022
CSPML (crystal structure prediction with machine learning-based element substitution)

CSPML (crystal structure prediction with machine learning-based element substitution) CSPML is a unique methodology for the crystal structure predicti

8 Dec 20, 2022
GDR-Net: Geometry-Guided Direct Regression Network for Monocular 6D Object Pose Estimation. (CVPR 2021)

GDR-Net This repo provides the PyTorch implementation of the work: Gu Wang, Fabian Manhardt, Federico Tombari, Xiangyang Ji. GDR-Net: Geometry-Guided

169 Jan 07, 2023