这是indexloc提供的服务,不要输入任何密码
Skip to content

Add TRPO policy #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:


- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Trust Region Policy Optimization](https://arxiv.org/pdf/1502.05477.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ On-policy
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.TRPOPolicy
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.PPOPolicy
:members:
:undoc-members:
Expand Down
7 changes: 4 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ Welcome to Tianshou!

**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:

* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization <https://arxiv.org/pdf/1502.05477.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
Expand Down Expand Up @@ -84,7 +85,7 @@ Tianshou is still under development, you can also check out the documents in sta
tutorials/concepts
tutorials/batch
tutorials/tictactoe
tutorials/trick
tutorials/benchmark
tutorials/cheatsheet


Expand Down
14 changes: 14 additions & 0 deletions docs/tutorials/benchmark.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Benchmark
=========

Mujoco Benchmark
----------------

Tianshou's Mujoco benchmark contains state-of-the-art results (even better than `SpinningUp <https://spinningup.openai.com/en/latest/spinningup/bench.html>`_!).

Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/mujoco

Atari Benchmark
---------------

Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/atari
10 changes: 5 additions & 5 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.

If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.

This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`. It returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps.
It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.

These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.

Expand All @@ -122,7 +122,7 @@ For example, you can write your hook as:
def preprocess_fn(**kwargs):
"""change reward to zero mean"""
# if only obs exist -> reset
# if obs/act/rew/done/... exist -> normal step
# if obs_next/act/rew/done/policy exist -> normal step
if 'rew' not in kwargs:
# means that it is called after env.reset(), it can only process the obs
return Batch() # none of the variables are needed to be updated
Expand Down Expand Up @@ -163,10 +163,10 @@ First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :cla

Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`.

The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map ``[s, a]`` pair to a new state:

- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s;
- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a.
- Before: ``(s, a, s', r, d)`` stored in replay buffer, and get stacked s;
- After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a.


.. _self_defined_env:
Expand Down
88 changes: 0 additions & 88 deletions docs/tutorials/trick.rst

This file was deleted.

4 changes: 1 addition & 3 deletions test/continuous/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def get_args():


def test_ppo(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
Expand Down Expand Up @@ -110,8 +109,7 @@ def dist(*logits):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
Expand Down
141 changes: 141 additions & 0 deletions test/continuous/test_trpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal

from tianshou.policy import TRPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=50000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.95)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int,
default=2) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
# trpo special
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--optim-critic-iters', type=int, default=5)
parser.add_argument('--max-kl', type=float, default=0.01)
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
parser.add_argument('--max-backtracks', type=int, default=10)

args = parser.parse_known_args()[0]
return args


def test_trpo(args=get_args()):
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
critic = Critic(Net(
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device,
activation=nn.Tanh), device=args.device).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits):
return Independent(Normal(*logits), 1)

policy = TRPOPolicy(
actor, critic, optim, dist,
discount_factor=args.gamma,
reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv,
gae_lambda=args.gae_lambda,
action_space=env.action_space,
optim_critic_iters=args.optim_critic_iters,
max_kl=args.max_kl,
backtrack_coeff=args.backtrack_coeff,
max_backtracks=args.max_backtracks)
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'trpo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


if __name__ == '__main__':
test_trpo()
3 changes: 1 addition & 2 deletions test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def test_a2c_with_il(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'a2c')
Expand Down
3 changes: 1 addition & 2 deletions test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def test_pg(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'pg')
Expand Down
3 changes: 1 addition & 2 deletions test/discrete/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def test_ppo(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
Expand Down
3 changes: 1 addition & 2 deletions test/discrete/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def test_discrete_sac(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
Expand Down
Loading