Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Related tags

Deep LearningUTNet
Overview

UTNet (Accepted at MICCAI 2021)

Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Introduction

Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing medical image segmentation. UTNet applies self-attention modules in both encoder and decoder for capturing long-range dependency at dif- ferent scales with minimal overhead. To this end, we propose an efficient self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2) to approximate O(n). A new self-attention decoder is also proposed to recover fine-grained details from the skipped connections in the encoder. Our approach addresses the dilemma that Transformer requires huge amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. We have evaluated UTNet on the multi- label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet demonstrates superior segmentation performance and robustness against the state-of-the-art approaches, holding the promise to generalize well on other medical image segmentations.

image image

Supportting models

UTNet

TransUNet

ResNet50-UTNet

ResNet50-UNet

SwinUNet

To be continue ...

Getting Started

Currently, we only support M&Ms dataset.

Prerequisites

Python >= 3.6
pytorch = 1.8.1
SimpleITK = 2.0.2
numpy = 1.19.5
einops = 0.3.2

Preprocess

Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'

Training

The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.

UTNet

For default UTNet setting, training with:

python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss

Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.

To optimize UTNet in your own task, there are several hyperparameters to tune:

'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.

'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.

'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.

'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.

Here are some recomended parameter setting:

--block_list 1234 --num_blocks 1,1,1,1

Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.

--block_list 1234 --num_blocks 1,1,4,8

Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.

--block_list 234 --num_blocks 2,4,8

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.

TransUNet

We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the original repo. The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.

python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0

ResNet50-UTNet

For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0

Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.

--block_list 23 --num_blocks 2,4

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

ResNet50-UNet

If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list ''  --gpu 0

SwinUNet

Download pre-trained model from the origin repo. As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.

python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224

Citation

If you find this repo helps, please kindly cite our paper, thanks!

@inproceedings{gao2021utnet,
  title={UTNet: a hybrid transformer architecture for medical image segmentation},
  author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={61--71},
  year={2021},
  organization={Springer}
}
A modular framework for vision & language multimodal research from Facebook AI Research (FAIR)

MMF is a modular framework for vision and language multimodal research from Facebook AI Research. MMF contains reference implementations of state-of-t

Facebook Research 5.1k Jan 04, 2023
The pure and clear PyTorch Distributed Training Framework.

The pure and clear PyTorch Distributed Training Framework. Introduction Requirements and Usage Dependency Dataset Basic Usage Slurm Cluster Usage Base

WILL LEE 208 Dec 20, 2022
Dist2Dec: A Simplicial Neural Network for Homology Localization

Dist2Dec: A Simplicial Neural Network for Homology Localization

Alexandros Keros 6 Jun 12, 2022
Implicit Model Specialization through DAG-based Decentralized Federated Learning

Federated Learning DAG Experiments This repository contains software artifacts to reproduce the experiments presented in the Middleware '21 paper "Imp

Operating Systems and Middleware Group 5 Oct 16, 2022
ML model to classify between cats and dogs

Cats-and-dogs-classifier This is my first ML model which can classify between cats and dogs. Here the accuracy is around 75%, however , the accuracy c

Sharath V 4 Aug 20, 2021
PyTorch code for 'Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning'

Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning This repository is for EMSRDPN introduced in the foll

7 Feb 10, 2022
HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images

HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images Histological Image Segmentation This

Saad Wazir 11 Dec 16, 2022
Fluency ENhanced Sentence-bert Evaluation (FENSE), metric for audio caption evaluation. And Benchmark dataset AudioCaps-Eval, Clotho-Eval.

FENSE The metric, Fluency ENhanced Sentence-bert Evaluation (FENSE), for audio caption evaluation, proposed in the paper "Can Audio Captions Be Evalua

Zhiling Zhang 13 Dec 23, 2022
Code and data accompanying our SVRHM'21 paper.

Code and data accompanying our SVRHM'21 paper. Requires tensorflow 1.13, python 3.7, scikit-learn, and pytorch 1.6.0 to be installed. Python scripts i

5 Nov 17, 2021
Bytedance Inc. 2.5k Jan 06, 2023
FS-Mol: A Few-Shot Learning Dataset of Molecules

FS-Mol is A Few-Shot Learning Dataset of Molecules, containing molecular compounds with measurements of activity against a variety of protein targets. The dataset is presented with a model evaluation

Microsoft 114 Dec 15, 2022
Train a state-of-the-art yolov3 object detector from scratch!

TrainYourOwnYOLO: Building a Custom Object Detector from Scratch This repo let's you train a custom image detector using the state-of-the-art YOLOv3 c

AntonMu 616 Jan 08, 2023
Sample and Computation Redistribution for Efficient Face Detection

Introduction SCRFD is an efficient high accuracy face detection approach which initially described in Arxiv. Performance Precision, flops and infer ti

Sajjad Aemmi 13 Mar 05, 2022
Volsdf - Volume Rendering of Neural Implicit Surfaces

Volume Rendering of Neural Implicit Surfaces Project Page | Paper | Data This re

Lior Yariv 221 Jan 07, 2023
Large-scale Hyperspectral Image Clustering Using Contrastive Learning, CIKM 21 Workshop

Spectral-spatial contrastive clustering (SSCC) Yaoming Cai, Yan Liu, Zijia Zhang, Zhihua Cai, and Xiaobo Liu, Large-scale Hyperspectral Image Clusteri

Yaoming Cai 4 Nov 02, 2022
Implementation of "GNNAutoScale: Scalable and Expressive Graph Neural Networks via Historical Embeddings" in PyTorch

PyGAS: Auto-Scaling GNNs in PyG PyGAS is the practical realization of our G NN A uto S cale (GAS) framework, which scales arbitrary message-passing GN

Matthias Fey 139 Dec 25, 2022
OpenAi's gym environment wrapper to vectorize them with Ray

Ray Vector Environment Wrapper You would like to use Ray to vectorize your environment but you don't want to use RLLib ? You came to the right place !

Pierre TASSEL 15 Nov 10, 2022
🔮 Execution time predictions for deep neural network training iterations across different GPUs.

Habitat: A Runtime-Based Computational Performance Predictor for Deep Neural Network Training Habitat is a tool that predicts a deep neural network's

Geoffrey Yu 44 Dec 27, 2022
Safe Local Motion Planning with Self-Supervised Freespace Forecasting, CVPR 2021

Safe Local Motion Planning with Self-Supervised Freespace Forecasting By Peiyun Hu, Aaron Huang, John Dolan, David Held, and Deva Ramanan Citing us Yo

Peiyun Hu 90 Dec 01, 2022
Transformer Tracking (CVPR2021)

TransT - Transformer Tracking [CVPR2021] Official implementation of the TransT (CVPR2021) , including training code and trained models. We are revisin

chenxin 465 Jan 06, 2023