Neural Turing Machines (NTM) - PyTorch Implementation

Overview

PyTorch Neural Turing Machine (NTM)

PyTorch implementation of Neural Turing Machines (NTM).

An NTM is a memory augumented neural network (attached to external memory) where the interactions with the external memory (address, read, write) are done using differentiable transformations. Overall, the network is end-to-end differentiable and thus trainable by a gradient based optimizer.

The NTM is processing input in sequences, much like an LSTM, but with additional benfits: (1) The external memory allows the network to learn algorithmic tasks easier (2) Having larger capacity, without increasing the network's trainable parameters.

The external memory allows the NTM to learn algorithmic tasks, that are much harder for LSTM to learn, and to maintain an internal state much longer than traditional LSTMs.

A PyTorch Implementation

This repository implements a vanilla NTM in a straight forward way. The following architecture is used:

NTM Architecture

Features

  • Batch learning support
  • Numerically stable
  • Flexible head configuration - use X read heads and Y write heads and specify the order of operation
  • copy and repeat-copy experiments agree with the paper

Copy Task

The Copy task tests the NTM's ability to store and recall a long sequence of arbitrary information. The input to the network is a random sequence of bits, ending with a delimiter. The sequence lengths are randomised between 1 to 20.

Training

Training convergence for the copy task using 4 different seeds (see the notebook for details)

NTM Convergence

The following plot shows the cost per sequence length during training. The network was trained with seed=10 and shows fast convergence. Other seeds may not perform as well but should converge in less than 30K iterations.

NTM Convergence

Evaluation

Here is an animated GIF that shows how the model generalize. The model was evaluated after every 500 training samples, using the target sequence shown in the upper part of the image. The bottom part shows the network output at any given training stage.

Copy Task

The following is the same, but with sequence length = 80. Note that the network was trained with sequences of lengths 1 to 20.

Copy Task


Repeat Copy Task

The Repeat Copy task tests whether the NTM can learn a simple nested function, and invoke it by learning to execute a for loop. The input to the network is a random sequence of bits, followed by a delimiter and a scalar value that represents the number of repetitions to output. The number of repetitions, was normalized to have zero mean and variance of one (as in the paper). Both the length of the sequence and the number of repetitions are randomised between 1 to 10.

Training

Training convergence for the repeat-copy task using 4 different seeds (see the notebook for details)

NTM Convergence

Evaluation

The following image shows the input presented to the network, a sequence of bits + delimiter + num-reps scalar. Specifically the sequence length here is eight and the number of repetitions is five.

Repeat Copy Task

And here's the output the network had predicted:

Repeat Copy Task

Here's an animated GIF that shows how the network learns to predict the targets. Specifically, the network was evaluated in each checkpoint saved during training with the same input sequence.

Repeat Copy Task

Installation

The NTM can be used as a reusable module, currently not packaged though.

  1. Clone repository
  2. Install PyTorch
  3. pip install -r requirements.txt

Usage

Execute ./train.py

usage: train.py [-h] [--seed SEED] [--task {copy,repeat-copy}] [-p PARAM]
                [--checkpoint-interval CHECKPOINT_INTERVAL]
                [--checkpoint-path CHECKPOINT_PATH]
                [--report-interval REPORT_INTERVAL]

optional arguments:
  -h, --help            show this help message and exit
  --seed SEED           Seed value for RNGs
  --task {copy,repeat-copy}
                        Choose the task to train (default: copy)
  -p PARAM, --param PARAM
                        Override model params. Example: "-pbatch_size=4
                        -pnum_heads=2"
  --checkpoint-interval CHECKPOINT_INTERVAL
                        Checkpoint interval (default: 1000). Use 0 to disable
                        checkpointing
  --checkpoint-path CHECKPOINT_PATH
                        Path for saving checkpoint data (default: './')
  --report-interval REPORT_INTERVAL
                        Reporting interval
Owner
Guy Zana
I make things, author of Curated Papers
Guy Zana
repro_eval is a collection of measures to evaluate the reproducibility/replicability of system-oriented IR experiments

repro_eval repro_eval is a collection of measures to evaluate the reproducibility/replicability of system-oriented IR experiments. The measures were d

IR Group at Technische Hochschule Köln 9 May 25, 2022
A comprehensive list of published machine learning applications to cosmology

ml-in-cosmology This github attempts to maintain a comprehensive list of published machine learning applications to cosmology, organized by subject ma

George Stein 290 Dec 29, 2022
The codes and models in 'Gaze Estimation using Transformer'.

GazeTR We provide the code of GazeTR-Hybrid in "Gaze Estimation using Transformer". We recommend you to use data processing codes provided in GazeHub.

65 Dec 27, 2022
A Quick and Dirty Progressive Neural Network written in TensorFlow.

prog_nn .▄▄ · ▄· ▄▌ ▐ ▄ ▄▄▄· ▐ ▄ ▐█ ▀. ▐█▪██▌•█▌▐█▐█ ▄█▪ •█▌▐█ ▄▀▀▀█▄▐█▌▐█▪▐█▐▐▌ ██▀

SynPon 53 Dec 12, 2022
Satellite labelling tool for manual labelling of storm top features such as overshooting tops, above-anvil plumes, cold U/Vs, rings etc.

Satellite labelling tool About this app A tool for manual labelling of storm top features such as overshooting tops, above-anvil plumes, cold U/Vs, ri

Czech Hydrometeorological Institute - Satellite Department 10 Sep 14, 2022
Composable transformations of Python+NumPy programsComposable transformations of Python+NumPy programs

Chex Chex is a library of utilities for helping to write reliable JAX code. This includes utils to help: Instrument your code (e.g. assertions) Debug

DeepMind 506 Jan 08, 2023
PyTorch code of my ICDAR 2021 paper Vision Transformer for Fast and Efficient Scene Text Recognition (ViTSTR)

Vision Transformer for Fast and Efficient Scene Text Recognition (ICDAR 2021) ViTSTR is a simple single-stage model that uses a pre-trained Vision Tra

Rowel Atienza 198 Dec 27, 2022
Source codes of CenterTrack++ in 2021 ICME Workshop on Big Surveillance Data Processing and Analysis

MOT Tracked object bounding box association (CenterTrack++) New association method based on CenterTrack. Two new branches (Tracked Size and IOU) are a

36 Oct 04, 2022
Official code for "EagerMOT: 3D Multi-Object Tracking via Sensor Fusion" [ICRA 2021]

EagerMOT: 3D Multi-Object Tracking via Sensor Fusion Read our ICRA 2021 paper here. Check out the 3 minute video for the quick intro or the full prese

Aleksandr Kim 276 Dec 30, 2022
[CVPR-2021] UnrealPerson: An adaptive pipeline for costless person re-identification

UnrealPerson: An Adaptive Pipeline for Costless Person Re-identification In our paper (arxiv), we propose a novel pipeline, UnrealPerson, that decreas

ZhangTianyu 70 Oct 10, 2022
Fewshot-face-translation-GAN - Generative adversarial networks integrating modules from FUNIT and SPADE for face-swapping.

Few-shot face translation A GAN based approach for one model to swap them all. The table below shows our priliminary face-swapping results requiring o

768 Dec 24, 2022
Repository for scripts and notebooks from the book: Programming PyTorch for Deep Learning

Repository for scripts and notebooks from the book: Programming PyTorch for Deep Learning

Ian Pointer 368 Dec 17, 2022
GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles

GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles This repository contains a method to generate 3D conformer ensembles direct

127 Dec 20, 2022
用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本和PARL(paddle)版本

用强化学习玩合成大西瓜 代码地址:https://github.com/Sharpiless/play-daxigua-using-Reinforcement-Learning 用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本、PARL(paddle)版本和pytorch版本

72 Dec 17, 2022
This is the official PyTorch implementation of our paper: "Artistic Style Transfer with Internal-external Learning and Contrastive Learning".

Artistic Style Transfer with Internal-external Learning and Contrastive Learning This is the official PyTorch implementation of our paper: "Artistic S

51 Dec 20, 2022
ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representation from common sense knowledge graphs.

ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representa

Bats Research 94 Nov 21, 2022
An off-line judger supporting distributed problem repositories

Thaw 中文 | English Thaw is an off-line judger supporting distributed problem repositories. Everyone can use Thaw release problems with license on GitHu

countercurrent_time 2 Jan 09, 2022
SOLOv2 on onnx & tensorRT

SOLOv2.tensorRT: NOTE: code based on WXinlong/SOLO add support to TensorRT inference onnxruntime tensorRT full_dims and dynamic shape postprocess with

47 Nov 26, 2022
discovering subdomains, hidden paths, extracting unique links

python-website-crawler discovering subdomains, hidden paths, extracting unique links pip install -r requirements.txt discover subdomain: You can give

merve 4 Sep 05, 2022
Summary Explorer is a tool to visually explore the state-of-the-art in text summarization.

Summary Explorer Summary Explorer is a tool to visually inspect the summaries from several state-of-the-art neural summarization models across multipl

Webis 42 Aug 14, 2022