Pytorch Implementation of Residual Vision Transformers(ResViT)

Related tags

Deep LearningResViT
Overview

ResViT

Official Pytorch Implementation of Residual Vision Transformers(ResViT) which is described in the following paper:

Onat Dalmaz and Mahmut Yurt and Tolga Çukur ResViT: Residual vision transformers for multi-modal medical image synthesis. arXiv. 2021.

Dependencies

python>=3.6.9
torch>=1.7.1
torchvision>=0.8.2
visdom
dominate
cuda=>11.2

Installation

  • Clone this repo:
git clone https://github.com/icon-lab/ResViT
cd ResViT

Download pre-trained ViT models from Google

wget https://storage.googleapis.com/vit_models/imagenet21k/R50-ViT-B_16.npz &&
mkdir ../model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/R50-ViT-B_16.npz

Dataset

You should structure your aligned dataset in the following way:

/Datasets/BRATS/
  ├── T1_T2
  ├── T2_FLAIR
  .
  .
  ├── T1_FLAIR_T2   
/Datasets/BRATS/T2__FLAIR/
  ├── train
  ├── val  
  ├── test   

Note that for many-to-one tasks, source modalities should be in the Red and Green channels. (For 2 input modalities)

Pre-training of ART blocks without the presence of transformers

For many-to-one tasks:
python3 train.py --dataroot Datasets/IXI/T1_T2__PD/ --name T1_T2_PD_IXI_pre_trained --gpu_ids 0 --model resvit_many --which_model_netG res_cnn --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 3 --loadSize 256 --fineSize 256 --niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0

For one-to-one tasks:
python3 train.py --dataroot Datasets/IXI/T1_T2/ --name T1_T2_IXI_pre_trained --gpu_ids 0 --model resvit_one --which_model_netG res_cnn --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 50 --niter_decay 50 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0

Fine tune ResViT

For many-to-one tasks:
python3 train.py --dataroot Datasets/IXI/T1_T2__PD/ --name T1_T2_PD_IXI_resvit --gpu_ids 0 --model resvit_many --which_model_netG resvit --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 3 --loadSize 256 --fineSize 256 --niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 --pre_trained_path checkpoints/T1_T2_PD_IXI_pre_trained/latest_net_G.pth --lr 0.001

For one-to-one tasks:
python3 train.py --dataroot Datasets/IXI/T1_T2/ --name T1_T2_IXI_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --which_direction AtoB --lambda_A 100 --dataset_mode aligned --norm batch --pool_size 0 --output_nc 1 --input_nc 1 --loadSize 256 --fineSize 256 --niter 25 --niter_decay 25 --save_epoch_freq 5 --checkpoints_dir checkpoints/ --display_id 0 --pre_trained_transformer 1 --pre_trained_resnet 1 --pre_trained_path checkpoints/T1_T2_IXI_pre_trained/latest_net_G.pth --lr 0.001

Testing

For many-to-one tasks:
python3 test.py --dataroot Datasets/IXI/T1_T2__PD/ --name T1_T2_PD_IXI_resvit --gpu_ids 0 --model resvit_many --which_model_netG resvit --dataset_mode aligned --norm batch --phase test --output_nc 1 --input_nc 3 --how_many 10000 --serial_batches --fineSize 256 --loadSize 256 --results_dir results/ --checkpoints_dir checkpoints/ --which_epoch latest

For one-to-one tasks:
python3 test.py --dataroot Datasets/IXI/T1_T2/ --name T1_T2_IXI_resvit --gpu_ids 0 --model resvit_one --which_model_netG resvit --dataset_mode aligned --norm batch --phase test --output_nc 1 --input_nc 1 --how_many 10000 --serial_batches --fineSize 256 --loadSize 256 --results_dir results/ --checkpoints_dir checkpoints/ --which_epoch latest

Citation

You are encouraged to modify/distribute this code. However, please acknowledge this code and cite the paper appropriately.

@misc{dalmaz2021resvit,
      title={ResViT: Residual vision transformers for multi-modal medical image synthesis}, 
      author={Onat Dalmaz and Mahmut Yurt and Tolga Çukur},
      year={2021},
      eprint={2106.16031},
      archivePrefix={arXiv},
      primaryClass={eess.IV}
}

For any questions, comments and contributions, please contact Onat Dalmaz (onat[at]ee.bilkent.edu.tr)

(c) ICON Lab 2021

Acknowledgments

This code uses libraries from pGAN and pix2pix repository.

Owner
ICON Lab
ICON Lab
An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" in Pytorch.

GLOM An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset. To understand this

50 Oct 19, 2022
Official code repository for the publication "Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons"

Latent Equilibrium: A unified learning theory for arbitrarily fast computation with arbitrarily slow neurons This repository contains the code to repr

Computational Neuroscience, University of Bern 3 Aug 04, 2022
Official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.

GLIDE This is the official codebase for running the small, filtered-data GLIDE model from GLIDE: Towards Photorealistic Image Generation and Editing w

OpenAI 2.9k Jan 04, 2023
Data cleaning, missing value handle, EDA use in this project

Lending Club Case Study Project Brief Solving this assignment will give you an idea about how real business problems are solved using EDA. In this cas

Dhruvil Sheth 1 Jan 05, 2022
Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21.

Final Project for the CS238: Decision Making Under Uncertainty course at Stanford University in Autumn '21. We optimized wind turbine placement in a wind farm, subject to wake effects, using Q-learni

Manasi Sharma 2 Sep 27, 2022
Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Set Recognition"

Adversarial Reciprocal Points Learning for Open Set Recognition Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Se

Guangyao Chen 78 Dec 28, 2022
This project provides a stock market environment using OpenGym with Deep Q-learning and Policy Gradient.

Stock Trading Market OpenAI Gym Environment with Deep Reinforcement Learning using Keras Overview This project provides a general environment for stoc

Kim, Ki Hyun 769 Dec 25, 2022
Gans-in-action - Companion repository to GANs in Action: Deep learning with Generative Adversarial Networks

GANs in Action by Jakub Langr and Vladimir Bok List of available code: Chapter 2: Colab, Notebook Chapter 3: Notebook Chapter 4: Notebook Chapter 6: C

GANs in Action 914 Dec 21, 2022
Platform-agnostic AI Framework 🔥

🇬🇧 TensorLayerX is a multi-backend AI framework, which can run on almost all operation systems and AI hardwares, and support hybrid-framework progra

TensorLayer Community 171 Jan 06, 2023
FeTaQA: Free-form Table Question Answering

FeTaQA: Free-form Table Question Answering FeTaQA is a Free-form Table Question Answering dataset with 10K Wikipedia-based {table, question, free-form

Language, Information, and Learning at Yale 40 Dec 13, 2022
RoIAlign & crop_and_resize for PyTorch

RoIAlign for PyTorch This is a PyTorch version of RoIAlign. This implementation is based on crop_and_resize and supports both forward and backward on

Long Chen 530 Jan 07, 2023
Implémentation en pyhton de l'article Depixelizing pixel art de Johannes Kopf et Dani Lischinski

Implémentation en pyhton de l'article Depixelizing pixel art de Johannes Kopf et Dani Lischinski

TableauBits 3 May 29, 2022
Pull sensitive data from users on windows including discord tokens and chrome data.

⭐ For a 🍪 Pegasus Pull sensitive data from users on windows including discord tokens and chrome data. Features 🟩 Discord tokens 🟩 Geolocation data

Addi 44 Dec 31, 2022
Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks

flownet2-pytorch Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks. Multiple GPU training is supported, a

NVIDIA Corporation 2.8k Dec 27, 2022
Riemannian Convex Potential Maps

Modeling distributions on Riemannian manifolds is a crucial component in understanding non-Euclidean data that arises, e.g., in physics and geology. The budding approaches in this space are limited b

Facebook Research 61 Nov 28, 2022
A system for quickly generating training data with weak supervision

Programmatically Build and Manage Training Data Announcement The Snorkel team is now focusing their efforts on Snorkel Flow, an end-to-end AI applicat

Snorkel Team 5.4k Jan 02, 2023
Deep Hedging Demo - An Example of Using Machine Learning for Derivative Pricing.

Deep Hedging Demo Pricing Derivatives using Machine Learning 1) Jupyter version: Run ./colab/deep_hedging_colab.ipynb on Colab. 2) Gui version: Run py

Yu Man Tam 102 Jan 06, 2023
Detection of drones using their thermal signatures from thermal camera through YOLO-V3 based CNN with modifications to encapsulate drone motion

Drone Detection using Thermal Signature This repository highlights the work for night-time drone detection using a using an Optris PI Lightweight ther

Chong Yu Quan 6 Dec 31, 2022
Direct design of biquad filter cascades with deep learning by sampling random polynomials.

IIRNet Direct design of biquad filter cascades with deep learning by sampling random polynomials. Usage git clone https://github.com/csteinmetz1/IIRNe

Christian J. Steinmetz 55 Nov 02, 2022
i3DMM: Deep Implicit 3D Morphable Model of Human Heads

i3DMM: Deep Implicit 3D Morphable Model of Human Heads CVPR 2021 (Oral) Arxiv | Poject Page This project is the official implementation our work, i3DM

Tarun Yenamandra 60 Jan 03, 2023