Tianshou - An elegant PyTorch deep reinforcement learning library.

Overview

PyPI Conda Read the Docs Read the Docs Unittest codecov GitHub issues GitHub stars GitHub forks GitHub license

Tianshou (天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:

Here is Tianshou's other features:

  • Elegant framework, using only ~4000 lines of code
  • State-of-the-art MuJoCo benchmark for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
  • Support parallel environment simulation (synchronous or asynchronous) for all algorithms Usage
  • Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) Usage
  • Support any type of environment state/action (e.g. a dict, a self-defined class, ...) Usage
  • Support customized training process Usage
  • Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
  • Support multi-agent RL Usage
  • Support both TensorBoard and W&B log tools
  • Support multi-GPU training Usage
  • Comprehensive documentation, PEP8 code-style checking, type checking and unit tests

In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.

“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。

Installation

Tianshou is currently hosted on PyPI and conda-forge. It requires Python >= 3.6.

You can simply install Tianshou from PyPI with the following command:

$ pip install tianshou

If you use Anaconda or Miniconda, you can install Tianshou from conda-forge through the following command:

$ conda install -c conda-forge tianshou

You can also install with the newest version through GitHub:

$ pip install git+https://github.com/thu-ml/[email protected] --upgrade

After installation, open your python console and type

import tianshou
print(tianshou.__version__)

If no error occurs, you have successfully installed Tianshou.

Documentation

The tutorials and API documentation are hosted on tianshou.readthedocs.io.

The example scripts are under test/ folder and examples/ folder.

中文文档位于 https://tianshou.readthedocs.io/zh/master/

Why Tianshou?

Comprehensive Functionality

RL Platform GitHub Stars # of Alg. (1) Custom Env Batch Training RNN Support Nested Observation Backend
Baselines GitHub stars 9 ✔️ (gym) (2) ✔️ TF1
Stable-Baselines GitHub stars 11 ✔️ (gym) (2) ✔️ TF1
Stable-Baselines3 GitHub stars 7 (3) ✔️ (gym) (2) ✔️ PyTorch
Ray/RLlib GitHub stars 16 ✔️ ✔️ ✔️ ✔️ TF/PyTorch
SpinningUp GitHub stars 6 ✔️ (gym) (2) PyTorch
Dopamine GitHub stars 7 TF/JAX
ACME GitHub stars 14 ✔️ (dm_env) ✔️ ✔️ ✔️ TF/JAX
keras-rl GitHub stars 7 ✔️ (gym) Keras
rlpyt GitHub stars 11 ✔️ ✔️ ✔️ PyTorch
ChainerRL GitHub stars 18 ✔️ (gym) ✔️ ✔️ Chainer
Sample Factory GitHub stars 1 (4) ✔️ (gym) ✔️ ✔️ ✔️ PyTorch
Tianshou GitHub stars 20 ✔️ (gym) ✔️ ✔️ ✔️ PyTorch

(1): access date: 2021-08-08

(2): not all algorithms support this feature

(3): TQC and QR-DQN in sb3-contrib instead of main repo

(4): super fast APPO!

High quality software engineering standard

RL Platform Documentation Code Coverage Type Hints Last Update
Baselines GitHub last commit
Stable-Baselines Documentation Status coverage GitHub last commit
Stable-Baselines3 Documentation Status coverage report ✔️ GitHub last commit
Ray/RLlib (1) ✔️ GitHub last commit
SpinningUp GitHub last commit
Dopamine GitHub last commit
ACME (1) ✔️ GitHub last commit
keras-rl Documentation (1) GitHub last commit
rlpyt Docs codecov GitHub last commit
ChainerRL Documentation Status Coverage Status GitHub last commit
Sample Factory codecov GitHub last commit
Tianshou Read the Docs codecov ✔️ GitHub last commit

(1): it has continuous integration but the coverage rate is not available

Reproducible and High Quality Result

Tianshou has its unit tests. Different from other platforms, the unit tests include the full agent training procedure for all of the implemented algorithms. It would be failed once if it could not train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform. Check out the GitHub Actions page for more detail.

The Atari/Mujoco benchmark results are under examples/atari/ and examples/mujoco/ folders. Our Mujoco result can beat most of existing benchmark.

Modularized Policy

We decouple all of the algorithms roughly into the following parts:

  • __init__: initialize the policy;
  • forward: to compute actions over given observations;
  • process_fn: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
  • learn: to learn from a given batch data;
  • post_process_fn: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
  • update: the main interface for training, i.e., process_fn -> learn -> post_process_fn.

Within this API, we can interact with different policies conveniently.

Quick Start

This is an example of Deep Q Network. You can also run the full script at test/discrete/test_dqn.py.

First, import some relevant packages:

import gym, torch, numpy as np, torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts

Define some hyper-parameters:

task = 'CartPole-v0'
lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))  # TensorBoard is supported!
# For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html

Make environments:

# you can also try with SubprocVectorEnv
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])

Define the network:

from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)

Setup policy and collectors:

policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)  # because DQN uses epsilon-greedy method

Let's train it:

result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect,
    test_num, batch_size, update_per_step=1 / step_per_collect,
    train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
    test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
    logger=logger)
print(f'Finished training! Use {result["duration"]}')

Save / load the trained policy (it's exactly the same as PyTorch nn.module):

torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))

Watch the performance with 35 FPS:

policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)

Look at the result saved in tensorboard: (with bash script in your terminal)

$ tensorboard --logdir log/dqn

You can check out the documentation for advanced usage.

It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses 3 seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)

$ python3 test/discrete/test_pg.py --seed 0 --render 0.03

Contributing

Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out this link.

Citing Tianshou

If you find Tianshou useful, please cite it in your publications.

@article{weng2021tianshou,
  title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
  author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun},
  journal={arXiv preprint arXiv:2107.14171},
  year={2021}
}

Acknowledgment

Tianshou was previously a reinforcement learning platform based on TensorFlow. You can check out the branch priv for more detail. Many thanks to Haosheng Zou's pioneering work for Tianshou before version 0.1.1.

We would like to thank TSAIL and Institute for Artificial Intelligence, Tsinghua University for providing such an excellent AI research platform.

Comments
  • RNN support

    RNN support

    • [ ] I have marked all applicable categories:

      • [ ] exception-raising bug
      • [ ] RL algorithm bug
      • [ ] documentation request (i.e. "X is missing from the documentation.")
      • [x] new feature request
    • [x] I have visited the source website, and in particular read the known issues

    • [x] I have searched through the issue tracker for duplicates

    • [ ] I have mentioned version numbers, operating system and environment, where applicable:

      import tianshou, sys
      print(tianshou.__version__, sys.version, sys.platform)
      

    I see on README that RNN support is on your TODO list. However, the module API seems to support RNN ( forward(obs, state) method). Could you please provide some examples on how to train RNN policy? Thanks!

    enhancement good first issue 
    opened by miriaford 55
  • assignment with heterogeneous batches

    assignment with heterogeneous batches

    In #142 , we specify the standard behavior of aggregating heterogeneous batches. Here we discuss assignment with heterogeneous batches.

    In Tianshou 0.2.5, batch assignment b0[index] = b1 requires that b0 and b1 are homogeneous (they have the same structure), but sometimes we need to deal with assignment with heterogeneous batches.

    Consider the following examples with reserved keys:

    x = Batch(a=[1, 2, 3, 4], b=Batch())
    y = Batch(a=Batch(), b=[5, 8])
    x[np.array([0, 1])] = y
    
    x = Batch(a=[1, 2, 3, 4], b=Batch())
    y = Batch(b=[5, 8])
    x[np.array([0, 1])] = y
    
    x = Batch(a=[1, 2, 3, 4])
    y = Batch(a=Batch(), b=[5, 8])
    x[np.array([0, 1])] = y
    

    What are the natural answers to these code snippets?

    Just raising exceptions is a possible solution, but I think it is not enough. From the rationale of reserved keys, I think assignment with heterogeneous batches should be allowed when they have clear meaning.

    An example of usage is in async simulation: I reset all environments and gather the initial observation & actions, but do not know about the reward & info, so I have a data = Batch(obs=xxx, act=xxx, info=Batch()). After one round of async simulation, environments 1 and 3 returned their next observations, so I have stepped_data = Batch(obs_next=xxx, info=xxx, rew=xxx). Then I have to set them back to the data, by data[np.array([1, 3])] = stepped_data.

    For any key chain k, there are 3 cases:

    1. data[k] is non-reserved but stepped_data[k] is reserved or does not exist (e.g. k='obs'), then I expect that data[k] is not changed after the assignment data[index] = stepped_data

    2. data[k] is reserved or does not exist but stepped_data[k] is non-reserved (e.g. k='obs_next'), then I expect that data[k] will be filled with zeros first, and then data[k][index] = stepped_data[k] works.

    3. both data[k] and stepped_data[k] are non-reserved, this is easy, just do the assignment data[k][index] = stepped_data[k].

    discussion 
    opened by youkaichao 44
  • W&B: add artifacts support

    W&B: add artifacts support

    • [x] I have marked all applicable categories:
      • [ ] exception-raising fix
      • [ ] algorithm implementation fix
      • [ ] documentation modification
      • [x] new feature
    • [x] I have reformatted the code using make format (required)
    • [x] I have checked the code using make commit-checks (required)
    • [ ] If applicable, I have mentioned the relevant/related issue(s)
    • [x] If applicable, I have listed every items in this Pull Request below

    This PR fixes existing bugs in WandbLogger and adds support for artifacts. Screenshot (131)

    When using WandbLogger, you can now resume your runs from any device. Example usage is in the examples/atari/atari_dqn_wandb.py:

    cd examples/atari
    python atari_dqn_wandb.py
    # terminate run. The run is executable on any device via
    python atari_dqn_wandb.py --resume_id {your run id}
    

    Let me know if I missed something. This is still a WIP. I'll add some more advanced visualization features.

    Bug fix

    It seems like the atari_dqn task is passing floats to logger.write method which crashes the script. I've handled the use case but it is hacky. We might need to modify the trainer to always pass dicts

    opened by AyushExel 36
  • Pettingzoo

    Pettingzoo

    • [X] I have marked all applicable categories:
      • [X] exception-raising fix
      • [ ] algorithm implementation fix
      • [X] documentation modification
      • [X] new feature
    • [X] I have reformatted the code using make format (required)
    • [X] I have checked the code using make commit-checks (required)
    • [X] If applicable, I have mentioned the relevant/related issue(s)
    • [X]If applicable, I have listed every items in this Pull Request below

    Tested with tic-tac-toe, piston ball (discrete and continuous)

    opened by mahi97 35
  • Add CachedReplayBuffer and ReplayBufferManager

    Add CachedReplayBuffer and ReplayBufferManager

    This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail.

    1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer;
    2. Make sure the reserved keys cannot be edited by methods like buffer.done = xxx;
    3. Add set_batch method for manually choosing the batch the ReplayBuffer wants to handle;
    4. Add sample_index method, same as sample but only return index instead of both index and batch data;
    5. Add prev (one-step previous transition index), next (one-step next transition index) and unfinished_index (the last modified index whose done==False);
    6. Separate alloc_fn method for allocating new memory for self._meta when a new (key, value) pair comes in;
    7. Move buffer's documentation to docs/tutorials/concepts.rst.
    opened by ChenDRAG 31
  • Standardize the behavior of Batch aggregation (stack/cat) when dealing with reserved keys

    Standardize the behavior of Batch aggregation (stack/cat) when dealing with reserved keys

    Currently, we use Batch() to indicate that a key reserves the place in Batch and it will have value later on. It works fine in most cases, but now we have noticed that it is problematic.

    The critic problem is: how to treat Batch(c=Batch())? There are two opinions:

    • Batch(c=Batch()) means hierarchical key reservation, and should be treated as empty. But the problem is, we want to enable the concatenation of Batches with other empty Batches, which is seemingly natural. But considering Batch.cat([Batch(a=np.random.rand(3, 4)), Batch(a=Batch(c=Batch()))]), we would want to treat Batch(c=Batch()) as non-empty.

    • Batch(c=Batch()) is not empty, and there is no need to support hierarchical key reservation. This makes the implementation straightforward, but does not support hierarchical key reservation. Unfortunately, there are some use cases of hierarchical key reservation in Tianshou.

    I think the critical problem is how to reserve keys and support hierarchical key reservation.

    @Trinkle23897 @duburcqa

    discussion 
    opened by youkaichao 28
  • Batching with a variable action space

    Batching with a variable action space

    I am trying to train an agent for a board game, Blokus, which has a potentially large (~32K) action space, while only a small fraction of them are valid by the rule per step. I am implementing the environment which provides all valid actions as observation per step (the shape of observation differs every step). Next I plan to train a DQN agent, but I am not sure how replay buffer batches work with a variable action space. I have read the tic-tac-toe example in the docs. What should I change from tic-tac-toe to adapt to a variable number of actions per step? Thanks!

    question 
    opened by dzy1997 27
  • Async Sampling

    Async Sampling

    • [ ] I have marked all applicable categories:

      • [ ] exception-raising bug
      • [ ] RL algorithm bug
      • [ ] documentation request (i.e. "X is missing from the documentation.")
      • [x] new feature request
    • [x] I have visited the source website, and in particular read the known issues

    • [x] I have searched through the issue tracker and issue categories for duplicates

    • [ ] I have mentioned version numbers, operating system and environment, where applicable:

      import tianshou, torch, sys
      print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
      

    hello~ In most envs, env step cost almost the same time, so SubprocVectorEnv and RayVectorEnv run almost linear scale up. However in my env, each step cost very different time, e.g. 90% step cost 1s, but 10% cost 10s. So when I run 10 SubprocVectorEnv,there are almost no speeding up, because almose envery step, there is a slow step cost 10s, and other 9 fast step cost 1s, but the step 'blocked' to wait for all 10 envs step finished. I think if the 10 envs step independently, the total time can almost linear scale up. But it seems that I have to change collector too much. Do you have any suggestion? Thanks very much for you great work on tianshou~

    enhancement good first issue discussion 
    opened by magicly 25
  • Make trainer resumable

    Make trainer resumable

    This is my simple idea: put everything (best_epoch and so on) in a class (TrainLog), and load/save all these things with hooked functions. I use a param epoch-per-save to save all these things every epoch-per-save epoch.

    But the problem is that the loggers log every n step/epoch, and we restart our training process from k * epoch-per-save * step-per-epoch env_step, so the loggers will write reduplicative log. I wonder if we can log only when saving the training data.

    opened by StephenArk30 24
  • Add offline trainer and discrete BCQ algorithm

    Add offline trainer and discrete BCQ algorithm

    Discrete BCQ: https://arxiv.org/abs/1910.01708 Offline trainer discussion: https://github.com/thu-ml/tianshou/issues/248#issuecomment-744921908

    Will implement a test_imitation.py in the next PR.

    opened by zhujl1991 24
  • How to restore from h5 file saved by PrioritizedVectorReplayBuffer?

    How to restore from h5 file saved by PrioritizedVectorReplayBuffer?

    • I have marked all applicable categories:
    • I have visited the source website
    • I have searched through the issue tracker for duplicates
    • I have mentioned version numbers, operating system and environment, where applicable:
    import tianshou, gym, torch, numpy, sys
    print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
    0.4.9 0.25.0 1.11.0+cu113 1.20.3 3.9.7 (default, Sep 16 2021, 16:59:28) [MSC v.1916 64 bit (AMD64)] win32
    

    When I run the following statement, it prompts: KeyError: 'weight', how can I modify it?

    data, idx = source.sample(0)

    buff = PrioritizedVectorReplayBuffer.load_hdf5(buffer_path)
    
    for buf_id, source in enumerate(buff.buffers):  # move data from buff to (old) buffer
        data, idx = source.sample(0)  # get all data
        train_collector.buffer.add(
            data, buffer_ids=np.full(shape=len(idx), fill_value=(buf_id % args.training_num)))
    del buff
    
    bug 
    opened by jaried 23
  • RecurrentCritic example

    RecurrentCritic example

    • [X] documentation request (i.e. "X is missing from the documentation.")

    I tried to use PPO+LSTM and got RecurrentActorProb working with PPO.

    However, I was not able to make RecurrentCritic work with PPO. There was no concrete example anywhere in the documentation. If someone has I would be interested in a full example with PPO with both recurrent critic and recurrent actor.

    opened by jaak-s 0
  • Problems with using tianshou to train pettingzoo's waterworld environment.

    Problems with using tianshou to train pettingzoo's waterworld environment.

    • [ ] I have marked all applicable categories:
      • [x] exception-raising bug
      • [ ] RL algorithm bug
      • [ ] documentation request (i.e. "X is missing from the documentation.")
      • [ ] new feature request
    • [x] I have visited the source website
    • [x] I have searched through the issue tracker for duplicates
    • [x] I have mentioned version numbers, operating system and environment, where applicable:
      import tianshou, gym, torch, numpy, sys, pettingzoo
      print(pettingzoo.__version__, tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
      

    Versions: -->print(pettingzoo.version, tianshou.version, gym.version, torch.version, numpy.version, sys.version, sys.platform) -->1.22.0 0.4.10 0.26.2 1.11.0+cpu 1.22.0 3.8.13 (default, Mar 28 2022, 06:59:08) [MSC v.1916 64 bit (AMD64)] win32

    Requesting help, I've been stuck for days!!! First, I watched "tianshou's documentation on the pettingzoo environment", which is what I learned through. Then, I tried to use tianshou to train pettingzoo's waterworld environment, but there was a problem that I could not solve.

    Next I will show the code I used, please help me to see if I am using tianshou wrongly. My code:Click

    Only round(step_per_collect*update_per_step)>=1, which means that the error occurs after the step starts collecting. Once steps are collected everything will lead to the following error: The error: Click

    I hope to be able to solve my problem, there is a great possibility that there is a problem with the way I use tianshou, but I really can not find the cause of the problem, I hope to get help, thank you very much!

    opened by boninggogogo 0
  • Gymnasium Integration

    Gymnasium Integration

    Changes:

    • Disclaimer in README
    • Replaced all occurences of Gym with Gymnasium
    • Removed code that is now dead since we no longer need to support the old step API
    • Updated type hints to only allow new step API
    • Increased required version of envpool to support Gymnasium
    • Increased required version of PettingZoo to support Gymnasium
    • Updated PettingZooEnv to only use the new step API, removed hack to also support old API
    • I had to add some # type: ignore comments, due to new type hinting in Gymnasium. I'm not that familiar with type hinting but I believe that the issue is on the Gymnasium side and we are looking into it.
    • Had to update MyTestEnv to support options kwarg
    • Skip NNI tests because they still use OpenAI Gym

    Still need to do:

    • Update the Jupyter notebooks in docs
    • Check the entire code base for more dead code (from compatibility stuff)
    • Check the reset functions of all environments/wrappers in code base to make sure they use the options kwarg
    • Someone might want to check test_env_finite.py
    opened by Markus28 3
  • [Illegal moves] Illegal moves made by tictactoe agent

    [Illegal moves] Illegal moves made by tictactoe agent

    • [x] I have marked all applicable categories:
      • [x] exception-raising bug
      • [ ] RL algorithm bug
      • [ ] documentation request (i.e. "X is missing from the documentation.")
      • [ ] new feature request
    • [x] I have visited the source website
    • [x] I have searched through the issue tracker for duplicates
    • [ ] I have mentioned version numbers, operating system and environment, where applicable:
      import tianshou, gym, torch, numpy, sys
      print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
      

    Versions: 0.4.10 0.26.3 1.13.1+cu117 1.23.5 3.9.15 | packaged by conda-forge | (main, Nov 22 2022, 15:55:03)

    Hi,

    I followed the script https://github.com/Farama-Foundation/PettingZoo/blob/master/tutorials/Tianshou/2_training_agents.py to train a tictactoe agent. After training the model, I tried to play against the trained agent, but it seems that the agent is making illegal moves. My code

    state_shape = env.observation_space["observation"].shape
    action_shape = env.action_space.n
    
    net = Net(state_shape=state_shape,
              action_shape=action_shape,
              hidden_sizes=[128, 128, 128, 128],
                 device="cuda" if torch.cuda.is_available() else "cpu",
                ).to("cuda" if torch.cuda.is_available() else "cpu")
    optim = torch.optim.Adam(net.parameters(), lr=1e-4)
    policy = DQNPolicy(
                model=net,
                optim=optim,
                discount_factor=0.9,
                estimation_step=3,
                target_update_freq=320,
            )
    
    policy.load_state_dict(torch.load(train_path))
    agents = env.agents
    agent = agents[0]
    new_game = True
    
    policy.eval()
    while not done:
        action = env.action_space.sample()
        if new_game:
            action = env.action_space.sample()
        else:
            observation['obs'] = observation['obs'].reshape(-1, int(np.prod(state_shape)))  # Reshape observation
            action = policy(Batch(**observation)).act[0]
    
        observation, reward, done, truncated, info = env.step(action)
        
        if not done:
            player_action = int(input('User input starts with 1 to 7: ')) - 1
            observation, reward, done, truncated, info = env.step(player_action)
            observation['info'] = info
    
        new_game = False
    

    The game: image

    I've checked the mask. It looks correct.

    Anyone able to help?

    question 
    opened by wei-ann-Github 2
  • How to implement simple action mask for multiple policies?

    How to implement simple action mask for multiple policies?

    I'm using a custom env which requires action mask for some situation (like when obs[0] = 0).

    I only found mask implementation in DQN where action needs to be transformed into dict.

    Since I'm experimenting on multiple discrete policies such as SAC/DQN/A2C/PPO, is there a universal method to implement mask on action in forward function of these policies?

    example like below: environment has "0/1/2" three actions, when obs[0] = 0, action 0 and 2 are unavailable and action 1 is available.

    def action_masks(self): if obs[0] == 0: return [False, True, False] else: return [True, True, True]

    question 
    opened by adiyaDalat 1
Releases(v0.4.11)
  • v0.4.11(Dec 24, 2022)

    Enhancement

    1. Hindsight Experience Replay as a replay buffer (#753, @Juno-T)
    2. Fix Atari PPO example (#780, @nuance1979)
    3. Update experiment details of MuJoCo benchmark (#779, @ChenDRAG)
    4. Tiny change since the tests are more than unit tests (#765, @fzyzcjy)

    Bug Fix

    1. Multi-agent: gym->gymnasium; render() update (#769, @WillDudley)
    2. Updated atari wrappers (#781, @Markus28)
    3. Fix info not pass issue in PGPolicy (#787, @Trinkle23897)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.11-py3-none-any.whl(159.82 KB)
    tianshou-0.4.11.tar.gz(120.13 KB)
  • v0.4.10(Oct 17, 2022)

    Enhancement

    1. Changes to support Gym 0.26.0 (#748, @Markus28)
    2. Added pre-commit (#752, @Markus28)
    3. Added support for new PettingZoo API (#751, @Markus28)
    4. Fix docs tictactoc dummy vector env (#749, @5cat)

    Bug fix

    1. Fix 2 bugs and refactor RunningMeanStd to support dict obs norm (#695, @Trinkle23897)
    2. Do not allow async simulation for test collector (#705, @CWHer)
    3. Fix venv wrapper reset retval error with gym env (#712, @Trinkle23897)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.10-py3-none-any.whl(156.15 KB)
    tianshou-0.4.10.tar.gz(117.32 KB)
  • v0.4.9(Jul 4, 2022)

    Bug Fix

    1. Fix save_checkpoint_fn return value to checkpoint_path (#659, @Trinkle23897)
    2. Fix an off-by-one bug in trainer iterator (#659, @Trinkle23897)
    3. Fix a bug in Discrete SAC evaluation; default to deterministic mode (#657, @nuance1979)
    4. Fix a bug in trainer about test reward not logged because self.env_step is not set for offline setting (#660, @nuance1979)
    5. Fix exception with watching pistonball environments (#663, @ycheng517)
    6. Use env.np_random.integers instead of env.np_random.randint in Atari examples (#613, @ycheng517)

    API Change

    1. Upgrade gym to >=0.23.1, support seed and return_info arguments for reset (#613, @ycheng517)

    New Features

    1. Add BranchDQN for large discrete action spaces (#618, @BFAnas)
    2. Add show_progress option for trainer (#641, @michalgregor)
    3. Added support for clipping to DQNPolicy (#642, @michalgregor)
    4. Implement TD3+BC for offline RL (#660, @nuance1979)
    5. Add multiDiscrete to discrete gym action space wrapper (#664, @BFAnas)

    Enhancement

    1. Use envpool in vizdoom example (#634, @Trinkle23897)
    2. Add Atari (discrete) SAC examples (#657, @nuance1979)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.9-py3-none-any.whl(152.88 KB)
    tianshou-0.4.9.tar.gz(110.43 KB)
  • v0.4.8(May 5, 2022)

    Bug fix

    1. Fix action scaling bug in SAC (#591, @ChenDRAG)

    Enhancement

    1. Add write_flush in two loggers, fix argument passing in WandbLogger (#581, @Trinkle23897)
    2. Update Multi-agent RL docs and upgrade pettingzoo (#595, @ycheng517)
    3. Add learning rate scheduler to BasePolicy (#598, @alexnikulkov)
    4. Add Jupyter notebook tutorials using Google Colaboratory (#599, @ChenDRAG)
    5. Unify utils.network: change action_dim to action_shape (#602, @Squeemos)
    6. Update Mujoco bemchmark's webpage (#606, @ChenDRAG)
    7. Add Atari results (#600, @gogoduan) (#616, @ChenDRAG)
    8. Convert RL Unplugged Atari datasets to tianshou ReplayBuffer (#621, @nuance1979)
    9. Implement REDQ (#623, @Jimenius)
    10. Improve data loading from D4RL and convert RL Unplugged to D4RL format (#624, @nuance1979)
    11. Add vecenv wrappers for obs_norm to support running mujoco experiment with envpool (#628, @Trinkle23897)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.8-py3-none-any.whl(146.86 KB)
    tianshou-0.4.8.tar.gz(105.18 KB)
  • v0.4.7(Mar 21, 2022)

    Bug Fix

    1. Add map_action_inverse for fixing the error of storing random action (#568)

    API Change

    1. Update WandbLogger implementation and update Atari examples, use Tensorboard SummaryWritter as core with wandb.init(..., sync_tensorboard=True) (#558, #562)
    2. Rename save_fn to save_best_fn to avoid ambiguity (#575)
    3. (Internal) Add tianshou.utils.deprecation for a unified deprecation wrapper. (#575)

    New Features

    1. Implement Generative Adversarial Imitation Learning (GAIL), add Mujoco examples (#550)
    2. Add Trainers as generators: OnpolicyTrainer, OffpolicyTrainer, and OfflineTrainer; remove duplicated code and merge into base trainer (#559)

    Enhancement

    1. Add imitation baselines for offline RL (#566)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.7-py3-none-any.whl(140.71 KB)
    tianshou-0.4.7.tar.gz(103.11 KB)
  • v0.4.6.post1(Feb 25, 2022)

  • v0.4.6(Feb 25, 2022)

    Bug Fix

    1. Fix casts to int by to_torch_as(...) calls in policies when using discrete actions (#521)

    API Change

    1. Change venv internal API name of worker: send_action -> send, get_result -> recv (align with envpool) (#517)

    New Features

    1. Add Intrinsic Curiosity Module (#503)
    2. Implement CQLPolicy and offline_cql example (#506)
    3. Pettingzoo environment support (#494)
    4. Enable venvs.reset() concurrent execution (#517)

    Enhancement

    1. Remove reset_buffer() from reset method (#501)
    2. Add atari ppo example (#523, #529)
    3. Add VizDoom PPO example and results (#533)
    4. Upgrade gym version to >=0.21 (#534)
    5. Switch atari example to use EnvPool by default (#534)

    Documentation

    1. Update dqn tutorial and add envpool to docs (#526)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.6-py3-none-any.whl(133.92 KB)
    tianshou-0.4.6.tar.gz(98.83 KB)
  • v0.4.5(Nov 28, 2021)

    Bug Fix

    1. Fix tqdm issue (#481)
    2. Fix atari wrapper to be deterministic (#467)
    3. Add writer.flush() in TensorboardLogger to ensure real-time logging result (#485)

    Enhancement

    1. Implements set_env_attr and get_env_attr for vector environments (#478)
    2. Implement BCQPolicy and offline_bcq example (#480)
    3. Enable test_collector=None in 3 trainers to turn off testing during training (#485)
    4. Fix an inconsistency in the implementation of Discrete CRR. Now it uses Critic class for its critic, following conventions in other actor-critic policies (#485)
    5. Update several offline policies to use ActorCritic class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic (#485)
    6. Move Atari offline RL examples to examples/offline and tests to test/offline (#485)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.5-py3-none-any.whl(126.49 KB)
    tianshou-0.4.5.tar.gz(94.50 KB)
  • v0.4.4(Oct 13, 2021)

    API Change

    1. add a new class DataParallelNet for multi-GPU training (#461)
    2. add ActorCritic for deterministic parameter grouping for share-head actor-critic network (#458)
    3. collector.collect() now returns 4 extra keys: rew/rew_std/len/len_std (previously this work is done in logger) (#459)
    4. rename WandBLogger -> WandbLogger (#441)

    Bug Fix

    1. fix logging in atari examples (#444)

    Enhancement

    1. save_fn() will be called at the beginning of trainer (#459)
    2. create a new page for logger (#463)
    3. add save_data and restore_data in wandb, allow more input arguments for wandb init, and integrate wandb into test/modelbase/test_psrl.py and examples/atari/atari_dqn.py (#441)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.4-py3-none-any.whl(122.50 KB)
    tianshou-0.4.4.tar.gz(91.53 KB)
  • v0.4.3(Sep 2, 2021)

    Bug Fix

    1. fix a2c/ppo optimizer bug when sharing head (#428)
    2. fix ppo dual clip implementation (#435)

    Enhancement

    1. add Rainbow (#386)
    2. add WandbLogger (#427)
    3. add env_id in preprocess_fn (#391)
    4. update README, add new chart and bibtex (#406)
    5. add Makefile, now you can use make commit-checks to automatically perform almost all checks (#432)
    6. add isort and yapf, apply to existing codebase (#432)
    7. add spelling check by using make spelling (#432)
    8. update contributing.rst (#432)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.3-py3-none-any.whl(121.00 KB)
    tianshou-0.4.3.tar.gz(90.37 KB)
  • v0.4.2(Jun 26, 2021)

    Enhancement

    1. Add model-free dqn family: IQN (#371), FQF (#376)
    2. Add model-free on-policy algorithm: NPG (#344, #347), TRPO (#337, #340)
    3. Add offline-rl algorithm: CQL (#359), CRR (#367)
    4. Support deterministic evaluation for onpolicy algorithms (#354)
    5. Make trainer resumable (#350)
    6. Support different state size and fix exception in venv.__del__ (#352, #384)
    7. Add vizdoom example (#384)
    8. Add numerical analysis tool and interactive plot (#335, #341)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.2-py3-none-any.whl(117.03 KB)
    tianshou-0.4.2.tar.gz(87.02 KB)
  • v0.4.1(Apr 4, 2021)

    API Change

    1. Add observation normalization in BaseVectorEnv (norm_obs, obs_rms, update_obs_rms and RunningMeanStd) (#308)
    2. Add policy.map_action to bound with raw action (e.g., map from (-inf, inf) to [-1, 1] by clipping or tanh squashing), and the mapped action won't store in replaybuffer (#313)
    3. Add lr_scheduler in on-policy algorithms, typically for LambdaLR (#318)

    Note

    To adapt with this version, you should change the action_range=... to action_space=env.action_space in policy initialization.

    Bug Fix

    1. Fix incorrect behaviors (error when n/ep==0 and reward shown in tqdm) with on-policy algorithm (#306, #328)
    2. Fix q-value mask_action error for obs_next (#310)

    Enhancement

    1. Release SOTA Mujoco benchmark (DDPG/TD3/SAC: #305, REINFORCE: #320, A2C: #325, PPO: #330) and add corresponding notes in /examples/mujoco/README.md
    2. Fix numpy>=1.20 typing issue (#323)
    3. Add cross-platform unittest (#331)
    4. Add a test on how to deal with finite env (#324)
    5. Add value normalization in on-policy algorithms (#319, #321)
    6. Separate advantage normalization and value normalization in PPO (#329)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.1-py3-none-any.whl(110.04 KB)
    tianshou-0.4.1.tar.gz(78.51 KB)
  • v0.4.0(Mar 2, 2021)

    This release contains several API and behavior changes.

    API Change

    Buffer

    1. Add ReplayBufferManager, PrioritizedReplayBufferManager, VectorReplayBuffer, PrioritizedVectorReplayBuffer, CachedReplayBuffer (#278, #280);
    2. Change buffer.add API from buffer.add(obs, act, rew, done, obs_next, info, policy, ...) to buffer.add(batch, buffer_ids) in order to add data more efficient (#280);
    3. Add set_batch method in buffer (#278);
    4. Add sample_index method, same as sample but only return index instead of both index and batch data (#278);
    5. Add prev (one-step previous transition index), next (one-step next transition index) and unfinished_index (the last modified index whose done==False) (#278);
    6. Add internal method _alloc_by_keys_diff in batch to support any form of keys pop up (#280);

    Collector

    1. Rewrite the original Collector, split the async function to AsyncCollector: Collector only supports sync mode, AsyncCollector support both modes (#280);
    2. Drop collector.collect(n_episode=List[int]) because the new collector can collect episodes without bias (#280);
    3. Move reward_metric from Collector to trainer (#280);
    4. Change Collector.collect logic: AsyncCollector.collect's semantic is the same as previous version, where collect(n_step or n_episode) will not collect exact n_step or n_episode transitions; Collector.collect(n_step or n_episode)'s semantic now changes to exact n_step or n_episode collect (#280);

    Policy

    1. Add policy.exploration_noise(action, batch) -> action method instead of implemented in policy.forward() (#280);
    2. Add Timelimit.truncate handler in compute_*_returns (#296);
    3. remove ignore_done flag (#296);
    4. remove reward_normalization option in offpolicy-algorithm (will raise Error if set to True) (#298);

    Trainer

    1. Change collect_per_step to step_per_collect (#293);
    2. Add update_per_step and episode_per_collect (#293);
    3. onpolicy_trainer now supports either step_collect or episode_collect (#293)
    4. Add BasicLogger and LazyLogger to log data more conveniently (#295)

    Bug Fix

    1. Fix VectorEnv action_space seed randomness -- when call env.seed(seed), it will call env.action_space.seed(seed); otherwise using Collector.collect(..., random=True) will produce different result each time (#300, #303).
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.4.0-py3-none-any.whl(103.78 KB)
    tianshou-0.4.0.tar.gz(75.96 KB)
  • v0.3.2(Feb 16, 2021)

    Bug Fix

    1. fix networks under utils/discrete and utils/continuous cannot work well under CUDA+torch<=1.6.0 (#289)
    2. fix 2 bugs of Batch: creating keys in Batch.__setitem__ now throws ValueError instead of KeyError; _create_value now allows placeholder with stack=False option (#284)

    Enhancement

    1. Add QR-DQN algorithm (#276)
    2. small optimization for Batch.cat and Batch.stack (#284), now it is almost as fast as v0.2.3
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.3.2-py3-none-any.whl(92.48 KB)
    tianshou-0.3.2.tar.gz(69.96 KB)
  • v0.3.1(Jan 20, 2021)

    API Change

    1. change utils.network args to support any form of MLP by default (#275), remove layer_num and hidden_layer_size, add hidden_sizes (a list of int indicate the network architecture)
    2. add HDF5 save/load method for ReplayBuffer (#261)
    3. add offline_trainer (#263)
    4. move Atari-related network to examples/atari/atari_network.py (#275)

    Bug Fix

    1. fix a potential bug in discrete behavior cloning policy (#263)

    Enhancement

    1. update SAC mujoco result (#246)
    2. add C51 algorithm with benchmark result (#266)
    3. enable type checking in utils.network (#275)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.3.1-py3-none-any.whl(91.16 KB)
    tianshou-0.3.1.tar.gz(69.45 KB)
  • v0.3.0(Sep 26, 2020)

    Since at this point, the code has largely changed from v0.2.0, we release version 0.3 from now on.

    API Change

    1. add policy.updating and clarify collecting state and updating state in training (#224)
    2. change train_fn(epoch) to train_fn(epoch, env_step) and test_fn(epoch) to test_fn(epoch, env_step) (#229)
    3. remove out-of-the-date API: collector.sample, collector.render, collector.seed, VectorEnv (#210)

    Bug Fix

    1. fix a bug in DDQN: target_q could not be sampled from np.random.rand (#224)
    2. fix a bug in DQN atari net: it should add a ReLU before the last layer (#224)
    3. fix a bug in collector timing (#224)
    4. fix a bug in the converter of Batch: deepcopy a Batch in to_numpy and to_torch (#213)
    5. ensure buffer.rew has a type of float (#229)

    Enhancement

    1. Anaconda support: conda install -c conda-forge tianshou (#228)
    2. add PSRL (#202)
    3. add SAC discrete (#216)
    4. add type check in unit test (#200)
    5. format code and update function signatures (#213)
    6. add pydocstyle and doc8 check (#210)
    7. several documentation fix (#210)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.3.0-py3-none-any.whl(80.95 KB)
    tianshou-0.3.0.tar.gz(63.50 KB)
  • v0.2.7(Sep 8, 2020)

    API Change

    1. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
    2. add save_only_last_obs for replay buffer in order to save the memory. (#184)
    3. remove default value in batch.split() and add merge_last argument (#185)
    4. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard (#189)
    5. add max_batchsize in onpolicy algorithms (#189)
    6. keep only sumtree in segment tree implementation (#193)
    7. add __contains__ and pop in batch: key in batch, batch.pop(key, deft) (#189)
    8. remove dict return support for collector preprocess_fn (#189)
    9. remove **kwargs in ReplayBuffer (#189)
    10. add no_grad argument in collector.collect (#204)

    Enhancement

    1. add DQN Atari examples (#187)
    2. change the type-checking order in batch.py and converter.py in order to meet the most often case first (#189)
    3. Numba acceleration for GAE, nstep, and segment tree (#193)
    4. add policy.eval() in all test scripts' "watch performance" (#189)
    5. add test_returns (both GAE and nstep) (#189)
    6. improve the code-coverage (from 90% to 95%) and remove the dead code (#189)
    7. polish examples/box2d/bipedal_hardcore_sac.py (#207)

    Bug fix

    1. fix a bug in MAPolicy: buffer.rew = Batch() doesn't change buffer.rew (thanks mypy) (#207)
    2. ~~set policy.eval() before collector.collect (#204)~~ This is a bug
    3. fix shape inconsistency for torch.Tensor in replay buffer (#189)
    4. potential bugfix for subproc.wait (#189)
    5. fix RecurrentActorProb (#189)
    6. fix some incorrect type annotation (#189)
    7. fix a bug in tictactoe set_eps (#193)
    8. dirty fix for asyncVenv check_id test
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.2.7-py3-none-any.whl(76.32 KB)
  • v0.2.6(Aug 19, 2020)

    API Change

    1. Replay buffer allows stack_num = 1 (#165)
    2. add policy.update to enable post process and remove collector.sample (#180)
    3. Remove collector.close and rename VectorEnv to DummyVectorEnv (#179)

    Enhancement

    1. Enable async simulation for all vector envs (#179)
    2. Improve PER (#159): use segment tree and enable all Q-learning algorithms to use PER
    3. unify single-env and multi-env in collector (#157)
    4. Pickle compatible for replay buffer and improve buffer.get (#182): fix #84 and make buffer more efficient
    5. Add ShmemVectorEnv implementation (#174)
    6. Add Dueling DQN implementation (#170)
    7. Add profile workflow (#143)
    8. Add BipedalWalkerHardcore-v3 SAC example (#177) (about 1 hour it is well-trained)

    Bug fix

    1. fix #162 of multi-dim action (#160)

    Note: 0.3 is coming soon!

    Source code(tar.gz)
    Source code(zip)
    tianshou-0.2.6-py3-none-any.whl(73.89 KB)
  • v0.2.5(Jul 22, 2020)

    New feature

    Multi-agent Reinforcement Learning: https://tianshou.readthedocs.io/en/latest/tutorials/tictactoe.html (#122)

    Documentation

    Add a tutorial of Batch class to standardized the behavior of Batch: https://tianshou.readthedocs.io/en/latest/tutorials/batch.html (#142)

    Bugfix

    • Fix inconsistent shape in A2CPolicy and PPOPolicy. Please be careful when dealing with log_prob (#155)
    • Fix list of tensors inside Batch, e.g., Batch(a=[np.zeros(3), torch.zeros(3)]) (#147)
    • Fix buffer update when stack_num > 0 (#154)
    • Remove useless kwargs
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.2.5-py3-none-any.whl(66.55 KB)
  • v0.2.4.post1(Jul 14, 2020)

    Several bug fix and enhancement:

    • remove deprecated API append (#126)
    • Batch.cat_ and Batch.stack_ is now working well with inconsistent keys (#130)
    • Batch.is_empty now correctly recognizes empty over empty Batch (#128)
    • reconstruct collector: remove multiple buffer case, change the internal data to Batch, and add reward_metric for MARL usage (#125)
    • add Batch.update to mimic dict.update (#128)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.2.4.post1-py3-none-any.whl(60.79 KB)
  • v0.2.4(Jul 10, 2020)

    Algorithm Implementation

    1. n_step returns for all Q-learning based algorithms; (#51)
    2. Auto alpha tuning in SAC (#80)
    3. Reserve policy._state to support saving hidden states in replay buffer (#19)
    4. Add sample_avail argument in ReplayBuffer to sample only available index in RNN training mode (#19)

    New Feature

    1. Batch.cat (#87), Batch.stack (#93), Batch.empty (#106, #110)
    2. Advanced slicing method of Batch (#106)
    3. Batch(kwargs, copy=True) will perform a deep copy (#110)
    4. Add random=True argument in collector.collect to perform sampling with random policy (#78)

    API Change

    1. Batch.append -> Batch.cat
    2. Remove atari wrapper to examples, since it is not a key feature in tianshou (#124)
    3. Add some pre-defined nets in tianshou.utils.net. Since we only define API instead of a class, we do not present it in tianshou.net. (#123)

    Docs

    Add cheatsheet: https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html

    Source code(tar.gz)
    Source code(zip)
    tianshou-0.2.4-py3-none-any.whl(60.11 KB)
  • v0.2.3(Jun 1, 2020)

  • v0.2.2(Apr 26, 2020)

    Algorithm Implementation

    1. Generalized Advantage Estimation (GAE);
    2. Update PPO algorithm with arXiv:1811.02553 and arXiv:1912.09729;
    3. Vanilla Imitation Learning (BC & DA, with continuous/discrete action space);
    4. Prioritized DQN;
    5. RNN-style policy network;
    6. Fix SAC with torch==1.5.0

    API change

    1. change __call__ to forward in policy;
    2. Add save_fn in trainer;
    3. Add __repr__ in tianshou.data, e.g. print(buffer)
    Source code(tar.gz)
    Source code(zip)
    tianshou-0.2.2-py3-none-any.whl(45.44 KB)
  • v0.2.1(Apr 7, 2020)

Owner
Tsinghua Machine Learning Group
Tsinghua Machine Learning Group
Point-NeRF: Point-based Neural Radiance Fields

Point-NeRF: Point-based Neural Radiance Fields Project Sites | Paper | Primary c

Qiangeng Xu 662 Jan 01, 2023
Latent Execution for Neural Program Synthesis

Latent Execution for Neural Program Synthesis This repo provides the code to replicate the experiments in the paper Xinyun Chen, Dawn Song, Yuandong T

Xinyun Chen 16 Oct 02, 2022
CSPML (crystal structure prediction with machine learning-based element substitution)

CSPML (crystal structure prediction with machine learning-based element substitution) CSPML is a unique methodology for the crystal structure predicti

8 Dec 20, 2022
Happywhale - Whale and Dolphin Identification Silver🥈 Solution (26/1588)

Kaggle-Happywhale Happywhale - Whale and Dolphin Identification Silver 🥈 Solution (26/1588) 竞赛方案思路 图像数据预处理-标志性特征图片裁剪:首先根据开源的标注数据训练YOLOv5x6目标检测模型,将训练集

Franxx 20 Nov 14, 2022
Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks.

Heterogeneous Graph Benchmark Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks. Roadmap We organize our repo by task, and on

THUDM 176 Dec 17, 2022
Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction".

TGIN Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction". Files in the folder dataset/ electr

Alibaba 21 Dec 21, 2022
A video scene detection algorithm is designed to detect a variety of different scenes within a video

Scene-Change-Detection - A video scene detection algorithm is designed to detect a variety of different scenes within a video. There is a very simple definition for a scene: It is a series of logical

1 Jan 04, 2022
Code and data for ACL2021 paper Cross-Lingual Abstractive Summarization with Limited Parallel Resources.

Multi-Task Framework for Cross-Lingual Abstractive Summarization (MCLAS) The code for ACL2021 paper Cross-Lingual Abstractive Summarization with Limit

Yu Bai 43 Nov 07, 2022
Code for the paper "Jukebox: A Generative Model for Music"

Status: Archive (code is provided as-is, no updates expected) Jukebox Code for "Jukebox: A Generative Model for Music" Paper Blog Explorer Colab Insta

OpenAI 6k Jan 02, 2023
On the Complementarity between Pre-Training and Back-Translation for Neural Machine Translation (Findings of EMNLP 2021))

PTvsBT On the Complementarity between Pre-Training and Back-Translation for Neural Machine Translation (Findings of EMNLP 2021) Citation Please cite a

Sunbow Liu 10 Nov 25, 2022
Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks

Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks Work accepted at NeurIPS'21 [paper, video]. If you use this code in

TU Delft 43 Dec 07, 2022
Code for the paper "Benchmarking and Analyzing Point Cloud Classification under Corruptions"

ModelNet-C Code for the paper "Benchmarking and Analyzing Point Cloud Classification under Corruptions". For the latest updates, see: sites.google.com

Jiawei Ren 45 Dec 28, 2022
We have made you a wrapper you can't refuse

We have made you a wrapper you can't refuse We have a vibrant community of developers helping each other in our Telegram group. Join us! Stay tuned fo

20.6k Jan 09, 2023
AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation

AniGAN: Style-Guided Generative Adversarial Networks for Unsupervised Anime Face Generation AniGAN: Style-Guided Generative Adversarial Networks for U

Bing Li 81 Dec 14, 2022
Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021

This folder contains the code for 'Scalable Variational Approaches for Bayesian Causal Discovery'. Installation To install, use conda with conda env c

14 Sep 21, 2022
Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis

Readme File for "Using Machine Learning to Test Causal Hypotheses in Conjoint Analysis" by Ham, Imai, and Janson. (2022) All scripts were written and

0 Jan 27, 2022
Open-AI's DALL-E for large scale training in mesh-tensorflow.

DALL-E in Mesh-Tensorflow [WIP] Open-AI's DALL-E in Mesh-Tensorflow. If this is similarly efficient to GPT-Neo, this repo should be able to train mode

EleutherAI 432 Dec 16, 2022
Code accompanying "Evolving spiking neuron cellular automata and networks to emulate in vitro neuronal activity," accepted to IEEE SSCI ICES 2021

Evolving-spiking-neuron-cellular-automata-and-networks-to-emulate-in-vitro-neuronal-activity Code accompanying "Evolving spiking neuron cellular autom

SOCRATES: Self-Organizing Computational substRATES 2 Dec 02, 2022
Metric learning algorithms in Python

metric-learn: Metric Learning in Python metric-learn contains efficient Python implementations of several popular supervised and weakly-supervised met

1.3k Dec 28, 2022
An Approach to Explore Logistic Regression Models

User-centered Regression An Approach to Explore Logistic Regression Models This tool applies the potential of Attribute-RadViz in identifying correlat

0 Nov 12, 2021