这是indexloc提供的服务,不要输入任何密码
Skip to content

ppo benchmark released #330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [C51](https://arxiv.org/pdf/1707.06887.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
Expand All @@ -39,7 +39,8 @@

Here is Tianshou's other features:

- Elegant framework, using only ~2000 lines of code
- Elegant framework, using only ~3000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/PPO/DDPG/TD3/SAC algorithms
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
Expand Down
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `C51 <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
Expand All @@ -30,6 +30,7 @@ Welcome to Tianshou!
Here is Tianshou's other features:

* Elegant framework, using only ~2000 lines of code
* State-of-the-art `MuJoCo benchmark <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco>`_
* Support parallel environment simulation (synchronous or asynchronous) for all algorithms: :ref:`parallel_sampling`
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
* Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
Expand Down
265 changes: 151 additions & 114 deletions examples/mujoco/README.md

Large diffs are not rendered by default.

Binary file added examples/mujoco/benchmark/Ant-v3/ppo/figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 22 additions & 12 deletions examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
Expand Down Expand Up @@ -36,12 +37,6 @@ def get_args():
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
# a2c special
parser.add_argument('--rew-norm', type=int, default=True)
parser.add_argument('--vf-coef', type=float, default=0.5)
Expand All @@ -50,6 +45,14 @@ def get_args():
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


Expand Down Expand Up @@ -120,6 +123,11 @@ def dist(*logits):
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
Expand All @@ -138,12 +146,14 @@ def dist(*logits):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
Expand Down
22 changes: 13 additions & 9 deletions examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
Expand Down Expand Up @@ -44,6 +45,8 @@ def get_args():
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


Expand Down Expand Up @@ -87,11 +90,10 @@ def test_ddpg(args=get_args()):
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step, action_space=env.action_space)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
Expand All @@ -113,12 +115,14 @@ def test_ddpg(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
if not args.watch:
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
Expand Down
175 changes: 175 additions & 0 deletions examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python3

import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal

from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)
# ppo special
parser.add_argument('--rew-norm', type=int, default=True)
# In theory, `vf-coef` will not make any difference if using Adam optimizer.
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=0)
parser.add_argument('--recompute-adv', type=int, default=1)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def test_ppo(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
norm_obs=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)

lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect) * args.epoch

lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits):
return Independent(Normal(*logits), 1)

policy = PPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
reward_normalization=args.rew_norm, action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space,
eps_clip=args.eps_clip, value_clip=args.value_clip,
dual_clip=args.dual_clip, advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_ppo'
log_path = os.path.join(args.logdir, args.task, 'ppo', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
test_ppo()
32 changes: 21 additions & 11 deletions examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
Expand Down Expand Up @@ -36,17 +37,19 @@ def get_args():
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)
# reinforce special
parser.add_argument('--rew-norm', type=int, default=True)
# "clip" option also works well.
parser.add_argument('--action-bound-method', type=str, default="tanh")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
# reinforce special
parser.add_argument('--rew-norm', type=int, default=True)
# "clip" option also works well.
parser.add_argument('--action-bound-method', type=str, default="tanh")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


Expand Down Expand Up @@ -110,6 +113,11 @@ def dist(*logits):
action_bound_method=args.action_bound_method,
lr_scheduler=lr_scheduler, action_space=env.action_space)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
Expand All @@ -128,12 +136,14 @@ def dist(*logits):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)

# Let's watch its performance!
policy.eval()
Expand Down
Loading