diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 51b8bccdb..77d1343ef 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -3,8 +3,9 @@ import pickle import pytest import numpy as np +from itertools import starmap -from tianshou.data import Batch, to_torch +from tianshou.data import Batch, to_torch, to_numpy def test_batch(): @@ -28,8 +29,19 @@ def test_batch(): assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) + with pytest.raises(TypeError): + Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]) + with pytest.raises(TypeError): + Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) + with pytest.raises(TypeError): + Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) + with pytest.raises(TypeError): + Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) + with pytest.raises(TypeError): + Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) + Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] @@ -307,7 +319,7 @@ def test_batch_over_batch_to_torch(): assert batch.b.d.dtype == torch.float32 -def test_utils_to_torch(): +def test_utils_to_torch_numpy(): batch = Batch( a=np.float64(1.0), b=Batch( @@ -323,8 +335,37 @@ def test_utils_to_torch(): assert batch_torch_float.a.dtype == torch.float32 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32 - array_list = [float('nan'), 1.0] - assert to_torch(array_list).dtype == torch.float64 + data_list = [float('nan'), 1] + data_list_torch = to_torch(data_list) + assert data_list_torch.dtype == torch.float64 + data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)] + data_list_2_torch = to_torch(data_list_2) + assert data_list_2_torch.shape == (2, 3, 3) + assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2) + data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))] + data_list_3_torch = to_torch(data_list_3) + assert isinstance(data_list_3_torch, list) + assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch) + assert all(starmap(np.allclose, + zip(to_numpy(to_torch(data_list_3)), data_list_3))) + data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))] + data_list_4_torch = to_torch(data_list_4) + assert isinstance(data_list_4_torch, list) + assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch) + assert all(starmap(np.allclose, + zip(to_numpy(to_torch(data_list_4)), data_list_4))) + data_list_5 = [np.zeros(2), np.zeros((3, 3))] + data_list_5_torch = to_torch(data_list_5) + assert isinstance(data_list_5_torch, list) + assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch) + data_array = np.random.rand(3, 2, 2) + data_empty_tensor = to_torch(data_array[[]]) + assert isinstance(data_empty_tensor, torch.Tensor) + assert data_empty_tensor.shape == (0, 2, 2) + data_empty_array = to_numpy(data_empty_tensor) + assert isinstance(data_empty_array, np.ndarray) + assert data_empty_array.shape == (0, 2, 2) + assert np.allclose(to_numpy(to_torch(data_array)), data_array) def test_batch_pickle(): @@ -432,7 +473,7 @@ def test_batch_standard_compatibility(): test_batch() test_batch_over_batch() test_batch_over_batch_to_torch() - test_utils_to_torch() + test_utils_to_torch_numpy() test_batch_pickle() test_batch_from_to_numpy_without_copy() test_batch_standard_compatibility() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f147ca326..ae07023ea 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -4,6 +4,7 @@ import numpy as np from copy import deepcopy from numbers import Number +from collections.abc import Collection from typing import Any, List, Tuple, Union, Iterator, Optional # Disable pickle warning related to torch, since it has been removed @@ -36,8 +37,11 @@ def _is_scalar(value: Any) -> bool: # 3. python object rather than dict / Batch / tensor # the check of dict / Batch is omitted because this only checks a value. # a dict / Batch will eventually check their values - value = np.asanyarray(value) - return value.size == 1 and not value.shape + if isinstance(value, torch.Tensor): + return value.numel() == 1 and not value.shape + else: + value = np.asanyarray(value) + return value.size == 1 and not value.shape def _is_number(value: Any) -> bool: @@ -53,16 +57,21 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: # convert the value to np.ndarray # convert to np.object data type if neither bool nor number + # raises an exception if array's elements are tensors themself v = np.asanyarray(v) if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) - if v.dtype == np.object and not v.shape: + if v.dtype == np.object: # scalar ndarray with np.object data type is very annoying # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) # a is not array([{}, {}], dtype=object), and a[0]={} results in # something very strange: # array([{}, array({}, dtype=object)], dtype=object) - v = v.item(0) + if not v.shape: + v = v.item(0) + elif any(isinstance(e, (np.ndarray, torch.Tensor)) + for e in v.reshape(-1)): + raise ValueError("Numpy arrays of tensors are not supported yet.") return v @@ -113,25 +122,29 @@ def _assert_type_keys(keys): def _parse_value(v: Any): - if isinstance(v, (list, tuple, np.ndarray)): - if not isinstance(v, np.ndarray) and \ - all(isinstance(e, torch.Tensor) for e in v): - v = torch.stack(v) - return v - v_ = _to_array_with_correct_type(v) - if v_.dtype == np.object and _is_batch_set(v): - v = Batch(v) # list of dict / Batch - else: - # normal data list (main case) - # or actually a data list with objects - v = v_ - elif isinstance(v, dict): + if isinstance(v, dict): v = Batch(v) elif isinstance(v, (Batch, torch.Tensor)): pass else: - # scalar case, convert to ndarray - v = _to_array_with_correct_type(v) + if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ + len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): + try: + return torch.stack(v) + except RuntimeError as e: + raise TypeError("Batch does not support non-stackable iterable" + " of torch.Tensor as unique value yet.") from e + try: + v_ = _to_array_with_correct_type(v) + except ValueError as e: + raise TypeError("Batch does not support heterogeneous list/tuple" + " of tensors as unique value yet.") from e + if _is_batch_set(v): + v = Batch(v) # list of dict / Batch + else: + # None, scalar, normal data list (main case) + # or an actual list of objects + v = v_ return v diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index bf4b3f62c..92a9db0f6 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -3,12 +3,12 @@ from numbers import Number from typing import Union, Optional -from tianshou.data import Batch +from tianshou.data.batch import _parse_value, Batch def to_numpy(x: Union[ - torch.Tensor, dict, Batch, np.ndarray]) -> Union[ - dict, Batch, np.ndarray]: + Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[ + Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without torch.Tensor.""" if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() @@ -17,13 +17,20 @@ def to_numpy(x: Union[ x[k] = to_numpy(v) elif isinstance(x, Batch): x.to_numpy() + elif isinstance(x, (list, tuple)): + try: + x = to_numpy(_parse_value(x)) + except TypeError: + x = [to_numpy(e) for e in x] + else: # fallback + x = np.asanyarray(x) return x -def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], +def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu' - ) -> Union[dict, Batch, torch.Tensor]: + ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without np.ndarray.""" if isinstance(x, torch.Tensor): if dtype is not None: @@ -36,14 +43,19 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], x.to_torch(dtype, device) elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) - elif isinstance(x, list) and len(x) > 0 and \ - all(isinstance(e, (np.number, np.bool_, Number)) for e in x): - x = to_torch(np.asanyarray(x), dtype, device) - elif isinstance(x, np.ndarray) and \ - isinstance(x.item(0), (np.number, np.bool_, Number)): - x = torch.from_numpy(x).to(device) - if dtype is not None: - x = x.type(dtype) + elif isinstance(x, (list, tuple)): + try: + x = to_torch(_parse_value(x), dtype, device) + except TypeError: + x = [to_torch(e, dtype, device) for e in x] + else: # fallback + x = np.asanyarray(x) + if issubclass(x.dtype.type, (np.bool_, np.number)): + x = torch.from_numpy(x).to(device) + if dtype is not None: + x = x.type(dtype) + else: + raise TypeError(f"object {x} cannot be converted to torch.") return x