Code for the paper "Reinforcement Learning as One Big Sequence Modeling Problem"

Overview

Trajectory Transformer

Code release for Reinforcement Learning as One Big Sequence Modeling Problem.

Installation

All python dependencies are in environment.yml. Install with:

conda env create -f environment.yml
conda activate trajectory
pip install -e .

For reproducibility, we have also included system requirements in a Dockerfile (see installation instructions), but the conda installation should work on most standard Linux machines.

Usage

Train a transformer with: python scripts/train.py --dataset halfcheetah-medium-v2

To reproduce the offline RL results: python scripts/plan.py --dataset halfcheetah-medium-v2

By default, these commands will use the hyperparameters in config/offline.py. You can override them with runtime flags:

python scripts/plan.py --dataset halfcheetah-medium-v2 \
	--horizon 5 --beam_width 32

A few hyperparameters are different from those listed in the paper because of changes to the discretization strategy. These hyperparameters will be updated in the next arxiv version to match what is currently in the codebase.

Pretrained models

We have provided pretrained models for 16 datasets: {halfcheetah, hopper, walker2d, ant}-{expert-v2, medium-expert-v2, medium-v2, medium-replay-v2}. Download them with ./pretrained.sh

The models will be saved in logs/$DATASET/gpt/pretrained. To plan with these models, refer to them using the gpt_loadpath flag:

python scripts/plan.py --dataset halfcheetah-medium-v2 \
	--gpt_loadpath gpt/pretrained

pretrained.sh will also download 15 plans from each model, saved to logs/$DATASET/plans/pretrained. Read them with python plotting/read_results.py.

To create the table of offline RL results from the paper, run python plotting/table.py. This will print a table that can be copied into a Latex document. (Expand to view table source.)
\begin{table*}[h]
\centering
\small
\begin{tabular}{llrrrrrr}
\toprule
\multicolumn{1}{c}{\bf Dataset} & \multicolumn{1}{c}{\bf Environment} & \multicolumn{1}{c}{\bf BC} & \multicolumn{1}{c}{\bf MBOP} & \multicolumn{1}{c}{\bf BRAC} & \multicolumn{1}{c}{\bf CQL} & \multicolumn{1}{c}{\bf DT} & \multicolumn{1}{c}{\bf TT (Ours)} \\ 
\midrule
Medium-Expert & HalfCheetah & $59.9$ & $105.9$ & $41.9$ & $62.4$ & $86.8$ & $95.0$ \scriptsize{\raisebox{1pt}{$\pm 0.2$}} \\ 
Medium-Expert & Hopper & $79.6$ & $55.1$ & $0.9$ & $111.0$ & $107.6$ & $110.0$ \scriptsize{\raisebox{1pt}{$\pm 2.7$}} \\ 
Medium-Expert & Walker2d & $36.6$ & $70.2$ & $81.6$ & $98.7$ & $108.1$ & $101.9$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\ 
Medium-Expert & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $116.1$ \scriptsize{\raisebox{1pt}{$\pm 9.0$}} \\ 
\midrule
Medium & HalfCheetah & $43.1$ & $44.6$ & $46.3$ & $44.4$ & $42.6$ & $46.9$ \scriptsize{\raisebox{1pt}{$\pm 0.4$}} \\ 
Medium & Hopper & $63.9$ & $48.8$ & $31.3$ & $58.0$ & $67.6$ & $61.1$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\ 
Medium & Walker2d & $77.3$ & $41.0$ & $81.1$ & $79.2$ & $74.0$ & $79.0$ \scriptsize{\raisebox{1pt}{$\pm 2.8$}} \\ 
Medium & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $83.1$ \scriptsize{\raisebox{1pt}{$\pm 7.3$}} \\ 
\midrule
Medium-Replay & HalfCheetah & $4.3$ & $42.3$ & $47.7$ & $46.2$ & $36.6$ & $41.9$ \scriptsize{\raisebox{1pt}{$\pm 2.5$}} \\ 
Medium-Replay & Hopper & $27.6$ & $12.4$ & $0.6$ & $48.6$ & $82.7$ & $91.5$ \scriptsize{\raisebox{1pt}{$\pm 3.6$}} \\ 
Medium-Replay & Walker2d & $36.9$ & $9.7$ & $0.9$ & $26.7$ & $66.6$ & $82.6$ \scriptsize{\raisebox{1pt}{$\pm 6.9$}} \\ 
Medium-Replay & Ant & $-$ & $-$ & $-$ & $-$ & $-$ & $77.0$ \scriptsize{\raisebox{1pt}{$\pm 6.8$}} \\ 
\midrule
\multicolumn{2}{c}{\bf Average (without Ant)} & 47.7 & 47.8 & 36.9 & 63.9 & 74.7 & 78.9 \hspace{.6cm} \\ 
\multicolumn{2}{c}{\bf Average (all settings)} & $-$ & $-$ & $-$ & $-$ & $-$ & 82.2 \hspace{.6cm} \\ 
\bottomrule
\end{tabular}
\label{table:d4rl}
\end{table*}

To create the average performance plot, run python plotting/plot.py. (Expand to view plot.)

Docker

Copy your MuJoCo key to the Docker build context and build the container:

cp ~/.mujoco/mjkey.txt azure/files/
docker build -f azure/Dockerfile . -t trajectory

Test the container:

docker run -it --rm --gpus all \
	--mount type=bind,source=$PWD,target=/home/code \
	--mount type=bind,source=$HOME/.d4rl,target=/root/.d4rl \
	trajectory \
	bash -c \
	"export PYTHONPATH=$PYTHONPATH:/home/code && \
	python /home/code/scripts/train.py --dataset hopper-medium-expert-v2 --exp_name docker/"

Running on Azure

Setup

  1. Launching jobs on Azure requires one more python dependency:
pip install git+https://github.com/JannerM/[email protected]
  1. Tag the image built in the previous section and push it to Docker Hub:
export DOCKER_USERNAME=$(docker info | sed '/Username:/!d;s/.* //')
docker tag trajectory ${DOCKER_USERNAME}/trajectory:latest
docker image push ${DOCKER_USERNAME}/trajectory
  1. Update azure/config.py, either by modifying the file directly or setting the relevant environment variables. To set the AZURE_STORAGE_CONNECTION variable, navigate to the Access keys section of your storage account. Click Show keys and copy the Connection string.

  2. Download azcopy: ./azure/download.sh

Usage

Launch training jobs with python azure/launch_train.py and planning jobs with python azure/launch_plan.py.

These scripts do not take runtime arguments. Instead, they run the corresponding scripts (scripts/train.py and scripts/plan.py, respectively) using the Cartesian product of the parameters in params_to_sweep.

Viewing results

To rsync the results from the Azure storage container, run ./azure/sync.sh.

To mount the storage container:

  1. Create a blobfuse config with ./azure/make_fuse_config.sh
  2. Run ./azure/mount.sh to mount the storage container to ~/azure_mount

To unmount the container, run sudo umount -f ~/azure_mount; rm -r ~/azure_mount

Reference

@article{janner2021sequence,
  title={Reinforcement Learning as One Big Sequence Modeling Problem},
  author={Michael Janner and Qiyang Li and Sergey Levine},
  journal={arXiv preprint arXiv:2106.02039},
  year={2021},
}

Acknowledgements

The GPT implementation is from Andrej Karpathy's minGPT repo.

AAAI 2022: Stationary diffusion state neural estimation

Stationary Diffusion State Neural Estimation Although many graph-based clustering methods attempt to model the stationary diffusion state in their obj

绽琨 33 Nov 24, 2022
Official Pytorch Implementation of 'Learning Action Completeness from Points for Weakly-supervised Temporal Action Localization' (ICCV-21 Oral)

Learning-Action-Completeness-from-Points Official Pytorch Implementation of 'Learning Action Completeness from Points for Weakly-supervised Temporal A

Pilhyeon Lee 67 Jan 03, 2023
🚀 PyTorch Implementation of "Progressive Distillation for Fast Sampling of Diffusion Models(v-diffusion)"

PyTorch Implementation of "Progressive Distillation for Fast Sampling of Diffusion Models(v-diffusion)" Unofficial PyTorch Implementation of Progressi

Vitaliy Hramchenko 58 Dec 19, 2022
AutoML library for deep learning

Official Website: autokeras.com AutoKeras: An AutoML system based on Keras. It is developed by DATA Lab at Texas A&M University. The goal of AutoKeras

Keras 8.7k Jan 08, 2023
The official PyTorch implementation for NCSNv2 (NeurIPS 2020)

Improved Techniques for Training Score-Based Generative Models This repo contains the official implementation for the paper Improved Techniques for Tr

174 Dec 26, 2022
modelvshuman is a Python library to benchmark the gap between human and machine vision

modelvshuman is a Python library to benchmark the gap between human and machine vision. Using this library, both PyTorch and TensorFlow models can be evaluated on 17 out-of-distribution datasets with

Bethge Lab 244 Jan 03, 2023
using yolox+deepsort for object-tracker

YOLOX_deepsort_tracker yolox+deepsort实现目标跟踪 最新的yolox尝尝鲜~~(yolox正处在频繁更新阶段,因此直接链接yolox仓库作为子模块) Install Clone the repository recursively: git clone --rec

245 Dec 26, 2022
Easy-to-use library to boost AI inference leveraging state-of-the-art optimization techniques.

NEW RELEASE How Nebullvm Works • Tutorials • Benchmarks • Installation • Get Started • Optimization Examples Discord | Website | LinkedIn | Twitter Ne

Nebuly 1.7k Dec 31, 2022
GANSketchingJittor - Implementation of Sketch Your Own GAN in Jittor

GANSketching in Jittor Implementation of (Sketch Your Own GAN) in Jittor(计图). Or

Bernard Tan 10 Jul 02, 2022
基于Paddle框架的fcanet复现

fcanet-Paddle 基于Paddle框架的fcanet复现 fcanet 本项目基于paddlepaddle框架复现fcanet,并参加百度第三届论文复现赛,将在2021年5月15日比赛完后提供AIStudio链接~敬请期待 参考项目: frazerlin-fcanet 数据准备 本项目已挂

QuanHao Guo 7 Mar 07, 2022
Pytorch implementation of "Neural Wireframe Renderer: Learning Wireframe to Image Translations"

Neural Wireframe Renderer: Learning Wireframe to Image Translations Pytorch implementation of ideas from the paper Neural Wireframe Renderer: Learning

Yuan Xue 7 Nov 14, 2022
CLIP + VQGAN / PixelDraw

clipit Yet Another VQGAN-CLIP Codebase This started as a fork of @nerdyrodent's VQGAN-CLIP code which was based on the notebooks of @RiversWithWings a

dribnet 276 Dec 12, 2022
Code for "Single-view robot pose and joint angle estimation via render & compare", CVPR 2021 (Oral).

Single-view robot pose and joint angle estimation via render & compare Yann Labbé, Justin Carpentier, Mathieu Aubry, Josef Sivic CVPR: Conference on C

Yann Labbé 51 Oct 14, 2022
FeTaQA: Free-form Table Question Answering

FeTaQA: Free-form Table Question Answering FeTaQA is a Free-form Table Question Answering dataset with 10K Wikipedia-based {table, question, free-form

Language, Information, and Learning at Yale 40 Dec 13, 2022
Implementation of 'X-Linear Attention Networks for Image Captioning' [CVPR 2020]

Introduction This repository is for X-Linear Attention Networks for Image Captioning (CVPR 2020). The original paper can be found here. Please cite wi

JDAI-CV 240 Dec 17, 2022
Caffe: a fast open framework for deep learning.

Caffe Caffe is a deep learning framework made with expression, speed, and modularity in mind. It is developed by Berkeley AI Research (BAIR)/The Berke

Berkeley Vision and Learning Center 33k Dec 28, 2022
PyTorch code for EMNLP 2021 paper: Don't be Contradicted with Anything! CI-ToD: Towards Benchmarking Consistency for Task-oriented Dialogue System

Don’t be Contradicted with Anything!CI-ToD: Towards Benchmarking Consistency for Task-oriented Dialogue System This repository contains the PyTorch im

Libo Qin 25 Sep 06, 2022
A scientific and useful toolbox, which contains practical and effective long-tail related tricks with extensive experimental results

Bag of tricks for long-tailed visual recognition with deep convolutional neural networks This repository is the official PyTorch implementation of AAA

Yong-Shun Zhang 181 Dec 28, 2022
The official implementation for "FQ-ViT: Fully Quantized Vision Transformer without Retraining".

FQ-ViT [arXiv] This repo contains the official implementation of "FQ-ViT: Fully Quantized Vision Transformer without Retraining". Table of Contents In

132 Jan 08, 2023
Gluon CV Toolkit

Gluon CV Toolkit | Installation | Documentation | Tutorials | GluonCV provides implementations of the state-of-the-art (SOTA) deep learning models in

Distributed (Deep) Machine Learning Community 5.4k Jan 06, 2023