diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md
index 7e719e248..9b32d9063 100644
--- a/examples/mujoco/README.md
+++ b/examples/mujoco/README.md
@@ -16,6 +16,7 @@ Supported algorithms are listed below:
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
+- A2C, commit id (TODO)
## Offpolicy algorithms
@@ -149,6 +150,45 @@ By comparison to both classic literature and open source implementations (e.g.,
5. We didn't tune `step-per-collect` option and `training-num` option. Default values are finetuned with PPO algorithm so we assume they are also good for REINFORCE. You can play with them if you want, but remember that `buffer-size` should always be larger than `step-per-collect`, and if `step-per-collect` is too small and `training-num` too large, episodes will be truncated and bootstrapped very often, which will harm performances. If `training-num` is too small (e.g., less than 8), speed will go down.
6. Sigma of action is not fixed (normally seen in other implementation) or conditioned on observation, but is an independent parameter which can be updated by gradient descent. We choose this setting because it works well in PPO, and is recommended by [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990). See Fig. 23.
+### A2C
+
+| Environment | Tianshou(3M steps) | [Spinning Up(Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)|
+| :--------------------: | :----------------: | :--------------------: |
+| Ant | **5236.8+-236.7** | ~5 |
+| HalfCheetah | **2377.3+-1363.7** | ~600 |
+| Hopper | **1608.6+-529.5** | ~800 |
+| Walker2d | **1805.4+-1055.9** | ~460 |
+| Swimmer | 40.2+-1.8 | **~51** |
+| Humanoid | **5316.6+-554.8** | N |
+| Reacher | **-5.2+-0.5** | N |
+| InvertedPendulum | **1000.0+-0.0** | N |
+| InvertedDoublePendulum | **9351.3+-12.8** | N |
+
+| Environment | Tianshou | [PPO paper](https://arxiv.org/abs/1707.06347) A2C | [PPO paper](https://arxiv.org/abs/1707.06347) A2C + Trust Region |
+| :--------------------: | :----------------: | :-------------: | :-------------: |
+| Ant | **3485.4+-433.1** | N | N |
+| HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 |
+| Hopper | **1253.2+-458.0** | ~900 | ~1220 |
+| Walker2d | **1091.6+-709.2** | ~850 | ~700 |
+| Swimmer | **36.6+-2.1** | ~31 | **~36** |
+| Humanoid | **1726.0+-1070.1** | N | N |
+| Reacher | **-6.7+-2.3** | ~-24 | ~-27 |
+| InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** |
+| InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 |
+
+\* details[[5]](#footnote5)[[6]](#footnote6)
+
+#### Hints for A2C
+
+0. We choose `clip` action method in A2C instead `tanh` option as used in REINFORCE simply to be consistent with original implementation. `tanh` may be better or equally well but we didn't try.
+1. (Initial) learning rate, lr decay, and `step-per-collect`, `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents trained), below are our findings.
+2. `step-per-collect`/`training-num` = `bootstrap-lenghth`, which is max length of an "episode" used in GAE estimator, 80/16=5 in default settings. When `bootstrap-lenghth` is small, (maybe) because GAE can at most looks forward 5 steps, and use bootstrap strategy very often, the critic is less well-trained, so they actor cannot converge to very high scores. However, if we increase `step-per-collect` to increase `bootstrap-lenghth` (e.g. 256/16=16), actor/critic will be updated less often, so sample efficiency is low, which will make training process slow. To conclude, If you don't restrict env timesteps, you can try to use larger `bootstrap-lenghth`, and train for more steps, which perhaps will give you better converged scores. Train slower, achieve higher.
+3. 7e-4 learning rate with decay strategy if proper for `step-per-collect=80`, `training-num=16`, but if you use larger `step-per-collect`(e.g. 256 - 2048), 7e-4 `lr` is a little bit small, because now you have more data and less noise for each update, and will be more confidence if taking larger steps; so higher learning rate(e.g. 1e-3) is more appropriate and usually boost performance in this setting. If plotting results arises fast in early stages and become unstable later, consider lr decay before decreasing lr.
+4. `max-grad-norm` doesn't really help in our experiments, we simply keep it for consistency with other open-source implementations (e.g. SB3).
+5. We original paper of A3C use RMSprop optimizer, we find that Adam with the same learning rate works equally well. We use RMSprop anyway. Again, for consistency.
+6. We notice that in SB3's implementation of A2C that set `gae-lambda` to 1 by default, we don't know why and after doing some experiments, results show 0.95 is better overall.
+7. We find out that `step-per-collect=256`, `training-num=8` are also good hyperparameters. You can have a try.
+
## Note
[1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
diff --git a/examples/mujoco/benchmark/Ant-v3/a2c/figure.png b/examples/mujoco/benchmark/Ant-v3/a2c/figure.png
new file mode 100644
index 000000000..8fa837cf8
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/a2c/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/a2c/figure.png
new file mode 100644
index 000000000..2294caedd
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/a2c/figure.png b/examples/mujoco/benchmark/Hopper-v3/a2c/figure.png
new file mode 100644
index 000000000..7cb00a016
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/a2c/figure.png b/examples/mujoco/benchmark/Humanoid-v3/a2c/figure.png
new file mode 100644
index 000000000..1e404939e
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/a2c/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/a2c/figure.png
new file mode 100644
index 000000000..fcee76bf3
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/a2c/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/a2c/figure.png
new file mode 100644
index 000000000..ca0f56e93
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/a2c/figure.png b/examples/mujoco/benchmark/Reacher-v2/a2c/figure.png
new file mode 100644
index 000000000..f497d13ef
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/a2c/figure.png b/examples/mujoco/benchmark/Swimmer-v3/a2c/figure.png
new file mode 100644
index 000000000..ddc296fff
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/a2c/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/a2c/figure.png b/examples/mujoco/benchmark/Walker2d-v3/a2c/figure.png
new file mode 100644
index 000000000..d81f79640
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/a2c/figure.png differ
diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py
new file mode 100755
index 000000000..bdfa3cb78
--- /dev/null
+++ b/examples/mujoco/mujoco_a2c.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+
+import os
+import gym
+import torch
+import datetime
+import argparse
+import numpy as np
+from torch import nn
+from torch.optim.lr_scheduler import LambdaLR
+from torch.utils.tensorboard import SummaryWriter
+from torch.distributions import Independent, Normal
+
+from tianshou.policy import A2CPolicy
+from tianshou.utils import BasicLogger
+from tianshou.env import SubprocVectorEnv
+from tianshou.utils.net.common import Net
+from tianshou.trainer import onpolicy_trainer
+from tianshou.utils.net.continuous import ActorProb, Critic
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='HalfCheetah-v3')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--buffer-size', type=int, default=4096)
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
+ parser.add_argument('--lr', type=float, default=7e-4)
+ parser.add_argument('--gamma', type=float, default=0.99)
+ parser.add_argument('--epoch', type=int, default=100)
+ parser.add_argument('--step-per-epoch', type=int, default=30000)
+ parser.add_argument('--step-per-collect', type=int, default=80)
+ parser.add_argument('--repeat-per-collect', type=int, default=1)
+ # batch-size >> step-per-collect means caculating all data in one singe forward.
+ parser.add_argument('--batch-size', type=int, default=99999)
+ parser.add_argument('--training-num', type=int, default=16)
+ parser.add_argument('--test-num', type=int, default=10)
+ parser.add_argument('--logdir', type=str, default='log')
+ parser.add_argument('--render', type=float, default=0.)
+ parser.add_argument(
+ '--device', type=str,
+ default='cuda' if torch.cuda.is_available() else 'cpu')
+ parser.add_argument('--resume-path', type=str, default=None)
+ # a2c special
+ parser.add_argument('--rew-norm', type=int, default=True)
+ parser.add_argument('--vf-coef', type=float, default=0.5)
+ parser.add_argument('--ent-coef', type=float, default=0.01)
+ parser.add_argument('--gae-lambda', type=float, default=0.95)
+ parser.add_argument('--bound-action-method', type=str, default="clip")
+ parser.add_argument('--lr-decay', type=int, default=True)
+ parser.add_argument('--max-grad-norm', type=float, default=0.5)
+ return parser.parse_args()
+
+
+def test_a2c(args=get_args()):
+ env = gym.make(args.task)
+ args.state_shape = env.observation_space.shape or env.observation_space.n
+ args.action_shape = env.action_space.shape or env.action_space.n
+ args.max_action = env.action_space.high[0]
+ print("Observations shape:", args.state_shape)
+ print("Actions shape:", args.action_shape)
+ print("Action range:", np.min(env.action_space.low),
+ np.max(env.action_space.high))
+ # train_envs = gym.make(args.task)
+ train_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)],
+ norm_obs=True)
+ # test_envs = gym.make(args.task)
+ test_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.test_num)],
+ norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
+
+ # seed
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ train_envs.seed(args.seed)
+ test_envs.seed(args.seed)
+ # model
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh, device=args.device)
+ actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
+ unbounded=True, device=args.device).to(args.device)
+ net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh, device=args.device)
+ critic = Critic(net_c, device=args.device).to(args.device)
+ torch.nn.init.constant_(actor.sigma_param, -0.5)
+ for m in list(actor.modules()) + list(critic.modules()):
+ if isinstance(m, torch.nn.Linear):
+ # orthogonal initialization
+ torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
+ torch.nn.init.zeros_(m.bias)
+ # do last policy layer scaling, this will make initial actions have (close to)
+ # 0 mean and std, and will help boost performances,
+ # see https://arxiv.org/abs/2006.05990, Fig.24 for details
+ for m in actor.mu.modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.zeros_(m.bias)
+ m.weight.data.copy_(0.01 * m.weight.data)
+
+ optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()),
+ lr=args.lr, eps=1e-5, alpha=0.99)
+
+ lr_scheduler = None
+ if args.lr_decay:
+ # decay learning rate to 0 linearly
+ max_update_num = np.ceil(
+ args.step_per_epoch / args.step_per_collect) * args.epoch
+
+ lr_scheduler = LambdaLR(
+ optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
+
+ def dist(*logits):
+ return Independent(Normal(*logits), 1)
+
+ policy = A2CPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
+ gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm,
+ vf_coef=args.vf_coef, ent_coef=args.ent_coef,
+ reward_normalization=args.rew_norm, action_scaling=True,
+ action_bound_method=args.bound_action_method,
+ lr_scheduler=lr_scheduler, action_space=env.action_space)
+
+ # collector
+ if args.training_num > 1:
+ buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
+ else:
+ buffer = ReplayBuffer(args.buffer_size)
+ train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
+ test_collector = Collector(policy, test_envs)
+ # log
+ t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
+ log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_a2c'
+ log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
+ writer = SummaryWriter(log_path)
+ writer.add_text("args", str(args))
+ logger = BasicLogger(writer, update_interval=100, train_interval=100)
+
+ def save_fn(policy):
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
+
+ # trainer
+ result = onpolicy_trainer(
+ policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
+ args.repeat_per_collect, args.test_num, args.batch_size,
+ step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
+ test_in_train=False)
+
+ # Let's watch its performance!
+ policy.eval()
+ test_envs.seed(args.seed)
+ test_collector.reset()
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
+ print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
+
+
+if __name__ == '__main__':
+ test_a2c()
diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py
index fae9b00f9..d491ee711 100755
--- a/examples/mujoco/mujoco_ddpg.py
+++ b/examples/mujoco/mujoco_ddpg.py
@@ -103,9 +103,9 @@ def test_ddpg(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
- log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str(args.seed) +
- '_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
- '-' + args.task.replace('-', '_') + '_ddpg')
+ t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
+ log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_ddpg'
+ log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py
index f7fc28c85..81683e632 100755
--- a/examples/mujoco/mujoco_reinforce.py
+++ b/examples/mujoco/mujoco_reinforce.py
@@ -123,7 +123,7 @@ def dist(*logits):
log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
- logger = BasicLogger(writer, update_interval=10)
+ logger = BasicLogger(writer, update_interval=10, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py
index 64685ad45..1800944a0 100755
--- a/examples/mujoco/mujoco_sac.py
+++ b/examples/mujoco/mujoco_sac.py
@@ -115,9 +115,9 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
- log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(args.seed) +
- '_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
- '-' + args.task.replace('-', '_') + '_sac')
+ t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
+ log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_sac'
+ log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py
index 1b33076cb..28fc2a8f4 100755
--- a/examples/mujoco/mujoco_td3.py
+++ b/examples/mujoco/mujoco_td3.py
@@ -117,9 +117,9 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
- log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str(args.seed) +
- '_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
- '-' + args.task.replace('-', '_') + '_td3')
+ t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
+ log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3'
+ log_path = os.path.join(args.logdir, args.task, 'td3', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py
index bbf46221f..e69cd15ba 100644
--- a/tianshou/policy/modelfree/a2c.py
+++ b/tianshou/policy/modelfree/a2c.py
@@ -120,7 +120,7 @@ def learn( # type: ignore
- self._weight_ent * ent_loss
self.optim.zero_grad()
loss.backward()
- if self._grad_norm is not None: # clip large gradient
+ if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()),
max_norm=self._grad_norm)
diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py
index 643c6f469..0b2c76e2e 100644
--- a/tianshou/policy/modelfree/ppo.py
+++ b/tianshou/policy/modelfree/ppo.py
@@ -96,7 +96,7 @@ def process_fn(
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
mean, std = np.mean(advantages), np.std(advantages)
- advantages = (advantages - mean) / std # per-batch norm
+ advantages = (advantages - mean) / std
else:
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
@@ -139,7 +139,7 @@ def learn( # type: ignore
- self._weight_ent * ent_loss
self.optim.zero_grad()
loss.backward()
- if self._grad_norm is not None: # clip large gradient
+ if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) + list(self.critic.parameters()),
max_norm=self._grad_norm)