A PyTorch implementation of a Factorization Machine module in cython.

Overview

fmpytorch

A library for factorization machines in pytorch. A factorization machine is like a linear model, except multiplicative interaction terms between the variables are modeled as well.

The input to a factorization machine layer is a vector, and the output is a scalar. Batching is fully supported.

This is a work in progress. Feedback and bugfixes welcome! Hopefully you find the code useful.

Usage

The factorization machine layers in fmpytorch can be used just like any other built-in module. Here's a simple feed-forward model using a factorization machine that takes in a 50-D input, and models interactions using k=5 factors.

import torch
from fmpytorch.second_order.fm import FactorizationMachine

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(100, 50)
        self.dropout = torch.nn.Dropout(.5)
	# This makes a fm layer mapping from 50-D to 1-D.
	# The number of factors is 5.
        self.fm = FactorizationMachine(50, 5)

    def forward(self, x):
        x = self.linear(x)
        x = self.dropout(x)
        x = self.fm(x)
        return x

See examples/toy.py or examples/regression.py for fuller examples.

Installation

This package requires pytorch, numpy, and cython.

To install, you can run:

cd fmpytorch
sudo python setup.py install

Factorization Machine brief intro

A linear model, given a vector x models its output y as

where w are the learnable weights of the model.

However, the interactions between the input variables x_i are purely additive. In some cases, it might be useful to model the interactions between your variables, e.g., x_i * x_j. You could add terms into your model like

However, this introduces a large number of w2 variables. Specifically, there are O(n^2) parameters introduced in this formulation, one for each interaction pair. A factorization machine approximates w2 using low dimensional factors, i.e.,

where each v_i is a low-dimensional vector. This is the forward pass of a second order factorization machine. This low-rank re-formulation has reduced the number of additional parameters for the factorization machine to O(k*n). Magically, the forward (and backward) pass can be reformulated so that it can be computed in O(k*n), rather than the naive O(k*n^2) formulation above.

Currently supported features

Currently, only a second order factorization machine is supported. The forward and backward passes are implemented in cython. Compared to the autodiff solution, the cython passes run several orders of magnitude faster. I've only tested it with python 2 at the moment.

TODOs

  1. Support for sparse tensors.
  2. More interesting useage examples
  3. More testing, e.g., with python 3, etc.
  4. Make sure all of the code plays nice with torch-specific stuff, e.g., GPUs
  5. Arbitrary order factorization machine support
  6. Better organization/code cleaning

Thanks to

Vlad Niculae (@vene) for his sage wisdom.

The original factorization machine citation, which this layer is based off of, is

@inproceedings{rendle2010factorization,
	       title={Factorization machines},
    	       author={Rendle, Steffen},
      	       booktitle={ICDM},
               pages={995--1000},
	       year={2010},
	       organization={IEEE}
}
Owner
Jack Hessel
Research Scientist @ AI2: PhD in CS previously from Cornell
Jack Hessel
The source code and dataset for the RecGURU paper (WSDM 2022)

RecGURU About The Project Source code and baselines for the RecGURU paper "RecGURU: Adversarial Learning of Generalized User Representations for Cross

Chenglin Li 17 Jan 07, 2023
Replication of Pix2Seq with Pretrained Model

Pretrained-Pix2Seq We provide the pre-trained model of Pix2Seq. This version contains new data augmentation. The model is trained for 300 epochs and c

peng gao 51 Nov 22, 2022
Tensorforce: a TensorFlow library for applied reinforcement learning

Tensorforce: a TensorFlow library for applied reinforcement learning Introduction Tensorforce is an open-source deep reinforcement learning framework,

Tensorforce 3.2k Jan 02, 2023
Tensorflow implementation of soft-attention mechanism for video caption generation.

SA-tensorflow Tensorflow implementation of soft-attention mechanism for video caption generation. An example of soft-attention mechanism. The attentio

Paul Chen 153 Nov 14, 2022
A facial recognition doorbell system using a Raspberry Pi

Facial Recognition Doorbell This project expands on the person-detecting doorbell system to allow it to identify faces, and announce names accordingly

rydercalmdown 22 Apr 15, 2022
Official implementation of Rethinking Graph Neural Architecture Search from Message-passing (CVPR2021)

Rethinking Graph Neural Architecture Search from Message-passing Intro The GNAS can automatically learn better architecture with the optimal depth of

Shaofei Cai 48 Sep 30, 2022
Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces"

Code Impementation for "Mold into a Graph: Efficient Bayesian Optimization over Mixed Spaces" This repo contains the implementation of GEBO algorithm.

Jaeyeon Ahn 2 Mar 22, 2022
Deep Sketch-guided Cartoon Video Inbetweening

Cartoon Video Inbetweening Paper | DOI | Video The source code of Deep Sketch-guided Cartoon Video Inbetweening by Xiaoyu Li, Bo Zhang, Jing Liao, Ped

Xiaoyu Li 37 Dec 22, 2022
A collection of Jupyter notebooks to play with NVIDIA's StyleGAN3 and OpenAI's CLIP for a text-based guided image generation.

StyleGAN3 CLIP-based guidance StyleGAN3 + CLIP StyleGAN3 + inversion + CLIP This repo is a collection of Jupyter notebooks made to easily play with St

Eugenio Herrera 176 Dec 30, 2022
Author Disambiguation using Knowledge Graph Embeddings with Literals

Author Name Disambiguation with Knowledge Graph Embeddings using Literals This is the repository for the master thesis project on Knowledge Graph Embe

12 Oct 19, 2022
This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

18 Sep 02, 2022
Attention-based Transformation from Latent Features to Point Clouds (AAAI 2022)

Attention-based Transformation from Latent Features to Point Clouds This repository contains a PyTorch implementation of the paper: Attention-based Tr

12 Nov 11, 2022
Breast cancer is been classified into benign tumour and malignant tumour.

Breast cancer is been classified into benign tumour and malignant tumour. Logistic regression is applied in this model.

1 Feb 04, 2022
MazeRL is an application oriented Deep Reinforcement Learning (RL) framework

MazeRL is an application oriented Deep Reinforcement Learning (RL) framework, addressing real-world decision problems. Our vision is to cover the complete development life cycle of RL applications ra

EnliteAI GmbH 222 Dec 24, 2022
Numerical-computing-is-fun - Learning numerical computing with notebooks for all ages.

As much as this series is to educate aspiring computer programmers and data scientists of all ages and all backgrounds, it is also a reminder to mysel

EKA foundation 758 Dec 25, 2022
Get the partition that a file belongs and the percentage of space that consumes

tinos_eisai_sy Get the partition that a file belongs and the percentage of space that consumes (works only with OSes that use the df command) tinos_ei

Konstantinos Patronas 6 Jan 24, 2022
OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021)

OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021) Video demo We here provide a video demo from co

20 Nov 25, 2022
GEP (GDB Enhanced Prompt) - a GDB plug-in for GDB command prompt with fzf history search, fish-like autosuggestions, auto-completion with floating window, partial string matching in history, and more!

GEP (GDB Enhanced Prompt) GEP (GDB Enhanced Prompt) is a GDB plug-in which make your GDB command prompt more convenient and flexibility. Why I need th

Alan Li 23 Dec 21, 2022
Simulating an AI playing 2048 using the Expectimax algorithm

2048-expectimax Simulating an AI playing 2048 using the Expectimax algorithm The base game engine uses code from here. The AI player is modeled as a m

Subha Ramesh 2 Jan 31, 2022
rastrainer is a QGIS plugin to training remote sensing semantic segmentation model based on PaddlePaddle.

rastrainer rastrainer is a QGIS plugin to training remote sensing semantic segmentation model based on PaddlePaddle. UI TODO Init UI. Add Block. Add l

deepbands 5 Mar 04, 2022