Contextual Attention Network: Transformer Meets U-Net

Overview

Contextual Attention Network: Transformer Meets U-Net

Contexual attention network for medical image segmentation with state of the art results on skin lesion segmentation, multiple myeloma cell segmentation. This method incorpotrates the transformer module into a U-Net structure so as to concomitantly capture long-range dependency along with resplendent local informations. If this code helps with your research please consider citing the following paper:

R. Azad, Moein Heidari, Yuli Wu and Dorit Merhof , "Contextual Attention Network: Transformer Meets U-Net", download link.

@article{reza2022contextual,
  title={Contextual Attention Network: Transformer Meets U-Net},
  author={Reza, Azad and Moein, Heidari and Yuli, Wu and Dorit, Merhof},
  journal={arXiv preprint arXiv:2203.01932},
  year={2022}
}

Please consider starring us, if you found it useful. Thanks

Updates

This code has been implemented in python language using Pytorch library and tested in ubuntu OS, though should be compatible with related environment. following Environement and Library needed to run the code:

  • Python 3
  • Pytorch

Run Demo

For training deep model and evaluating on each data set follow the bellow steps:
1- Download the ISIC 2018 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18.
2- Run Prepare_ISIC2018.py for data preperation and dividing data to train,validation and test sets.
3- Run train_skin.py for training the model using trainng and validation sets. The model will be train for 100 epochs and it will save the best weights for the valiation set.
4- For performance calculation and producing segmentation result, run evaluate_skin.py. It will represent performance measures and will saves related results in results folder.

Notice: For training and evaluating on ISIC 2017 and ph2 follow the bellow steps :

ISIC 2017- Download the ISIC 2017 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18\7.
then Run Prepare_ISIC2017.py for data preperation and dividing data to train,validation and test sets.
ph2- Download the ph2 dataset from this link and extract it then Run Prepare_ph2.py for data preperation and dividing data to train,validation and test sets.
Follow step 3 and 4 for model traing and performance estimation. For ph2 dataset you need to first train the model with ISIC 2017 data set and then fine-tune the trained model using ph2 dataset.

Quick Overview

Diagram of the proposed method

Perceptual visualization of the proposed Contextual Attention module.

Diagram of the proposed method

Results

For evaluating the performance of the proposed method, Two challenging task in medical image segmentaion has been considered. In bellow, results of the proposed approach illustrated.

Task 1: SKin Lesion Segmentation

Performance Comparision on SKin Lesion Segmentation

In order to compare the proposed method with state of the art appraoches on SKin Lesion Segmentation, we considered Drive dataset.

Methods (On ISIC 2017) Dice-Score Sensivity Specificaty Accuracy
Ronneberger and et. all U-net 0.8159 0.8172 0.9680 0.9164
Oktay et. all Attention U-net 0.8082 0.7998 0.9776 0.9145
Lei et. all DAGAN 0.8425 0.8363 0.9716 0.9304
Chen et. all TransU-net 0.8123 0.8263 0.9577 0.9207
Asadi et. all MCGU-Net 0.8927 0.8502 0.9855 0.9570
Valanarasu et. all MedT 0.8037 0.8064 0.9546 0.9090
Wu et. all FAT-Net 0.8500 0.8392 0.9725 0.9326
Azad et. all Proposed TMUnet 0.9164 0.9128 0.9789 0.9660

For more results on ISIC 2018 and PH2 dataset, please refer to the paper

SKin Lesion Segmentation segmentation result on test data

SKin Lesion Segmentation  result (a) Input images. (b) Ground truth. (c) U-net. (d) Gated Axial-Attention. (e) Proposed method without a contextual attention module and (f) Proposed method.

Multiple Myeloma Cell Segmentation

Performance Evalution on the Multiple Myeloma Cell Segmentation task

Methods mIOU
Frequency recalibration U-Net 0.9392
XLAB Insights 0.9360
DSC-IITISM 0.9356
Multi-scale attention deeplabv3+ 0.9065
U-Net 0.7665
Baseline 0.9172
Proposed 0.9395

Multiple Myeloma Cell Segmentation results

Multiple Myeloma Cell Segmentation result

Model weights

You can download the learned weights for each dataset in the following table.

Dataset Learned weights
ISIC 2018 TMUnet
ISIC 2017 TMUnet
Ph2 TMUnet

Query

All implementations are done by Reza Azad and Moein Heidari. For any query please contact us for more information.

rezazad68@gmail.com
moeinheidari7829@gmail.com
Owner
Reza Azad
Deep Learning and Computer Vision Researcher
Reza Azad
Out-of-boundary View Synthesis towards Full-frame Video Stabilization

Out-of-boundary View Synthesis towards Full-frame Video Stabilization Introduction | Update | Results Demo | Introduction This repository contains the

25 Oct 10, 2022
Extreme Lightwegith Portrait Segmentation

Extreme Lightwegith Portrait Segmentation Please go to this link to download code Requirements python 3 pytorch = 0.4.1 torchvision==0.2.1 opencv-pyt

HYOJINPARK 59 Dec 16, 2022
C3d-pytorch - Pytorch porting of C3D network, with Sports1M weights

C3D for pytorch This is a pytorch porting of the network presented in the paper Learning Spatiotemporal Features with 3D Convolutional Networks How to

Davide Abati 311 Jan 06, 2023
Object Detection using YOLO from PyImageSearch

Object Detection using YOLO from PyImageSearch By applying object detection, you’ll not only be able to determine what is in an image, but also where

Mohamed NIANG 1 Feb 09, 2022
Python SDK for building, training, and deploying ML models

Overview of Kubeflow Fairing Kubeflow Fairing is a Python package that streamlines the process of building, training, and deploying machine learning (

Kubeflow 325 Dec 13, 2022
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Jina AI 2 Mar 15, 2022
This is a deep learning-based method to segment deep brain structures and a brain mask from T1 weighted MRI.

DBSegment This tool generates 30 deep brain structures segmentation, as well as a brain mask from T1-Weighted MRI. The whole procedure should take ~1

Luxembourg Neuroimaging (Platform OpNeuroImg) 2 Oct 25, 2022
The PyTorch implementation for paper "Neural Texture Extraction and Distribution for Controllable Person Image Synthesis" (CVPR2022 Oral)

ArXiv | Get Start Neural-Texture-Extraction-Distribution The PyTorch implementation for our paper "Neural Texture Extraction and Distribution for Cont

Ren Yurui 111 Dec 10, 2022
Implementation of average- and worst-case robust flatness measures for adversarial training.

Relating Adversarially Robust Generalization to Flat Minima This repository contains code corresponding to the MLSys'21 paper: D. Stutz, M. Hein, B. S

David Stutz 13 Nov 27, 2022
A model to classify a piece of news as REAL or FAKE

Fake_news_classification A model to classify a piece of news as REAL or FAKE. This python project of detecting fake news deals with fake and real news

Gokul Stark 1 Jan 29, 2022
DABO: Data Augmentation with Bilevel Optimization

DABO: Data Augmentation with Bilevel Optimization [Paper] The goal is to automatically learn an efficient data augmentation regime for image classific

ElementAI 24 Aug 12, 2022
Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset

Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset This repository provides a unified online platform, LoLi-P

Chongyi Li 457 Jan 03, 2023
Efficient Lottery Ticket Finding: Less Data is More

The lottery ticket hypothesis (LTH) reveals the existence of winning tickets (sparse but critical subnetworks) for dense networks, that can be trained in isolation from random initialization to match

VITA 20 Sep 04, 2022
Using Random Effects to Account for High-Cardinality Categorical Features and Repeated Measures in Deep Neural Networks

LMMNN Using Random Effects to Account for High-Cardinality Categorical Features and Repeated Measures in Deep Neural Networks This is the working dire

Giora Simchoni 10 Nov 02, 2022
Several simple examples for popular neural network toolkits calling custom CUDA operators.

Neural Network CUDA Example Several simple examples for neural network toolkits (PyTorch, TensorFlow, etc.) calling custom CUDA operators. We provide

WeiYang 798 Jan 01, 2023
Embracing Single Stride 3D Object Detector with Sparse Transformer

SST: Single-stride Sparse Transformer This is the official implementation of paper: Embracing Single Stride 3D Object Detector with Sparse Transformer

TuSimple 385 Dec 28, 2022
A collection of differentiable SVD methods and also the official implementation of the ICCV21 paper "Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?"

Differentiable SVD Introduction This repository contains: The official Pytorch implementation of ICCV21 paper Why Approximate Matrix Square Root Outpe

YueSong 32 Dec 25, 2022
[NeurIPS 2021] Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training

Better Safe Than Sorry: Preventing Delusive Adversaries with Adversarial Training Code for NeurIPS 2021 paper "Better Safe Than Sorry: Preventing Delu

Lue Tao 29 Sep 20, 2022
An implementation of Deep Forest 2021.2.1.

Deep Forest (DF) 21 DF21 is an implementation of Deep Forest 2021.2.1. It is designed to have the following advantages: Powerful: Better accuracy than

LAMDA Group, Nanjing University 795 Jan 03, 2023
A clean and extensible PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners

A clean and extensible PyTorch implementation of Masked Autoencoders Are Scalable Vision Learners A PyTorch re-implementation of Mask Autoencoder trai

Tianyu Hua 23 Dec 13, 2022