这是indexloc提供的服务,不要输入任何密码
Skip to content

Saving and loading replay buffer with HDF5 #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
be78460
Add saving and loading of replay buffer to/from HDF5 file.
nicoguertler Dec 4, 2020
224328a
Add test for HDF5-based saving of replay buffer.
nicoguertler Dec 4, 2020
72beac7
Also copy size and index.
nicoguertler Dec 4, 2020
6241681
Remove swap files.
nicoguertler Dec 4, 2020
00965d6
Add dependency on h5py.
nicoguertler Dec 4, 2020
5323977
Adapt to PEP8.
nicoguertler Dec 7, 2020
94d5315
Finalize docstrings.
nicoguertler Dec 7, 2020
c298790
Add support for Batch and torch.Tensor in saving to HDF5.
nicoguertler Dec 7, 2020
396e9de
Adapt to PEP8.
nicoguertler Dec 7, 2020
e96d8d9
Use read_direct even if entry doesn't exist.
nicoguertler Dec 8, 2020
f541948
Load replay buffer into numpy arrays by default.
nicoguertler Dec 8, 2020
cddf131
Adapt to PEP8.
nicoguertler Dec 8, 2020
ff8b112
Version constraint for h5py and execution of test_hdf5 in __main__.
nicoguertler Dec 9, 2020
1bb7ccb
Update test/base/test_buffer.py
nicoguertler Dec 9, 2020
c01c2e1
Update tianshou/data/buffer.py
nicoguertler Dec 9, 2020
cc87c20
Update tianshou/data/buffer.py
nicoguertler Dec 9, 2020
bcd41dd
Update tianshou/data/buffer.py
nicoguertler Dec 9, 2020
2ea6dfd
Update tianshou/data/buffer.py
Trinkle23897 Dec 9, 2020
b30ae5b
Add _hdf5 to loading and saving methods.
nicoguertler Dec 9, 2020
dda680e
Merge branch 'save_load_replay_buffer' of github.com:nicoguertler/tia…
nicoguertler Dec 9, 2020
f655020
Make sure tensor is in main memory before converting to numpy.
nicoguertler Dec 9, 2020
6071cb2
fix test
Trinkle23897 Dec 9, 2020
a554ba8
Add HDF5 example to usage code snippet.
nicoguertler Dec 9, 2020
eab5b69
Merge branch 'save_load_replay_buffer' of github.com:nicoguertler/tia…
nicoguertler Dec 9, 2020
953f9d9
minor fix
Trinkle23897 Dec 9, 2020
36d4754
Merge branch 'save_load_replay_buffer' of github.com:nicoguertler/tia…
Trinkle23897 Dec 9, 2020
b9e14aa
test more
Trinkle23897 Dec 9, 2020
3443a39
Use __getstate__ for saving via HDF5.
nicoguertler Dec 9, 2020
eff1f99
Fix PEP8.
nicoguertler Dec 9, 2020
a807128
Fix type hints.
nicoguertler Dec 9, 2020
b07b40b
PEP8.
nicoguertler Dec 9, 2020
f41a571
Implement fall back to pickle.
nicoguertler Dec 10, 2020
782fef8
Fix formatting.
nicoguertler Dec 10, 2020
fcbc5ac
Fix type hints, add tests, clean up.
nicoguertler Dec 11, 2020
4e130ea
PEP8
nicoguertler Dec 11, 2020
7214503
Fix missing import, implement some suggestions.
nicoguertler Dec 11, 2020
07459d9
Exception when saving, use __name__ and more readable unit test.
nicoguertler Dec 14, 2020
801ca14
Fix typo.
nicoguertler Dec 14, 2020
ea838d2
fix test
Trinkle23897 Dec 15, 2020
335ff26
fix new version of sphinxcontrib-bibtex
Trinkle23897 Dec 15, 2020
5b10e98
fix test
Trinkle23897 Dec 15, 2020
4173092
add extra test
Trinkle23897 Dec 15, 2020
2d45c1a
Fallback to pickle for arrays with data type not supported by HDF5.
nicoguertler Dec 16, 2020
a5d4505
Turn into inner function.
nicoguertler Dec 16, 2020
7ccd995
RuntimeError only for ndarrays and not for Tensors.
nicoguertler Dec 16, 2020
c4197e9
Only use __class__.__name__ when value unclear in advance.
nicoguertler Dec 16, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,4 @@ MUJOCO_LOG.TXT
.DS_Store
*.zip
*.pstats
*.swp
9 changes: 9 additions & 0 deletions docs/bibtex.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"cited": {
"tutorials/dqn": [
"DQN",
"DDPG",
"PPO"
]
}
}
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
]
)
}
bibtex_bibfiles = ['refs.bib']

# -- Options for HTML output -------------------------------------------------

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ def get_version() -> str:
install_requires=[
"gym>=0.15.4",
"tqdm",
"numpy",
"numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=3.1.0"
],
extras_require={
"dev": [
Expand Down
70 changes: 70 additions & 0 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import torch
import pickle
import pytest
import tempfile
import h5py
import numpy as np
from timeit import timeit

from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5

if __name__ == '__main__':
from env import MyTestEnv
Expand Down Expand Up @@ -278,7 +282,73 @@ def test_pickle():
pbuf.weight[np.arange(len(pbuf))])


def test_hdf5():
size = 100
buffers = {
"array": ReplayBuffer(size, stack_num=2),
"list": ListReplayBuffer(),
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': rew,
'done': 0,
'info': {"number": {"n": i}, 'extra': None},
}
buffers["array"].add(**kwargs)
buffers["list"].add(**kwargs)
buffers["prioritized"].add(weight=np.random.rand(), **kwargs)

# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path

# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}

# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k]._maxsize == buffers[k]._maxsize
assert _buffers[k]._index == buffers[k]._index
assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]:
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
assert np.all(
buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
assert np.all(
buffers[k][:].info.extra == _buffers[k][:].info.extra)

for path in paths.values():
os.remove(path)

# raise exception when value cannot be pickled
data = {"not_supported": lambda x: x*x}
grp = h5py.Group
with pytest.raises(NotImplementedError):
to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled
data = {"not_supported": np.array(lambda x: x*x)}
grp = h5py.Group
with pytest.raises(RuntimeError):
to_hdf5(data, grp)


if __name__ == '__main__':
test_hdf5()
test_replaybuffer()
test_ignore_obs_next()
test_stack()
Expand Down
41 changes: 38 additions & 3 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import h5py
import torch
import numpy as np
from numbers import Number
from typing import Any, Dict, List, Tuple, Union, Optional

from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.utils.converter import to_hdf5, from_hdf5


class ReplayBuffer:
Expand Down Expand Up @@ -38,7 +40,10 @@ class ReplayBuffer:
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
Expand All @@ -54,7 +59,7 @@ class ReplayBuffer:
0., 0., 0., 0., 0., 0., 0.])

>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
Expand All @@ -63,6 +68,15 @@ class ReplayBuffer:
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3
>>> # load contents of HDF5 file into existing buffer
>>> # (only possible if size of buffer and data in file match)
>>> buf.load_contents_hdf5('buf.hdf5')
>>> len(buf)
3

:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
Expand Down Expand Up @@ -167,8 +181,14 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
We need it because pickling buffer does not work out-of-the-box
("buffer.__getattr__" is customized).
"""
self._indices = np.arange(state["_maxsize"])
self.__dict__.update(state)

def __getstate__(self) -> dict:
exclude = {"_indices"}
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
return state

def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
value = self._meta.__dict__[name]
Expand Down Expand Up @@ -359,6 +379,21 @@ def __getitem__(
policy=self.get(index, "policy"),
)

def save_hdf5(self, path: str) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__getstate__(), f)

@classmethod
def load_hdf5(
cls, path: str, device: Optional[str] = None
) -> "ReplayBuffer":
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
return buf


class ListReplayBuffer(ReplayBuffer):
"""List-based replay buffer.
Expand Down
91 changes: 90 additions & 1 deletion tianshou/data/utils/converter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import h5py
import torch
import pickle
import numpy as np
from copy import deepcopy
from numbers import Number
from typing import Union, Optional
from typing import Dict, Union, Optional

from tianshou.data.batch import _parse_value, Batch

Expand Down Expand Up @@ -80,3 +82,90 @@ def to_torch_as(
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)


# Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[ # type: ignore
int, float, Batch, np.ndarray, torch.Tensor, object,
'Hdf5ConvertibleType', # type: ignore
]

Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore


def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
"""Copy object into HDF5 group."""

def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None:
"""Pickle, convert to numpy array and write to HDF5 dataset."""
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
y.create_dataset(key, data=data)

for k, v in x.items():
if isinstance(v, (Batch, dict)):
# dicts and batches are both represented by groups
subgrp = y.create_group(k)
if isinstance(v, Batch):
subgrp_data = v.__getstate__()
subgrp.attrs["__data_type__"] = "Batch"
else:
subgrp_data = v
to_hdf5(subgrp_data, subgrp)
elif isinstance(v, torch.Tensor):
# PyTorch tensors are written to datasets
y.create_dataset(k, data=to_numpy(v))
y[k].attrs["__data_type__"] = "Tensor"
elif isinstance(v, np.ndarray):
try:
# NumPy arrays are written to datasets
y.create_dataset(k, data=v)
y[k].attrs["__data_type__"] = "ndarray"
except TypeError:
# If data type is not supported by HDF5 fall back to pickle.
# This happens if dtype=object (e.g. due to entries being None)
# and possibly in other cases like structured arrays.
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise RuntimeError(
f"Attempted to pickle {v.__class__.__name__} due to "
"data type not supported by HDF5 and failed."
) from e
y[k].attrs["__data_type__"] = "pickled_ndarray"
elif isinstance(v, (int, float)):
# ints and floats are stored as attributes of groups
y.attrs[k] = v
else: # resort to pickle for any other type of object
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise NotImplementedError(
f"No conversion to HDF5 for object of type '{type(v)}' "
"implemented and fallback to pickle failed."
) from e
y[k].attrs["__data_type__"] = v.__class__.__name__


def from_hdf5(
x: h5py.Group, device: Optional[str] = None
) -> Hdf5ConvertibleType:
"""Restore object from HDF5 group."""
if isinstance(x, h5py.Dataset):
# handle datasets
if x.attrs["__data_type__"] == "ndarray":
y = np.array(x)
elif x.attrs["__data_type__"] == "Tensor":
y = torch.tensor(x, device=device)
else:
y = pickle.loads(x[()])
else:
# handle groups representing a dict or a Batch
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
for k, v in x.items():
y[k] = from_hdf5(v, device)
if "__data_type__" in x.attrs:
# if dictionary represents Batch, convert to Batch
if x.attrs["__data_type__"] == "Batch":
y = Batch(y)
return y