Supporting code for the paper "Dangers of Bayesian Model Averaging under Covariate Shift"

Overview

Dangers of Bayesian Model Averaging under Covariate Shift

This repository contains the code to reproduce the experiments in the paper Dangers of Bayesian Model Averaging under Covariate Shift by Pavel Izmailov, Patrick Nicholson, Sanae Lotfi and Andrew Gordon Wilson.

The code is forked from the Google Research BNN HMC repo.

Introduction

Approximate Bayesian inference for neural networks is considered a robust alternative to standard training, often providing good performance on out-of-distribution data. However, it was recently shown that Bayesian neural networks (BNNs) with high fidelity inference through Hamiltonian Monte Carlo (HMC) provide shockingly poor performance under covariate shift. For example, below we show that a ResNet-20 BNN approximated with HMC underperforms a maximum a-posteriori (MAP) solution by 25% on the pixelate-corrupted CIFAR-10 test set. This result is particularly surprising given that on the in-distribution test data, the BNN outperforms the MAP solution by over 5%. In this work, we seek to understand, further demonstrate, and help remedy this concerning behaviour.

As an example, let us consider a fully-connected network on MNIST. MNIST contains many dead pixels, i.e. pixels near the boundary that are zero for all training images. The corresponding weights in the first layer of the network are always multiplied by zero, and have no effect on the likelihood of the training data. Consequently, in a Bayesian neural network, these weights will be sampled from the prior. A MAP solution on the other hand will set these parameters close to zero. In the animation, we visualize the weights in the first layer of a Bayesian neural network and a MAP solution. For each sample, we show the value of the weight corresponding to the highlighted pixel.

If at test time the data is corrupted, e.g. by Gaussian noise, and the pixels near the boundary of the image are activated, the MAP solution will ignore these pixels, while the predictions of the BNN will be significantly affected.

In the paper, we extend this reasoning to general linear dependencies between input features for both fully connected and convolutional Bayesian neural networks. We also propose EmpCov, a prior based on the empirical covariance of the data which significantly improves robustness of BNNs to covariate shift. We implement EmpCov as well as other priors for Bayesian neural networks in this repo.

Requirements

We use provide a requirements.txt file that can be used to create a conda environment to run the code in this repo:

conda create --name <env> --file requirements.txt

Example set-up using pip:

pip install tensorflow

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.65+cuda112 -f \
https://storage.googleapis.com/jax-releases/jax_releases.html

pip install git+https://github.com/deepmind/dm-haiku
pip install tensorflow_datasets
pip install tabulate
pip install optax

Please see the JAX repo for the latest instructions on how to install JAX on your hardware.

File Structure

The implementations of HMC and other methods forked from the BNN HMC repo are in the bnn_hmc folder. The main training scripts are run_hmc.py for HMC and run_sgd.py for SGD respectively. In the notebooks folder we show examples of how to extract the covariance matrices for EmpCov priors, and evaluate the results under various corruptions.

.
+-- bnn_hmc/
|   +-- core/
|   |   +-- hmc.py (The Hamiltonian Monte Carlo algorithm)
|   |   +-- sgmcmc.py (SGMCMC methods as optax optimizers)
|   |   +-- vi.py (Mean field variational inference)
|   +-- utils/ (Utility functions used by the training scripts)
|   |   +-- train_utils.py (The training epochs and update rules)
|   |   +-- models.py (Models used in the experiments)
|   |   +-- losses.py (Prior and likelihood functions)
|   |   +-- data_utils.py (Loading and pre-processing the data)
|   |   +-- optim_utils.py (Optimizers and learning rate schedules)
|   |   +-- ensemble_utils.py (Implementation of ensembling of predictions)
|   |   +-- metrics.py (Metrics used in evaluation)
|   |   +-- cmd_args_utils.py (Common command line arguments)
|   |   +-- script_utils.py (Common functionality of the training scripts)
|   |   +-- checkpoint_utils.py (Saving and loading checkpoints)
|   |   +-- logging_utils.py (Utilities for logging printing the results)
|   |   +-- precision_utils.py (Controlling the numerical precision)
|   |   +-- tree_utils.py (Common operations on pytree objects)
+-- notebooks/  
|   +-- cnn_robustness_cifar10.ipynb (Creates CIFAR-10 CNN figures used in paper)  
|   +-- mlp_robustness_mnist.ipynb (Creates MNIST MLP figures used in paper)
|   +-- cifar10_cnn_extract_empcov.ipynb (Constructs EmpCov prior covariance matrix for CIFAR-10 CNN)
|   +-- mnist_extract_empcov.ipynb (Constructs EmpCov prior covariance matrices for CIFAR-10 CNN and MLP)
+-- empcov_covs/
|   +-- cifar_cnn_pca_inv_cov.npy (EmpCov inverse prior covariance for CIFAR-10 CNN)
|   +-- mnist_cnn_pca_inv_cov.npy (EmpCov inverse prior covariance for MNIST CNN)
|   +-- mnist_mlp_pca_inv_cov.npy (EmpCov inverse prior covariance for MNIST MLP)
+-- run_hmc.py (HMC training script)
+-- run_sgd.py (SGD training script)

Training Scripts

The training scripts are adapted from the Google Research BNN HMC repo. For completeness, we provide full details about the command line arguments here.

Common command line arguments:

  • seed — random seed
  • dir — training directory for saving the checkpoints and tensorboard logs
  • dataset_name — name of the dataset, e.g. cifar10, cifar100, mnist
  • subset_train_to — number of datapoints to use from the dataset; by default, the full dataset is used
  • model_name — name of the neural network architecture, e.g. lenet, resnet20_frn_swish, cnn_lstm, mlp_regression_small
  • weight_decay — weight decay; for Bayesian methods, weight decay determines the prior variance (prior_var = 1 / weight_decay)
  • temperature — posterior temperature (default: 1)
  • init_checkpoint — path to the checkpoint to use for initialization (optional)
  • tabulate_freq — frequency of tabulate table header logging
  • use_float64 — use float64 precision (does not work on TPUs and some GPUs); by default, we use float32 precision
  • prior_family — type of prior to use; must be one of Gaussian, ExpFNormP, Laplace, StudentT, SumFilterLeNet, EmpCovLeNet or EmpCovMLP; see the next section for more details

Prior Families

In this repo we implement several prior distribution families. Some of the prior families have additional command line arguments specifying the parameters of the prior:

  • Gaussian — iid Gaussian prior centered at 0 with variance equal to 1 / weight_decay
  • Laplace — iid Laplace prior centered at 0 with variance equal to 1 / weight_decay
  • StudentT — iid Laplace prior centered at 0 with studentt_degrees_of_freedom degrees of freedom and scaled by 1 / weight_decay
  • ExpFNormP — iid ExpNorm prior centered at 0 defined in the paper. expfnormp_power specifies the power under the exponent in the prior, and 1 / weight_decay defines the scale of the prior
  • EmpCovLeNet and EmpCovMLPEmpCov priors with the inverse of empirical covariance matrix of the data as a .npy array provided as empcov_invcov_ckpt; empcov_wd allows to rescale the covariance matrix for the first layer.
  • SumFilterLeNetSumFilter prior presented in the paper; 1 / sumfilterlenet_weight_decay determines the prior variance for the sum of the filter weights in the first layer

Some prior types require additional arguments, such as empcov_pca_wd and studentt_degrees_of_freedom; run scripts with --help for full details.

Running HMC

To run HMC, you can use the run_hmc.py training script. Arguments:

  • step_size — HMC step size
  • trajectory_len — HMC trajectory length
  • num_iterations — Total number of HMC iterations
  • max_num_leapfrog_steps — Maximum number of leapfrog steps allowed; meant as a sanity check and should be greater than trajectory_len / step_size
  • num_burn_in_iterations — Number of burn-in iterations (default: 0)

Examples

CNN on CIFAR-10 with different priors:

# Gaussian prior
python3 run_hmc.py --seed=0 --weight_decay=100 --temperature=1. \
  --dir=runs/hmc/cifar10/gaussian/ --dataset_name=cifar10 \
  --model_name=lenet --step_size=3.e-05 --trajectory_len=0.15 \
  --num_iterations=100 --max_num_leapfrog_steps=5300 \
  --num_burn_in_iterations=10

# Laplace prior
python3 run_hmc.py --seed=0 --weight_decay=100 --temperature=1. \
  --dir=runs/hmc/cifar10/laplace --dataset_name=cifar10 \
  --model_name=lenet --step_size=3.e-05 --trajectory_len=0.15 \
  --num_iterations=100 --max_num_leapfrog_steps=5300 \
  --num_burn_in_iterations=10 --prior_family=Laplace

# Gaussian prior, T=0.1
python3  run_hmc.py --seed=0 --weight_decay=3 --temperature=0.01 \
  --dir=runs/hmc/cifar10/lenet/temp --dataset_name=cifar10 \
  --model_name=lenet --step_size=1.e-05 --trajectory_len=0.1 \
  --num_iterations=100 --max_num_leapfrog_steps=10000 \
  --num_burn_in_iterations=10

# EmpCov prior
python3 run_hmc.py --seed=0 --weight_decay=100. --temperature=1. \
  --dir=runs/hmc/cifar10/EmpCov --dataset_name=cifar10 \
  --model_name=lenet --step_size=1.e-4 --trajectory_len=0.157 \ 
  --num_iterations=100 --max_num_leapfrog_steps=2000 \
  --num_burn_in_iterations=10 --prior_family=EmpCovLeNet \
  --empcov_invcov_ckpt=empcov_covs/cifar_cnn_pca_inv_cov.npy \
  --empcov_wd=100.

We ran these commands on a machine with 8 NVIDIA Tesla V-100 GPUs.

MLP on MNIST using different priors:

# Gaussian prior
python3 run_hmc.py --seed=2 --weight_decay=100  \
  --dir=runs/hmc/mnist/gaussian \
  --dataset_name=mnist --model_name=mlp_classification \
  --step_size=1.e-05 --trajectory_len=0.15 \
  --num_iterations=100 --max_num_leapfrog_steps=15500 \
  --num_burn_in_iterations=10

# Laplace prior
python3 run_hmc.py --seed=0 --weight_decay=3.0 \
  --dir=runs/hmc/mnist/laplace --dataset_name=mnist \
  --model_name=mlp_classification --step_size=6.e-05 \
  --trajectory_len=0.9 --num_iterations=100 \
  --max_num_leapfrog_steps=15500 \
  --num_burn_in_iterations=10 --prior_family=Laplace

# Student-T prior
python3 run_hmc.py --seed=0 --weight_decay=10. \
  --dir=runs/hmc/mnist/studentt --dataset_name=mnist \
  --model_name=mlp_classification --step_size=1.e-4 --trajectory_len=0.49 \ 
  --num_iterations=100 --max_num_leapfrog_steps=5000 \
  --num_burn_in_iterations=10 --prior_family=StudentT \
  --studentt_degrees_of_freedom=5.

# Gaussian prior, T=0.1
python3 run_hmc.py --seed=11 --weight_decay=100 \
  --temperature=0.01 --dir=runs/hmc/mnist/temp \
  --dataset_name=mnist --model_name=mlp_classification \
  --step_size=6.3e-07 --trajectory_len=0.015 \
  --num_iterations=100 --max_num_leapfrog_steps=25500 \
  --num_burn_in_iterations=10

# EmpCov prior
python3 run_hmc.py --seed=0 --weight_decay=100 \
  --dir=runs/hmc/mnist/empcov --dataset_name=mnist \
  --model_name=mlp_classification --step_size=1.e-05 \
  --trajectory_len=0.15 --num_iterations=100 \
  --max_num_leapfrog_steps=15500 \
  --num_burn_in_iterations=10 --prior_family=EmpCovMLP \
  --empcov_invcov_ckpt=empcov_covs/mnist_mlp_pca_inv_cov.npy \
  --empcov_wd=100  

This script can be ran on a single GPU or a TPU V3-8.

Running SGD

To run SGD, you can use the run_sgd.py training script. Arguments:

  • init_step_size — Initial SGD step size; we use a cosine schedule
  • num_epochs — total number of SGD epochs iterations
  • batch_size — batch size
  • eval_freq — frequency of evaluation (epochs)
  • save_freq — frequency of checkpointing (epochs)
  • momentum_decay — momentum decay parameter for SGD

Examples

MLP on MNIST:

python3 run_sgd.py --seed=0 --weight_decay=100 --dir=runs/sgd/mnist/ \
  --dataset_name=mnist --model_name=mlp_classification \
  --init_step_size=1e-7 --eval_freq=10 --batch_size=80 \
  --num_epochs=100 --save_freq=100

CNN on CIFAR-10:

python3 run_sgd.py --seed=0 --weight_decay=100. --dir=runs/sgd/cifar10/lenet \
  --dataset_name=cifar10 --model_name=lenet --init_step_size=1e-7 --batch_size=80 \
  --num_epochs=300 --save_freq=300

To train a deep ensemble, we simply train multiple copies of SGD with different random seeds.

Results

We consider the corrupted versions of the MNIST and CIFAR-10 datasets with both fully-connected (mlp_classification) and convolutional (lenet) architectures. Additionally, we consider domain shift problems from MNIST to SVHN and from CIFAR-10 to STL-10. We apply the EmpCov prior to the first layer of Bayesian neural networks (BNNs), and a Gaussian prior to all other layers using the commands in the examples. The following figure shows the results for: deep ensembles, maximum-a-posterior estimate obtained through SGD, BNNs with a Gaussian prior, and BNNs with our novel EmpCov prior. EmpCov prior improves the robustness of BNNs to covariate shift, leading to better results on most corruptions and a competitive performance with deep ensembles for both fully-connected and convolutional architectures.

combined_resolution png-1

Owner
Pavel Izmailov
Pavel Izmailov
Project Tugas Besar pertama Pengenalan Komputasi Institut Teknologi Bandung

Vending_Machine_(Mesin_Penjual_Minuman) Project Tugas Besar pertama Pengenalan Komputasi Institut Teknologi Bandung Raw Sketch untuk Essay Ringkasan P

QueenLy 1 Nov 08, 2021
The official implementation of paper Siamese Transformer Pyramid Networks for Real-Time UAV Tracking, accepted by WACV22

SiamTPN Introduction This is the official implementation of the SiamTPN (WACV2022). The tracker intergrates pyramid feature network and transformer in

Robotics and Intelligent Systems Control @ NYUAD 29 Jan 08, 2023
ScaleNet: A Shallow Architecture for Scale Estimation

ScaleNet: A Shallow Architecture for Scale Estimation Repository for the code of ScaleNet paper: "ScaleNet: A Shallow Architecture for Scale Estimatio

Axel Barroso 34 Nov 09, 2022
MetaAvatar: Learning Animatable Clothed Human Models from Few Depth Images

MetaAvatar: Learning Animatable Clothed Human Models from Few Depth Images This repository contains the implementation of our paper MetaAvatar: Learni

sfwang 96 Dec 13, 2022
A graph adversarial learning toolbox based on PyTorch and DGL.

GraphWar: Arms Race in Graph Adversarial Learning NOTE: GraphWar is still in the early stages and the API will likely continue to change. 🚀 Installat

Jintang Li 54 Jan 05, 2023
CAUSE: Causality from AttribUtions on Sequence of Events

CAUSE: Causality from AttribUtions on Sequence of Events

Wei Zhang 21 Dec 01, 2022
TEDSummary is a speech summary corpus. It includes TED talks subtitle (Document), Title-Detail (Summary), speaker name (Meta info), MP4 URL, and utterance id

TEDSummary is a speech summary corpus. It includes TED talks subtitle (Document), Title-Detail (Summary), speaker name (Meta info), MP4 URL

3 Dec 26, 2022
Lipstick ain't enough: Beyond Color-Matching for In-the-Wild Makeup Transfer (CVPR 2021)

Table of Content Introduction Datasets Getting Started Requirements Usage Example Training & Evaluation CPM: Color-Pattern Makeup Transfer CPM is a ho

VinAI Research 248 Dec 13, 2022
Facebook Research 605 Jan 02, 2023
Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Şebnem 6 Jan 18, 2022
Code for "FPS-Net: A convolutional fusion network for large-scale LiDAR point cloud segmentation".

FPS-Net Code for "FPS-Net: A convolutional fusion network for large-scale LiDAR point cloud segmentation", accepted by ISPRS journal of Photogrammetry

15 Nov 30, 2022
Repositorio de los Laboratorios de Análisis Numérico / Análisis Numérico I de FAMAF, UNC.

Repositorio de los Laboratorios de Análisis Numérico / Análisis Numérico I de FAMAF, UNC. Para los Laboratorios de la materia, vamos a utilizar el len

Luis Biedma 18 Dec 12, 2022
A simple code to convert image format and channel as well as resizing and renaming multiple images.

Rename-Resize-and-convert-multiple-images A simple code to convert image format and channel as well as resizing and renaming multiple images. This cod

Happy N. Monday 3 Feb 15, 2022
Hierarchical User Intent Graph Network for Multimedia Recommendation

Hierarchical User Intent Graph Network for Multimedia Recommendation This is our Pytorch implementation for the paper: Hierarchical User Intent Graph

6 Jan 05, 2023
ParaGen is a PyTorch deep learning framework for parallel sequence generation

ParaGen is a PyTorch deep learning framework for parallel sequence generation. Apart from sequence generation, ParaGen also enhances various NLP tasks, including sequence-level classification, extrac

Bytedance Inc. 169 Dec 22, 2022
A curated list of awesome open source libraries to deploy, monitor, version and scale your machine learning

Awesome production machine learning This repository contains a curated list of awesome open source libraries that will help you deploy, monitor, versi

The Institute for Ethical Machine Learning 12.9k Jan 04, 2023
KSAI Lite is a deep learning inference framework of kingsoft, based on tensorflow lite

KSAI Lite is a deep learning inference framework of kingsoft, based on tensorflow lite

80 Dec 27, 2022
Source code to accompany Defunctland's video "FASTPASS: A Complicated Legacy"

Shapeland Simulator Source code to accompany Defunctland's video "FASTPASS: A Complicated Legacy" Download the video at https://www.youtube.com/watch?

TouringPlans.com 70 Dec 14, 2022
A pytorch-version implementation codes of paper: "BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation"

BSN++: Complementary Boundary Regressor with Scale-Balanced Relation Modeling for Temporal Action Proposal Generation A pytorch-version implementation

11 Oct 08, 2022
existing and custom freqtrade strategies supporting the new hyperstrategy format.

freqtrade-strategies Description Existing and self-developed strategies, rewritten to support the new HyperStrategy format from the freqtrade-develop

39 Aug 20, 2021