这是indexloc提供的服务,不要输入任何密码
Skip to content

Improve to_torch/to_numpy converters #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
54d3dfb
Enable converting list/tuple back and forth from/to numpy/torch.
Jul 18, 2020
628733d
Add fallbacks.
Jul 18, 2020
85d13c1
Fix PEP8
Jul 18, 2020
d90bcee
Update unit tests.
Jul 18, 2020
a6e6c3f
Type annotation. Robust dtype check.
Jul 20, 2020
9965f6f
List of object are converted individually, as a single tensor otherwise.
Jul 20, 2020
5777408
Improve robustness of _to_array_with_correct_type
Jul 20, 2020
750dddd
Add unit tests.
Jul 20, 2020
61d4186
Do not catch exception at _to_array_with_correct_type level.
Jul 20, 2020
5b0e145
Use _parse_value
Jul 20, 2020
ca9ae18
Fix PEP8
Jul 20, 2020
1199310
Fix _parse_value list output type fallback.
Jul 20, 2020
fe6ad9a
Catch torch exception.
Jul 20, 2020
9d784f4
Do not convert torch tensor during fallback.
Jul 20, 2020
d6bea24
Improve unit tests.
Jul 20, 2020
bfe0a39
Add unit tests.
Jul 20, 2020
edaed5a
FIx missing import
Jul 20, 2020
da4bbc2
Remove support of numpy arrays of tensors for Batch value parser.
Jul 20, 2020
53ed48d
Forbid numpy arrays of tensors.
Jul 20, 2020
7b96800
Fix PEP8.
Jul 20, 2020
d91189e
Fix comment.
Jul 20, 2020
ca2dd9e
Reduce _parse_value branch number.
Jul 20, 2020
fb0f74e
Fix None value.
Jul 20, 2020
4c4357d
Forward error message for debugging purpose.
Jul 20, 2020
128a816
Fix _is_scalar.
Jul 21, 2020
26b739b
More specific try/catch blocks.
Jul 21, 2020
cca2c10
Fix exception chaining.
Jul 21, 2020
ffbd017
Fix PEP8.
Jul 21, 2020
7cef13d
Fix _is_scalar.
Jul 21, 2020
f793320
Fix missing corner case.
Jul 21, 2020
3f89363
Fix PEP8.
Jul 21, 2020
99c7f95
Allow Batch empty key.
Jul 21, 2020
a7e7779
Fix multi-dim array datatype check.
Jul 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pickle
import pytest
import numpy as np
from itertools import starmap

from tianshou.data import Batch, to_torch
from tianshou.data import Batch, to_torch, to_numpy


def test_batch():
Expand All @@ -28,8 +29,19 @@ def test_batch():
assert b.a == 3
with pytest.raises(AssertionError):
Batch({1: 2})
with pytest.raises(TypeError):
Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert torch.allclose(batch.a, torch.ones(2, 3))
Batch(a=[])
batch = Batch(obs=[0], np=np.zeros([3, 4]))
assert batch.obs == batch["obs"]
batch.obs = [1]
Expand Down Expand Up @@ -307,7 +319,7 @@ def test_batch_over_batch_to_torch():
assert batch.b.d.dtype == torch.float32


def test_utils_to_torch():
def test_utils_to_torch_numpy():
batch = Batch(
a=np.float64(1.0),
b=Batch(
Expand All @@ -323,8 +335,37 @@ def test_utils_to_torch():
assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32
array_list = [float('nan'), 1.0]
assert to_torch(array_list).dtype == torch.float64
data_list = [float('nan'), 1]
data_list_torch = to_torch(data_list)
assert data_list_torch.dtype == torch.float64
data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
data_list_2_torch = to_torch(data_list_2)
assert data_list_2_torch.shape == (2, 3, 3)
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
data_list_3_torch = to_torch(data_list_3)
assert isinstance(data_list_3_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
data_list_4_torch = to_torch(data_list_4)
assert isinstance(data_list_4_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
data_list_5_torch = to_torch(data_list_5)
assert isinstance(data_list_5_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
data_array = np.random.rand(3, 2, 2)
data_empty_tensor = to_torch(data_array[[]])
assert isinstance(data_empty_tensor, torch.Tensor)
assert data_empty_tensor.shape == (0, 2, 2)
data_empty_array = to_numpy(data_empty_tensor)
assert isinstance(data_empty_array, np.ndarray)
assert data_empty_array.shape == (0, 2, 2)
assert np.allclose(to_numpy(to_torch(data_array)), data_array)


def test_batch_pickle():
Expand Down Expand Up @@ -432,7 +473,7 @@ def test_batch_standard_compatibility():
test_batch()
test_batch_over_batch()
test_batch_over_batch_to_torch()
test_utils_to_torch()
test_utils_to_torch_numpy()
test_batch_pickle()
test_batch_from_to_numpy_without_copy()
test_batch_standard_compatibility()
Expand Down
51 changes: 32 additions & 19 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from copy import deepcopy
from numbers import Number
from collections.abc import Collection
from typing import Any, List, Tuple, Union, Iterator, Optional

# Disable pickle warning related to torch, since it has been removed
Expand Down Expand Up @@ -36,8 +37,11 @@ def _is_scalar(value: Any) -> bool:
# 3. python object rather than dict / Batch / tensor
# the check of dict / Batch is omitted because this only checks a value.
# a dict / Batch will eventually check their values
value = np.asanyarray(value)
return value.size == 1 and not value.shape
if isinstance(value, torch.Tensor):
return value.numel() == 1 and not value.shape
else:
value = np.asanyarray(value)
return value.size == 1 and not value.shape


def _is_number(value: Any) -> bool:
Expand All @@ -53,16 +57,21 @@ def _is_number(value: Any) -> bool:
def _to_array_with_correct_type(v: Any) -> np.ndarray:
# convert the value to np.ndarray
# convert to np.object data type if neither bool nor number
# raises an exception if array's elements are tensors themself
v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
if v.dtype == np.object and not v.shape:
if v.dtype == np.object:
# scalar ndarray with np.object data type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
# array([{}, array({}, dtype=object)], dtype=object)
v = v.item(0)
if not v.shape:
v = v.item(0)
elif any(isinstance(e, (np.ndarray, torch.Tensor))
for e in v.reshape(-1)):
raise ValueError("Numpy arrays of tensors are not supported yet.")
return v


Expand Down Expand Up @@ -113,25 +122,29 @@ def _assert_type_keys(keys):


def _parse_value(v: Any):
if isinstance(v, (list, tuple, np.ndarray)):
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
v = torch.stack(v)
return v
v_ = _to_array_with_correct_type(v)
if v_.dtype == np.object and _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# normal data list (main case)
# or actually a data list with objects
v = v_
elif isinstance(v, dict):
if isinstance(v, dict):
v = Batch(v)
elif isinstance(v, (Batch, torch.Tensor)):
pass
else:
# scalar case, convert to ndarray
v = _to_array_with_correct_type(v)
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
try:
return torch.stack(v)
except RuntimeError as e:
raise TypeError("Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.") from e
try:
v_ = _to_array_with_correct_type(v)
except ValueError as e:
raise TypeError("Batch does not support heterogeneous list/tuple"
" of tensors as unique value yet.") from e
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# None, scalar, normal data list (main case)
# or an actual list of objects
v = v_
return v


Expand Down
38 changes: 25 additions & 13 deletions tianshou/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from numbers import Number
from typing import Union, Optional

from tianshou.data import Batch
from tianshou.data.batch import _parse_value, Batch


def to_numpy(x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
dict, Batch, np.ndarray]:
Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
Expand All @@ -17,13 +17,20 @@ def to_numpy(x: Union[
x[k] = to_numpy(v)
elif isinstance(x, Batch):
x.to_numpy()
elif isinstance(x, (list, tuple)):
try:
x = to_numpy(_parse_value(x))
except TypeError:
x = [to_numpy(e) for e in x]
else: # fallback
x = np.asanyarray(x)
return x


def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]:
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, torch.Tensor):
if dtype is not None:
Expand All @@ -36,14 +43,19 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
x.to_torch(dtype, device)
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, list) and len(x) > 0 and \
all(isinstance(e, (np.number, np.bool_, Number)) for e in x):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, np.ndarray) and \
isinstance(x.item(0), (np.number, np.bool_, Number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
elif isinstance(x, (list, tuple)):
try:
x = to_torch(_parse_value(x), dtype, device)
except TypeError:
x = [to_torch(e, dtype, device) for e in x]
else: # fallback
x = np.asanyarray(x)
if issubclass(x.dtype.type, (np.bool_, np.number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
else:
raise TypeError(f"object {x} cannot be converted to torch.")
return x


Expand Down