Implementation of TabTransformer, attention network for tabular data, in Pytorch

Overview

Tab Transformer

Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's breadth of GBDT's performance.

Install

$ pip install tab-transformer-pytorch

Usage

import torch
from tab_transformer_pytorch import TabTransformer

cont_mean_std = torch.randn(10, 2)

model = TabTransformer(
    categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
    num_continuous = 10,                # number of continuous values
    dim = 32,                           # dimension, paper set at 32
    dim_out = 1,                        # binary prediction, but could be anything
    depth = 6,                          # depth, paper recommended 6
    heads = 8,                          # heads, paper recommends 8
    attn_dropout = 0.1,                 # post-attention dropout
    ff_dropout = 0.1,                   # feed forward dropout
    mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
    mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm
)

x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)

Unsupervised Training

To undergo the type of unsupervised training described in the paper, you can first convert your categories tokens to the appropriate unique ids, and then use Electra on model.transformer.

Citations

@misc{huang2020tabtransformer,
    title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings}, 
    author={Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
    year={2020},
    eprint={2012.06678},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
Comments
  • Minor Bug: actuation function being applied to output layer in class MLP

    Minor Bug: actuation function being applied to output layer in class MLP

    The code for class MLP is mistakingly applying the actuation function to the last (i.e. output) layer. The error is in the evaluation of the is_last flag. The current code is:

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 1)
    

    The last line should be changed to is_last = ind >= (len(dims) - 2):

    class MLP(nn.Module):
        def __init__(self, dims, act = None):
            super().__init__()
            dims_pairs = list(zip(dims[:-1], dims[1:]))
            layers = []
            for ind, (dim_in, dim_out) in enumerate(dims_pairs):
                is_last = ind >= (len(dims) - 2)
    

    If you like, I can do a pull request.

    opened by rminhas 1
  • Update tab_transformer_pytorch.py

    Update tab_transformer_pytorch.py

    Add activation function out of the loop for the whole model, not after each of the linear layers. 'if is_last' condition was creating linear output all the time no matter what the activation function was.

    opened by EveryoneDirn 0
  • Unindent continuous_mean_std buffer

    Unindent continuous_mean_std buffer

    Problem: continuous_mean_std is not an attribute of TabTransformer if not defined in the argument explicitly. Example reproducing AttributeError:

    model = TabTransformer(
        categories = (10, 5, 6, 5, 8),      # tuple containing the number of unique values within each category
        num_continuous = 10,                # number of continuous values
        dim = 32,                           # dimension, paper set at 32
        dim_out = 1,                        # binary prediction, but could be anything
        depth = 6,                          # depth, paper recommended 6
        heads = 8,                          # heads, paper recommends 8
        attn_dropout = 0.1,                 # post-attention dropout
        ff_dropout = 0.1,                   # feed forward dropout
        mlp_hidden_mults = (4, 2),          # relative multiples of each hidden dimension of the last mlp to logits
        mlp_act = nn.ReLU(),                # activation for final mlp, defaults to relu, but could be anything else (selu etc)
    # continuous_mean_std = cont_mean_std # (optional) - normalize the continuous values before layer norm)
    x_categ = torch.randint(0, 5, (1, 5))     # category values, from 0 - max number of categories, in the order as passed into the constructor above
    x_cont = torch.randn(1, 10)               # assume continuous values are already normalized individually
    pred = model(x_categ, x_cont) # gives AttributeError
    
    

    Solution: Simply un-indenting the buffer registration of continuous_mean_std.

    opened by spliew 0
  • low gpu usage,

    low gpu usage,

    Hi.

    I'm having a problem with running your code with my dataset. It's pretty slow. GPU runs at 50% usage in average and each epoch takes almost 900 seconds to run.

    My dataset has 590540 rows, 24 categorical features, and 192 continuous features. Categories are encoded using Label encoder. Total dataset size is around 600Mb. My gpu is an integrated NVIDIA RTX 3060 with 6Gb of RAM. Optimizer is Adam.

    These are the software versions:

    Windows 10

    Python: 3.7.11 Pytorch: 1.7.0+cu110 Numpy: 1.21.2

    Let me know if you need more info from my side.

    Thanks.

    Xin.

    opened by xinqiao123 0
  • Intended usage of num_special_tokens?

    Intended usage of num_special_tokens?

    From what I understand, these are supposed to be reserved for oov values. Is the intended usage to set oov values in the input to some negative number and overwrite the offset? That is what it seems like it would take to achieve the desired outcome, but also seems somewhat confusing and clunky to do. Or perhaps I am misunderstanding its purpose? Thanks!

    opened by LLYX 2
  • No Category Shared Embedding?

    No Category Shared Embedding?

    I noticed that this implementation does not seem to have the feature of a shared embedding between each value belonging to the same category (unless I missed it) that the paper mentions (c_phi_i). If it's indeed missing, do you have plans to add that?

    Thanks for this implementation!

    opened by LLYX 3
  • index -1 is out of bounds for dimension 1 with size 17

    index -1 is out of bounds for dimension 1 with size 17

    I encountered this problem during the training process. What is the possible reason for this problem, and how can I solve this problem? Thanks!

      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 583, in forward
        return self.tabnet(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 468, in forward
        steps_output, M_loss = self.encoder(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 160, in forward
        M = self.att_transformers[step](prior, att)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py", line 637, in forward
        x = self.selector(x)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 109, in forward
        return sparsemax(input, self.dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 52, in forward
        tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
      File "/home/zhanghz/miniforge3/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py", line 94, in _threshold_and_support
        tau = input_cumsum.gather(dim, support_size - 1)
    RuntimeError: index -1 is out of bounds for dimension 1 with size 17
    Experiment has terminated.
    
    opened by hengzhe-zhang 2
  • Is there any training example about tabtransformer?

    Is there any training example about tabtransformer?

    Hi, I want to use it in a tabular dataset to finish a supervised learning,But I dont really know how to train this model with dataset(it seems that there is no such content in the readme file ). Could you please help me? thank you.

    opened by pancodex 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
[ECCV 2020] XingGAN for Person Image Generation

Contents XingGAN or CrossingGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evaluation Acknowl

Hao Tang 218 Oct 29, 2022
Meta-learning for NLP

Self-Supervised Meta-Learning for Few-Shot Natural Language Classification Tasks Code for training the meta-learning models and fine-tuning on downstr

IESL 43 Nov 08, 2022
Dynamic Environments with Deformable Objects (DEDO)

DEDO - Dynamic Environments with Deformable Objects DEDO is a lightweight and customizable suite of environments with deformable objects. It is aimed

Rika 32 Dec 22, 2022
This repository contains tutorials for the py4DSTEM Python package

py4DSTEM Tutorials This repository contains tutorials for the py4DSTEM Python package. For more information about py4DSTEM, including installation ins

11 Dec 23, 2022
A multilingual version of MS MARCO passage ranking dataset

mMARCO A multilingual version of MS MARCO passage ranking dataset This repository presents a neural machine translation-based method for translating t

75 Dec 27, 2022
Dataloader tools for language modelling

Installation: pip install lm_dataloader Design Philosophy A library to unify lm dataloading at large scale Simple interface, any tokenizer can be inte

5 Mar 25, 2022
Pytorch implementation of various High Dynamic Range (HDR) Imaging algorithms

Deep High Dynamic Range Imaging Benchmark This repository is the pytorch impleme

Tianhong Dai 5 Nov 16, 2022
ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation

ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation This repository provides a PyTorch implementation of ADSPM. Requirements Pyth

24 Jul 24, 2022
TJU Deep Learning & Neural Network

Deep_Learning & Neural_Network_Lab 实验环境 Python 3.9 Anaconda3(官网下载或清华镜像都行) PyTorch 1.10.1(安装代码如下) conda install pytorch torchvision torchaudio cudatool

St3ve Lee 1 Jan 19, 2022
根据midi文件演奏“风物之诗琴”的脚本 "Windsong Lyre" auto play

Genshin-lyre-auto-play 简体中文 | English 简介 根据midi文件演奏“风物之诗琴”的脚本。由Python驱动,在此承诺, ⚠️ 项目内绝不含任何能够引起安全问题的代码。 前排提示:所有键盘在动但是原神没反应的都是因为没有管理员权限,双击run.bat或者以管理员模式

御坂17032号 386 Jan 01, 2023
The official implementation of paper "Finding the Task-Optimal Low-Bit Sub-Distribution in Deep Neural Networks" (IJCV under review).

DGMS This is the code of the paper "Finding the Task-Optimal Low-Bit Sub-Distribution in Deep Neural Networks". Installation Our code works with Pytho

Runpei Dong 3 Aug 28, 2022
A parametric soroban written with CADQuery.

A parametric soroban written in CADQuery The purpose of this project is to demonstrate how "code CAD" can be intuitive to learn. See soroban.py for a

Lee 4 Aug 13, 2022
(IEEE TIP 2021) Regularized Densely-connected Pyramid Network for Salient Instance Segmentation

RDPNet IEEE TIP 2021: Regularized Densely-connected Pyramid Network for Salient Instance Segmentation PyTorch training and testing code are available.

Yu-Huan Wu 41 Oct 21, 2022
This repository contains the official implementation code of the paper Transformer-based Feature Reconstruction Network for Robust Multimodal Sentiment Analysis

This repository contains the official implementation code of the paper Transformer-based Feature Reconstruction Network for Robust Multimodal Sentiment Analysis, accepted at ACMMM 2021.

Ziqi Yuan 10 Sep 30, 2022
Rethinking Transformer-based Set Prediction for Object Detection

Rethinking Transformer-based Set Prediction for Object Detection Here are the code for the ICCV paper. The code is adapted from Detectron2 and AdelaiD

Zhiqing Sun 62 Dec 03, 2022
Pre-trained BERT Models for Ancient and Medieval Greek, and associated code for LaTeCH 2021 paper titled - "A Pilot Study for BERT Language Modelling and Morphological Analysis for Ancient and Medieval Greek"

Ancient Greek BERT The first and only available Ancient Greek sub-word BERT model! State-of-the-art post fine-tuning on Part-of-Speech Tagging and Mor

Pranaydeep Singh 22 Dec 08, 2022
《Improving Unsupervised Image Clustering With Robust Learning》(2020)

Improving Unsupervised Image Clustering With Robust Learning This repo is the PyTorch codes for "Improving Unsupervised Image Clustering With Robust L

Sungwon Park 129 Dec 27, 2022
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

Laura Smith 70 Dec 07, 2022
Unofficial implementation of PatchCore anomaly detection

PatchCore anomaly detection Unofficial implementation of PatchCore(new SOTA) anomaly detection model Original Paper : Towards Total Recall in Industri

Changwoo Ha 268 Dec 22, 2022