From 54b4cb6b6c2eec2643c61fddb3d90bd29d748492 Mon Sep 17 00:00:00 2001 From: ChenDRAG <308604256@qq.com> Date: Wed, 24 Feb 2021 15:03:38 +0800 Subject: [PATCH 01/10] rebase --- examples/box2d/mcc_sac.py | 2 +- examples/mujoco/runnable/ant_v2_ddpg.py | 2 +- examples/mujoco/runnable/ant_v2_td3.py | 2 +- examples/mujoco/runnable/halfcheetahBullet_v0_sac.py | 2 +- examples/mujoco/runnable/point_maze_td3.py | 2 +- test/continuous/test_ddpg.py | 1 - test/continuous/test_sac_with_il.py | 1 - test/continuous/test_td3.py | 1 - test/discrete/test_sac.py | 3 +-- tianshou/policy/base.py | 7 ++++++- tianshou/policy/modelfree/ddpg.py | 6 ------ tianshou/policy/modelfree/discrete_sac.py | 5 +---- tianshou/policy/modelfree/sac.py | 5 +---- tianshou/policy/modelfree/td3.py | 9 +++------ 14 files changed, 17 insertions(+), 31 deletions(-) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 7fe3daed2..f22c7846e 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -92,7 +92,7 @@ def test_sac(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=args.rew_norm, ignore_done=True, + reward_normalization=args.rew_norm, exploration_noise=OUNoise(0.0, args.noise_std)) # collector train_collector = Collector( diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index bc75e1f51..53e9ac4d7 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -73,7 +73,7 @@ def test_ddpg(args=get_args()): action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 004b604a6..cbbd952f3 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -81,7 +81,7 @@ def test_td3(args=get_args()): policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 2ed046294..db0ce6ec8 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -81,7 +81,7 @@ def test_sac(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 8e5f37be7..ed2ce0efc 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -86,7 +86,7 @@ def test_td3(args=get_args()): policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 311aa65a7..232eef17c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -83,7 +83,6 @@ def test_ddpg(args=get_args()): tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 1b4a977ef..8d1842876 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -91,7 +91,6 @@ def test_sac_with_il(args=get_args()): action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index d67818194..c24741c3c 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -95,7 +95,6 @@ def test_td3(args=get_args()): update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index d7f408ffe..b5871f66a 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -88,8 +88,7 @@ def test_discrete_sac(args=get_args()): policy = DiscreteSACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, args.tau, args.gamma, args.alpha, - reward_normalization=args.rew_norm, - ignore_done=args.ignore_done) + reward_normalization=args.rew_norm) # collector train_collector = Collector( policy, train_envs, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 99a2f8a81..639eca7d6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -203,7 +203,12 @@ def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray: :return: A bool type numpy.ndarray in the same shape with indice. "True" means "obs_next" of that buffer[indice] is valid. """ - return ~buffer.done[indice].astype(np.bool) + mask = ~buffer.done[indice].astype(np.bool) + # info['TimeLimit.truncated'] will be set to True if 'done' flag is generated + # because of timelimit of environments. Checkout gym.wrappers.TimeLimit. + if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info: + mask = mask | buffer.info['TimeLimit.truncated'][indice] + return mask @staticmethod def compute_episodic_return( diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index efa9fb7f9..d91359a2f 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -26,8 +26,6 @@ class DDPGPolicy(BasePolicy): add to the action, defaults to ``GaussianNoise(sigma=0.1)``. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. :param int estimation_step: greater than 1, the number of steps to look ahead. @@ -48,7 +46,6 @@ def __init__( gamma: float = 0.99, exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: @@ -73,7 +70,6 @@ def __init__( self._action_scale = (action_range[1] - action_range[0]) / 2.0 # it is only a little difference to use GaussianNoise # self.noise = OUNoise() - self._rm_done = ignore_done self._rew_norm = reward_normalization assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step @@ -110,8 +106,6 @@ def _target_q( def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - if self._rm_done: - batch.done = batch.done * 0.0 batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 9c46fc4a3..fd67d4738 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -28,8 +28,6 @@ class DiscreteSACPolicy(SACPolicy): alpha is automatatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to ``False``. .. seealso:: @@ -51,13 +49,12 @@ def __init__( float, Tuple[float, torch.Tensor, torch.optim.Optimizer] ] = 0.2, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: super().__init__(actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, (-np.inf, np.inf), tau, gamma, alpha, - reward_normalization, ignore_done, estimation_step, + reward_normalization, estimation_step, **kwargs) self._alpha: Union[float, torch.Tensor] diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 091d1243c..cb53fad7f 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -34,8 +34,6 @@ class SACPolicy(DDPGPolicy): alpha is automatatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. :param BaseNoise exploration_noise: add a noise to action for exploration, defaults to None. This is useful when solving hard-exploration problem. :param bool deterministic_eval: whether to use deterministic action (mean @@ -63,14 +61,13 @@ def __init__( float, Tuple[float, torch.Tensor, torch.optim.Optimizer] ] = 0.2, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, exploration_noise: Optional[BaseNoise] = None, deterministic_eval: bool = True, **kwargs: Any, ) -> None: super().__init__(None, None, None, None, action_range, tau, gamma, - exploration_noise, reward_normalization, ignore_done, + exploration_noise, reward_normalization, estimation_step, **kwargs) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index f79c2a0d5..23e16d88a 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -37,8 +37,6 @@ class TD3Policy(DDPGPolicy): network, default to 0.5. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. .. seealso:: @@ -62,13 +60,12 @@ def __init__( update_actor_freq: int = 2, noise_clip: float = 0.5, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: - super().__init__(actor, actor_optim, None, None, action_range, - tau, gamma, exploration_noise, reward_normalization, - ignore_done, estimation_step, **kwargs) + super().__init__(actor, actor_optim, None, None, action_range, tau, gamma, + exploration_noise, reward_normalization, + estimation_step, **kwargs) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim From 34ef6c04b88f3728a14e89c255dafa1b77ebc3f4 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 25 Feb 2021 13:54:44 +0800 Subject: [PATCH 02/10] add test, solve bug --- test/base/test_returns.py | 40 +++++++++++++++++++++++++++++++-------- tianshou/policy/base.py | 2 +- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 5aba7f56a..e4e485164 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -24,6 +24,8 @@ def test_episodic_returns(size=2560): batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), + info=Batch( {'TimeLimit.truncated': + np.array([False, False, False, False, False, True, False, False])}) ) for b in batch: b.obs = b.act = 1 @@ -69,6 +71,24 @@ def test_episodic_returns(size=2560): 474.2876, 390.1027, 299.476, 202.]) assert np.allclose(ret.returns, returns) buf.reset() + batch = Batch( + done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), + rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), + info=Batch( {'TimeLimit.truncated': + np.array([False, False, False, True, False, False, + False, True, False, False, False, False])}) + ) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) + ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) + returns = np.array( + [454.0109, 375.2386, 290.3669 , 199.01, + 462.9138, 381.3571, 293.5248 , 199.02, + 474.2876, 390.1027, 299.476 , 202. ]) + assert np.allclose(ret.returns, returns) + if __name__ == '__main__': buf = ReplayBuffer(size) batch = Batch( @@ -106,15 +126,18 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): buf_len = len(buffer) for i in range(len(indice)): flag, r = False, 0. + real_step_n = nstep for n in range(nstep): idx = (indice[i] + n) % buf_len r += buffer.rew[idx] * gamma ** n - if buffer.done[idx]: + if buffer.done[idx] and not buffer.info['TimeLimit.truncated'][idx]: flag = True + if buffer.done[idx]: + real_step_n = n + 1 break if not flag: - idx = (indice[i] + nstep - 1) % buf_len - r += to_numpy(target_q_fn(buffer, idx)) * gamma ** nstep + idx = (indice[i] + real_step_n - 1) % buf_len + r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n returns[i] = r return returns @@ -122,7 +145,7 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, info={"TimeLimit.truncated": i==3})) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] @@ -131,7 +154,7 @@ def test_nstep_returns(size=10000): returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=1 ).pop('returns').reshape(-1)) - assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) + assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indice) assert np.allclose(returns, r_), (r_, returns) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( @@ -142,7 +165,7 @@ def test_nstep_returns(size=10000): returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=2 ).pop('returns').reshape(-1)) - assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( @@ -153,7 +176,7 @@ def test_nstep_returns(size=10000): returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=10 ).pop('returns').reshape(-1)) - assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( @@ -164,7 +187,8 @@ def test_nstep_returns(size=10000): if __name__ == '__main__': buf = ReplayBuffer(size) for i in range(int(size * 1.5)): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0)) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0, + info={"TimeLimit.truncated": i==3})) batch, indice = buf.sample(256) def vanilla(): diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 639eca7d6..cf2678f74 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -382,7 +382,7 @@ def _nstep_return( gammas = np.full(indices[0].shape, n_step) for n in range(n_step - 1, -1, -1): now = indices[n] - gammas[end_flag[now] > 0] = n + gammas[end_flag[now] > 0] = n + 1 returns[end_flag[now] > 0] = 0.0 returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns From 024055fcc32a4b62181497f4f85fa6537c6dec79 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 25 Feb 2021 13:58:09 +0800 Subject: [PATCH 03/10] pep8 --- test/base/test_returns.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index e4e485164..03c54fb62 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -24,8 +24,8 @@ def test_episodic_returns(size=2560): batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), - info=Batch( {'TimeLimit.truncated': - np.array([False, False, False, False, False, True, False, False])}) + info=Batch({'TimeLimit.truncated': + np.array([False, False, False, False, False, True, False, False])}) ) for b in batch: b.obs = b.act = 1 @@ -74,9 +74,9 @@ def test_episodic_returns(size=2560): batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), - info=Batch( {'TimeLimit.truncated': - np.array([False, False, False, True, False, False, - False, True, False, False, False, False])}) + info=Batch({'TimeLimit.truncated': + np.array([False, False, False, True, False, False, + False, True, False, False, False, False])}) ) for b in batch: b.obs = b.act = 1 @@ -84,9 +84,9 @@ def test_episodic_returns(size=2560): v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) returns = np.array( - [454.0109, 375.2386, 290.3669 , 199.01, - 462.9138, 381.3571, 293.5248 , 199.02, - 474.2876, 390.1027, 299.476 , 202. ]) + [454.0109, 375.2386, 290.3669, 199.01, + 462.9138, 381.3571, 293.5248, 199.02, + 474.2876, 390.1027, 299.476, 202.]) assert np.allclose(ret.returns, returns) if __name__ == '__main__': @@ -145,7 +145,8 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, info={"TimeLimit.truncated": i==3})) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % + 4 == 3, info={"TimeLimit.truncated": i == 3})) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] @@ -176,7 +177,8 @@ def test_nstep_returns(size=10000): returns = to_numpy(BasePolicy.compute_nstep_return( batch, buf, indice, target_q_fn, gamma=.1, n_step=10 ).pop('returns').reshape(-1)) - assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, + 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( @@ -188,7 +190,7 @@ def test_nstep_returns(size=10000): buf = ReplayBuffer(size) for i in range(int(size * 1.5)): buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0, - info={"TimeLimit.truncated": i==3})) + info={"TimeLimit.truncated": i == 3})) batch, indice = buf.sample(256) def vanilla(): From 2d22143db25287844c542e67aa35cc1ca179e67c Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 25 Feb 2021 15:44:15 +0800 Subject: [PATCH 04/10] small change in logger --- tianshou/utils/log_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py index 7605a27ba..c50c8ebae 100644 --- a/tianshou/utils/log_tools.py +++ b/tianshou/utils/log_tools.py @@ -138,7 +138,7 @@ def log_test_data(self, collect_result: dict, step: int) -> None: def log_update_data(self, update_result: dict, step: int) -> None: if step - self.last_log_update_step >= self.update_interval: for k, v in update_result.items(): - self.write("train/" + k, step, v) # save in train/ + self.write(k, step, v) self.last_log_update_step = step From d83d7e1a88d312c73871bd8d73abad1a74843dd3 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 25 Feb 2021 16:37:04 +0800 Subject: [PATCH 05/10] fix test_c51 --- test/discrete/test_c51.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 53768dae6..1d0c4cc0a 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -116,8 +116,8 @@ def test_fn(epoch, env_step): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': From ebca5e4a9e83a3afff5ca6e6aef2b6588b1d4c5c Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 25 Feb 2021 17:26:37 +0800 Subject: [PATCH 06/10] format --- test/base/test_returns.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 03c54fb62..5c5bc7086 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -83,10 +83,10 @@ def test_episodic_returns(size=2560): buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) - returns = np.array( - [454.0109, 375.2386, 290.3669, 199.01, - 462.9138, 381.3571, 293.5248, 199.02, - 474.2876, 390.1027, 299.476, 202.]) + returns = np.array([ + 454.0109, 375.2386, 290.3669, 199.01, + 462.9138, 381.3571, 293.5248, 199.02, + 474.2876, 390.1027, 299.476, 202.]) assert np.allclose(ret.returns, returns) if __name__ == '__main__': @@ -145,8 +145,8 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % - 4 == 3, info={"TimeLimit.truncated": i == 3})) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, + info={"TimeLimit.truncated": i == 3})) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] From 71f16eea104dc404f39dc32aa43f644e78705ae9 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 25 Feb 2021 19:32:57 +0800 Subject: [PATCH 07/10] add test --- test/base/test_returns.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 5c5bc7086..5528cbf63 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -130,9 +130,10 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): for n in range(nstep): idx = (indice[i] + n) % buf_len r += buffer.rew[idx] * gamma ** n - if buffer.done[idx] and not buffer.info['TimeLimit.truncated'][idx]: - flag = True if buffer.done[idx]: + if not (hasattr(buffer, 'info') and + buffer.info['TimeLimit.truncated'][idx]): + flag = True real_step_n = n + 1 break if not flag: @@ -204,6 +205,36 @@ def optimized(): print('nstep vanilla', timeit(vanilla, setup=vanilla, number=cnt)) print('nstep optim ', timeit(optimized, setup=optimized, number=cnt)) +def test_nstep_returns_without_timelimit(size=10000): + buf = ReplayBuffer(10) + for i in range(12): + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) + batch, indice = buf.sample(0) + assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) + # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] + # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] + # test nstep = 1 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) + r_ = compute_nstep_return_base(1, .1, buf, indice) + assert np.allclose(returns, r_), (r_, returns) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + # test nstep = 2 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + r_ = compute_nstep_return_base(2, .1, buf, indice) + assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) if __name__ == '__main__': test_nstep_returns() From 297495601869d996b0244ff89256893e09180812 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Thu, 25 Feb 2021 19:36:16 +0800 Subject: [PATCH 08/10] pep8 --- test/base/test_returns.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 5528cbf63..99a49f9b0 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -205,6 +205,7 @@ def optimized(): print('nstep vanilla', timeit(vanilla, setup=vanilla, number=cnt)) print('nstep optim ', timeit(optimized, setup=optimized, number=cnt)) + def test_nstep_returns_without_timelimit(size=10000): buf = ReplayBuffer(10) for i in range(12): @@ -236,6 +237,7 @@ def test_nstep_returns_without_timelimit(size=10000): ).pop('returns')) assert np.allclose(returns_multidim, returns[:, np.newaxis]) + if __name__ == '__main__': test_nstep_returns() test_episodic_returns() From 3e2c4a903f6e1cb43749c600ca0f454eabe4009a Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 25 Feb 2021 20:33:44 +0800 Subject: [PATCH 09/10] look at diff --- test/base/test_returns.py | 65 ++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 99a49f9b0..e96983db8 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -144,6 +144,38 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): def test_nstep_returns(size=10000): + buf = ReplayBuffer(10) + for i in range(12): + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) + batch, indice = buf.sample(0) + assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) + # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] + # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] + # test nstep = 1 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) + r_ = compute_nstep_return_base(1, .1, buf, indice) + assert np.allclose(returns, r_), (r_, returns) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + # test nstep = 2 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + r_ = compute_nstep_return_base(2, .1, buf, indice) + assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + + +def test_nstep_returns_with_timelimit(size=10000): buf = ReplayBuffer(10) for i in range(12): buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, @@ -206,38 +238,7 @@ def optimized(): print('nstep optim ', timeit(optimized, setup=optimized, number=cnt)) -def test_nstep_returns_without_timelimit(size=10000): - buf = ReplayBuffer(10) - for i in range(12): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) - batch, indice = buf.sample(0) - assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) - # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] - # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] - # test nstep = 1 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=1 - ).pop('returns').reshape(-1)) - assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) - r_ = compute_nstep_return_base(1, .1, buf, indice) - assert np.allclose(returns, r_), (r_, returns) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 - ).pop('returns')) - assert np.allclose(returns_multidim, returns[:, np.newaxis]) - # test nstep = 2 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=2 - ).pop('returns').reshape(-1)) - assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) - r_ = compute_nstep_return_base(2, .1, buf, indice) - assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 - ).pop('returns')) - assert np.allclose(returns_multidim, returns[:, np.newaxis]) - - if __name__ == '__main__': test_nstep_returns() + test_nstep_returns_with_timelimit() test_episodic_returns() From b41dd09faae4f7ede670723f6d662b7760a75b19 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 25 Feb 2021 20:37:58 +0800 Subject: [PATCH 10/10] add missing test --- test/base/test_returns.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/base/test_returns.py b/test/base/test_returns.py index e96983db8..e8d70de5c 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -173,6 +173,17 @@ def test_nstep_returns(size=10000): batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 ).pop('returns')) assert np.allclose(returns_multidim, returns[:, np.newaxis]) + # test nstep = 10 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + r_ = compute_nstep_return_base(10, .1, buf, indice) + assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) def test_nstep_returns_with_timelimit(size=10000): @@ -223,7 +234,7 @@ def test_nstep_returns_with_timelimit(size=10000): buf = ReplayBuffer(size) for i in range(int(size * 1.5)): buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0, - info={"TimeLimit.truncated": i == 3})) + info={"TimeLimit.truncated": i % 33 == 0})) batch, indice = buf.sample(256) def vanilla():