diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 5f4772399..e91d16ab7 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -116,7 +116,7 @@ def test_batch_over_batch(): assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) -def test_batch_cat_and_stack_and_empty(): +def test_batch_cat_and_stack(): b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) b12_cat_out = Batch.cat((b1, b2)) @@ -145,24 +145,6 @@ def test_batch_cat_and_stack_and_empty(): assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) assert b5.b.d[0] == b5_dict[0]['b']['d'] assert b5.b.d[1] == 0.0 - b5[1] = Batch.empty(b5[0]) - assert np.allclose(b5.a, [False, False]) - assert np.allclose(b5.b.c, [2, 0]) - assert np.allclose(b5.b.d, [1, 0]) - data = Batch(a=[False, True], - b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]}, - c=np.array([1, 3, 4], dtype=np.int), - t=torch.tensor([4, 5, 6, 7.])) - data[-1] = Batch.empty(data[1]) - assert np.allclose(data.c, [1, 3, 0]) - assert np.allclose(data.a, [False, False]) - assert list(data.b.c) == ['2.0', ''] - assert list(data.b.d) == [1, None] - assert np.allclose(data.b.e, [2, 0]) - assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.])) - b0 = Batch() - b0.empty_() - assert b0.shape == [] def test_batch_over_batch_to_torch(): @@ -225,6 +207,71 @@ def test_batch_from_to_numpy_without_copy(): assert c_mem_addr_new == c_mem_addr_orig +def test_batch_copy(): + batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6])) + batch2 = Batch({'c': np.array([6, 7, 8]), 'b': batch}) + orig_c_addr = batch2.c.__array_interface__['data'][0] + orig_b_a_addr = batch2.b.a.__array_interface__['data'][0] + orig_b_b_addr = batch2.b.b.__array_interface__['data'][0] + # test with copy=False + batch3 = Batch(copy=False, **batch2) + curr_c_addr = batch3.c.__array_interface__['data'][0] + curr_b_a_addr = batch3.b.a.__array_interface__['data'][0] + curr_b_b_addr = batch3.b.b.__array_interface__['data'][0] + assert batch2.c is batch3.c + assert batch2.b is batch3.b + assert batch2.b.a is batch3.b.a + assert batch2.b.b is batch3.b.b + assert orig_c_addr == curr_c_addr + assert orig_b_a_addr == curr_b_a_addr + assert orig_b_b_addr == curr_b_b_addr + # test with copy=True + batch3 = Batch(copy=True, **batch2) + curr_c_addr = batch3.c.__array_interface__['data'][0] + curr_b_a_addr = batch3.b.a.__array_interface__['data'][0] + curr_b_b_addr = batch3.b.b.__array_interface__['data'][0] + assert batch2.c is not batch3.c + assert batch2.b is not batch3.b + assert batch2.b.a is not batch3.b.a + assert batch2.b.b is not batch3.b.b + assert orig_c_addr != curr_c_addr + assert orig_b_a_addr != curr_b_a_addr + assert orig_b_b_addr != curr_b_b_addr + + +def test_batch_empty(): + b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, + {'a': True, 'b': {'c': 3.0}}]) + b5 = Batch(b5_dict) + b5[1] = Batch.empty(b5[0]) + assert np.allclose(b5.a, [False, False]) + assert np.allclose(b5.b.c, [2, 0]) + assert np.allclose(b5.b.d, [1, 0]) + data = Batch(a=[False, True], + b={'c': np.array([2., 'st'], dtype=np.object), + 'd': [1, None], + 'e': [2., float('nan')]}, + c=np.array([1, 3, 4], dtype=np.int), + t=torch.tensor([4, 5, 6, 7.])) + data[-1] = Batch.empty(data[1]) + assert np.allclose(data.c, [1, 3, 0]) + assert np.allclose(data.a, [False, False]) + assert list(data.b.c) == [2.0, None] + assert list(data.b.d) == [1, None] + assert np.allclose(data.b.e, [2, 0]) + assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.])) + data[0].empty_() # which will fail in a, b.c, b.d, b.e, c + assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0])) + data.empty_(index=0) + assert np.allclose(data.c, [0, 3, 0]) + assert list(data.b.c) == [None, None] + assert list(data.b.d) == [None, None] + assert list(data.b.e) == [0, 0] + b0 = Batch() + b0.empty_() + assert b0.shape == [] + + def test_batch_numpy_compatibility(): batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), @@ -246,4 +293,6 @@ def test_batch_numpy_compatibility(): test_batch_pickle() test_batch_from_to_numpy_without_copy() test_batch_numpy_compatibility() - test_batch_cat_and_stack_and_empty() + test_batch_cat_and_stack() + test_batch_copy() + test_batch_empty() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index bbf1c6a94..b3f1f8a74 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,8 +1,8 @@ import torch -import copy import pprint import warnings import numpy as np +from copy import deepcopy from functools import reduce from numbers import Number from typing import Any, List, Tuple, Union, Iterator, Optional @@ -85,8 +85,13 @@ class Batch: c: '2312312', ) - In short, you can define a :class:`Batch` with any key-value pair. The - current implementation of Tianshou typically use 7 reserved keys in + In short, you can define a :class:`Batch` with any key-value pair. + + For Numpy arrays, only data types with ``np.object`` and numbers are + supported. For strings or other data types, however, they can be held + in ``np.object`` arrays. + + The current implementation of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: * ``obs`` the observation of step :math:`t` ; @@ -252,7 +257,10 @@ def __init__(self, batch_dict: Optional[Union[ dict, 'Batch', Tuple[Union[dict, 'Batch']], List[Union[dict, 'Batch']], np.ndarray]] = None, + copy: bool = False, **kwargs) -> None: + if copy: + batch_dict = deepcopy(batch_dict) if _is_batch_set(batch_dict): self.stack_(batch_dict) elif isinstance(batch_dict, (dict, Batch)): @@ -264,7 +272,7 @@ def __init__(self, v = np.array(v) self.__dict__[k] = v if len(kwargs) > 0: - self.__init__(kwargs) + self.__init__(kwargs, copy=copy) def __setattr__(self, key: str, value: Any): """self[key] = value""" @@ -360,7 +368,7 @@ def __iadd__(self, other: Union['Batch', Number]): def __add__(self, other: Union['Batch', Number]): """Algebraic addition with another :class:`~tianshou.data.Batch` instance out-of-place.""" - return copy.deepcopy(self).__iadd__(other) + return deepcopy(self).__iadd__(other) def __imul__(self, val: Number): """Algebraic multiplication with a scalar value in-place.""" @@ -372,7 +380,7 @@ def __imul__(self, val: Number): def __mul__(self, val: Number): """Algebraic multiplication with a scalar value out-of-place.""" - return copy.deepcopy(self).__imul__(val) + return deepcopy(self).__imul__(val) def __itruediv__(self, val: Number): """Algebraic division wibyth a scalar value in-place.""" @@ -384,7 +392,7 @@ def __itruediv__(self, val: Number): def __truediv__(self, val: Number): """Algebraic division wibyth a scalar value out-of-place.""" - return copy.deepcopy(self).__itruediv__(val) + return deepcopy(self).__itruediv__(val) def __repr__(self) -> str: """Return str(self).""" @@ -476,7 +484,7 @@ def cat_(self, batch: 'Batch') -> None: if v is None: continue if not hasattr(self, k) or self.__dict__[k] is None: - self.__dict__[k] = copy.deepcopy(v) + self.__dict__[k] = deepcopy(v) elif isinstance(v, np.ndarray) and v.ndim > 0: self.__dict__[k] = np.concatenate([self.__dict__[k], v]) elif isinstance(v, torch.Tensor): @@ -537,34 +545,45 @@ def stack(batches: List['Batch'], axis: int = 0) -> 'Batch': batch.stack_(batches, axis) return batch - def empty_(self) -> 'Batch': + def empty_(self, index: Union[ + str, slice, int, np.integer, np.ndarray, List[int]] = None + ) -> 'Batch': """Return an empty a :class:`~tianshou.data.Batch` object with 0 or - ``None`` filled. + ``None`` filled. If ``index`` is specified, it will only reset the + specific indexed-data. """ for k, v in self.items(): if v is None: continue if isinstance(v, Batch): - self.__dict__[k].empty_() - elif isinstance(v, np.ndarray) and v.dtype == np.object: - self.__dict__[k].fill(None) - elif isinstance(v, torch.Tensor): # cannot apply fill_ directly - self.__dict__[k] = torch.zeros_like(self.__dict__[k]) - else: # np - self.__dict__[k] *= 0 - if hasattr(v, 'dtype') and v.dtype.kind in 'fc': - self.__dict__[k] = np.nan_to_num(self.__dict__[k]) + self.__dict__[k].empty_(index=index) + elif isinstance(v, torch.Tensor): + self.__dict__[k][index] = 0 + elif isinstance(v, np.ndarray): + if v.dtype == np.object: + self.__dict__[k][index] = None + else: + self.__dict__[k][index] = 0 + else: # scalar value + warnings.warn('You are calling Batch.empty on a NumPy scalar, ' + 'which may cause undefined behaviors.') + if isinstance(v, (np.generic, Number)): + self.__dict__[k] *= 0 + if np.isnan(self.__dict__[k]): + self.__dict__[k] = 0 + else: + self.__dict__[k] = None return self @staticmethod - def empty(batch: 'Batch') -> 'Batch': + def empty(batch: 'Batch', index: Union[ + str, slice, int, np.integer, np.ndarray, List[int]] = None + ) -> 'Batch': """Return an empty :class:`~tianshou.data.Batch` object with 0 or ``None`` filled, the shape is the same as the given :class:`~tianshou.data.Batch`. """ - batch = Batch(**batch) - batch.empty_() - return batch + return deepcopy(batch).empty_(index) def __len__(self) -> int: """Return len(self).""" diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 75c948a8a..3a7ad7821 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -200,10 +200,15 @@ def _reset_state(self, id: Union[int, List[int]]) -> None: return if isinstance(self.state, list): self.state[id] = None - elif isinstance(self.state, (torch.Tensor, np.ndarray)): - self.state[id] *= 0 - else: # Batch - self.state[id].empty_() + elif isinstance(self.state, torch.Tensor): + self.state[id].zero_() + elif isinstance(self.state, np.ndarray): + if isinstance(self.state.dtype == np.object): + self.state[id] = None + else: + self.state[id] = 0 + elif isinstance(self.state, Batch): + self.state.empty_(id) def collect(self, n_step: int = 0,