这是indexloc提供的服务,不要输入任何密码
Skip to content

fix 2 bugs of batch #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_version() -> str:
install_requires=[
"gym>=0.15.4",
"tqdm",
"numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793
"numpy!=1.16.0,<1.20.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",
Expand Down
12 changes: 11 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def test_batch():
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
assert b.c.dtype == np.object
b = Batch()
b.update()
assert b.is_empty()
Expand Down Expand Up @@ -143,8 +145,10 @@ def test_batch():
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0
with pytest.raises(KeyError):
with pytest.raises(ValueError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
with pytest.raises(ValueError):
batch3[0] = Batch(a={"c": 2, "e": 1})
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
Expand Down Expand Up @@ -333,6 +337,12 @@ def test_batch_cat_and_stack():
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

# test with illegal input format
with pytest.raises(ValueError):
Batch.cat([[Batch(a=1)], [Batch(a=1)]])
with pytest.raises(ValueError):
Batch.stack([[Batch(a=1)], [Batch(a=1)]])

# exceptions
assert Batch.cat([]).is_empty()
assert Batch.stack([]).is_empty()
Expand Down
6 changes: 3 additions & 3 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,9 +621,9 @@ def test_multibuf_hdf5():
'done': i % 3 == 2,
'info': {"number": {"n": i, "t": info_t}, 'extra': None},
}
buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
buffers["vector"].add(**Batch.stack([kwargs, kwargs, kwargs]),
buffer_ids=[0, 1, 2])
buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
buffers["cached"].add(**Batch.stack([kwargs, kwargs, kwargs]),
cached_buffer_ids=[0, 1, 2])

# save
Expand Down Expand Up @@ -657,7 +657,7 @@ def test_multibuf_hdf5():
'done': False,
'info': {"number": {"n": i}, 'Timelimit.truncate': True},
}
buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]]))
buffers[k].add(**Batch.stack([kwargs, kwargs, kwargs, kwargs]))
act = np.zeros(buffers[k].maxsize)
if k == "vector":
act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
Expand Down
47 changes: 33 additions & 14 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
Sequence

# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
# https://github.com/pytorch/pytorch/pull/39003
warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.")


def _is_batch_set(data: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
Expand Down Expand Up @@ -91,6 +85,9 @@ def _create_value(
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
is_scalar = _is_scalar(inst)
if not stack and is_scalar:
# _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here
if isinstance(inst, Batch) and inst.is_empty(recurse=True):
return inst
# should never hit since it has already checked in Batch.cat_
# here we do not consider scalar types, following the behavior of numpy
# which does not support concatenation of zero-dimensional arrays
Expand Down Expand Up @@ -257,7 +254,7 @@ def __setitem__(
raise ValueError("Batch does not supported tensor assignment. "
"Use a compatible Batch or dict instead.")
if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError(
raise ValueError(
"Creating keys is not supported by item assignment.")
for key, val in self.items():
try:
Expand Down Expand Up @@ -449,12 +446,21 @@ def cat_(
"""Concatenate a list of (or one) Batch objects into current batch."""
if isinstance(batches, Batch):
batches = [batches]
if len(batches) == 0:
# check input format
batch_list = []
for b in batches:
if isinstance(b, dict):
if len(b) > 0:
batch_list.append(Batch(b))
elif isinstance(b, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty():
batch_list.append(b)
else:
raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_")
if len(batch_list) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]

# x.is_empty() means that x is Batch() and should be ignored
batches = [x for x in batches if not x.is_empty()]
batches = batch_list
try:
# x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
Expand Down Expand Up @@ -496,9 +502,22 @@ def stack_(
self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0
) -> None:
"""Stack a list of Batch object into current batch."""
if len(batches) == 0:
# check input format
batch_list = []
for b in batches:
if isinstance(b, dict):
if len(b) > 0:
batch_list.append(Batch(b))
elif isinstance(b, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty():
batch_list.append(b)
else:
raise ValueError(
f"Cannot concatenate {type(b)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
batches = batch_list
if not self.is_empty():
batches = [self] + batches
# collect non-empty keys
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _add_to_buffer(self, name: str, inst: Any) -> None:
)
try:
value[self._index] = inst
except KeyError: # inst is a dict/Batch
except ValueError: # inst is a dict/Batch
for key in set(inst.keys()).difference(value.keys()):
self._buffer_allocator([name, key], inst[key])
self._meta[name][self._index] = inst
Expand Down