From aefcdfcde99be244ca591f78ea6467ecf5452a86 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 22:13:57 +0800 Subject: [PATCH 1/8] re-implement Batch.stack and add testcases --- test/base/test_batch.py | 23 ++++++++++++- tianshou/data/batch.py | 74 ++++++++++++++++++++++++++++++++--------- 2 files changed, 80 insertions(+), 17 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 03031ff2e..7e287e269 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -166,7 +166,7 @@ def test_batch_cat_and_stack(): assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 - # test batch with incompatible keys + # test cat with incompatible keys b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) @@ -177,6 +177,7 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # test stack with compatible keys b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) @@ -194,6 +195,26 @@ def test_batch_cat_and_stack(): assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 + # test stack with incompatible keys + a = Batch(a=1, b=2, c=3) + b = Batch(a=4, b=5, d=6) + c = Batch(c=7, b=6, d=9) + d = Batch.stack([a, b, c]) + assert np.allclose(d.a, [1, 4, 0]) + assert np.allclose(d.b, [2, 5, 6]) + assert np.allclose(d.c, [3, 0, 7]) + assert np.allclose(d.d, [0, 6, 9]) + + b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) + test = Batch.stack([b1, b2]) + ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), + b=torch.stack([torch.zeros(4, 6), b2.b]), + common=Batch(c=np.stack([b1.common.c, b2.common.c]))) + assert np.allclose(test.a, ans.a) + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + def test_batch_over_batch_to_torch(): batch = Batch( diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1240cfe50..5beb8db02 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,7 +3,6 @@ import warnings import numpy as np from copy import deepcopy -from functools import reduce from numbers import Number from typing import Any, List, Tuple, Union, Iterator, Optional @@ -530,20 +529,26 @@ def cat_(self, if isinstance(value, np.ndarray): arrs = [] for i, e in enumerate(batches): - shape = [lens[i]] + list(value.shape[1:]) - pad = np.zeros(shape, dtype=value.dtype) - arrs.append(e.get(k, pad)) + v = e.get(k, None) + if v is None: + shape = [lens[i]] + list(value.shape[1:]) + v = np.zeros(shape, dtype=value.dtype) + arrs.append(v) self.__dict__[k] = np.concatenate(arrs) elif isinstance(value, torch.Tensor): arrs = [] for i, e in enumerate(batches): - shape = [lens[i]] + list(value.shape[1:]) - pad = torch.zeros(shape, - dtype=value.dtype, - device=value.device) - arrs.append(e.get(k, pad)) + v = e.get(k, None) + if v is None: + shape = [lens[i]] + list(value.shape[1:]) + v = torch.zeros( + shape, dtype=value.dtype, device=value.device) + arrs.append(v) self.__dict__[k] = torch.cat(arrs) else: + # here we do not consider scalar types, following the + # behavior of numpy which does not support concatenation + # of zero-dimensional arrays (scalars) raise TypeError( f"cannot cat value with type {type(value)}, we only " "support dict, Batch, np.ndarray, and torch.Tensor") @@ -576,6 +581,9 @@ def stack_(self, """Stack a list of :class:`~tianshou.data.Batch` object into current batch. """ + if len(batches) == 0: + return + batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] if len(self.__dict__) > 0: batches = [self] + list(batches) keys_map = list(map(lambda e: set(e.keys()), batches)) @@ -593,18 +601,52 @@ def stack_(self, if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) self.__dict__[k] = v - keys_partial = reduce(set.symmetric_difference, keys_map) + keys_partial = set.union(*keys_map) - keys_shared _assert_type_keys(keys_partial) for k in keys_partial: + is_dict = False + value = None for i, e in enumerate(batches): val = e.get(k, None) if val is not None: - try: - self.__dict__[k][i] = val - except KeyError: - self.__dict__[k] = \ - _create_value(val, len(batches)) - self.__dict__[k][i] = val + if isinstance(val, (dict, Batch)): + is_dict = True + else: # np.ndarray / scalar or torch.Tensor + value = val + break + if is_dict: + self.__dict__[k] = Batch.stack( + [e.get(k, Batch()) for e in batches], axis) + else: + if isinstance(value, np.ndarray) \ + or issubclass(value.__class__, np.generic) \ + or isinstance(value, Number): + if isinstance(value, Number): + value = np.asarray(value) + arrs = [] + for i, e in enumerate(batches): + v = e.get(k, None) + if v is None: + v = np.zeros_like(value) + arrs.append(v) + v = np.stack(arrs, axis) + if not issubclass(v.dtype.type, (np.bool_, np.number)): + v = v.astype(np.object) + self.__dict__[k] = v + elif isinstance(value, torch.Tensor): + arrs = [] + for i, e in enumerate(batches): + v = e.get(k, None) + if v is None: + v = torch.zeros_like(value) + arrs.append(v) + self.__dict__[k] = torch.stack(arrs, axis) + else: # fallback to object + arrs = [] + for i, e in enumerate(batches): + v = e.get(k, None) + arrs.append(v) + self.__dict__[k] = np.array(arrs) @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': From b3ab3fe5d0f3e40c4c245d62e87f3d7d5dda367f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 22:16:50 +0800 Subject: [PATCH 2/8] add doc for Batch.stack --- tianshou/data/batch.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 5beb8db02..0ec80b1af 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -651,7 +651,19 @@ def stack_(self, @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': """Stack a list of :class:`~tianshou.data.Batch` object into a single - new batch. + new batch. For keys that are not shared across all batches, + batches that do not have these keys will be padded by zeros. E.g. + :: + + >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) + >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.stack([a, b]) + >>> c.a.shape + (8, 4) + >>> c.b.shape + (8, 6) + >>> c.common.c.shape + (8, 5) """ batch = Batch() batch.stack_(batches, axis) From f7d7482732d9f5bddf4aae01e929094ac6e119f8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 Jul 2020 01:33:53 +0800 Subject: [PATCH 3/8] reuse _create_values and refactor stack_ & cat_ --- tianshou/data/batch.py | 121 +++++++++++++---------------------------- 1 file changed, 37 insertions(+), 84 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 0ec80b1af..6976ea1d9 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -23,28 +23,43 @@ def _is_batch_set(data: Any) -> bool: return False -def _create_value(inst: Any, size: int) -> Union[ +def _create_value(inst: Any, size: int, stack=True) -> Union[ 'Batch', np.ndarray, torch.Tensor]: + """ + :param bool stack: whether to stack or to concatenate. E.g. if inst has + shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape + of (10, 3, 5), otherwise (10, 5) + """ + has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) + is_scalar = isinstance(inst, Number) or \ + issubclass(inst.__class__, np.generic) or (has_shape and not inst.shape) + if not stack and is_scalar: + # here we do not consider scalar types, following the + # behavior of numpy which does not support concatenation + # of zero-dimensional arrays (scalars) + raise TypeError(f"cannot cat {inst} with which is scalar") + if has_shape: + shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) if isinstance(inst, np.ndarray): if issubclass(inst.dtype.type, (np.bool_, np.number)): target_type = inst.dtype.type else: target_type = np.object - return np.full((size, *inst.shape), + return np.full(shape, fill_value=None if target_type == np.object else 0, dtype=target_type) elif isinstance(inst, torch.Tensor): - return torch.full((size, *inst.shape), + return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): - zero_batch.__dict__[key] = _create_value(val, size) + zero_batch.__dict__[key] = _create_value(val, size, stack=stack) return zero_batch - elif isinstance(inst, (np.generic, Number)): - return _create_value(np.asarray(inst), size) + elif is_scalar: + return _create_value(np.asarray(inst), size, stack=stack) else: # fall back to np.object return np.array([None for _ in range(size)]) @@ -494,6 +509,9 @@ def cat_(self, # partial keys will be padded by zeros # with the shape of [len, rest_shape] lens = [len(x) for x in batches] + sum_lens = [0] + for x in lens: + sum_lens.append(sum_lens[-1] + x) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) values_shared = [ @@ -512,46 +530,15 @@ def cat_(self, keys_partial = set.union(*keys_map) - keys_shared _assert_type_keys(keys_partial) for k in keys_partial: - is_dict = False - value = None for i, e in enumerate(batches): val = e.get(k, None) if val is not None: - if isinstance(val, (dict, Batch)): - is_dict = True - else: # np.ndarray or torch.Tensor - value = val - break - if is_dict: - self.__dict__[k] = Batch.cat( - [e.get(k, Batch()) for e in batches]) - else: - if isinstance(value, np.ndarray): - arrs = [] - for i, e in enumerate(batches): - v = e.get(k, None) - if v is None: - shape = [lens[i]] + list(value.shape[1:]) - v = np.zeros(shape, dtype=value.dtype) - arrs.append(v) - self.__dict__[k] = np.concatenate(arrs) - elif isinstance(value, torch.Tensor): - arrs = [] - for i, e in enumerate(batches): - v = e.get(k, None) - if v is None: - shape = [lens[i]] + list(value.shape[1:]) - v = torch.zeros( - shape, dtype=value.dtype, device=value.device) - arrs.append(v) - self.__dict__[k] = torch.cat(arrs) - else: - # here we do not consider scalar types, following the - # behavior of numpy which does not support concatenation - # of zero-dimensional arrays (scalars) - raise TypeError( - f"cannot cat value with type {type(value)}, we only " - "support dict, Batch, np.ndarray, and torch.Tensor") + try: + self.__dict__[k][sum_lens[i]:sum_lens[i+1]] = val + except KeyError: + self.__dict__[k] = \ + _create_value(val, sum_lens[-1], stack=False) + self.__dict__[k][sum_lens[i]:sum_lens[i+1]] = val @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': @@ -601,52 +588,18 @@ def stack_(self, if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) self.__dict__[k] = v - keys_partial = set.union(*keys_map) - keys_shared + keys_partial = set.difference(set.union(*keys_map), keys_shared) _assert_type_keys(keys_partial) for k in keys_partial: - is_dict = False - value = None for i, e in enumerate(batches): val = e.get(k, None) if val is not None: - if isinstance(val, (dict, Batch)): - is_dict = True - else: # np.ndarray / scalar or torch.Tensor - value = val - break - if is_dict: - self.__dict__[k] = Batch.stack( - [e.get(k, Batch()) for e in batches], axis) - else: - if isinstance(value, np.ndarray) \ - or issubclass(value.__class__, np.generic) \ - or isinstance(value, Number): - if isinstance(value, Number): - value = np.asarray(value) - arrs = [] - for i, e in enumerate(batches): - v = e.get(k, None) - if v is None: - v = np.zeros_like(value) - arrs.append(v) - v = np.stack(arrs, axis) - if not issubclass(v.dtype.type, (np.bool_, np.number)): - v = v.astype(np.object) - self.__dict__[k] = v - elif isinstance(value, torch.Tensor): - arrs = [] - for i, e in enumerate(batches): - v = e.get(k, None) - if v is None: - v = torch.zeros_like(value) - arrs.append(v) - self.__dict__[k] = torch.stack(arrs, axis) - else: # fallback to object - arrs = [] - for i, e in enumerate(batches): - v = e.get(k, None) - arrs.append(v) - self.__dict__[k] = np.array(arrs) + try: + self.__dict__[k][i] = val + except KeyError: + self.__dict__[k] = \ + _create_value(val, len(batches)) + self.__dict__[k][i] = val @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': From 6bf23c76e077d44ed13ef9b6834b178bae3cea6f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 Jul 2020 01:42:20 +0800 Subject: [PATCH 4/8] fix pep8 --- tianshou/data/batch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6976ea1d9..11abb4cda 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -31,8 +31,10 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ of (10, 3, 5), otherwise (10, 5) """ has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) - is_scalar = isinstance(inst, Number) or \ - issubclass(inst.__class__, np.generic) or (has_shape and not inst.shape) + is_scalar = \ + isinstance(inst, Number) or \ + issubclass(inst.__class__, np.generic) or \ + (has_shape and not inst.shape) if not stack and is_scalar: # here we do not consider scalar types, following the # behavior of numpy which does not support concatenation From f434a26f4fd3740987c757f5ff1e809903788714 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 12 Jul 2020 10:33:56 +0800 Subject: [PATCH 5/8] fix docs --- tianshou/data/batch.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 11abb4cda..ccc76573a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -516,8 +516,7 @@ def cat_(self, sum_lens.append(sum_lens[-1] + x) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) - values_shared = [ - [e[k] for e in batches] for k in keys_shared] + values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): @@ -577,8 +576,7 @@ def stack_(self, batches = [self] + list(batches) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) - values_shared = [ - [e[k] for e in batches] for k in keys_shared] + values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): @@ -614,11 +612,11 @@ def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) >>> c = Batch.stack([a, b]) >>> c.a.shape - (8, 4) + (2, 4, 4) >>> c.b.shape - (8, 6) + (2, 4, 6) >>> c.common.c.shape - (8, 5) + (2, 4, 5) """ batch = Batch() batch.stack_(batches, axis) From a78d4acd7a48ec0a3094fc59bcd52b7b420eea99 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 12 Jul 2020 11:12:50 +0800 Subject: [PATCH 6/8] raise exception for stacking with partial keys and axis!=0 --- tianshou/data/batch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index ccc76573a..c4eb917f0 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -589,6 +589,10 @@ def stack_(self, v = v.astype(np.object) self.__dict__[k] = v keys_partial = set.difference(set.union(*keys_map), keys_shared) + if keys_partial and axis != 0: + raise ValueError( + f"Stack of Batch with non-shared keys {keys_partial}" + f" is only supported with axis=0, but got axis={axis}!") _assert_type_keys(keys_partial) for k in keys_partial: for i, e in enumerate(batches): @@ -617,6 +621,12 @@ def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': (2, 4, 6) >>> c.common.c.shape (2, 4, 5) + + .. note:: + + If there are keys that are not shared across all batches, + ``stack`` with ``axis!=0`` is undefined, and will cause an + exception. """ batch = Batch() batch.stack_(batches, axis) From 1f08c4b835f30e126e46373dbe04b57c3ae36eb1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 12 Jul 2020 11:30:50 +0800 Subject: [PATCH 7/8] minor fix --- tianshou/data/batch.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index c4eb917f0..08ed2c6ae 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -535,11 +535,11 @@ def cat_(self, val = e.get(k, None) if val is not None: try: - self.__dict__[k][sum_lens[i]:sum_lens[i+1]] = val + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val except KeyError: self.__dict__[k] = \ _create_value(val, sum_lens[-1], stack=False) - self.__dict__[k][sum_lens[i]:sum_lens[i+1]] = val + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': @@ -624,9 +624,8 @@ def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': .. note:: - If there are keys that are not shared across all batches, - ``stack`` with ``axis!=0`` is undefined, and will cause an - exception. + If there are keys that are not shared across all batches, ``stack`` + with ``axis != 0`` is undefined, and will cause an exception. """ batch = Batch() batch.stack_(batches, axis) From c1873904aa1835094b945179a05ca1d939b796c3 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sun, 12 Jul 2020 15:07:11 +0800 Subject: [PATCH 8/8] minor fix --- tianshou/data/batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 08ed2c6ae..415c8f38a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -591,8 +591,8 @@ def stack_(self, keys_partial = set.difference(set.union(*keys_map), keys_shared) if keys_partial and axis != 0: raise ValueError( - f"Stack of Batch with non-shared keys {keys_partial}" - f" is only supported with axis=0, but got axis={axis}!") + f"Stack of Batch with non-shared keys {keys_partial} " + f"is only supported with axis=0, but got axis={axis}!") _assert_type_keys(keys_partial) for k in keys_partial: for i, e in enumerate(batches):