Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

Overview

VIN: Value Iteration Networks

Architecture of Value Iteration Network

A quick thank you

A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:

Why another VIN implementation?

  1. The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
  2. This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the original MATLAB implementation.
  3. Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.

Installation

This repository requires following packages:

Use pip to install the necessary dependencies:

pip install -U -r requirements.txt 

Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.

How to train

8x8 gridworld

python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

16x16 gridworld

python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128

28x28 gridworld

python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. One of: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

How to test / visualize paths (requires training first)

8x8 gridworld

python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10

16x16 gridworld

python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20

28x28 gridworld

python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36

To visualize the optimal and predicted paths simply pass:

--plot

Flags:

  • weights: Path to trained weights.
  • imsize: The size of input images. One of: [8, 16, 28]
  • plot: If supplied, the optimal and predicted paths will be plotted
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.

Results

Gridworld Sample One Sample Two
8x8
16x16
28x28

Datasets

Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.

Dataset size 8x8 16x16 28x28
Train set 81337 456309 1529584
Test set 13846 77203 251755

The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the dataset/make_training_data.py script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)

Performance: Success Rate

This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).

Success Rate 8x8 16x16 28x28
PyTorch 99.69% 96.99% 91.07%

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
PyTorch 99.83% 94.84% 88.54%
Comments
  • testing accuracy fairly low

    testing accuracy fairly low

    I just tried to follow the instructions in the repo, and tested models trained but got a fairly low accuracy. I'm using pyTorch 0.1.12_1. Is there anything I should pay attention to?

    opened by xinleipan 10
  • Prebuilt Dataset Generation

    Prebuilt Dataset Generation

    Hello,

    I was wondering how you generated the prebuilt datasets that are downloaded when running download_weights_and_datasets.sh, i.e. what were the max_obs and max_obs_size parameters?

    Did you follow this file in the original repo? https://github.com/avivt/VIN/blob/master/scripts/make_data_gridworld_nips.m

    Thanks, Emilio

    opened by eparisotto 5
  • the rollout accuracy in test script is lower than the test accuracy in train script.

    the rollout accuracy in test script is lower than the test accuracy in train script.

    Hello!

    I have a little doubt.Does the rollout accuracy indicate the success rate? If so, why is it lower than the prediction accuracy? In the Aviv's implementation, the success rate of the 8x8 grid world was as high as 99.6%. Why is the success rate in your experiment relatively low?

    Thanks!

    opened by albzni 4
  • RUN ERROR

    RUN ERROR

    when I run 'python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128', it's ok,but again 'python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128' was run, an error occurred as follows: [email protected]:~/pytorch-value-iteration-networks$ python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 10 --k 20 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 58, in _process images = images.astype(np.float32) MemoryError

    opened by N-Kingsley 3
  • Problem of running the test script

    Problem of running the test script

    Hello,

    I downloaded the data with the .sh downloading script you provided, I also got an nps weights file after training. When I ran the testing command I got the following error: Traceback (most recent call last): File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 158, in main(config) File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 85, in main _, predictions = vin(X_in, S1_in, S2_in, config) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 357, in call result = self.forward(*input, **kwargs) File "/home/research/DL/VIN/pytorch-value-iteration-networks/model.py", line 64, in forward return logits, self.sm(logits) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 352, in call for hook in self._forward_pre_hooks.values(): File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 398, in getattr type(self).name, name)) AttributeError: 'Softmax' object has no attribute '_forward_pre_hooks'

    Thanks for helping!

    opened by YantianZha 3
  • Improved readability of the VIN model, in addition to minor changes

    Improved readability of the VIN model, in addition to minor changes

    My main modification is in the forward method of the model where you extract the q_out from the q values, and not repeating q = F.conv2d(...) in two places. I also made minor improvements, such as adding argparse in the dataset creation script and changing .cuda() into .to(device) in test.py.

    opened by shuishida 2
  • Inconsistent tensor sizes when starting training

    Inconsistent tensor sizes when starting training

    Hey there. I'm trying to run

    python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
    

    But I get the following error

    Number of Train Samples: 103926
    Number of Test Samples: 17434
         Epoch | Train Loss | Train Error | Epoch Time
    Traceback (most recent call last):
      File "train.py", line 147, in <module>
        train(net, trainloader, config, criterion, optimizer, use_GPU)
      File "train.py", line 40, in train
        outputs, predictions = net(X, S1, S2, config)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 224, in __call__
        result = self.forward(*input, **kwargs)
      File "/media/user_home2/j1k1000o/j1k/VINs/pytorch-value-iteration-networks/model.py", line 44, in forward
        q = F.conv2d(torch.cat([r, v], 1), 
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 897, in cat
        return Concat.apply(dim, *iterable)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/tensor.py", line 317, in forward
        return torch.cat(inputs, dim)
    RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1502009910772/work/torch/lib/THC/generic/THCTensorMath.cu:141
    

    I've executed

    ./download_weights_and_datasets.sh
    

    as well as

    python ./dataset/make_training_data.py
    

    And I'm running it on an Ubuntu 16.04, python 3.6 and with all the requirements installed.

    Can you help me out?

    opened by juancprzs 2
  • Don't understand VIN last step

    Don't understand VIN last step

        slice_s1 = S1.long().expand(config.imsize, 1, config.l_q, q.size(0))
        slice_s1 = slice_s1.permute(3, 2, 1, 0)
        q_out = q.gather(2, slice_s1).squeeze(2)
    

    What does this 3 lines do?

    opened by QiXuanWang 1
  • KeyError: 'arr_1 is not a file in the archive'

    KeyError: 'arr_1 is not a file in the archive'

    python3 train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 49, in _process S1 = f['arr_1'] File "/home/user/miniconda3/lib/python3.6/site-packages/numpy/lib/npyio.py", line 255, in getitem raise KeyError("%s is not a file in the archive" % key) KeyError: 'arr_1 is not a file in the archive'

    I got this error, could you please

    opened by derelearnro 1
  • Problem of running dataset/make_training_data.py script

    Problem of running dataset/make_training_data.py script

    Hi

    When I tried to run the make_training_data.py script to generate the gridworld.npz file, I got the following error:

    FileNotFoundError: [Errno 2] No such file or directory: 'dataset/gridworld_28x28.npz'
    

    And I found that line 101 should be modified as follows:

    save_path = "gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
    
    opened by ruqing00 0
Owner
Kent Sommer
Software Engineer @ Toyota Research Institute (SF Bay Area)
Kent Sommer
Author: Wenhao Yu ([email protected]). ACL 2022. Commonsense Reasoning on Knowledge Graph for Text Generation

Diversifying Commonsense Reasoning Generation on Knowledge Graph Introduction -- This is the pytorch implementation of our ACL 2022 paper "Diversifyin

DM2 Lab @ ND 61 Dec 30, 2022
Repository for tackling Kaggle Ultrasound Nerve Segmentation challenge using Torchnet.

Ultrasound Nerve Segmentation Challenge using Torchnet This repository acts as a starting point for someone who wants to start with the kaggle ultraso

Qure.ai 46 Jul 18, 2022
Allows including an action inside another action (by preprocessing the Yaml file). This is how composite actions should have worked.

actions-includes Allows including an action inside another action (by preprocessing the Yaml file). Instead of using uses or run in your action step,

Tim Ansell 70 Nov 04, 2022
FB-tCNN for SSVEP Recognition

FB-tCNN for SSVEP Recognition Here are the codes of the tCNN and FB-tCNN in the paper "Filter Bank Convolutional Neural Network for Short Time-Window

Wenlong Ding 12 Dec 14, 2022
Educational API for 3D Vision using pose to control carton.

Educational API for 3D Vision using pose to control carton.

41 Jul 10, 2022
💡 Learnergy is a Python library for energy-based machine learning models.

Learnergy: Energy-based Machine Learners Welcome to Learnergy. Did you ever reach a bottleneck in your computational experiments? Are you tired of imp

Gustavo Rosa 57 Nov 17, 2022
the code used for the preprint Embedding-based Instance Segmentation of Microscopy Images.

EmbedSeg Introduction This repository hosts the version of the code used for the preprint Embedding-based Instance Segmentation of Microscopy Images.

JugLab 88 Dec 25, 2022
Weight estimation in CT by multi atlas techniques

maweight A Python package for multi-atlas based weight estimation for CT images, including segmentation by registration, feature extraction and model

György Kovács 0 Dec 24, 2021
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Amin Rezaei 126 Dec 27, 2022
Artificial Neural network regression model to predict the energy output in a combined cycle power plant.

Energy_Output_Predictor Artificial Neural network regression model to predict the energy output in a combined cycle power plant. Abstract Energy outpu

1 Feb 11, 2022
Caffe-like explicit model constructor. C(onfig)Model

cmodel Caffe-like explicit model constructor. C(onfig)Model Installation pip install git+https://github.com/bonlime/cmodel Usage In order to allow usi

1 Feb 18, 2022
FedGS: A Federated Group Synchronization Framework Implemented by LEAF-MX.

FedGS: Data Heterogeneity-Robust Federated Learning via Group Client Selection in Industrial IoT Preparation For instructions on generating data, plea

Lizonghang 9 Dec 22, 2022
Pytorch implementation of "Forward Thinking: Building and Training Neural Networks One Layer at a Time"

forward-thinking-pytorch Pytorch implementation of Forward Thinking: Building and Training Neural Networks One Layer at a Time Requirements Python 2.7

Kim Heecheol 65 Oct 06, 2022
Classification of EEG data using Deep Learning

Graduation-Project Classification of EEG data using Deep Learning Epilepsy is the most common neurological disease in the world. Epilepsy occurs as a

Osman Alpaydın 5 Jun 24, 2022
An API-first distributed deployment system of deep learning models using timeseries data to analyze and predict systems behaviour

Gordo Building thousands of models with timeseries data to monitor systems. Table of content About Examples Install Uninstall Developer manual How to

Equinor 26 Dec 27, 2022
A collection of metrics for evaluating timbre dissimilarity using the TorchMetrics API

Timbre Dissimilarity Metrics A collection of metrics for evaluating timbre dissimilarity using the TorchMetrics API Installation pip install -e . Usag

Ben Hayes 21 Jan 05, 2022
The code for our paper Semi-Supervised Learning with Multi-Head Co-Training

Semi-Supervised Learning with Multi-Head Co-Training (PyTorch) Abstract Co-training, extended from self-training, is one of the frameworks for semi-su

cmc 6 Dec 04, 2022
Bootstrapped Representation Learning on Graphs

Bootstrapped Representation Learning on Graphs This is the PyTorch implementation of BGRL Bootstrapped Representation Learning on Graphs The main scri

NerDS Lab :: Neural Data Science Lab 55 Jan 07, 2023
Weakly Supervised Text-to-SQL Parsing through Question Decomposition

Weakly Supervised Text-to-SQL Parsing through Question Decomposition The official repository for the paper "Weakly Supervised Text-to-SQL Parsing thro

14 Dec 19, 2022
Diffusion Probabilistic Models for 3D Point Cloud Generation (CVPR 2021)

Diffusion Probabilistic Models for 3D Point Cloud Generation [Paper] [Code] The official code repository for our CVPR 2021 paper "Diffusion Probabilis

Shitong Luo 323 Jan 05, 2023