Implementation of the SUMO (Slim U-Net trained on MODA) model

Related tags

Deep Learningsumo
Overview

SUMO - Slim U-Net trained on MODA

Implementation of the SUMO (Slim U-Net trained on MODA) model as described in:

TODO: add reference to paper once available

Installation Guide

On Linux with anaconda or miniconda installed, the project can be used by running the following commands to clone the repository, create a new environment and install the required dependencies:

git clone https://github.com/dslaborg/sumo.git
cd sumo
conda env create --file environment.yaml
conda activate sumo

Scripts - Quick Guide

Running and evaluating an experiment

The main model training and evaluation procedure is implemented in bin/train.py and bin/eval.py using the Pytorch Lightning framework. A chosen configuration used to train the model is called an experiment, and the evaluation is carried out using a configuration and the result folder of a training run.

train.py

Trains the model as specified in the corresponding configuration file, writes its log to the console and saves a log file and intermediate results for Tensorboard and model checkpoints to a result directory.

Arguments:

  • -e NAME, --experiment NAME: name of experiment to run, for which a NAME.yaml file has to exist in the config directory; default is default

eval.py

Evaluates a trained model, either on the validation data or test data and reports the achieved metrics.

Arguments:

  • -e NAME, --experiment NAME: name of configuration file, that should be used for evaluation, for which a NAME.yaml file has to exist in the config directory; usually equals the experiment used to train the model; default is default
  • -i PATH, --input PATH: path containing the model that should be evaluated; the given input can either be a model checkpoint, which then will be used directly, or the output directory of a train.py execution, in which case the best model will be used from PATH/models/; if the configuration has cross validation enabled, the output directory is expected and the best model per fold will be obtained from PATH/fold_*/models/; no default value
  • -t, --test: if given, the test data is used instead of the validation data

Further example scripts

In addition to scripts used to create the figures in our manuscript (spindle_analysis.py, spindle_analysis_correlations.py and spindle_detection_examply.py), the scripts directory contains two scripts that demonstrate the usage of this project.

create_data_splits.py

Demonstrates the procedure used to split the data into test and non-test subjects and the subsequent creation of a hold-out validation set and (alternatively) cross validation folds.

Arguments:

  • -i PATH, --input PATH: path containing the (necessary) input data, as produced by the MODA file MODA02_genEEGVectBlock.m; relative paths starting from the scripts directory; default is ../input/
  • -o PATH, --output PATH: path in which the generated data splits should be stored in; relative paths starting from the scripts directory; default is ../output/datasets_{datatime}
  • -n NUMBER, --n_datasets NUMBER: number of random split-candidates drawn/generated; default is 25
  • -t FRACTION, --test FRACTION: Proportion of data that is used as test data; 0<=FRACTION<=1; default is 0.2

predict_plain_data.py

Demonstrates how to predict spindles with a trained SUMO model on arbitrary EEG data, which is expected as a dict with the keys representing the EEG channels and the values the corresponding data vector.

Arguments:

  • -d PATH, --data_path PATH: path containing the input data, either in .pickle or .npy format, as a dict with the channel name as key and the EEG data as value; relative paths starting from the scripts directory; no default value
  • -m PATH, --model_path PATH: path containing the model checkpoint, which should be used to predict spindles; relative paths starting from the scripts directory; default is ../output/final.ckpt
  • -g NUMBER, --gpus NUMBER: number of GPUs to use, if 0 is given, calculations are done using CPUs; default is 0
  • -sr RATE, --sample_rate RATE: sample rate of the provided data; default is 100.0

Project Setup

The project is set up as follows:

  • bin/: contains the train.py and eval.py scripts, which are used for model training and subsequent evaluation in experiments (as configured within the config directory) using the Pytorch Lightning framework
  • config/: contains the configurations of the experiments, configuring how to train or evaluate the model
    • default.yaml: provides a sensible default configuration
    • final.yaml: contains the configuration used to train the final model checkpoint (output/final.ckpt)
    • predict.yaml: configuration that can be used to predict spindles on arbitrary data, e.g. by using the script at scripts/predict_plain_data.py
  • input/: should contain the used input files, e.g. the EEG data and annotated spindles as produced by the MODA repository and transformed as demonstrated in the /scripts/create_data_splits.py file
  • output/: contains generated output by any experiment runs or scripts, e.g. the created figures
    • final.ckpt: the final model checkpoint, on which the test data performance, as reported in the paper, was obtained
  • scripts/: various scripts used to create the plots of our paper and to demonstrate the usage of this project
    • a7/: python implementation of the A7 algorithm as described in:
      Karine Lacourse, Jacques Delfrate, Julien Beaudry, Paul E. Peppard and Simon C. Warby. "A sleep spindle detection algorithm that emulates human expert spindle scoring." Journal of Neuroscience Methods 316 (2019): 3-11.
      
    • create_data_splits.py: demonstrates the procedure, how the data set splits were obtained, including the evaluation on the A7 algorithm
    • predict_plain_data.py: demonstrates the prediction of spindles on arbitrary EEG data, using a trained model checkpoint
    • spindle_analysis.py, spindle_analysis_correlations.py, spindle_detection_example.py: scripts used to create some of the figures used in our paper
  • sumo/: the implementation of the SUMO model and used classes and functions, for more information see the docstrings

Configuration Parameters

The configuration of an experiment is implemented using yaml configuration files. These files must be placed within the config directory and must match the name past as --experiment to the eval.py or train.py script. The default.yaml is always loaded as a set of default configuration parameters and parameters specified in an additional file overwrite the default values. Any parameters or groups of parameters that should be None, have to be configured as either null or Null following the YAML definition.

The available parameters are as follows:

  • data: configuration of the used input data; optional, can be None if spindle should be annotated on arbitrary EEG data
    • directory and file_name: the input file containing the Subject objects (see scripts/create_data_splits.py) is expected to be located at ${directory}/${file_name}, where relative paths are to be starting from the root project directory; the file should be a (pickled) dict with the name of a data set as key and the list of corresponding subjects as value; default is input/subjects.pickle
    • split: describing the keys of the data sets to be used, specifying either train and validation, or cross_validation, and optionally test
      • cross_validation: can be either an integer k>=2, in which the keys fold_0, ..., fold_{k-1} are expected to exist, or a list of keys
    • batch_size: size of the used minbatches during training; default is 12
    • preprocessing: if z-scoring should be performed on the EEG data, default is True
  • experiment: definition of the performed experiment; mandatory
    • model: definition of the model configuration; mandatory
      • n_classes: number of output parameters; default is 2
      • activation: name of an activation function as defined in torch.nn package; default is ReLU
      • depth: number of layers of the U excluding the last layer; default is 2
      • channel_size: number of filters of the convolutions in the first layer; default is 16
      • pools: list containing the size of pooling and upsampling operations; has to contain as many values as the value of depth; default [4;4]
      • convolution_params: parameters used by the Conv1d modules
      • moving_avg_size: width of the moving average filter; default is 42
    • train: configuration used in training the model; mandatory
      • n_epochs: maximal number of epochs to be run before stopping training; default is 800
      • early_stopping: number of epochs without any improvement in the val_f1_mean metric, after which training is stopped; default is 300
      • optimizer: configuration of an optimizer as defined in torch.optim package; contains class_name (default is Adam) and parameters, which are passed to the constructor of the used optimizer class
      • lr_scheduler: used learning rate scheduler; optional, default is None
      • loss: configuration of loss function as defined either in sumo.loss package (GeneralizedDiceLoss) or torch.nn package; contains class_name (default is GeneralizedDiceLoss) and parameters, which are passed to the constructor of the used loss class
    • validation: configuration used in evaluating the model; mandatory
      • overlap_threshold_step: step size of the overlap thresholds used to calculate (validation) F1 scores
Unified Instance and Knowledge Alignment Pretraining for Aspect-based Sentiment Analysis

Unified Instance and Knowledge Alignment Pretraining for Aspect-based Sentiment Analysis Requirements python 3.7 pytorch-gpu 1.7 numpy 1.19.4 pytorch_

12 Oct 29, 2022
BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training

BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training By Likun Cai, Zhi Zhang, Yi Zhu, Li Zhang, Mu Li, Xiangyang Xue. This

290 Dec 29, 2022
A unofficial pytorch implementation of PAN(PSENet2): Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network

Efficient and Accurate Arbitrary-Shaped Text Detection with Pixel Aggregation Network Requirements pytorch 1.1+ torchvision 0.3+ pyclipper opencv3 gcc

zhoujun 400 Dec 26, 2022
Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition How Fast Compare to Other Zero-Shot NAS Proxies on CIFAR-10/100 Pre-trained Model

190 Dec 29, 2022
[NeurIPS 2021] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods

Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods Large Scale Learning on Non-Homophilous Graphs: New Benchmark

60 Jan 03, 2023
Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021

Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021 The code for training mCOLT/mRASP2, a multilingua

104 Jan 01, 2023
Ejemplo Algoritmo Viterbi - Example of a Viterbi algorithm applied to a hidden Markov model on DNA sequence

Ejemplo Algoritmo Viterbi Ejemplo de un algoritmo Viterbi aplicado a modelo ocul

Mateo Velásquez Molina 1 Jan 10, 2022
Replication attempt for the Protein Folding Model

RGN2-Replica (WIP) To eventually become an unofficial working Pytorch implementation of RGN2, an state of the art model for MSA-less Protein Folding f

Eric Alcaide 36 Nov 29, 2022
A high-level Python library for Quantum Natural Language Processing

lambeq About lambeq is a toolkit for quantum natural language processing (QNLP). Documentation: https://cqcl.github.io/lambeq/ User support: lambeq-su

Cambridge Quantum 315 Jan 01, 2023
Official pytorch implementation of the IrwGAN for unaligned image-to-image translation

IrwGAN (ICCV2021) Unaligned Image-to-Image Translation by Learning to Reweight [Update] 12/15/2021 All dataset are released, trained models and genera

37 Nov 09, 2022
Official repository for GCR rerank, a GCN-based reranking method for both image and video re-ID

Official repository for GCR rerank, a GCN-based reranking method for both image and video re-ID

53 Nov 22, 2022
A graph adversarial learning toolbox based on PyTorch and DGL.

GraphWar: Arms Race in Graph Adversarial Learning NOTE: GraphWar is still in the early stages and the API will likely continue to change. 🚀 Installat

Jintang Li 54 Jan 05, 2023
Neural Articulated Radiance Field

Neural Articulated Radiance Field NARF Neural Articulated Radiance Field Atsuhiro Noguchi, Xiao Sun, Stephen Lin, Tatsuya Harada ICCV 2021 [Paper] [Co

Atsuhiro Noguchi 144 Jan 03, 2023
PyTorch implementation for ComboGAN

ComboGAN This is our ongoing PyTorch implementation for ComboGAN. Code was written by Asha Anoosheh (built upon CycleGAN) [ComboGAN Paper] If you use

Asha Anoosheh 139 Dec 20, 2022
Torchreid: Deep learning person re-identification in PyTorch.

Torchreid Torchreid is a library for deep-learning person re-identification, written in PyTorch. It features: multi-GPU training support both image- a

Kaiyang 3.7k Jan 05, 2023
Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation

Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation

Bae, Gwangbin 95 Jan 04, 2023
KITTI-360 Annotation Tool is a framework that developed based on python(cherrypy + jinja2 + sqlite3) as the server end and javascript + WebGL as the front end.

KITTI-360 Annotation Tool is a framework that developed based on python(cherrypy + jinja2 + sqlite3) as the server end and javascript + WebGL as the front end.

86 Dec 12, 2022
[ICSE2020] MemLock: Memory Usage Guided Fuzzing

MemLock: Memory Usage Guided Fuzzing This repository provides the tool and the evaluation subjects for the paper "MemLock: Memory Usage Guided Fuzzing

Cheng Wen 54 Jan 07, 2023
Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ryuichiro Hataya 50 Dec 05, 2022