From c8352610e2719831687a6081d1e9d510b291cdc4 Mon Sep 17 00:00:00 2001 From: ChenHuayu <308604256@qq.com> Date: Mon, 20 Jul 2020 20:20:42 +0800 Subject: [PATCH 1/3] buffer update bug fix --- test/base/test_buffer.py | 18 +++++++++++------- tianshou/data/buffer.py | 7 ++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 28ccd88c2..505afc2af 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -10,7 +10,6 @@ def test_replaybuffer(size=10, bufsize=20): env = MyTestEnv(size) buf = ReplayBuffer(bufsize) - buf2 = ReplayBuffer(bufsize) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): @@ -22,11 +21,6 @@ def test_replaybuffer(size=10, bufsize=20): assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() - assert len(buf) > len(buf2) - buf2.update(buf) - assert len(buf) == len(buf2) - assert buf2[0].obs == buf[5].obs - assert buf2[-1].obs == buf[4].obs b = ReplayBuffer(size=10) b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 @@ -82,7 +76,6 @@ def test_stack(size=5, bufsize=9, stack_num=4): _, indice = buf2.sample(1) assert indice.sum() == 2 - def test_priortized_replaybuffer(size=32, bufsize=15): env = MyTestEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) @@ -103,9 +96,20 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert np.allclose( buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) +def test_update(): + buf1 = ReplayBuffer(4, stack_num=2) + buf2 = ReplayBuffer(4, stack_num=2) + for i in range(5): + buf1.add(obs=np.array([i]), act=float(i), rew=i*i, done= False, info= {'incident':'found'}) + assert len(buf1) > len(buf2) + buf2.update(buf1) + assert len(buf1) == len(buf2) + assert (buf2[0].obs == buf1[1].obs).all() + assert (buf2[-1].obs == buf1[0].obs).all() if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() test_stack() test_priortized_replaybuffer(233333, 200000) + test_update() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index f593d2a74..f36910e4d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -153,7 +153,11 @@ def update(self, buffer: 'ReplayBuffer') -> None: return i = begin = buffer._index % len(buffer) while True: - self.add(**buffer[i]) + self.add( + buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], + (buffer.obs_next[i] if not buffer.obs_next.is_empty() else None) if self._save_s_ else None, + buffer.info[i] if not buffer.info.is_empty() else {}, + buffer.policy[i] if not buffer.policy.is_empty() else {}) i = (i + 1) % len(buffer) if i == begin: break @@ -440,3 +444,4 @@ def __getitem__(self, index: Union[ weight=self.weight[index], policy=self.get(index, 'policy'), ) + From 0403780dc7f9603f115473099e025e703dc151bb Mon Sep 17 00:00:00 2001 From: ChenHuayu <308604256@qq.com> Date: Mon, 20 Jul 2020 21:00:01 +0800 Subject: [PATCH 2/3] some fix in buffer update --- test/base/test_buffer.py | 13 +++++++++---- tianshou/data/buffer.py | 15 +++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 505afc2af..4234861d4 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,5 +1,6 @@ import numpy as np -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer + +from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -76,6 +77,7 @@ def test_stack(size=5, bufsize=9, stack_num=4): _, indice = buf2.sample(1) assert indice.sum() == 2 + def test_priortized_replaybuffer(size=32, bufsize=15): env = MyTestEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) @@ -96,17 +98,20 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert np.allclose( buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) + def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): - buf1.add(obs=np.array([i]), act=float(i), rew=i*i, done= False, info= {'incident':'found'}) - assert len(buf1) > len(buf2) + buf1.add(obs=np.array([i]), act=float(i), rew=i * + i, done=False, info={'incident': 'found'}) + assert len(buf1) > len(buf2) buf2.update(buf1) - assert len(buf1) == len(buf2) + assert len(buf1) == len(buf2) assert (buf2[0].obs == buf1[1].obs).all() assert (buf2[-1].obs == buf1[0].obs).all() + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index f36910e4d..4ab28c4b0 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -153,11 +153,19 @@ def update(self, buffer: 'ReplayBuffer') -> None: return i = begin = buffer._index % len(buffer) while True: + obs_next = None if isinstance( + buffer.obs_next, Batch) and buffer.obs_next.is_empty() else \ + buffer.obs_next[i] + info = {} if isinstance( + buffer.obs_next, Batch) and buffer.info.is_empty() else \ + buffer.info[i] + policy = {} if isinstance( + buffer.obs_next, Batch) and buffer.policy.is_empty() else \ + buffer.policy[i] self.add( buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], - (buffer.obs_next[i] if not buffer.obs_next.is_empty() else None) if self._save_s_ else None, - buffer.info[i] if not buffer.info.is_empty() else {}, - buffer.policy[i] if not buffer.policy.is_empty() else {}) + obs_next if self._save_s_ else None, + info, policy) i = (i + 1) % len(buffer) if i == begin: break @@ -444,4 +452,3 @@ def __getitem__(self, index: Union[ weight=self.weight[index], policy=self.get(index, 'policy'), ) - From 1e3044ba38c3a90caf8ab7f2fd0045ebfdcca4f6 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Mon, 20 Jul 2020 22:00:09 +0800 Subject: [PATCH 3/3] polish --- test/base/test_buffer.py | 4 ++-- tianshou/data/buffer.py | 25 +++++++++++-------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4234861d4..6178a3299 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -103,8 +103,8 @@ def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): - buf1.add(obs=np.array([i]), act=float(i), rew=i * - i, done=False, info={'incident': 'found'}) + buf1.add(obs=np.array([i]), act=float(i), rew=i * i, + done=False, info={'incident': 'found'}) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 80793ccdb..b7ddcff38 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -157,28 +157,25 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst + def _get_stack_num(self): + return self._stack + + def _set_stack_num(self, num): + self._stack = num + def update(self, buffer: 'ReplayBuffer') -> None: """Move the data from the given buffer to self.""" if len(buffer) == 0: return i = begin = buffer._index % len(buffer) + origin = buffer._get_stack_num() + buffer._set_stack_num(0) while True: - obs_next = None if isinstance( - buffer.obs_next, Batch) and buffer.obs_next.is_empty() else \ - buffer.obs_next[i] - info = {} if isinstance( - buffer.obs_next, Batch) and buffer.info.is_empty() else \ - buffer.info[i] - policy = {} if isinstance( - buffer.obs_next, Batch) and buffer.policy.is_empty() else \ - buffer.policy[i] - self.add( - buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i], - obs_next if self._save_s_ else None, - info, policy) + self.add(**buffer[i]) i = (i + 1) % len(buffer) if i == begin: break + buffer._set_stack_num(origin) def add(self, obs: Union[dict, Batch, np.ndarray], @@ -420,7 +417,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: replace=self._replace) p = p[indice] # weight of each sample elif batch_size == 0: - p = np.full(shape=self._size, fill_value=1.0/self._size) + p = np.full(shape=self._size, fill_value=1.0 / self._size) indice = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index),