diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index e69cd15ba..9396971df 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -14,8 +14,7 @@ class A2CPolicy(PGPolicy): :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.nn.Module critic: the critic network. (s -> V(s)) - :param torch.optim.Optimizer optim: the optimizer for actor and critic - network. + :param torch.optim.Optimizer optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. :type dist_fn: Type[torch.distributions.Distribution] :param float discount_factor: in [0, 1]. Default to 0.99. @@ -71,6 +70,13 @@ def __init__( def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: + batch = self._compute_returns(batch, buffer, indice) + batch.act = to_torch_as(batch.act, batch.v_s) + return batch + + def _compute_returns( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: v_s, v_s_ = [], [] with torch.no_grad(): @@ -96,7 +102,6 @@ def process_fn( self.ret_rms.update(unnormalized_returns) else: batch.returns = unnormalized_returns - batch.act = to_torch_as(batch.act, batch.v_s) batch.returns = to_torch_as(batch.returns, batch.v_s) batch.adv = to_torch_as(advantages, batch.v_s) return batch diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 0b2c76e2e..9b9c61272 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Type, Optional from tianshou.policy import A2CPolicy -from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as +from tianshou.data import Batch, ReplayBuffer, to_torch_as class PPOPolicy(A2CPolicy): @@ -24,6 +24,11 @@ class PPOPolicy(A2CPolicy): Default to 5.0 (set None if you do not want to use it). :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. Default to True. + :param bool advantage_normalization: whether to do per mini-batch advantage + normalization. Default to True. + :param bool recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + Default to False. :param float vf_coef: weight for value loss. Default to 0.5. :param float ent_coef: weight for entropy loss. Default to 0.01. :param float max_grad_norm: clipping gradients in back propagation. Default to @@ -59,7 +64,9 @@ def __init__( dist_fn: Type[torch.distributions.Distribution], eps_clip: float = 0.2, dual_clip: Optional[float] = None, - value_clip: bool = True, + value_clip: bool = False, + advantage_normalization: bool = True, + recompute_advantage: bool = False, **kwargs: Any, ) -> None: super().__init__(actor, critic, optim, dist_fn, **kwargs) @@ -68,51 +75,41 @@ def __init__( "Dual-clip PPO parameter should greater than 1.0." self._dual_clip = dual_clip self._value_clip = value_clip + if not self._rew_norm: + assert not self._value_clip, \ + "value clip is available only when `reward_normalization` is True" + self._norm_adv = advantage_normalization + self._recompute_adv = recompute_advantage def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - v_s, v_s_, old_log_prob = [], [], [] + if self._recompute_adv: + # buffer input `buffer` and `indice` to be used in `learn()`. + self._buffer = buffer + self._indice = indice + batch = self._compute_returns(batch, buffer, indice) + batch.act = to_torch_as(batch.act, batch.v_s) + old_log_prob = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): - v_s.append(self.critic(b.obs)) - v_s_.append(self.critic(b.obs_next)) - old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0]))) - batch.v_s = torch.cat(v_s, dim=0).flatten() # old value - v_s = to_numpy(batch.v_s) - v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) - # when normalizing values, we do not minus self.ret_rms.mean to be numerically - # consistent with OPENAI baselines' value normalization pipeline. Emperical - # study also shows that "minus mean" will harm performances a tiny little bit - # due to unknown reasons (on Mujoco envs, not confident, though). - if self._rew_norm: # unnormalize v_s & v_s_ - v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) - v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) - unnormalized_returns, advantages = self.compute_episodic_return( - batch, buffer, indice, v_s_, v_s, - gamma=self._gamma, gae_lambda=self._lambda) - if self._rew_norm: - batch.returns = unnormalized_returns / \ - np.sqrt(self.ret_rms.var + self._eps) - self.ret_rms.update(unnormalized_returns) - mean, std = np.mean(advantages), np.std(advantages) - advantages = (advantages - mean) / std - else: - batch.returns = unnormalized_returns - batch.act = to_torch_as(batch.act, batch.v_s) + old_log_prob.append(self(b).dist.log_prob(b.act)) batch.logp_old = torch.cat(old_log_prob, dim=0) - batch.returns = to_torch_as(batch.returns, batch.v_s) - batch.adv = to_torch_as(advantages, batch.v_s) return batch def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] - for _ in range(repeat): + for step in range(repeat): + if self._recompute_adv and step > 0: + batch = self._compute_returns(batch, self._buffer, self._indice) for b in batch.split(batch_size, merge_last=True): # calculate loss for actor dist = self(b).dist + if self._norm_adv: + mean, std = b.adv.mean(), b.adv.std() + b.adv = (b.adv - mean) / std # per-batch norm ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv @@ -130,9 +127,9 @@ def learn( # type: ignore -self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) - vf_loss = 0.5 * torch.max(vf1, vf2).mean() + vf_loss = torch.max(vf1, vf2).mean() else: - vf_loss = 0.5 * (b.returns - value).pow(2).mean() + vf_loss = (b.returns - value).pow(2).mean() # calculate regularization and overall loss ent_loss = dist.entropy().mean() loss = clip_loss + self._weight_vf * vf_loss \