这是indexloc提供的服务,不要输入任何密码
Skip to content

Fix padding of inconsistent keys with Batch.stack and Batch.cat #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_batch_cat_and_stack():
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2

# test batch with incompatible keys
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
Expand All @@ -177,6 +177,7 @@ def test_batch_cat_and_stack():
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

# test stack with compatible keys
b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]]))
Expand All @@ -194,6 +195,26 @@ def test_batch_cat_and_stack():
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0

# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])

b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)


def test_batch_over_batch_to_torch():
batch = Batch(
Expand Down
104 changes: 60 additions & 44 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
import numpy as np
from copy import deepcopy
from functools import reduce
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional

Expand All @@ -24,28 +23,45 @@ def _is_batch_set(data: Any) -> bool:
return False


def _create_value(inst: Any, size: int) -> Union[
def _create_value(inst: Any, size: int, stack=True) -> Union[
'Batch', np.ndarray, torch.Tensor]:
"""
:param bool stack: whether to stack or to concatenate. E.g. if inst has
shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape
of (10, 3, 5), otherwise (10, 5)
"""
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
is_scalar = \
isinstance(inst, Number) or \
issubclass(inst.__class__, np.generic) or \
(has_shape and not inst.shape)
if not stack and is_scalar:
# here we do not consider scalar types, following the
# behavior of numpy which does not support concatenation
# of zero-dimensional arrays (scalars)
raise TypeError(f"cannot cat {inst} with which is scalar")
if has_shape:
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
if issubclass(inst.dtype.type, (np.bool_, np.number)):
target_type = inst.dtype.type
else:
target_type = np.object
return np.full((size, *inst.shape),
return np.full(shape,
fill_value=None if target_type == np.object else 0,
dtype=target_type)
elif isinstance(inst, torch.Tensor):
return torch.full((size, *inst.shape),
return torch.full(shape,
fill_value=0,
device=inst.device,
dtype=inst.dtype)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
zero_batch.__dict__[key] = _create_value(val, size)
zero_batch.__dict__[key] = _create_value(val, size, stack=stack)
return zero_batch
elif isinstance(inst, (np.generic, Number)):
return _create_value(np.asarray(inst), size)
elif is_scalar:
return _create_value(np.asarray(inst), size, stack=stack)
else: # fall back to np.object
return np.array([None for _ in range(size)])

Expand Down Expand Up @@ -495,10 +511,12 @@ def cat_(self,
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
lens = [len(x) for x in batches]
sum_lens = [0]
for x in lens:
sum_lens.append(sum_lens[-1] + x)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
Expand All @@ -513,40 +531,15 @@ def cat_(self,
keys_partial = set.union(*keys_map) - keys_shared
_assert_type_keys(keys_partial)
for k in keys_partial:
is_dict = False
value = None
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
if isinstance(val, (dict, Batch)):
is_dict = True
else: # np.ndarray or torch.Tensor
value = val
break
if is_dict:
self.__dict__[k] = Batch.cat(
[e.get(k, Batch()) for e in batches])
else:
if isinstance(value, np.ndarray):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = np.zeros(shape, dtype=value.dtype)
arrs.append(e.get(k, pad))
self.__dict__[k] = np.concatenate(arrs)
elif isinstance(value, torch.Tensor):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = torch.zeros(shape,
dtype=value.dtype,
device=value.device)
arrs.append(e.get(k, pad))
self.__dict__[k] = torch.cat(arrs)
else:
raise TypeError(
f"cannot cat value with type {type(value)}, we only "
"support dict, Batch, np.ndarray, and torch.Tensor")
try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
except KeyError:
self.__dict__[k] = \
_create_value(val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val

@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
Expand Down Expand Up @@ -576,12 +569,14 @@ def stack_(self,
"""Stack a list of :class:`~tianshou.data.Batch` object into current
batch.
"""
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if len(self.__dict__) > 0:
batches = [self] + list(batches)
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
values_shared = [[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
Expand All @@ -593,7 +588,11 @@ def stack_(self,
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = reduce(set.symmetric_difference, keys_map)
keys_partial = set.difference(set.union(*keys_map), keys_shared)
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} "
f"is only supported with axis=0, but got axis={axis}!")
_assert_type_keys(keys_partial)
for k in keys_partial:
for i, e in enumerate(batches):
Expand All @@ -609,7 +608,24 @@ def stack_(self,
@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
new batch.
new batch. For keys that are not shared across all batches,
batches that do not have these keys will be padded by zeros. E.g.
::

>>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))
>>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.stack([a, b])
>>> c.a.shape
(2, 4, 4)
>>> c.b.shape
(2, 4, 6)
>>> c.common.c.shape
(2, 4, 5)

.. note::

If there are keys that are not shared across all batches, ``stack``
with ``axis != 0`` is undefined, and will cause an exception.
"""
batch = Batch()
batch.stack_(batches, axis)
Expand Down