Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions

Overview

torch-imle

Concise and self-contained PyTorch library implementing the I-MLE gradient estimator proposed in our NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

This repository contains a library for transforming any combinatorial black-box solver in a differentiable layer. All code for reproducing the experiments in the NeurIPS paper is available in the official NEC Laboratories Europe repository.

Overview

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear program (ILP) solvers, in standard deep learning architectures. The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to an empirical distribution, we have to design surrogate empirical distributions. Here we propose two families of surrogate distributions which are widely applicable and work well in practice.

Example

For example, let's consider a map from a simple game where the task is to find the shortest path from the top-left to the bottom-right corner. Black areas have the highest and white areas the lowest cost. In the centre, you can see what happens when we use the proposed sum-of-gamma noise distribution to sample paths. On the right, you can see the resulting marginal probabilities for every tile (the probability of each tile being part of a sampled path).

Gradients and Learning

Let us assume that the optimal shortest path is the one of the left. Starting from random weights, the model can learn to produce the weights that will result in the optimal shortest path via Gradient Descent, by minimising the Hamming loss between the produced path and the gold path. Here we show the paths being produced during training (middle), and the corresponding map weights (right).

Input noise temperature set to 0.0, and target noise temperature set to 0.0:

Input noise temperature set to 1.0, and target noise temperature set to 1.0:

Input noise temperature set to 2.0, and target noise temperature set to 2.0:

Input noise temperature set to 5.0, and target noise temperature set to 5.0:

Input noise temperature set to 5.0, and target noise temperature set to 0.0:

All animations were generated by this script.

Code

Using this library is extremely easy -- see this example as a reference. Assuming we have a method that implements a black-box combinatorial solver such as Dijkstra's algorithm:

import numpy as np

import torch
from torch import Tensor

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

We can obtain the corresponding distribution and gradients in this way:

from imle.wrapper import imle
from imle.target import TargetDistribution
from imle.noise import SumOfGammaNoiseDistribution

target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100)

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

imle_solver = imle(torch_solver,
                   target_distribution=target_distribution,
                    noise_distribution=noise_distribution,
                    nb_samples=10,
                    input_noise_temperature=input_noise_temperature,
                    target_noise_temperature=target_noise_temperature)

Or, alternatively, using a simple function annotation:

@imle(target_distribution=target_distribution,
      noise_distribution=noise_distribution,
      nb_samples=10,
      input_noise_temperature=input_noise_temperature,
      target_noise_temperature=target_noise_temperature)
def imle_solver(weights_batch: Tensor) -> Tensor:
    return torch_solver(weights_batch)

Papers using I-MLE

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
UCL Natural Language Processing
UCL Natural Language Processing
Python Classes: Medical Insurance Project using Object Oriented Programming Concepts

Medical-Insurance-Project-OOP Python Classes: Medical Insurance Project using Object Oriented Programming Concepts Classes are an incredibly useful pr

Hugo B. 0 Feb 04, 2022
Tensorboard for pytorch (and chainer, mxnet, numpy, ...)

tensorboardX Write TensorBoard events with simple function call. The current release (v2.3) is tested on anaconda3, with PyTorch 1.8.1 / torchvision 0

Tzu-Wei Huang 7.5k Dec 28, 2022
Vector AI — A platform for building vector based applications. Encode, query and analyse data using vectors.

Vector AI is a framework designed to make the process of building production grade vector based applications as quickly and easily as possible. Create

Vector AI 267 Dec 23, 2022
Official code release for: EditGAN: High-Precision Semantic Image Editing

Official code release for: EditGAN: High-Precision Semantic Image Editing

565 Jan 05, 2023
Price-Prediction-For-a-Dream-Home - A machine learning based linear regression trained model for house price prediction.

Price-Prediction-For-a-Dream-Home ROADMAP TO THIS LINEAR REGRESSION BASED HOUSE PRICE PREDICTION PREDICTION MODEL Import all the dependencies of the p

DIKSHA DESWAL 1 Dec 29, 2021
Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images

Anchor Retouching via Model Interaction for Robust Object Detection in Aerial Images In this paper, we present an effective Dynamic Enhancement Anchor

13 Dec 09, 2022
Official implementation for the paper: Generating Smooth Pose Sequences for Diverse Human Motion Prediction

Generating Smooth Pose Sequences for Diverse Human Motion Prediction This is official implementation for the paper Generating Smooth Pose Sequences fo

Wei Mao 28 Dec 10, 2022
StorSeismic: An approach to pre-train a neural network to store seismic data features

StorSeismic: An approach to pre-train a neural network to store seismic data features This repository contains codes and resources to reproduce experi

Seismic Wave Analysis Group 11 Dec 05, 2022
The Codebase for Causal Distillation for Language Models.

Causal Distillation for Language Models Zhengxuan Wu*,Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D.

Zen 20 Dec 31, 2022
Simultaneous NMT/MMT framework in PyTorch

This repository includes the codes, the experiment configurations and the scripts to prepare/download data for the Simultaneous Machine Translation wi

<a href=[email protected]"> 37 Sep 29, 2022
ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin et al., 2020).

ReConsider ReConsider is a re-ranking model that re-ranks the top-K (passage, answer-span) predictions of an Open-Domain QA Model like DPR (Karpukhin

Facebook Research 47 Jul 26, 2022
Privacy-Preserving Portrait Matting [ACM MM-21]

Privacy-Preserving Portrait Matting [ACM MM-21] This is the official repository of the paper Privacy-Preserving Portrait Matting. Jizhizi Li∗, Sihan M

Jizhizi_Li 212 Dec 27, 2022
Deep Learning (with PyTorch)

Deep Learning (with PyTorch) This notebook repository now has a companion website, where all the course material can be found in video and textual for

Alfredo Canziani 6.2k Jan 07, 2023
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 04, 2023
Double pendulum simulator using a symplectic Euler's method and Hamiltonian mechanics

Symplectic Double Pendulum Simulator Double pendulum simulator using a symplectic Euler's method. The program calculates the momentum and position of

Scott Marino 1 Jan 12, 2022
Using knowledge-informed machine learning on the PRONOSTIA (FEMTO) and IMS bearing data sets. Predict remaining-useful-life (RUL).

Knowledge Informed Machine Learning using a Weibull-based Loss Function Exploring the concept of knowledge-informed machine learning with the use of a

Tim 43 Dec 14, 2022
This repository contains the implementation of the HealthGen model, a generative model to synthesize realistic EHR time series data with missingness

HealthGen: Conditional EHR Time Series Generation This repository contains the implementation of the HealthGen model, a generative model to synthesize

0 Jan 20, 2022
Towards Multi-Camera 3D Human Pose Estimation in Wild Environment

PanopticStudio Toolbox This repository has a toolbox to download, process, and visualize the Panoptic Studio (Panoptic) data. Note: Sep-21-2020: Curre

335 Jan 09, 2023
Pytorch Implementation of LNSNet for Superpixel Segmentation

LNSNet Overview Official implementation of Learning the Superpixel in a Non-iterative and Lifelong Manner (CVPR'21) Learning Strategy The proposed LNS

42 Oct 11, 2022
[ICCV2021] Learning to Track Objects from Unlabeled Videos

Unsupervised Single Object Tracking (USOT) 🌿 Learning to Track Objects from Unlabeled Videos Jilai Zheng, Chao Ma, Houwen Peng and Xiaokang Yang 2021

53 Dec 28, 2022