Official Pytorch implementation of MixMo framework

Overview

MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks

Official PyTorch implementation of the MixMo framework | paper | docs

Alexandre Ramé, Rémy Sun, Matthieu Cord

Citation

If you find this code useful for your research, please cite:

@article{rame2021ixmo,
    title={MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks},
    author={Alexandre Rame and Remy Sun and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2103.06132}
}

Abstract

Recent strategies achieved ensembling “for free” by fitting concurrently diverse subnetworks inside a single base network. The main idea during training is that each subnetwork learns to classify only one of the multiple inputs simultaneously provided. However, the question of how to best mix these multiple inputs has not been studied so far.

In this paper, we introduce MixMo, a new generalized framework for learning multi-input multi-output deep subnetworks. Our key motivation is to replace the suboptimal summing operation hidden in previous approaches by a more appropriate mixing mechanism. For that purpose, we draw inspiration from successful mixed sample data augmentations. We show that binary mixing in features - particularly with rectangular patches from CutMix - enhances results by making subnetworks stronger and more diverse.

We improve state of the art for image classification on CIFAR-100 and Tiny ImageNet datasets. Our easy to implement models notably outperform data augmented deep ensembles, without the inference and memory overheads. As we operate in features and simply better leverage the expressiveness of large networks, we open a new line of research complementary to previous works.

Overview

Most important code sections

This repository provides a general wrapper over PyTorch to reproduce the main results from the paper. The code sections specific to MixMo can be found in:

  1. mixmo.loaders.dataset_wrapper.py and specifically MixMoDataset to create batches with multiple inputs and multiple outputs.
  2. mixmo.augmentations.mixing_blocks.py where we create the mixing masks, e.g. via linear summing (_mixup_mask) or via patch mixing (_cutmix_mask).
  3. mixmo.networks.resnet.py and mixmo.networks.wrn.py where we adapt the network structures to handle:
    • multiple inputs via multiple conv1s encoders (one for each input). The function mixmo.augmentations.mixing_blocks.mix_manifold is used to mix the extracted representations according to the masks provided in metadata from MixMoDataset.
    • multiple outputs via multiple predictions.

This translates to additional tensor management in mixmo.learners.learner.py.

Pseudo code

Our MixMoDataset wraps a PyTorch Dataset. The batch_repetition_sampler repeats the same index b times in each batch. Moreover, we provide SoftCrossEntropyLoss which handles soft-labels required by mixed sample data augmentations such as CutMix.

from mixmo.loaders import (dataset_wrapper, batch_repetition_sampler)
from mixmo.networks.wrn import WideResNetMixMo
from mixmo.core.loss import SoftCrossEntropyLoss as criterion

...

# cf mixmo.loaders.loader
train_dataset = dataset_wrapper.MixMoDataset(
        dataset=CIFAR100(os.path.join(dataplace, "cifar100-data")),
        num_members=2,  # we use M=2 subnetworks
        mixmo_mix_method="cutmix",  # patch mixing, linker to mixmo.augmentations.mixing_blocks._cutmix_mask
        mixmo_alpha=2,  # mixing ratio sampled from Beta distribution with concentration 2
        mixmo_weight_root=3  # root for reweighting of loss components 3
        )
network = WideResNetMixMo(depth=28, widen_factor=10, num_classes=100)

...

# cf mixmo.learners.learner and mixmo.learners.model_wrapper
for _ in range(num_epochs):
    for indexes_0, indexes_1 in batch_repetition_sampler(batch_size=64, b=4, max_index=len(train_dataset)):
        for (inputs_0, inputs_1, targets_0, targets_1, metadata_mixmo_masks) in train_dataset(indexes_0, indexes_1):
            outputs_0, outputs_1 = network([inputs_0, inputs_1], metadata_mixmo_masks)
            loss = criterion(outputs_0, targets_0) + criterion(outputs_1, targets_1)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

Configuration files

Our code heavily relies on yaml config files. In the mixmo-pytorch/config folder, we provide the configs to reproduce the main paper results.

For example, the state-of-the-art exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4 means that:

  • cifar100: dataset is CIFAR-100.
  • wrn2810-2: WideResNet-28-10 network architecture with M=2 subnetworks.
  • cutmixmo-p5: mixing block is patch mixing with probability p=0.5 else linear mixing.
  • msdacutmix: use CutMix mixed sample data augmentation.
  • bar4: batch repetition to b=4.

Results and available checkpoints

CIFAR-100 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar100
-- Vanilla 81.79 exp_cifar100_wrn2810_1net_standard_bar1.yaml
-- Mixup 83.43 exp_cifar100_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 83.95 exp_cifar100_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 82.92 exp_cifar100_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 82.96 exp_cifar100_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 85.52 - 85.59 exp_cifar100_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 85.36 - 85.57 exp_cifar100_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 85.77 - 85.92 exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

CIFAR-10 with WideResNet-28-10

Subnetwork method MSDA Top-1 Accuracy config file in mixmo-pytorch/config/cifar10
-- Vanilla 96.37 exp_cifar10_wrn2810_1net_standard_bar1.yaml
-- Mixup 97.07 exp_cifar10_wrn2810_1net_msdamixup_bar1.yaml
-- CutMix 97.28 exp_cifar10_wrn2810_1net_msdacutmix_bar1.yaml
MIMO -- 96.71 exp_cifar10_wrn2810-2_mimo_standard_bar4.yaml
Linear-MixMo -- 96.88 exp_cifar10_wrn2810-2_linearmixmo_standard_bar4.yaml
Cut-MixMo -- 97.52 exp_cifar10_wrn2810-2_cutmixmo-p5_standard_bar4.yaml
Linear-MixMo CutMix 97.73 exp_cifar10_wrn2810-2_linearmixmo_msdacutmix_bar4.yaml
Cut-MixMo CutMix 97.83 exp_cifar10_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml

Tiny ImageNet-200 with PreActResNet-18-width

Method Width Top-1 Accuracy config file in mixmo-pytorch/config/tiny
Vanilla 1 62.75 exp_tinyimagenet_res18_1net_standard_bar1.yaml
Linear-MixMo 1 62.91 exp_tinyimagenet_res18-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 1 64.32 exp_tinyimagenet_res18-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 2 64.91 exp_tinyimagenet_res182_1net_standard_bar1.yaml
Linear-MixMo 2 67.03 exp_tinyimagenet_res182-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 2 69.12 exp_tinyimagenet_res182-2_cutmixmo-p5_standard_bar4.yaml
Vanilla 3 65.84 exp_tinyimagenet_res183_1net_standard_bar1.yaml
Linear-MixMo 3 68.36 exp_tinyimagenet_res183-2_linearmixmo_standard_bar4.yaml
Cut-MixMo 3 70.23 exp_tinyimagenet_res183-2_cutmixmo-p5_standard_bar4.yaml

Installation

Requirements overview

  • python >= 3.6
  • torch >= 1.4.0
  • torchsummary >= 1.5.1
  • torchvision >= 0.5.0
  • tensorboard >= 1.14.0

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/mixmo-pytorch.git
  1. Install this repository and the dependencies using pip:
$ conda create --name mixmo python=3.6.10
$ conda activate mixmo
$ cd mixmo-pytorch
$ pip install -r requirements.txt

With this, you can edit the MixMo code on the fly.

Datasets

We advise to first create a dedicated data folder dataplace, that will be provided as an argument in the subsequent scripts.

  • CIFAR

CIFAR-10 and CIFAR-100 datasets are managed by Pytorch dataloader. First time you run a script, the dataloader will download the dataset in your provided dataplace.

  • Tiny-ImageNet

Tiny-ImageNet dataset needs to be download beforehand. The following process is forked from manifold mixup.

  1. Download the zipped data from https://tiny-imagenet.herokuapp.com/.
  2. Extract the zipped data in folder dataplace.
  3. Run the following script (This will arange the validation data in the format required by the pytorch loader).
$ python scripts/script_load_tiny_data.py --dataplace $dataplace

Running the code

Training

Baseline

First, to train a baseline model, simply execute the following command:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810_1net_standard_bar1.yaml --dataplace $dataplace --saveplace $saveplace

It will create an output folder exp_cifar100_wrn2810_1net_standard_bar1 located in parent folder saveplace. This folder includes model checkpoints, a copy of your config file, logs and tensorboard logs. By default, if the output folder already exists, training will load the last weights epoch and will continue. If you want to forcefully restart training, simply add --from_scratch as an argument.

MixMo

When training MixMo, you just need to select the appropriate config file. For example, to obtain state of the art results on CIFAR-100 by combining Cut-MixMo and CutMix, just execute:

$ python3 scripts/train.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --saveplace $saveplace

Evaluation

To evaluate the accuracy of a given strategy, you can train your own model, or just download our pretrained checkpoints:

$ python3 scripts/evaluate.py --config_path config/cifar100/exp_cifar100_wrn2810-2_cutmixmo-p5_msdacutmix_bar4.yaml --dataplace $dataplace --checkpoint $checkpoint --tempscal
  • checkpoint can be either:
    • a path towards a checkpoint.
    • an int matching the training epoch you wish to evaluate. In that case, you need to provide --saveplace $saveplace.
    • the string best: we then automatically select the best training epoch. In that case, you need to provide --saveplace $saveplace.
  • --tempscal: indicates that you will apply temperature scaling

Results will be printed at the end of the script.

If you wish to test the models against common corruptions and perturbations, download the CIFAR-100-c dataset in your dataplace. Then use --robustness at evaluation.

Create your own configuration files and learning strategies

You can create new configs automatically via:

$ python3 scripts/templateutils_mixmo.py --template_path scripts/exp_mixmo_template.yaml --config_dir config/$your_config_dir --dataset $dataset

Acknowledgements and references

[ICCV 2021] Relaxed Transformer Decoders for Direct Action Proposal Generation

RTD-Net (ICCV 2021) This repo holds the codes of paper: "Relaxed Transformer Decoders for Direct Action Proposal Generation", accepted in ICCV 2021. N

Multimedia Computing Group, Nanjing University 80 Nov 30, 2022
A package to predict protein inter-residue geometries from sequence data

trRosetta This package is a part of trRosetta protein structure prediction protocol developed in: Improved protein structure prediction using predicte

Ivan Anishchenko 185 Jan 07, 2023
The repo of the preprinting paper "Labels Are Not Perfect: Inferring Spatial Uncertainty in Object Detection"

Inferring Spatial Uncertainty in Object Detection A teaser version of the code for the paper Labels Are Not Perfect: Inferring Spatial Uncertainty in

ZINING WANG 21 Mar 03, 2022
Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

MRSAIL (Mini Robotics, Software & AI Lab) 6 Nov 26, 2022
Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021.

Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021. Bobo Xi, Jiaojiao Li, Yunsong Li and Qian Du. Code f

Bobo Xi 7 Nov 03, 2022
Implementation of ML models like Decision tree, Naive Bayes, Logistic Regression and many other

ML_Model_implementaion Implementation of ML models like Decision tree, Naive Bayes, Logistic Regression and many other dectree_model: Implementation o

Anshuman Dalai 3 Jan 24, 2022
Red Team tool for exfiltrating files from a target's Google Drive that you have access to, via Google's API.

GD-Thief Red Team tool for exfiltrating files from a target's Google Drive that you(the attacker) has access to, via the Google Drive API. This includ

Antonio Piazza 39 Dec 27, 2022
Code for "On Memorization in Probabilistic Deep Generative Models"

On Memorization in Probabilistic Deep Generative Models This repository contains the code necessary to reproduce the experiments in On Memorization in

The Alan Turing Institute 3 Jun 09, 2022
CUda Matrix Multiply library.

cumm CUda Matrix Multiply library. cumm is developed during learning of CUTLASS, which use too much c++ template and make code unmaintainable. So I de

49 Dec 27, 2022
Semi-supevised Semantic Segmentation with High- and Low-level Consistency

Semi-supevised Semantic Segmentation with High- and Low-level Consistency This Pytorch repository contains the code for our work Semi-supervised Seman

123 Dec 30, 2022
BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalanced Tongue Data

Balanced-Evolutionary-Semi-Stacking Code for the paper ''BESS: Balanced Evolutionary Semi-Stacking for Disease Detection via Partially Labeled Imbalan

0 Jan 16, 2022
Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. It can use GPUs and perform efficient symbolic differentiation.

============================================================================================================ `MILA will stop developing Theano https:

9.6k Jan 06, 2023
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

551 Dec 29, 2022
Generating Anime Images by Implementing Deep Convolutional Generative Adversarial Networks paper

AnimeGAN - Deep Convolutional Generative Adverserial Network PyTorch implementation of DCGAN introduced in the paper: Unsupervised Representation Lear

Rohit Kukreja 23 Jul 21, 2022
Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

730 Jan 09, 2023
Implementation of the Chamfer Distance as a module for pyTorch

Chamfer Distance for pyTorch This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension.

Christian Diller 205 Jan 05, 2023
Some pre-commit hooks for OpenMMLab projects

pre-commit-hooks Some pre-commit hooks for OpenMMLab projects. Using pre-commit-hooks with pre-commit Add this to your .pre-commit-config.yaml - rep

OpenMMLab 16 Nov 29, 2022
the official implementation of the paper "Isometric Multi-Shape Matching" (CVPR 2021)

Isometric Multi-Shape Matching (IsoMuSh) Paper-CVF | Paper-arXiv | Video | Code Citation If you find our work useful in your research, please consider

Maolin Gao 9 Jul 17, 2022
Rayvens makes it possible for data scientists to access hundreds of data services within Ray with little effort.

Rayvens augments Ray with events. With Rayvens, Ray applications can subscribe to event streams, process and produce events. Rayvens leverages Apache

CodeFlare 32 Dec 25, 2022
Bayesian Inference Tools in Python

BayesPy Bayesian Inference Tools in Python Our goal is, given the discrete outcomes of events, estimate the distribution of categories. Using gradient

Max Sklar 99 Dec 14, 2022