The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization".

Overview

Codebase for learning control flow in transformers

The official repository for our paper "The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization".

Paper: https://arxiv.org/abs/2110.07732

Please note that this repository is a cleaned-up version of the internal research repository we use. In case you encounter any problems with it, please don't hesitate to contact me.

Setup

This project requires Python 3 (tested with Python 3.8 and 3.9) and PyTorch 1.8.

pip3 install -r requirements.txt

Create a Weights and Biases account and run

wandb login

More information on setting up Weights and Biases can be found on https://docs.wandb.com/quickstart.

For plotting, LaTeX is required (to avoid Type 3 fonts and to render symbols). Installation is OS specific.

Usage

Running the experiments from the paper on a cluster

The code makes use of Weights and Biases for experiment tracking. In the sweeps directory, we provide sweep configurations for all experiments we have performed. The sweeps are officially meant for hyperparameter optimization, but we use them to run multiple configurations and seeds.

To reproduce our results, start a sweep for each of the YAML files in the sweeps directory. Run wandb agent for each of them in the root directory of the project. This will run all the experiments, and they will be displayed on the W&B dashboard. The name of the sweeps must match the name of the files in sweeps directory, except the .yaml ending. More details on how to run W&B sweeps can be found at https://docs.wandb.com/sweeps/quickstart. If you want to use a Linux cluster to run the experiments, you might find https://github.com/robertcsordas/cluster_tool useful.

For example, if you want to run NDR on compositional table lookup, run wandb sweep --name ctl_ndr sweeps/ctl_ndr.yaml. This creates the sweep and prints out its ID. Then run wandb agent <ID> with that ID.

Re-creating plots from the paper

Edit config file paper/config.json. Enter your project name in the field "wandb_project" (e.g. "username/project").

Run the scripts in the paper directory. For example:

cd paper
./run_all.sh

The output will be generated in the paper/out/ directory. Tables will be printed to stdout in latex format.

If you want to reproduce individual plots, it can be done by running individial python files in the paper directory.

Running experiments locally

It is possible to run single experiments with Tensorboard without using Weights and Biases. This is intended to be used for debugging the code locally.

If you want to run experiments locally, you can use run.py:

./run.py sweeps/ctl_ndr.yaml

If the sweep in question has multiple parameter choices, run.py will interactively prompt choices of each of them.

The experiment also starts a Tensorboard instance automatically on port 7000. If the port is already occupied, it will incrementally search for the next free port.

Note that the plotting scripts work only with Weights and Biases.

Reducing memory usage

In case some tasks won't fit on your GPU, play around with "-max_length_per_batch " argument. It can trade off memory usage/speed by slicing batches and executing them in multiple passes. Reduce it until the model fits.

BibText

@article{csordas2021neural,
      title={The Neural Data Router: Adaptive Control Flow in Transformers Improves Systematic Generalization}, 
      author={R\'obert Csord\'as and Kazuki Irie and J\"urgen Schmidhuber},
      journal={Preprint arXiv:2110.07732},
      year={2021},
      month={October}
}
Owner
Csordás Róbert
Csordás Róbert
Tensorflow port of a full NetVLAD network

netvlad_tf The main intention of this repo is deployment of a full NetVLAD network, which was originally implemented in Matlab, in Python. We provide

Robotics and Perception Group 225 Nov 08, 2022
FID calculation with proper image resizing and quantization steps

clean-fid: Fixing Inconsistencies in FID Project | Paper The FID calculation involves many steps that can produce inconsistencies in the final metric.

Gaurav Parmar 606 Jan 06, 2023
QI-Q RoboMaster2022 CV Algorithm

QI-Q RoboMaster2022 CV Algorithm

2 Jan 10, 2022
MassiveSumm: a very large-scale, very multilingual, news summarisation dataset

MassiveSumm: a very large-scale, very multilingual, news summarisation dataset This repository contains links to data and code to fetch and reproduce

Daniel Varab 19 Dec 16, 2022
The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training

[ICLR 2022] The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training The Unreasonable Effectiveness of

VITA 44 Dec 23, 2022
This repository is an open-source implementation of the ICRA 2021 paper: Locus: LiDAR-based Place Recognition using Spatiotemporal Higher-Order Pooling.

Locus This repository is an open-source implementation of the ICRA 2021 paper: Locus: LiDAR-based Place Recognition using Spatiotemporal Higher-Order

Robotics and Autonomous Systems Group 96 Dec 15, 2022
Code Repository for Liquid Time-Constant Networks (LTCs)

Liquid time-constant Networks (LTCs) [Update] A Pytorch version is added in our sister repository: https://github.com/mlech26l/keras-ncp This is the o

Ramin Hasani 553 Dec 27, 2022
3D dataset of humans Manipulating Objects in-the-Wild (MOW)

MOW dataset [Website] This repository maintains our 3D dataset of humans Manipulating Objects in-the-Wild (MOW). The dataset contains 512 images in th

Zhe Cao 28 Nov 06, 2022
Semi-supervised Adversarial Learning to Generate Photorealistic Face Images of New Identities from 3D Morphable Model

Semi-supervised Adversarial Learning to Generate Photorealistic Face Images of New Identities from 3D Morphable Model Baris Gecer 1, Binod Bhattarai 1

Baris Gecer 190 Dec 29, 2022
TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Prediction.

TalkNet 2 [WIP] TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Predictio

Rishikesh (ऋषिकेश) 69 Dec 17, 2022
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

Loft 242 Dec 30, 2022
Code for KHGT model, AAAI2021

KHGT Code for KHGT accepted by AAAI2021 Please unzip the data files in Datasets/ first. To run KHGT on Yelp data, use python labcode_yelp.py For Movi

32 Nov 29, 2022
StocksMA is a package to facilitate access to financial and economic data of Moroccan stocks.

Creating easier access to the Moroccan stock market data What is StocksMA ? StocksMA is a package to facilitate access to financial and economic data

Salah Eddine LABIAD 28 Jan 04, 2023
Rohit Ingole 2 Mar 24, 2022
Weakly Supervised Learning of Rigid 3D Scene Flow

Weakly Supervised Learning of Rigid 3D Scene Flow This repository provides code and data to train and evaluate a weakly supervised method for rigid 3D

Zan Gojcic 124 Dec 27, 2022
PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020).

NHDRRNet-PyTorch This is the PyTorch implementation of Deep HDR Imaging via A Non-Local Network (TIP 2020). 0. Differences between Original Paper and

Yutong Zhang 1 Mar 01, 2022
CenterNet:Objects as Points目标检测模型在Pytorch当中的实现

CenterNet:Objects as Points目标检测模型在Pytorch当中的实现

Bubbliiiing 267 Dec 29, 2022
Conversational text Analysis using various NLP techniques

PyConverse Let me try first Installation pip install pyconverse Usage Please try this notebook that demos the core functionalities: basic usage noteb

Rita Anjana 158 Dec 25, 2022
Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.

face-mask-detection Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network. It contains 3 scr

amirsalar 13 Jan 18, 2022
This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge column damage detection

Bridge-damage-segmentation This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge c

Jingxiao Liu 5 Dec 07, 2022