这是indexloc提供的服务,不要输入任何密码
Skip to content

Improve buffer.prev() & buffer.next() #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.a[1] == 4 and b.info.b.c[1] == 0
assert b.info.d.e[1] == -np.inf
# test batch-style adding method, where len(batch) == 1
batch.done = 1
batch.done = [1]
batch.info.e = np.zeros([1, 4])
batch = Batch.stack([batch])
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
Expand All @@ -79,6 +79,13 @@ def test_replaybuffer(size=10, bufsize=20):
assert b.info.e.shape == (b.maxsize, 1, 4)
with pytest.raises(IndexError):
b[22]
# test prev / next
assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1])
assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2])
batch.done = [0]
b.add(batch, buffer_ids=[0])
assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3])
assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3])


def test_ignore_obs_next(size=10):
Expand Down Expand Up @@ -718,7 +725,6 @@ def test_multibuf_hdf5():
test_stack()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()
test_pickle()
test_hdf5()
Expand Down
111 changes: 90 additions & 21 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import h5py
import torch
import numpy as np
from numba import njit
from typing import Any, Dict, List, Tuple, Union, Sequence, Optional

from tianshou.data.batch import _create_value
Expand Down Expand Up @@ -116,6 +117,7 @@ def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer":

def reset(self) -> None:
"""Clear all the data in replay buffer and episode statistics."""
self.last_index = np.array([0])
self._index = self._size = 0
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0

Expand All @@ -137,15 +139,15 @@ def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
The index won't be modified if it is the beginning of an episode.
"""
index = (index - 1) % self._size
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
end_flag = self.done[index] | (index == self.last_index[0])
return (index + end_flag) % self._size

def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
"""Return the index of next transition.

The index won't be modified if it is the end of an episode.
"""
end_flag = self.done[index] | np.isin(index, self.unfinished_index())
end_flag = self.done[index] | (index == self.last_index[0])
return (index + (1 - end_flag)) % self._size

def update(self, buffer: "ReplayBuffer") -> np.ndarray:
Expand All @@ -163,6 +165,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
to_indices = []
for _ in range(len(from_indices)):
to_indices.append(self._index)
self.last_index[0] = self._index
self._index = (self._index + 1) % self.maxsize
self._size = min(self._size + 1, self.maxsize)
to_indices = np.array(to_indices)
Expand All @@ -180,7 +183,7 @@ def _add_index(
Return (index_to_be_modified, episode_reward, episode_length,
episode_start_index).
"""
ptr = self._index
self.last_index[0] = ptr = self._index
self._size = min(self._size + 1, self.maxsize)
self._index = (self._index + 1) % self.maxsize

Expand Down Expand Up @@ -296,6 +299,13 @@ def get(
"""Return the stacked result.

E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index.

:param index: the index for getting stacked data (t in the example).
:param str key: the key to get, should be one of the reserved_keys.
:param default_value: if the given key's data is not found and default_value is
set, return this default_value.
:param int stack_num: the stack num (4 in the example). Default to
self.stack_num.
"""
if key not in self._meta and default_value is not None:
return default_value
Expand All @@ -306,7 +316,10 @@ def get(
if stack_num == 1: # the most often case
return val[index]
stack: List[Any] = []
indice = np.asarray(index)
if isinstance(index, list):
indice = np.array(index)
else:
indice = index
for _ in range(stack_num):
stack = [val[indice]] + stack
indice = self.prev(indice)
Expand Down Expand Up @@ -453,12 +466,24 @@ def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
offset.append(size)
size += buf.maxsize
self._offset = np.array(offset)
self._extend_offset = np.array(offset + [size])
self._lengths = np.zeros_like(offset)
super().__init__(size=size, **kwargs)
self._compile()

def _compile(self) -> None:
lens = last = index = np.array([0])
offset = np.array([0, 1])
done = np.array([False, False])
_prev_index(index, offset, done, last, lens)
_next_index(index, offset, done, last, lens)

def __len__(self) -> int:
return sum([len(buf) for buf in self.buffers])
return self._lengths.sum()

def reset(self) -> None:
self.last_index = self._offset.copy()
self._lengths = np.zeros_like(self._offset)
for buf in self.buffers:
buf.reset()

Expand All @@ -477,22 +502,20 @@ def unfinished_index(self) -> np.ndarray:
])

def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
prev_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
prev_indices[mask] = buf.prev(index[mask] - offset) + offset
return prev_indices
if isinstance(index, (list, np.ndarray)):
return _prev_index(np.asarray(index), self._extend_offset,
self.done, self.last_index, self._lengths)
else:
return _prev_index(np.array([index]), self._extend_offset,
self.done, self.last_index, self._lengths)[0]

def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray:
index = np.asarray(index) % self.maxsize
next_indices = np.zeros_like(index)
for offset, buf in zip(self._offset, self.buffers):
mask = (offset <= index) & (index < offset + buf.maxsize)
if np.any(mask):
next_indices[mask] = buf.next(index[mask] - offset) + offset
return next_indices
if isinstance(index, (list, np.ndarray)):
return _next_index(np.asarray(index), self._extend_offset,
self.done, self.last_index, self._lengths)
else:
return _next_index(np.array([index]), self._extend_offset,
self.done, self.last_index, self._lengths)[0]

def update(self, buffer: ReplayBuffer) -> np.ndarray:
"""The ReplayBufferManager cannot be updated by any buffer."""
Expand Down Expand Up @@ -534,6 +557,8 @@ def add(
ep_lens.append(ep_len)
ep_rews.append(ep_rew)
ep_idxs.append(ep_idx + self._offset[buffer_id])
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
self._lengths[buffer_id] = len(self.buffers[buffer_id])
ptrs = np.array(ptrs)
try:
self._meta[ptrs] = batch
Expand Down Expand Up @@ -564,9 +589,8 @@ def sample_index(self, batch_size: int) -> np.ndarray:
if batch_size == 0: # get all available indices
sample_num = np.zeros(self.buffer_num, np.int)
else:
buffer_lens = np.array([len(buf) for buf in self.buffers])
buffer_idx = np.random.choice(
self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum()
self.buffer_num, batch_size, p=self._lengths / self._lengths.sum()
)
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num)
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
Expand Down Expand Up @@ -726,6 +750,51 @@ def add(
updated_ep_idx.append(index[0])
updated_ptr.append(index[-1])
self.buffers[buffer_idx].reset()
self._lengths[0] = len(self.main_buffer)
self._lengths[buffer_idx] = 0
self.last_index[0] = index[-1]
self.last_index[buffer_idx] = self._offset[buffer_idx]
ptr[done] = updated_ptr
ep_idx[done] = updated_ep_idx
return ptr, ep_rew, ep_len, ep_idx


@njit
def _prev_index(
index: np.ndarray,
offset: np.ndarray,
done: np.ndarray,
last_index: np.ndarray,
lengths: np.ndarray,
) -> np.ndarray:
index = index % offset[-1]
prev_index = np.zeros_like(index)
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
mask = (start <= index) & (index < end)
cur_len = max(1, cur_len)
if np.sum(mask) > 0:
subind = index[mask]
subind = (subind - start - 1) % cur_len
end_flag = done[subind + start] | (subind + start == last)
prev_index[mask] = (subind + end_flag) % cur_len + start
return prev_index


@njit
def _next_index(
index: np.ndarray,
offset: np.ndarray,
done: np.ndarray,
last_index: np.ndarray,
lengths: np.ndarray,
) -> np.ndarray:
index = index % offset[-1]
next_index = np.zeros_like(index)
for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index):
mask = (start <= index) & (index < end)
cur_len = max(1, cur_len)
if np.sum(mask) > 0:
subind = index[mask]
end_flag = done[subind] | (subind == last)
next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start
return next_index
8 changes: 4 additions & 4 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ def compute_episodic_return(
Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
to calculate q function/reward to go of given batch.

:param Batch batch: a data batch which contains several episodes of data
in sequential order. Mind that the end of each finished episode of batch
:param Batch batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be
recongized by buffer.unfinished_index().
:param np.ndarray indice: tell batch's location in buffer, batch is
equal to buffer[indice].
:param numpy.ndarray indice: tell batch's location in buffer, batch is equal to
buffer[indice].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
:param float gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param float gae_lambda: the parameter for Generalized Advantage Estimation,
Expand Down