这是indexloc提供的服务,不要输入任何密码
Skip to content

fix numpy>=1.20 typing check #323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 30, 2021
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_version() -> str:
install_requires=[
"gym>=0.15.4",
"tqdm",
"numpy!=1.16.0,<1.20.0", # https://github.com/numpy/numpy/issues/12793
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",
Expand Down
40 changes: 21 additions & 19 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def test_batch():
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
assert b.c.dtype == np.object
assert b.c.dtype == object
b = Batch(d=[None], e=[starmap], f=Batch)
assert b.d.dtype == b.e.dtype == np.object and b.f == Batch
assert b.d.dtype == b.e.dtype == object and b.f == Batch
b = Batch()
b.update()
assert b.is_empty()
Expand Down Expand Up @@ -153,10 +153,10 @@ def test_batch():
batch3[0] = Batch(a={"c": 2, "e": 1})
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
assert batch4.a.dtype == object # auto convert to object
batch4.update(a=np.array(['c', 'd']))
assert list(batch4.a) == ['c', 'd']
assert batch4.a.dtype == np.object # auto convert to np.object
assert batch4.a.dtype == object # auto convert to object
batch5 = Batch(a=np.array([{'index': 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
Expand Down Expand Up @@ -405,21 +405,23 @@ def test_utils_to_torch_numpy():
assert data_list_2_torch.shape == (2, 3, 3)
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
data_list_3_torch = to_torch(data_list_3)
assert isinstance(data_list_3_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_3)
with pytest.raises(TypeError):
to_numpy(data_list_3_torch)
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
data_list_4_torch = to_torch(data_list_4)
assert isinstance(data_list_4_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_4)
with pytest.raises(TypeError):
to_numpy(data_list_4_torch)
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
data_list_5_torch = to_torch(data_list_5)
assert isinstance(data_list_5_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))]
with pytest.raises(TypeError):
to_torch(data_list_5)
with pytest.raises(TypeError):
to_numpy(data_list_5_torch)
data_array = np.random.rand(3, 2, 2)
data_empty_tensor = to_torch(data_array[[]])
assert isinstance(data_empty_tensor, torch.Tensor)
Expand Down Expand Up @@ -508,10 +510,10 @@ def test_batch_empty():
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': np.array([2., 'st'], dtype=np.object),
b={'c': np.array([2., 'st'], dtype=object),
'd': [1, None],
'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
c=np.array([1, 3, 4], dtype=int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
Expand Down
14 changes: 7 additions & 7 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_replaybuffer(size=10, bufsize=20):
done=done, obs_next=obs_next, info=info))
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
assert buf.act.dtype == np.int
assert buf.act.dtype == int
assert buf.act.shape == (bufsize, 1)
data, indice = buf.sample(bufsize * 2)
assert (indice < len(buf)).all()
Expand All @@ -50,9 +50,9 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.obs_next[0] == 'str'
assert np.all(b.obs[1:] == 0)
assert np.all(b.obs_next[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert b.info.a[0] == 3 and b.info.a.dtype == int
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
assert np.all(b.info.b.c[1:] == 0.0)
assert ptr.shape == (1,) and ptr[0] == 0
assert ep_rew.shape == (1,) and ep_rew[0] == 1
Expand Down Expand Up @@ -180,8 +180,8 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
assert len(buf2) == min(bufsize, 3 * (i + 1))
# check single buffer's data
assert buf.info.key.shape == (buf.maxsize,)
assert buf.rew.dtype == np.float
assert buf.done.dtype == np.bool_
assert buf.rew.dtype == float
assert buf.done.dtype == bool
data, indice = buf.sample(len(buf) // 2)
buf.update_weight(indice, -data.weight / 2)
assert np.allclose(buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha)
Expand Down Expand Up @@ -273,7 +273,7 @@ def test_segtree():
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, np.int)
naive = np.ones(actual_len, int)
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
Expand Down Expand Up @@ -485,7 +485,7 @@ def test_replaybuffermanager():
buf.set_batch(batch)
assert np.allclose(buf.buffers[-1].info, [1] * 5)
assert buf.sample_index(-1).tolist() == []
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == object


def test_cachedbuffer():
Expand Down
52 changes: 20 additions & 32 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
from collections.abc import Collection
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence

IndexType = Union[slice, int, np.ndarray, List[int]]


def _is_batch_set(data: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
# or 1-D np.ndarray with np.object type,
# or 1-D np.ndarray with object type,
# where each element is a dict/Batch object
if isinstance(data, np.ndarray): # most often case
# "for e in data" will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
return data.dtype == np.object and all(
return data.dtype == object and all(
isinstance(e, (dict, Batch)) for e in data)
elif isinstance(data, (list, tuple)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
Expand Down Expand Up @@ -50,13 +52,13 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)):
return v # most often case
# convert the value to np.ndarray
# convert to np.object data type if neither bool nor number
# convert to object data type if neither bool nor number
# raises an exception if array's elements are tensors themself
v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
if v.dtype == np.object:
# scalar ndarray with np.object data type is very annoying
v = v.astype(object)
if v.dtype == object:
# scalar ndarray with object data type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
Expand Down Expand Up @@ -87,13 +89,11 @@ def _create_value(
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
else:
target_type = np.object
target_type = inst.dtype.type if issubclass(
inst.dtype.type, (np.bool_, np.number)) else object
return np.full(
shape,
fill_value=None if target_type == np.object else 0,
fill_value=None if target_type == object else 0,
dtype=target_type
)
elif isinstance(inst, torch.Tensor):
Expand All @@ -105,8 +105,8 @@ def _create_value(
return zero_batch
elif is_scalar:
return _create_value(np.asarray(inst), size, stack=stack)
else: # fall back to np.object
return np.array([None for _ in range(size)])
else: # fall back to object
return np.array([None for _ in range(size)], object)


def _assert_type_keys(keys: Iterable[str]) -> None:
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(
for k, v in batch_dict.items():
self.__dict__[k] = _parse_value(v)
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
self.stack_(batch_dict) # type: ignore
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy) # type: ignore

Expand Down Expand Up @@ -223,9 +223,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
"""
self.__init__(**state) # type: ignore

def __getitem__(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
) -> Any:
def __getitem__(self, index: Union[str, IndexType]) -> Any:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
Expand All @@ -241,11 +239,7 @@ def __getitem__(
else:
raise IndexError("Cannot access item from empty Batch object.")

def __setitem__(
self,
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
value: Any,
) -> None:
def __setitem__(self, index: Union[str, IndexType], value: Any) -> None:
"""Assign value to self[index]."""
value = _parse_value(value)
if isinstance(index, str):
Expand Down Expand Up @@ -530,8 +524,7 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
self.__dict__[k] = Batch.stack(v, axis)
else: # most often case is np.ndarray
v = np.stack(v, axis)
self.__dict__[k] = _to_array_with_correct_type(v)
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
# keys that are reserved in all batches
Expand Down Expand Up @@ -587,9 +580,7 @@ def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch":
batch.stack_(batches, axis)
return batch

def empty_(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> "Batch":
def empty_(self, index: Optional[Union[slice, IndexType]] = None) -> "Batch":
"""Return an empty Batch object with 0 or None filled.

If "index" is specified, it will only reset the specific indexed-data.
Expand Down Expand Up @@ -620,7 +611,7 @@ def empty_(
elif v is None:
continue
elif isinstance(v, np.ndarray):
if v.dtype == np.object:
if v.dtype == object:
self.__dict__[k][index] = None
else:
self.__dict__[k][index] = 0
Expand All @@ -636,10 +627,7 @@ def empty_(
return self

@staticmethod
def empty(
batch: "Batch",
index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None,
) -> "Batch":
def empty(batch: "Batch", index: Optional[IndexType] = None) -> "Batch":
"""Return an empty Batch object with 0 or None filled.

The shape is the same as the given Batch.
Expand Down
Loading