From 3328252cf71ae309cb388a15257d7fcbf718536c Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 11 Jul 2020 11:06:22 +0800 Subject: [PATCH 01/13] minor polish --- tianshou/data/batch.py | 32 ++++++++------------------------ tianshou/data/buffer.py | 20 +++++++------------- 2 files changed, 15 insertions(+), 37 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6b1051788..02282af77 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -259,8 +259,7 @@ def __init__(self, v_ = None if not isinstance(v, np.ndarray) and \ all(isinstance(e, torch.Tensor) for e in v): - v_ = torch.stack(v) - self.__dict__[k] = v_ + self.__dict__[k] = torch.stack(v) continue else: v_ = np.asanyarray(v) @@ -333,9 +332,8 @@ def __getitem__(self, index: Union[ else: raise IndexError("Cannot access item from empty Batch object.") - def __setitem__( - self, - index: Union[str, slice, int, np.integer, np.ndarray, List[int]], + def __setitem__(self, index: Union[ + str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" if isinstance(value, np.ndarray): @@ -454,10 +452,8 @@ def to_numpy(self) -> None: elif isinstance(v, Batch): v.to_numpy() - def to_torch(self, - dtype: Optional[torch.dtype] = None, - device: Union[str, int, torch.device] = 'cpu' - ) -> None: + def to_torch(self, dtype: Optional[torch.dtype] = None, + device: Union[str, int, torch.device] = 'cpu') -> None: """Change all numpy.ndarray to torch.Tensor. This is an in-place operation. """ @@ -473,28 +469,16 @@ def to_torch(self, v = v.type(dtype) self.__dict__[k] = v elif isinstance(v, torch.Tensor): - if dtype is not None and v.dtype != dtype: - must_update_tensor = True - elif v.device.type != device.type: - must_update_tensor = True - elif device.index is not None and \ + if dtype is not None and v.dtype != dtype or \ + v.device.type != device.type or \ + device.index is not None and \ device.index != v.device.index: - must_update_tensor = True - else: - must_update_tensor = False - if must_update_tensor: if dtype is not None: v = v.type(dtype) self.__dict__[k] = v.to(device) elif isinstance(v, Batch): v.to_torch(dtype, device) - def append(self, batch: 'Batch') -> None: - warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be ' - 'removed soon, please use ' - ':meth:`~tianshou.data.Batch.cat`') - return self.cat_(batch) - def cat_(self, batch: 'Batch') -> None: """Concatenate a :class:`~tianshou.data.Batch` object into current batch. diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 33d31789d..803796078 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -108,8 +108,7 @@ def __init__(self, size: int, stack_num: Optional[int] = 0, super().__init__() self._maxsize = size self._stack = stack_num - assert stack_num != 1, \ - 'stack_num should greater than 1' + assert stack_num != 1, 'stack_num should greater than 1' self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next @@ -136,12 +135,11 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, np.ndarray) and \ - value.shape[1:] != inst.shape: + if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape: raise ValueError( "Cannot add data to a buffer with different shape, key: " - f"{name}, expect shape: {value.shape[1:]}" - f", given shape: {inst.shape}.") + f"{name}, expect shape: {value.shape[1:]}, " + f"given shape: {inst.shape}.") try: value[self._index] = inst except KeyError: @@ -357,7 +355,7 @@ def __init__(self, size: int, alpha: float, beta: float, self._weight_sum = 0.0 self._amortization_freq = 50 self._replace = replace - self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64) + self._meta.weight = np.zeros(size, dtype=np.float64) def add(self, obs: Union[dict, np.ndarray], @@ -372,7 +370,7 @@ def add(self, """Add a batch of data into replay buffer.""" # we have to sacrifice some convenience for speed self._weight_sum += np.abs(weight) ** self._alpha - \ - self._meta.__dict__['weight'][self._index] + self._meta.weight[self._index] self._add_to_buffer('weight', np.abs(weight) ** self._alpha) super().add(obs, act, rew, done, obs_next, info, policy) @@ -410,14 +408,10 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: f"batch_size should be less than {len(self)}, \ or set replace=True") batch = self[indice] - impt_weight = Batch( - impt_weight=(self._size * p) ** (-self._beta)) + impt_weight = Batch(impt_weight=(self._size * p) ** (-self._beta)) batch.cat_(impt_weight) return batch, indice - def reset(self) -> None: - super().reset() - def update_weight(self, indice: Union[slice, np.ndarray], new_weight: np.ndarray) -> None: """Update priority weight by indice in this buffer. From 6fe3ac553688c2ac1fb21d386b598b5f1c3c6389 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 14:46:09 +0800 Subject: [PATCH 02/13] improve and implement Batch.cat_ --- test/base/test_batch.py | 24 +++++++-- tianshou/data/batch.py | 106 +++++++++++++++++++++++++++---------- tianshou/data/buffer.py | 2 +- tianshou/data/collector.py | 2 +- 4 files changed, 100 insertions(+), 34 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index a9f2cdd20..506b4dac4 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -19,7 +19,7 @@ def test_batch(): assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] - batch.cat_(batch) + batch.cat_([batch]) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) @@ -96,13 +96,13 @@ def test_batch_over_batch(): for k, v in batch2.items(): assert np.all(batch2[k] == v) assert batch2[-1].b.b == 0 - batch2.cat_(Batch(c=[6, 7, 8], b=batch)) + batch2.cat_([Batch(c=[6, 7, 8], b=batch)]) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) - batch3.cat_(Batch(c=[6, 7, 8], b=d)) + batch3.cat_([Batch(c=[6, 7, 8], b=d)]) assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) @@ -124,18 +124,32 @@ def test_batch_over_batch(): def test_batch_cat_and_stack(): + # test cat with compatible keys b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) - b12_cat_out = Batch.cat((b1, b2)) + b12_cat_out = Batch.cat([b1, b2]) b12_cat_in = copy.deepcopy(b1) - b12_cat_in.cat_(b2) + b12_cat_in.cat_([b2]) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) assert b12_cat_in.a.d.e.ndim == 1 + b12_stack = Batch.stack((b1, b2)) assert isinstance(b12_stack.a.d.e, np.ndarray) assert b12_stack.a.d.e.ndim == 2 + + # test batch with incompatible keys + b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) + b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) + test = Batch.cat([b1, b2]) + ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + assert np.allclose(test.a, ans.a) + assert torch.allclose(test.b, ans.b) + assert np.allclose(test.common.c, ans.common.c) + b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 02282af77..a3f0ff1d4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -479,44 +479,96 @@ def to_torch(self, dtype: Optional[torch.dtype] = None, elif isinstance(v, Batch): v.to_torch(dtype, device) - def cat_(self, batch: 'Batch') -> None: - """Concatenate a :class:`~tianshou.data.Batch` object into current - batch. + def cat_(self, batches: List[Union[dict, 'Batch']]) -> None: + """Concatenate a list of :class:`~tianshou.data.Batch` objects + into current batch. """ - assert isinstance(batch, Batch), \ - 'Only Batch is allowed to be concatenated in-place!' - for k, v in batch.items(): - if v is None: - continue - if not hasattr(self, k) or self.__dict__[k] is None: - self.__dict__[k] = deepcopy(v) - elif isinstance(v, np.ndarray) and v.ndim > 0: - self.__dict__[k] = np.concatenate([self.__dict__[k], v]) - elif isinstance(v, torch.Tensor): - self.__dict__[k] = torch.cat([self.__dict__[k], v]) - elif isinstance(v, Batch): - self.__dict__[k].cat_(v) + if len(batches) == 0: + return + batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] + if len(self.__dict__) > 0: + batches = [self] + list(batches) + # partial keys will be padded by zeros + # with the shape of [len, rest_shape] + lens = [len(x) for x in batches] + keys_map = list(map(lambda e: set(e.keys()), batches)) + keys_shared = set.intersection(*keys_map) + values_shared = [ + [e[k] for e in batches] for k in keys_shared] + _assert_type_keys(keys_shared) + for k, v in zip(keys_shared, values_shared): + if all(isinstance(e, (dict, Batch)) for e in v): + self.__dict__[k] = Batch.cat(v) + elif all(isinstance(e, torch.Tensor) for e in v): + self.__dict__[k] = torch.cat(v) else: - s = 'No support for method "cat" with type '\ - f'{type(v)} in class Batch.' - raise TypeError(s) + v = np.concatenate(v) + if not issubclass(v.dtype.type, (np.bool_, np.number)): + v = v.astype(np.object) + self.__dict__[k] = v + keys_partial = set.union(*keys_map) - keys_shared + _assert_type_keys(keys_partial) + for k in keys_partial: + is_dict = False + value = None + for i, e in enumerate(batches): + val = e.get(k, None) + if val is not None: + if isinstance(val, (dict, Batch)): + is_dict = True + else: + # ndarray or torch.Tensor + value = val + break + if is_dict: + self.__dict__[k] = Batch.cat( + [e.get(k, Batch()) for e in batches]) + else: + if isinstance(value, np.ndarray): + arrs = [] + for i, e in enumerate(batches): + shape = [lens[i]] + list(value.shape[1:]) + pad = np.zeros(shape, dtype=value.dtype) + arrs.append(e.get(k, pad)) + self.__dict__[k] = np.concatenate(arrs) + elif isinstance(value, torch.Tensor): + arrs = [] + for i, e in enumerate(batches): + shape = [lens[i]] + list(value.shape[1:]) + pad = torch.zeros(shape, + dtype=value.dtype, + device=value.device) + arrs.append(e.get(k, pad)) + self.__dict__[k] = torch.cat(arrs) + else: + raise TypeError(f"cannot cat value with type " + f"{type(value)},we only support" + f" dict, Batch, np.ndarray, " + f"and torch.Tensor") @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': """Concatenate a list of :class:`~tianshou.data.Batch` object into a single - new batch. + new batch. For keys that are not shared across all batches, + batches that do not have these keys will be padded by zeros with + appropriate shapes. + e.g. + a = Batch(a=shape(3,4), common=Batch(c=shape(3, 5))) + b = Batch(b=shape(4,3), common=Batch(c=shape(4, 5))) + c = Batch.cat([b, c]) + c.a.shape: (7, 4) + c.b.shape: (7, 3) + c.common.c.shape: (7, 5) """ batch = Batch() - for batch_ in batches: - if isinstance(batch_, dict): - batch_ = Batch(batch_) - batch.cat_(batch_) + batch.cat_(batches) return batch def stack_(self, batches: List[Union[dict, 'Batch']], axis: int = 0) -> None: - """Stack a :class:`~tianshou.data.Batch` object i into current batch. + """Stack a list of :class:`~tianshou.data.Batch` object + into current batch. """ if len(self.__dict__) > 0: batches = [self] + list(batches) @@ -550,8 +602,8 @@ def stack_(self, @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': - """Stack a :class:`~tianshou.data.Batch` object into a single new - batch. + """Stack a list of :class:`~tianshou.data.Batch` object + into a single new batch. """ batch = Batch() batch.stack_(batches, axis) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 803796078..3e418b71f 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -409,7 +409,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: or set replace=True") batch = self[indice] impt_weight = Batch(impt_weight=(self._size * p) ** (-self._beta)) - batch.cat_(impt_weight) + batch.cat_([impt_weight]) return batch, indice def update_weight(self, indice: Union[slice, np.ndarray], diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 3a7ad7821..bb7ca9596 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -426,7 +426,7 @@ def sample(self, batch_size: int) -> Batch: if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) - batch_data.cat_(batch) + batch_data.cat_([batch]) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) From f34b0f8bc769b2965ee660041f158cfe1f3e2d2c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 15:08:17 +0800 Subject: [PATCH 03/13] bugfix for buffer.sample with field impt_weight --- tianshou/data/buffer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 3e418b71f..f593d2a74 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -408,8 +408,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: f"batch_size should be less than {len(self)}, \ or set replace=True") batch = self[indice] - impt_weight = Batch(impt_weight=(self._size * p) ** (-self._beta)) - batch.cat_([impt_weight]) + batch["impt_weight"] = (self._size * p) ** (-self._beta) return batch, indice def update_weight(self, indice: Union[slice, np.ndarray], From eba3b366cb950ad749597b228a7598df3c079f38 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 18:28:36 +0800 Subject: [PATCH 04/13] restore the usage of a.cat_(b) --- test/base/test_batch.py | 8 ++++---- tianshou/data/batch.py | 9 ++++++--- tianshou/data/collector.py | 2 +- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 506b4dac4..952f44c0d 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -19,7 +19,7 @@ def test_batch(): assert batch.obs == batch["obs"] batch.obs = [1] assert batch.obs == [1] - batch.cat_([batch]) + batch.cat_(batch) assert np.allclose(batch.obs, [1, 1]) assert batch.np.shape == (6, 4) assert np.allclose(batch[0].obs, batch[1].obs) @@ -96,13 +96,13 @@ def test_batch_over_batch(): for k, v in batch2.items(): assert np.all(batch2[k] == v) assert batch2[-1].b.b == 0 - batch2.cat_([Batch(c=[6, 7, 8], b=batch)]) + batch2.cat_(Batch(c=[6, 7, 8], b=batch)) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) - batch3.cat_([Batch(c=[6, 7, 8], b=d)]) + batch3.cat_(Batch(c=[6, 7, 8], b=d)) assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) @@ -129,7 +129,7 @@ def test_batch_cat_and_stack(): b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat([b1, b2]) b12_cat_in = copy.deepcopy(b1) - b12_cat_in.cat_([b2]) + b12_cat_in.cat_(b2) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) assert isinstance(b12_cat_in.a.d.e, np.ndarray) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a3f0ff1d4..75794dea5 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -479,10 +479,13 @@ def to_torch(self, dtype: Optional[torch.dtype] = None, elif isinstance(v, Batch): v.to_torch(dtype, device) - def cat_(self, batches: List[Union[dict, 'Batch']]) -> None: - """Concatenate a list of :class:`~tianshou.data.Batch` objects - into current batch. + def cat_(self, + batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: + """Concatenate a list of (or one) :class:`~tianshou.data.Batch` + objects into current batch. """ + if isinstance(batches, Batch): + batches = [batches] if len(batches) == 0: return batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index bb7ca9596..3a7ad7821 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -426,7 +426,7 @@ def sample(self, batch_size: int) -> Batch: if batch_size and cur_batch or batch_size <= 0: batch, indice = b.sample(cur_batch) batch = self.process_fn(batch, b, indice) - batch_data.cat_([batch]) + batch_data.cat_(batch) else: batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) From 42ee76d5285adceb1131ad0e4cc102f6960716b1 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 11 Jul 2020 19:56:27 +0800 Subject: [PATCH 05/13] fix 2 bugs in batch and add corresponding unittest --- test/base/test_batch.py | 16 ++++++++++++++++ tianshou/data/batch.py | 17 ++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 952f44c0d..e7a845a3b 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -86,6 +86,18 @@ def test_batch(): assert batch3.a.d.f[0] == 5.0 with pytest.raises(KeyError): batch3.a.d[0] = Batch(f=5.0, g=0.0) + # auto convert + batch4 = Batch(a=np.array(['a', 'b'])) + assert batch4.a.dtype == np.object # auto convert to np.object + batch4.update(a=np.array(['c', 'd'])) + assert list(batch4.a) == ['c', 'd'] + assert batch4.a.dtype == np.object # auto convert to np.object + batch5 = Batch(a=np.array([{'index': 0}])) + assert isinstance(batch5.a, Batch) + assert np.allclose(batch5.a.index, [0]) + batch5.b = np.array([{'index': 1}]) + assert isinstance(batch5.b, Batch) + assert np.allclose(batch5.b.index, [1]) def test_batch_over_batch(): @@ -100,6 +112,10 @@ def test_batch_over_batch(): assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) + batch2.update(batch2.b) + assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) + assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) + assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 75794dea5..b861c3c29 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -293,7 +293,8 @@ def __setattr__(self, key: str, value: Any): value = np.array(value) if not issubclass(value.dtype.type, (np.bool_, np.number)): value = value.astype(np.object) - elif isinstance(value, dict): + elif isinstance(value, dict) or isinstance(value, np.ndarray) \ + and value.dtype == np.object and _is_batch_set(value): value = Batch(value) self.__dict__[key] = value @@ -650,6 +651,20 @@ def empty(batch: 'Batch', index: Union[ """ return deepcopy(batch).empty_(index) + def update(self, batch: Optional[Union[dict, 'Batch']] = None, + **kwargs) -> None: + """Update this batch from another dict/Batch.""" + if isinstance(batch, dict): + batch = Batch(batch) + if batch is not None: + for k, v in batch.items(): + self.__dict__[k] = v + if kwargs is not None: + batch = Batch(kwargs) + if batch is not None: + for k, v in batch.items(): + self.__dict__[k] = v + def __len__(self) -> int: """Return len(self).""" r = [] From 7e42b76c063dcca26498228d538b908a5d34b913 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 20:13:39 +0800 Subject: [PATCH 06/13] code fix for update --- tianshou/data/batch.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index b861c3c29..14489b2be 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -654,16 +654,13 @@ def empty(batch: 'Batch', index: Union[ def update(self, batch: Optional[Union[dict, 'Batch']] = None, **kwargs) -> None: """Update this batch from another dict/Batch.""" + if batch is None: + self.update(kwargs) + return if isinstance(batch, dict): batch = Batch(batch) - if batch is not None: - for k, v in batch.items(): - self.__dict__[k] = v - if kwargs is not None: - batch = Batch(kwargs) - if batch is not None: - for k, v in batch.items(): - self.__dict__[k] = v + for k, v in batch.items(): + self.__dict__[k] = v def __len__(self) -> int: """Return len(self).""" From d6d0a20eae78bb05095e2455a2044664bd241dd9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 20:20:59 +0800 Subject: [PATCH 07/13] update is_empty to recognize empty over empty; bugfix for len --- test/base/test_batch.py | 2 ++ tianshou/data/batch.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index e7a845a3b..8719758ef 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -10,6 +10,8 @@ def test_batch(): assert list(Batch()) == [] assert Batch().is_empty() + assert Batch(b={'c': {}}).is_empty() + assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() with pytest.raises(AssertionError): Batch({1: 2}) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 14489b2be..923669c5d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -666,7 +666,7 @@ def __len__(self) -> int: """Return len(self).""" r = [] for v in self.__dict__.values(): - if isinstance(v, Batch) and len(v.__dict__) == 0: + if isinstance(v, Batch) and v.is_empty(): continue elif hasattr(v, '__len__') and (not isinstance( v, (np.ndarray, torch.Tensor)) or v.ndim > 0): @@ -678,7 +678,9 @@ def __len__(self) -> int: return min(r) def is_empty(self): - return len(self.__dict__.keys()) == 0 + return not any( + not x.is_empty() if isinstance(x, Batch) + else hasattr(x, '__len__') and len(x) > 0 for x in self.values()) @property def shape(self) -> List[int]: From 988f0f367d2b8d27be2a5d0a562f5ac1bbbf9857 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 20:35:25 +0800 Subject: [PATCH 08/13] bugfix for update and add testcase --- test/base/test_batch.py | 8 +++++++- tianshou/data/batch.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 8719758ef..325105c40 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -13,6 +13,11 @@ def test_batch(): assert Batch(b={'c': {}}).is_empty() assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 assert not Batch(a=[1, 2, 3]).is_empty() + b = Batch() + b.update() + assert b.is_empty() + b.update(c=[3, 5]) + assert np.allclose(b.c, [3, 5]) with pytest.raises(AssertionError): Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) @@ -114,10 +119,11 @@ def test_batch_over_batch(): assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) - batch2.update(batch2.b) + batch2.update(batch2.b, six=[6, 6, 6]) assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) + assert np.allclose(batch2.six, [6, 6, 6]) d = {'a': [3, 4, 5], 'b': [4, 5, 6]} batch3 = Batch(c=[6, 7, 8], b=d) batch3.cat_(Batch(c=[6, 7, 8], b=d)) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 923669c5d..bab3803c3 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -661,6 +661,8 @@ def update(self, batch: Optional[Union[dict, 'Batch']] = None, batch = Batch(batch) for k, v in batch.items(): self.__dict__[k] = v + if kwargs: + self.update(kwargs) def __len__(self) -> int: """Return len(self).""" From 1c27de85fd6577fca7f53a0dc6894d56ba2db6b4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 11 Jul 2020 20:41:28 +0800 Subject: [PATCH 09/13] add testcase of update --- test/base/test_batch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 325105c40..03031ff2e 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -18,6 +18,9 @@ def test_batch(): assert b.is_empty() b.update(c=[3, 5]) assert np.allclose(b.c, [3, 5]) + # mimic the behavior of dict.update, where kwargs can overwrite keys + b.update({'a': 2}, a=3) + assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) batch = Batch(a=[torch.ones(3), torch.ones(3)]) From ae5769b543bc5396d8002b6ae91dafbb8338ab5f Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 11 Jul 2020 21:03:26 +0800 Subject: [PATCH 10/13] fix docs --- tianshou/data/batch.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index bab3803c3..825be854a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -482,8 +482,8 @@ def to_torch(self, dtype: Optional[torch.dtype] = None, def cat_(self, batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: - """Concatenate a list of (or one) :class:`~tianshou.data.Batch` - objects into current batch. + """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects + into current batch. """ if isinstance(batches, Batch): batches = [batches] @@ -521,7 +521,7 @@ def cat_(self, if isinstance(val, (dict, Batch)): is_dict = True else: - # ndarray or torch.Tensor + # np.ndarray or torch.Tensor value = val break if is_dict: @@ -552,17 +552,21 @@ def cat_(self, @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': - """Concatenate a list of :class:`~tianshou.data.Batch` object into a single - new batch. For keys that are not shared across all batches, + """Concatenate a list of :class:`~tianshou.data.Batch` object into a + single new batch. For keys that are not shared across all batches, batches that do not have these keys will be padded by zeros with - appropriate shapes. - e.g. - a = Batch(a=shape(3,4), common=Batch(c=shape(3, 5))) - b = Batch(b=shape(4,3), common=Batch(c=shape(4, 5))) - c = Batch.cat([b, c]) - c.a.shape: (7, 4) - c.b.shape: (7, 3) - c.common.c.shape: (7, 5) + appropriate shapes. E.g. + :: + + >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) + >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) + >>> c = Batch.cat([a, b]) + >>> c.a.shape + (7, 4) + >>> c.b.shape + (7, 3) + >>> c.common.c.shape + (7, 5) """ batch = Batch() batch.cat_(batches) @@ -571,8 +575,8 @@ def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': def stack_(self, batches: List[Union[dict, 'Batch']], axis: int = 0) -> None: - """Stack a list of :class:`~tianshou.data.Batch` object - into current batch. + """Stack a list of :class:`~tianshou.data.Batch` object into current + batch. """ if len(self.__dict__) > 0: batches = [self] + list(batches) From 25750788195ccf6ef6a5af438d0a7b0e88991526 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 11 Jul 2020 21:11:23 +0800 Subject: [PATCH 11/13] fix docs --- tianshou/data/batch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 825be854a..8933c3a2d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -520,8 +520,7 @@ def cat_(self, if val is not None: if isinstance(val, (dict, Batch)): is_dict = True - else: - # np.ndarray or torch.Tensor + else: # np.ndarray or torch.Tensor value = val break if is_dict: From 84160b4723c03e9f85942b8a1d1852a889406468 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 11 Jul 2020 21:14:14 +0800 Subject: [PATCH 12/13] fix docs [ci skip] --- tianshou/data/batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 8933c3a2d..1d390be37 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -609,8 +609,8 @@ def stack_(self, @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': - """Stack a list of :class:`~tianshou.data.Batch` object - into a single new batch. + """Stack a list of :class:`~tianshou.data.Batch` object into a single + new batch. """ batch = Batch() batch.stack_(batches, axis) From d5a366fb0e51b2d0e274dc87da90b7d571815da3 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 11 Jul 2020 21:19:17 +0800 Subject: [PATCH 13/13] fix docs [ci skip] --- .github/workflows/pytest.yml | 1 + tianshou/data/batch.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index eb14a875e..c1e3604ed 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -5,6 +5,7 @@ on: [push, pull_request] jobs: build: runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: python-version: [3.6, 3.7, 3.8] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1d390be37..1240cfe50 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -544,10 +544,9 @@ def cat_(self, arrs.append(e.get(k, pad)) self.__dict__[k] = torch.cat(arrs) else: - raise TypeError(f"cannot cat value with type " - f"{type(value)},we only support" - f" dict, Batch, np.ndarray, " - f"and torch.Tensor") + raise TypeError( + f"cannot cat value with type {type(value)}, we only " + "support dict, Batch, np.ndarray, and torch.Tensor") @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':