From 7f582577f0c0cb794bc9723248a5a695531d4fe2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 Jul 2020 22:33:36 +0800 Subject: [PATCH 1/3] make sure the key type of Batch is string, and add unit tests --- test/base/test_batch.py | 2 ++ tianshou/data/batch.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index e2390dec7..063fa8c3f 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -9,6 +9,8 @@ def test_batch(): assert list(Batch()) == [] + with pytest.raises(AssertionError): + Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch = Batch(obs=[0], np=np.zeros([3, 4])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index e14b98f50..2e102cd5b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -50,6 +50,12 @@ def _create_value(inst: Any, size: int) -> Union[ return np.array([None for _ in range(size)]) +def _assert_type_keys(keys): + keys = list(keys) + assert all(isinstance(e, str) for e in keys), \ + f"keys should all be string, but got {keys}" + + class Batch: """Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a @@ -247,6 +253,7 @@ def __init__(self, batch_dict = deepcopy(batch_dict) if batch_dict is not None: if isinstance(batch_dict, (dict, Batch)): + _assert_type_keys(batch_dict.keys()) for k, v in batch_dict.items(): if isinstance(v, (list, tuple, np.ndarray)): v_ = None @@ -531,6 +538,7 @@ def stack_(self, keys_shared = set.intersection(*keys_map) values_shared = [ [e[k] for e in batches] for k in keys_shared] + _assert_type_keys(keys_shared) for k, v in zip(keys_shared, values_shared): if all(isinstance(e, (dict, Batch)) for e in v): self.__dict__[k] = Batch.stack(v, axis) @@ -542,6 +550,7 @@ def stack_(self, v = v.astype(np.object) self.__dict__[k] = v keys_partial = reduce(set.symmetric_difference, keys_map) + _assert_type_keys(keys_partial) for k in keys_partial: for i, e in enumerate(batches): val = e.get(k, None) From 4129116c86e23bc16922e71c13122c7cc66b965b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 Jul 2020 23:22:28 +0800 Subject: [PATCH 2/3] add is_empty() function and unit tests --- test/base/test_batch.py | 2 ++ tianshou/data/batch.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 063fa8c3f..a9f2cdd20 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -9,6 +9,8 @@ def test_batch(): assert list(Batch()) == [] + assert Batch().is_empty() + assert not Batch(a=[1, 2, 3]).is_empty() with pytest.raises(AssertionError): Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2e102cd5b..f31c9db10 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -624,6 +624,9 @@ def __len__(self) -> int: raise TypeError("Object of type 'Batch' has no len()") return min(r) + def is_empty(self): + return len(self.__dict__.keys()) == 0 + @property def shape(self) -> List[int]: """Return self.shape.""" From 31f6da7debc1d3c3ce3c6d90b13d0469db12e3e6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 00:56:08 +0800 Subject: [PATCH 3/3] enable cat of mixing dict and Batch, just like stack --- tianshou/data/batch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f31c9db10..6b1051788 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -518,12 +518,14 @@ def cat_(self, batch: 'Batch') -> None: raise TypeError(s) @staticmethod - def cat(batches: List['Batch']) -> 'Batch': - """Concatenate a :class:`~tianshou.data.Batch` object into a single + def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': + """Concatenate a list of :class:`~tianshou.data.Batch` object into a single new batch. """ batch = Batch() for batch_ in batches: + if isinstance(batch_, dict): + batch_ = Batch(batch_) batch.cat_(batch_) return batch @@ -563,7 +565,7 @@ def stack_(self, self.__dict__[k][i] = val @staticmethod - def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': + def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': """Stack a :class:`~tianshou.data.Batch` object into a single new batch. """