diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 03031ff2e..7e287e269 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -166,7 +166,7 @@ def test_batch_cat_and_stack(): assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 - # test batch with incompatible keys + # test cat with incompatible keys b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) @@ -177,6 +177,7 @@ def test_batch_cat_and_stack(): assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) + # test stack with compatible keys b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) @@ -194,6 +195,26 @@ def test_batch_cat_and_stack(): assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 + # test stack with incompatible keys + a = Batch(a=1, b=2, c=3) + b = Batch(a=4, b=5, d=6) + c = Batch(c=7, b=6, d=9) + d = Batch.stack([a, b, c]) + assert np.allclose(d.a, [1, 4, 0]) + assert np.allclose(d.b, [2, 5, 6]) + assert np.allclose(d.c, [3, 0, 7]) + assert np.allclose(d.d, [0, 6, 9]) + + b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) + test = Batch.stack([b1, b2]) + ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), + b=torch.stack([torch.zeros(4, 6), b2.b]), + common=Batch(c=np.stack([b1.common.c, b2.common.c]))) + assert np.allclose(test.a, ans.a) + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + def test_batch_over_batch_to_torch(): batch = Batch( diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1240cfe50..415c8f38a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,7 +3,6 @@ import warnings import numpy as np from copy import deepcopy -from functools import reduce from numbers import Number from typing import Any, List, Tuple, Union, Iterator, Optional @@ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool: return False -def _create_value(inst: Any, size: int) -> Union[ +def _create_value(inst: Any, size: int, stack=True) -> Union[ 'Batch', np.ndarray, torch.Tensor]: + """ + :param bool stack: whether to stack or to concatenate. E.g. if inst has + shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape + of (10, 3, 5), otherwise (10, 5) + """ + has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) + is_scalar = \ + isinstance(inst, Number) or \ + issubclass(inst.__class__, np.generic) or \ + (has_shape and not inst.shape) + if not stack and is_scalar: + # here we do not consider scalar types, following the + # behavior of numpy which does not support concatenation + # of zero-dimensional arrays (scalars) + raise TypeError(f"cannot cat {inst} with which is scalar") + if has_shape: + shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) if isinstance(inst, np.ndarray): if issubclass(inst.dtype.type, (np.bool_, np.number)): target_type = inst.dtype.type else: target_type = np.object - return np.full((size, *inst.shape), + return np.full(shape, fill_value=None if target_type == np.object else 0, dtype=target_type) elif isinstance(inst, torch.Tensor): - return torch.full((size, *inst.shape), + return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): - zero_batch.__dict__[key] = _create_value(val, size) + zero_batch.__dict__[key] = _create_value(val, size, stack=stack) return zero_batch - elif isinstance(inst, (np.generic, Number)): - return _create_value(np.asarray(inst), size) + elif is_scalar: + return _create_value(np.asarray(inst), size, stack=stack) else: # fall back to np.object return np.array([None for _ in range(size)]) @@ -495,10 +511,12 @@ def cat_(self, # partial keys will be padded by zeros # with the shape of [len, rest_shape] lens = [len(x) for x in batches] + sum_lens = [0] + for x in lens: + sum_lens.append(sum_lens[-1] + x) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) - values_shared = [ - [e[k] for e in batches] for k in keys_shared] + values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): @@ -513,40 +531,15 @@ def cat_(self, keys_partial = set.union(*keys_map) - keys_shared _assert_type_keys(keys_partial) for k in keys_partial: - is_dict = False - value = None for i, e in enumerate(batches): val = e.get(k, None) if val is not None: - if isinstance(val, (dict, Batch)): - is_dict = True - else: # np.ndarray or torch.Tensor - value = val - break - if is_dict: - self.__dict__[k] = Batch.cat( - [e.get(k, Batch()) for e in batches]) - else: - if isinstance(value, np.ndarray): - arrs = [] - for i, e in enumerate(batches): - shape = [lens[i]] + list(value.shape[1:]) - pad = np.zeros(shape, dtype=value.dtype) - arrs.append(e.get(k, pad)) - self.__dict__[k] = np.concatenate(arrs) - elif isinstance(value, torch.Tensor): - arrs = [] - for i, e in enumerate(batches): - shape = [lens[i]] + list(value.shape[1:]) - pad = torch.zeros(shape, - dtype=value.dtype, - device=value.device) - arrs.append(e.get(k, pad)) - self.__dict__[k] = torch.cat(arrs) - else: - raise TypeError( - f"cannot cat value with type {type(value)}, we only " - "support dict, Batch, np.ndarray, and torch.Tensor") + try: + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val + except KeyError: + self.__dict__[k] = \ + _create_value(val, sum_lens[-1], stack=False) + self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': @@ -576,12 +569,14 @@ def stack_(self, """Stack a list of :class:`~tianshou.data.Batch` object into current batch. """ + if len(batches) == 0: + return + batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] if len(self.__dict__) > 0: batches = [self] + list(batches) keys_map = list(map(lambda e: set(e.keys()), batches)) keys_shared = set.intersection(*keys_map) - values_shared = [ - [e[k] for e in batches] for k in keys_shared] + values_shared = [[e[k] for e in batches] for k in keys_shared] _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): @@ -593,7 +588,11 @@ def stack_(self, if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) self.__dict__[k] = v - keys_partial = reduce(set.symmetric_difference, keys_map) + keys_partial = set.difference(set.union(*keys_map), keys_shared) + if keys_partial and axis != 0: + raise ValueError( + f"Stack of Batch with non-shared keys {keys_partial} " + f"is only supported with axis=0, but got axis={axis}!") _assert_type_keys(keys_partial) for k in keys_partial: for i, e in enumerate(batches): @@ -609,7 +608,24 @@ def stack_(self, @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': """Stack a list of :class:`~tianshou.data.Batch` object into a single - new batch. + new batch. For keys that are not shared across all batches, + batches that do not have these keys will be padded by zeros. E.g. + :: + + >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) + >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.stack([a, b]) + >>> c.a.shape + (2, 4, 4) + >>> c.b.shape + (2, 4, 6) + >>> c.common.c.shape + (2, 4, 5) + + .. note:: + + If there are keys that are not shared across all batches, ``stack`` + with ``axis != 0`` is undefined, and will cause an exception. """ batch = Batch() batch.stack_(batches, axis)