这是indexloc提供的服务,不要输入任何密码
Skip to content

A2C benchmark for mujoco #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 28, 2021
40 changes: 40 additions & 0 deletions examples/mujoco/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Supported algorithms are listed below:
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
- A2C, commit id (TODO)

## Offpolicy algorithms

Expand Down Expand Up @@ -149,6 +150,45 @@ By comparison to both classic literature and open source implementations (e.g.,
5. We didn't tune `step-per-collect` option and `training-num` option. Default values are finetuned with PPO algorithm so we assume they are also good for REINFORCE. You can play with them if you want, but remember that `buffer-size` should always be larger than `step-per-collect`, and if `step-per-collect` is too small and `training-num` too large, episodes will be truncated and bootstrapped very often, which will harm performances. If `training-num` is too small (e.g., less than 8), speed will go down.
6. Sigma of action is not fixed (normally seen in other implementation) or conditioned on observation, but is an independent parameter which can be updated by gradient descent. We choose this setting because it works well in PPO, and is recommended by [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990). See Fig. 23.

### A2C

| Environment | Tianshou(3M steps) | [Spinning Up(Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)|
| :--------------------: | :----------------: | :--------------------: |
| Ant | **5236.8+-236.7** | ~5 |
| HalfCheetah | **2377.3+-1363.7** | ~600 |
| Hopper | **1608.6+-529.5** | ~800 |
| Walker2d | **1805.4+-1055.9** | ~460 |
| Swimmer | 40.2+-1.8 | **~51** |
| Humanoid | **5316.6+-554.8** | N |
| Reacher | **-5.2+-0.5** | N |
| InvertedPendulum | **1000.0+-0.0** | N |
| InvertedDoublePendulum | **9351.3+-12.8** | N |

| Environment | Tianshou | [PPO paper](https://arxiv.org/abs/1707.06347) A2C | [PPO paper](https://arxiv.org/abs/1707.06347) A2C + Trust Region |
| :--------------------: | :----------------: | :-------------: | :-------------: |
| Ant | **3485.4+-433.1** | N | N |
| HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 |
| Hopper | **1253.2+-458.0** | ~900 | ~1220 |
| Walker2d | **1091.6+-709.2** | ~850 | ~700 |
| Swimmer | **36.6+-2.1** | ~31 | **~36** |
| Humanoid | **1726.0+-1070.1** | N | N |
| Reacher | **-6.7+-2.3** | ~-24 | ~-27 |
| InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** |
| InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 |

\* details<sup>[[5]](#footnote5)</sup><sup>[[6]](#footnote6)</sup>

#### Hints for A2C

0. We choose `clip` action method in A2C instead `tanh` option as used in REINFORCE simply to be consistent with original implementation. `tanh` may be better or equally well but we didn't try.
1. (Initial) learning rate, lr decay, and `step-per-collect`, `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents trained), below are our findings.
2. `step-per-collect`/`training-num` = `bootstrap-lenghth`, which is max length of an "episode" used in GAE estimator, 80/16=5 in default settings. When `bootstrap-lenghth` is small, (maybe) because GAE can at most looks forward 5 steps, and use bootstrap strategy very often, the critic is less well-trained, so they actor cannot converge to very high scores. However, if we increase `step-per-collect` to increase `bootstrap-lenghth` (e.g. 256/16=16), actor/critic will be updated less often, so sample efficiency is low, which will make training process slow. To conclude, If you don't restrict env timesteps, you can try to use larger `bootstrap-lenghth`, and train for more steps, which perhaps will give you better converged scores. Train slower, achieve higher.
3. 7e-4 learning rate with decay strategy if proper for `step-per-collect=80`, `training-num=16`, but if you use larger `step-per-collect`(e.g. 256 - 2048), 7e-4 `lr` is a little bit small, because now you have more data and less noise for each update, and will be more confidence if taking larger steps; so higher learning rate(e.g. 1e-3) is more appropriate and usually boost performance in this setting. If plotting results arises fast in early stages and become unstable later, consider lr decay before decreasing lr.
4. `max-grad-norm` doesn't really help in our experiments, we simply keep it for consistency with other open-source implementations (e.g. SB3).
5. We original paper of A3C use RMSprop optimizer, we find that Adam with the same learning rate works equally well. We use RMSprop anyway. Again, for consistency.
6. We notice that in SB3's implementation of A2C that set `gae-lambda` to 1 by default, we don't know why and after doing some experiments, results show 0.95 is better overall.
7. We find out that `step-per-collect=256`, `training-num=8` are also good hyperparameters. You can have a try.

## Note

<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
Expand Down
Binary file added examples/mujoco/benchmark/Ant-v3/a2c/figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 157 additions & 0 deletions examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3

import os
import gym
import torch
import datetime
import argparse
import numpy as np
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal

from tianshou.policy import A2CPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=80)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means caculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
# a2c special
parser.add_argument('--rew-norm', type=int, default=True)
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.01)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
return parser.parse_args()


def test_a2c(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
norm_obs=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)

optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()),
lr=args.lr, eps=1e-5, alpha=0.99)

lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect) * args.epoch

lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

def dist(*logits):
return Independent(Normal(*logits), 1)

policy = A2CPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
reward_normalization=args.rew_norm, action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_a2c'
log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
test_a2c()
6 changes: 3 additions & 3 deletions examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def test_ddpg(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str(args.seed) +
'_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
'-' + args.task.replace('-', '_') + '_ddpg')
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_ddpg'
log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def dist(*logits):
log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=10)
logger = BasicLogger(writer, update_interval=10, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
6 changes: 3 additions & 3 deletions examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(args.seed) +
'_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
'-' + args.task.replace('-', '_') + '_sac')
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_sac'
log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
Expand Down
6 changes: 3 additions & 3 deletions examples/mujoco/mujoco_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str(args.seed) +
'_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
'-' + args.task.replace('-', '_') + '_td3')
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3'
log_path = os.path.join(args.logdir, args.task, 'td3', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
Expand Down
2 changes: 1 addition & 1 deletion tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def learn( # type: ignore
- self._weight_ent * ent_loss
self.optim.zero_grad()
loss.backward()
if self._grad_norm is not None: # clip large gradient
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()),
max_norm=self._grad_norm)
Expand Down
4 changes: 2 additions & 2 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def process_fn(
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
mean, std = np.mean(advantages), np.std(advantages)
advantages = (advantages - mean) / std # per-batch norm
advantages = (advantages - mean) / std
else:
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
Expand Down Expand Up @@ -139,7 +139,7 @@ def learn( # type: ignore
- self._weight_ent * ent_loss
self.optim.zero_grad()
loss.backward()
if self._grad_norm is not None: # clip large gradient
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()),
max_norm=self._grad_norm)
Expand Down