diff --git a/test/base/test_returns.py b/test/base/test_returns.py index e8d70de5c..fcf689036 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -30,9 +30,9 @@ def test_episodic_returns(size=2560): for b in batch: b.obs = b.act = 1 buf.add(b) - batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) + returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) - assert np.allclose(batch.returns, ans) + assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), @@ -41,9 +41,9 @@ def test_episodic_returns(size=2560): for b in batch: b.obs = b.act = 1 buf.add(b) - batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) + returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) - assert np.allclose(batch.returns, ans) + assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), @@ -52,9 +52,9 @@ def test_episodic_returns(size=2560): for b in batch: b.obs = b.act = 1 buf.add(b) - batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) + returns, _ = fn(batch, buf, buf.sample_index(0), gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) - assert np.allclose(batch.returns, ans) + assert np.allclose(returns, ans) buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), @@ -64,12 +64,12 @@ def test_episodic_returns(size=2560): b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) - ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) - returns = np.array([ + returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) + ground_truth = np.array([ 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., 474.2876, 390.1027, 299.476, 202.]) - assert np.allclose(ret.returns, returns) + assert np.allclose(returns, ground_truth) buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), @@ -82,12 +82,12 @@ def test_episodic_returns(size=2560): b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) - ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) - returns = np.array([ + returns, _ = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) + ground_truth = np.array([ 454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248, 199.02, 474.2876, 390.1027, 299.476, 202.]) - assert np.allclose(ret.returns, returns) + assert np.allclose(returns, ground_truth) if __name__ == '__main__': buf = ReplayBuffer(size) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 895d3c1f5..336e4b673 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -91,7 +91,8 @@ def test_ppo(args=get_args()): def dist(*logits): return Independent(Normal(*logits), 1) policy = PPOPolicy( - actor, critic, optim, dist, args.gamma, + actor, critic, optim, dist, + discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 323d14848..e9003ce8b 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -78,7 +78,8 @@ def test_a2c_with_il(args=get_args()): actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical policy = A2CPolicy( - actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda, + actor, critic, optim, dist, + discount_factor=args.gamma, gae_lambda=args.gae_lambda, vf_coef=args.vf_coef, ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm, action_space=env.action_space) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 83e3c1f6b..d96609a26 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -17,7 +17,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.95) @@ -27,7 +27,7 @@ def get_args(): parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -65,6 +65,11 @@ def test_pg(args=get_args()): policy = PGPolicy(net, optim, dist, args.gamma, reward_normalization=args.rew_norm, action_space=env.action_space) + for m in net.modules(): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) # collector train_collector = Collector( policy, train_envs, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 11428dc0d..8ba380e9e 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -80,7 +80,8 @@ def test_ppo(args=get_args()): actor.parameters()).union(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical policy = PPOPolicy( - actor, critic, optim, dist, args.gamma, + actor, critic, optim, dist, + discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, vf_coef=args.vf_coef, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 1d420173b..b29706575 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,7 @@ from torch import nn from numba import njit from abc import ABC, abstractmethod -from typing import Any, Dict, Union, Optional, Callable +from typing import Any, Dict, Tuple, Union, Optional, Callable from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -254,14 +254,14 @@ def compute_episodic_return( buffer: ReplayBuffer, indice: np.ndarray, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, + v_s: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, - rew_norm: bool = False, - ) -> Batch: + ) -> Tuple[np.ndarray, np.ndarray]: """Compute returns over given batch. Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) - to calculate q function/reward to go of given batch. + to calculate q/advantage value of given batch. :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch @@ -273,10 +273,8 @@ def compute_episodic_return( :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1]. Default to 0.95. - :param bool rew_norm: normalize the reward to Normal(0, 1). Default to False. - :return: a Batch. The result will be stored in batch.returns as a numpy - array with shape (bsz, ). + :return: two numpy arrays (returns, advantage) with each shape (bsz, ). """ rew = batch.rew if v_s_ is None: @@ -284,14 +282,14 @@ def compute_episodic_return( v_s_ = np.zeros_like(rew) else: v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice) + v_s = np.roll(v_s_, 1) if v_s is None else to_numpy(v_s.flatten()) end_flag = batch.done.copy() end_flag[np.isin(indice, buffer.unfinished_index())] = True - returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda) - if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): - returns = (returns - returns.mean()) / returns.std() - batch.returns = returns - return batch + advantage = _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + returns = advantage + v_s + # normalization varies from each policy, so we don't do it here + return returns, advantage @staticmethod def compute_nstep_return( @@ -355,8 +353,6 @@ def _compile(self) -> None: i64 = np.array([[0, 1]], dtype=np.int64) _gae_return(f64, f64, f64, b, 0.1, 0.1) _gae_return(f32, f32, f64, b, 0.1, 0.1) - _episodic_return(f64, f64, b, 0.1, 0.1) - _episodic_return(f32, f64, b, 0.1, 0.1) _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) @@ -379,19 +375,6 @@ def _gae_return( return returns -@njit -def _episodic_return( - v_s_: np.ndarray, - rew: np.ndarray, - end_flag: np.ndarray, - gamma: float, - gae_lambda: float, -) -> np.ndarray: - """Numba speedup: 4.1s -> 0.057s.""" - v_s = np.roll(v_s_, 1) - return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s - - @njit def _nstep_return( rew: np.ndarray, diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 0abf62cd1..3dd1e561a 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -2,7 +2,7 @@ import numpy as np from torch import nn import torch.nn.functional as F -from typing import Any, Dict, List, Type, Union, Optional +from typing import Any, Dict, List, Type, Optional from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -53,17 +53,14 @@ def __init__( critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], - discount_factor: float = 0.99, vf_coef: float = 0.5, ent_coef: float = 0.01, max_grad_norm: Optional[float] = None, gae_lambda: float = 0.95, - reward_normalization: bool = False, max_batchsize: int = 256, **kwargs: Any ) -> None: - super().__init__(None, optim, dist_fn, discount_factor, **kwargs) - self.actor = actor + super().__init__(actor, optim, dist_fn, **kwargs) self.critic = critic assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda @@ -71,51 +68,27 @@ def __init__( self._weight_ent = ent_coef self._grad_norm = max_grad_norm self._batch = max_batchsize - self._rew_norm = reward_normalization def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - if self._lambda in [0.0, 1.0]: - return self.compute_episodic_return( - batch, buffer, indice, - None, gamma=self._gamma, gae_lambda=self._lambda) - v_ = [] + v_s_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): - v_.append(to_numpy(self.critic(b.obs_next))) - v_ = np.concatenate(v_, axis=0) - return self.compute_episodic_return( - batch, buffer, indice, v_, - gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) - - def forward( - self, - batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs: Any - ) -> Batch: - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which has 4 keys: - - * ``act`` the action. - * ``logits`` the network's raw output. - * ``dist`` the action distribution. - * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - logits, h = self.actor(batch.obs, state=state, info=batch.info) - if isinstance(logits, tuple): - dist = self.dist_fn(*logits) + v_s_.append(to_numpy(self.critic(b.obs_next))) + v_s_ = np.concatenate(v_s_, axis=0) + if self._rew_norm: # unnormalize v_s_ + v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean + unnormalized_returns, _ = self.compute_episodic_return( + batch, buffer, indice, v_s_=v_s_, + gamma=self._gamma, gae_lambda=self._lambda) + if self._rew_norm: + batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ + np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.update(unnormalized_returns) else: - dist = self.dist_fn(logits) - act = dist.sample() - return Batch(logits=logits, act=act, state=h, dist=dist) + batch.returns = unnormalized_returns + return batch def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 4333112b4..ac06f1c00 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -4,10 +4,11 @@ from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.utils import RunningMeanStd class PGPolicy(BasePolicy): - """Implementation of Vanilla Policy Gradient. + """Implementation of REINFORCE algorithm. :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) @@ -33,7 +34,7 @@ class PGPolicy(BasePolicy): def __init__( self, - model: Optional[torch.nn.Module], + model: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], discount_factor: float = 0.99, @@ -45,14 +46,15 @@ def __init__( ) -> None: super().__init__(action_scaling=action_scaling, action_bound_method=action_bound_method, **kwargs) - if model is not None: - self.model: torch.nn.Module = model + self.actor = model self.optim = optim self.lr_scheduler = lr_scheduler self.dist_fn = dist_fn assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self._gamma = discount_factor self._rew_norm = reward_normalization + self.ret_rms = RunningMeanStd() + self._eps = 1e-8 def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray @@ -65,11 +67,16 @@ def process_fn( where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. """ - # batch.returns = self._vanilla_returns(batch) - # batch.returns = self._vectorized_returns(batch) - return self.compute_episodic_return( - batch, buffer, indice, gamma=self._gamma, - gae_lambda=1.0, rew_norm=self._rew_norm) + v_s_ = np.full(indice.shape, self.ret_rms.mean) + unnormalized_returns, _ = self.compute_episodic_return( + batch, buffer, indice, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0) + if self._rew_norm: + batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ + np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.update(unnormalized_returns) + else: + batch.returns = unnormalized_returns + return batch def forward( self, @@ -91,7 +98,7 @@ def forward( Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ - logits, h = self.model(batch.obs, state=state, info=batch.info) + logits, h = self.actor(batch.obs, state=state) if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: @@ -106,9 +113,10 @@ def learn( # type: ignore for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): self.optim.zero_grad() - dist = self(b).dist - a = to_torch_as(b.act, dist.logits) - r = to_torch_as(b.returns, dist.logits) + result = self(b) + dist = result.dist + a = to_torch_as(b.act, result.act) + r = to_torch_as(b.returns, result.act) log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) loss = -(log_prob * r).mean() loss.backward() @@ -119,27 +127,3 @@ def learn( # type: ignore self.lr_scheduler.step() return {"loss": losses} - - # def _vanilla_returns(self, batch): - # returns = batch.rew[:] - # last = 0 - # for i in range(len(returns) - 1, -1, -1): - # if not batch.done[i]: - # returns[i] += self._gamma * last - # last = returns[i] - # return returns - - # def _vectorized_returns(self, batch): - # # according to my tests, it is slower than _vanilla_returns - # # import scipy.signal - # convolve = np.convolve - # # convolve = scipy.signal.convolve - # rew = batch.rew[::-1] - # batch_size = len(rew) - # gammas = self._gamma ** np.arange(batch_size) - # c = convolve(rew, gammas)[:batch_size] - # T = np.where(batch.done[::-1])[0] - # d = np.zeros_like(rew) - # d[T] += c[T] - rew[T] - # d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T) - # return (c - convolve(d, gammas)[:batch_size])[::-1] diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 4d81dd6cd..db7a22c6f 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,13 +1,13 @@ import torch import numpy as np from torch import nn -from typing import Any, Dict, List, Type, Union, Optional +from typing import Any, Dict, List, Type, Optional -from tianshou.policy import PGPolicy +from tianshou.policy import A2CPolicy from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as -class PPOPolicy(PGPolicy): +class PPOPolicy(A2CPolicy): r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. :param torch.nn.Module actor: the actor network following the rules in @@ -30,8 +30,8 @@ class PPOPolicy(PGPolicy): Default to 5.0 (set None if you do not want to use it). :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. Default to True. - :param bool reward_normalization: normalize the returns to Normal(0, 1). - Default to True. + :param bool reward_normalization: normalize the returns and advantage to + Normal(0, 1). Default to False. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint. @@ -58,7 +58,6 @@ def __init__( critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], - discount_factor: float = 0.99, max_grad_norm: Optional[float] = None, eps_clip: float = 0.2, vf_coef: float = 0.5, @@ -66,81 +65,50 @@ def __init__( gae_lambda: float = 0.95, dual_clip: Optional[float] = None, value_clip: bool = True, - reward_normalization: bool = True, max_batchsize: int = 256, **kwargs: Any, ) -> None: - super().__init__(None, optim, dist_fn, discount_factor, **kwargs) - self._max_grad_norm = max_grad_norm + super().__init__( + actor, critic, optim, dist_fn, max_grad_norm=max_grad_norm, + vf_coef=vf_coef, ent_coef=ent_coef, gae_lambda=gae_lambda, + max_batchsize=max_batchsize, **kwargs) self._eps_clip = eps_clip - self._weight_vf = vf_coef - self._weight_ent = ent_coef - self.actor = actor - self.critic = critic - self._batch = max_batchsize - assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." - self._lambda = gae_lambda assert dual_clip is None or dual_clip > 1.0, \ "Dual-clip PPO parameter should greater than 1.0." self._dual_clip = dual_clip self._value_clip = value_clip - self._rew_norm = reward_normalization def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - if self._rew_norm: - mean, std = batch.rew.mean(), batch.rew.std() - if not np.isclose(std, 0.0, 1e-2): - batch.rew = (batch.rew - mean) / std - v, v_, old_log_prob = [], [], [] + v_s, v_s_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): - v_.append(self.critic(b.obs_next)) - v.append(self.critic(b.obs)) - old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v[0]))) - v_ = to_numpy(torch.cat(v_, dim=0)) - batch = self.compute_episodic_return( - batch, buffer, indice, v_, gamma=self._gamma, - gae_lambda=self._lambda, rew_norm=self._rew_norm) - batch.v = torch.cat(v, dim=0).flatten() # old value - batch.act = to_torch_as(batch.act, v[0]) - batch.logp_old = torch.cat(old_log_prob, dim=0) - batch.returns = to_torch_as(batch.returns, v[0]) - batch.adv = batch.returns - batch.v + v_s.append(self.critic(b.obs)) + v_s_.append(self.critic(b.obs_next)) + old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v_s[0]))) + batch.v_s = torch.cat(v_s, dim=0).flatten() # old value + v_s = to_numpy(batch.v_s) + v_s_ = to_numpy(torch.cat(v_s_, dim=0).flatten()) + if self._rew_norm: # unnormalize v_s & v_s_ + v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean + v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.mean + unnormalized_returns, advantages = self.compute_episodic_return( + batch, buffer, indice, v_s_, v_s, + gamma=self._gamma, gae_lambda=self._lambda) if self._rew_norm: - mean, std = batch.adv.mean(), batch.adv.std() - if not np.isclose(std.item(), 0.0, 1e-2): - batch.adv = (batch.adv - mean) / std - return batch - - def forward( - self, - batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs: Any, - ) -> Batch: - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which has 4 keys: - - * ``act`` the action. - * ``logits`` the network's raw output. - * ``dist`` the action distribution. - * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for - more detailed explanation. - """ - logits, h = self.actor(batch.obs, state=state, info=batch.info) - if isinstance(logits, tuple): - dist = self.dist_fn(*logits) + batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ + np.sqrt(self.ret_rms.var + self._eps) + self.ret_rms.update(unnormalized_returns) + mean, std = np.mean(advantages), np.std(advantages) + advantages = (advantages - mean) / std # per-batch norm else: - dist = self.dist_fn(logits) - act = dist.sample() - return Batch(logits=logits, act=act, state=h, dist=dist) + batch.returns = unnormalized_returns + batch.act = to_torch_as(batch.act, batch.v_s) + batch.logp_old = torch.cat(old_log_prob, dim=0) + batch.returns = to_torch_as(batch.returns, batch.v_s) + batch.adv = to_torch_as(advantages, batch.v_s) + return batch def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any @@ -162,7 +130,8 @@ def learn( # type: ignore clip_loss = -torch.min(surr1, surr2).mean() clip_losses.append(clip_loss.item()) if self._value_clip: - v_clip = b.v + (value - b.v).clamp(-self._eps_clip, self._eps_clip) + v_clip = b.v_s + (value - b.v_s).clamp( + -self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = 0.5 * torch.max(vf1, vf2).mean() @@ -176,10 +145,10 @@ def learn( # type: ignore losses.append(loss.item()) self.optim.zero_grad() loss.backward() - if self._max_grad_norm: + if self._grad_norm is not None: nn.utils.clip_grad_norm_( list(self.actor.parameters()) + list(self.critic.parameters()), - self._max_grad_norm) + self._grad_norm) self.optim.step() # update learning rate if lr_scheduler is given if self.lr_scheduler is not None: