From fab18b0445171551e92f0060abeb0cf11110377d Mon Sep 17 00:00:00 2001 From: sunmingzhi <531483935@qq.com> Date: Wed, 1 Mar 2023 09:56:32 +0800 Subject: [PATCH 1/3] fix bug #811 --- test/base/test_buffer.py | 32 ++++++++++++++++++++++++++++++++ tianshou/data/buffer/her.py | 7 ++++--- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index cf011b7f1..c4d3c6bd7 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -437,6 +437,38 @@ def compute_reward_fn(ag, g): assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep) assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep) + # Another test case for cycled indices + env_size = 99 + bufsize = 15 + env = MyGoalEnv(env_size, array_state=False) + buf = HERReplayBuffer( + bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8 + ) + buf.future_p = 1 + for x, ep_len in enumerate([10, 20]): + obs, _ = env.reset() + for i in range(ep_len): + act = 1 + obs_next, rew, terminated, truncated, info = env.step(act) + batch = Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info + ) + if x == 1 and obs["observation"] < 10: + obs = obs_next + continue + buf.add(batch) + obs = obs_next + buf._restore_cache() + sample_indices = np.array([10]) # Suppose the sampled indices is [10] + buf.rewrite_transitions(sample_indices) + assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20] + def test_update(): buf1 = ReplayBuffer(4, stack_num=2) diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 8c5c37166..fc182436e 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -120,9 +120,10 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: # Calculate future timestep to use current = indices[0] terminal = indices[-1] - future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current) - future_offset = future_offset.astype(int) - future_t = (current + future_offset) + episodes_len = (terminal - current + self.maxsize) % self.maxsize + future_offset = np.random.uniform(size=len(indices[0])) * episodes_len + future_offset = np.round(future_offset).astype(int) + future_t = (current + future_offset) % self.maxsize # Compute indices # open indices are used to find longest, unique trajectories among From 1e29223f84406b44eb524f6813ff55fd70571818 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Fri, 3 Mar 2023 15:42:18 -0800 Subject: [PATCH 2/3] fix --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 296035903..0e11d5441 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ exclude = dist *.egg-info max-line-length = 87 -ignore = B305,W504,B006,B008,B024,W503 +ignore = B305,W504,B006,B008,B024,W503,B028 [yapf] based_on_style = pep8 From 746e2c5fd881be88adfd9fa2eeebf0622cbe1ba8 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Fri, 3 Mar 2023 16:02:00 -0800 Subject: [PATCH 3/3] polish --- setup.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 5912cd172..253be206b 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,6 @@ def get_install_requires() -> str: "torch>=1.4.0", "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements - "protobuf~=3.19.0", # breaking change, sphinx fail "packaging", ] @@ -30,9 +29,9 @@ def get_install_requires() -> str: def get_extras_require() -> str: req = { "dev": [ - "sphinx<4", + "sphinx", "sphinx_rtd_theme", - "jinja2<3.1", # temporary fix + "jinja2", "sphinxcontrib-bibtex", "flake8", "flake8-bugbear",