diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst
index 8ea2d272e..a314cdedb 100644
--- a/docs/tutorials/concepts.rst
+++ b/docs/tutorials/concepts.rst
@@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
Buffer
------
-.. automodule:: tianshou.data.ReplayBuffer
- :members:
- :noindex:
+:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
-Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
+The current implementation of Tianshou typically use 7 reserved keys in
+:class:`~tianshou.data.Batch`:
+
+* ``obs`` the observation of step :math:`t` ;
+* ``act`` the action of step :math:`t` ;
+* ``rew`` the reward of step :math:`t` ;
+* ``done`` the done flag of step :math:`t` ;
+* ``obs_next`` the observation of step :math:`t+1` ;
+* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``);
+* ``policy`` the data computed by policy in step :math:`t`;
+
+The following code snippet illustrates its usage, including:
+
+- the basic data storage: ``add()``;
+- get attribute, get slicing data, ...;
+- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``;
+- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``;
+- save/load data from buffer: pickle and HDF5;
+
+::
+
+ >>> import pickle, numpy as np
+ >>> from tianshou.data import ReplayBuffer
+ >>> buf = ReplayBuffer(size=20)
+ >>> for i in range(3):
+ ... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})
+
+ >>> buf.obs
+ # since we set size = 20, len(buf.obs) == 20.
+ array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
+ >>> # but there are only three valid items, so len(buf) == 3.
+ >>> len(buf)
+ 3
+ >>> # save to file "buf.pkl"
+ >>> pickle.dump(buf, open('buf.pkl', 'wb'))
+ >>> # save to HDF5 file
+ >>> buf.save_hdf5('buf.hdf5')
+
+ >>> buf2 = ReplayBuffer(size=10)
+ >>> for i in range(15):
+ ... done = i % 4 == 0
+ ... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
+ >>> len(buf2)
+ 10
+ >>> buf2.obs
+ # since its size = 10, it only stores the last 10 steps' result.
+ array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
+
+ >>> # move buf2's result into buf (meanwhile keep it chronologically)
+ >>> buf.update(buf2)
+ >>> buf.obs
+ array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
+ 0, 0, 0, 0])
+
+ >>> # get all available index by using batch_size = 0
+ >>> indice = buf.sample_index(0)
+ >>> indice
+ array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
+ >>> # get one step previous/next transition
+ >>> buf.prev(indice)
+ array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
+ >>> buf.next(indice)
+ array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
+
+ >>> # get a random sample from buffer
+ >>> # the batch_data is equal to buf[indice].
+ >>> batch_data, indice = buf.sample(batch_size=4)
+ >>> batch_data.obs == buf[indice].obs
+ array([ True, True, True, True])
+ >>> len(buf)
+ 13
+
+ >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
+ >>> len(buf)
+ 3
+ >>> # load complete buffer from HDF5 file
+ >>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
+ >>> len(buf)
+ 3
+
+:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38):
+
+.. raw:: html
+
+
+ Advance usage of ReplayBuffer
+
+.. code-block:: python
+
+ >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
+ >>> for i in range(16):
+ ... done = i % 5 == 0
+ ... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
+ ... done=done, obs_next={'id': i + 1})
+ ... print(i, ep_len, ep_rew)
+ 0 1 0.0
+ 1 0 0.0
+ 2 0 0.0
+ 3 0 0.0
+ 4 0 0.0
+ 5 5 15.0
+ 6 0 0.0
+ 7 0 0.0
+ 8 0 0.0
+ 9 0 0.0
+ 10 5 40.0
+ 11 0 0.0
+ 12 0 0.0
+ 13 0 0.0
+ 14 0 0.0
+ 15 5 65.0
+ >>> print(buf) # you can see obs_next is not saved in buf
+ ReplayBuffer(
+ obs: Batch(
+ id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
+ ),
+ act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
+ rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
+ done: array([False, True, False, False, False, False, True, False,
+ False]),
+ info: Batch(),
+ policy: Batch(),
+ )
+ >>> index = np.arange(len(buf))
+ >>> print(buf.get(index, 'obs').id)
+ [[ 7 7 8 9]
+ [ 7 8 9 10]
+ [11 11 11 11]
+ [11 11 11 12]
+ [11 11 12 13]
+ [11 12 13 14]
+ [12 13 14 15]
+ [ 7 7 7 7]
+ [ 7 7 7 8]]
+ >>> # here is another way to get the stacked data
+ >>> # (stack only for obs and obs_next)
+ >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
+ 0
+ >>> # we can get obs_next through __getitem__, even if it doesn't exist
+ >>> print(buf[:].obs_next.id)
+ [[ 7 8 9 10]
+ [ 7 8 9 10]
+ [11 11 11 12]
+ [11 11 12 13]
+ [11 12 13 14]
+ [12 13 14 15]
+ [12 13 14 15]
+ [ 7 7 7 8]
+ [ 7 7 8 9]]
+
+.. raw:: html
+
+
+
+Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
Policy
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 1695195b4..04b1928e2 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -7,9 +7,10 @@
import numpy as np
from timeit import timeit
-from tianshou.data import Batch, SegmentTree, \
- ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5
+from tianshou.data import Batch, SegmentTree, ReplayBuffer
+from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer
+from tianshou.data import ReplayBufferManager, CachedReplayBuffer
if __name__ == '__main__':
from env import MyTestEnv
@@ -38,11 +39,14 @@ def test_replaybuffer(size=10, bufsize=20):
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
b = ReplayBuffer(size=10)
- b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}})
+ # neg bsz should return empty index
+ assert b.sample_index(-1).tolist() == []
+ b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}})
assert b.obs[0] == 1
- assert b.done[0] == 'str'
+ assert b.done[0]
+ assert b.obs_next[0] == 'str'
assert np.all(b.obs[1:] == 0)
- assert np.all(b.done[1:] == np.array(None))
+ assert np.all(b.obs_next[1:] == np.array(None))
assert b.info.a[0] == 3 and b.info.a.dtype == np.integer
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact
@@ -91,7 +95,7 @@ def test_ignore_obs_next(size=10):
assert data.obs_next
-def test_stack(size=5, bufsize=9, stack_num=4):
+def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
@@ -115,7 +119,9 @@ def test_stack(size=5, bufsize=9, stack_num=4):
_, indice = buf2.sample(0)
assert indice.tolist() == [2, 6]
_, indice = buf2.sample(1)
- assert indice in [2, 6]
+ assert indice[0] in [2, 6]
+ batch, indice = buf2.sample(-1) # neg bsz -> no data
+ assert indice.tolist() == [] and len(batch) == 0
with pytest.raises(IndexError):
buf[bufsize * 2]
@@ -152,6 +158,12 @@ def test_update():
assert len(buf1) == len(buf2)
assert (buf2[0].obs == buf1[1].obs).all()
assert (buf2[-1].obs == buf1[0].obs).all()
+ b = ListReplayBuffer()
+ with pytest.raises(NotImplementedError):
+ b.update(b)
+ b = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
+ with pytest.raises(NotImplementedError):
+ b.update(b)
def test_segtree():
@@ -260,8 +272,7 @@ def test_pickle():
vbuf = ReplayBuffer(size, stack_num=2)
lbuf = ListReplayBuffer()
pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4)
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- rew = torch.tensor([1.]).to(device)
+ rew = np.array([1, 1])
for i in range(4):
vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)
for i in range(3):
@@ -287,18 +298,18 @@ def test_hdf5():
buffers = {
"array": ReplayBuffer(size, stack_num=2),
"list": ListReplayBuffer(),
- "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
+ "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4),
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
- rew = torch.tensor([1.]).to(device)
+ info_t = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
- 'rew': rew,
- 'done': 0,
- 'info': {"number": {"n": i}, 'extra': None},
+ 'rew': np.array([1, 2]),
+ 'done': i % 3 == 2,
+ 'info': {"number": {"n": i, "t": info_t}, 'extra': None},
}
buffers["array"].add(**kwargs)
buffers["list"].add(**kwargs)
@@ -320,10 +331,10 @@ def test_hdf5():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
- assert _buffers[k]._maxsize == buffers[k]._maxsize
- assert _buffers[k]._index == buffers[k]._index
+ assert _buffers[k].maxsize == buffers[k].maxsize
assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]:
+ assert _buffers[k]._index == buffers[k]._index
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
@@ -332,28 +343,350 @@ def test_hdf5():
assert np.all(
buffers[k][:].info.extra == _buffers[k][:].info.extra)
- for path in paths.values():
- os.remove(path)
-
# raise exception when value cannot be pickled
- data = {"not_supported": lambda x: x*x}
+ data = {"not_supported": lambda x: x * x}
grp = h5py.Group
with pytest.raises(NotImplementedError):
to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled
- data = {"not_supported": np.array(lambda x: x*x)}
+ data = {"not_supported": np.array(lambda x: x * x)}
grp = h5py.Group
with pytest.raises(RuntimeError):
to_hdf5(data, grp)
+def test_replaybuffermanager():
+ buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)])
+ ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3],
+ done=[0, 0, 1], buffer_ids=[0, 1, 2])
+ assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3])
+ with pytest.raises(NotImplementedError):
+ # ReplayBufferManager cannot be updated
+ buf.update(buf)
+ # sample index / prev / next / unfinished_index
+ indice = buf.sample_index(11000)
+ assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample
+ batch, indice = buf.sample(0)
+ assert np.allclose(indice, [0, 5, 10])
+ indice_prev = buf.prev(indice)
+ assert np.allclose(indice_prev, indice), indice_prev
+ indice_next = buf.next(indice)
+ assert np.allclose(indice_next, indice), indice_next
+ assert np.allclose(buf.unfinished_index(), [0, 5])
+ buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3])
+ assert np.allclose(buf.unfinished_index(), [0, 5])
+ batch, indice = buf.sample(10)
+ batch, indice = buf.sample(0)
+ assert np.allclose(indice, [0, 5, 10, 15])
+ indice_prev = buf.prev(indice)
+ assert np.allclose(indice_prev, indice), indice_prev
+ indice_next = buf.next(indice)
+ assert np.allclose(indice_next, indice), indice_next
+ data = np.array([0, 0, 0, 0])
+ buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
+ buf.add(obs=data, act=data, rew=data, done=1 - data,
+ buffer_ids=[0, 1, 2, 3])
+ assert len(buf) == 12
+ buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3])
+ buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1],
+ buffer_ids=[0, 1, 2, 3])
+ assert len(buf) == 20
+ indice = buf.sample_index(120000)
+ assert np.bincount(indice).min() >= 5000
+ batch, indice = buf.sample(10)
+ indice = buf.sample_index(0)
+ assert np.allclose(indice, np.arange(len(buf)))
+ # check the actual data stored in buf._meta
+ assert np.allclose(buf.done, [
+ 0, 0, 1, 0, 0,
+ 0, 0, 1, 0, 1,
+ 1, 0, 1, 0, 0,
+ 1, 0, 1, 0, 1,
+ ])
+ assert np.allclose(buf.prev(indice), [
+ 0, 0, 1, 3, 3,
+ 5, 5, 6, 8, 8,
+ 10, 11, 11, 13, 13,
+ 15, 16, 16, 18, 18,
+ ])
+ assert np.allclose(buf.next(indice), [
+ 1, 2, 2, 4, 4,
+ 6, 7, 7, 9, 9,
+ 10, 12, 12, 14, 14,
+ 15, 17, 17, 19, 19,
+ ])
+ assert np.allclose(buf.unfinished_index(), [4, 14])
+ ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1],
+ buffer_ids=[2])
+ assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1])
+ assert np.allclose(buf.unfinished_index(), [4])
+ indice = list(sorted(buf.sample_index(0)))
+ assert np.allclose(indice, np.arange(len(buf)))
+ assert np.allclose(buf.prev(indice), [
+ 0, 0, 1, 3, 3,
+ 5, 5, 6, 8, 8,
+ 14, 11, 11, 13, 13,
+ 15, 16, 16, 18, 18,
+ ])
+ assert np.allclose(buf.next(indice), [
+ 1, 2, 2, 4, 4,
+ 6, 7, 7, 9, 9,
+ 10, 12, 12, 14, 10,
+ 15, 17, 17, 19, 19,
+ ])
+ # corner case: list, int and -1
+ assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
+ assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
+ batch = buf._meta
+ batch.info.n = np.ones(buf.maxsize)
+ buf.set_batch(batch)
+ assert np.allclose(buf.buffers[-1].info.n, [1] * 5)
+ assert buf.sample_index(-1).tolist() == []
+ assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object
+
+
+def test_cachedbuffer():
+ buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5)
+ assert buf.sample_index(0).tolist() == []
+ # check the normal function/usage/storage in CachedReplayBuffer
+ ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0],
+ cached_buffer_ids=[1])
+ obs = np.zeros(buf.maxsize)
+ obs[15] = 1
+ indice = buf.sample_index(0)
+ assert np.allclose(indice, [15])
+ assert np.allclose(buf.prev(indice), [15])
+ assert np.allclose(buf.next(indice), [15])
+ assert np.allclose(buf.obs, obs)
+ assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0])
+ ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1],
+ cached_buffer_ids=[3])
+ obs[[0, 25]] = 2
+ indice = buf.sample_index(0)
+ assert np.allclose(indice, [0, 15])
+ assert np.allclose(buf.prev(indice), [0, 15])
+ assert np.allclose(buf.next(indice), [0, 15])
+ assert np.allclose(buf.obs, obs)
+ assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0])
+ assert np.allclose(buf.unfinished_index(), [15])
+ assert np.allclose(buf.sample_index(0), [0, 15])
+ ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4],
+ done=[0, 1], cached_buffer_ids=[3, 1])
+ assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0])
+ obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
+ assert np.allclose(buf.obs, obs)
+ assert np.allclose(buf.unfinished_index(), [25])
+ indice = buf.sample_index(0)
+ assert np.allclose(indice, [0, 1, 2, 25])
+ assert np.allclose(buf.done[indice], [1, 0, 1, 0])
+ assert np.allclose(buf.prev(indice), [0, 1, 1, 25])
+ assert np.allclose(buf.next(indice), [0, 2, 2, 25])
+ indice = buf.sample_index(10000)
+ assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample
+ # cached buffer with main_buffer size == 0 (no update)
+ # used in test_collector
+ buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5)
+ data = np.zeros(4)
+ rew = np.ones([4, 4])
+ buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data)
+ buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
+ buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data)
+ buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data)
+ buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data)
+ assert np.allclose(buf.done, [
+ 0, 0, 1, 0, 0,
+ 0, 1, 1, 0, 0,
+ 0, 0, 0, 0, 0,
+ 0, 1, 0, 0, 0,
+ ])
+ indice = buf.sample_index(0)
+ assert np.allclose(indice, [0, 1, 10, 11])
+ assert np.allclose(buf.prev(indice), [0, 0, 10, 10])
+ assert np.allclose(buf.next(indice), [1, 1, 11, 11])
+
+
+def test_multibuf_stack():
+ size = 5
+ bufsize = 9
+ stack_num = 4
+ cached_num = 3
+ env = MyTestEnv(size)
+ # test if CachedReplayBuffer can handle stack_num + ignore_obs_next
+ buf4 = CachedReplayBuffer(
+ ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
+ cached_num, size)
+ # test if CachedReplayBuffer can handle super corner case:
+ # prio-buffer + stack_num + ignore_obs_next + sample_avail
+ buf5 = CachedReplayBuffer(
+ PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num,
+ ignore_obs_next=True, sample_avail=True),
+ cached_num, size)
+ obs = env.reset(1)
+ for i in range(18):
+ obs_next, rew, done, info = env.step(1)
+ obs_list = np.array([obs + size * i for i in range(cached_num)])
+ act_list = [1] * cached_num
+ rew_list = [rew] * cached_num
+ done_list = [done] * cached_num
+ obs_next_list = -obs_list
+ info_list = [info] * cached_num
+ buf4.add(obs_list, act_list, rew_list, done_list,
+ obs_next_list, info_list)
+ buf5.add(obs_list, act_list, rew_list, done_list,
+ obs_next_list, info_list)
+ obs = obs_next
+ if done:
+ obs = env.reset(1)
+ # check the `add` order is correct
+ assert np.allclose(buf4.obs.reshape(-1), [
+ 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer
+ 1, 2, 3, 4, 0, # cached_buffer[0]
+ 6, 7, 8, 9, 0, # cached_buffer[1]
+ 11, 12, 13, 14, 0, # cached_buffer[2]
+ ]), buf4.obs
+ assert np.allclose(buf4.done, [
+ 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer
+ 0, 0, 0, 1, 0, # cached_buffer[0]
+ 0, 0, 0, 1, 0, # cached_buffer[1]
+ 0, 0, 0, 1, 0, # cached_buffer[2]
+ ]), buf4.done
+ assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
+ indice = sorted(buf4.sample_index(0))
+ assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
+ assert np.allclose(buf4[indice].obs[..., 0], [
+ [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14],
+ [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7],
+ [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11],
+ [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6],
+ [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12],
+ ])
+ assert np.allclose(buf4[indice].obs_next[..., 0], [
+ [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14],
+ [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8],
+ [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12],
+ [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7],
+ [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12],
+ ])
+ assert np.all(buf4.done == buf5.done)
+ indice = buf5.sample_index(0)
+ assert np.allclose(sorted(indice), [2, 7])
+ assert np.all(np.isin(buf5.sample_index(100), indice))
+ # manually change the stack num
+ buf5.stack_num = 2
+ for buf in buf5.buffers:
+ buf.stack_num = 2
+ indice = buf5.sample_index(0)
+ assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20])
+ batch, _ = buf5.sample(0)
+ assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1)
+ buf5.update_weight(indice, batch.weight * 0)
+ weight = buf5[np.arange(buf5.maxsize)].weight
+ modified_weight = weight[[0, 1, 2, 5, 6, 7]]
+ assert modified_weight.min() == modified_weight.max()
+ assert modified_weight.max() < 1
+ unmodified_weight = weight[[3, 4, 8]]
+ assert unmodified_weight.min() == unmodified_weight.max()
+ assert unmodified_weight.max() < 1
+ cached_weight = weight[9:]
+ assert cached_weight.min() == cached_weight.max() == 1
+ # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
+ buf6 = CachedReplayBuffer(
+ ReplayBuffer(bufsize, stack_num=stack_num,
+ save_only_last_obs=True, ignore_obs_next=True),
+ cached_num, size)
+ obs = np.random.rand(size, 4, 84, 84)
+ buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1],
+ obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2])
+ assert buf6.obs.shape == (buf6.maxsize, 84, 84)
+ assert np.allclose(buf6.obs[0], obs[0, -1])
+ assert np.allclose(buf6.obs[14], obs[2, -1])
+ assert np.allclose(buf6.obs[19], obs[0, -1])
+ assert buf6[0].obs.shape == (4, 84, 84)
+
+
+def test_multibuf_hdf5():
+ size = 100
+ buffers = {
+ "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]),
+ "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size)
+ }
+ buffer_types = {k: b.__class__ for k, b in buffers.items()}
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ info_t = torch.tensor([1.]).to(device)
+ for i in range(4):
+ kwargs = {
+ 'obs': Batch(index=np.array([i])),
+ 'act': i,
+ 'rew': np.array([1, 2]),
+ 'done': i % 3 == 2,
+ 'info': {"number": {"n": i, "t": info_t}, 'extra': None},
+ }
+ buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
+ buffer_ids=[0, 1, 2])
+ buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]),
+ cached_buffer_ids=[0, 1, 2])
+
+ # save
+ paths = {}
+ for k, buf in buffers.items():
+ f, path = tempfile.mkstemp(suffix='.hdf5')
+ os.close(f)
+ buf.save_hdf5(path)
+ paths[k] = path
+
+ # load replay buffer
+ _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
+
+ # compare
+ for k in buffers.keys():
+ assert len(_buffers[k]) == len(buffers[k])
+ assert np.allclose(_buffers[k].act, buffers[k].act)
+ assert _buffers[k].stack_num == buffers[k].stack_num
+ assert _buffers[k].maxsize == buffers[k].maxsize
+ assert np.all(_buffers[k]._indices == buffers[k]._indices)
+ # check shallow copy in ReplayBufferManager
+ for k in ["vector", "cached"]:
+ buffers[k].info.number.n[0] = -100
+ assert buffers[k].buffers[0].info.number.n[0] == -100
+ # check if still behave normally
+ for k in ["vector", "cached"]:
+ kwargs = {
+ 'obs': Batch(index=np.array([5])),
+ 'act': 5,
+ 'rew': np.array([2, 1]),
+ 'done': False,
+ 'info': {"number": {"n": i}, 'Timelimit.truncate': True},
+ }
+ buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]]))
+ act = np.zeros(buffers[k].maxsize)
+ if k == "vector":
+ act[np.arange(5)] = np.array([0, 1, 2, 3, 5])
+ act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5])
+ act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5])
+ act[size * 3] = 5
+ elif k == "cached":
+ act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
+ act[np.arange(3) + size] = np.array([3, 5, 2])
+ act[np.arange(3) + size * 2] = np.array([3, 5, 2])
+ act[np.arange(3) + size * 3] = np.array([3, 5, 2])
+ act[size * 4] = 5
+ assert np.allclose(buffers[k].act, act)
+
+ for path in paths.values():
+ os.remove(path)
+
+
if __name__ == '__main__':
- test_hdf5()
test_replaybuffer()
test_ignore_obs_next()
test_stack()
- test_pickle()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()
+ test_pickle()
+ test_hdf5()
+ test_replaybuffermanager()
+ test_cachedbuffer()
+ test_multibuf_stack()
+ test_multibuf_hdf5()
diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py
index d0987cdbf..1a48aaee9 100644
--- a/test/continuous/test_ppo.py
+++ b/test/continuous/test_ppo.py
@@ -18,7 +18,7 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
- parser.add_argument('--seed', type=int, default=1626)
+ parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py
index 7020df275..0d03fce97 100644
--- a/test/discrete/test_qrdqn.py
+++ b/test/discrete/test_qrdqn.py
@@ -16,7 +16,7 @@
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
- parser.add_argument('--seed', type=int, default=1626)
+ parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py
index e51b8d161..368427a19 100644
--- a/tianshou/data/__init__.py
+++ b/tianshou/data/__init__.py
@@ -1,8 +1,8 @@
from tianshou.data.batch import Batch
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
from tianshou.data.utils.segtree import SegmentTree
-from tianshou.data.buffer import ReplayBuffer, \
- ListReplayBuffer, PrioritizedReplayBuffer
+from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \
+ PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer
from tianshou.data.collector import Collector
__all__ = [
@@ -14,5 +14,7 @@
"ReplayBuffer",
"ListReplayBuffer",
"PrioritizedReplayBuffer",
+ "ReplayBufferManager",
+ "CachedReplayBuffer",
"Collector",
]
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 74299df9d..4c1ea14a6 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -1,5 +1,6 @@
import h5py
import torch
+import warnings
import numpy as np
from numbers import Number
from typing import Any, Dict, List, Tuple, Union, Optional
@@ -13,121 +14,13 @@ class ReplayBuffer:
""":class:`~tianshou.data.ReplayBuffer` stores data generated from \
interaction between the policy and environment.
- The current implementation of Tianshou typically use 7 reserved keys in
- :class:`~tianshou.data.Batch`:
-
- * ``obs`` the observation of step :math:`t` ;
- * ``act`` the action of step :math:`t` ;
- * ``rew`` the reward of step :math:`t` ;
- * ``done`` the done flag of step :math:`t` ;
- * ``obs_next`` the observation of step :math:`t+1` ;
- * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \
- function returns 4 arguments, and the last one is ``info``);
- * ``policy`` the data computed by policy in step :math:`t`;
-
- The following code snippet illustrates its usage:
- ::
-
- >>> import pickle, numpy as np
- >>> from tianshou.data import ReplayBuffer
- >>> buf = ReplayBuffer(size=20)
- >>> for i in range(3):
- ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
- >>> buf.obs
- # since we set size = 20, len(buf.obs) == 20.
- array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
- 0., 0., 0., 0.])
- >>> # but there are only three valid items, so len(buf) == 3.
- >>> len(buf)
- 3
- >>> # save to file "buf.pkl"
- >>> pickle.dump(buf, open('buf.pkl', 'wb'))
- >>> # save to HDF5 file
- >>> buf.save_hdf5('buf.hdf5')
- >>> buf2 = ReplayBuffer(size=10)
- >>> for i in range(15):
- ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
- >>> len(buf2)
- 10
- >>> buf2.obs
- # since its size = 10, it only stores the last 10 steps' result.
- array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
-
- >>> # move buf2's result into buf (meanwhile keep it chronologically)
- >>> buf.update(buf2)
- array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
- 0., 0., 0., 0., 0., 0., 0.])
-
- >>> # get a random sample from buffer
- >>> # the batch_data is equal to buf[indice].
- >>> batch_data, indice = buf.sample(batch_size=4)
- >>> batch_data.obs == buf[indice].obs
- array([ True, True, True, True])
- >>> len(buf)
- 13
- >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
- >>> len(buf)
- 3
- >>> # load complete buffer from HDF5 file
- >>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
- >>> len(buf)
- 3
- >>> # load contents of HDF5 file into existing buffer
- >>> # (only possible if size of buffer and data in file match)
- >>> buf.load_contents_hdf5('buf.hdf5')
- >>> len(buf)
- 3
-
- :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
- (typically for RNN usage, see issue#19), ignoring storing the next
- observation (save memory in atari tasks), and multi-modal observation (see
- issue#38):
- ::
-
- >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
- >>> for i in range(16):
- ... done = i % 5 == 0
- ... buf.add(obs={'id': i}, act=i, rew=i, done=done,
- ... obs_next={'id': i + 1})
- >>> print(buf) # you can see obs_next is not saved in buf
- ReplayBuffer(
- act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
- done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
- info: Batch(),
- obs: Batch(
- id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
- ),
- policy: Batch(),
- rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
- )
- >>> index = np.arange(len(buf))
- >>> print(buf.get(index, 'obs').id)
- [[ 7. 7. 8. 9.]
- [ 7. 8. 9. 10.]
- [11. 11. 11. 11.]
- [11. 11. 11. 12.]
- [11. 11. 12. 13.]
- [11. 12. 13. 14.]
- [12. 13. 14. 15.]
- [ 7. 7. 7. 7.]
- [ 7. 7. 7. 8.]]
- >>> # here is another way to get the stacked data
- >>> # (stack only for obs and obs_next)
- >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
- 0.0
- >>> # we can get obs_next through __getitem__, even if it doesn't exist
- >>> print(buf[:].obs_next.id)
- [[ 7. 8. 9. 10.]
- [ 7. 8. 9. 10.]
- [11. 11. 11. 12.]
- [11. 11. 12. 13.]
- [11. 12. 13. 14.]
- [12. 13. 14. 15.]
- [12. 13. 14. 15.]
- [ 7. 7. 7. 8.]
- [ 7. 7. 8. 9.]]
-
- :param int size: the size of replay buffer.
+ ReplayBuffer can be considered as a specialized form (or management) of
+ Batch. It stores all the data in a batch with circular-queue style.
+
+ For the example usage of ReplayBuffer, please check out Section Buffer in
+ :doc:`/tutorials/concepts`.
+
+ :param int size: the maximum size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than or equal to 1, defaults to 1 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to False.
@@ -136,9 +29,11 @@ class ReplayBuffer:
False.
:param bool sample_avail: the parameter indicating sampling only available
index when using frame-stack sampling method, defaults to False.
- This feature is not supported in Prioritized Replay Buffer currently.
"""
+ _reserved_keys = ("obs", "act", "rew", "done",
+ "obs_next", "info", "policy")
+
def __init__(
self,
size: int,
@@ -147,16 +42,20 @@ def __init__(
save_only_last_obs: bool = False,
sample_avail: bool = False,
) -> None:
+ self.options: Dict[str, Any] = {
+ "stack_num": stack_num,
+ "ignore_obs_next": ignore_obs_next,
+ "save_only_last_obs": save_only_last_obs,
+ "sample_avail": sample_avail,
+ }
super().__init__()
- self._maxsize = size
- self._indices = np.arange(size)
+ self.maxsize = size
+ assert stack_num > 0, "stack_num should greater than 0"
self.stack_num = stack_num
- self._avail = sample_avail and stack_num > 1
- self._avail_index: List[int] = []
- self._save_s_ = not ignore_obs_next
- self._last_obs = save_only_last_obs
- self._index = 0
- self._size = 0
+ self._indices = np.arange(size)
+ self._save_obs_next = not ignore_obs_next
+ self._save_only_last_obs = save_only_last_obs
+ self._sample_avail = sample_avail
self._meta: Batch = Batch()
self.reset()
@@ -181,20 +80,92 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
We need it because pickling buffer does not work out-of-the-box
("buffer.__getattr__" is customized).
"""
- self._indices = np.arange(state["_maxsize"])
self.__dict__.update(state)
+ # compatible with version == 0.3.1's HDF5 data format
+ self._indices = np.arange(self.maxsize)
+
+ def __setattr__(self, key: str, value: Any) -> None:
+ """Set self.key = value."""
+ assert key not in self._reserved_keys, (
+ "key '{}' is reserved and cannot be assigned".format(key))
+ super().__setattr__(key, value)
+
+ def save_hdf5(self, path: str) -> None:
+ """Save replay buffer to HDF5 file."""
+ with h5py.File(path, "w") as f:
+ to_hdf5(self.__dict__, f)
+
+ @classmethod
+ def load_hdf5(
+ cls, path: str, device: Optional[str] = None
+ ) -> "ReplayBuffer":
+ """Load replay buffer from HDF5 file."""
+ with h5py.File(path, "r") as f:
+ buf = cls.__new__(cls)
+ buf.__setstate__(from_hdf5(f, device=device))
+ return buf
+
+ def reset(self) -> None:
+ """Clear all the data in replay buffer and episode statistics."""
+ self._index = self._size = 0
+ self._episode_length, self._episode_reward = 0, 0.0
+
+ def set_batch(self, batch: Batch) -> None:
+ """Manually choose the batch you want the ReplayBuffer to manage."""
+ assert len(batch) == self.maxsize and \
+ set(batch.keys()).issubset(self._reserved_keys), \
+ "Input batch doesn't meet ReplayBuffer's data form requirement."
+ self._meta = batch
+
+ def unfinished_index(self) -> np.ndarray:
+ """Return the index of unfinished episode."""
+ last = (self._index - 1) % self._size if self._size else 0
+ return np.array(
+ [last] if not self.done[last] and self._size else [], np.int)
+
+ def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
+ """Return the index of previous transition.
+
+ The index won't be modified if it is the beginning of an episode.
+ """
+ index = (index - 1) % self._size
+ end_flag = self.done[index] | np.isin(index, self.unfinished_index())
+ return (index + end_flag) % self._size
- def __getstate__(self) -> dict:
- exclude = {"_indices"}
- state = {k: v for k, v in self.__dict__.items() if k not in exclude}
- return state
+ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
+ """Return the index of next transition.
+
+ The index won't be modified if it is the end of an episode.
+ """
+ end_flag = self.done[index] | np.isin(index, self.unfinished_index())
+ return (index + (1 - end_flag)) % self._size
+
+ def update(self, buffer: "ReplayBuffer") -> None:
+ """Move the data from the given buffer to current buffer."""
+ if len(buffer) == 0 or self.maxsize == 0:
+ return
+ stack_num, buffer.stack_num = buffer.stack_num, 1
+ save_only_last_obs = self._save_only_last_obs
+ self._save_only_last_obs = False
+ indices = buffer.sample_index(0) # get all available indices
+ for i in indices:
+ self.add(**buffer[i]) # type: ignore
+ buffer.stack_num = stack_num
+ self._save_only_last_obs = save_only_last_obs
+
+ def _buffer_allocator(self, key: List[str], value: Any) -> None:
+ """Allocate memory on buffer._meta for new (key, value) pair."""
+ data = self._meta
+ for k in key[:-1]:
+ data = data[k]
+ data[key[-1]] = _create_value(value, self.maxsize)
def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
value = self._meta.__dict__[name]
except KeyError:
- self._meta.__dict__[name] = _create_value(inst, self._maxsize)
- value = self._meta.__dict__[name]
+ self._buffer_allocator([name], inst)
+ value = self._meta[name]
if isinstance(inst, (torch.Tensor, np.ndarray)):
if inst.shape != value.shape[1:]:
raise ValueError(
@@ -203,33 +174,10 @@ def _add_to_buffer(self, name: str, inst: Any) -> None:
)
try:
value[self._index] = inst
- except KeyError:
- for key in set(inst.keys()).difference(value.__dict__.keys()):
- value.__dict__[key] = _create_value(inst[key], self._maxsize)
- value[self._index] = inst
-
- @property
- def stack_num(self) -> int:
- return self._stack
-
- @stack_num.setter
- def stack_num(self, num: int) -> None:
- assert num > 0, "stack_num should greater than 0"
- self._stack = num
-
- def update(self, buffer: "ReplayBuffer") -> None:
- """Move the data from the given buffer to self."""
- if len(buffer) == 0:
- return
- i = begin = buffer._index % len(buffer)
- stack_num_orig = buffer.stack_num
- buffer.stack_num = 1
- while True:
- self.add(**buffer[i]) # type: ignore
- i = (i + 1) % len(buffer)
- if i == begin:
- break
- buffer.stack_num = stack_num_orig
+ except KeyError: # inst is a dict/Batch
+ for key in set(inst.keys()).difference(value.keys()):
+ self._buffer_allocator([name, key], inst[key])
+ self._meta[name][self._index] = inst
def add(
self,
@@ -241,117 +189,110 @@ def add(
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
**kwargs: Any,
- ) -> None:
- """Add a batch of data into replay buffer."""
+ ) -> Tuple[int, Union[float, np.ndarray]]:
+ """Add a batch of data into replay buffer.
+
+ Return (episode_length, episode_reward) if one episode is terminated,
+ otherwise return (0, 0.0).
+ """
assert isinstance(
info, (dict, Batch)
), "You should return a dict in the last argument of env.step()."
- if self._last_obs:
+ if self._save_only_last_obs:
obs = obs[-1]
self._add_to_buffer("obs", obs)
self._add_to_buffer("act", act)
- # make sure the reward is a float instead of an int
- self._add_to_buffer("rew", rew * 1.0) # type: ignore
- self._add_to_buffer("done", done)
- if self._save_s_:
+ # make sure the data type of reward is float instead of int
+ # but rew may be np.ndarray, so that we cannot use float(rew)
+ rew = rew * 1.0 # type: ignore
+ self._add_to_buffer("rew", rew)
+ self._add_to_buffer("done", bool(done)) # done should be a bool scalar
+ if self._save_obs_next:
if obs_next is None:
obs_next = Batch()
- elif self._last_obs:
+ elif self._save_only_last_obs:
obs_next = obs_next[-1]
self._add_to_buffer("obs_next", obs_next)
self._add_to_buffer("info", info)
self._add_to_buffer("policy", policy)
- # maintain available index for frame-stack sampling
- if self._avail:
- # update current frame
- avail = sum(self.done[i] for i in range(
- self._index - self.stack_num + 1, self._index)) == 0
- if self._size < self.stack_num - 1:
- avail = False
- if avail and self._index not in self._avail_index:
- self._avail_index.append(self._index)
- elif not avail and self._index in self._avail_index:
- self._avail_index.remove(self._index)
- # remove the later available frame because of broken storage
- t = (self._index + self.stack_num - 1) % self._maxsize
- if t in self._avail_index:
- self._avail_index.remove(t)
-
- if self._maxsize > 0:
- self._size = min(self._size + 1, self._maxsize)
- self._index = (self._index + 1) % self._maxsize
+ if self.maxsize > 0:
+ self._size = min(self._size + 1, self.maxsize)
+ self._index = (self._index + 1) % self.maxsize
+ else: # TODO: remove this after deleting ListReplayBuffer
+ self._size = self._index = self._size + 1
+
+ self._episode_reward += rew
+ self._episode_length += 1
+
+ if done:
+ result = self._episode_length, self._episode_reward
+ self._episode_length, self._episode_reward = 0, 0.0
+ return result
else:
- self._size = self._index = self._index + 1
+ return 0, self._episode_reward * 0.0
- def reset(self) -> None:
- """Clear all the data in replay buffer."""
- self._index = 0
- self._size = 0
- self._avail_index = []
+ def sample_index(self, batch_size: int) -> np.ndarray:
+ """Get a random sample of index with size = batch_size.
+
+ Return all available indices in the buffer if batch_size is 0; return
+ an empty numpy array if batch_size < 0 or no available index can be
+ sampled.
+ """
+ if self.stack_num == 1 or not self._sample_avail: # most often case
+ if batch_size > 0:
+ return np.random.choice(self._size, batch_size)
+ elif batch_size == 0: # construct current available indices
+ return np.concatenate([
+ np.arange(self._index, self._size),
+ np.arange(self._index)])
+ else:
+ return np.array([], np.int)
+ else:
+ if batch_size < 0:
+ return np.array([], np.int)
+ all_indices = prev_indices = np.concatenate([
+ np.arange(self._index, self._size), np.arange(self._index)])
+ for _ in range(self.stack_num - 2):
+ prev_indices = self.prev(prev_indices)
+ all_indices = all_indices[prev_indices != self.prev(prev_indices)]
+ if batch_size > 0:
+ return np.random.choice(all_indices, batch_size)
+ else:
+ return all_indices
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
- """Get a random sample from buffer with size equal to batch_size.
+ """Get a random sample from buffer with size = batch_size.
Return all the data in the buffer if batch_size is 0.
:return: Sample data and its corresponding index inside the buffer.
"""
- if batch_size > 0:
- _all = self._avail_index if self._avail else self._size
- indice = np.random.choice(_all, batch_size)
- else:
- if self._avail:
- indice = np.array(self._avail_index)
- else:
- indice = np.concatenate([
- np.arange(self._index, self._size),
- np.arange(0, self._index),
- ])
- assert len(indice) > 0, "No available indice can be sampled."
- return self[indice], indice
+ indices = self.sample_index(batch_size)
+ return self[indices], indices
def get(
self,
- indice: Union[slice, int, np.integer, np.ndarray],
+ index: Union[int, np.integer, np.ndarray],
key: str,
stack_num: Optional[int] = None,
) -> Union[Batch, np.ndarray]:
"""Return the stacked result.
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
- indice. The stack_num (here equals to 4) is given from buffer
- initialization procedure.
+ index.
"""
if stack_num is None:
stack_num = self.stack_num
- if stack_num == 1: # the most often case
- if key != "obs_next" or self._save_s_:
- val = self._meta.__dict__[key]
- try:
- return val[indice]
- except IndexError as e:
- if not (isinstance(val, Batch) and val.is_empty()):
- raise e # val != Batch()
- return Batch()
- indice = self._indices[:self._size][indice]
- done = self._meta.__dict__["done"]
- if key == "obs_next" and not self._save_s_:
- indice += 1 - done[indice].astype(np.int)
- indice[indice == self._size] = 0
- key = "obs"
- val = self._meta.__dict__[key]
+ val = self._meta[key]
try:
- if stack_num == 1:
- return val[indice]
+ if stack_num == 1: # the most often case
+ return val[index]
stack: List[Any] = []
+ indice = np.asarray(index)
for _ in range(stack_num):
stack = [val[indice]] + stack
- pre_indice = np.asarray(indice - 1)
- pre_indice[pre_indice == -1] = self._size - 1
- indice = np.asarray(
- pre_indice + done[pre_indice].astype(np.int))
- indice[indice == self._size] = 0
+ indice = self.prev(indice)
if isinstance(val, Batch):
return Batch.stack(stack, axis=indice.ndim)
else:
@@ -369,31 +310,24 @@ def __getitem__(
If stack_num is larger than 1, return the stacked obs and obs_next with
shape (batch, len, ...).
"""
+ if isinstance(index, slice): # change slice to np array
+ index = self._indices[:len(self)][index]
+ # raise KeyError first instead of AttributeError, to support np.array
+ obs = self.get(index, "obs")
+ if self._save_obs_next:
+ obs_next = self.get(index, "obs_next")
+ else:
+ obs_next = self.get(self.next(index), "obs")
return Batch(
- obs=self.get(index, "obs"),
+ obs=obs,
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
- obs_next=self.get(index, "obs_next"),
+ obs_next=obs_next,
info=self.get(index, "info"),
policy=self.get(index, "policy"),
)
- def save_hdf5(self, path: str) -> None:
- """Save replay buffer to HDF5 file."""
- with h5py.File(path, "w") as f:
- to_hdf5(self.__getstate__(), f)
-
- @classmethod
- def load_hdf5(
- cls, path: str, device: Optional[str] = None
- ) -> "ReplayBuffer":
- """Load replay buffer from HDF5 file."""
- with h5py.File(path, "r") as f:
- buf = cls.__new__(cls)
- buf.__setstate__(from_hdf5(f, device=device))
- return buf
-
class ListReplayBuffer(ReplayBuffer):
"""List-based replay buffer.
@@ -411,24 +345,27 @@ class ListReplayBuffer(ReplayBuffer):
"""
def __init__(self, **kwargs: Any) -> None:
+ warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.")
super().__init__(size=0, ignore_obs_next=False, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
- def _add_to_buffer(
- self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]
- ) -> None:
- if self._meta.__dict__.get(name) is None:
+ def _add_to_buffer(self, name: str, inst: Any) -> None:
+ if self._meta.get(name) is None:
self._meta.__dict__[name] = []
- self._meta.__dict__[name].append(inst)
+ self._meta[name].append(inst)
def reset(self) -> None:
- self._index = self._size = 0
- for k in list(self._meta.__dict__.keys()):
- if isinstance(self._meta.__dict__[k], list):
+ super().reset()
+ for k in self._meta.keys():
+ if isinstance(self._meta[k], list):
self._meta.__dict__[k] = []
+ def update(self, buffer: ReplayBuffer) -> None:
+ """The ListReplayBuffer cannot be updated by any buffer."""
+ raise NotImplementedError
+
class PrioritizedReplayBuffer(ReplayBuffer):
"""Implementation of Prioritized Experience Replay. arXiv:1511.05952.
@@ -464,8 +401,7 @@ def add(
policy: Optional[Union[dict, Batch]] = {},
weight: Optional[Union[Number, np.number]] = None,
**kwargs: Any,
- ) -> None:
- """Add a batch of data into replay buffer."""
+ ) -> Tuple[int, Union[float, np.ndarray]]:
if weight is None:
weight = self._max_prio
else:
@@ -473,60 +409,289 @@ def add(
self._max_prio = max(self._max_prio, weight)
self._min_prio = min(self._min_prio, weight)
self.weight[self._index] = weight ** self._alpha
- super().add(obs, act, rew, done, obs_next, info, policy, **kwargs)
+ return super().add(obs, act, rew, done, obs_next,
+ info, policy, **kwargs)
- def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
- """Get a random sample from buffer with priority probability.
-
- Return all the data in the buffer if batch_size is 0.
+ def sample_index(self, batch_size: int) -> np.ndarray:
+ if batch_size > 0 and self._size > 0:
+ scalar = np.random.rand(batch_size) * self.weight.reduce()
+ return self.weight.get_prefix_sum_idx(scalar)
+ else:
+ return super().sample_index(batch_size)
- :return: Sample data and its corresponding index inside the buffer.
+ def get_weight(
+ self, index: Union[slice, int, np.integer, np.ndarray]
+ ) -> np.ndarray:
+ """Get the importance sampling weight.
The "weight" in the returned Batch is the weight on loss function
to de-bias the sampling process (some transition tuples are sampled
more often so their losses are weighted less).
"""
- assert self._size > 0, "Cannot sample a buffer with 0 size!"
- if batch_size == 0:
- indice = np.concatenate([
- np.arange(self._index, self._size),
- np.arange(0, self._index),
- ])
- else:
- scalar = np.random.rand(batch_size) * self.weight.reduce()
- indice = self.weight.get_prefix_sum_idx(scalar)
- batch = self[indice]
# important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
- batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
- return batch, indice
+ return (self.weight[index] / self._min_prio) ** (-self._beta)
def update_weight(
self,
- indice: Union[np.ndarray],
- new_weight: Union[np.ndarray, torch.Tensor]
+ index: np.ndarray,
+ new_weight: Union[np.ndarray, torch.Tensor],
) -> None:
- """Update priority weight by indice in this buffer.
+ """Update priority weight by index in this buffer.
- :param np.ndarray indice: indice you want to update weight.
+ :param np.ndarray index: index you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
weight = np.abs(to_numpy(new_weight)) + self.__eps
- self.weight[indice] = weight ** self._alpha
+ self.weight[index] = weight ** self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> Batch:
- return Batch(
- obs=self.get(index, "obs"),
- act=self.act[index],
- rew=self.rew[index],
- done=self.done[index],
- obs_next=self.get(index, "obs_next"),
- info=self.get(index, "info"),
- policy=self.get(index, "policy"),
- weight=self.weight[index],
- )
+ batch = super().__getitem__(index)
+ batch.weight = self.get_weight(index)
+ return batch
+
+
+class ReplayBufferManager(ReplayBuffer):
+ """ReplayBufferManager contains a list of ReplayBuffer.
+
+ These replay buffers have contiguous memory layout, and the storage space
+ each buffer has is a shallow copy of the topmost memory.
+
+ :param int buffer_list: a list of ReplayBuffers needed to be handled.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
+ explanation.
+ """
+
+ def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None:
+ self.buffer_num = len(buffer_list)
+ self.buffers = buffer_list
+ self._offset = []
+ offset = 0
+ for buf in self.buffers:
+ # overwrite sub-buffers' _buffer_allocator so that
+ # the top buffer can allocate new memory for all sub-buffers
+ buf._buffer_allocator = self._buffer_allocator # type: ignore
+ assert buf._meta.is_empty()
+ self._offset.append(offset)
+ offset += buf.maxsize
+ super().__init__(size=offset, **kwargs)
+
+ def __len__(self) -> int:
+ return sum([len(buf) for buf in self.buffers])
+
+ def reset(self) -> None:
+ for buf in self.buffers:
+ buf.reset()
+
+ def _set_batch_for_children(self) -> None:
+ for offset, buf in zip(self._offset, self.buffers):
+ buf.set_batch(self._meta[offset:offset + buf.maxsize])
+
+ def set_batch(self, batch: Batch) -> None:
+ super().set_batch(batch)
+ self._set_batch_for_children()
+
+ def unfinished_index(self) -> np.ndarray:
+ return np.concatenate([
+ buf.unfinished_index() + offset
+ for offset, buf in zip(self._offset, self.buffers)])
+
+ def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
+ index = np.asarray(index) % self.maxsize
+ prev_indices = np.zeros_like(index)
+ for offset, buf in zip(self._offset, self.buffers):
+ mask = (offset <= index) & (index < offset + buf.maxsize)
+ if np.any(mask):
+ prev_indices[mask] = buf.prev(index[mask] - offset) + offset
+ return prev_indices
+
+ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
+ index = np.asarray(index) % self.maxsize
+ next_indices = np.zeros_like(index)
+ for offset, buf in zip(self._offset, self.buffers):
+ mask = (offset <= index) & (index < offset + buf.maxsize)
+ if np.any(mask):
+ next_indices[mask] = buf.next(index[mask] - offset) + offset
+ return next_indices
+
+ def update(self, buffer: ReplayBuffer) -> None:
+ """The ReplayBufferManager cannot be updated by any buffer."""
+ raise NotImplementedError
+
+ def _buffer_allocator(self, key: List[str], value: Any) -> None:
+ super()._buffer_allocator(key, value)
+ self._set_batch_for_children()
+
+ def add( # type: ignore
+ self,
+ obs: Any,
+ act: Any,
+ rew: np.ndarray,
+ done: np.ndarray,
+ obs_next: Any = Batch(),
+ info: Optional[Batch] = Batch(),
+ policy: Optional[Batch] = Batch(),
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
+ **kwargs: Any
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Add a batch of data into ReplayBufferManager.
+
+ Each of the data's length (first dimension) must equal to the length of
+ buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
+
+ Return the array of episode_length and episode_reward with shape
+ (len(buffer_ids), ...), where (episode_length[i], episode_reward[i])
+ refers to the buffer_ids[i]'s corresponding episode result.
+ """
+ if buffer_ids is None:
+ buffer_ids = np.arange(self.buffer_num)
+ # assume each element in buffer_ids is unique
+ assert np.bincount(buffer_ids).max() == 1
+ batch = Batch(obs=obs, act=act, rew=rew, done=done,
+ obs_next=obs_next, info=info, policy=policy)
+ assert len(buffer_ids) == len(batch)
+ episode_lengths = [] # (len(buffer_ids),)
+ episode_rewards = [] # (len(buffer_ids), ...)
+ for batch_idx, buffer_id in enumerate(buffer_ids):
+ length, reward = self.buffers[buffer_id].add(**batch[batch_idx])
+ episode_lengths.append(length)
+ episode_rewards.append(reward)
+ return np.stack(episode_lengths), np.stack(episode_rewards)
+
+ def sample_index(self, batch_size: int) -> np.ndarray:
+ if batch_size < 0:
+ return np.array([], np.int)
+ if self._sample_avail and self.stack_num > 1:
+ all_indices = np.concatenate([
+ buf.sample_index(0) + offset
+ for offset, buf in zip(self._offset, self.buffers)])
+ if batch_size == 0:
+ return all_indices
+ else:
+ return np.random.choice(all_indices, batch_size)
+ if batch_size == 0: # get all available indices
+ sample_num = np.zeros(self.buffer_num, np.int)
+ else:
+ buffer_lens = np.array([len(buf) for buf in self.buffers])
+ buffer_idx = np.random.choice(self.buffer_num, batch_size,
+ p=buffer_lens / buffer_lens.sum())
+ sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
+ # avoid batch_size > 0 and sample_num == 0 -> get child's all data
+ sample_num[sample_num == 0] = -1
+
+ return np.concatenate([
+ buf.sample_index(bsz) + offset
+ for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
+ ])
+
+
+class CachedReplayBuffer(ReplayBufferManager):
+ """CachedReplayBuffer contains a given main buffer and n cached buffers, \
+ cached_buffer_num * ReplayBuffer(size=max_episode_length).
+
+ The memory layout is: ``| main_buffer | cached_buffers[0] |
+ cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``.
+
+ The data is first stored in cached buffers. When the episode is
+ terminated, the data will move to the main buffer and the corresponding
+ cached buffer will be reset.
+
+ :param ReplayBuffer main_buffer: the main buffer whose ``.update()``
+ function behaves normally.
+ :param int cached_buffer_num: number of ReplayBuffer needs to be created
+ for cached buffer.
+ :param int max_episode_length: the maximum length of one episode, used in
+ each cached buffer's maxsize.
+
+ .. seealso::
+
+ Please refer to :class:`~tianshou.data.ReplayBuffer` or
+ :class:`~tianshou.data.ReplayBufferManager` for more detailed
+ explanation.
+ """
+
+ def __init__(
+ self,
+ main_buffer: ReplayBuffer,
+ cached_buffer_num: int,
+ max_episode_length: int,
+ ) -> None:
+ assert cached_buffer_num > 0 and max_episode_length > 0
+ self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer)
+ kwargs = main_buffer.options
+ buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs)
+ for _ in range(cached_buffer_num)]
+ super().__init__(buffer_list=buffers, **kwargs)
+ self.main_buffer = self.buffers[0]
+ self.cached_buffers = self.buffers[1:]
+ self.cached_buffer_num = cached_buffer_num
+
+ def add( # type: ignore
+ self,
+ obs: Any,
+ act: Any,
+ rew: np.ndarray,
+ done: np.ndarray,
+ obs_next: Any = Batch(),
+ info: Optional[Batch] = Batch(),
+ policy: Optional[Batch] = Batch(),
+ cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None,
+ **kwargs: Any,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """Add a batch of data into CachedReplayBuffer.
+
+ Each of the data's length (first dimension) must equal to the length of
+ cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ...,
+ cached_buffer_num - 1].
+
+ Return the array of episode_length and episode_reward with shape
+ (len(cached_buffer_ids), ...), where (episode_length[i],
+ episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's
+ corresponding episode result.
+ """
+ if cached_buffer_ids is None:
+ cached_buffer_ids = np.arange(self.cached_buffer_num)
+ else: # make sure it is np.ndarray
+ cached_buffer_ids = np.asarray(cached_buffer_ids)
+ # in self.buffers, the first buffer is main_buffer
+ buffer_ids = cached_buffer_ids + 1 # type: ignore
+ result = super().add(obs, act, rew, done, obs_next, info,
+ policy, buffer_ids=buffer_ids, **kwargs)
+ # find the terminated episode, move data from cached buf to main buf
+ for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]:
+ self.main_buffer.update(self.cached_buffers[buffer_idx])
+ self.cached_buffers[buffer_idx].reset()
+ return result
+
+ def __getitem__(
+ self, index: Union[slice, int, np.integer, np.ndarray]
+ ) -> Batch:
+ batch = super().__getitem__(index)
+ if self._is_prioritized:
+ indice = self._indices[index]
+ mask = indice < self.main_buffer.maxsize
+ batch.weight = np.ones(len(indice))
+ batch.weight[mask] = self.main_buffer.get_weight(indice[mask])
+ return batch
+
+ def update_weight(
+ self,
+ index: np.ndarray,
+ new_weight: Union[np.ndarray, torch.Tensor],
+ ) -> None:
+ """Update priority weight by index in main buffer.
+
+ :param np.ndarray index: index you want to update weight.
+ :param np.ndarray new_weight: new priority weight you want to update.
+ """
+ if self._is_prioritized:
+ mask = index < self.main_buffer.maxsize
+ self.main_buffer.update_weight(index[mask], new_weight[mask])