Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Related tags

Deep LearningUTNet
Overview

UTNet (Accepted at MICCAI 2021)

Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Introduction

Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing medical image segmentation. UTNet applies self-attention modules in both encoder and decoder for capturing long-range dependency at dif- ferent scales with minimal overhead. To this end, we propose an efficient self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2) to approximate O(n). A new self-attention decoder is also proposed to recover fine-grained details from the skipped connections in the encoder. Our approach addresses the dilemma that Transformer requires huge amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. We have evaluated UTNet on the multi- label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet demonstrates superior segmentation performance and robustness against the state-of-the-art approaches, holding the promise to generalize well on other medical image segmentations.

image image

Supportting models

UTNet

TransUNet

ResNet50-UTNet

ResNet50-UNet

SwinUNet

To be continue ...

Getting Started

Currently, we only support M&Ms dataset.

Prerequisites

Python >= 3.6
pytorch = 1.8.1
SimpleITK = 2.0.2
numpy = 1.19.5
einops = 0.3.2

Preprocess

Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'

Training

The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.

UTNet

For default UTNet setting, training with:

python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss

Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.

To optimize UTNet in your own task, there are several hyperparameters to tune:

'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.

'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.

'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.

'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.

Here are some recomended parameter setting:

--block_list 1234 --num_blocks 1,1,1,1

Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.

--block_list 1234 --num_blocks 1,1,4,8

Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.

--block_list 234 --num_blocks 2,4,8

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.

TransUNet

We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the original repo. The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.

python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0

ResNet50-UTNet

For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0

Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.

--block_list 23 --num_blocks 2,4

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

ResNet50-UNet

If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list ''  --gpu 0

SwinUNet

Download pre-trained model from the origin repo. As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.

python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224

Citation

If you find this repo helps, please kindly cite our paper, thanks!

@inproceedings{gao2021utnet,
  title={UTNet: a hybrid transformer architecture for medical image segmentation},
  author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={61--71},
  year={2021},
  organization={Springer}
}
Pytorch implementation of the paper "Class-Balanced Loss Based on Effective Number of Samples"

Class-balanced-loss-pytorch Pytorch implementation of the paper Class-Balanced Loss Based on Effective Number of Samples presented at CVPR'19. Yin Cui

Vandit Jain 697 Dec 29, 2022
Implementation of Sequence Generative Adversarial Nets with Policy Gradient

SeqGAN Requirements: Tensorflow r1.0.1 Python 2.7 CUDA 7.5+ (For GPU) Introduction Apply Generative Adversarial Nets to generating sequences of discre

Lantao Yu 2k Dec 29, 2022
CurriculumNet: Weakly Supervised Learning from Large-Scale Web Images

CurriculumNet Introduction This repo contains related code and models from the ECCV 2018 CurriculumNet paper. CurriculumNet is a new training strategy

156 Jul 04, 2022
Bayesian regularization for functional graphical models.

BayesFGM Paper: Jiajing Niu, Andrew Brown. Bayesian regularization for functional graphical models. Requirements R version 3.6.3 and up Python 3.6 and

0 Oct 07, 2021
SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch.

The SpeechBrain Toolkit SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch. The goal is to create a single, flexible, and us

SpeechBrain 5.1k Jan 02, 2023
2021搜狐校园文本匹配算法大赛 分比我们低的都是帅哥队

sohu_text_matching 2021搜狐校园文本匹配算法大赛Top2:分比我们低的都是帅哥队 本repo包含了本次大赛决赛环节提交的代码文件及答辩PPT,提交的模型文件可在百度网盘获取(链接:https://pan.baidu.com/s/1T9FtwiGFZhuC8qqwXKZSNA ,

hflserdaniel 43 Oct 01, 2022
Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation

Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation Woncheol Shin1, Gyubok Lee1, Jiyoung Lee1, Joonseok Lee2,3, Edward Ch

Woncheol Shin 7 Sep 26, 2022
High level network definitions with pre-trained weights in TensorFlow

TensorNets High level network definitions with pre-trained weights in TensorFlow (tested with 2.1.0 = TF = 1.4.0). Guiding principles Applicability.

Taehoon Lee 1k Dec 13, 2022
An Agnostic Computer Vision Framework - Pluggable to any Training Library: Fastai, Pytorch-Lightning with more to come

IceVision is the first agnostic computer vision framework to offer a curated collection with hundreds of high-quality pre-trained models from torchvision, MMLabs, and soon Pytorch Image Models. It or

airctic 789 Dec 29, 2022
Self-Supervised Image Denoising via Iterative Data Refinement

Self-Supervised Image Denoising via Iterative Data Refinement Yi Zhang1, Dasong Li1, Ka Lung Law2, Xiaogang Wang1, Hongwei Qin2, Hongsheng Li1 1CUHK-S

Zhang Yi 72 Jan 01, 2023
BOOKSUM: A Collection of Datasets for Long-form Narrative Summarization

BOOKSUM: A Collection of Datasets for Long-form Narrative Summarization Authors: Wojciech Kryściński, Nazneen Rajani, Divyansh Agarwal, Caiming Xiong,

Salesforce 125 Dec 31, 2022
Amazing-Python-Scripts - 🚀 Curated collection of Amazing Python scripts from Basics to Advance with automation task scripts.

📑 Introduction A curated collection of Amazing Python scripts from Basics to Advance with automation task scripts. This is your Personal space to fin

Avinash Ranjan 1.1k Dec 29, 2022
A PyTorch Implementation of "SINE: Scalable Incomplete Network Embedding" (ICDM 2018).

Scalable Incomplete Network Embedding ⠀⠀ A PyTorch implementation of Scalable Incomplete Network Embedding (ICDM 2018). Abstract Attributed network em

Benedek Rozemberczki 69 Sep 22, 2022
Organseg dags - The repository contains the codebase for multi-organ segmentation with directed acyclic graphs (DAGs) in CT.

Organseg dags - The repository contains the codebase for multi-organ segmentation with directed acyclic graphs (DAGs) in CT.

yzf 1 Jun 12, 2022
Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face Manipulation" published in CVPR 2020.

FFD Source Code Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face M

88 Nov 22, 2022
Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras

Deep Learning Tutorial for Kaggle Ultrasound Nerve Segmentation competition, using Keras This tutorial shows how to use Keras library to build deep ne

Marko Jocić 922 Dec 19, 2022
🎓Automatically Update CV Papers Daily using Github Actions (Update at 12:00 UTC Every Day)

🎓Automatically Update CV Papers Daily using Github Actions (Update at 12:00 UTC Every Day)

Realcat 270 Jan 07, 2023
Differentiable Wavetable Synthesis

Differentiable Wavetable Synthesis

4 Feb 11, 2022
Python package for downloading ECMWF reanalysis data and converting it into a time series format.

ecmwf_models Readers and converters for data from the ECMWF reanalysis models. Written in Python. Works great in combination with pytesmo. Citation If

TU Wien - Department of Geodesy and Geoinformation 31 Dec 26, 2022
Predict stock movement with Machine Learning and Deep Learning algorithms

Project Overview Stock market movement prediction using LSTM Deep Neural Networks and machine learning algorithms Software and Library Requirements Th

Naz Delam 46 Sep 13, 2022