这是indexloc提供的服务,不要输入任何密码
Skip to content

refactor ppo #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ class A2CPolicy(PGPolicy):
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution]
:param float discount_factor: in [0, 1]. Default to 0.99.
Expand Down Expand Up @@ -71,6 +70,13 @@ def __init__(

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
batch = self._compute_returns(batch, buffer, indice)
batch.act = to_torch_as(batch.act, batch.v_s)
return batch

def _compute_returns(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
v_s, v_s_ = [], []
with torch.no_grad():
Expand All @@ -96,7 +102,6 @@ def process_fn(
self.ret_rms.update(unnormalized_returns)
else:
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch
Expand Down
63 changes: 30 additions & 33 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Type, Optional

from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.data import Batch, ReplayBuffer, to_torch_as


class PPOPolicy(A2CPolicy):
Expand All @@ -24,6 +24,11 @@ class PPOPolicy(A2CPolicy):
Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
Default to True.
:param bool advantage_normalization: whether to do per mini-batch advantage
normalization. Default to True.
:param bool recompute_advantage: whether to recompute advantage every update
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
Default to False.
:param float vf_coef: weight for value loss. Default to 0.5.
:param float ent_coef: weight for entropy loss. Default to 0.01.
:param float max_grad_norm: clipping gradients in back propagation. Default to
Expand Down Expand Up @@ -59,7 +64,9 @@ def __init__(
dist_fn: Type[torch.distributions.Distribution],
eps_clip: float = 0.2,
dual_clip: Optional[float] = None,
value_clip: bool = True,
value_clip: bool = False,
advantage_normalization: bool = True,
recompute_advantage: bool = False,
**kwargs: Any,
) -> None:
super().__init__(actor, critic, optim, dist_fn, **kwargs)
Expand All @@ -68,51 +75,41 @@ def __init__(
"Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip
self._value_clip = value_clip
if not self._rew_norm:
assert not self._value_clip, \
"value clip is available only when `reward_normalization` is True"
self._norm_adv = advantage_normalization
self._recompute_adv = recompute_advantage

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
v_s, v_s_, old_log_prob = [], [], []
if self._recompute_adv:
# buffer input `buffer` and `indice` to be used in `learn()`.
self._buffer = buffer
self._indice = indice
batch = self._compute_returns(batch, buffer, indice)
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(b.obs))
v_s_.append(self.critic(b.obs_next))
old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0])))
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = to_numpy(batch.v_s)
v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten())
# when normalizing values, we do not minus self.ret_rms.mean to be numerically
# consistent with OPENAI baselines' value normalization pipeline. Emperical
# study also shows that "minus mean" will harm performances a tiny little bit
# due to unknown reasons (on Mujoco envs, not confident, though).
if self._rew_norm: # unnormalize v_s & v_s_
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
unnormalized_returns, advantages = self.compute_episodic_return(
batch, buffer, indice, v_s_, v_s,
gamma=self._gamma, gae_lambda=self._lambda)
if self._rew_norm:
batch.returns = unnormalized_returns / \
np.sqrt(self.ret_rms.var + self._eps)
self.ret_rms.update(unnormalized_returns)
mean, std = np.mean(advantages), np.std(advantages)
advantages = (advantages - mean) / std
else:
batch.returns = unnormalized_returns
batch.act = to_torch_as(batch.act, batch.v_s)
old_log_prob.append(self(b).dist.log_prob(b.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, batch.v_s)
batch.adv = to_torch_as(advantages, batch.v_s)
return batch

def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for step in range(repeat):
if self._recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indice)
for b in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(b).dist
if self._norm_adv:
mean, std = b.adv.mean(), b.adv.std()
b.adv = (b.adv - mean) / std # per-batch norm
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv
Expand All @@ -130,9 +127,9 @@ def learn( # type: ignore
-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2)
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
vf_loss = torch.max(vf1, vf2).mean()
else:
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
vf_loss = (b.returns - value).pow(2).mean()
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = clip_loss + self._weight_vf * vf_loss \
Expand Down