Neural Turing Machines (NTM) - PyTorch Implementation

Overview

PyTorch Neural Turing Machine (NTM)

PyTorch implementation of Neural Turing Machines (NTM).

An NTM is a memory augumented neural network (attached to external memory) where the interactions with the external memory (address, read, write) are done using differentiable transformations. Overall, the network is end-to-end differentiable and thus trainable by a gradient based optimizer.

The NTM is processing input in sequences, much like an LSTM, but with additional benfits: (1) The external memory allows the network to learn algorithmic tasks easier (2) Having larger capacity, without increasing the network's trainable parameters.

The external memory allows the NTM to learn algorithmic tasks, that are much harder for LSTM to learn, and to maintain an internal state much longer than traditional LSTMs.

A PyTorch Implementation

This repository implements a vanilla NTM in a straight forward way. The following architecture is used:

NTM Architecture

Features

  • Batch learning support
  • Numerically stable
  • Flexible head configuration - use X read heads and Y write heads and specify the order of operation
  • copy and repeat-copy experiments agree with the paper

Copy Task

The Copy task tests the NTM's ability to store and recall a long sequence of arbitrary information. The input to the network is a random sequence of bits, ending with a delimiter. The sequence lengths are randomised between 1 to 20.

Training

Training convergence for the copy task using 4 different seeds (see the notebook for details)

NTM Convergence

The following plot shows the cost per sequence length during training. The network was trained with seed=10 and shows fast convergence. Other seeds may not perform as well but should converge in less than 30K iterations.

NTM Convergence

Evaluation

Here is an animated GIF that shows how the model generalize. The model was evaluated after every 500 training samples, using the target sequence shown in the upper part of the image. The bottom part shows the network output at any given training stage.

Copy Task

The following is the same, but with sequence length = 80. Note that the network was trained with sequences of lengths 1 to 20.

Copy Task


Repeat Copy Task

The Repeat Copy task tests whether the NTM can learn a simple nested function, and invoke it by learning to execute a for loop. The input to the network is a random sequence of bits, followed by a delimiter and a scalar value that represents the number of repetitions to output. The number of repetitions, was normalized to have zero mean and variance of one (as in the paper). Both the length of the sequence and the number of repetitions are randomised between 1 to 10.

Training

Training convergence for the repeat-copy task using 4 different seeds (see the notebook for details)

NTM Convergence

Evaluation

The following image shows the input presented to the network, a sequence of bits + delimiter + num-reps scalar. Specifically the sequence length here is eight and the number of repetitions is five.

Repeat Copy Task

And here's the output the network had predicted:

Repeat Copy Task

Here's an animated GIF that shows how the network learns to predict the targets. Specifically, the network was evaluated in each checkpoint saved during training with the same input sequence.

Repeat Copy Task

Installation

The NTM can be used as a reusable module, currently not packaged though.

  1. Clone repository
  2. Install PyTorch
  3. pip install -r requirements.txt

Usage

Execute ./train.py

usage: train.py [-h] [--seed SEED] [--task {copy,repeat-copy}] [-p PARAM]
                [--checkpoint-interval CHECKPOINT_INTERVAL]
                [--checkpoint-path CHECKPOINT_PATH]
                [--report-interval REPORT_INTERVAL]

optional arguments:
  -h, --help            show this help message and exit
  --seed SEED           Seed value for RNGs
  --task {copy,repeat-copy}
                        Choose the task to train (default: copy)
  -p PARAM, --param PARAM
                        Override model params. Example: "-pbatch_size=4
                        -pnum_heads=2"
  --checkpoint-interval CHECKPOINT_INTERVAL
                        Checkpoint interval (default: 1000). Use 0 to disable
                        checkpointing
  --checkpoint-path CHECKPOINT_PATH
                        Path for saving checkpoint data (default: './')
  --report-interval REPORT_INTERVAL
                        Reporting interval
Owner
Guy Zana
I make things, author of Curated Papers
Guy Zana
i-SpaSP: Structured Neural Pruning via Sparse Signal Recovery

i-SpaSP: Structured Neural Pruning via Sparse Signal Recovery This is a public code repository for the publication: i-SpaSP: Structured Neural Pruning

Cameron Ronald Wolfe 5 Nov 04, 2022
Inflated i3d network with inception backbone, weights transfered from tensorflow

I3D models transfered from Tensorflow to PyTorch This repo contains several scripts that allow to transfer the weights from the tensorflow implementat

Yana 479 Dec 08, 2022
Towards the D-Optimal Online Experiment Design for Recommender Selection (KDD 2021)

Towards the D-Optimal Online Experiment Design for Recommender Selection (KDD 2021) Contact 0 Jan 11, 2022

This is a repository for a Semantic Segmentation inference API using the Gluoncv CV toolkit

BMW Semantic Segmentation GPU/CPU Inference API This is a repository for a Semantic Segmentation inference API using the Gluoncv CV toolkit. The train

BMW TechOffice MUNICH 56 Nov 24, 2022
This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

This reporistory contains the test-dev data of the paper "xGQA: Cross-lingual Visual Question Answering".

AdapterHub 18 Dec 09, 2022
This application explain how we can easily integrate Deepface framework with Python Django application

deepface_suite This application explain how we can easily integrate Deepface framework with Python Django application install redis cache install requ

Mohamed Naji Aboo 3 Apr 18, 2022
Image classification for projects and researches

This is a tool to help you quickly solve classification problems including: data analysis, training, report results and model explanation.

Nguyễn Trường Lâu 2 Dec 27, 2021
Official code release for 3DV 2021 paper Human Performance Capture from Monocular Video in the Wild.

Official code release for 3DV 2021 paper Human Performance Capture from Monocular Video in the Wild.

Chen Guo 58 Dec 24, 2022
RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation

RIFE RIFE: Real-Time Intermediate Flow Estimation for Video Frame Interpolation Ported from https://github.com/hzwer/arXiv2020-RIFE Dependencies NumPy

49 Jan 07, 2023
PyTorch implementation of Barlow Twins.

Barlow Twins: Self-Supervised Learning via Redundancy Reduction PyTorch implementation of Barlow Twins. @article{zbontar2021barlow, title={Barlow Tw

Facebook Research 839 Dec 29, 2022
Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.

This repository contains the code release for Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. This implementation is written in JAX, and is a fork of Google's JaxNeRF

Google 625 Dec 30, 2022
Source code of generalized shuffled linear regression

Generalized-Shuffled-Linear-Regression Code for the ICCV 2021 paper: Generalized Shuffled Linear Regression. Authors: Feiran Li, Kent Fujiwara, Fumio

FEI 7 Oct 26, 2022
Official implementation of the ICCV 2021 paper "Joint Inductive and Transductive Learning for Video Object Segmentation"

JOINT This is the official implementation of Joint Inductive and Transductive learning for Video Object Segmentation, to appear in ICCV 2021. @inproce

Yunyao 35 Oct 16, 2022
Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021)

HAIS Hierarchical Aggregation for 3D Instance Segmentation (ICCV 2021) by Shaoyu Chen, Jiemin Fang, Qian Zhang, Wenyu Liu, Xinggang Wang*. (*) Corresp

Hust Visual Learning Team 145 Jan 05, 2023
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe

Patrick Kidger 717 Jan 09, 2023
FairyTailor: Multimodal Generative Framework for Storytelling

FairyTailor: Multimodal Generative Framework for Storytelling

Eden Bens 172 Dec 30, 2022
Code for unmixing audio signals in four different stems "drums, bass, vocals, others". The code is adapted from "Jukebox: A Generative Model for Music"

Status: Archive (code is provided as-is, no updates expected) Disclaimer This code is a based on "Jukebox: A Generative Model for Music" Paper We adju

Wadhah Zai El Amri 24 Dec 29, 2022
Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Generated Images"

Reverse_Engineering_GMs Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Gener

100 Dec 18, 2022
A Deep Learning based project for creating line art portraits.

ArtLine The main aim of the project is to create amazing line art portraits. Sounds Intresting,let's get to the pictures!! Model-(Smooth) Model-(Quali

Vijish Madhavan 3.3k Jan 07, 2023
5 Jan 05, 2023