diff --git a/examples/box2d/README.md b/examples/box2d/README.md index 0935534b4..53300dd1c 100644 --- a/examples/box2d/README.md +++ b/examples/box2d/README.md @@ -1,7 +1,6 @@ # Bipedal-Hardcore-SAC -- Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) +- Our default choice: remove the done flag penalty, will soon converge to \~270 reward within 100 epochs (10M env steps, 3~4 hours, see the image below) - If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward) -- Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed). -![](results/sac/BipedalHardcore.png) \ No newline at end of file +![](results/sac/BipedalHardcore.png) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 4e123719b..e6e7d73e7 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -24,6 +24,8 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.1) + parser.add_argument('--auto_alpha', type=int, default=1) + parser.add_argument('--alpha_lr', type=float, default=3e-4) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--collect-per-step', type=int, default=10) @@ -46,7 +48,7 @@ def get_args(): class EnvWrapper(object): """Env wrapper for reward scale, action repeat and action noise""" - def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3): + def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.0): self._env = gym.make(task) self.action_repeat = action_repeat self.reward_scale = reward_scale @@ -109,6 +111,12 @@ def IsStop(reward): critic2 = Critic(net_c2, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, args.tau, args.gamma, args.alpha, diff --git a/examples/box2d/results/sac/BipedalHardcore.png b/examples/box2d/results/sac/BipedalHardcore.png index 0b4196955..86367e591 100644 Binary files a/examples/box2d/results/sac/BipedalHardcore.png and b/examples/box2d/results/sac/BipedalHardcore.png differ diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 5019f5b62..bee5af4e7 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -5,11 +5,11 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter +from torch.distributions import Independent, Normal from tianshou.policy import PPOPolicy from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.policy.dist import DiagGaussian from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -84,7 +84,11 @@ def test_ppo(args=get_args()): torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(list( actor.parameters()) + list(critic.parameters()), lr=args.lr) - dist = DiagGaussian + + # replace DiagGuassian with Independent(Normal) which is equivalent + # pass *logits to be consistent with policy.forward + def dist(*logits): + return Independent(Normal(*logits), 1) policy = PPOPolicy( actor, critic, optim, dist, args.gamma, max_grad_norm=args.max_grad_norm, diff --git a/tianshou/policy/dist.py b/tianshou/policy/dist.py deleted file mode 100644 index f1792e4eb..000000000 --- a/tianshou/policy/dist.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - - -class DiagGaussian(torch.distributions.Normal): - """Diagonal Gaussian distribution.""" - - def log_prob(self, actions): - return super().log_prob(actions).sum(-1, keepdim=True) - - def entropy(self): - return super().entropy().sum(-1) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 920bfb167..47823c676 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -2,9 +2,9 @@ import numpy as np from copy import deepcopy from typing import Dict, Tuple, Union, Optional +from torch.distributions import Normal, Independent from tianshou.policy import DDPGPolicy -from tianshou.policy.dist import DiagGaussian from tianshou.data import Batch, to_torch_as, ReplayBuffer from tianshou.exploration import BaseNoise @@ -47,23 +47,26 @@ class SACPolicy(DDPGPolicy): explanation. """ - def __init__(self, - actor: torch.nn.Module, - actor_optim: torch.optim.Optimizer, - critic1: torch.nn.Module, - critic1_optim: torch.optim.Optimizer, - critic2: torch.nn.Module, - critic2_optim: torch.optim.Optimizer, - tau: float = 0.005, - gamma: float = 0.99, - alpha: Tuple[float, torch.Tensor, torch.optim.Optimizer] - or float = 0.2, - action_range: Optional[Tuple[float, float]] = None, - reward_normalization: bool = False, - ignore_done: bool = False, - estimation_step: int = 1, - exploration_noise: Optional[BaseNoise] = None, - **kwargs) -> None: + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + tau: float = 0.005, + gamma: float = 0.99, + alpha: Union[ + float, Tuple[float, torch.Tensor, torch.optim.Optimizer] + ] = 0.2, + action_range: Optional[Tuple[float, float]] = None, + reward_normalization: bool = False, + ignore_done: bool = False, + estimation_step: int = 1, + exploration_noise: Optional[BaseNoise] = None, + **kwargs + ) -> None: super().__init__(None, None, None, None, tau, gamma, exploration_noise, action_range, reward_normalization, ignore_done, estimation_step, **kwargs) @@ -75,14 +78,12 @@ def __init__(self, self.critic2_old.eval() self.critic2_optim = critic2_optim - self._automatic_alpha_tuning = not isinstance(alpha, float) - if self._automatic_alpha_tuning: - self._target_entropy = alpha[0] - assert(alpha[1].shape == torch.Size([1]) - and alpha[1].requires_grad) - self._log_alpha = alpha[1] - self._alpha_optim = alpha[2] - self._alpha = self._log_alpha.exp() + self._is_auto_alpha = False + if isinstance(alpha, tuple): + self._is_auto_alpha = True + self._target_entropy, self._log_alpha, self._alpha_optim = alpha + assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad + self._alpha = self._log_alpha.detach().exp() else: self._alpha = alpha @@ -111,12 +112,13 @@ def forward(self, batch: Batch, obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) - dist = DiagGaussian(*logits) + dist = Independent(Normal(*logits), 1) x = dist.rsample() y = torch.tanh(x) act = y * self._action_scale + self._action_bias y = self._action_scale * (1 - y.pow(2)) + self.__eps - log_prob = dist.log_prob(x) - torch.log(y).sum(-1, keepdim=True) + log_prob = dist.log_prob(x).unsqueeze(-1) + log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) if self._noise is not None and self.training and explorating: act += to_torch_as(self._noise(act.shape), act) act = act.clamp(self._range[0], self._range[1]) @@ -167,13 +169,13 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: actor_loss.backward() self.actor_optim.step() - if self._automatic_alpha_tuning: - log_prob = (obs_result.log_prob + self._target_entropy).detach() + if self._is_auto_alpha: + log_prob = obs_result.log_prob.detach() + self._target_entropy alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step() - self._alpha = self._log_alpha.exp() + self._alpha = self._log_alpha.detach().exp() self.sync_weight() @@ -182,6 +184,7 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: 'loss/critic1': critic1_loss.item(), 'loss/critic2': critic2_loss.item(), } - if self._automatic_alpha_tuning: + if self._is_auto_alpha: result['loss/alpha'] = alpha_loss.item() + result['v/alpha'] = self._alpha.item() return result diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index c04a4b556..be52c5400 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -77,13 +77,13 @@ def offpolicy_trainer( start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): + # train + policy.train() + if train_fn: + train_fn(epoch) with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: - # collect - if train_fn: - train_fn(epoch) - policy.eval() result = train_collector.collect(n_step=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): @@ -100,10 +100,9 @@ def offpolicy_trainer( start_time, train_collector, test_collector, test_result['rew']) else: + policy.train() if train_fn: train_fn(epoch) - # train - policy.train() for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): global_step += collect_per_step diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 6af43173c..db13d06e3 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -77,13 +77,13 @@ def onpolicy_trainer( start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): + # train + policy.train() + if train_fn: + train_fn(epoch) with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: - # collect - if train_fn: - train_fn(epoch) - policy.eval() result = train_collector.collect(n_episode=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): @@ -100,10 +100,9 @@ def onpolicy_trainer( start_time, train_collector, test_collector, test_result['rew']) else: + policy.train() if train_fn: train_fn(epoch) - # train - policy.train() losses = policy.update( 0, train_collector.buffer, batch_size, repeat_per_collect) train_collector.reset_buffer()