ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Related tags

Deep Learningesgd
Overview

ESGD-M

ESGD-M is a stochastic non-convex second order optimizer, suitable for training deep learning models. It is based on ESGD (Equilibrated adaptive learning rates for non-convex optimization) and incorporates quasi-hyperbolic momentum (Quasi-hyperbolic momentum and Adam for deep learning) to accelerate convergence, which considerably improves its performance over plain ESGD.

ESGD-M obtains Hessian information through occasional Hessian-vector products (by default, every ten optimizer steps; each Hessian-vector product is approximately the same cost as a gradient evaluation) and uses it to adapt per-parameter learning rates. It estimates the diagonal of the absolute Hessian, diag(|H|), to use as a diagonal preconditioner.

To use this optimizer you must call .backward() with the create_graph=True option. Gradient accumulation steps and distributed training are currently not supported.

Learning rates

ESGD-M learning rates have a different meaning from SGD and Adagrad/Adam/etc. You may need to try learning rates in the range 1e-3 to 1.

SGD class optimizers:

  • If you rescale your parameters by a factor of n, you must scale your learning rate by a factor of n^2.

  • If you rescale your loss by a factor of n, you must scale your learning rate by a factor of 1 / n.

Adagrad/Adam class optimizers:

  • If you rescale your parameters by a factor of n, you must scale your learning rate by a factor of n.

  • If you rescale your loss by a factor of n, you do not have to scale your learning rate.

Second order optimizers (including ESGD-M):

  • You do not have to scale your learning rate if you rescale either your parameters or your loss.

Momentum

The default configuration is Nesterov momentum (if v is not specified then it will default to the value of beta_1, producing Nesterov momentum):

opt = ESGD(model.parameters(), lr=1, betas=(0.9, 0.999), v=0.9)

The Quasi-Hyperbolic Momentum recommended defaults can be obtained using:

opt = ESGD(model.parameters(), lr=1, betas=(0.999, 0.999), v=0.7)

Setting v equal to 1 will do normal (non-Nesterov) momentum.

The ESGD-M decay coefficient beta_2 refers not to the squared gradient as in Adam but to the squared Hessian diagonal estimate, which it uses in place of the squared gradient to provide per-parameter adaptive learning rates.

Hessian-vector products

The absolute Hessian diagonal diag(|H|) is estimated every update_d_every steps. The default is 10. Also, for the first d_warmup steps the diagonal will be estimated regardless, to obtain a lower variance estimate of diag(|H|) quickly. The estimation uses a Hessian-vector product, which takes around the same amount of time as a gradient evaluation to compute. You must explicitly signal to PyTorch that you want to do a double backward pass by:

opt.zero_grad(set_to_none=True)
loss = loss_fn(model(inputs), targets)
loss.backward(create_graph=True)
opt.step()

Weight decay

Weight decay is performed separately from the Hessian-vector product and the preconditioner, similar to AdamW except that the weight decay value provided by the user is multiplied by the current learning rate to determine the factor to decay the weights by.

Learning rate warmup

Because the diag(|H|) estimates are high variance, the adaptive learning rates are not very reliable before many steps have been taken and many estimates have been averaged together. To deal with this ESGD-M has a short exponential learning rate warmup by default (it is combined with any external learning rate schedulers). On each step (starting from 1) the learning rate will be:

lr * (1 - lr_warmup**step)

The default value for lr_warmup is 0.99, which reaches 63% of the specified learning rate in 100 steps and 95% in 300 steps.

Owner
Katherine Crowson
AI/generative artist.
Katherine Crowson
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL: Graph Contrastive Learning for PyTorch PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL com

GCL: Graph Contrastive Learning Library for PyTorch 594 Jan 08, 2023
DA2Lite is an automated model compression toolkit for PyTorch.

DA2Lite (Deep Architecture to Lite) is a toolkit to compress and accelerate deep network models. ⭐ Star us on GitHub — it helps!! Frameworks & Librari

Sinhan Kang 7 Mar 22, 2022
SOLO and SOLOv2 for instance segmentation, ECCV 2020 & NeurIPS 2020.

SOLO: Segmenting Objects by Locations This project hosts the code for implementing the SOLO algorithms for instance segmentation. SOLO: Segmenting Obj

Xinlong Wang 1.5k Dec 31, 2022
Action Recognition for Self-Driving Cars

Action Recognition for Self-Driving Cars This repo contains the codes for the 2021 Fall semester project "Action Recognition for Self-Driving Cars" at

VITA lab at EPFL 3 Apr 07, 2022
FANet - Real-time Semantic Segmentation with Fast Attention

FANet Real-time Semantic Segmentation with Fast Attention Ping Hu, Federico Perazzi, Fabian Caba Heilbron, Oliver Wang, Zhe Lin, Kate Saenko , Stan Sc

Ping Hu 42 Nov 30, 2022
Go from graph data to a secure and interactive visual graph app in 15 minutes. Batteries-included self-hosting of graph data apps with Streamlit, Graphistry, RAPIDS, and more!

✔️ Linux ✔️ OS X ❌ Windows (#39) Welcome to graph-app-kit Turn your graph data into a secure and interactive visual graph app in 15 minutes! Why This

Graphistry 107 Jan 02, 2023
Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative Adversarial Neural Networks

ForecastingNonverbalSignals This is the implementation for the paper Forecasting Nonverbal Social Signals during Dyadic Interactions with Generative A

1 Feb 10, 2022
Real-CUGAN - Real Cascade U-Nets for Anime Image Super Resolution

Real Cascade U-Nets for Anime Image Super Resolution 中文 | English 🔥 Real-CUGAN

tarsin 111 Dec 28, 2022
Second Order Optimization and Curvature Estimation with K-FAC in JAX.

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX

DeepMind 90 Dec 22, 2022
Reference code for the paper CAMS: Color-Aware Multi-Style Transfer.

CAMS: Color-Aware Multi-Style Transfer Mahmoud Afifi1, Abdullah Abuolaim*1, Mostafa Hussien*2, Marcus A. Brubaker1, Michael S. Brown1 1York University

Mahmoud Afifi 36 Dec 04, 2022
EquiBind: Geometric Deep Learning for Drug Binding Structure Prediction

EquiBind: geometric deep learning for fast predictions of the 3D structure in which a small molecule binds to a protein

Hannes Stärk 355 Jan 03, 2023
Code for the paper "Graph Attention Tracking". (CVPR2021)

SiamGAT 1. Environment setup This code has been tested on Ubuntu 16.04, Python 3.5, Pytorch 1.2.0, CUDA 9.0. Please install related libraries before r

122 Dec 24, 2022
Using contrastive learning and OpenAI's CLIP to find good embeddings for images with lossy transformations

Creating Robust Representations from Pre-Trained Image Encoders using Contrastive Learning Sriram Ravula, Georgios Smyrnis This is the code for our pr

Sriram Ravula 26 Dec 10, 2022
Metadata-Extractor - Metadata Extractor Script can be used to read in exif metadata

Metadata Extractor The exifextract script can be used to read in exif metadata f

1 Feb 16, 2022
An Implementation of Fully Convolutional Networks in Tensorflow.

Update An example on how to integrate this code into your own semantic segmentation pipeline can be found in my KittiSeg project repository. tensorflo

Marvin Teichmann 1.1k Dec 12, 2022
This repo contains implementation of different architectures for emotion recognition in conversations.

Emotion Recognition in Conversations Updates 🔥 🔥 🔥 Date Announcements 03/08/2021 🎆 🎆 We have released a new dataset M2H2: A Multimodal Multiparty

Deep Cognition and Language Research (DeCLaRe) Lab 1k Dec 30, 2022
Collect some papers about transformer with vision. Awesome Transformer with Computer Vision (CV)

Awesome Visual-Transformer Collect some Transformer with Computer-Vision (CV) papers. If you find some overlooked papers, please open issues or pull r

dkliang 2.8k Jan 08, 2023
[NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning

SoCo [NeurIPS 2021 Spotlight] Aligning Pretraining for Detection via Object-Level Contrastive Learning By Fangyun Wei*, Yue Gao*, Zhirong Wu, Han Hu,

Yue Gao 139 Dec 14, 2022
Codebase for arXiv preprint "NeRF++: Analyzing and Improving Neural Radiance Fields"

NeRF++ Codebase for arXiv preprint "NeRF++: Analyzing and Improving Neural Radiance Fields" Work with 360 capture of large-scale unbounded scenes. Sup

Kai Zhang 722 Dec 28, 2022
This is the repository for the NeurIPS-21 paper [Contrastive Graph Poisson Networks: Semi-Supervised Learning with Extremely Limited Labels].

CGPN This is the repository for the NeurIPS-21 paper [Contrastive Graph Poisson Networks: Semi-Supervised Learning with Extremely Limited Labels]. Req

10 Sep 12, 2022