diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 8ea2d272e..a314cdedb 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair Buffer ------ -.. automodule:: tianshou.data.ReplayBuffer - :members: - :noindex: +:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style. -Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. +The current implementation of Tianshou typically use 7 reserved keys in +:class:`~tianshou.data.Batch`: + +* ``obs`` the observation of step :math:`t` ; +* ``act`` the action of step :math:`t` ; +* ``rew`` the reward of step :math:`t` ; +* ``done`` the done flag of step :math:`t` ; +* ``obs_next`` the observation of step :math:`t+1` ; +* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``); +* ``policy`` the data computed by policy in step :math:`t`; + +The following code snippet illustrates its usage, including: + +- the basic data storage: ``add()``; +- get attribute, get slicing data, ...; +- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``; +- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``; +- save/load data from buffer: pickle and HDF5; + +:: + + >>> import pickle, numpy as np + >>> from tianshou.data import ReplayBuffer + >>> buf = ReplayBuffer(size=20) + >>> for i in range(3): + ... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}) + + >>> buf.obs + # since we set size = 20, len(buf.obs) == 20. + array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + >>> # but there are only three valid items, so len(buf) == 3. + >>> len(buf) + 3 + >>> # save to file "buf.pkl" + >>> pickle.dump(buf, open('buf.pkl', 'wb')) + >>> # save to HDF5 file + >>> buf.save_hdf5('buf.hdf5') + + >>> buf2 = ReplayBuffer(size=10) + >>> for i in range(15): + ... done = i % 4 == 0 + ... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}) + >>> len(buf2) + 10 + >>> buf2.obs + # since its size = 10, it only stores the last 10 steps' result. + array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9]) + + >>> # move buf2's result into buf (meanwhile keep it chronologically) + >>> buf.update(buf2) + >>> buf.obs + array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, + 0, 0, 0, 0]) + + >>> # get all available index by using batch_size = 0 + >>> indice = buf.sample_index(0) + >>> indice + array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + >>> # get one step previous/next transition + >>> buf.prev(indice) + array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11]) + >>> buf.next(indice) + array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12]) + + >>> # get a random sample from buffer + >>> # the batch_data is equal to buf[indice]. + >>> batch_data, indice = buf.sample(batch_size=4) + >>> batch_data.obs == buf[indice].obs + array([ True, True, True, True]) + >>> len(buf) + 13 + + >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" + >>> len(buf) + 3 + >>> # load complete buffer from HDF5 file + >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') + >>> len(buf) + 3 + +:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38): + +.. raw:: html + +
+ Advance usage of ReplayBuffer + +.. code-block:: python + + >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) + >>> for i in range(16): + ... done = i % 5 == 0 + ... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i, + ... done=done, obs_next={'id': i + 1}) + ... print(i, ep_len, ep_rew) + 0 1 0.0 + 1 0 0.0 + 2 0 0.0 + 3 0 0.0 + 4 0 0.0 + 5 5 15.0 + 6 0 0.0 + 7 0 0.0 + 8 0 0.0 + 9 0 0.0 + 10 5 40.0 + 11 0 0.0 + 12 0 0.0 + 13 0 0.0 + 14 0 0.0 + 15 5 65.0 + >>> print(buf) # you can see obs_next is not saved in buf + ReplayBuffer( + obs: Batch( + id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), + ), + act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), + rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), + done: array([False, True, False, False, False, False, True, False, + False]), + info: Batch(), + policy: Batch(), + ) + >>> index = np.arange(len(buf)) + >>> print(buf.get(index, 'obs').id) + [[ 7 7 8 9] + [ 7 8 9 10] + [11 11 11 11] + [11 11 11 12] + [11 11 12 13] + [11 12 13 14] + [12 13 14 15] + [ 7 7 7 7] + [ 7 7 7 8]] + >>> # here is another way to get the stacked data + >>> # (stack only for obs and obs_next) + >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() + 0 + >>> # we can get obs_next through __getitem__, even if it doesn't exist + >>> print(buf[:].obs_next.id) + [[ 7 8 9 10] + [ 7 8 9 10] + [11 11 11 12] + [11 11 12 13] + [11 12 13 14] + [12 13 14 15] + [12 13 14 15] + [ 7 7 7 8] + [ 7 7 8 9]] + +.. raw:: html + +

+ +Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. Policy diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 1695195b4..04b1928e2 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -7,9 +7,10 @@ import numpy as np from timeit import timeit -from tianshou.data import Batch, SegmentTree, \ - ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data.utils.converter import to_hdf5 +from tianshou.data import Batch, SegmentTree, ReplayBuffer +from tianshou.data import ListReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import ReplayBufferManager, CachedReplayBuffer if __name__ == '__main__': from env import MyTestEnv @@ -38,11 +39,14 @@ def test_replaybuffer(size=10, bufsize=20): assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) - b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) + # neg bsz should return empty index + assert b.sample_index(-1).tolist() == [] + b.add(1, 1, 1, 1, 'str', {'a': 3, 'b': {'c': 5.0}}) assert b.obs[0] == 1 - assert b.done[0] == 'str' + assert b.done[0] + assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) - assert np.all(b.done[1:] == np.array(None)) + assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact @@ -91,7 +95,7 @@ def test_ignore_obs_next(size=10): assert data.obs_next -def test_stack(size=5, bufsize=9, stack_num=4): +def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) @@ -115,7 +119,9 @@ def test_stack(size=5, bufsize=9, stack_num=4): _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) - assert indice in [2, 6] + assert indice[0] in [2, 6] + batch, indice = buf2.sample(-1) # neg bsz -> no data + assert indice.tolist() == [] and len(batch) == 0 with pytest.raises(IndexError): buf[bufsize * 2] @@ -152,6 +158,12 @@ def test_update(): assert len(buf1) == len(buf2) assert (buf2[0].obs == buf1[1].obs).all() assert (buf2[-1].obs == buf1[0].obs).all() + b = ListReplayBuffer() + with pytest.raises(NotImplementedError): + b.update(b) + b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) + with pytest.raises(NotImplementedError): + b.update(b) def test_segtree(): @@ -260,8 +272,7 @@ def test_pickle(): vbuf = ReplayBuffer(size, stack_num=2) lbuf = ListReplayBuffer() pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - rew = torch.tensor([1.]).to(device) + rew = np.array([1, 1]) for i in range(4): vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) for i in range(3): @@ -287,18 +298,18 @@ def test_hdf5(): buffers = { "array": ReplayBuffer(size, stack_num=2), "list": ListReplayBuffer(), - "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4) + "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' - rew = torch.tensor([1.]).to(device) + info_t = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, - 'rew': rew, - 'done': 0, - 'info': {"number": {"n": i}, 'extra': None}, + 'rew': np.array([1, 2]), + 'done': i % 3 == 2, + 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } buffers["array"].add(**kwargs) buffers["list"].add(**kwargs) @@ -320,10 +331,10 @@ def test_hdf5(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num - assert _buffers[k]._maxsize == buffers[k]._maxsize - assert _buffers[k]._index == buffers[k]._index + assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: + assert _buffers[k]._index == buffers[k]._index assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: @@ -332,28 +343,350 @@ def test_hdf5(): assert np.all( buffers[k][:].info.extra == _buffers[k][:].info.extra) - for path in paths.values(): - os.remove(path) - # raise exception when value cannot be pickled - data = {"not_supported": lambda x: x*x} + data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): to_hdf5(data, grp) # ndarray with data type not supported by HDF5 that cannot be pickled - data = {"not_supported": np.array(lambda x: x*x)} + data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): to_hdf5(data, grp) +def test_replaybuffermanager(): + buf = ReplayBufferManager([ReplayBuffer(size=5) for i in range(4)]) + ep_len, ep_rew = buf.add(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], + done=[0, 0, 1], buffer_ids=[0, 1, 2]) + assert np.allclose(ep_len, [0, 0, 1]) and np.allclose(ep_rew, [0, 0, 3]) + with pytest.raises(NotImplementedError): + # ReplayBufferManager cannot be updated + buf.update(buf) + # sample index / prev / next / unfinished_index + indice = buf.sample_index(11000) + assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample + batch, indice = buf.sample(0) + assert np.allclose(indice, [0, 5, 10]) + indice_prev = buf.prev(indice) + assert np.allclose(indice_prev, indice), indice_prev + indice_next = buf.next(indice) + assert np.allclose(indice_next, indice), indice_next + assert np.allclose(buf.unfinished_index(), [0, 5]) + buf.add(obs=[4], act=[4], rew=[4], done=[1], buffer_ids=[3]) + assert np.allclose(buf.unfinished_index(), [0, 5]) + batch, indice = buf.sample(10) + batch, indice = buf.sample(0) + assert np.allclose(indice, [0, 5, 10, 15]) + indice_prev = buf.prev(indice) + assert np.allclose(indice_prev, indice), indice_prev + indice_next = buf.next(indice) + assert np.allclose(indice_next, indice), indice_next + data = np.array([0, 0, 0, 0]) + buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=1 - data, + buffer_ids=[0, 1, 2, 3]) + assert len(buf) == 12 + buf.add(obs=data, act=data, rew=data, done=data, buffer_ids=[0, 1, 2, 3]) + buf.add(obs=data, act=data, rew=data, done=[0, 1, 0, 1], + buffer_ids=[0, 1, 2, 3]) + assert len(buf) == 20 + indice = buf.sample_index(120000) + assert np.bincount(indice).min() >= 5000 + batch, indice = buf.sample(10) + indice = buf.sample_index(0) + assert np.allclose(indice, np.arange(len(buf))) + # check the actual data stored in buf._meta + assert np.allclose(buf.done, [ + 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, + 1, 0, 1, 0, 0, + 1, 0, 1, 0, 1, + ]) + assert np.allclose(buf.prev(indice), [ + 0, 0, 1, 3, 3, + 5, 5, 6, 8, 8, + 10, 11, 11, 13, 13, + 15, 16, 16, 18, 18, + ]) + assert np.allclose(buf.next(indice), [ + 1, 2, 2, 4, 4, + 6, 7, 7, 9, 9, + 10, 12, 12, 14, 14, + 15, 17, 17, 19, 19, + ]) + assert np.allclose(buf.unfinished_index(), [4, 14]) + ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[1], + buffer_ids=[2]) + assert np.allclose(ep_len, [3]) and np.allclose(ep_rew, [1]) + assert np.allclose(buf.unfinished_index(), [4]) + indice = list(sorted(buf.sample_index(0))) + assert np.allclose(indice, np.arange(len(buf))) + assert np.allclose(buf.prev(indice), [ + 0, 0, 1, 3, 3, + 5, 5, 6, 8, 8, + 14, 11, 11, 13, 13, + 15, 16, 16, 18, 18, + ]) + assert np.allclose(buf.next(indice), [ + 1, 2, 2, 4, 4, + 6, 7, 7, 9, 9, + 10, 12, 12, 14, 10, + 15, 17, 17, 19, 19, + ]) + # corner case: list, int and -1 + assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] + assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] + batch = buf._meta + batch.info.n = np.ones(buf.maxsize) + buf.set_batch(batch) + assert np.allclose(buf.buffers[-1].info.n, [1] * 5) + assert buf.sample_index(-1).tolist() == [] + assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object + + +def test_cachedbuffer(): + buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) + assert buf.sample_index(0).tolist() == [] + # check the normal function/usage/storage in CachedReplayBuffer + ep_len, ep_rew = buf.add(obs=[1], act=[1], rew=[1], done=[0], + cached_buffer_ids=[1]) + obs = np.zeros(buf.maxsize) + obs[15] = 1 + indice = buf.sample_index(0) + assert np.allclose(indice, [15]) + assert np.allclose(buf.prev(indice), [15]) + assert np.allclose(buf.next(indice), [15]) + assert np.allclose(buf.obs, obs) + assert np.allclose(ep_len, [0]) and np.allclose(ep_rew, [0.0]) + ep_len, ep_rew = buf.add(obs=[2], act=[2], rew=[2], done=[1], + cached_buffer_ids=[3]) + obs[[0, 25]] = 2 + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 15]) + assert np.allclose(buf.prev(indice), [0, 15]) + assert np.allclose(buf.next(indice), [0, 15]) + assert np.allclose(buf.obs, obs) + assert np.allclose(ep_len, [1]) and np.allclose(ep_rew, [2.0]) + assert np.allclose(buf.unfinished_index(), [15]) + assert np.allclose(buf.sample_index(0), [0, 15]) + ep_len, ep_rew = buf.add(obs=[3, 4], act=[3, 4], rew=[3, 4], + done=[0, 1], cached_buffer_ids=[3, 1]) + assert np.allclose(ep_len, [0, 2]) and np.allclose(ep_rew, [0, 5.0]) + obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] + assert np.allclose(buf.obs, obs) + assert np.allclose(buf.unfinished_index(), [25]) + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 1, 2, 25]) + assert np.allclose(buf.done[indice], [1, 0, 1, 0]) + assert np.allclose(buf.prev(indice), [0, 1, 1, 25]) + assert np.allclose(buf.next(indice), [0, 2, 2, 25]) + indice = buf.sample_index(10000) + assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample + # cached buffer with main_buffer size == 0 (no update) + # used in test_collector + buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) + data = np.zeros(4) + rew = np.ones([4, 4]) + buf.add(obs=data, act=data, rew=rew, done=[0, 0, 1, 1], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[1, 1, 1, 1], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[0, 0, 0, 0], obs_next=data) + buf.add(obs=data, act=data, rew=rew, done=[0, 1, 0, 1], obs_next=data) + assert np.allclose(buf.done, [ + 0, 0, 1, 0, 0, + 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, + ]) + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 1, 10, 11]) + assert np.allclose(buf.prev(indice), [0, 0, 10, 10]) + assert np.allclose(buf.next(indice), [1, 1, 11, 11]) + + +def test_multibuf_stack(): + size = 5 + bufsize = 9 + stack_num = 4 + cached_num = 3 + env = MyTestEnv(size) + # test if CachedReplayBuffer can handle stack_num + ignore_obs_next + buf4 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), + cached_num, size) + # test if CachedReplayBuffer can handle super corner case: + # prio-buffer + stack_num + ignore_obs_next + sample_avail + buf5 = CachedReplayBuffer( + PrioritizedReplayBuffer(bufsize, 0.6, 0.4, stack_num=stack_num, + ignore_obs_next=True, sample_avail=True), + cached_num, size) + obs = env.reset(1) + for i in range(18): + obs_next, rew, done, info = env.step(1) + obs_list = np.array([obs + size * i for i in range(cached_num)]) + act_list = [1] * cached_num + rew_list = [rew] * cached_num + done_list = [done] * cached_num + obs_next_list = -obs_list + info_list = [info] * cached_num + buf4.add(obs_list, act_list, rew_list, done_list, + obs_next_list, info_list) + buf5.add(obs_list, act_list, rew_list, done_list, + obs_next_list, info_list) + obs = obs_next + if done: + obs = env.reset(1) + # check the `add` order is correct + assert np.allclose(buf4.obs.reshape(-1), [ + 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer + 1, 2, 3, 4, 0, # cached_buffer[0] + 6, 7, 8, 9, 0, # cached_buffer[1] + 11, 12, 13, 14, 0, # cached_buffer[2] + ]), buf4.obs + assert np.allclose(buf4.done, [ + 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer + 0, 0, 0, 1, 0, # cached_buffer[0] + 0, 0, 0, 1, 0, # cached_buffer[1] + 0, 0, 0, 1, 0, # cached_buffer[2] + ]), buf4.done + assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) + indice = sorted(buf4.sample_index(0)) + assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) + assert np.allclose(buf4[indice].obs[..., 0], [ + [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], + [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], + [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], + [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], + ]) + assert np.allclose(buf4[indice].obs_next[..., 0], [ + [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], + [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], + [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], + [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], + ]) + assert np.all(buf4.done == buf5.done) + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [2, 7]) + assert np.all(np.isin(buf5.sample_index(100), indice)) + # manually change the stack num + buf5.stack_num = 2 + for buf in buf5.buffers: + buf.stack_num = 2 + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) + batch, _ = buf5.sample(0) + assert np.allclose(buf5[np.arange(buf5.maxsize)].weight, 1) + buf5.update_weight(indice, batch.weight * 0) + weight = buf5[np.arange(buf5.maxsize)].weight + modified_weight = weight[[0, 1, 2, 5, 6, 7]] + assert modified_weight.min() == modified_weight.max() + assert modified_weight.max() < 1 + unmodified_weight = weight[[3, 4, 8]] + assert unmodified_weight.min() == unmodified_weight.max() + assert unmodified_weight.max() < 1 + cached_weight = weight[9:] + assert cached_weight.min() == cached_weight.max() == 1 + # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next + buf6 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, + save_only_last_obs=True, ignore_obs_next=True), + cached_num, size) + obs = np.random.rand(size, 4, 84, 84) + buf6.add(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], + obs_next=[obs[3], obs[1]], cached_buffer_ids=[1, 2]) + assert buf6.obs.shape == (buf6.maxsize, 84, 84) + assert np.allclose(buf6.obs[0], obs[0, -1]) + assert np.allclose(buf6.obs[14], obs[2, -1]) + assert np.allclose(buf6.obs[19], obs[0, -1]) + assert buf6[0].obs.shape == (4, 84, 84) + + +def test_multibuf_hdf5(): + size = 100 + buffers = { + "vector": ReplayBufferManager([ReplayBuffer(size) for i in range(4)]), + "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) + } + buffer_types = {k: b.__class__ for k, b in buffers.items()} + device = 'cuda' if torch.cuda.is_available() else 'cpu' + info_t = torch.tensor([1.]).to(device) + for i in range(4): + kwargs = { + 'obs': Batch(index=np.array([i])), + 'act': i, + 'rew': np.array([1, 2]), + 'done': i % 3 == 2, + 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, + } + buffers["vector"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + buffer_ids=[0, 1, 2]) + buffers["cached"].add(**Batch.cat([[kwargs], [kwargs], [kwargs]]), + cached_buffer_ids=[0, 1, 2]) + + # save + paths = {} + for k, buf in buffers.items(): + f, path = tempfile.mkstemp(suffix='.hdf5') + os.close(f) + buf.save_hdf5(path) + paths[k] = path + + # load replay buffer + _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} + + # compare + for k in buffers.keys(): + assert len(_buffers[k]) == len(buffers[k]) + assert np.allclose(_buffers[k].act, buffers[k].act) + assert _buffers[k].stack_num == buffers[k].stack_num + assert _buffers[k].maxsize == buffers[k].maxsize + assert np.all(_buffers[k]._indices == buffers[k]._indices) + # check shallow copy in ReplayBufferManager + for k in ["vector", "cached"]: + buffers[k].info.number.n[0] = -100 + assert buffers[k].buffers[0].info.number.n[0] == -100 + # check if still behave normally + for k in ["vector", "cached"]: + kwargs = { + 'obs': Batch(index=np.array([5])), + 'act': 5, + 'rew': np.array([2, 1]), + 'done': False, + 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, + } + buffers[k].add(**Batch.cat([[kwargs], [kwargs], [kwargs], [kwargs]])) + act = np.zeros(buffers[k].maxsize) + if k == "vector": + act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) + act[size * 3] = 5 + elif k == "cached": + act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) + act[np.arange(3) + size] = np.array([3, 5, 2]) + act[np.arange(3) + size * 2] = np.array([3, 5, 2]) + act[np.arange(3) + size * 3] = np.array([3, 5, 2]) + act[size * 4] = 5 + assert np.allclose(buffers[k].act, act) + + for path in paths.values(): + os.remove(path) + + if __name__ == '__main__': - test_hdf5() test_replaybuffer() test_ignore_obs_next() test_stack() - test_pickle() test_segtree() test_priortized_replaybuffer() test_priortized_replaybuffer(233333, 200000) test_update() + test_pickle() + test_hdf5() + test_replaybuffermanager() + test_cachedbuffer() + test_multibuf_stack() + test_multibuf_hdf5() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index d0987cdbf..1a48aaee9 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -18,7 +18,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 7020df275..0d03fce97 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -16,7 +16,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index e51b8d161..368427a19 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,8 +1,8 @@ from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree -from tianshou.data.buffer import ReplayBuffer, \ - ListReplayBuffer, PrioritizedReplayBuffer +from tianshou.data.buffer import ReplayBuffer, ListReplayBuffer, \ + PrioritizedReplayBuffer, ReplayBufferManager, CachedReplayBuffer from tianshou.data.collector import Collector __all__ = [ @@ -14,5 +14,7 @@ "ReplayBuffer", "ListReplayBuffer", "PrioritizedReplayBuffer", + "ReplayBufferManager", + "CachedReplayBuffer", "Collector", ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 74299df9d..4c1ea14a6 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,5 +1,6 @@ import h5py import torch +import warnings import numpy as np from numbers import Number from typing import Any, Dict, List, Tuple, Union, Optional @@ -13,121 +14,13 @@ class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ interaction between the policy and environment. - The current implementation of Tianshou typically use 7 reserved keys in - :class:`~tianshou.data.Batch`: - - * ``obs`` the observation of step :math:`t` ; - * ``act`` the action of step :math:`t` ; - * ``rew`` the reward of step :math:`t` ; - * ``done`` the done flag of step :math:`t` ; - * ``obs_next`` the observation of step :math:`t+1` ; - * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \ - function returns 4 arguments, and the last one is ``info``); - * ``policy`` the data computed by policy in step :math:`t`; - - The following code snippet illustrates its usage: - :: - - >>> import pickle, numpy as np - >>> from tianshou.data import ReplayBuffer - >>> buf = ReplayBuffer(size=20) - >>> for i in range(3): - ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) - >>> buf.obs - # since we set size = 20, len(buf.obs) == 20. - array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0.]) - >>> # but there are only three valid items, so len(buf) == 3. - >>> len(buf) - 3 - >>> # save to file "buf.pkl" - >>> pickle.dump(buf, open('buf.pkl', 'wb')) - >>> # save to HDF5 file - >>> buf.save_hdf5('buf.hdf5') - >>> buf2 = ReplayBuffer(size=10) - >>> for i in range(15): - ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) - >>> len(buf2) - 10 - >>> buf2.obs - # since its size = 10, it only stores the last 10 steps' result. - array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.]) - - >>> # move buf2's result into buf (meanwhile keep it chronologically) - >>> buf.update(buf2) - array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., - 0., 0., 0., 0., 0., 0., 0.]) - - >>> # get a random sample from buffer - >>> # the batch_data is equal to buf[indice]. - >>> batch_data, indice = buf.sample(batch_size=4) - >>> batch_data.obs == buf[indice].obs - array([ True, True, True, True]) - >>> len(buf) - 13 - >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" - >>> len(buf) - 3 - >>> # load complete buffer from HDF5 file - >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') - >>> len(buf) - 3 - >>> # load contents of HDF5 file into existing buffer - >>> # (only possible if size of buffer and data in file match) - >>> buf.load_contents_hdf5('buf.hdf5') - >>> len(buf) - 3 - - :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling - (typically for RNN usage, see issue#19), ignoring storing the next - observation (save memory in atari tasks), and multi-modal observation (see - issue#38): - :: - - >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) - >>> for i in range(16): - ... done = i % 5 == 0 - ... buf.add(obs={'id': i}, act=i, rew=i, done=done, - ... obs_next={'id': i + 1}) - >>> print(buf) # you can see obs_next is not saved in buf - ReplayBuffer( - act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), - info: Batch(), - obs: Batch( - id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - ), - policy: Batch(), - rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - ) - >>> index = np.arange(len(buf)) - >>> print(buf.get(index, 'obs').id) - [[ 7. 7. 8. 9.] - [ 7. 8. 9. 10.] - [11. 11. 11. 11.] - [11. 11. 11. 12.] - [11. 11. 12. 13.] - [11. 12. 13. 14.] - [12. 13. 14. 15.] - [ 7. 7. 7. 7.] - [ 7. 7. 7. 8.]] - >>> # here is another way to get the stacked data - >>> # (stack only for obs and obs_next) - >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() - 0.0 - >>> # we can get obs_next through __getitem__, even if it doesn't exist - >>> print(buf[:].obs_next.id) - [[ 7. 8. 9. 10.] - [ 7. 8. 9. 10.] - [11. 11. 11. 12.] - [11. 11. 12. 13.] - [11. 12. 13. 14.] - [12. 13. 14. 15.] - [12. 13. 14. 15.] - [ 7. 7. 7. 8.] - [ 7. 7. 8. 9.]] - - :param int size: the size of replay buffer. + ReplayBuffer can be considered as a specialized form (or management) of + Batch. It stores all the data in a batch with circular-queue style. + + For the example usage of ReplayBuffer, please check out Section Buffer in + :doc:`/tutorials/concepts`. + + :param int size: the maximum size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater than or equal to 1, defaults to 1 (no stacking). :param bool ignore_obs_next: whether to store obs_next, defaults to False. @@ -136,9 +29,11 @@ class ReplayBuffer: False. :param bool sample_avail: the parameter indicating sampling only available index when using frame-stack sampling method, defaults to False. - This feature is not supported in Prioritized Replay Buffer currently. """ + _reserved_keys = ("obs", "act", "rew", "done", + "obs_next", "info", "policy") + def __init__( self, size: int, @@ -147,16 +42,20 @@ def __init__( save_only_last_obs: bool = False, sample_avail: bool = False, ) -> None: + self.options: Dict[str, Any] = { + "stack_num": stack_num, + "ignore_obs_next": ignore_obs_next, + "save_only_last_obs": save_only_last_obs, + "sample_avail": sample_avail, + } super().__init__() - self._maxsize = size - self._indices = np.arange(size) + self.maxsize = size + assert stack_num > 0, "stack_num should greater than 0" self.stack_num = stack_num - self._avail = sample_avail and stack_num > 1 - self._avail_index: List[int] = [] - self._save_s_ = not ignore_obs_next - self._last_obs = save_only_last_obs - self._index = 0 - self._size = 0 + self._indices = np.arange(size) + self._save_obs_next = not ignore_obs_next + self._save_only_last_obs = save_only_last_obs + self._sample_avail = sample_avail self._meta: Batch = Batch() self.reset() @@ -181,20 +80,92 @@ def __setstate__(self, state: Dict[str, Any]) -> None: We need it because pickling buffer does not work out-of-the-box ("buffer.__getattr__" is customized). """ - self._indices = np.arange(state["_maxsize"]) self.__dict__.update(state) + # compatible with version == 0.3.1's HDF5 data format + self._indices = np.arange(self.maxsize) + + def __setattr__(self, key: str, value: Any) -> None: + """Set self.key = value.""" + assert key not in self._reserved_keys, ( + "key '{}' is reserved and cannot be assigned".format(key)) + super().__setattr__(key, value) + + def save_hdf5(self, path: str) -> None: + """Save replay buffer to HDF5 file.""" + with h5py.File(path, "w") as f: + to_hdf5(self.__dict__, f) + + @classmethod + def load_hdf5( + cls, path: str, device: Optional[str] = None + ) -> "ReplayBuffer": + """Load replay buffer from HDF5 file.""" + with h5py.File(path, "r") as f: + buf = cls.__new__(cls) + buf.__setstate__(from_hdf5(f, device=device)) + return buf + + def reset(self) -> None: + """Clear all the data in replay buffer and episode statistics.""" + self._index = self._size = 0 + self._episode_length, self._episode_reward = 0, 0.0 + + def set_batch(self, batch: Batch) -> None: + """Manually choose the batch you want the ReplayBuffer to manage.""" + assert len(batch) == self.maxsize and \ + set(batch.keys()).issubset(self._reserved_keys), \ + "Input batch doesn't meet ReplayBuffer's data form requirement." + self._meta = batch + + def unfinished_index(self) -> np.ndarray: + """Return the index of unfinished episode.""" + last = (self._index - 1) % self._size if self._size else 0 + return np.array( + [last] if not self.done[last] and self._size else [], np.int) + + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + """Return the index of previous transition. + + The index won't be modified if it is the beginning of an episode. + """ + index = (index - 1) % self._size + end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + return (index + end_flag) % self._size - def __getstate__(self) -> dict: - exclude = {"_indices"} - state = {k: v for k, v in self.__dict__.items() if k not in exclude} - return state + def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + """Return the index of next transition. + + The index won't be modified if it is the end of an episode. + """ + end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + return (index + (1 - end_flag)) % self._size + + def update(self, buffer: "ReplayBuffer") -> None: + """Move the data from the given buffer to current buffer.""" + if len(buffer) == 0 or self.maxsize == 0: + return + stack_num, buffer.stack_num = buffer.stack_num, 1 + save_only_last_obs = self._save_only_last_obs + self._save_only_last_obs = False + indices = buffer.sample_index(0) # get all available indices + for i in indices: + self.add(**buffer[i]) # type: ignore + buffer.stack_num = stack_num + self._save_only_last_obs = save_only_last_obs + + def _buffer_allocator(self, key: List[str], value: Any) -> None: + """Allocate memory on buffer._meta for new (key, value) pair.""" + data = self._meta + for k in key[:-1]: + data = data[k] + data[key[-1]] = _create_value(value, self.maxsize) def _add_to_buffer(self, name: str, inst: Any) -> None: try: value = self._meta.__dict__[name] except KeyError: - self._meta.__dict__[name] = _create_value(inst, self._maxsize) - value = self._meta.__dict__[name] + self._buffer_allocator([name], inst) + value = self._meta[name] if isinstance(inst, (torch.Tensor, np.ndarray)): if inst.shape != value.shape[1:]: raise ValueError( @@ -203,33 +174,10 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: ) try: value[self._index] = inst - except KeyError: - for key in set(inst.keys()).difference(value.__dict__.keys()): - value.__dict__[key] = _create_value(inst[key], self._maxsize) - value[self._index] = inst - - @property - def stack_num(self) -> int: - return self._stack - - @stack_num.setter - def stack_num(self, num: int) -> None: - assert num > 0, "stack_num should greater than 0" - self._stack = num - - def update(self, buffer: "ReplayBuffer") -> None: - """Move the data from the given buffer to self.""" - if len(buffer) == 0: - return - i = begin = buffer._index % len(buffer) - stack_num_orig = buffer.stack_num - buffer.stack_num = 1 - while True: - self.add(**buffer[i]) # type: ignore - i = (i + 1) % len(buffer) - if i == begin: - break - buffer.stack_num = stack_num_orig + except KeyError: # inst is a dict/Batch + for key in set(inst.keys()).difference(value.keys()): + self._buffer_allocator([name, key], inst[key]) + self._meta[name][self._index] = inst def add( self, @@ -241,117 +189,110 @@ def add( info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs: Any, - ) -> None: - """Add a batch of data into replay buffer.""" + ) -> Tuple[int, Union[float, np.ndarray]]: + """Add a batch of data into replay buffer. + + Return (episode_length, episode_reward) if one episode is terminated, + otherwise return (0, 0.0). + """ assert isinstance( info, (dict, Batch) ), "You should return a dict in the last argument of env.step()." - if self._last_obs: + if self._save_only_last_obs: obs = obs[-1] self._add_to_buffer("obs", obs) self._add_to_buffer("act", act) - # make sure the reward is a float instead of an int - self._add_to_buffer("rew", rew * 1.0) # type: ignore - self._add_to_buffer("done", done) - if self._save_s_: + # make sure the data type of reward is float instead of int + # but rew may be np.ndarray, so that we cannot use float(rew) + rew = rew * 1.0 # type: ignore + self._add_to_buffer("rew", rew) + self._add_to_buffer("done", bool(done)) # done should be a bool scalar + if self._save_obs_next: if obs_next is None: obs_next = Batch() - elif self._last_obs: + elif self._save_only_last_obs: obs_next = obs_next[-1] self._add_to_buffer("obs_next", obs_next) self._add_to_buffer("info", info) self._add_to_buffer("policy", policy) - # maintain available index for frame-stack sampling - if self._avail: - # update current frame - avail = sum(self.done[i] for i in range( - self._index - self.stack_num + 1, self._index)) == 0 - if self._size < self.stack_num - 1: - avail = False - if avail and self._index not in self._avail_index: - self._avail_index.append(self._index) - elif not avail and self._index in self._avail_index: - self._avail_index.remove(self._index) - # remove the later available frame because of broken storage - t = (self._index + self.stack_num - 1) % self._maxsize - if t in self._avail_index: - self._avail_index.remove(t) - - if self._maxsize > 0: - self._size = min(self._size + 1, self._maxsize) - self._index = (self._index + 1) % self._maxsize + if self.maxsize > 0: + self._size = min(self._size + 1, self.maxsize) + self._index = (self._index + 1) % self.maxsize + else: # TODO: remove this after deleting ListReplayBuffer + self._size = self._index = self._size + 1 + + self._episode_reward += rew + self._episode_length += 1 + + if done: + result = self._episode_length, self._episode_reward + self._episode_length, self._episode_reward = 0, 0.0 + return result else: - self._size = self._index = self._index + 1 + return 0, self._episode_reward * 0.0 - def reset(self) -> None: - """Clear all the data in replay buffer.""" - self._index = 0 - self._size = 0 - self._avail_index = [] + def sample_index(self, batch_size: int) -> np.ndarray: + """Get a random sample of index with size = batch_size. + + Return all available indices in the buffer if batch_size is 0; return + an empty numpy array if batch_size < 0 or no available index can be + sampled. + """ + if self.stack_num == 1 or not self._sample_avail: # most often case + if batch_size > 0: + return np.random.choice(self._size, batch_size) + elif batch_size == 0: # construct current available indices + return np.concatenate([ + np.arange(self._index, self._size), + np.arange(self._index)]) + else: + return np.array([], np.int) + else: + if batch_size < 0: + return np.array([], np.int) + all_indices = prev_indices = np.concatenate([ + np.arange(self._index, self._size), np.arange(self._index)]) + for _ in range(self.stack_num - 2): + prev_indices = self.prev(prev_indices) + all_indices = all_indices[prev_indices != self.prev(prev_indices)] + if batch_size > 0: + return np.random.choice(all_indices, batch_size) + else: + return all_indices def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with size equal to batch_size. + """Get a random sample from buffer with size = batch_size. Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. """ - if batch_size > 0: - _all = self._avail_index if self._avail else self._size - indice = np.random.choice(_all, batch_size) - else: - if self._avail: - indice = np.array(self._avail_index) - else: - indice = np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) - assert len(indice) > 0, "No available indice can be sampled." - return self[indice], indice + indices = self.sample_index(batch_size) + return self[indices], indices def get( self, - indice: Union[slice, int, np.integer, np.ndarray], + index: Union[int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None, ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the - indice. The stack_num (here equals to 4) is given from buffer - initialization procedure. + index. """ if stack_num is None: stack_num = self.stack_num - if stack_num == 1: # the most often case - if key != "obs_next" or self._save_s_: - val = self._meta.__dict__[key] - try: - return val[indice] - except IndexError as e: - if not (isinstance(val, Batch) and val.is_empty()): - raise e # val != Batch() - return Batch() - indice = self._indices[:self._size][indice] - done = self._meta.__dict__["done"] - if key == "obs_next" and not self._save_s_: - indice += 1 - done[indice].astype(np.int) - indice[indice == self._size] = 0 - key = "obs" - val = self._meta.__dict__[key] + val = self._meta[key] try: - if stack_num == 1: - return val[indice] + if stack_num == 1: # the most often case + return val[index] stack: List[Any] = [] + indice = np.asarray(index) for _ in range(stack_num): stack = [val[indice]] + stack - pre_indice = np.asarray(indice - 1) - pre_indice[pre_indice == -1] = self._size - 1 - indice = np.asarray( - pre_indice + done[pre_indice].astype(np.int)) - indice[indice == self._size] = 0 + indice = self.prev(indice) if isinstance(val, Batch): return Batch.stack(stack, axis=indice.ndim) else: @@ -369,31 +310,24 @@ def __getitem__( If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). """ + if isinstance(index, slice): # change slice to np array + index = self._indices[:len(self)][index] + # raise KeyError first instead of AttributeError, to support np.array + obs = self.get(index, "obs") + if self._save_obs_next: + obs_next = self.get(index, "obs_next") + else: + obs_next = self.get(self.next(index), "obs") return Batch( - obs=self.get(index, "obs"), + obs=obs, act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self.get(index, "obs_next"), + obs_next=obs_next, info=self.get(index, "info"), policy=self.get(index, "policy"), ) - def save_hdf5(self, path: str) -> None: - """Save replay buffer to HDF5 file.""" - with h5py.File(path, "w") as f: - to_hdf5(self.__getstate__(), f) - - @classmethod - def load_hdf5( - cls, path: str, device: Optional[str] = None - ) -> "ReplayBuffer": - """Load replay buffer from HDF5 file.""" - with h5py.File(path, "r") as f: - buf = cls.__new__(cls) - buf.__setstate__(from_hdf5(f, device=device)) - return buf - class ListReplayBuffer(ReplayBuffer): """List-based replay buffer. @@ -411,24 +345,27 @@ class ListReplayBuffer(ReplayBuffer): """ def __init__(self, **kwargs: Any) -> None: + warnings.warn("ListReplayBuffer will be replaced in version 0.4.0.") super().__init__(size=0, ignore_obs_next=False, **kwargs) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: raise NotImplementedError("ListReplayBuffer cannot be sampled!") - def _add_to_buffer( - self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool] - ) -> None: - if self._meta.__dict__.get(name) is None: + def _add_to_buffer(self, name: str, inst: Any) -> None: + if self._meta.get(name) is None: self._meta.__dict__[name] = [] - self._meta.__dict__[name].append(inst) + self._meta[name].append(inst) def reset(self) -> None: - self._index = self._size = 0 - for k in list(self._meta.__dict__.keys()): - if isinstance(self._meta.__dict__[k], list): + super().reset() + for k in self._meta.keys(): + if isinstance(self._meta[k], list): self._meta.__dict__[k] = [] + def update(self, buffer: ReplayBuffer) -> None: + """The ListReplayBuffer cannot be updated by any buffer.""" + raise NotImplementedError + class PrioritizedReplayBuffer(ReplayBuffer): """Implementation of Prioritized Experience Replay. arXiv:1511.05952. @@ -464,8 +401,7 @@ def add( policy: Optional[Union[dict, Batch]] = {}, weight: Optional[Union[Number, np.number]] = None, **kwargs: Any, - ) -> None: - """Add a batch of data into replay buffer.""" + ) -> Tuple[int, Union[float, np.ndarray]]: if weight is None: weight = self._max_prio else: @@ -473,60 +409,289 @@ def add( self._max_prio = max(self._max_prio, weight) self._min_prio = min(self._min_prio, weight) self.weight[self._index] = weight ** self._alpha - super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) + return super().add(obs, act, rew, done, obs_next, + info, policy, **kwargs) - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with priority probability. - - Return all the data in the buffer if batch_size is 0. + def sample_index(self, batch_size: int) -> np.ndarray: + if batch_size > 0 and self._size > 0: + scalar = np.random.rand(batch_size) * self.weight.reduce() + return self.weight.get_prefix_sum_idx(scalar) + else: + return super().sample_index(batch_size) - :return: Sample data and its corresponding index inside the buffer. + def get_weight( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> np.ndarray: + """Get the importance sampling weight. The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ - assert self._size > 0, "Cannot sample a buffer with 0 size!" - if batch_size == 0: - indice = np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) - else: - scalar = np.random.rand(batch_size) * self.weight.reduce() - indice = self.weight.get_prefix_sum_idx(scalar) - batch = self[indice] # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) - batch.weight = (batch.weight / self._min_prio) ** (-self._beta) - return batch, indice + return (self.weight[index] / self._min_prio) ** (-self._beta) def update_weight( self, - indice: Union[np.ndarray], - new_weight: Union[np.ndarray, torch.Tensor] + index: np.ndarray, + new_weight: Union[np.ndarray, torch.Tensor], ) -> None: - """Update priority weight by indice in this buffer. + """Update priority weight by index in this buffer. - :param np.ndarray indice: indice you want to update weight. + :param np.ndarray index: index you want to update weight. :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps - self.weight[indice] = weight ** self._alpha + self.weight[index] = weight ** self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) def __getitem__( self, index: Union[slice, int, np.integer, np.ndarray] ) -> Batch: - return Batch( - obs=self.get(index, "obs"), - act=self.act[index], - rew=self.rew[index], - done=self.done[index], - obs_next=self.get(index, "obs_next"), - info=self.get(index, "info"), - policy=self.get(index, "policy"), - weight=self.weight[index], - ) + batch = super().__getitem__(index) + batch.weight = self.get_weight(index) + return batch + + +class ReplayBufferManager(ReplayBuffer): + """ReplayBufferManager contains a list of ReplayBuffer. + + These replay buffers have contiguous memory layout, and the storage space + each buffer has is a shallow copy of the topmost memory. + + :param int buffer_list: a list of ReplayBuffers needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed + explanation. + """ + + def __init__(self, buffer_list: List[ReplayBuffer], **kwargs: Any) -> None: + self.buffer_num = len(buffer_list) + self.buffers = buffer_list + self._offset = [] + offset = 0 + for buf in self.buffers: + # overwrite sub-buffers' _buffer_allocator so that + # the top buffer can allocate new memory for all sub-buffers + buf._buffer_allocator = self._buffer_allocator # type: ignore + assert buf._meta.is_empty() + self._offset.append(offset) + offset += buf.maxsize + super().__init__(size=offset, **kwargs) + + def __len__(self) -> int: + return sum([len(buf) for buf in self.buffers]) + + def reset(self) -> None: + for buf in self.buffers: + buf.reset() + + def _set_batch_for_children(self) -> None: + for offset, buf in zip(self._offset, self.buffers): + buf.set_batch(self._meta[offset:offset + buf.maxsize]) + + def set_batch(self, batch: Batch) -> None: + super().set_batch(batch) + self._set_batch_for_children() + + def unfinished_index(self) -> np.ndarray: + return np.concatenate([ + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers)]) + + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + index = np.asarray(index) % self.maxsize + prev_indices = np.zeros_like(index) + for offset, buf in zip(self._offset, self.buffers): + mask = (offset <= index) & (index < offset + buf.maxsize) + if np.any(mask): + prev_indices[mask] = buf.prev(index[mask] - offset) + offset + return prev_indices + + def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + index = np.asarray(index) % self.maxsize + next_indices = np.zeros_like(index) + for offset, buf in zip(self._offset, self.buffers): + mask = (offset <= index) & (index < offset + buf.maxsize) + if np.any(mask): + next_indices[mask] = buf.next(index[mask] - offset) + offset + return next_indices + + def update(self, buffer: ReplayBuffer) -> None: + """The ReplayBufferManager cannot be updated by any buffer.""" + raise NotImplementedError + + def _buffer_allocator(self, key: List[str], value: Any) -> None: + super()._buffer_allocator(key, value) + self._set_batch_for_children() + + def add( # type: ignore + self, + obs: Any, + act: Any, + rew: np.ndarray, + done: np.ndarray, + obs_next: Any = Batch(), + info: Optional[Batch] = Batch(), + policy: Optional[Batch] = Batch(), + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + **kwargs: Any + ) -> Tuple[np.ndarray, np.ndarray]: + """Add a batch of data into ReplayBufferManager. + + Each of the data's length (first dimension) must equal to the length of + buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. + + Return the array of episode_length and episode_reward with shape + (len(buffer_ids), ...), where (episode_length[i], episode_reward[i]) + refers to the buffer_ids[i]'s corresponding episode result. + """ + if buffer_ids is None: + buffer_ids = np.arange(self.buffer_num) + # assume each element in buffer_ids is unique + assert np.bincount(buffer_ids).max() == 1 + batch = Batch(obs=obs, act=act, rew=rew, done=done, + obs_next=obs_next, info=info, policy=policy) + assert len(buffer_ids) == len(batch) + episode_lengths = [] # (len(buffer_ids),) + episode_rewards = [] # (len(buffer_ids), ...) + for batch_idx, buffer_id in enumerate(buffer_ids): + length, reward = self.buffers[buffer_id].add(**batch[batch_idx]) + episode_lengths.append(length) + episode_rewards.append(reward) + return np.stack(episode_lengths), np.stack(episode_rewards) + + def sample_index(self, batch_size: int) -> np.ndarray: + if batch_size < 0: + return np.array([], np.int) + if self._sample_avail and self.stack_num > 1: + all_indices = np.concatenate([ + buf.sample_index(0) + offset + for offset, buf in zip(self._offset, self.buffers)]) + if batch_size == 0: + return all_indices + else: + return np.random.choice(all_indices, batch_size) + if batch_size == 0: # get all available indices + sample_num = np.zeros(self.buffer_num, np.int) + else: + buffer_lens = np.array([len(buf) for buf in self.buffers]) + buffer_idx = np.random.choice(self.buffer_num, batch_size, + p=buffer_lens / buffer_lens.sum()) + sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) + # avoid batch_size > 0 and sample_num == 0 -> get child's all data + sample_num[sample_num == 0] = -1 + + return np.concatenate([ + buf.sample_index(bsz) + offset + for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) + ]) + + +class CachedReplayBuffer(ReplayBufferManager): + """CachedReplayBuffer contains a given main buffer and n cached buffers, \ + cached_buffer_num * ReplayBuffer(size=max_episode_length). + + The memory layout is: ``| main_buffer | cached_buffers[0] | + cached_buffers[1] | ... | cached_buffers[cached_buffer_num - 1]``. + + The data is first stored in cached buffers. When the episode is + terminated, the data will move to the main buffer and the corresponding + cached buffer will be reset. + + :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` + function behaves normally. + :param int cached_buffer_num: number of ReplayBuffer needs to be created + for cached buffer. + :param int max_episode_length: the maximum length of one episode, used in + each cached buffer's maxsize. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` or + :class:`~tianshou.data.ReplayBufferManager` for more detailed + explanation. + """ + + def __init__( + self, + main_buffer: ReplayBuffer, + cached_buffer_num: int, + max_episode_length: int, + ) -> None: + assert cached_buffer_num > 0 and max_episode_length > 0 + self._is_prioritized = isinstance(main_buffer, PrioritizedReplayBuffer) + kwargs = main_buffer.options + buffers = [main_buffer] + [ReplayBuffer(max_episode_length, **kwargs) + for _ in range(cached_buffer_num)] + super().__init__(buffer_list=buffers, **kwargs) + self.main_buffer = self.buffers[0] + self.cached_buffers = self.buffers[1:] + self.cached_buffer_num = cached_buffer_num + + def add( # type: ignore + self, + obs: Any, + act: Any, + rew: np.ndarray, + done: np.ndarray, + obs_next: Any = Batch(), + info: Optional[Batch] = Batch(), + policy: Optional[Batch] = Batch(), + cached_buffer_ids: Optional[Union[np.ndarray, List[int]]] = None, + **kwargs: Any, + ) -> Tuple[np.ndarray, np.ndarray]: + """Add a batch of data into CachedReplayBuffer. + + Each of the data's length (first dimension) must equal to the length of + cached_buffer_ids. By default the cached_buffer_ids is [0, 1, ..., + cached_buffer_num - 1]. + + Return the array of episode_length and episode_reward with shape + (len(cached_buffer_ids), ...), where (episode_length[i], + episode_reward[i]) refers to the cached_buffer_ids[i]th cached buffer's + corresponding episode result. + """ + if cached_buffer_ids is None: + cached_buffer_ids = np.arange(self.cached_buffer_num) + else: # make sure it is np.ndarray + cached_buffer_ids = np.asarray(cached_buffer_ids) + # in self.buffers, the first buffer is main_buffer + buffer_ids = cached_buffer_ids + 1 # type: ignore + result = super().add(obs, act, rew, done, obs_next, info, + policy, buffer_ids=buffer_ids, **kwargs) + # find the terminated episode, move data from cached buf to main buf + for buffer_idx in cached_buffer_ids[np.asarray(done, np.bool_)]: + self.main_buffer.update(self.cached_buffers[buffer_idx]) + self.cached_buffers[buffer_idx].reset() + return result + + def __getitem__( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> Batch: + batch = super().__getitem__(index) + if self._is_prioritized: + indice = self._indices[index] + mask = indice < self.main_buffer.maxsize + batch.weight = np.ones(len(indice)) + batch.weight[mask] = self.main_buffer.get_weight(indice[mask]) + return batch + + def update_weight( + self, + index: np.ndarray, + new_weight: Union[np.ndarray, torch.Tensor], + ) -> None: + """Update priority weight by index in main buffer. + + :param np.ndarray index: index you want to update weight. + :param np.ndarray new_weight: new priority weight you want to update. + """ + if self._is_prioritized: + mask = index < self.main_buffer.maxsize + self.main_buffer.update_weight(index[mask], new_weight[mask])