这是indexloc提供的服务,不要输入任何密码
Skip to content

Add Timelimit trick to optimize policies #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_sac(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True,
reward_normalization=args.rew_norm,
exploration_noise=OUNoise(0.0, args.noise_std))
# collector
train_collector = Collector(
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/runnable/ant_v2_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_ddpg(args=get_args()):
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=True, ignore_done=True)
reward_normalization=True)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/runnable/ant_v2_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_td3(args=get_args()):
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True)
reward_normalization=True)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/runnable/halfcheetahBullet_v0_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_sac(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=True, ignore_done=True)
reward_normalization=True)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
2 changes: 1 addition & 1 deletion examples/mujoco/runnable/point_maze_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_td3(args=get_args()):
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True)
reward_normalization=True)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
79 changes: 75 additions & 4 deletions test/base/test_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def test_episodic_returns(size=2560):
batch = Batch(
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
info=Batch({'TimeLimit.truncated':
np.array([False, False, False, False, False, True, False, False])})
)
for b in batch:
b.obs = b.act = 1
Expand Down Expand Up @@ -69,6 +71,24 @@ def test_episodic_returns(size=2560):
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]),
info=Batch({'TimeLimit.truncated':
np.array([False, False, False, True, False, False,
False, True, False, False, False, False])})
)
for b in batch:
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
454.0109, 375.2386, 290.3669, 199.01,
462.9138, 381.3571, 293.5248, 199.02,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)

if __name__ == '__main__':
buf = ReplayBuffer(size)
batch = Batch(
Expand Down Expand Up @@ -106,15 +126,19 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice):
buf_len = len(buffer)
for i in range(len(indice)):
flag, r = False, 0.
real_step_n = nstep
for n in range(nstep):
idx = (indice[i] + n) % buf_len
r += buffer.rew[idx] * gamma ** n
if buffer.done[idx]:
flag = True
if not (hasattr(buffer, 'info') and
buffer.info['TimeLimit.truncated'][idx]):
flag = True
real_step_n = n + 1
break
if not flag:
idx = (indice[i] + nstep - 1) % buf_len
r += to_numpy(target_q_fn(buffer, idx)) * gamma ** nstep
idx = (indice[i] + real_step_n - 1) % buf_len
r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n
returns[i] = r
return returns

Expand Down Expand Up @@ -161,10 +185,56 @@ def test_nstep_returns(size=10000):
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])


def test_nstep_returns_with_timelimit(size=10000):
buf = ReplayBuffer(10)
for i in range(12):
buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3,
info={"TimeLimit.truncated": i == 3}))
batch, indice = buf.sample(0)
assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
# rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
# done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]
# test nstep = 1
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=1
).pop('returns').reshape(-1))
assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
r_ = compute_nstep_return_base(1, .1, buf, indice)
assert np.allclose(returns, r_), (r_, returns)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 2
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=2
).pop('returns').reshape(-1))
assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
r_ = compute_nstep_return_base(2, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 10
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=10
).pop('returns').reshape(-1))
assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78,
7.8, 8, 10.122, 11.22, 12.2, 12])
r_ = compute_nstep_return_base(10, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])

if __name__ == '__main__':
buf = ReplayBuffer(size)
for i in range(int(size * 1.5)):
buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0))
buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0,
info={"TimeLimit.truncated": i % 33 == 0}))
batch, indice = buf.sample(256)

def vanilla():
Expand All @@ -181,4 +251,5 @@ def optimized():

if __name__ == '__main__':
test_nstep_returns()
test_nstep_returns_with_timelimit()
test_episodic_returns()
1 change: 0 additions & 1 deletion test/continuous/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def test_ddpg(args=get_args()):
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector
train_collector = Collector(
Expand Down
1 change: 0 additions & 1 deletion test/continuous/test_sac_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def test_sac_with_il(args=get_args()):
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector
train_collector = Collector(
Expand Down
1 change: 0 additions & 1 deletion test/continuous/test_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def test_td3(args=get_args()):
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)
# collector
train_collector = Collector(
Expand Down
4 changes: 2 additions & 2 deletions test/discrete/test_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def test_fn(epoch, env_step):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])
if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions test/discrete/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ def test_discrete_sac(args=get_args()):
policy = DiscreteSACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done)
reward_normalization=args.rew_norm)
# collector
train_collector = Collector(
policy, train_envs,
Expand Down
9 changes: 7 additions & 2 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray:
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
"obs_next" of that buffer[indice] is valid.
"""
return ~buffer.done[indice].astype(np.bool)
mask = ~buffer.done[indice].astype(np.bool)
# info['TimeLimit.truncated'] will be set to True if 'done' flag is generated
# because of timelimit of environments. Checkout gym.wrappers.TimeLimit.
if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info:
mask = mask | buffer.info['TimeLimit.truncated'][indice]
return mask

@staticmethod
def compute_episodic_return(
Expand Down Expand Up @@ -377,7 +382,7 @@ def _nstep_return(
gammas = np.full(indices[0].shape, n_step)
for n in range(n_step - 1, -1, -1):
now = indices[n]
gammas[end_flag[now] > 0] = n
gammas[end_flag[now] > 0] = n + 1
returns[end_flag[now] > 0] = 0.0
returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
Expand Down
6 changes: 0 additions & 6 deletions tianshou/policy/modelfree/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class DDPGPolicy(BasePolicy):
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to False.
:param int estimation_step: greater than 1, the number of steps to look
ahead.

Expand All @@ -48,7 +46,6 @@ def __init__(
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs: Any,
) -> None:
Expand All @@ -73,7 +70,6 @@ def __init__(
self._action_scale = (action_range[1] - action_range[0]) / 2.0
# it is only a little difference to use GaussianNoise
# self.noise = OUNoise()
self._rm_done = ignore_done
self._rew_norm = reward_normalization
assert estimation_step > 0, "estimation_step should be greater than 0"
self._n_step = estimation_step
Expand Down Expand Up @@ -110,8 +106,6 @@ def _target_q(
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._rm_done:
batch.done = batch.done * 0.0
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
Expand Down
5 changes: 1 addition & 4 deletions tianshou/policy/modelfree/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class DiscreteSACPolicy(SACPolicy):
alpha is automatatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.

.. seealso::

Expand All @@ -51,13 +49,12 @@ def __init__(
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
] = 0.2,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
reward_normalization, ignore_done, estimation_step,
reward_normalization, estimation_step,
**kwargs)
self._alpha: Union[float, torch.Tensor]

Expand Down
5 changes: 1 addition & 4 deletions tianshou/policy/modelfree/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class SACPolicy(DDPGPolicy):
alpha is automatatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to False.
:param BaseNoise exploration_noise: add a noise to action for exploration,
defaults to None. This is useful when solving hard-exploration problem.
:param bool deterministic_eval: whether to use deterministic action (mean
Expand Down Expand Up @@ -63,14 +61,13 @@ def __init__(
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
] = 0.2,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True,
**kwargs: Any,
) -> None:
super().__init__(None, None, None, None, action_range, tau, gamma,
exploration_noise, reward_normalization, ignore_done,
exploration_noise, reward_normalization,
estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
Expand Down
9 changes: 3 additions & 6 deletions tianshou/policy/modelfree/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class TD3Policy(DDPGPolicy):
network, default to 0.5.
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to False.

.. seealso::

Expand All @@ -62,13 +60,12 @@ def __init__(
update_actor_freq: int = 2,
noise_clip: float = 0.5,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(actor, actor_optim, None, None, action_range,
tau, gamma, exploration_noise, reward_normalization,
ignore_done, estimation_step, **kwargs)
super().__init__(actor, actor_optim, None, None, action_range, tau, gamma,
exploration_noise, reward_normalization,
estimation_step, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
self.critic1_optim = critic1_optim
Expand Down
2 changes: 1 addition & 1 deletion tianshou/utils/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def log_test_data(self, collect_result: dict, step: int) -> None:
def log_update_data(self, update_result: dict, step: int) -> None:
if step - self.last_log_update_step >= self.update_interval:
for k, v in update_result.items():
self.write("train/" + k, step, v) # save in train/
self.write(k, step, v)
self.last_log_update_step = step


Expand Down