diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 3f65ad633..225375d0a 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -67,7 +67,7 @@ def test_replaybuffer(size=10, bufsize=20): assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 - batch.done = 1 + batch.done = [1] batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) @@ -79,6 +79,13 @@ def test_replaybuffer(size=10, bufsize=20): assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] + # test prev / next + assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) + assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) + batch.done = [0] + b.add(batch, buffer_ids=[0]) + assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) + assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) def test_ignore_obs_next(size=10): @@ -718,7 +725,6 @@ def test_multibuf_hdf5(): test_stack() test_segtree() test_priortized_replaybuffer() - test_priortized_replaybuffer(233333, 200000) test_update() test_pickle() test_hdf5() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 477a2531a..b048c1ead 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,6 +1,7 @@ import h5py import torch import numpy as np +from numba import njit from typing import Any, Dict, List, Tuple, Union, Sequence, Optional from tianshou.data.batch import _create_value @@ -116,6 +117,7 @@ def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": def reset(self) -> None: """Clear all the data in replay buffer and episode statistics.""" + self.last_index = np.array([0]) self._index = self._size = 0 self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 @@ -137,7 +139,7 @@ def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: The index won't be modified if it is the beginning of an episode. """ index = (index - 1) % self._size - end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + end_flag = self.done[index] | (index == self.last_index[0]) return (index + end_flag) % self._size def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: @@ -145,7 +147,7 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: The index won't be modified if it is the end of an episode. """ - end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + end_flag = self.done[index] | (index == self.last_index[0]) return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> np.ndarray: @@ -163,6 +165,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: to_indices = [] for _ in range(len(from_indices)): to_indices.append(self._index) + self.last_index[0] = self._index self._index = (self._index + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) @@ -180,7 +183,7 @@ def _add_index( Return (index_to_be_modified, episode_reward, episode_length, episode_start_index). """ - ptr = self._index + self.last_index[0] = ptr = self._index self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize @@ -296,6 +299,13 @@ def get( """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index. + + :param index: the index for getting stacked data (t in the example). + :param str key: the key to get, should be one of the reserved_keys. + :param default_value: if the given key's data is not found and default_value is + set, return this default_value. + :param int stack_num: the stack num (4 in the example). Default to + self.stack_num. """ if key not in self._meta and default_value is not None: return default_value @@ -306,7 +316,10 @@ def get( if stack_num == 1: # the most often case return val[index] stack: List[Any] = [] - indice = np.asarray(index) + if isinstance(index, list): + indice = np.array(index) + else: + indice = index for _ in range(stack_num): stack = [val[indice]] + stack indice = self.prev(indice) @@ -453,12 +466,24 @@ def __init__(self, buffer_list: List[ReplayBuffer]) -> None: offset.append(size) size += buf.maxsize self._offset = np.array(offset) + self._extend_offset = np.array(offset + [size]) + self._lengths = np.zeros_like(offset) super().__init__(size=size, **kwargs) + self._compile() + + def _compile(self) -> None: + lens = last = index = np.array([0]) + offset = np.array([0, 1]) + done = np.array([False, False]) + _prev_index(index, offset, done, last, lens) + _next_index(index, offset, done, last, lens) def __len__(self) -> int: - return sum([len(buf) for buf in self.buffers]) + return self._lengths.sum() def reset(self) -> None: + self.last_index = self._offset.copy() + self._lengths = np.zeros_like(self._offset) for buf in self.buffers: buf.reset() @@ -477,22 +502,20 @@ def unfinished_index(self) -> np.ndarray: ]) def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) % self.maxsize - prev_indices = np.zeros_like(index) - for offset, buf in zip(self._offset, self.buffers): - mask = (offset <= index) & (index < offset + buf.maxsize) - if np.any(mask): - prev_indices[mask] = buf.prev(index[mask] - offset) + offset - return prev_indices + if isinstance(index, (list, np.ndarray)): + return _prev_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) + else: + return _prev_index(np.array([index]), self._extend_offset, + self.done, self.last_index, self._lengths)[0] def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) % self.maxsize - next_indices = np.zeros_like(index) - for offset, buf in zip(self._offset, self.buffers): - mask = (offset <= index) & (index < offset + buf.maxsize) - if np.any(mask): - next_indices[mask] = buf.next(index[mask] - offset) + offset - return next_indices + if isinstance(index, (list, np.ndarray)): + return _next_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) + else: + return _next_index(np.array([index]), self._extend_offset, + self.done, self.last_index, self._lengths)[0] def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" @@ -534,6 +557,8 @@ def add( ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) + self.last_index[buffer_id] = ptr + self._offset[buffer_id] + self._lengths[buffer_id] = len(self.buffers[buffer_id]) ptrs = np.array(ptrs) try: self._meta[ptrs] = batch @@ -564,9 +589,8 @@ def sample_index(self, batch_size: int) -> np.ndarray: if batch_size == 0: # get all available indices sample_num = np.zeros(self.buffer_num, np.int) else: - buffer_lens = np.array([len(buf) for buf in self.buffers]) buffer_idx = np.random.choice( - self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum() + self.buffer_num, batch_size, p=self._lengths / self._lengths.sum() ) sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) # avoid batch_size > 0 and sample_num == 0 -> get child's all data @@ -726,6 +750,51 @@ def add( updated_ep_idx.append(index[0]) updated_ptr.append(index[-1]) self.buffers[buffer_idx].reset() + self._lengths[0] = len(self.main_buffer) + self._lengths[buffer_idx] = 0 + self.last_index[0] = index[-1] + self.last_index[buffer_idx] = self._offset[buffer_idx] ptr[done] = updated_ptr ep_idx[done] = updated_ep_idx return ptr, ep_rew, ep_len, ep_idx + + +@njit +def _prev_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + prev_index = np.zeros_like(index) + for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (start <= index) & (index < end) + cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] + subind = (subind - start - 1) % cur_len + end_flag = done[subind + start] | (subind + start == last) + prev_index[mask] = (subind + end_flag) % cur_len + start + return prev_index + + +@njit +def _next_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + next_index = np.zeros_like(index) + for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (start <= index) & (index < end) + cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] + end_flag = done[subind] | (subind == last) + next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start + return next_index diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 9023bf47b..3b663d630 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,12 +220,12 @@ def compute_episodic_return( Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. - :param Batch batch: a data batch which contains several episodes of data - in sequential order. Mind that the end of each finished episode of batch + :param Batch batch: a data batch which contains several episodes of data in + sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). - :param np.ndarray indice: tell batch's location in buffer, batch is - equal to buffer[indice]. + :param numpy.ndarray indice: tell batch's location in buffer, batch is equal to + buffer[indice]. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation,