Hierarchical probabilistic 3D U-Net, with attention mechanisms (β€”π˜ˆπ˜΅π˜΅π˜¦π˜―π˜΅π˜ͺ𝘰𝘯 𝘜-π˜•π˜¦π˜΅, π˜šπ˜Œπ˜™π˜¦π˜΄π˜•π˜¦π˜΅) and a nested decoder structure with deep supervision (β€”π˜œπ˜•π˜¦π˜΅++).

Overview

Clinically Significant Prostate Cancer Detection in bpMRI

Note: This repo will be continually updated upon future advancements and we welcome open-source contributions! Currently, it shares the TensorFlow 2.5 version of the Hierarchical Probabilistic 3D U-Net (with attention mechanisms, nested decoder structure and deep supervision), titled M1, as explored in the publication(s) listed below. Source code used for training this model, as per our original setup, carry a large number of dependencies on internal datasets, tooling, infrastructure and hardware, and their release is currently not feasible. However, an equivalent minimal adaptation has been made available. We encourage users to test out M1, identify potential areas for significant improvement and propose PRs for inclusion to this repo.

Pre-Trained Model using 1950 bpMRI with PI-RADS v2 Annotations [Training:Validation Ratio - 80:20]:
To infer lesion predictions on testing samples using the pre-trained variant (architecture in commit 58b784f) of this algorithm, please visit https://grand-challenge.org/algorithms/prostate-mri-cad-cspca/

Main Scripts
● Preprocessing Functions: tf2.5/scripts/preprocess.py
● Tensor-Based Augmentations: tf2.5/scripts/model/augmentations.py
● Training Script Template: tf2.5/scripts/train_model.py
● Basic Callbacks (e.g. LR Schedules): tf2.5/scripts/callbacks.py
● Loss Functions: tf2.5/scripts/model/losses.py
● Network Architecture: tf2.5/scripts/model/unets/networks.py

Requirements
● Complete Docker Container: anindox8/m1:latest
● Key Python Packages: tf2.5/requirements.txt

schematic Train-time schematic for the Bayesian/hierarchical probabilistic configuration of M1. L_S denotes the segmentation loss between prediction p and ground-truth Y. Additionally, L_KL, denoting the Kullback–Leibler divergence loss between prior distribution P and posterior distribution Q, is used at train-time (refer to arXiv:1905.13077). For each execution of the model, latent samples z_i ∈ Q (train-time) or z_i ∈ P (test-time) are successively drawn at increasing scales of the model to predict one segmentation mask p.

schematic Architecture schematic of M1, with attention mechanisms and a nested decoder structure with deep supervision.

Minimal Example of Model Setup in TensorFlow 2.5:
(More Details: Training CNNs in TF2: Walkthrough; TF2 Datasets: Best Practices; TensorFlow Probability)

# U-Net Definition (Note: Hyperparameters are Data-Centric -> Require Adequate Tuning for Optimal Performance)
unet_model = unets.networks.M1(\
                        input_spatial_dims =  (20,160,160),            
                        input_channels     =   3,
                        num_classes        =   2,                       
                        filters            =  (32,64,128,256,512),   
                        strides            = ((1,1,1),(1,2,2),(1,2,2),(2,2,2),(2,2,2)),  
                        kernel_sizes       = ((1,3,3),(1,3,3),(3,3,3),(3,3,3),(3,3,3)),
                        prob_latent_dims   =  (3,2,1,0)
                        dropout_rate       =   0.50,       
                        dropout_mode       =  'monte-carlo',
                        se_reduction       =  (8,8,8,8,8),
                        att_sub_samp       = ((1,1,1),(1,1,1),(1,1,1),(1,1,1)),
                        kernel_initializer =   tf.keras.initializers.Orthogonal(gain=1), 
                        bias_initializer   =   tf.keras.initializers.TruncatedNormal(mean=0, stddev=1e-3),
                        kernel_regularizer =   tf.keras.regularizers.l2(1e-4),
                        bias_regularizer   =   tf.keras.regularizers.l2(1e-4),     
                        cascaded           =   False,
                        probabilistic      =   True,
                        deep_supervision   =   True,
                        summary            =   True)  

# Schedule Cosine Annealing Learning Rate with Warm Restarts
LR_SCHEDULE = (tf.keras.optimizers.schedules.CosineDecayRestarts(\
                        initial_learning_rate=1e-3, t_mul=2.00, m_mul=1.00, alpha=1e-3,
                        first_decay_steps=int(np.ceil(((TRAIN_SAMPLES)/BATCH_SIZE)))*10))
                                                  
# Compile Model w/ Optimizer and Loss Function(s)
unet_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=LR_SCHEDULE, amsgrad=True), 
                   loss      = losses.Focal(alpha=[0.75, 0.25], gamma=2.00).loss)

# Train Model
unet_model.fit(...)

If you use this repo or some part of its codebase, please cite the following articles (see bibtex):

● A. Saha, J. Bosma, J. Linmans, M. Hosseinzadeh, H. Huisman (2021), "Anatomical and Diagnostic Bayesian Segmentation in Prostate MRI βˆ’Should Different Clinical Objectives Mandate Different Loss Functions?", Medical Imaging Meets NeurIPS Workshop – 35th Conference on Neural Information Processing Systems (NeurIPS), Sydney, Australia. (architecture in commit 914ec9d)

● A. Saha, M. Hosseinzadeh, H. Huisman (2021), "End-to-End Prostate Cancer Detection in bpMRI via 3D CNNs: Effect of Attention Mechanisms, Clinical Priori and Decoupled False Positive Reduction", Medical Image Analysis:102155. (architecture in commit 58b784f)

● A. Saha, M. Hosseinzadeh, H. Huisman (2020), "Encoding Clinical Priori in 3D Convolutional Neural Networks for Prostate Cancer Detection in bpMRI", Medical Imaging Meets NeurIPS Workshop – 34th Conference on Neural Information Processing Systems (NeurIPS), Vancouver, Canada. (architecture in commit 58b784f)

Contact: [email protected]; [email protected]

Related U-Net Architectures:
● nnU-Net: https://github.com/MIC-DKFZ/nnUNet
● Attention U-Net: https://github.com/ozan-oktay/Attention-Gated-Networks
● UNet++: https://github.com/MrGiovanni/UNetPlusPlus
● Hierarchical Probabilistic U-Net: https://github.com/deepmind/deepmind-research/tree/master/hierarchical_probabilistic_unet

Owner
Diagnostic Image Analysis Group
Diagnostic Image Analysis Group
Implementation of Geometric Vector Perceptron, a simple circuit for 3d rotation equivariance for learning over large biomolecules, in Pytorch. Idea proposed and accepted at ICLR 2021

Geometric Vector Perceptron Implementation of Geometric Vector Perceptron, a simple circuit with 3d rotation equivariance for learning over large biom

Phil Wang 59 Nov 24, 2022
Linear Variational State Space Filters

Linear Variational State Space Filters To set up the environment, use the provided scripts in the docker/ folder to build and run the codebase inside

0 Dec 13, 2021
Shape-Adaptive Selection and Measurement for Oriented Object Detection

Source Code of AAAI22-2171 Introduction The source code includes training and inference procedures for the proposed method of the paper submitted to t

houliping 24 Nov 29, 2022
Extending JAX with custom C++ and CUDA code

Extending JAX with custom C++ and CUDA code This repository is meant as a tutorial demonstrating the infrastructure required to provide custom ops in

Dan Foreman-Mackey 237 Dec 23, 2022
hipCaffe: the HIP port of Caffe

Caffe Caffe is a deep learning framework made with expression, speed, and modularity in mind. It is developed by the Berkeley Vision and Learning Cent

ROCm Software Platform 126 Dec 05, 2022
Implementation of H-UCRL Algorithm

Implementation of H-UCRL Algorithm This repository is an implementation of the H-UCRL algorithm introduced in Curi, S., Berkenkamp, F., & Krause, A. (

Sebastian Curi 25 May 20, 2022
A set of tools for Namebase and HNS

HNS-TOOLS A set of tools for Namebase and HNS To install: pip install -r requirements.txt To run: py main.py My Namebase referral code: http://namebas

RunDavidMC 7 Apr 08, 2022
An original implementation of "MetaICL Learning to Learn In Context" by Sewon Min, Mike Lewis, Luke Zettlemoyer and Hannaneh Hajishirzi

MetaICL: Learning to Learn In Context This includes an original implementation of "MetaICL: Learning to Learn In Context" by Sewon Min, Mike Lewis, Lu

Meta Research 141 Jan 07, 2023
Implementation for "Domain-Specific Bias Filtering for Single Labeled Domain Generalization"

DSBF Introduction This repository contains the implementation code for paper: Domain-Specific Bias Filtering for Single Labeled Domain Generalization

ScottYuan 7 Jan 05, 2023
Nicholas Lee 3 Jan 09, 2022
Vertex AI: Serverless framework for MLOPs (ESP / ENG)

Vertex AI: Serverless framework for MLOPs (ESP / ENG) EspaΓ±ol QuΓ© es esto? Este repo contiene un pipeline end to end diseΓ±ado usando el SDK de Kubeflo

HernΓ‘n Escudero 2 Apr 28, 2022
A3C LSTM Atari with Pytorch plus A3G design

NEWLY ADDED A3G A NEW GPU/CPU ARCHITECTURE OF A3C FOR SUBSTANTIALLY ACCELERATED TRAINING!! RL A3C Pytorch NEWLY ADDED A3G!! New implementation of A3C

David Griffis 532 Jan 02, 2023
Omniscient Video Super-Resolution

Omniscient Video Super-Resolution This is the official code of OVSR (Omniscient Video Super-Resolution, ICCV 2021). This work is based on PFNL. Datase

36 Oct 27, 2022
H&M Fashion Image similarity search with Weaviate and DocArray

H&M Fashion Image similarity search with Weaviate and DocArray This example shows how to do image similarity search using DocArray and Weaviate as Doc

Laura Ham 18 Aug 11, 2022
This repo implements several applications of the proposed generalized Bures-Wasserstein (GBW) geometry on symmetric positive definite matrices.

GBW This repo implements several applications of the proposed generalized Bures-Wasserstein (GBW) geometry on symmetric positive definite matrices. Ap

Andi Han 0 Oct 22, 2021
"SOLQ: Segmenting Objects by Learning Queries", SOLQ is an end-to-end instance segmentation framework with Transformer.

SOLQ: Segmenting Objects by Learning Queries This repository is an official implementation of the paper SOLQ: Segmenting Objects by Learning Queries.

MEGVII Research 179 Jan 02, 2023
Contains source code for the winning solution of the xView3 challenge

Winning Solution for xView3 Challenge This repository contains source code and pretrained models for my (Eugene Khvedchenya) solution to xView 3 Chall

Eugene Khvedchenya 51 Dec 30, 2022
Behavioral "black-box" testing for recommender systems

RecList RecList Free software: MIT license Documentation: https://reclist.readthedocs.io. Overview RecList is an open source library providing behavio

Jacopo Tagliabue 375 Dec 30, 2022
Code base for the paper "Scalable One-Pass Optimisation of High-Dimensional Weight-Update Hyperparameters by Implicit Differentiation"

This repository contains code for the paper Scalable One-Pass Optimisation of High-Dimensional Weight-Update Hyperparameters by Implicit Differentiati

8 Aug 28, 2022
Lightweight mmm - Lightweight (Bayesian) Media Mix Model

Lightweight (Bayesian) Media Mix Model This is not an official Google product. L

Google 342 Jan 03, 2023