这是indexloc提供的服务,不要输入任何密码
Skip to content

Add CachedReplayBuffer and ReplayBufferManager #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8229b47
add first version of cached replay buffer(baseline), add standard api…
ChenDRAG Jan 19, 2021
483404f
add cached buffer, vec buffer
ChenDRAG Jan 20, 2021
942c2a3
simple pep8 fix
ChenDRAG Jan 20, 2021
ec4b246
init
ChenDRAG Jan 20, 2021
e7f631e
Merge branch 'master(net/utils change #275)' into cached
ChenDRAG Jan 20, 2021
da564be
Merge branch 'master' into cached
Trinkle23897 Jan 20, 2021
36e799e
some change
ChenDRAG Jan 22, 2021
50b20a0
update ReplayBuffer
Trinkle23897 Jan 22, 2021
3e487dc
Merge branch 'cached' of github.com:ChenDRAG/tianshou into cached
ChenDRAG Jan 22, 2021
0ac97af
refactor ReplayBuffer
ChenDRAG Jan 22, 2021
bd90b97
refactor vec/cached buffer
ChenDRAG Jan 22, 2021
2b7c227
pep8 fix
ChenDRAG Jan 22, 2021
3afb2cb
update VectorReplayBuffer and add test
Trinkle23897 Jan 23, 2021
443969d
update cached
Trinkle23897 Jan 24, 2021
ee51e64
order change, small fix
ChenDRAG Jan 25, 2021
a5bc4ad
try unittest
Trinkle23897 Jan 25, 2021
17e3612
add more test and fix bugs
Trinkle23897 Jan 25, 2021
7eba23d
fix a bug and add some corner-case tests
Trinkle23897 Jan 25, 2021
b9f4f2a
re-implement sample_avail function and add test for CachedReplayBuffe…
Trinkle23897 Jan 26, 2021
8fe85f8
improve documents
Trinkle23897 Jan 26, 2021
f59a530
ReplayBuffers._offset
Trinkle23897 Jan 26, 2021
425c2bd
fix atari-style update; support CachedBuffer with main_buffer==PrioBu…
Trinkle23897 Jan 27, 2021
2361755
assert _meta.is_empty() in ReplayBuffers init
Trinkle23897 Jan 27, 2021
0160a7f
Merge branch 'master' into cached
Trinkle23897 Jan 28, 2021
75d581b
small fix
Trinkle23897 Jan 28, 2021
b5d93f3
small fix
Trinkle23897 Jan 28, 2021
c8f27c9
improve coverage
Trinkle23897 Jan 28, 2021
9c879f2
small buffer change
ChenDRAG Jan 28, 2021
26bb74c
pep8 fix
ChenDRAG Jan 28, 2021
74df1c5
fix ci
Trinkle23897 Jan 28, 2021
39463ba
recover speed to 2000+
Trinkle23897 Jan 28, 2021
31f0c94
improve documents
Trinkle23897 Jan 28, 2021
720da29
ReplayBuffers -> ReplayBufferManager; alloc_fn -> _buffer_allocator; …
Trinkle23897 Jan 28, 2021
b9385d8
re-organize test_buffer.py
Trinkle23897 Jan 28, 2021
16bb42e
improve test
Trinkle23897 Jan 29, 2021
cada7cc
test if can be faster
Trinkle23897 Jan 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 156 additions & 4 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
Buffer
------

.. automodule:: tianshou.data.ReplayBuffer
:members:
:noindex:
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.

Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:

* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
* ``rew`` the reward of step :math:`t` ;
* ``done`` the done flag of step :math:`t` ;
* ``obs_next`` the observation of step :math:`t+1` ;
* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;

The following code snippet illustrates its usage, including:

- the basic data storage: ``add()``;
- get attribute, get slicing data, ...;
- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``;
- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``;
- save/load data from buffer: pickle and HDF5;

::

>>> import pickle, numpy as np
>>> from tianshou.data import ReplayBuffer
>>> buf = ReplayBuffer(size=20)
>>> for i in range(3):
... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={})

>>> buf.obs
# since we set size = 20, len(buf.obs) == 20.
array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')

>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... done = i % 4 == 0
... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={})
>>> len(buf2)
10
>>> buf2.obs
# since its size = 10, it only stores the last 10 steps' result.
array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])

>>> # move buf2's result into buf (meanwhile keep it chronologically)
>>> buf.update(buf2)
>>> buf.obs
array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
0, 0, 0, 0])

>>> # get all available index by using batch_size = 0
>>> indice = buf.sample_index(0)
>>> indice
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
>>> # get one step previous/next transition
>>> buf.prev(indice)
array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
>>> buf.next(indice)
array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])

>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
>>> len(buf)
13

>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3

:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38):

.. raw:: html

<details>
<summary>Advance usage of ReplayBuffer</summary>

.. code-block:: python

>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
>>> for i in range(16):
... done = i % 5 == 0
... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i,
... done=done, obs_next={'id': i + 1})
... print(i, ep_len, ep_rew)
0 1 0.0
1 0 0.0
2 0 0.0
3 0 0.0
4 0 0.0
5 5 15.0
6 0 0.0
7 0 0.0
8 0 0.0
9 0 0.0
10 5 40.0
11 0 0.0
12 0 0.0
13 0 0.0
14 0 0.0
15 5 65.0
>>> print(buf) # you can see obs_next is not saved in buf
ReplayBuffer(
obs: Batch(
id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
),
act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
done: array([False, True, False, False, False, False, True, False,
False]),
info: Batch(),
policy: Batch(),
)
>>> index = np.arange(len(buf))
>>> print(buf.get(index, 'obs').id)
[[ 7 7 8 9]
[ 7 8 9 10]
[11 11 11 11]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[ 7 7 7 7]
[ 7 7 7 8]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
0
>>> # we can get obs_next through __getitem__, even if it doesn't exist
>>> print(buf[:].obs_next.id)
[[ 7 8 9 10]
[ 7 8 9 10]
[11 11 11 12]
[11 11 12 13]
[11 12 13 14]
[12 13 14 15]
[12 13 14 15]
[ 7 7 7 8]
[ 7 7 8 9]]

.. raw:: html

</details><br>

Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.


Policy
Expand Down
Loading