DABO: Data Augmentation with Bilevel Optimization

Overview

License

figure figure

DABO: Data Augmentation with Bilevel Optimization [Paper]

The goal is to automatically learn an efficient data augmentation regime for image classification.

Accepted at WACV2021

Table of Contents

Overview

What's new: This method provides a way to automatically learn data augmentation in order to improve the image classification performance. It does not require us to hard code augmentation techniques, which might need domain knowledge or an expensive hyper-parameter search on the validation set.

Key insight: Our method efficiently trains a network that performs data augmentation. This network learns data augmentation by usiing the gradient that flows from computing the classifier's validation loss using an online version of bilevel optimization. We also perform truncated back-propagation in order to significantly reduce the computational cost of bilevel optimization.

How it works: Our method jointly trains a classifier and an augmentation network through the following steps,

figure

  • For each mini batch,a forward pass is made to calculate the training loss.
  • Based on the training loss and the gradient of the training loss, an optimization step is made for the classifier in the inner loop.
  • A forward pass is then made on the classifier with the new weight to calculate the validation loss.
  • The gradient from the validation loss is backpropagated to train the augmentation network.

Results: Our model obtains better results than carefuly hand engineered transformations and GAN-based approaches. Further, the results are competitive against methods that use a policy search on CIFAR10, CIFAR100, BACH, Tiny-Imagenet and Imagenet datasets.

Why it matters: Proper data augmentation can significantly improve generalization performance. Unfortunately, deriving these augmentations require domain expertise or extensive hyper-parameter search. Thus, having an automatic and quick way of identifying efficient data augmentation has a big impact in obtaining better models.

Where to go from here: Performance can be improved by extending the set of learned transformations to non-differentiable transformations. The estimation of the validation loss could also be improved by exploring more the influence of the number of iteration in the inner loop. Finally, the method can be extended to other tasks like object detection of image segmentation.

Experiments

1. Install requirements: Run this command to install the Haven library which helps in managing experiments.

pip install -r requirements.txt

2.1 CIFAR10 experiments: The followng command runs the training and validation loop for CIFAR.

python trainval.py -e cifar -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

2.2 BACH experiments: The followng command runs the training and validation loop on BACH dataset.

python trainval.py -e bach -sb ../results -d ../data -r 1

where -e defines the experiment group, -sb is the result directory, and -d is the dataset directory.

3. Results: Display the results by following the steps below,

figure

Launch Jupyter by running the following on terminal,

jupyter nbextension enable --py widgetsnbextension
jupyter notebook

Then, run the following script on a Jupyter cell,

from haven import haven_jupyter as hj
from haven import haven_results as hr
from haven import haven_utils as hu

# path to where the experiments got saved
savedir_base = ''
exp_list = None

# exp_list = hu.load_py().EXP_GROUPS[]
# get experiments
rm = hr.ResultManager(exp_list=exp_list, 
                      savedir_base=savedir_base, 
                      verbose=0
                     )
y_metrics = ['test_acc']
bar_agg = 'max'
mode = 'bar'
legend_list = ['model.netA.name']
title_list = 'dataset.name'
legend_format = 'Augmentation Netwok: {}'
filterby_list = {'dataset':{'name':'cifar10'}, 'model':{'netC':{'name':'resnet18_meta_2'}}}

# launch dashboard
hj.get_dashboard(rm, vars(), wide_display=True)

Citation

@article{mounsaveng2020learning,
  title={Learning Data Augmentation with Online Bilevel Optimization for Image Classification},
  author={Mounsaveng, Saypraseuth and Laradji, Issam and Ayed, Ismail Ben and Vazquez, David and Pedersoli, Marco},
  journal={arXiv preprint arXiv:2006.14699},
  year={2020}
}
Owner
ElementAI
ElementAI
FIRM-AFL is the first high-throughput greybox fuzzer for IoT firmware.

FIRM-AFL FIRM-AFL is the first high-throughput greybox fuzzer for IoT firmware. FIRM-AFL addresses two fundamental problems in IoT fuzzing. First, it

356 Dec 23, 2022
Ego4d dataset repository. Download the dataset, visualize, extract features & example usage of the dataset

Ego4D EGO4D is the world's largest egocentric (first person) video ML dataset and benchmark suite, with 3,600 hrs (and counting) of densely narrated v

Meta Research 118 Jan 07, 2023
A library of multi-agent reinforcement learning components and systems

Mava: a research framework for distributed multi-agent reinforcement learning Table of Contents Overview Getting Started Supported Environments System

InstaDeep Ltd 463 Dec 23, 2022
Good Classification Measures and How to Find Them

Good Classification Measures and How to Find Them This repository contains supplementary materials for the paper "Good Classification Measures and How

Yandex Research 7 Nov 13, 2022
Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021, Pytorch)

S2VD Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021) Requirements and Dependencies Ubuntu 16.04, cuda 10.0 Python 3.6.10, P

Zongsheng Yue 53 Nov 23, 2022
AI Virtual Calculator: This is a simple virtual calculator based on Artificial intelligence.

AI Virtual Calculator: This is a simple virtual calculator that works with gestures using OpenCV. We will use our hand in the air to click on the calc

Md. Rakibul Islam 1 Jan 13, 2022
This repo is to be freely used by ML devs to check the GAN performances without coding from scratch.

GANs for Fun Created because I can! GOAL The goal of this repo is to be freely used by ML devs to check the GAN performances without coding from scrat

Sagnik Roy 13 Jan 26, 2022
Solving Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge

Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge Associated code for the paper Zero-Shot Learning in Named Entity Recognitio

Søren Hougaard Mulvad 13 Dec 25, 2022
Using contrastive learning and OpenAI's CLIP to find good embeddings for images with lossy transformations

The official code for the paper "Inverse Problems Leveraging Pre-trained Contrastive Representations" (to appear in NeurIPS 2021).

Sriram Ravula 26 Dec 10, 2022
EXplainable Artificial Intelligence (XAI)

EXplainable Artificial Intelligence (XAI) This repository includes the codes for different projects on eXplainable Artificial Intelligence (XAI) by th

4 Nov 28, 2022
existing and custom freqtrade strategies supporting the new hyperstrategy format.

freqtrade-strategies Description Existing and self-developed strategies, rewritten to support the new HyperStrategy format from the freqtrade-develop

39 Aug 20, 2021
Points2Surf: Learning Implicit Surfaces from Point Clouds (ECCV 2020 Spotlight)

Points2Surf: Learning Implicit Surfaces from Point Clouds (ECCV 2020 Spotlight)

Philipp Erler 329 Jan 06, 2023
X-VLM: Multi-Grained Vision Language Pre-Training

X-VLM: learning multi-grained vision language alignments Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts. Yan Zeng, Xi

Yan Zeng 286 Dec 23, 2022
Some experiments with tennis player aging curves using Hilbert space GPs in PyMC. Only experimental for now.

NOTE: This is still being developed! Setup notes This document uses Jeff Sackmann's tennis data. You can obtain it as follows: git clone https://githu

Martin Ingram 1 Jan 20, 2022
Generic U-Net Tensorflow implementation for image segmentation

Tensorflow Unet Warning This project is discontinued in favour of a Tensorflow 2 compatible reimplementation of this project found under https://githu

Joel Akeret 1.8k Dec 10, 2022
NeuPy is a Tensorflow based python library for prototyping and building neural networks

NeuPy v0.8.2 NeuPy is a python library for prototyping and building neural networks. NeuPy uses Tensorflow as a computational backend for deep learnin

Yurii Shevchuk 729 Jan 03, 2023
Extracts data from the database for a graph-node and stores it in parquet files

subgraph-extractor Extracts data from the database for a graph-node and stores it in parquet files Installation For developing, it's recommended to us

Cardstack 0 Jan 10, 2022
Newt - a Gaussian process library in JAX.

Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\

AaltoML 0 Nov 02, 2021
Pretrained Cost Model for Distributed Constraint Optimization Problems

Pretrained Cost Model for Distributed Constraint Optimization Problems Requirements PyTorch 1.9.0 PyTorch Geometric 1.7.1 Directory structure baseline

2 Aug 28, 2022
HTSeq is a Python library to facilitate processing and analysis of data from high-throughput sequencing (HTS) experiments.

HTSeq DEVS: https://github.com/htseq/htseq DOCS: https://htseq.readthedocs.io A Python library to facilitate programmatic analysis of data from high-t

HTSeq 57 Dec 20, 2022