From 2564e989fb0528ab1120312ba8834e704da5482f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 09:44:47 +0800 Subject: [PATCH 1/4] Improve Batch (#126) * make sure the key type of Batch is string, and add unit tests * add is_empty() function and unit tests * enable cat of mixing dict and Batch, just like stack --- test/base/test_batch.py | 4 ++++ tianshou/data/batch.py | 20 +++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index e2390dec7..a9f2cdd20 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -9,6 +9,10 @@ def test_batch(): assert list(Batch()) == [] + assert Batch().is_empty() + assert not Batch(a=[1, 2, 3]).is_empty() + with pytest.raises(AssertionError): + Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch = Batch(obs=[0], np=np.zeros([3, 4])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e14b98f50..6b1051788 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -50,6 +50,12 @@ def _create_value(inst: Any, size: int) -> Union[ return np.array([None for _ in range(size)]) +def _assert_type_keys(keys): + keys = list(keys) + assert all(isinstance(e, str) for e in keys), \ + f"keys should all be string, but got {keys}" + + class Batch: """Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a @@ -247,6 +253,7 @@ def __init__(self, batch_dict = deepcopy(batch_dict) if batch_dict is not None: if isinstance(batch_dict, (dict, Batch)): + _assert_type_keys(batch_dict.keys()) for k, v in batch_dict.items(): if isinstance(v, (list, tuple, np.ndarray)): v_ = None @@ -511,12 +518,14 @@ def cat_(self, batch: 'Batch') -> None: raise TypeError(s) @staticmethod - def cat(batches: List['Batch']) -> 'Batch': - """Concatenate a :class:`~tianshou.data.Batch` object into a single + def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': + """Concatenate a list of :class:`~tianshou.data.Batch` object into a single new batch. """ batch = Batch() for batch_ in batches: + if isinstance(batch_, dict): + batch_ = Batch(batch_) batch.cat_(batch_) return batch @@ -531,6 +540,7 @@ def stack_(self, keys_shared = set.intersection(*keys_map) values_shared = [ [e[k] for e in batches] for k in keys_shared] + _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): self.__dict__[k] = Batch.stack(v, axis) @@ -542,6 +552,7 @@ def stack_(self, v = v.astype(np.object) self.__dict__[k] = v keys_partial = reduce(set.symmetric_difference, keys_map) + _assert_type_keys(keys_partial) for k in keys_partial: for i, e in enumerate(batches): val = e.get(k, None) @@ -554,7 +565,7 @@ def stack_(self, self.__dict__[k][i] = val @staticmethod - def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': + def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': """Stack a :class:`~tianshou.data.Batch` object into a single new batch. """ @@ -615,6 +626,9 @@ def __len__(self) -> int: raise TypeError("Object of type 'Batch' has no len()") return min(r) + def is_empty(self): + return len(self.__dict__.keys()) == 0 + @property def shape(self) -> List[int]: """Return self.shape.""" From affeec13de764c46e09588b3242cd6a9cbde87ea Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 21:46:01 +0800 Subject: [PATCH 2/4] Improve Batch (#128) * minor polish * improve and implement Batch.cat_ * bugfix for buffer.sample with field impt_weight * restore the usage of a.cat_(b) * fix 2 bugs in batch and add corresponding unittest * code fix for update * update is_empty to recognize empty over empty; bugfix for len * bugfix for update and add testcase * add testcase of update * fix docs * fix docs * fix docs [ci skip] * fix docs [ci skip] Co-authored-by: Trinkle23897 <463003665@qq.com> --- .github/workflows/pytest.yml | 1 + test/base/test_batch.py | 43 ++++++++- tianshou/data/batch.py | 167 +++++++++++++++++++++++------------ tianshou/data/buffer.py | 21 ++--- 4 files changed, 162 insertions(+), 70 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index eb14a875e..c1e3604ed 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -5,6 +5,7 @@ on: [push, pull_request] jobs: build: runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: python-version: [3.6, 3.7, 3.8] diff --git a/test/base/test_batch.py b/test/base/test_batch.py index a9f2cdd20..03031ff2e 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -10,7 +10,17 @@ def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() + assert Batch(b={'c': {}}).is_empty() + assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() + b = Batch() + b.update() + assert b.is_empty() + b.update(c=[3, 5]) + assert np.allclose(b.c, [3, 5]) + # mimic the behavior of dict.update, where kwargs can overwrite keys + b.update({'a': 2}, a=3) + assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) @@ -86,6 +96,18 @@ def test_batch(): assert batch3.a.d.f[0] == 5.0 with pytest.raises(KeyError): batch3.a.d[0] = Batch(f=5.0, g=0.0) + # auto convert + batch4 = Batch(a=np.array(['a', 'b'])) + assert batch4.a.dtype == np.object # auto convert to np.object + batch4.update(a=np.array(['c', 'd'])) + assert list(batch4.a) == ['c', 'd'] + assert batch4.a.dtype == np.object # auto convert to np.object + batch5 = Batch(a=np.array([{'index': 0}])) + assert isinstance(batch5.a, Batch) + assert np.allclose(batch5.a.index, [0]) + batch5.b = np.array([{'index': 1}]) + assert isinstance(batch5.b, Batch) + assert np.allclose(batch5.b.index, [1]) def test_batch_over_batch(): @@ -100,6 +122,11 @@ def test_batch_over_batch(): assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) + batch2.update(batch2.b, six=[6, 6, 6]) + assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) + assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) + assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) + assert np.allclose(batch2.six, [6, 6, 6]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) @@ -124,18 +151,32 @@ def test_batch_over_batch(): def test_batch_cat_and_stack(): + # test cat with compatible keys b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) - b12_cat_out = Batch.cat((b1, b2)) + b12_cat_out = Batch.cat([b1, b2]) b12_cat_in = copy.deepcopy(b1) b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 + b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 + + # test batch with incompatible keys + b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) + b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) + test = Batch.cat([b1, b2]) + ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + assert np.allclose(test.a, ans.a) + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6b1051788..1240cfe50 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -259,8 +259,7 @@ def __init__(self, v_ = None if not isinstance(v, np.ndarray) and \ all(isinstance(e, torch.Tensor) for e in v): - v_ = torch.stack(v) - self.__dict__[k] = v_ + self.__dict__[k] = torch.stack(v) continue else: v_ = np.asanyarray(v) @@ -294,7 +293,8 @@ def __setattr__(self, key: str, value: Any): value = np.array(value) if not issubclass(value.dtype.type, (np.bool_, np.number)): value = value.astype(np.object) - elif isinstance(value, dict): + elif isinstance(value, dict) or isinstance(value, np.ndarray) \ + and value.dtype == np.object and _is_batch_set(value): value = Batch(value) self.__dict__[key] = value @@ -333,9 +333,8 @@ def __getitem__(self, index: Union[ else: raise IndexError("Cannot access item from empty Batch object.") - def __setitem__( - self, - index: Union[str, slice, int, np.integer, np.ndarray, List[int]], + def __setitem__(self, index: Union[ + str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" if isinstance(value, np.ndarray): @@ -454,10 +453,8 @@ def to_numpy(self) -> None: elif isinstance(v, Batch): v.to_numpy() - def to_torch(self, - dtype: Optional[torch.dtype] = None, - device: Union[str, int, torch.device] = 'cpu' - ) -> None: + def to_torch(self, dtype: Optional[torch.dtype] = None, + device: Union[str, int, torch.device] = 'cpu') -> None: """Change all numpy.ndarray to torch.Tensor. This is an in-place operation. """ @@ -473,66 +470,111 @@ def to_torch(self, v = v.type(dtype) self.__dict__[k] = v elif isinstance(v, torch.Tensor): - if dtype is not None and v.dtype != dtype: - must_update_tensor = True - elif v.device.type != device.type: - must_update_tensor = True - elif device.index is not None and \ + if dtype is not None and v.dtype != dtype or \ + v.device.type != device.type or \ + device.index is not None and \ device.index != v.device.index: - must_update_tensor = True - else: - must_update_tensor = False - if must_update_tensor: if dtype is not None: v = v.type(dtype) self.__dict__[k] = v.to(device) elif isinstance(v, Batch): v.to_torch(dtype, device) - def append(self, batch: 'Batch') -> None: - warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be ' - 'removed soon, please use ' - ':meth:`~tianshou.data.Batch.cat`') - return self.cat_(batch) - - def cat_(self, batch: 'Batch') -> None: - """Concatenate a :class:`~tianshou.data.Batch` object into current - batch. + def cat_(self, + batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: + """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects + into current batch. """ - assert isinstance(batch, Batch), \ - 'Only Batch is allowed to be concatenated in-place!' - for k, v in batch.items(): - if v is None: - continue - if not hasattr(self, k) or self.__dict__[k] is None: - self.__dict__[k] = deepcopy(v) - elif isinstance(v, np.ndarray) and v.ndim > 0: - self.__dict__[k] = np.concatenate([self.__dict__[k], v]) - elif isinstance(v, torch.Tensor): - self.__dict__[k] = torch.cat([self.__dict__[k], v]) - elif isinstance(v, Batch): - self.__dict__[k].cat_(v) + if isinstance(batches, Batch): + batches = [batches] + if len(batches) == 0: + return + batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] + if len(self.__dict__) > 0: + batches = [self] + list(batches) + # partial keys will be padded by zeros + # with the shape of [len, rest_shape] + lens = [len(x) for x in batches] + keys_map = list(map(lambda e: set(e.keys()), batches)) + keys_shared = set.intersection(*keys_map) + values_shared = [ + [e[k] for e in batches] for k in keys_shared] + _assert_type_keys(keys_shared) + for k, v in zip(keys_shared, values_shared): + if all(isinstance(e, (dict, Batch)) for e in v): + self.__dict__[k] = Batch.cat(v) + elif all(isinstance(e, torch.Tensor) for e in v): + self.__dict__[k] = torch.cat(v) + else: + v = np.concatenate(v) + if not issubclass(v.dtype.type, (np.bool_, np.number)): + v = v.astype(np.object) + self.__dict__[k] = v + keys_partial = set.union(*keys_map) - keys_shared + _assert_type_keys(keys_partial) + for k in keys_partial: + is_dict = False + value = None + for i, e in enumerate(batches): + val = e.get(k, None) + if val is not None: + if isinstance(val, (dict, Batch)): + is_dict = True + else: # np.ndarray or torch.Tensor + value = val + break + if is_dict: + self.__dict__[k] = Batch.cat( + [e.get(k, Batch()) for e in batches]) else: - s = 'No support for method "cat" with type '\ - f'{type(v)} in class Batch.' - raise TypeError(s) + if isinstance(value, np.ndarray): + arrs = [] + for i, e in enumerate(batches): + shape = [lens[i]] + list(value.shape[1:]) + pad = np.zeros(shape, dtype=value.dtype) + arrs.append(e.get(k, pad)) + self.__dict__[k] = np.concatenate(arrs) + elif isinstance(value, torch.Tensor): + arrs = [] + for i, e in enumerate(batches): + shape = [lens[i]] + list(value.shape[1:]) + pad = torch.zeros(shape, + dtype=value.dtype, + device=value.device) + arrs.append(e.get(k, pad)) + self.__dict__[k] = torch.cat(arrs) + else: + raise TypeError( + f"cannot cat value with type {type(value)}, we only " + "support dict, Batch, np.ndarray, and torch.Tensor") @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': - """Concatenate a list of :class:`~tianshou.data.Batch` object into a single - new batch. + """Concatenate a list of :class:`~tianshou.data.Batch` object into a + single new batch. For keys that are not shared across all batches, + batches that do not have these keys will be padded by zeros with + appropriate shapes. E.g. + :: + + >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) + >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.cat([a, b]) + >>> c.a.shape + (7, 4) + >>> c.b.shape + (7, 3) + >>> c.common.c.shape + (7, 5) """ batch = Batch() - for batch_ in batches: - if isinstance(batch_, dict): - batch_ = Batch(batch_) - batch.cat_(batch_) + batch.cat_(batches) return batch def stack_(self, batches: List[Union[dict, 'Batch']], axis: int = 0) -> None: - """Stack a :class:`~tianshou.data.Batch` object i into current batch. + """Stack a list of :class:`~tianshou.data.Batch` object into current + batch. """ if len(self.__dict__) > 0: batches = [self] + list(batches) @@ -566,8 +608,8 @@ def stack_(self, @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': - """Stack a :class:`~tianshou.data.Batch` object into a single new - batch. + """Stack a list of :class:`~tianshou.data.Batch` object into a single + new batch. """ batch = Batch() batch.stack_(batches, axis) @@ -611,11 +653,24 @@ def empty(batch: 'Batch', index: Union[ """ return deepcopy(batch).empty_(index) + def update(self, batch: Optional[Union[dict, 'Batch']] = None, + **kwargs) -> None: + """Update this batch from another dict/Batch.""" + if batch is None: + self.update(kwargs) + return + if isinstance(batch, dict): + batch = Batch(batch) + for k, v in batch.items(): + self.__dict__[k] = v + if kwargs: + self.update(kwargs) + def __len__(self) -> int: """Return len(self).""" r = [] for v in self.__dict__.values(): - if isinstance(v, Batch) and len(v.__dict__) == 0: + if isinstance(v, Batch) and v.is_empty(): continue elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): @@ -627,7 +682,9 @@ def __len__(self) -> int: return min(r) def is_empty(self): - return len(self.__dict__.keys()) == 0 + return not any( + not x.is_empty() if isinstance(x, Batch) + else hasattr(x, '__len__') and len(x) > 0 for x in self.values()) @property def shape(self) -> List[int]: diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 33d31789d..f593d2a74 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -108,8 +108,7 @@ def __init__(self, size: int, stack_num: Optional[int] = 0, super().__init__() self._maxsize = size self._stack = stack_num - assert stack_num != 1, \ - 'stack_num should greater than 1' + assert stack_num != 1, 'stack_num should greater than 1' self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next @@ -136,12 +135,11 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, np.ndarray) and \ - value.shape[1:] != inst.shape: + if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, key: " - f"{name}, expect shape: {value.shape[1:]}" - f", given shape: {inst.shape}.") + f"{name}, expect shape: {value.shape[1:]}, " + f"given shape: {inst.shape}.") try: value[self._index] = inst except KeyError: @@ -357,7 +355,7 @@ def __init__(self, size: int, alpha: float, beta: float, self._weight_sum = 0.0 self._amortization_freq = 50 self._replace = replace - self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64) + self._meta.weight = np.zeros(size, dtype=np.float64) def add(self, obs: Union[dict, np.ndarray], @@ -372,7 +370,7 @@ def add(self, """Add a batch of data into replay buffer.""" # we have to sacrifice some convenience for speed self._weight_sum += np.abs(weight) ** self._alpha - \ - self._meta.__dict__['weight'][self._index] + self._meta.weight[self._index] self._add_to_buffer('weight', np.abs(weight) ** self._alpha) super().add(obs, act, rew, done, obs_next, info, policy) @@ -410,14 +408,9 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: f"batch_size should be less than {len(self)}, \ or set replace=True") batch = self[indice] - impt_weight = Batch( - impt_weight=(self._size * p) ** (-self._beta)) - batch.cat_(impt_weight) + batch["impt_weight"] = (self._size * p) ** (-self._beta) return batch, indice - def reset(self) -> None: - super().reset() - def update_weight(self, indice: Union[slice, np.ndarray], new_weight: np.ndarray) -> None: """Update priority weight by indice in this buffer. From 5599a6d1a65543715bddf37e26cae2fa9ec566e3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 Jul 2020 23:45:42 +0800 Subject: [PATCH 3/4] Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130) * re-implement Batch.stack and add testcases * add doc for Batch.stack * reuse _create_values and refactor stack_ & cat_ * fix pep8 * fix docs * raise exception for stacking with partial keys and axis!=0 * minor fix * minor fix Co-authored-by: Trinkle23897 <463003665@qq.com> --- test/base/test_batch.py | 23 ++++++++- tianshou/data/batch.py | 104 +++++++++++++++++++++++----------------- 2 files changed, 82 insertions(+), 45 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 03031ff2e..7e287e269 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -166,7 +166,7 @@ def test_batch_cat_and_stack(): assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 - # test batch with incompatible keys + # test cat with incompatible keys b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) @@ -177,6 +177,7 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # test stack with compatible keys b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) @@ -194,6 +195,26 @@ def test_batch_cat_and_stack(): assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 + # test stack with incompatible keys + a = Batch(a=1, b=2, c=3) + b = Batch(a=4, b=5, d=6) + c = Batch(c=7, b=6, d=9) + d = Batch.stack([a, b, c]) + assert np.allclose(d.a, [1, 4, 0]) + assert np.allclose(d.b, [2, 5, 6]) + assert np.allclose(d.c, [3, 0, 7]) + assert np.allclose(d.d, [0, 6, 9]) + + b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) + test = Batch.stack([b1, b2]) + ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), + b=torch.stack([torch.zeros(4, 6), b2.b]), + common=Batch(c=np.stack([b1.common.c, b2.common.c]))) + assert np.allclose(test.a, ans.a) + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + def test_batch_over_batch_to_torch(): batch = Batch( diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1240cfe50..415c8f38a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,7 +3,6 @@ import warnings import numpy as np from copy import deepcopy -from functools import reduce from numbers import Number from typing import Any, List, Tuple, Union, Iterator, Optional @@ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool: return False -def _create_value(inst: Any, size: int) -> Union[ +def _create_value(inst: Any, size: int, stack=True) -> Union[ 'Batch', np.ndarray, torch.Tensor]: + """ + :param bool stack: whether to stack or to concatenate. E.g. if inst has + shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape + of (10, 3, 5), otherwise (10, 5) + """ + has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) + is_scalar = \ + isinstance(inst, Number) or \ + issubclass(inst.__class__, np.generic) or \ + (has_shape and not inst.shape) + if not stack and is_scalar: + # here we do not consider scalar types, following the + # behavior of numpy which does not support concatenation + # of zero-dimensional arrays (scalars) + raise TypeError(f"cannot cat {inst} with which is scalar") + if has_shape: + shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) if isinstance(inst, np.ndarray): if issubclass(inst.dtype.type, (np.bool_, np.number)): target_type = inst.dtype.type else: target_type = np.object - return np.full((size, *inst.shape), + return np.full(shape, fill_value=None if target_type == np.object else 0, dtype=target_type) elif isinstance(inst, torch.Tensor): - return torch.full((size, *inst.shape), + return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): - zero_batch.__dict__[key] = _create_value(val, size) + zero_batch.__dict__[key] = _create_value(val, size, stack=stack) return zero_batch - elif isinstance(inst, (np.generic, Number)): - return _create_value(np.asarray(inst), size) + elif is_scalar: + return _create_value(np.asarray(inst), size, stack=stack) else: # fall back to np.object return np.array([None for _ in range(size)]) @@ -495,10 +511,12 @@ def cat_(self, # partial keys will be padded by zeros # with the shape of [len, rest_shape] lens = [len(x) for x in batches] + sum_lens = [0] + for x in lens: + sum_lens.append(sum_lens[-1] + x) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) - values_shared = [ - [e[k] for e in batches] for k in keys_shared] + values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): @@ -513,40 +531,15 @@ def cat_(self, keys_partial = set.union(*keys_map) - keys_shared _assert_type_keys(keys_partial) for k in keys_partial: - is_dict = False - value = None for i, e in enumerate(batches): val = e.get(k, None) if val is not None: - if isinstance(val, (dict, Batch)): - is_dict = True - else: # np.ndarray or torch.Tensor - value = val - break - if is_dict: - self.__dict__[k] = Batch.cat( - [e.get(k, Batch()) for e in batches]) - else: - if isinstance(value, np.ndarray): - arrs = [] - for i, e in enumerate(batches): - shape = [lens[i]] + list(value.shape[1:]) - pad = np.zeros(shape, dtype=value.dtype) - arrs.append(e.get(k, pad)) - self.__dict__[k] = np.concatenate(arrs) - elif isinstance(value, torch.Tensor): - arrs = [] - for i, e in enumerate(batches): - shape = [lens[i]] + list(value.shape[1:]) - pad = torch.zeros(shape, - dtype=value.dtype, - device=value.device) - arrs.append(e.get(k, pad)) - self.__dict__[k] = torch.cat(arrs) - else: - raise TypeError( - f"cannot cat value with type {type(value)}, we only " - "support dict, Batch, np.ndarray, and torch.Tensor") + try: + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val + except KeyError: + self.__dict__[k] = \ + _create_value(val, sum_lens[-1], stack=False) + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': @@ -576,12 +569,14 @@ def stack_(self, """Stack a list of :class:`~tianshou.data.Batch` object into current batch. """ + if len(batches) == 0: + return + batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] if len(self.__dict__) > 0: batches = [self] + list(batches) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) - values_shared = [ - [e[k] for e in batches] for k in keys_shared] + values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): @@ -593,7 +588,11 @@ def stack_(self, if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) self.__dict__[k] = v - keys_partial = reduce(set.symmetric_difference, keys_map) + keys_partial = set.difference(set.union(*keys_map), keys_shared) + if keys_partial and axis != 0: + raise ValueError( + f"Stack of Batch with non-shared keys {keys_partial} " + f"is only supported with axis=0, but got axis={axis}!") _assert_type_keys(keys_partial) for k in keys_partial: for i, e in enumerate(batches): @@ -609,7 +608,24 @@ def stack_(self, @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': """Stack a list of :class:`~tianshou.data.Batch` object into a single - new batch. + new batch. For keys that are not shared across all batches, + batches that do not have these keys will be padded by zeros. E.g. + :: + + >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) + >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.stack([a, b]) + >>> c.a.shape + (2, 4, 4) + >>> c.b.shape + (2, 4, 6) + >>> c.common.c.shape + (2, 4, 5) + + .. note:: + + If there are keys that are not shared across all batches, ``stack`` + with ``axis != 0`` is undefined, and will cause an exception. """ batch = Batch() batch.stack_(batches, axis) From 26fb87433de0d2604078ecbc99502efe0a815d5d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 13 Jul 2020 00:24:31 +0800 Subject: [PATCH 4/4] Improve collector (#125) * remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com> --- test/base/env.py | 38 ++--- test/base/test_collector.py | 50 ++++++- tianshou/data/collector.py | 272 +++++++++++++++--------------------- 3 files changed, 183 insertions(+), 177 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 1aa409fca..b0962154f 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -1,19 +1,34 @@ -import time import gym +import time from gym.spaces.discrete import Discrete class MyTestEnv(gym.Env): - def __init__(self, size, sleep=0, dict_state=False): + """This is a "going right" task. The task is to go right ``size`` steps. + """ + + def __init__(self, size, sleep=0, dict_state=False, ma_rew=0): self.size = size self.sleep = sleep self.dict_state = dict_state + self.ma_rew = ma_rew self.action_space = Discrete(2) self.reset() def reset(self, state=0): self.done = False self.index = state + return self._get_dict_state() + + def _get_reward(self): + """Generate a non-scalar reward if ma_rew is True.""" + x = int(self.done) + if self.ma_rew > 0: + return [x] * self.ma_rew + return x + + def _get_dict_state(self): + """Generate a dict_state if dict_state is True.""" return {'index': self.index} if self.dict_state else self.index def step(self, action): @@ -23,22 +38,13 @@ def step(self, action): time.sleep(self.sleep) if self.index == self.size: self.done = True - if self.dict_state: - return {'index': self.index}, 0, True, {} - else: - return self.index, 0, True, {} + return self._get_dict_state(), self._get_reward(), self.done, {} if action == 0: self.index = max(self.index - 1, 0) - if self.dict_state: - return {'index': self.index}, 0, False, {'key': 1, 'env': self} - else: - return self.index, 0, False, {} + return self._get_dict_state(), self._get_reward(), self.done, \ + {'key': 1, 'env': self} if self.dict_state else {} elif action == 1: self.index += 1 self.done = self.index == self.size - if self.dict_state: - return {'index': self.index}, int(self.done), self.done, \ - {'key': 1, 'env': self} - else: - return self.index, int(self.done), self.done, \ - {'key': 1, 'env': self} + return self._get_dict_state(), self._get_reward(), \ + self.done, {'key': 1, 'env': self} diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 16fbdda8d..ead017a01 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -27,16 +27,16 @@ def learn(self): def preprocess_fn(**kwargs): # modify info before adding into the buffer - if kwargs.get('info', None) is not None: + # if info is not provided from env, it will be a ``Batch()``. + if not kwargs.get('info', Batch()).is_empty(): n = len(kwargs['obs']) info = kwargs['info'] for i in range(n): info[i].update(rew=kwargs['rew'][i]) return {'info': info} - # or - # return Batch(info=info) + # or: return Batch(info=info) else: - return {} + return Batch() class Logger(object): @@ -119,6 +119,48 @@ def test_collector_with_dict_state(): print(batch['obs_next']['index']) +def test_collector_with_ma(): + def reward_metric(x): + return x.sum() + env = MyTestEnv(size=5, sleep=0, ma_rew=4) + policy = MyPolicy() + c0 = Collector(policy, env, ReplayBuffer(size=100), + preprocess_fn, reward_metric=reward_metric) + r = c0.collect(n_step=3)['rew'] + assert np.asanyarray(r).size == 1 and r == 0. + r = c0.collect(n_episode=3)['rew'] + assert np.asanyarray(r).size == 1 and r == 4. + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) + for i in [2, 3, 4, 5]] + envs = VectorEnv(env_fns) + c1 = Collector(policy, envs, ReplayBuffer(size=100), + preprocess_fn, reward_metric=reward_metric) + r = c1.collect(n_step=10)['rew'] + assert np.asanyarray(r).size == 1 and r == 4. + r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] + assert np.asanyarray(r).size == 1 and r == 4. + batch = c1.sample(10) + print(batch) + c0.buffer.update(c1.buffer) + obs = [ + 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., + 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., + 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.] + assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs) + rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, + 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, + 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1] + assert np.allclose(c0.buffer[:len(c0.buffer)].rew, + [[x] * 4 for x in rew]) + c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), + preprocess_fn, reward_metric=reward_metric) + r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] + assert np.asanyarray(r).size == 1 and r == 4. + batch = c2.sample(10) + print(batch['obs_next']) + + if __name__ == '__main__': test_collector() test_collector_with_dict_state() + test_collector_with_ma() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3a7ad7821..40cd7390c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -8,8 +8,8 @@ from tianshou.utils import MovAvg from tianshou.env import BaseVectorEnv from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy from tianshou.exploration import BaseNoise +from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy class Collector(object): @@ -25,12 +25,18 @@ class Collector(object): ``None``, it will automatically assign a small-size :class:`~tianshou.data.ReplayBuffer`. :param function preprocess_fn: a function called before the data has been - added to the buffer, see issue #42, defaults to ``None``. + added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults + to ``None``. :param int stat_size: for the moving average of recording speed, defaults to 100. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. + :param function reward_metric: to be used in multi-agent RL. The reward to + report is of shape [agent_num], but we need to return a single scalar + to monitor training. This function specifies what is the desired + metric, e.g., the reward of agent 1 or the average reward over all + agents. By default, the behavior is to select the reward of agent 1. The ``preprocess_fn`` is a function called before the data has been added to the buffer with batch format, which receives up to 7 keys as listed in @@ -87,68 +93,58 @@ class Collector(object): def __init__(self, policy: BasePolicy, env: Union[gym.Env, BaseVectorEnv], - buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]] - = None, + buffer: Optional[ReplayBuffer] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, stat_size: Optional[int] = 100, action_noise: Optional[BaseNoise] = None, + reward_metric: Optional[Callable[[np.ndarray], float]] = None, **kwargs) -> None: super().__init__() self.env = env self.env_num = 1 - self.collect_time = 0 - self.collect_step = 0 - self.collect_episode = 0 + self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 self.buffer = buffer self.policy = policy self.preprocess_fn = preprocess_fn - # if preprocess_fn is None: - # def _prep(**kwargs): - # return kwargs - # self.preprocess_fn = _prep self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) - self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: self.env_num = len(env) - if isinstance(self.buffer, list): - assert len(self.buffer) == self.env_num, \ - 'The number of data buffer does not match the number of ' \ - 'input env.' - self._multi_buf = True - elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None: - self._cached_buf = [ - ListReplayBuffer() for _ in range(self.env_num)] - else: - raise TypeError('The buffer in data collector is invalid!') + self._cached_buf = [ListReplayBuffer() + for _ in range(self.env_num)] self.stat_size = stat_size self._action_noise = action_noise + + self._rew_metric = reward_metric or Collector._default_rew_metric self.reset() + @staticmethod + def _default_rew_metric(x): + # this internal function is designed for single-agent RL + # for multi-agent RL, a reward_metric must be provided + assert np.asanyarray(x).size == 1, \ + 'Please specify the reward_metric ' \ + 'since the reward is not a scalar.' + return x + def reset(self) -> None: """Reset all related variables in the collector.""" + self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, + obs_next={}, policy={}) self.reset_env() self.reset_buffer() - # state over batch is either a list, an np.ndarray, or a torch.Tensor - self.state = None self.step_speed = MovAvg(self.stat_size) self.episode_speed = MovAvg(self.stat_size) - self.collect_step = 0 - self.collect_episode = 0 - self.collect_time = 0 + self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 if self._action_noise is not None: self._action_noise.reset() def reset_buffer(self) -> None: """Reset the main data buffer.""" - if self._multi_buf: - for b in self.buffer: - b.reset() - else: - if self.buffer is not None: - self.buffer.reset() + if self.buffer is not None: + self.buffer.reset() def get_env_num(self) -> int: """Return the number of environments the collector have.""" @@ -158,34 +154,28 @@ def reset_env(self) -> None: """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ - self._obs = self.env.reset() + obs = self.env.reset() if not self._multi_env: - self._obs = self._make_batch(self._obs) + obs = self._make_batch(obs) if self.preprocess_fn: - self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs) - self._act = self._rew = self._done = self._info = None - if self._multi_env: - self.reward = np.zeros(self.env_num) - self.length = np.zeros(self.env_num) - else: - self.reward, self.length = 0, 0 + obs = self.preprocess_fn(obs=obs).get('obs', obs) + self.data.obs = obs + self.reward = 0. # will be specified when the first data is ready + self.length = np.zeros(self.env_num) for b in self._cached_buf: b.reset() def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: """Reset all the seed(s) of the given environment(s).""" - if hasattr(self.env, 'seed'): - return self.env.seed(seed) + return self.env.seed(seed) def render(self, **kwargs) -> None: """Render all the environment(s).""" - if hasattr(self.env, 'render'): - return self.env.render(**kwargs) + return self.env.render(**kwargs) def close(self) -> None: """Close the environment(s).""" - if hasattr(self.env, 'close'): - self.env.close() + self.env.close() def _make_batch(self, data: Any) -> np.ndarray: """Return [data].""" @@ -195,20 +185,14 @@ def _make_batch(self, data: Any) -> np.ndarray: return np.array([data]) def _reset_state(self, id: Union[int, List[int]]) -> None: - """Reset self.state[id].""" - if self.state is None: - return - if isinstance(self.state, list): - self.state[id] = None - elif isinstance(self.state, torch.Tensor): - self.state[id].zero_() - elif isinstance(self.state, np.ndarray): - if isinstance(self.state.dtype == np.object): - self.state[id] = None - else: - self.state[id] = 0 - elif isinstance(self.state, Batch): - self.state.empty_(id) + """Reset self.data.state[id].""" + state = self.data.state # it is a reference + if isinstance(state, torch.Tensor): + state[id].zero_() + elif isinstance(state, np.ndarray): + state[id] = None if state.dtype == np.object else 0 + elif isinstance(state, Batch): + state.empty_(id) def collect(self, n_step: int = 0, @@ -244,26 +228,27 @@ def collect(self, * ``rew`` the mean reward over collected episodes. * ``len`` the mean length over collected episodes. """ - warning_count = 0 if not self._multi_env: n_episode = np.sum(n_episode) start_time = time.time() assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ "One and only one collection number specification is permitted!" - cur_step = 0 - cur_episode = np.zeros(self.env_num) if self._multi_env else 0 - reward_sum = 0 - length_sum = 0 + cur_step, cur_episode = 0, np.zeros(self.env_num) + reward_sum, length_sum = 0., 0 while True: - if warning_count >= 100000: + if cur_step >= 100000 and cur_episode.sum() == 0: warnings.warn( 'There are already many steps in an episode. ' 'You should add a time limitation to your environment!', Warning) - batch = Batch( - obs=self._obs, act=self._act, rew=self._rew, - done=self._done, obs_next=None, info=self._info, - policy=None) + + # restore the state and the input data + last_state = self.data.state + if last_state.is_empty(): + last_state = None + self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) + + # calculate the next action if random: action_space = self.env.action_space if isinstance(action_space, list): @@ -272,69 +257,54 @@ def collect(self, result = Batch(act=self._make_batch(action_space.sample())) else: with torch.no_grad(): - result = self.policy(batch, self.state) + result = self.policy(self.data, last_state) - # save hidden state to policy._state, in order to save into buffer - self.state = result.get('state', None) + # convert None to Batch(), since None is reserved for 0-init + state = result.get('state', Batch()) + if state is None: + state = Batch() + self.data.state = state if hasattr(result, 'policy'): - self._policy = to_numpy(result.policy) - if self.state is not None: - self._policy._state = self.state - elif self.state is not None: - self._policy = Batch(_state=self.state) - else: - self._policy = [{}] * self.env_num + self.data.policy = to_numpy(result.policy) + # save hidden state to policy._state, in order to save into buffer + self.data.policy._state = self.data.state - self._act = to_numpy(result.act) + self.data.act = to_numpy(result.act) if self._action_noise is not None: - self._act += self._action_noise(self._act.shape) - obs_next, self._rew, self._done, self._info = self.env.step( - self._act if self._multi_env else self._act[0]) + self.data.act += self._action_noise(self.data.act.shape) + + # step in env + obs_next, rew, done, info = self.env.step( + self.data.act if self._multi_env else self.data.act[0]) + + # move data to self.data if not self._multi_env: obs_next = self._make_batch(obs_next) - self._rew = self._make_batch(self._rew) - self._done = self._make_batch(self._done) - self._info = self._make_batch(self._info) + rew = self._make_batch(rew) + done = self._make_batch(done) + info = self._make_batch(info) + self.data.obs_next = obs_next + self.data.rew = rew + self.data.done = done + self.data.info = info + if log_fn: - log_fn(self._info if self._multi_env else self._info[0]) + log_fn(info if self._multi_env else info[0]) if render: - self.env.render() + self.render() if render > 0: time.sleep(render) + + # add data into the buffer self.length += 1 - self.reward += self._rew + self.reward += self.data.rew if self.preprocess_fn: - result = self.preprocess_fn( - obs=self._obs, act=self._act, rew=self._rew, - done=self._done, obs_next=obs_next, info=self._info, - policy=self._policy) - self._obs = result.get('obs', self._obs) - self._act = result.get('act', self._act) - self._rew = result.get('rew', self._rew) - self._done = result.get('done', self._done) - obs_next = result.get('obs_next', obs_next) - self._info = result.get('info', self._info) - self._policy = result.get('policy', self._policy) - if self._multi_env: + result = self.preprocess_fn(**self.data) + self.data.update(result) + if self._multi_env: # cache_buffer branch for i in range(self.env_num): - data = { - 'obs': self._obs[i], 'act': self._act[i], - 'rew': self._rew[i], 'done': self._done[i], - 'obs_next': obs_next[i], 'info': self._info[i], - 'policy': self._policy[i]} - if self._cached_buf: - warning_count += 1 - self._cached_buf[i].add(**data) - elif self._multi_buf: - warning_count += 1 - self.buffer[i].add(**data) - cur_step += 1 - else: - warning_count += 1 - if self.buffer is not None: - self.buffer.add(**data) - cur_step += 1 - if self._done[i]: + self._cached_buf[i].add(**self.data[i]) + if self.data.done[i]: if n_step != 0 or np.isscalar(n_episode) or \ cur_episode[i] < n_episode[i]: cur_episode[i] += 1 @@ -344,46 +314,47 @@ def collect(self, cur_step += len(self._cached_buf[i]) if self.buffer is not None: self.buffer.update(self._cached_buf[i]) - self.reward[i], self.length[i] = 0, 0 + self.reward[i], self.length[i] = 0., 0 if self._cached_buf: self._cached_buf[i].reset() self._reset_state(i) - if sum(self._done): - obs_next = self.env.reset(np.where(self._done)[0]) + obs_next = self.data.obs_next + if sum(self.data.done): + obs_next = self.env.reset(np.where(self.data.done)[0]) if self.preprocess_fn: obs_next = self.preprocess_fn(obs=obs_next).get( 'obs', obs_next) + self.data.obs_next = obs_next if n_episode != 0: if isinstance(n_episode, list) and \ (cur_episode >= np.array(n_episode)).all() or \ np.isscalar(n_episode) and \ cur_episode.sum() >= n_episode: break - else: + else: # single buffer, without cache_buffer if self.buffer is not None: - self.buffer.add( - self._obs[0], self._act[0], self._rew[0], - self._done[0], obs_next[0], self._info[0], - self._policy[0]) + self.buffer.add(**self.data[0]) cur_step += 1 - if self._done: + if self.data.done[0]: cur_episode += 1 reward_sum += self.reward[0] - length_sum += self.length - self.reward, self.length = 0, 0 - self.state = None + length_sum += self.length[0] + self.reward, self.length = 0., np.zeros(self.env_num) + self.data.state = Batch() obs_next = self._make_batch(self.env.reset()) if self.preprocess_fn: obs_next = self.preprocess_fn(obs=obs_next).get( 'obs', obs_next) + self.data.obs_next = obs_next if n_episode != 0 and cur_episode >= n_episode: break if n_step != 0 and cur_step >= n_step: break - self._obs = obs_next - self._obs = obs_next - if self._multi_env: - cur_episode = sum(cur_episode) + self.data.obs = self.data.obs_next + self.data.obs = self.data.obs_next + + # generate the statistics + cur_episode = sum(cur_episode) duration = max(time.time() - start_time, 1e-9) self.step_speed.add(cur_step / duration) self.episode_speed.add(cur_episode / duration) @@ -394,12 +365,15 @@ def collect(self, n_episode = np.sum(n_episode) else: n_episode = max(cur_episode, 1) + reward_sum /= n_episode + if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum + reward_sum = self._rew_metric(reward_sum) return { 'n/ep': cur_episode, 'n/st': cur_step, 'v/st': self.step_speed.get(), 'v/ep': self.episode_speed.get(), - 'rew': reward_sum / n_episode, + 'rew': reward_sum, 'len': length_sum / n_episode, } @@ -412,22 +386,6 @@ def sample(self, batch_size: int) -> Batch: the buffer, otherwise it will extract the data with the given batch_size. """ - if self._multi_buf: - if batch_size > 0: - lens = [len(b) for b in self.buffer] - total = sum(lens) - batch_index = np.random.choice( - len(self.buffer), batch_size, p=np.array(lens) / total) - else: - batch_index = np.array([]) - batch_data = Batch() - for i, b in enumerate(self.buffer): - cur_batch = (batch_index == i).sum() - if batch_size and cur_batch or batch_size <= 0: - batch, indice = b.sample(cur_batch) - batch = self.process_fn(batch, b, indice) - batch_data.cat_(batch) - else: - batch_data, indice = self.buffer.sample(batch_size) - batch_data = self.process_fn(batch_data, self.buffer, indice) + batch_data, indice = self.buffer.sample(batch_size) + batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data