Image-to-image regression with uncertainty quantification in PyTorch

Overview

im2im-uq

A platform for image-to-image regression with rigorous, distribution-free uncertainty quantification.


An algorithmic MRI reconstruction with uncertainty. A rapidly acquired but undersampled MR image of a knee (A) is fed into a model that predicts a sharp reconstruction (B) along with a calibrated notion of uncertainty (C). In (C), red means high uncertainty and blue means low uncertainty. Wherever the reconstruction contains hallucinations, the uncertainty is high; see the hallucination in the image patch (E), which has high uncertainty in (F), and does not exist in the ground truth (G).

Summary

This repository provides a convenient way to train deep-learning models in PyTorch for image-to-image regression---any task where the input and output are both images---along with rigorous uncertainty quantification. The uncertainty quantification takes the form of an interval for each pixel which is guaranteed to contain most true pixel values with high-probability no matter the choice of model or the dataset used (it is a risk-controlling prediction set). The training pipeline is already built to handle more than one GPU and all training/calibration should run automatically.

The basic workflow is

  • Define your dataset in core/datasets/.
  • Create a folder for your experiment experiments/new_experiment, along with a file experiments/new_experiment/config.yml defining the model architecture, hyperparameters, and method of uncertainty quantification. You can use experiments/fastmri_test/config.yml as a template.
  • Edit core/scripts/router.py to point to your data directory.
  • From the root folder, run wandb sweep experiments/new_experiment/config.yml, and run the resulting sweep.
  • After the sweep is complete, models will be saved in experiments/new_experiment/checkpoints, the metrics will be printed to the terminal, and outputs will be in experiments/new_experiment/output/. See experiments/fastmri_test/plot.py for an example of how to make plots from the raw outputs.

Following this procedure will train one or more models (depending on config.yml) that perform image-to-image regression with rigorous uncertainty quantification.

There are two pre-baked examples that you can run on your own after downloading the open-source data: experiments/fastmri_test/config.yml and experiments/temca_test/config.yml. The third pre-baked example, experiments/bsbcm_test/config.yml, reiles on data collected at Berkeley that has not yet been publicly released (but will be soon).

Paper

Image-to-Image Regression with Distribution-Free Uncertainty Quantification and Applications in Imaging

@article{angelopoulos2022image,
  title={Image-to-Image Regression with Distribution-Free Uncertainty Quantification and Applications in Imaging},
  author={Angelopoulos, Anastasios N and Kohli, Amit P and Bates, Stephen and Jordan, Michael I and Malik, Jitendra and Alshaabi, Thayer and Upadhyayula, Srigokul and Romano, Yaniv},
  journal={arXiv preprint arXiv:2202.05265},
  year={2022}
}

Installation

You will need to execute

conda env create -f environment.yml
conda activate im2im-uq

You will also need to go through the Weights and Biases setup process that initiates when you run your first sweep. You may need to make an account on their website.

Reproducing the results

FastMRI dataset

  • Download the FastMRI dataset to your machine and unzip it. We worked with the knee_singlecoil_train dataset.
  • Edit Line 71 of core/scripts/router to point to the your local dataset.
  • From the root folder, run wandb sweep experiments/fastmri_test/config.yml
  • After the run is complete, run cd experiments/fastmri_test/plot.py to plot the results.

TEMCA2 dataset

  • Download the TEMCA2 dataset to your machine and unzip it. We worked with sections 3501 through 3839.
  • Edit Line 78 of core/scripts/router to point to the your local dataset.
  • From the root folder, run wandb sweep experiments/temca_test/config.yml
  • After the run is complete, run cd experiments/temca_test/plot.py to plot the results.

Adding a new experiment

If you want to extend this code to a new experiment, you will need to write some code compatible with our infrastructure. If adding a new dataset, you will need to write a valid PyTorch dataset object; you need to add a new model architecture, you will need to specify it; and so on. Usually, you will want to start by creating a folder experiments/new_experiment along with a config file experiments/new_experiment/config.yml. The easiest way is to start from an existing config, like experiments/fastmri_test/config.yml.

Adding new datasets

To add a new dataset, use the following procedure.

  • Download the dataset to your machine.
  • In core/datasets, make a new folder for your dataset core/datasets/new_dataset.
  • Make a valid PyTorch Dataset class for your new dataset. The most critical part is writing a __get_item__ method that returns an image-image pair in CxHxW order; see core/datasets/bsbcm/BSBCMDataset.py for a simple example.
  • Make a file core/datasets/new_dataset/__init__.py and export your dataset by adding the line from .NewDataset.py import NewDatasetClass (substituting in your filename and classname appropriately).
  • Edit core/scripts/router.py to load your new dataset, near Line 64, following the pattern therein. You will also need to import your dataset object.
  • Populate your new config file experiments/new_experiment/config.yml with the correct directories and experiment name.
  • Execute wandb sweep experiments/new_experiment/config.yml and proceed as normal!

Adding new models

In our system, there are two parts to a model---the base architecture, which we call a trunk (e.g. a U-Net), and the final layer. Defining a trunk is as simple as writing a regular PyTorch nn.module and adding it near Line 87 of core/scripts/router.py (you will also need to import it); see core/models/trunks/unet.py for an example.

The process for adding a final layer is a bit more involved. The final layer is simply a Pytorch nn.module, but it also must come with two functions: a loss function and a nested prediction set function. See core/models/finallayers/quantile_layer.py for an example. The steps are:

  • Create a final layer nn.module object. The final layer should also have a heuristic notion of uncertainty built in, like quantile outputs.
  • Specify the loss function is used to train a network with this final layer.
  • Specify a nested prediction set function that uses output of the final layer to form a prediction set. The prediction set should scale up and down with a free factor lam, which will later be calibrated. The function should have the same prototype as that on Line 34 of core/models/finallayers/quantile_layer.py for an example.
  • After creating the new final layer and related functions, add it to core/models/add_uncertainty.py as in Line 59.
  • Edit wandb sweep experiments/new_experiment/config.yml to include your new final layer, and run the sweep as normal!
Owner
Anastasios Angelopoulos
Ph.D. student at UC Berkeley AI Research.
Anastasios Angelopoulos
Image Captioning using CNN and Transformers

Image-Captioning Keras/Tensorflow Image Captioning application using CNN and Transformer as encoder/decoder. In particulary, the architecture consists

24 Dec 28, 2022
A Demo server serving Bert through ONNX with GPU written in Rust with <3

Demo BERT ONNX server written in rust This demo showcase the use of onnxruntime-rs on BERT with a GPU on CUDA 11 served by actix-web and tokenized wit

Xavier Tao 28 Jan 01, 2023
Custom TensorFlow2 implementations of forward and backward computation of soft-DTW algorithm in batch mode.

Batch Soft-DTW(Dynamic Time Warping) in TensorFlow2 including forward and backward computation Custom TensorFlow2 implementations of forward and backw

19 Aug 30, 2022
LexGLUE: A Benchmark Dataset for Legal Language Understanding in English

LexGLUE: A Benchmark Dataset for Legal Language Understanding in English ⚖️ 🏆 🧑‍🎓 👩‍⚖️ Dataset Summary Inspired by the recent widespread use of th

95 Dec 08, 2022
Puzzle-CAM: Improved localization via matching partial and full features.

Puzzle-CAM The official implementation of "Puzzle-CAM: Improved localization via matching partial and full features".

Sanghyun Jo 150 Nov 14, 2022
MixText: Linguistically-Informed Interpolation of Hidden Space for Semi-Supervised Text Classification

MixText This repo contains codes for the following paper: Jiaao Chen, Zichao Yang, Diyi Yang: MixText: Linguistically-Informed Interpolation of Hidden

GT-SALT 309 Dec 12, 2022
Repo for EMNLP 2021 paper "Beyond Preserved Accuracy: Evaluating Loyalty and Robustness of BERT Compression"

beyond-preserved-accuracy Repo for EMNLP 2021 paper "Beyond Preserved Accuracy: Evaluating Loyalty and Robustness of BERT Compression" How to implemen

Kevin Canwen Xu 10 Dec 23, 2022
OpenPCDet Toolbox for LiDAR-based 3D Object Detection.

OpenPCDet OpenPCDet is a clear, simple, self-contained open source project for LiDAR-based 3D object detection. It is also the official code release o

OpenMMLab 3.2k Dec 31, 2022
NLP From Scratch Without Large-Scale Pretraining: A Simple and Efficient Framework

NLP From Scratch Without Large-Scale Pretraining This repository contains the code, pre-trained model checkpoints and curated datasets for our paper:

Xingcheng Yao 224 Dec 08, 2022
Code for A Volumetric Transformer for Accurate 3D Tumor Segmentation

VT-UNet This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmentaion results of VT-UNet. Environmen

Himashi Amanda Peiris 114 Dec 20, 2022
Planner_backend - Academic planner application designed for students and counselors.

Planner (backend) Academic planner application designed for students and advisors.

2 Dec 31, 2021
Learning cell communication from spatial graphs of cells

ncem Features Repository for the manuscript Fischer, D. S., Schaar, A. C. and Theis, F. Learning cell communication from spatial graphs of cells. 2021

Theis Lab 77 Dec 30, 2022
SmartSim Infrastructure Library.

Home Install Documentation Slack Invite Cray Labs SmartSim SmartSim makes it easier to use common Machine Learning (ML) libraries like PyTorch and Ten

Cray Labs 139 Jan 01, 2023
Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling

Hamiltonian Dynamics with Non-Newtonian Momentum for Rapid Sampling Code for the paper: Greg Ver Steeg and Aram Galstyan. "Hamiltonian Dynamics with N

Greg Ver Steeg 25 Mar 14, 2022
Optimizes image files by converting them to webp while also updating all references.

About Optimizes images by (re-)saving them as webp. For every file it replaced it automatically updates all references. Works on single files as well

Watermelon Wolverine 18 Dec 23, 2022
Action Segmentation Evaluation

Reference Action Segmentation Evaluation Code This repository contains the reference code for action segmentation evaluation. If you have a bug-fix/im

5 May 22, 2022
Pytorch implementation of the paper Time-series Generative Adversarial Networks

TimeGAN-pytorch Pytorch implementation of the paper Time-series Generative Adversarial Networks presented at NeurIPS'19. Jinsung Yoon, Daniel Jarrett

Zhiwei ZHANG 21 Nov 24, 2022
TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffic Environments for IV 2022.

TorchGRL TorchGRL is the source code for our paper Graph Convolution-Based Deep Reinforcement Learning for Multi-Agent Decision-Making in Mixed Traffi

XXQQ 42 Dec 09, 2022
:hot_pepper: R²SQL: "Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing." (AAAI 2021)

R²SQL The PyTorch implementation of paper Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing. (AAAI 2021) Requirement

huybery 60 Dec 31, 2022
Hub is a dataset format with a simple API for creating, storing, and collaborating on AI datasets of any size.

Hub is a dataset format with a simple API for creating, storing, and collaborating on AI datasets of any size. The hub data layout enables rapid transformations and streaming of data while training m

Activeloop 5.1k Jan 08, 2023