From 224b2428027541ddfd6988fa3c874578950bb3c1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 09:48:43 +0800 Subject: [PATCH 01/19] pickle compatible for buffer --- test/base/test_buffer.py | 39 +++++++++++++++++++++++++++++++++++++-- tianshou/data/buffer.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4c6bc7176..4b9f36eda 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -2,8 +2,8 @@ import numpy as np from timeit import timeit -from tianshou.data import Batch, PrioritizedReplayBuffer, \ - ReplayBuffer, SegmentTree +from tianshou.data import Batch, SegmentTree, \ + ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -214,10 +214,45 @@ def sample_tree(): print('tree', timeit(sample_tree, setup=sample_tree, number=1000)) +def test_pickle(): + size = 100 + vbuf = ReplayBuffer(size, stack_num=2) + lbuf = ListReplayBuffer() + pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) + for i in range(4): + vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=1, done=0) + for i in range(3): + lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=1, done=0) + for i in range(5): + pbuf.add(obs=Batch(index=np.array([i])), + act=2, rew=1, done=0, weight=np.random.rand()) + # save + vbuf.save('/tmp/vbuf.pkl') + lbuf.save('/tmp/lbuf.pkl') + pbuf.save('/tmp/pbuf.pkl') + # normal load + _vbuf = ReplayBuffer.load('/tmp/vbuf.pkl') + _lbuf = ListReplayBuffer.load('/tmp/lbuf.pkl') + _pbuf = PrioritizedReplayBuffer.load('/tmp/pbuf.pkl') + assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act) + assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act) + assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) + # make sure the meta var is identical + assert _vbuf.stack_num == vbuf.stack_num + assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], + pbuf.weight[np.arange(len(pbuf))]) + # load data by inconsistent class will raise an error + with pytest.raises(AssertionError): + ReplayBuffer.load('/tmp/lbuf.pkl') + with pytest.raises(AssertionError): + PrioritizedReplayBuffer.load('/tmp/vbuf.pkl') + + if __name__ == '__main__': test_replaybuffer() test_ignore_obs_next() test_stack() + test_pickle() test_segtree() test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 1d7a80f3c..e699094bd 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,4 +1,5 @@ import torch +import pickle import numpy as np from typing import Any, Tuple, Union, Optional @@ -35,6 +36,7 @@ class ReplayBuffer: >>> # but there are only three valid items, so len(buf) == 3. >>> len(buf) 3 + >>> buf.save('old_buf.pkl') # save to file "old_buf.pkl" >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) @@ -54,6 +56,11 @@ class ReplayBuffer: >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) + >>> len(buf) + 13 + >>> buf = ReplayBuffer.load('old_buf.pkl') # load from "old_buf.pkl" + >>> len(buf) + 3 :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next @@ -141,6 +148,15 @@ def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" return self._meta.__dict__[key] + def __setstate__(self, state): + """unpickling interface""" + self.__dict__.update(state) + + def __getstate__(self, data={}): + """pickling interface""" + data.update(self.__dict__) + return data + def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] @@ -159,6 +175,18 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst + def save(self, filename): + """Save data to a pickle file.""" + pickle.dump(self, open(filename, 'wb')) + + @classmethod + def load(cls, filename): + """Load data from a pickle file.""" + buf = pickle.load(open(filename, 'rb')) + assert type(buf) == cls, \ + f"Cannot load a {cls} from a {type(buf)}." + return buf + @property def stack_num(self): return self._stack From 7ad05a5c747101b1820a5d45634b4e85161d3d6f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 09:53:04 +0800 Subject: [PATCH 02/19] fix assert error log --- tianshou/data/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index e699094bd..0bd3bd963 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -184,7 +184,7 @@ def load(cls, filename): """Load data from a pickle file.""" buf = pickle.load(open(filename, 'rb')) assert type(buf) == cls, \ - f"Cannot load a {cls} from a {type(buf)}." + f"Cannot load a {cls.__name__} from a {buf.__class__.__name__}." return buf @property From e6592b0ed3a3950fc370b6fdda96f49db8ce1af2 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 10:19:23 +0800 Subject: [PATCH 03/19] cuda tensor in test buffer --- test/base/test_buffer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4b9f36eda..f2ae2ca5b 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,3 +1,4 @@ +import torch import pytest import numpy as np from timeit import timeit @@ -219,13 +220,16 @@ def test_pickle(): vbuf = ReplayBuffer(size, stack_num=2) lbuf = ListReplayBuffer() pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + rew = torch.tensor([1.]).to(device) + print(rew) for i in range(4): - vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=1, done=0) + vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) for i in range(3): - lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=1, done=0) + lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0) for i in range(5): pbuf.add(obs=Batch(index=np.array([i])), - act=2, rew=1, done=0, weight=np.random.rand()) + act=2, rew=rew, done=0, weight=np.random.rand()) # save vbuf.save('/tmp/vbuf.pkl') lbuf.save('/tmp/lbuf.pkl') From f74b1c8dc59ec23415885f0f84904b330947faca Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 10:32:14 +0800 Subject: [PATCH 04/19] remove buffer.load and buffer.save --- test/base/test_buffer.py | 20 ++++++++------------ tianshou/data/buffer.py | 19 +++---------------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index f2ae2ca5b..f64f26ca7 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,4 +1,5 @@ import torch +import pickle import pytest import numpy as np from timeit import timeit @@ -231,13 +232,13 @@ def test_pickle(): pbuf.add(obs=Batch(index=np.array([i])), act=2, rew=rew, done=0, weight=np.random.rand()) # save - vbuf.save('/tmp/vbuf.pkl') - lbuf.save('/tmp/lbuf.pkl') - pbuf.save('/tmp/pbuf.pkl') - # normal load - _vbuf = ReplayBuffer.load('/tmp/vbuf.pkl') - _lbuf = ListReplayBuffer.load('/tmp/lbuf.pkl') - _pbuf = PrioritizedReplayBuffer.load('/tmp/pbuf.pkl') + pickle.dump(vbuf, open('/tmp/vbuf.pkl', 'wb')) + pickle.dump(lbuf, open('/tmp/lbuf.pkl', 'wb')) + pickle.dump(pbuf, open('/tmp/pbuf.pkl', 'wb')) + # load + _vbuf = pickle.load(open('/tmp/vbuf.pkl', 'rb')) + _lbuf = pickle.load(open('/tmp/lbuf.pkl', 'rb')) + _pbuf = pickle.load(open('/tmp/pbuf.pkl', 'rb')) assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act) assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act) assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) @@ -245,11 +246,6 @@ def test_pickle(): assert _vbuf.stack_num == vbuf.stack_num assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))]) - # load data by inconsistent class will raise an error - with pytest.raises(AssertionError): - ReplayBuffer.load('/tmp/lbuf.pkl') - with pytest.raises(AssertionError): - PrioritizedReplayBuffer.load('/tmp/vbuf.pkl') if __name__ == '__main__': diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 0bd3bd963..fe692b2a7 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,5 +1,4 @@ import torch -import pickle import numpy as np from typing import Any, Tuple, Union, Optional @@ -24,7 +23,7 @@ class ReplayBuffer: The following code snippet illustrates its usage: :: - >>> import numpy as np + >>> import pickle, numpy as np >>> from tianshou.data import ReplayBuffer >>> buf = ReplayBuffer(size=20) >>> for i in range(3): @@ -36,7 +35,7 @@ class ReplayBuffer: >>> # but there are only three valid items, so len(buf) == 3. >>> len(buf) 3 - >>> buf.save('old_buf.pkl') # save to file "old_buf.pkl" + >>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl" >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) @@ -58,7 +57,7 @@ class ReplayBuffer: array([ True, True, True, True]) >>> len(buf) 13 - >>> buf = ReplayBuffer.load('old_buf.pkl') # load from "old_buf.pkl" + >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" >>> len(buf) 3 @@ -175,18 +174,6 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst - def save(self, filename): - """Save data to a pickle file.""" - pickle.dump(self, open(filename, 'wb')) - - @classmethod - def load(cls, filename): - """Load data from a pickle file.""" - buf = pickle.load(open(filename, 'rb')) - assert type(buf) == cls, \ - f"Cannot load a {cls.__name__} from a {buf.__class__.__name__}." - return buf - @property def stack_num(self): return self._stack From d380c539d19d183a9a66300b468b947b77b762c1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 10:33:59 +0800 Subject: [PATCH 05/19] simplify getstate --- tianshou/data/buffer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index fe692b2a7..02d987a0f 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -151,10 +151,9 @@ def __setstate__(self, state): """unpickling interface""" self.__dict__.update(state) - def __getstate__(self, data={}): + def __getstate__(self): """pickling interface""" - data.update(self.__dict__) - return data + return self.__dict__ def _add_to_buffer(self, name: str, inst: Any) -> None: try: From c2d731d42cea9d6b67ff7196959365b16c9c79d0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 10:48:30 +0800 Subject: [PATCH 06/19] use loads and dumps --- test/base/test_buffer.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index f64f26ca7..3e03c7f69 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -231,14 +231,10 @@ def test_pickle(): for i in range(5): pbuf.add(obs=Batch(index=np.array([i])), act=2, rew=rew, done=0, weight=np.random.rand()) - # save - pickle.dump(vbuf, open('/tmp/vbuf.pkl', 'wb')) - pickle.dump(lbuf, open('/tmp/lbuf.pkl', 'wb')) - pickle.dump(pbuf, open('/tmp/pbuf.pkl', 'wb')) - # load - _vbuf = pickle.load(open('/tmp/vbuf.pkl', 'rb')) - _lbuf = pickle.load(open('/tmp/lbuf.pkl', 'rb')) - _pbuf = pickle.load(open('/tmp/pbuf.pkl', 'rb')) + # save & load + _vbuf = pickle.loads(pickle.dumps(vbuf)) + _lbuf = pickle.loads(pickle.dumps(lbuf)) + _pbuf = pickle.loads(pickle.dumps(pbuf)) assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act) assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act) assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) From e93d0810adf25f6f3944e91e4fdfa85c5c78725f Mon Sep 17 00:00:00 2001 From: n+e Date: Wed, 12 Aug 2020 13:06:38 +0800 Subject: [PATCH 07/19] Update test/base/test_buffer.py Co-authored-by: youkaichao --- test/base/test_buffer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 3e03c7f69..8a76b0c9a 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -223,7 +223,6 @@ def test_pickle(): pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) device = 'cuda' if torch.cuda.is_available() else 'cpu' rew = torch.tensor([1.]).to(device) - print(rew) for i in range(4): vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) for i in range(3): From 77bb6aa6d15f521b640b8873452628a645eed67e Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 21:36:41 +0800 Subject: [PATCH 08/19] simplify code --- tianshou/data/buffer.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 02d987a0f..31c138f61 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -145,16 +145,10 @@ def __repr__(self) -> str: def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" + if key.startswith('__'): # since we do not use key begin with "__" + raise AttributeError return self._meta.__dict__[key] - def __setstate__(self, state): - """unpickling interface""" - self.__dict__.update(state) - - def __getstate__(self): - """pickling interface""" - return self.__dict__ - def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] @@ -394,7 +388,7 @@ def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" if key == 'weight': return self._weight - return self._meta.__dict__[key] + return super().__getattr__(key) def add(self, obs: Union[dict, np.ndarray], From 821399f8b2ad6fef099ba43a98a056eefd0e9b43 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Wed, 12 Aug 2020 21:42:52 +0800 Subject: [PATCH 09/19] remove a __dict__ --- tianshou/data/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 31c138f61..705df4246 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -147,7 +147,7 @@ def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" if key.startswith('__'): # since we do not use key begin with "__" raise AttributeError - return self._meta.__dict__[key] + return self._meta[key] def _add_to_buffer(self, name: str, inst: Any) -> None: try: From a5387cab09ac4a63f2e107a078cfc45f4a0e14f5 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 13 Aug 2020 07:57:08 +0800 Subject: [PATCH 10/19] exclude __xxx__ --- tianshou/data/buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 705df4246..100303ba4 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -145,8 +145,8 @@ def __repr__(self) -> str: def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" - if key.startswith('__'): # since we do not use key begin with "__" - raise AttributeError + if key.startswith('__') and key.endswith('__'): + raise AttributeError # since we do not use undefined key "__xxx__" return self._meta[key] def _add_to_buffer(self, name: str, inst: Any) -> None: From 0c7dc6cf24951d80490b4015e99c0524bbca7f85 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 14 Aug 2020 08:33:37 +0800 Subject: [PATCH 11/19] resolve RecursionError --- tianshou/data/buffer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 100303ba4..b0320f92e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -145,9 +145,13 @@ def __repr__(self) -> str: def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" - if key.startswith('__') and key.endswith('__'): - raise AttributeError # since we do not use undefined key "__xxx__" - return self._meta[key] + if '_meta' not in self.__dict__: + # pickle.load will not init self._meta at first place + raise AttributeError + try: + return self._meta[key] + except KeyError: + raise AttributeError def _add_to_buffer(self, name: str, inst: Any) -> None: try: From bd5a4be5ee72dbdbf20db46cc16cfe80fa673b1a Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 14 Aug 2020 13:29:03 +0800 Subject: [PATCH 12/19] setstate --- tianshou/data/buffer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b0320f92e..364bcf03a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -143,16 +143,16 @@ def __repr__(self) -> str: """Return str(self).""" return self.__class__.__name__ + self._meta.__repr__()[5:] - def __getattr__(self, key: str) -> Union['Batch', Any]: + def __getattr__(self, key: str) -> Any: """Return self.key""" - if '_meta' not in self.__dict__: - # pickle.load will not init self._meta at first place - raise AttributeError try: return self._meta[key] except KeyError: raise AttributeError + def __setstate__(self, state): + self.__dict__.update(state) + def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] @@ -284,11 +284,12 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, 1 if indice.step is None else indice.step) else: indice = np.array(indice, copy=True) + done = self._meta.done # set last frame done to True last_index = (self._index - 1 + self._size) % self._size - last_done, self.done[last_index] = self.done[last_index], True + last_done, done[last_index] = done[last_index], True if key == 'obs_next' and (not self._save_s_ or self.obs_next is None): - indice += 1 - self.done[indice].astype(np.int) + indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' val = self._meta.__dict__[key] @@ -300,7 +301,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, pre_indice = np.asarray(indice - 1) pre_indice[pre_indice == -1] = self._size - 1 indice = np.asarray( - pre_indice + self.done[pre_indice].astype(np.int)) + pre_indice + done[pre_indice].astype(np.int)) indice[indice == self._size] = 0 if isinstance(val, Batch): stack = Batch.stack(stack, axis=indice.ndim) @@ -312,7 +313,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, stack = Batch() if not isinstance(val, Batch) or len(val.__dict__) > 0: raise e - self.done[last_index] = last_done + done[last_index] = last_done return stack def __getitem__(self, index: Union[ From 9325a49294e454ce52b29106c28208e0374c3a40 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Fri, 14 Aug 2020 13:33:19 +0800 Subject: [PATCH 13/19] doc --- tianshou/data/buffer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 364bcf03a..c3bbfdc70 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -151,6 +151,9 @@ def __getattr__(self, key: str) -> Any: raise AttributeError def __setstate__(self, state): + """Unpickling interface. We need it because pickling buffer does not + work out-of-the-box (buffer.__getattr__ is customized). + """ self.__dict__.update(state) def _add_to_buffer(self, name: str, inst: Any) -> None: From a302db841a4ab0f8a0cead9ad8dbe4ac3378db58 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 14 Aug 2020 14:01:12 +0800 Subject: [PATCH 14/19] Update doc --- tianshou/data/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c3bbfdc70..f2cae507d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -152,7 +152,7 @@ def __getattr__(self, key: str) -> Any: def __setstate__(self, state): """Unpickling interface. We need it because pickling buffer does not - work out-of-the-box (buffer.__getattr__ is customized). + work out-of-the-box (``buffer.__getattr__`` is customized). """ self.__dict__.update(state) From 5f43d5aa1feeedd30a55660ce94f0eec2b97e9ae Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 15 Aug 2020 16:06:49 +0800 Subject: [PATCH 15/19] return the most often case first --- tianshou/data/buffer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index f2cae507d..7b15abc16 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -276,6 +276,12 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, """ if stack_num is None: stack_num = self.stack_num + # the most often case: stack_num == 1 + if stack_num == 1 and not (key == 'obs_next' and not self._save_s_): + try: + return self._meta[key][indice] + except IndexError: + pass if isinstance(indice, slice): indice = np.arange( 0 if indice.start is None @@ -291,7 +297,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, # set last frame done to True last_index = (self._index - 1 + self._size) % self._size last_done, done[last_index] = done[last_index], True - if key == 'obs_next' and (not self._save_s_ or self.obs_next is None): + if key == 'obs_next' and not self._save_s_: indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' From 8f16864ac9b696f5138de2ae3d7f93736c1759cb Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 15 Aug 2020 16:16:58 +0800 Subject: [PATCH 16/19] polish condition --- tianshou/data/buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 7b15abc16..e577f5232 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -277,7 +277,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, if stack_num is None: stack_num = self.stack_num # the most often case: stack_num == 1 - if stack_num == 1 and not (key == 'obs_next' and not self._save_s_): + if stack_num == 1 and (key != 'obs_next' or self._save_s_): try: return self._meta[key][indice] except IndexError: From 9ca276452fcd828f8f2ba459a5bee4deda9bff4e Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 15 Aug 2020 19:12:02 +0800 Subject: [PATCH 17/19] improve buffer.get --- test/base/test_buffer.py | 32 ++++++++++++----- tianshou/data/buffer.py | 74 +++++++++++++++++----------------------- 2 files changed, 55 insertions(+), 51 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 8a76b0c9a..393534c03 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -41,10 +41,11 @@ def test_replaybuffer(size=10, bufsize=20): def test_ignore_obs_next(size=10): # Issue 82 - buf = ReplayBuffer(size, ignore_obs_net=True) + buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]), - 'mask2': np.array([i + 4, 0, 1, 0, 0])}, + 'mask2': np.array([i + 4, 0, 1, 0, 0]), + 'mask': i}, act={'act_id': i, 'position_id': i + 3}, rew=i, @@ -57,6 +58,22 @@ def test_ignore_obs_next(size=10): assert isinstance(data, Batch) assert isinstance(data2, Batch) assert np.allclose(indice, orig) + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) + buf.stack_num = 4 + data = buf[indice] + data2 = buf[indice] + assert np.allclose(data.obs_next.mask, data2.obs_next.mask) + assert np.allclose(data.obs_next.mask, np.array([ + [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], + [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6], + [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]])) + assert np.allclose(data.info['if'], data2.info['if']) + assert np.allclose(data.info['if'], np.array([ + [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], + [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6], + [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]])) + assert data.obs_next def test_stack(size=5, bufsize=9, stack_num=4): @@ -64,7 +81,7 @@ def test_stack(size=5, bufsize=9, stack_num=4): buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) obs = env.reset(1) - for i in range(15): + for i in range(16): obs_next, rew, done, info = env.step(1) buf.add(obs, 1, rew, done, None, info) buf2.add(obs, 1, rew, done, None, info) @@ -75,12 +92,11 @@ def test_stack(size=5, bufsize=9, stack_num=4): assert np.allclose(buf.get(indice, 'obs'), np.expand_dims( [[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]], axis=-1)) - print(buf) + [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]], axis=-1)) _, indice = buf2.sample(0) - assert indice == [2] + assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) - assert indice.sum() == 2 + assert indice in [2, 6] def test_priortized_replaybuffer(size=32, bufsize=15): @@ -109,7 +125,7 @@ def test_update(): buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): buf1.add(obs=np.array([i]), act=float(i), rew=i * i, - done=False, info={'incident': 'found'}) + done=i % 2 == 0, info={'incident': 'found'}) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index e577f5232..08f472df0 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -125,6 +125,7 @@ def __init__(self, size: int, stack_num: int = 1, sample_avail: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size + self._indices = np.arange(size) self._stack = None self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 @@ -164,9 +165,8 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: value = self._meta.__dict__[name] if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: raise ValueError( - "Cannot add data to a buffer with different shape, key: " - f"{name}, expect shape: {value.shape[1:]}, " - f"given shape: {inst.shape}.") + "Cannot add data to a buffer with different shape, with key " + f"{name}, expect {value.shape[1:]}, given {inst.shape}.") try: value[self._index] = inst except KeyError: @@ -276,54 +276,42 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, """ if stack_num is None: stack_num = self.stack_num - # the most often case: stack_num == 1 - if stack_num == 1 and (key != 'obs_next' or self._save_s_): - try: - return self._meta[key][indice] - except IndexError: - pass - if isinstance(indice, slice): - indice = np.arange( - 0 if indice.start is None - else self._size - indice.start if indice.start < 0 - else indice.start, - self._size if indice.stop is None - else self._size - indice.stop if indice.stop < 0 - else indice.stop, - 1 if indice.step is None else indice.step) - else: - indice = np.array(indice, copy=True) - done = self._meta.done - # set last frame done to True - last_index = (self._index - 1 + self._size) % self._size - last_done, done[last_index] = done[last_index], True + if stack_num == 1: # the most often case + if key != 'obs_next' or self._save_s_: + val = self._meta.__dict__[key] + try: + return val[indice] + except IndexError as e: + if not (isinstance(val, Batch) and val.is_empty()): + raise e # val != Batch() + return Batch() + indice = self._indices[:self._size][indice] + done = self._meta.__dict__['done'] if key == 'obs_next' and not self._save_s_: indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' val = self._meta.__dict__[key] try: - if stack_num > 1: - stack = [] - for _ in range(stack_num): - stack = [val[indice]] + stack - pre_indice = np.asarray(indice - 1) - pre_indice[pre_indice == -1] = self._size - 1 - indice = np.asarray( - pre_indice + done[pre_indice].astype(np.int)) - indice[indice == self._size] = 0 - if isinstance(val, Batch): - stack = Batch.stack(stack, axis=indice.ndim) - else: - stack = np.stack(stack, axis=indice.ndim) + if stack_num == 1: + return val[indice] + stack = [] + for _ in range(stack_num): + stack = [val[indice]] + stack + pre_indice = np.asarray(indice - 1) + pre_indice[pre_indice == -1] = self._size - 1 + indice = np.asarray( + pre_indice + done[pre_indice].astype(np.int)) + indice[indice == self._size] = 0 + if isinstance(val, Batch): + stack = Batch.stack(stack, axis=indice.ndim) else: - stack = val[indice] + stack = np.stack(stack, axis=indice.ndim) + return stack except IndexError as e: - stack = Batch() - if not isinstance(val, Batch) or len(val.__dict__) > 0: - raise e - done[last_index] = last_done - return stack + if not (isinstance(val, Batch) and val.is_empty()): + raise e # val != Batch() + return Batch() def __getitem__(self, index: Union[ slice, int, np.integer, np.ndarray]) -> Batch: From b289bafbeeb43335787f609289d504b9bf5296c2 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 16 Aug 2020 06:29:52 +0800 Subject: [PATCH 18/19] add a check --- tianshou/policy/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a2f545e48..01398ca3c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -231,7 +231,8 @@ def post_process_fn(self, batch: Batch, usage is to update the sampling weight in prioritized experience replay. Check out :ref:`policy_concept` for more information. """ - if isinstance(buffer, PrioritizedReplayBuffer): + if isinstance(buffer, PrioritizedReplayBuffer) \ + and hasattr(batch, 'weight'): buffer.update_weight(indice, batch.weight) def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): From 76d7f5a16bf255d203e29cc310e60ec44391df07 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 16 Aug 2020 16:08:43 +0800 Subject: [PATCH 19/19] raise from --- tianshou/data/buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 08f472df0..24aeb0bcc 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -148,8 +148,8 @@ def __getattr__(self, key: str) -> Any: """Return self.key""" try: return self._meta[key] - except KeyError: - raise AttributeError + except KeyError as e: + raise AttributeError from e def __setstate__(self, state): """Unpickling interface. We need it because pickling buffer does not