From 54d3dfb8f4b1efd7ea061d997c0103aa6ba4377f Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sat, 18 Jul 2020 14:53:51 +0200 Subject: [PATCH 01/33] Enable converting list/tuple back and forth from/to numpy/torch. --- tianshou/data/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index bf4b3f62c..03cdda167 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -17,6 +17,8 @@ def to_numpy(x: Union[ x[k] = to_numpy(v) elif isinstance(x, Batch): x.to_numpy() + elif isinstance(x, (list, tuple)): + x = [to_numpy(x_i) for x_i in x] return x @@ -36,9 +38,8 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], x.to_torch(dtype, device) elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) - elif isinstance(x, list) and len(x) > 0 and \ - all(isinstance(e, (np.number, np.bool_, Number)) for e in x): - x = to_torch(np.asanyarray(x), dtype, device) + elif isinstance(x, (list, tuple)): + x = [to_torch(x_i, dtype, device) for x_i in x] elif isinstance(x, np.ndarray) and \ isinstance(x.item(0), (np.number, np.bool_, Number)): x = torch.from_numpy(x).to(device) From 628733df24b0ad684f407c98049668e38d2f04b9 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sat, 18 Jul 2020 14:58:58 +0200 Subject: [PATCH 02/33] Add fallbacks. --- tianshou/data/utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 03cdda167..c3cccfc6a 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -19,6 +19,8 @@ def to_numpy(x: Union[ x.to_numpy() elif isinstance(x, (list, tuple)): x = [to_numpy(x_i) for x_i in x] + else: # fallback + x = np.asanyarray(x) return x @@ -40,11 +42,14 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): x = [to_torch(x_i, dtype, device) for x_i in x] - elif isinstance(x, np.ndarray) and \ - isinstance(x.item(0), (np.number, np.bool_, Number)): - x = torch.from_numpy(x).to(device) - if dtype is not None: - x = x.type(dtype) + else: # fallback + x = np.asanyarray(x) + if isinstance(x.item(0), (np.number, np.bool_, Number)): + x = torch.from_numpy(x).to(device) + if dtype is not None: + x = x.type(dtype) + else: + raise TypeError(f"object {x} cannot be converted to torch.") return x From 85d13c155926efd1cfc1cbd0de6c9987077ea677 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sat, 18 Jul 2020 15:08:16 +0200 Subject: [PATCH 03/33] Fix PEP8 --- tianshou/data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index c3cccfc6a..e3cd32c66 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -19,7 +19,7 @@ def to_numpy(x: Union[ x.to_numpy() elif isinstance(x, (list, tuple)): x = [to_numpy(x_i) for x_i in x] - else: # fallback + else: # fallback x = np.asanyarray(x) return x @@ -42,7 +42,7 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): x = [to_torch(x_i, dtype, device) for x_i in x] - else: # fallback + else: # fallback x = np.asanyarray(x) if isinstance(x.item(0), (np.number, np.bool_, Number)): x = torch.from_numpy(x).to(device) From d90bceefab1265739a3aa22cddc86fad1b7b4eb9 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sat, 18 Jul 2020 15:14:50 +0200 Subject: [PATCH 04/33] Update unit tests. --- test/base/test_batch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 51b8bccdb..c6894c1c4 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -323,8 +323,10 @@ def test_utils_to_torch(): assert batch_torch_float.a.dtype == torch.float32 assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32 - array_list = [float('nan'), 1.0] - assert to_torch(array_list).dtype == torch.float64 + data_list = [float('nan'), 1] + data_list_torch = to_torch(data_list) + assert data_list_torch[0].dtype == torch.float64 + assert data_list_torch[1].dtype == torch.int64 def test_batch_pickle(): From a6e6c3f04dcd3165eaf93b3567f72ecfc0fc83bf Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 10:08:52 +0200 Subject: [PATCH 05/33] Type annotation. Robust dtype check. --- tianshou/data/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index e3cd32c66..1da20e36f 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -7,8 +7,8 @@ def to_numpy(x: Union[ - torch.Tensor, dict, Batch, np.ndarray]) -> Union[ - dict, Batch, np.ndarray]: + Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[ + Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without torch.Tensor.""" if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() @@ -24,10 +24,10 @@ def to_numpy(x: Union[ return x -def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], +def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu' - ) -> Union[dict, Batch, torch.Tensor]: + ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: """Return an object without np.ndarray.""" if isinstance(x, torch.Tensor): if dtype is not None: @@ -44,7 +44,7 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], x = [to_torch(x_i, dtype, device) for x_i in x] else: # fallback x = np.asanyarray(x) - if isinstance(x.item(0), (np.number, np.bool_, Number)): + if issubclass(x.dtype.type, (np.bool_, np.number)): x = torch.from_numpy(x).to(device) if dtype is not None: x = x.type(dtype) From 9965f6f1b2a95f750c6202dacfc0eddc1eba8bad Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 10:30:53 +0200 Subject: [PATCH 06/33] List of object are converted individually, as a single tensor otherwise. --- tianshou/data/utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 1da20e36f..4b0d916a5 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -18,7 +18,11 @@ def to_numpy(x: Union[ elif isinstance(x, Batch): x.to_numpy() elif isinstance(x, (list, tuple)): - x = [to_numpy(x_i) for x_i in x] + x = np.asanyarray(x) + if x.dtype == np.object: + x = [to_numpy(e) for e in x] + else: + x = to_numpy(x) else: # fallback x = np.asanyarray(x) return x @@ -41,7 +45,11 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): - x = [to_torch(x_i, dtype, device) for x_i in x] + x = np.asanyarray(x) + if x.dtype == np.object: + x = [to_torch(e, dtype, device) for e in x] + else: + x = to_torch(x, dtype, device) else: # fallback x = np.asanyarray(x) if issubclass(x.dtype.type, (np.bool_, np.number)): From 577740850c5a5bc80d48c4d45c0dc81054c8978f Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 10:48:25 +0200 Subject: [PATCH 07/33] Improve robustness of _to_array_with_correct_type --- tianshou/data/batch.py | 23 +++++++++++++---------- tianshou/data/utils.py | 8 ++++---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f147ca326..0c6d99c87 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -53,16 +53,19 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: # convert the value to np.ndarray # convert to np.object data type if neither bool nor number - v = np.asanyarray(v) - if not issubclass(v.dtype.type, (np.bool_, np.number)): - v = v.astype(np.object) - if v.dtype == np.object and not v.shape: - # scalar ndarray with np.object data type is very annoying - # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) - # a is not array([{}, {}], dtype=object), and a[0]={} results in - # something very strange: - # array([{}, array({}, dtype=object)], dtype=object) - v = v.item(0) + try: + v = np.asanyarray(v) + if not issubclass(v.dtype.type, (np.bool_, np.number)): + v = v.astype(np.object) + if v.dtype == np.object and not v.shape: + # scalar ndarray with np.object data type is very annoying + # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) + # a is not array([{}, {}], dtype=object), and a[0]={} results in + # something very strange: + # array([{}, array({}, dtype=object)], dtype=object) + v = v.item(0) + except ValueError: + pass return v diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 4b0d916a5..4982d1610 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -1,9 +1,9 @@ import torch import numpy as np from numbers import Number -from typing import Union, Optional +from typing import Any, Union, Optional -from tianshou.data import Batch +from tianshou.data.batch import _to_array_with_correct_type, Batch def to_numpy(x: Union[ @@ -18,7 +18,7 @@ def to_numpy(x: Union[ elif isinstance(x, Batch): x.to_numpy() elif isinstance(x, (list, tuple)): - x = np.asanyarray(x) + x = _to_array_with_correct_type(x) if x.dtype == np.object: x = [to_numpy(e) for e in x] else: @@ -45,7 +45,7 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): - x = np.asanyarray(x) + x = _to_array_with_correct_type(x) if x.dtype == np.object: x = [to_torch(e, dtype, device) for e in x] else: From 750dddde61c755d5b317bf4835552061ac67864a Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 10:48:41 +0200 Subject: [PATCH 08/33] Add unit tests. --- test/base/test_batch.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index c6894c1c4..c7cba26af 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -325,9 +325,22 @@ def test_utils_to_torch(): assert batch_torch_float.b.d.dtype == torch.float32 data_list = [float('nan'), 1] data_list_torch = to_torch(data_list) - assert data_list_torch[0].dtype == torch.float64 - assert data_list_torch[1].dtype == torch.int64 - + assert data_list_torch.dtype == torch.float64 + data_list_2 = [np.zeros((3, 3)), np.zeros((3, 3))] + data_list_2_torch = to_torch(data_list_2) + assert data_list_2_torch.shape == (2, 3, 3) + data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))] + data_list_3_torch = to_torch(data_list_3) + assert isinstance(data_list_3_torch, list) + assert isinstance(data_list_3_torch[0], torch.Tensor) + data_list_4 = [np.zeros(2), np.zeros((3, 3))] + data_list_4_torch = to_torch(data_list_4) + assert isinstance(data_list_4_torch, list) + assert isinstance(data_list_4_torch[0], torch.Tensor) + data_array = np.zeros((3, 2, 2)) + data_tensor = to_torch(data_array[[]]) + assert isinstance(data_tensor, torch.Tensor) + assert data_tensor.shape == (0, 2, 2) def test_batch_pickle(): batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), From 61d418624a569d355c5ab445a361cdd8ff68eb4e Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 10:54:53 +0200 Subject: [PATCH 09/33] Do not catch exception at _to_array_with_correct_type level. --- tianshou/data/batch.py | 24 ++++++++++-------------- tianshou/data/utils.py | 5 ++++- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 0c6d99c87..434e1346c 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -52,20 +52,16 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: # convert the value to np.ndarray - # convert to np.object data type if neither bool nor number - try: - v = np.asanyarray(v) - if not issubclass(v.dtype.type, (np.bool_, np.number)): - v = v.astype(np.object) - if v.dtype == np.object and not v.shape: - # scalar ndarray with np.object data type is very annoying - # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) - # a is not array([{}, {}], dtype=object), and a[0]={} results in - # something very strange: - # array([{}, array({}, dtype=object)], dtype=object) - v = v.item(0) - except ValueError: - pass + v = np.asanyarray(v) + if not issubclass(v.dtype.type, (np.bool_, np.number)): + v = v.astype(np.object) + if v.dtype == np.object and not v.shape: + # scalar ndarray with np.object data type is very annoying + # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) + # a is not array([{}, {}], dtype=object), and a[0]={} results in + # something very strange: + # array([{}, array({}, dtype=object)], dtype=object) + v = v.item(0) return v diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 4982d1610..3be17270d 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -45,7 +45,10 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): - x = _to_array_with_correct_type(x) + try: + x = _to_array_with_correct_type(x) + except ValueError: + pass if x.dtype == np.object: x = [to_torch(e, dtype, device) for e in x] else: From 5b0e1459c96dc4255ad5b077c2c9dd76c0ce3024 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 11:08:38 +0200 Subject: [PATCH 10/33] Use _parse_value --- test/base/test_batch.py | 6 +++++- tianshou/data/batch.py | 20 +++++++++++--------- tianshou/data/utils.py | 13 +++++-------- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index c7cba26af..56e078468 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -333,10 +333,14 @@ def test_utils_to_torch(): data_list_3_torch = to_torch(data_list_3) assert isinstance(data_list_3_torch, list) assert isinstance(data_list_3_torch[0], torch.Tensor) - data_list_4 = [np.zeros(2), np.zeros((3, 3))] + data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))] data_list_4_torch = to_torch(data_list_4) assert isinstance(data_list_4_torch, list) assert isinstance(data_list_4_torch[0], torch.Tensor) + data_list_5 = [np.zeros(2), np.zeros((3, 3))] + data_list_5_torch = to_torch(data_list_5) + assert isinstance(data_list_5_torch, list) + assert isinstance(data_list_5_torch[0], torch.Tensor) data_array = np.zeros((3, 2, 2)) data_tensor = to_torch(data_array[[]]) assert isinstance(data_tensor, torch.Tensor) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 434e1346c..46cce693f 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -115,15 +115,17 @@ def _parse_value(v: Any): if isinstance(v, (list, tuple, np.ndarray)): if not isinstance(v, np.ndarray) and \ all(isinstance(e, torch.Tensor) for e in v): - v = torch.stack(v) - return v - v_ = _to_array_with_correct_type(v) - if v_.dtype == np.object and _is_batch_set(v): - v = Batch(v) # list of dict / Batch - else: - # normal data list (main case) - # or actually a data list with objects - v = v_ + return torch.stack(v) + try: + v_ = _to_array_with_correct_type(v) + if v_.dtype == np.object and _is_batch_set(v): + v = Batch(v) # list of dict / Batch + else: + # normal data list (main case) + # or actually a data list with objects + v = v_ + except ValueError: + v = [_to_array_with_correct_type(e) for e in v] elif isinstance(v, dict): v = Batch(v) elif isinstance(v, (Batch, torch.Tensor)): diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 3be17270d..4f6296661 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -3,7 +3,7 @@ from numbers import Number from typing import Any, Union, Optional -from tianshou.data.batch import _to_array_with_correct_type, Batch +from tianshou.data.batch import _parse_value, Batch def to_numpy(x: Union[ @@ -18,8 +18,8 @@ def to_numpy(x: Union[ elif isinstance(x, Batch): x.to_numpy() elif isinstance(x, (list, tuple)): - x = _to_array_with_correct_type(x) - if x.dtype == np.object: + x = _parse_value(x) + if isinstance(x, (list, tuple)) or x.dtype == np.object: x = [to_numpy(e) for e in x] else: x = to_numpy(x) @@ -45,11 +45,8 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): - try: - x = _to_array_with_correct_type(x) - except ValueError: - pass - if x.dtype == np.object: + x = _parse_value(x) + if isinstance(x, (list, tuple)) or x.dtype == np.object: x = [to_torch(e, dtype, device) for e in x] else: x = to_torch(x, dtype, device) From ca9ae18dac5eaddddfb28c7638bfac7f37beda2d Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 11:17:37 +0200 Subject: [PATCH 11/33] Fix PEP8 --- tianshou/data/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 4f6296661..1953bc721 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -1,7 +1,7 @@ import torch import numpy as np from numbers import Number -from typing import Any, Union, Optional +from typing import Union, Optional from tianshou.data.batch import _parse_value, Batch From 1199310a6e8dd641dd48624afe61a71f5ed24c0c Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 11:28:12 +0200 Subject: [PATCH 12/33] Fix _parse_value list output type fallback. --- tianshou/data/batch.py | 5 ++++- tianshou/data/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 46cce693f..ccc9715c4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -125,7 +125,10 @@ def _parse_value(v: Any): # or actually a data list with objects v = v_ except ValueError: - v = [_to_array_with_correct_type(e) for e in v] + v_ = np.empty(len(v), dtype=np.object) + for i, e in enumerate(v): + v_[i] = _to_array_with_correct_type(e) + v = v_ elif isinstance(v, dict): v = Batch(v) elif isinstance(v, (Batch, torch.Tensor)): diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 1953bc721..59fd27434 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -19,7 +19,7 @@ def to_numpy(x: Union[ x.to_numpy() elif isinstance(x, (list, tuple)): x = _parse_value(x) - if isinstance(x, (list, tuple)) or x.dtype == np.object: + if x.dtype == np.object: x = [to_numpy(e) for e in x] else: x = to_numpy(x) @@ -46,7 +46,7 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): x = _parse_value(x) - if isinstance(x, (list, tuple)) or x.dtype == np.object: + if x.dtype == np.object: x = [to_torch(e, dtype, device) for e in x] else: x = to_torch(x, dtype, device) From fe6ad9a021869dd0a1dc44e98d9cb8e75ff6eea6 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 11:32:42 +0200 Subject: [PATCH 13/33] Catch torch exception. --- tianshou/data/batch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index ccc9715c4..03501311b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -113,10 +113,10 @@ def _assert_type_keys(keys): def _parse_value(v: Any): if isinstance(v, (list, tuple, np.ndarray)): - if not isinstance(v, np.ndarray) and \ - all(isinstance(e, torch.Tensor) for e in v): - return torch.stack(v) try: + if not isinstance(v, np.ndarray) and \ + all(isinstance(e, torch.Tensor) for e in v): + return torch.stack(v) v_ = _to_array_with_correct_type(v) if v_.dtype == np.object and _is_batch_set(v): v = Batch(v) # list of dict / Batch @@ -124,7 +124,7 @@ def _parse_value(v: Any): # normal data list (main case) # or actually a data list with objects v = v_ - except ValueError: + except (ValueError, RuntimeError): v_ = np.empty(len(v), dtype=np.object) for i, e in enumerate(v): v_[i] = _to_array_with_correct_type(e) From 9d784f41ddb1834a3cc8cc09f9560b486d8a44b8 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 11:36:20 +0200 Subject: [PATCH 14/33] Do not convert torch tensor during fallback. --- tianshou/data/batch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 03501311b..4e165d211 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -127,7 +127,10 @@ def _parse_value(v: Any): except (ValueError, RuntimeError): v_ = np.empty(len(v), dtype=np.object) for i, e in enumerate(v): - v_[i] = _to_array_with_correct_type(e) + if not isinstance(e, torch.Tensor): + v_[i] = _to_array_with_correct_type(e) + else: + v_[i] = e v = v_ elif isinstance(v, dict): v = Batch(v) From d6bea246cd0e75923f7addf903c6be10848df012 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 15:57:44 +0200 Subject: [PATCH 15/33] Improve unit tests. --- test/base/test_batch.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 56e078468..28a8f604a 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -3,6 +3,7 @@ import pickle import pytest import numpy as np +from itertools import starmap from tianshou.data import Batch, to_torch @@ -307,7 +308,7 @@ def test_batch_over_batch_to_torch(): assert batch.b.d.dtype == torch.float32 -def test_utils_to_torch(): +def test_utils_to_torch_numpy(): batch = Batch( a=np.float64(1.0), b=Batch( @@ -326,25 +327,32 @@ def test_utils_to_torch(): data_list = [float('nan'), 1] data_list_torch = to_torch(data_list) assert data_list_torch.dtype == torch.float64 - data_list_2 = [np.zeros((3, 3)), np.zeros((3, 3))] + data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)] data_list_2_torch = to_torch(data_list_2) assert data_list_2_torch.shape == (2, 3, 3) + assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2) data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))] data_list_3_torch = to_torch(data_list_3) assert isinstance(data_list_3_torch, list) - assert isinstance(data_list_3_torch[0], torch.Tensor) + assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch) + assert all(starmap(np.allclose, + zip(to_numpy(to_torch(data_list_3)), data_list_3))) data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))] data_list_4_torch = to_torch(data_list_4) assert isinstance(data_list_4_torch, list) - assert isinstance(data_list_4_torch[0], torch.Tensor) + assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch) + assert all(starmap(np.allclose, + zip(to_numpy(to_torch(data_list_4)), data_list_4))) data_list_5 = [np.zeros(2), np.zeros((3, 3))] data_list_5_torch = to_torch(data_list_5) assert isinstance(data_list_5_torch, list) - assert isinstance(data_list_5_torch[0], torch.Tensor) + assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch) data_array = np.zeros((3, 2, 2)) data_tensor = to_torch(data_array[[]]) assert isinstance(data_tensor, torch.Tensor) assert data_tensor.shape == (0, 2, 2) + assert np.allclose(to_numpy(to_torch(data_array)), data_array) + def test_batch_pickle(): batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), @@ -451,7 +459,7 @@ def test_batch_standard_compatibility(): test_batch() test_batch_over_batch() test_batch_over_batch_to_torch() - test_utils_to_torch() + test_utils_to_torch_numpy() test_batch_pickle() test_batch_from_to_numpy_without_copy() test_batch_standard_compatibility() From bfe0a3964e3eb8b71aa03ad73fa2fb5d649a591a Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 16:01:11 +0200 Subject: [PATCH 16/33] Add unit tests. --- test/base/test_batch.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 28a8f604a..7fcb9e1da 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -347,10 +347,13 @@ def test_utils_to_torch_numpy(): data_list_5_torch = to_torch(data_list_5) assert isinstance(data_list_5_torch, list) assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch) - data_array = np.zeros((3, 2, 2)) - data_tensor = to_torch(data_array[[]]) - assert isinstance(data_tensor, torch.Tensor) - assert data_tensor.shape == (0, 2, 2) + data_array = np.random.rand(3, 2, 2) + data_empty_tensor = to_torch(data_array[[]]) + assert isinstance(data_empty_tensor, torch.Tensor) + assert data_empty_tensor.shape == (0, 2, 2) + data_empty_array = to_numpy(data_empty_tensor) + assert isinstance(data_empty_array, np.ndarray) + assert data_empty_array.shape == (0, 2, 2) assert np.allclose(to_numpy(to_torch(data_array)), data_array) From edaed5a89e00ae1a195fee54291bd8965c4a67f8 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 16:02:33 +0200 Subject: [PATCH 17/33] FIx missing import --- test/base/test_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 7fcb9e1da..829d167e2 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -5,7 +5,7 @@ import numpy as np from itertools import starmap -from tianshou.data import Batch, to_torch +from tianshou.data import Batch, to_torch, to_numpy def test_batch(): From da4bbc26b5dd86a3efed1f1a9ae907c8d6a7d1b8 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 16:23:11 +0200 Subject: [PATCH 18/33] Remove support of numpy arrays of tensors for Batch value parser. --- test/base/test_batch.py | 4 ++++ tianshou/data/batch.py | 9 ++------- tianshou/data/utils.py | 14 ++++++-------- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 829d167e2..4ed8b973d 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -29,6 +29,10 @@ def test_batch(): assert b.a == 3 with pytest.raises(AssertionError): Batch({1: 2}) + with pytest.raises(TypeError): + Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]) + with pytest.raises(TypeError): + Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch = Batch(obs=[0], np=np.zeros([3, 4])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4e165d211..1096e0257 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -125,13 +125,8 @@ def _parse_value(v: Any): # or actually a data list with objects v = v_ except (ValueError, RuntimeError): - v_ = np.empty(len(v), dtype=np.object) - for i, e in enumerate(v): - if not isinstance(e, torch.Tensor): - v_[i] = _to_array_with_correct_type(e) - else: - v_[i] = e - v = v_ + raise TypeError("Batch does not support non-stackable list/tuple of "\ + "tensors as value yet.") elif isinstance(v, dict): v = Batch(v) elif isinstance(v, (Batch, torch.Tensor)): diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py index 59fd27434..92a9db0f6 100644 --- a/tianshou/data/utils.py +++ b/tianshou/data/utils.py @@ -18,11 +18,10 @@ def to_numpy(x: Union[ elif isinstance(x, Batch): x.to_numpy() elif isinstance(x, (list, tuple)): - x = _parse_value(x) - if x.dtype == np.object: + try: + x = to_numpy(_parse_value(x)) + except TypeError: x = [to_numpy(e) for e in x] - else: - x = to_numpy(x) else: # fallback x = np.asanyarray(x) return x @@ -45,11 +44,10 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], elif isinstance(x, (np.number, np.bool_, Number)): x = to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, (list, tuple)): - x = _parse_value(x) - if x.dtype == np.object: + try: + x = to_torch(_parse_value(x), dtype, device) + except TypeError: x = [to_torch(e, dtype, device) for e in x] - else: - x = to_torch(x, dtype, device) else: # fallback x = np.asanyarray(x) if issubclass(x.dtype.type, (np.bool_, np.number)): From 53ed48d16ce9986a5ecc682f403445dd5832b4c1 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 16:32:11 +0200 Subject: [PATCH 19/33] Forbid numpy arrays of tensors. --- test/base/test_batch.py | 2 ++ tianshou/data/batch.py | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 4ed8b973d..6263a24b2 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -31,6 +31,8 @@ def test_batch(): Batch({1: 2}) with pytest.raises(TypeError): Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]) + with pytest.raises(TypeError): + Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 1096e0257..b28cee85a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -55,13 +55,16 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: v = np.asanyarray(v) if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) - if v.dtype == np.object and not v.shape: + if v.dtype == np.object: # scalar ndarray with np.object data type is very annoying # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) # a is not array([{}, {}], dtype=object), and a[0]={} results in # something very strange: # array([{}, array({}, dtype=object)], dtype=object) - v = v.item(0) + if not v.shape: + v = v.item(0) + elif isinstance(v.item(0), (np.ndarray, torch.Tensor)): + raise ValueError("Numpy arrays of tensors are not supported yet.") return v From 7b96800d286ccf42caadde6ef015572c63970cb8 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 16:35:51 +0200 Subject: [PATCH 20/33] Fix PEP8. --- tianshou/data/batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index b28cee85a..2ce895dec 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -128,8 +128,8 @@ def _parse_value(v: Any): # or actually a data list with objects v = v_ except (ValueError, RuntimeError): - raise TypeError("Batch does not support non-stackable list/tuple of "\ - "tensors as value yet.") + raise TypeError("Batch does not support non-stackable list/tuple " + "of tensors as value yet.") elif isinstance(v, dict): v = Batch(v) elif isinstance(v, (Batch, torch.Tensor)): From d91189e2ee161e0336bb6916f25005824f8a006a Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 16:43:18 +0200 Subject: [PATCH 21/33] Fix comment. --- tianshou/data/batch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 2ce895dec..9926c04cc 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -52,6 +52,8 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: # convert the value to np.ndarray + # convert to np.object data type if neither bool nor number + # raises an exception if array's elements are tensors themself v = np.asanyarray(v) if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(np.object) From ca2dd9e7bfc1381d6c90ff8a3bcd511304c8a4fb Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 16:51:50 +0200 Subject: [PATCH 22/33] Reduce _parse_value branch number. --- tianshou/data/batch.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 9926c04cc..b0f228b89 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -4,6 +4,7 @@ import numpy as np from copy import deepcopy from numbers import Number +from collections.abc import Iterable from typing import Any, List, Tuple, Union, Iterator, Optional # Disable pickle warning related to torch, since it has been removed @@ -117,28 +118,25 @@ def _assert_type_keys(keys): def _parse_value(v: Any): - if isinstance(v, (list, tuple, np.ndarray)): + if isinstance(v, dict): + v = Batch(v) + elif isinstance(v, (Batch, torch.Tensor)): + pass + else: try: - if not isinstance(v, np.ndarray) and \ + if not isinstance(v, np.ndarray) and isinstance(v, Iterable) and \ all(isinstance(e, torch.Tensor) for e in v): return torch.stack(v) v_ = _to_array_with_correct_type(v) if v_.dtype == np.object and _is_batch_set(v): v = Batch(v) # list of dict / Batch else: - # normal data list (main case) + # scalar case, normal data list (main case) # or actually a data list with objects v = v_ except (ValueError, RuntimeError): raise TypeError("Batch does not support non-stackable list/tuple " "of tensors as value yet.") - elif isinstance(v, dict): - v = Batch(v) - elif isinstance(v, (Batch, torch.Tensor)): - pass - else: - # scalar case, convert to ndarray - v = _to_array_with_correct_type(v) return v From fb0f74e78d4d681b4428c43fb10c73f289108f4c Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 17:05:40 +0200 Subject: [PATCH 23/33] Fix None value. --- tianshou/data/batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index b0f228b89..fb679ab31 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -128,11 +128,11 @@ def _parse_value(v: Any): all(isinstance(e, torch.Tensor) for e in v): return torch.stack(v) v_ = _to_array_with_correct_type(v) - if v_.dtype == np.object and _is_batch_set(v): + if _is_batch_set(v): v = Batch(v) # list of dict / Batch else: - # scalar case, normal data list (main case) - # or actually a data list with objects + # None, scalar, normal data list (main case) + # or an actual list of objects v = v_ except (ValueError, RuntimeError): raise TypeError("Batch does not support non-stackable list/tuple " From 4c4357d0e83a9ea755fce2e6fcaf675f0308e1c0 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 20 Jul 2020 17:22:40 +0200 Subject: [PATCH 24/33] Forward error message for debugging purpose. --- tianshou/data/batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index fb679ab31..6ef95cf3a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -134,9 +134,9 @@ def _parse_value(v: Any): # None, scalar, normal data list (main case) # or an actual list of objects v = v_ - except (ValueError, RuntimeError): + except (ValueError, RuntimeError) as e: raise TypeError("Batch does not support non-stackable list/tuple " - "of tensors as value yet.") + "of tensors as value yet: \n" + str(e)) return v From 128a81632226750da80615b522b8c25b0d76945a Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 08:26:41 +0200 Subject: [PATCH 25/33] Fix _is_scalar. --- tianshou/data/batch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6ef95cf3a..09ada1e5e 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -37,8 +37,11 @@ def _is_scalar(value: Any) -> bool: # 3. python object rather than dict / Batch / tensor # the check of dict / Batch is omitted because this only checks a value. # a dict / Batch will eventually check their values - value = np.asanyarray(value) - return value.size == 1 and not value.shape + if isinstance(value, (np.ndarray, torch.Tensor)): + return value.size == 1 and not value.shape + else: + value = np.asanyarray(value) + return _is_scalar(value) def _is_number(value: Any) -> bool: From 26b739b2dc4df9f73ccdf4d114d55a43e91c587e Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 08:27:23 +0200 Subject: [PATCH 26/33] More specific try/catch blocks. --- tianshou/data/batch.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 09ada1e5e..8109c5c8a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -126,20 +126,26 @@ def _parse_value(v: Any): elif isinstance(v, (Batch, torch.Tensor)): pass else: - try: - if not isinstance(v, np.ndarray) and isinstance(v, Iterable) and \ - all(isinstance(e, torch.Tensor) for e in v): + if not isinstance(v, np.ndarray) and isinstance(v, Iterable) and \ + all(isinstance(e, torch.Tensor) for e in v): + try: return torch.stack(v) + except RuntimeError as e: + raise Exception([ + TypeError("Batch does not support non-stackable list/tuple " + "of torch.Tensor as unique value yet."), e]) + try: v_ = _to_array_with_correct_type(v) - if _is_batch_set(v): - v = Batch(v) # list of dict / Batch - else: - # None, scalar, normal data list (main case) - # or an actual list of objects - v = v_ - except (ValueError, RuntimeError) as e: - raise TypeError("Batch does not support non-stackable list/tuple " - "of tensors as value yet: \n" + str(e)) + except ValueError as e: + raise Exception([ + TypeError("Batch does not support heterogeneous list/tuple " + "of tensors as unique value yet."), e]) + if _is_batch_set(v): + v = Batch(v) # list of dict / Batch + else: + # None, scalar, normal data list (main case) + # or an actual list of objects + v = v_ return v From cca2c10df9ef082a90e20b7eaf1ba477e91c229e Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 08:30:53 +0200 Subject: [PATCH 27/33] Fix exception chaining. --- tianshou/data/batch.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 8109c5c8a..3bcf4302a 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -131,15 +131,13 @@ def _parse_value(v: Any): try: return torch.stack(v) except RuntimeError as e: - raise Exception([ - TypeError("Batch does not support non-stackable list/tuple " - "of torch.Tensor as unique value yet."), e]) + raise TypeError("Batch does not support non-stackable list/tuple " + "of torch.Tensor as unique value yet.") from e try: v_ = _to_array_with_correct_type(v) except ValueError as e: - raise Exception([ - TypeError("Batch does not support heterogeneous list/tuple " - "of tensors as unique value yet."), e]) + raise TypeError("Batch does not support heterogeneous list/tuple " + "of tensors as unique value yet.") from e if _is_batch_set(v): v = Batch(v) # list of dict / Batch else: From ffbd0176f7b6333ffb1d21ba064e479d68e5e040 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 08:38:24 +0200 Subject: [PATCH 28/33] Fix PEP8. --- tianshou/data/batch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 3bcf4302a..9b2872be7 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -131,13 +131,13 @@ def _parse_value(v: Any): try: return torch.stack(v) except RuntimeError as e: - raise TypeError("Batch does not support non-stackable list/tuple " - "of torch.Tensor as unique value yet.") from e + raise TypeError("Batch does not support non-stackable iterable" + " of torch.Tensor as unique value yet.") from e try: v_ = _to_array_with_correct_type(v) except ValueError as e: - raise TypeError("Batch does not support heterogeneous list/tuple " - "of tensors as unique value yet.") from e + raise TypeError("Batch does not support heterogeneous list/tuple" + " of tensors as unique value yet.") from e if _is_batch_set(v): v = Batch(v) # list of dict / Batch else: From 7cef13d4c82582b704e07d3834e0425143399e3d Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 09:11:44 +0200 Subject: [PATCH 29/33] Fix _is_scalar. --- tianshou/data/batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 9b2872be7..c5008c83c 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -37,11 +37,11 @@ def _is_scalar(value: Any) -> bool: # 3. python object rather than dict / Batch / tensor # the check of dict / Batch is omitted because this only checks a value. # a dict / Batch will eventually check their values - if isinstance(value, (np.ndarray, torch.Tensor)): - return value.size == 1 and not value.shape + if isinstance(value, torch.Tensor): + return value.numel() == 1 and not value.shape else: value = np.asanyarray(value) - return _is_scalar(value) + return value.size == 1 and not value.shape def _is_number(value: Any) -> bool: From f793320a432697d0b7a99d3d9f3ef19ef3ffffb7 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 09:34:33 +0200 Subject: [PATCH 30/33] Fix missing corner case. --- test/base/test_batch.py | 4 ++++ tianshou/data/batch.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 6263a24b2..3cf8485fd 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -35,6 +35,10 @@ def test_batch(): Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) + with pytest.raises(TypeError): + Batch(a=[torch.zeros(3,3), np.zeros([3,3])]) + with pytest.raises(TypeError): + Batch(a=[1, np.zeros([3,3]), torch.zeros(3,3)]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch = Batch(obs=[0], np=np.zeros([3, 4])) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index c5008c83c..644202ca5 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -69,7 +69,7 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: # array([{}, array({}, dtype=object)], dtype=object) if not v.shape: v = v.item(0) - elif isinstance(v.item(0), (np.ndarray, torch.Tensor)): + elif any(isinstance(e, (np.ndarray, torch.Tensor)) for e in v): raise ValueError("Numpy arrays of tensors are not supported yet.") return v From 3f893634d12fe7655f37df86c36708d598f3aff3 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 09:37:08 +0200 Subject: [PATCH 31/33] Fix PEP8. --- test/base/test_batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 3cf8485fd..d82fb3039 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -36,9 +36,9 @@ def test_batch(): with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) with pytest.raises(TypeError): - Batch(a=[torch.zeros(3,3), np.zeros([3,3])]) + Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) with pytest.raises(TypeError): - Batch(a=[1, np.zeros([3,3]), torch.zeros(3,3)]) + Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch = Batch(obs=[0], np=np.zeros([3, 4])) From 99c7f95160cea10847eca58801bc98b390994f43 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 09:48:28 +0200 Subject: [PATCH 32/33] Allow Batch empty key. --- test/base/test_batch.py | 1 + tianshou/data/batch.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index d82fb3039..77d1343ef 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -41,6 +41,7 @@ def test_batch(): Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) batch = Batch(a=[torch.ones(3), torch.ones(3)]) assert torch.allclose(batch.a, torch.ones(2, 3)) + Batch(a=[]) batch = Batch(obs=[0], np=np.zeros([3, 4])) assert batch.obs == batch["obs"] batch.obs = [1] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 644202ca5..647dc7845 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -4,7 +4,7 @@ import numpy as np from copy import deepcopy from numbers import Number -from collections.abc import Iterable +from collections.abc import Collection from typing import Any, List, Tuple, Union, Iterator, Optional # Disable pickle warning related to torch, since it has been removed @@ -126,8 +126,8 @@ def _parse_value(v: Any): elif isinstance(v, (Batch, torch.Tensor)): pass else: - if not isinstance(v, np.ndarray) and isinstance(v, Iterable) and \ - all(isinstance(e, torch.Tensor) for e in v): + if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \ + len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v): try: return torch.stack(v) except RuntimeError as e: From a7e777965027cd76938201213b2865efda4dad84 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 21 Jul 2020 10:27:08 +0200 Subject: [PATCH 33/33] Fix multi-dim array datatype check. --- tianshou/data/batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 647dc7845..ae07023ea 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -69,7 +69,8 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: # array([{}, array({}, dtype=object)], dtype=object) if not v.shape: v = v.item(0) - elif any(isinstance(e, (np.ndarray, torch.Tensor)) for e in v): + elif any(isinstance(e, (np.ndarray, torch.Tensor)) + for e in v.reshape(-1)): raise ValueError("Numpy arrays of tensors are not supported yet.") return v