这是indexloc提供的服务,不要输入任何密码
Skip to content

change Batch.empty to in-place fill; add copy option for Batch construction #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 6, 2020
Merged
89 changes: 69 additions & 20 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_batch_over_batch():
assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3)


def test_batch_cat_and_stack_and_empty():
def test_batch_cat_and_stack():
b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
b12_cat_out = Batch.cat((b1, b2))
Expand Down Expand Up @@ -145,24 +145,6 @@ def test_batch_cat_and_stack_and_empty():
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': [2., 'st'], 'd': [1, None], 'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == ['2.0', '']
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
b0 = Batch()
b0.empty_()
assert b0.shape == []


def test_batch_over_batch_to_torch():
Expand Down Expand Up @@ -225,6 +207,71 @@ def test_batch_from_to_numpy_without_copy():
assert c_mem_addr_new == c_mem_addr_orig


def test_batch_copy():
batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6]))
batch2 = Batch({'c': np.array([6, 7, 8]), 'b': batch})
orig_c_addr = batch2.c.__array_interface__['data'][0]
orig_b_a_addr = batch2.b.a.__array_interface__['data'][0]
orig_b_b_addr = batch2.b.b.__array_interface__['data'][0]
# test with copy=False
batch3 = Batch(copy=False, **batch2)
curr_c_addr = batch3.c.__array_interface__['data'][0]
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
assert batch2.c is batch3.c
assert batch2.b is batch3.b
assert batch2.b.a is batch3.b.a
assert batch2.b.b is batch3.b.b
assert orig_c_addr == curr_c_addr
assert orig_b_a_addr == curr_b_a_addr
assert orig_b_b_addr == curr_b_b_addr
# test with copy=True
batch3 = Batch(copy=True, **batch2)
curr_c_addr = batch3.c.__array_interface__['data'][0]
curr_b_a_addr = batch3.b.a.__array_interface__['data'][0]
curr_b_b_addr = batch3.b.b.__array_interface__['data'][0]
assert batch2.c is not batch3.c
assert batch2.b is not batch3.b
assert batch2.b.a is not batch3.b.a
assert batch2.b.b is not batch3.b.b
assert orig_c_addr != curr_c_addr
assert orig_b_a_addr != curr_b_a_addr
assert orig_b_b_addr != curr_b_b_addr


def test_batch_empty():
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': np.array([2., 'st'], dtype=np.object),
'd': [1, None],
'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == [2.0, None]
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0]))
data.empty_(index=0)
assert np.allclose(data.c, [0, 3, 0])
assert list(data.b.c) == [None, None]
assert list(data.b.d) == [None, None]
assert list(data.b.e) == [0, 0]
b0 = Batch()
b0.empty_()
assert b0.shape == []


def test_batch_numpy_compatibility():
batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
b=Batch(),
Expand All @@ -246,4 +293,6 @@ def test_batch_numpy_compatibility():
test_batch_pickle()
test_batch_from_to_numpy_without_copy()
test_batch_numpy_compatibility()
test_batch_cat_and_stack_and_empty()
test_batch_cat_and_stack()
test_batch_copy()
test_batch_empty()
65 changes: 42 additions & 23 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import copy
import pprint
import warnings
import numpy as np
from copy import deepcopy
from functools import reduce
from numbers import Number
from typing import Any, List, Tuple, Union, Iterator, Optional
Expand Down Expand Up @@ -85,8 +85,13 @@ class Batch:
c: '2312312',
)

In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 7 reserved keys in
In short, you can define a :class:`Batch` with any key-value pair.

For Numpy arrays, only data types with ``np.object`` and numbers are
supported. For strings or other data types, however, they can be held
in ``np.object`` arrays.

The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:

* ``obs`` the observation of step :math:`t` ;
Expand Down Expand Up @@ -252,7 +257,10 @@ def __init__(self,
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
copy: bool = False,
**kwargs) -> None:
if copy:
batch_dict = deepcopy(batch_dict)
if _is_batch_set(batch_dict):
self.stack_(batch_dict)
elif isinstance(batch_dict, (dict, Batch)):
Expand All @@ -264,7 +272,7 @@ def __init__(self,
v = np.array(v)
self.__dict__[k] = v
if len(kwargs) > 0:
self.__init__(kwargs)
self.__init__(kwargs, copy=copy)

def __setattr__(self, key: str, value: Any):
"""self[key] = value"""
Expand Down Expand Up @@ -360,7 +368,7 @@ def __iadd__(self, other: Union['Batch', Number]):
def __add__(self, other: Union['Batch', Number]):
"""Algebraic addition with another :class:`~tianshou.data.Batch`
instance out-of-place."""
return copy.deepcopy(self).__iadd__(other)
return deepcopy(self).__iadd__(other)

def __imul__(self, val: Number):
"""Algebraic multiplication with a scalar value in-place."""
Expand All @@ -372,7 +380,7 @@ def __imul__(self, val: Number):

def __mul__(self, val: Number):
"""Algebraic multiplication with a scalar value out-of-place."""
return copy.deepcopy(self).__imul__(val)
return deepcopy(self).__imul__(val)

def __itruediv__(self, val: Number):
"""Algebraic division wibyth a scalar value in-place."""
Expand All @@ -384,7 +392,7 @@ def __itruediv__(self, val: Number):

def __truediv__(self, val: Number):
"""Algebraic division wibyth a scalar value out-of-place."""
return copy.deepcopy(self).__itruediv__(val)
return deepcopy(self).__itruediv__(val)

def __repr__(self) -> str:
"""Return str(self)."""
Expand Down Expand Up @@ -476,7 +484,7 @@ def cat_(self, batch: 'Batch') -> None:
if v is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = copy.deepcopy(v)
self.__dict__[k] = deepcopy(v)
elif isinstance(v, np.ndarray) and v.ndim > 0:
self.__dict__[k] = np.concatenate([self.__dict__[k], v])
elif isinstance(v, torch.Tensor):
Expand Down Expand Up @@ -537,34 +545,45 @@ def stack(batches: List['Batch'], axis: int = 0) -> 'Batch':
batch.stack_(batches, axis)
return batch

def empty_(self) -> 'Batch':
def empty_(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty a :class:`~tianshou.data.Batch` object with 0 or
``None`` filled.
``None`` filled. If ``index`` is specified, it will only reset the
specific indexed-data.
"""
for k, v in self.items():
if v is None:
continue
if isinstance(v, Batch):
self.__dict__[k].empty_()
elif isinstance(v, np.ndarray) and v.dtype == np.object:
self.__dict__[k].fill(None)
elif isinstance(v, torch.Tensor): # cannot apply fill_ directly
self.__dict__[k] = torch.zeros_like(self.__dict__[k])
else: # np
self.__dict__[k] *= 0
if hasattr(v, 'dtype') and v.dtype.kind in 'fc':
self.__dict__[k] = np.nan_to_num(self.__dict__[k])
self.__dict__[k].empty_(index=index)
elif isinstance(v, torch.Tensor):
self.__dict__[k][index] = 0
elif isinstance(v, np.ndarray):
if v.dtype == np.object:
self.__dict__[k][index] = None
else:
self.__dict__[k][index] = 0
else: # scalar value
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
'which may cause undefined behaviors.')
if isinstance(v, (np.generic, Number)):
self.__dict__[k] *= 0
if np.isnan(self.__dict__[k]):
self.__dict__[k] = 0
else:
self.__dict__[k] = None
return self

@staticmethod
def empty(batch: 'Batch') -> 'Batch':
def empty(batch: 'Batch', index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
"""Return an empty :class:`~tianshou.data.Batch` object with 0 or
``None`` filled, the shape is the same as the given
:class:`~tianshou.data.Batch`.
"""
batch = Batch(**batch)
batch.empty_()
return batch
return deepcopy(batch).empty_(index)

def __len__(self) -> int:
"""Return len(self)."""
Expand Down
13 changes: 9 additions & 4 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,15 @@ def _reset_state(self, id: Union[int, List[int]]) -> None:
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] *= 0
else: # Batch
self.state[id].empty_()
elif isinstance(self.state, torch.Tensor):
self.state[id].zero_()
elif isinstance(self.state, np.ndarray):
if isinstance(self.state.dtype == np.object):
self.state[id] = None
else:
self.state[id] = 0
elif isinstance(self.state, Batch):
self.state.empty_(id)

def collect(self,
n_step: int = 0,
Expand Down