这是indexloc提供的服务,不要输入任何密码
Skip to content

Improve Batch #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 11, 2020
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-latest
if: "!contains(github.event.head_commit.message, 'ci skip')"
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
Expand Down
43 changes: 42 additions & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
def test_batch():
assert list(Batch()) == []
assert Batch().is_empty()
assert Batch(b={'c': {}}).is_empty()
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch()
b.update()
assert b.is_empty()
b.update(c=[3, 5])
assert np.allclose(b.c, [3, 5])
# mimic the behavior of dict.update, where kwargs can overwrite keys
b.update({'a': 2}, a=3)
assert b.a == 3
with pytest.raises(AssertionError):
Batch({1: 2})
batch = Batch(a=[torch.ones(3), torch.ones(3)])
Expand Down Expand Up @@ -86,6 +96,18 @@ def test_batch():
assert batch3.a.d.f[0] == 5.0
with pytest.raises(KeyError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
batch4.update(a=np.array(['c', 'd']))
assert list(batch4.a) == ['c', 'd']
assert batch4.a.dtype == np.object # auto convert to np.object
batch5 = Batch(a=np.array([{'index': 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
batch5.b = np.array([{'index': 1}])
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])


def test_batch_over_batch():
Expand All @@ -100,6 +122,11 @@ def test_batch_over_batch():
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0])
batch2.update(batch2.b, six=[6, 6, 6])
assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0])
assert np.allclose(batch2.six, [6, 6, 6])
d = {'a': [3, 4, 5], 'b': [4, 5, 6]}
batch3 = Batch(c=[6, 7, 8], b=d)
batch3.cat_(Batch(c=[6, 7, 8], b=d))
Expand All @@ -124,18 +151,32 @@ def test_batch_over_batch():


def test_batch_cat_and_stack():
# test cat with compatible keys
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2))
b12_cat_out = Batch.cat([b1, b2])
b12_cat_in = copy.deepcopy(b1)
b12_cat_in.cat_(b2)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
assert isinstance(b12_cat_in.a.d.e, np.ndarray)
assert b12_cat_in.a.d.e.ndim == 1

b12_stack = Batch.stack((b1, b2))
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2

# test batch with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)

b3 = Batch(a=np.zeros((3, 4)),
b=torch.ones((2, 5)),
c=Batch(d=[[1], [2]]))
Expand Down
167 changes: 112 additions & 55 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ def __init__(self,
v_ = None
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
v_ = torch.stack(v)
self.__dict__[k] = v_
self.__dict__[k] = torch.stack(v)
continue
else:
v_ = np.asanyarray(v)
Expand Down Expand Up @@ -294,7 +293,8 @@ def __setattr__(self, key: str, value: Any):
value = np.array(value)
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
elif isinstance(value, dict):
elif isinstance(value, dict) or isinstance(value, np.ndarray) \
and value.dtype == np.object and _is_batch_set(value):
value = Batch(value)
self.__dict__[key] = value

Expand Down Expand Up @@ -333,9 +333,8 @@ def __getitem__(self, index: Union[
else:
raise IndexError("Cannot access item from empty Batch object.")

def __setitem__(
self,
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(value, np.ndarray):
Expand Down Expand Up @@ -454,10 +453,8 @@ def to_numpy(self) -> None:
elif isinstance(v, Batch):
v.to_numpy()

def to_torch(self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> None:
def to_torch(self, dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu') -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an in-place
operation.
"""
Expand All @@ -473,66 +470,111 @@ def to_torch(self,
v = v.type(dtype)
self.__dict__[k] = v
elif isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype:
must_update_tensor = True
elif v.device.type != device.type:
must_update_tensor = True
elif device.index is not None and \
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
device.index is not None and \
device.index != v.device.index:
must_update_tensor = True
else:
must_update_tensor = False
if must_update_tensor:
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v.to(device)
elif isinstance(v, Batch):
v.to_torch(dtype, device)

def append(self, batch: 'Batch') -> None:
warnings.warn('Method :meth:`~tianshou.data.Batch.append` will be '
'removed soon, please use '
':meth:`~tianshou.data.Batch.cat`')
return self.cat_(batch)

def cat_(self, batch: 'Batch') -> None:
"""Concatenate a :class:`~tianshou.data.Batch` object into current
batch.
def cat_(self,
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
"""Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects
into current batch.
"""
assert isinstance(batch, Batch), \
'Only Batch is allowed to be concatenated in-place!'
for k, v in batch.items():
if v is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
self.__dict__[k] = torch.cat([self.__dict__[k], v])
elif isinstance(v, Batch):
self.__dict__[k].cat_(v)
if isinstance(batches, Batch):
batches = [batches]
if len(batches) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if len(self.__dict__) > 0:
batches = [self] + list(batches)
# partial keys will be padded by zeros
# with the shape of [len, rest_shape]
lens = [len(x) for x in batches]
keys_map = list(map(lambda e: set(e.keys()), batches))
keys_shared = set.intersection(*keys_map)
values_shared = [
[e[k] for e in batches] for k in keys_shared]
_assert_type_keys(keys_shared)
for k, v in zip(keys_shared, values_shared):
if all(isinstance(e, (dict, Batch)) for e in v):
self.__dict__[k] = Batch.cat(v)
elif all(isinstance(e, torch.Tensor) for e in v):
self.__dict__[k] = torch.cat(v)
else:
v = np.concatenate(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
self.__dict__[k] = v
keys_partial = set.union(*keys_map) - keys_shared
_assert_type_keys(keys_partial)
for k in keys_partial:
is_dict = False
value = None
for i, e in enumerate(batches):
val = e.get(k, None)
if val is not None:
if isinstance(val, (dict, Batch)):
is_dict = True
else: # np.ndarray or torch.Tensor
value = val
break
if is_dict:
self.__dict__[k] = Batch.cat(
[e.get(k, Batch()) for e in batches])
else:
s = 'No support for method "cat" with type '\
f'{type(v)} in class Batch.'
raise TypeError(s)
if isinstance(value, np.ndarray):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = np.zeros(shape, dtype=value.dtype)
arrs.append(e.get(k, pad))
self.__dict__[k] = np.concatenate(arrs)
elif isinstance(value, torch.Tensor):
arrs = []
for i, e in enumerate(batches):
shape = [lens[i]] + list(value.shape[1:])
pad = torch.zeros(shape,
dtype=value.dtype,
device=value.device)
arrs.append(e.get(k, pad))
self.__dict__[k] = torch.cat(arrs)
else:
raise TypeError(
f"cannot cat value with type {type(value)}, we only "
"support dict, Batch, np.ndarray, and torch.Tensor")

@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a single
new batch.
"""Concatenate a list of :class:`~tianshou.data.Batch` object into a
single new batch. For keys that are not shared across all batches,
batches that do not have these keys will be padded by zeros with
appropriate shapes. E.g.
::

>>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))
>>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))
>>> c = Batch.cat([a, b])
>>> c.a.shape
(7, 4)
>>> c.b.shape
(7, 3)
>>> c.common.c.shape
(7, 5)
"""
batch = Batch()
for batch_ in batches:
if isinstance(batch_, dict):
batch_ = Batch(batch_)
batch.cat_(batch_)
batch.cat_(batches)
return batch

def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
"""Stack a :class:`~tianshou.data.Batch` object i into current batch.
"""Stack a list of :class:`~tianshou.data.Batch` object into current
batch.
"""
if len(self.__dict__) > 0:
batches = [self] + list(batches)
Expand Down Expand Up @@ -566,8 +608,8 @@ def stack_(self,

@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
"""Stack a :class:`~tianshou.data.Batch` object into a single new
batch.
"""Stack a list of :class:`~tianshou.data.Batch` object into a single
new batch.
"""
batch = Batch()
batch.stack_(batches, axis)
Expand Down Expand Up @@ -611,11 +653,24 @@ def empty(batch: 'Batch', index: Union[
"""
return deepcopy(batch).empty_(index)

def update(self, batch: Optional[Union[dict, 'Batch']] = None,
**kwargs) -> None:
"""Update this batch from another dict/Batch."""
if batch is None:
self.update(kwargs)
return
if isinstance(batch, dict):
batch = Batch(batch)
for k, v in batch.items():
self.__dict__[k] = v
if kwargs:
self.update(kwargs)

def __len__(self) -> int:
"""Return len(self)."""
r = []
for v in self.__dict__.values():
if isinstance(v, Batch) and len(v.__dict__) == 0:
if isinstance(v, Batch) and v.is_empty():
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
Expand All @@ -627,7 +682,9 @@ def __len__(self) -> int:
return min(r)

def is_empty(self):
return len(self.__dict__.keys()) == 0
return not any(
not x.is_empty() if isinstance(x, Batch)
else hasattr(x, '__len__') and len(x) > 0 for x in self.values())

@property
def shape(self) -> List[int]:
Expand Down
21 changes: 7 additions & 14 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def __init__(self, size: int, stack_num: Optional[int] = 0,
super().__init__()
self._maxsize = size
self._stack = stack_num
assert stack_num != 1, \
'stack_num should greater than 1'
assert stack_num != 1, 'stack_num should greater than 1'
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._save_s_ = not ignore_obs_next
Expand All @@ -136,12 +135,11 @@ def _add_to_buffer(self, name: str, inst: Any) -> None:
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
if isinstance(inst, np.ndarray) and \
value.shape[1:] != inst.shape:
if isinstance(inst, np.ndarray) and value.shape[1:] != inst.shape:
raise ValueError(
"Cannot add data to a buffer with different shape, key: "
f"{name}, expect shape: {value.shape[1:]}"
f", given shape: {inst.shape}.")
f"{name}, expect shape: {value.shape[1:]}, "
f"given shape: {inst.shape}.")
try:
value[self._index] = inst
except KeyError:
Expand Down Expand Up @@ -357,7 +355,7 @@ def __init__(self, size: int, alpha: float, beta: float,
self._weight_sum = 0.0
self._amortization_freq = 50
self._replace = replace
self._meta.__dict__['weight'] = np.zeros(size, dtype=np.float64)
self._meta.weight = np.zeros(size, dtype=np.float64)

def add(self,
obs: Union[dict, np.ndarray],
Expand All @@ -372,7 +370,7 @@ def add(self,
"""Add a batch of data into replay buffer."""
# we have to sacrifice some convenience for speed
self._weight_sum += np.abs(weight) ** self._alpha - \
self._meta.__dict__['weight'][self._index]
self._meta.weight[self._index]
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
super().add(obs, act, rew, done, obs_next, info, policy)

Expand Down Expand Up @@ -410,14 +408,9 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
f"batch_size should be less than {len(self)}, \
or set replace=True")
batch = self[indice]
impt_weight = Batch(
impt_weight=(self._size * p) ** (-self._beta))
batch.cat_(impt_weight)
batch["impt_weight"] = (self._size * p) ** (-self._beta)
return batch, indice

def reset(self) -> None:
super().reset()

def update_weight(self, indice: Union[slice, np.ndarray],
new_weight: np.ndarray) -> None:
"""Update priority weight by indice in this buffer.
Expand Down