Second-Order Neural ODE Optimizer, NeurIPS 2021 spotlight

Related tags

Deep Learningsnopt
Overview

Second-order Neural ODE Optimizer
(NeurIPS 2021 Spotlight) [arXiv]

✔️ faster convergence in wall-clock time | ✔️ O(1) memory cost |
✔️ better test-time performance | ✔️ architecture co-optimization

This repo provides PyTorch code of Second-order Neural ODE Optimizer (SNOpt), a second-order optimizer for training Neural ODEs that retains O(1) memory cost with superior convergence and test-time performance.

SNOpt result

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1) and torchdiffeq >= 0.2.0 are required.

  1. Install the dependencies with Anaconda and activate the environment snopt with
    conda env create --file requirements.yaml python=3
    conda activate snopt
  2. [Optional] This repo provides a modification (with 15 lines!) of torchdiffeq that allows SNOpt to collect 2nd-order information during adjoint-based training. If you wish to run torchdiffeq on other commit, simply copy-and-paste the folder to this directory then apply the provided snopt_integration.patch.
    cp -r <path_to_your_torchdiffeq_folder> .
    git apply snopt_integration.patch

Run the code

We provide example code for 8 datasets across image classification (main_img_clf.py), time-series prediction (main_time_series.py), and continuous normalizing flow (main_cnf.py). The command lines to generate similar results shown in our paper are detailed in scripts folder. Datasets will be automatically downloaded to data folder at the first call, and all results will be saved to result folder.

bash scripts/run_img_clf.sh     <dataset> # dataset can be {mnist, svhn, cifar10}
bash scripts/run_time_series.sh <dataset> # dataset can be {char-traj, art-wr, spo-ad}
bash scripts/run_cnf.sh         <dataset> # dataset can be {miniboone, gas}

For architecture (specifically integration time) co-optimization, run

bash scripts/run_img_clf.sh cifar10-t1-optimize

Integration with your workflow

snopt can be integrated flawlessly with existing training work flow. Below we provide a handy checklist and pseudo-code to help your integration. For more complex examples, please refer to main_*.py in this repo.

  • Import torchdiffeq that is patched with snopt integration; otherwise simply use torchdiffeq in this repo.
  • Inherit snopt.ODEFuncBase as your vector field; implement the forward pass in F rather than forward.
  • Create Neural ODE with ode layer(s) using snopt.ODEBlock; implement properties odes and ode_mods.
  • Initialize snopt.SNOpt as preconditioner; call train_itr_setup() and step() before standard optim.zero_grad() and optim.step() (see the code below).
  • That's it 🤓 ! Enjoy your second-order training 🚂 🚅 !
import torch
from torchdiffeq import odeint_adjoint as odesolve
from snopt import SNOpt, ODEFuncBase, ODEBlock
from easydict import EasyDict as dict

class ODEFunc(ODEFuncBase):
    def __init__(self, opt):
        super(ODEFunc, self).__init__(opt)
        self.linear = torch.nn.Linear(input_dim, input_dim)

    def F(self, t, z):
        return self.linear(z)

class NeuralODE(torch.nn.Module):
    def __init__(self, ode):
        super(NeuralODE, self).__init__()
        self.ode = ode

    def forward(self, z):
        return self.ode(z)

    @property
    def odes(self): # in case we have multiple odes, collect them in a list
        return [self.ode]

    @property
    def ode_mods(self): # modules of all ode(s)
        return [mod for mod in self.ode.odefunc.modules()]

# Create Neural ODE
opt = dict(
    optimizer='SNOpt',tol=1e-3,ode_solver='dopri5',use_adaptive_t1=False,snopt_step_size=0.01)
odefunc = ODEFunc(opt)
integration_time = torch.tensor([0.0, 1.0]).float()
ode = ODEBlock(opt, odefunc, odesolve, integration_time)
net = NeuralODE(ode)

# Create SNOpt optimizer
precond = SNOpt(net, eps=0.05, update_freq=100)
optim = torch.optim.SGD(net.parameters(), lr=0.001)

# Training loop
for (x,y) in training_loader:
    precond.train_itr_setup() # <--- additional step for precond
    optim.zero_grad()

    loss = loss_function(net(x), y)
    loss.backward()

    # Run SNOpt optimizer
    precond.step()            # <--- additional step for precond
    optim.step()

What the library actually contains

This snopt library implements the following objects for efficient 2nd-order adjoint-based training of Neural ODEs.

  • ODEFuncBase: Defines the vector field (inherits torch.nn.Module) of Neural ODE.
  • CNFFuncBase: Serves the same purposes as ODEFuncBase except for CNF applications.
  • ODEBlock: A Neural-ODE module (torch.nn.Module) that solves the initial value problem (given the vector field, integration time, and a ODE solver) and handles integration time co-optimization with feedback policy.
  • SNOpt: Our primary 2nd-order optimizer (torch.optim.Optimizer), implemented as a "preconditioner" (see example code above). It takes the following arguments.
    • net is the Neural ODE. Note that the entire network (rather than net.parameters()) is required.
    • eps is the the regularization that stabilizes preconditioning. We recommend the value in [0.05, 0.1].
    • update_freq is the frequency to refresh the 2nd-order information. We recommend the value 100~200.
    • alpha decides the running averages of eigenvalues. We recommend fixing the value to 0.75.
    • full_precond decides whether we wish to precondition layers aside from those in Neural ODEs.
  • SNOptAdjointCollector: A helper to collect information from torchdiffeq to construct 2nd-order matrices.
  • IntegrationTimeOptimizer: Our 2nd-order method that co-optimizes the integration time (i.e., t1). This is done by calling t1_train_itr_setup(train_it) and update_t1() together with optim.zero_grad() and optim.step() (see trainer.py).

The options are passed in as opt and contains the following fields (see options.py for full descriptions.)

  • optimizer is the training method. Use "SNOpt" to enable our method.
  • ode_solver specifies the ODE solver (default is "dopri5") with the absolute/relative tolerance tol.
  • For CNF applications, use divergence_type to specify how divergence should be computed.
  • snopt_step_size determines the step sizes SNOpt will sample along the integration to compute 2nd-order matrices. We recommend the value 0.01 for integration time [0,1], which yield around 100 sampled points.
  • For integration time (t1) co-optimization, enable the flag use_adaptive_t1 and setup the following options.
    • adaptive_t1 specifies t1 optimization method. Choices are "baseline" and "feedback"(ours).
    • t1_lr is the learning rate. We recommend the value in [0.05, 0.1].
    • t1_reg is the coefficient of the quadratic penalty imposed on t1. The performance is quite sensitive to this value. We recommend the value in [1e-4, 1e-3].
    • t1_update_freq is the frequency to update t1. We recommend the value 50~100.

Remarks & Citation

The current library only supports adjoint-based training, yet it can be extended to normal odeint method (stay tuned!). The pre-processing of tabular and uea datasets are adopted from ffjord and NeuralCDE, and the eigenvalue-regularized preconditioning is adopted from EKFAC-pytorch.

If you find this library useful, please cite ⬇️ . Contact me ([email protected]) if you have any questions!

@inproceedings{liu2021second,
  title={Second-order Neural ODE Optimizer},
  author={Liu, Guan-Horng and Chen, Tianrong and Theodorou, Evangelos A},
  booktitle={Advances in Neural Information Processing Systems},
  year={2021},
}
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Analyzes your GitHub Profile and presents you with a report on how likely you are to become the next MLH Fellow!

Fellowship Prediction GitHub Profile Comparative Analysis Tool Built with BentoML Table of Contents: Features Disclaimer Technologies Used Contributin

Damir Temir 51 Dec 29, 2022
alfred-py: A deep learning utility library for **human**

Alfred Alfred is command line tool for deep-learning usage. if you want split an video into image frames or combine frames into a single video, then a

JinTian 800 Jan 03, 2023
Customer Segmentation using RFM

Customer-Segmentation-using-RFM İş Problemi Bir e-ticaret şirketi müşterilerini segmentlere ayırıp bu segmentlere göre pazarlama stratejileri belirlem

Nazli Sener 7 Dec 26, 2021
Code accompanying the NeurIPS 2021 paper "Generating High-Quality Explanations for Navigation in Partially-Revealed Environments"

Generating High-Quality Explanations for Navigation in Partially-Revealed Environments This work presents an approach to explainable navigation under

RAIL Group @ George Mason University 1 Oct 28, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 18, 2021
Sequential model-based optimization with a `scipy.optimize` interface

Scikit-Optimize Scikit-Optimize, or skopt, is a simple and efficient library to minimize (very) expensive and noisy black-box functions. It implements

Scikit-Optimize 2.5k Jan 04, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 01, 2023
Prompt Tuning with Rules

PTR Code and datasets for our paper "PTR: Prompt Tuning with Rules for Text Classification" If you use the code, please cite the following paper: @art

THUNLP 118 Dec 30, 2022
Newt - a Gaussian process library in JAX.

Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\

AaltoML 0 Nov 02, 2021
Link prediction using Multiple Order Local Information (MOLI)

Understanding the network formation pattern for better link prediction Authors: [e

Wu Lab 0 Oct 18, 2021
This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm and CNN.

Vietnamese sign lagnuage recognition using MHI and CNN This is a model to classify Vietnamese sign language using Motion history image (MHI) algorithm

Phat Pham 3 Feb 24, 2022
Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21)

NeuralGIF Code for Neural-GIF: Neural Generalized Implicit Functions for Animating People in Clothing(ICCV21) We present Neural Generalized Implicit F

Garvita Tiwari 104 Nov 18, 2022
Implementation of Sequence Generative Adversarial Nets with Policy Gradient

SeqGAN Requirements: Tensorflow r1.0.1 Python 2.7 CUDA 7.5+ (For GPU) Introduction Apply Generative Adversarial Nets to generating sequences of discre

Lantao Yu 2k Dec 29, 2022
Controlling the MicriSpotAI robot from scratch

Abstract: The SpotMicroAI project is designed to be a low cost, easily built quadruped robot. The design is roughly based off of Boston Dynamics quadr

Florian Wilk 405 Jan 05, 2023
pyspark🍒🥭 is delicious,just eat it!😋😋

如何用10天吃掉pyspark? 🔥 🔥 《10天吃掉那只pyspark》 🚀

lyhue1991 578 Dec 30, 2022
Self-Supervised depth kalilia

Self-Supervised depth kalilia

24 Oct 15, 2022
Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.

Tensor2Tensor Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and ac

12.9k Jan 09, 2023
A very tiny, very simple, and very secure file encryption tool.

Picocrypt is a very tiny (hence "Pico"), very simple, yet very secure file encryption tool. It uses the modern ChaCha20-Poly1305 cipher suite as well

Evan Su 1k Dec 30, 2022
Tello Drone Trajectory Tracking

With this library you can track the trajectory of your tello drone or swarm of drones in real time.

Kamran Asgarov 2 Oct 12, 2022
Prototype-based Incremental Few-Shot Semantic Segmentation

Prototype-based Incremental Few-Shot Semantic Segmentation Fabio Cermelli, Massimiliano Mancini, Yongqin Xian, Zeynep Akata, Barbara Caputo -- BMVC 20

Fabio Cermelli 21 Dec 29, 2022