diff --git a/.gitignore b/.gitignore index 604f2dfc1..082dcef12 100644 --- a/.gitignore +++ b/.gitignore @@ -144,3 +144,4 @@ MUJOCO_LOG.TXT .DS_Store *.zip *.pstats +*.swp diff --git a/docs/bibtex.json b/docs/bibtex.json new file mode 100644 index 000000000..685d5db67 --- /dev/null +++ b/docs/bibtex.json @@ -0,0 +1,9 @@ +{ + "cited": { + "tutorials/dqn": [ + "DQN", + "DDPG", + "PPO" + ] + } +} \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 2169ab8ec..f7bcc562d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,6 +70,7 @@ ] ) } +bibtex_bibfiles = ['refs.bib'] # -- Options for HTML output ------------------------------------------------- diff --git a/setup.py b/setup.py index a7410ce8c..5af0b206f 100644 --- a/setup.py +++ b/setup.py @@ -47,10 +47,11 @@ def get_version() -> str: install_requires=[ "gym>=0.15.4", "tqdm", - "numpy", + "numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard", "torch>=1.4.0", "numba>=0.51.0", + "h5py>=3.1.0" ], extras_require={ "dev": [ diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index d8c46dd77..1695195b4 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,11 +1,15 @@ +import os import torch import pickle import pytest +import tempfile +import h5py import numpy as np from timeit import timeit from tianshou.data import Batch, SegmentTree, \ ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer +from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': from env import MyTestEnv @@ -278,7 +282,73 @@ def test_pickle(): pbuf.weight[np.arange(len(pbuf))]) +def test_hdf5(): + size = 100 + buffers = { + "array": ReplayBuffer(size, stack_num=2), + "list": ListReplayBuffer(), + "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4) + } + buffer_types = {k: b.__class__ for k, b in buffers.items()} + device = 'cuda' if torch.cuda.is_available() else 'cpu' + rew = torch.tensor([1.]).to(device) + for i in range(4): + kwargs = { + 'obs': Batch(index=np.array([i])), + 'act': i, + 'rew': rew, + 'done': 0, + 'info': {"number": {"n": i}, 'extra': None}, + } + buffers["array"].add(**kwargs) + buffers["list"].add(**kwargs) + buffers["prioritized"].add(weight=np.random.rand(), **kwargs) + + # save + paths = {} + for k, buf in buffers.items(): + f, path = tempfile.mkstemp(suffix='.hdf5') + os.close(f) + buf.save_hdf5(path) + paths[k] = path + + # load replay buffer + _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} + + # compare + for k in buffers.keys(): + assert len(_buffers[k]) == len(buffers[k]) + assert np.allclose(_buffers[k].act, buffers[k].act) + assert _buffers[k].stack_num == buffers[k].stack_num + assert _buffers[k]._maxsize == buffers[k]._maxsize + assert _buffers[k]._index == buffers[k]._index + assert np.all(_buffers[k]._indices == buffers[k]._indices) + for k in ["array", "prioritized"]: + assert isinstance(buffers[k].get(0, "info"), Batch) + assert isinstance(_buffers[k].get(0, "info"), Batch) + for k in ["array"]: + assert np.all( + buffers[k][:].info.number.n == _buffers[k][:].info.number.n) + assert np.all( + buffers[k][:].info.extra == _buffers[k][:].info.extra) + + for path in paths.values(): + os.remove(path) + + # raise exception when value cannot be pickled + data = {"not_supported": lambda x: x*x} + grp = h5py.Group + with pytest.raises(NotImplementedError): + to_hdf5(data, grp) + # ndarray with data type not supported by HDF5 that cannot be pickled + data = {"not_supported": np.array(lambda x: x*x)} + grp = h5py.Group + with pytest.raises(RuntimeError): + to_hdf5(data, grp) + + if __name__ == '__main__': + test_hdf5() test_replaybuffer() test_ignore_obs_next() test_stack() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c8c572a22..74299df9d 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,10 +1,12 @@ +import h5py import torch import numpy as np from numbers import Number from typing import Any, Dict, List, Tuple, Union, Optional -from tianshou.data import Batch, SegmentTree, to_numpy from tianshou.data.batch import _create_value +from tianshou.data import Batch, SegmentTree, to_numpy +from tianshou.data.utils.converter import to_hdf5, from_hdf5 class ReplayBuffer: @@ -38,7 +40,10 @@ class ReplayBuffer: >>> # but there are only three valid items, so len(buf) == 3. >>> len(buf) 3 - >>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl" + >>> # save to file "buf.pkl" + >>> pickle.dump(buf, open('buf.pkl', 'wb')) + >>> # save to HDF5 file + >>> buf.save_hdf5('buf.hdf5') >>> buf2 = ReplayBuffer(size=10) >>> for i in range(15): ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) @@ -54,7 +59,7 @@ class ReplayBuffer: 0., 0., 0., 0., 0., 0., 0.]) >>> # get a random sample from buffer - >>> # the batch_data is equal to buf[incide]. + >>> # the batch_data is equal to buf[indice]. >>> batch_data, indice = buf.sample(batch_size=4) >>> batch_data.obs == buf[indice].obs array([ True, True, True, True]) @@ -63,6 +68,15 @@ class ReplayBuffer: >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" >>> len(buf) 3 + >>> # load complete buffer from HDF5 file + >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') + >>> len(buf) + 3 + >>> # load contents of HDF5 file into existing buffer + >>> # (only possible if size of buffer and data in file match) + >>> buf.load_contents_hdf5('buf.hdf5') + >>> len(buf) + 3 :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next @@ -167,8 +181,14 @@ def __setstate__(self, state: Dict[str, Any]) -> None: We need it because pickling buffer does not work out-of-the-box ("buffer.__getattr__" is customized). """ + self._indices = np.arange(state["_maxsize"]) self.__dict__.update(state) + def __getstate__(self) -> dict: + exclude = {"_indices"} + state = {k: v for k, v in self.__dict__.items() if k not in exclude} + return state + def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] @@ -359,6 +379,21 @@ def __getitem__( policy=self.get(index, "policy"), ) + def save_hdf5(self, path: str) -> None: + """Save replay buffer to HDF5 file.""" + with h5py.File(path, "w") as f: + to_hdf5(self.__getstate__(), f) + + @classmethod + def load_hdf5( + cls, path: str, device: Optional[str] = None + ) -> "ReplayBuffer": + """Load replay buffer from HDF5 file.""" + with h5py.File(path, "r") as f: + buf = cls.__new__(cls) + buf.__setstate__(from_hdf5(f, device=device)) + return buf + class ListReplayBuffer(ReplayBuffer): """List-based replay buffer. diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index a8e884508..52b0744cf 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -1,8 +1,10 @@ +import h5py import torch +import pickle import numpy as np from copy import deepcopy from numbers import Number -from typing import Union, Optional +from typing import Dict, Union, Optional from tianshou.data.batch import _parse_value, Batch @@ -80,3 +82,90 @@ def to_torch_as( """ assert isinstance(y, torch.Tensor) return to_torch(x, dtype=y.dtype, device=y.device) + + +# Note: object is used as a proxy for objects that can be pickled +# Note: mypy does not support cyclic definition currently +Hdf5ConvertibleValues = Union[ # type: ignore + int, float, Batch, np.ndarray, torch.Tensor, object, + 'Hdf5ConvertibleType', # type: ignore +] + +Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore + + +def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None: + """Copy object into HDF5 group.""" + + def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None: + """Pickle, convert to numpy array and write to HDF5 dataset.""" + data = np.frombuffer(pickle.dumps(x), dtype=np.byte) + y.create_dataset(key, data=data) + + for k, v in x.items(): + if isinstance(v, (Batch, dict)): + # dicts and batches are both represented by groups + subgrp = y.create_group(k) + if isinstance(v, Batch): + subgrp_data = v.__getstate__() + subgrp.attrs["__data_type__"] = "Batch" + else: + subgrp_data = v + to_hdf5(subgrp_data, subgrp) + elif isinstance(v, torch.Tensor): + # PyTorch tensors are written to datasets + y.create_dataset(k, data=to_numpy(v)) + y[k].attrs["__data_type__"] = "Tensor" + elif isinstance(v, np.ndarray): + try: + # NumPy arrays are written to datasets + y.create_dataset(k, data=v) + y[k].attrs["__data_type__"] = "ndarray" + except TypeError: + # If data type is not supported by HDF5 fall back to pickle. + # This happens if dtype=object (e.g. due to entries being None) + # and possibly in other cases like structured arrays. + try: + to_hdf5_via_pickle(v, y, k) + except Exception as e: + raise RuntimeError( + f"Attempted to pickle {v.__class__.__name__} due to " + "data type not supported by HDF5 and failed." + ) from e + y[k].attrs["__data_type__"] = "pickled_ndarray" + elif isinstance(v, (int, float)): + # ints and floats are stored as attributes of groups + y.attrs[k] = v + else: # resort to pickle for any other type of object + try: + to_hdf5_via_pickle(v, y, k) + except Exception as e: + raise NotImplementedError( + f"No conversion to HDF5 for object of type '{type(v)}' " + "implemented and fallback to pickle failed." + ) from e + y[k].attrs["__data_type__"] = v.__class__.__name__ + + +def from_hdf5( + x: h5py.Group, device: Optional[str] = None +) -> Hdf5ConvertibleType: + """Restore object from HDF5 group.""" + if isinstance(x, h5py.Dataset): + # handle datasets + if x.attrs["__data_type__"] == "ndarray": + y = np.array(x) + elif x.attrs["__data_type__"] == "Tensor": + y = torch.tensor(x, device=device) + else: + y = pickle.loads(x[()]) + else: + # handle groups representing a dict or a Batch + y = {k: v for k, v in x.attrs.items() if k != "__data_type__"} + for k, v in x.items(): + y[k] = from_hdf5(v, device) + if "__data_type__" in x.attrs: + # if dictionary represents Batch, convert to Batch + if x.attrs["__data_type__"] == "Batch": + y = Batch(y) + return y