From 680557e8a6df89279d1e9afdff8b516137182ded Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 21:03:54 +0800 Subject: [PATCH 1/8] remove unfinished_index --- test/base/test_buffer.py | 10 +--------- tianshou/data/buffer.py | 30 ++++++++++++++---------------- tianshou/policy/base.py | 10 +++++----- 3 files changed, 20 insertions(+), 30 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 3f65ad633..6419c1e3a 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -396,7 +396,7 @@ def test_replaybuffermanager(): with pytest.raises(NotImplementedError): # ReplayBufferManager cannot be updated buf.update(buf) - # sample index / prev / next / unfinished_index + # sample index / prev / next indice = buf.sample_index(11000) assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample batch, indice = buf.sample(0) @@ -405,9 +405,7 @@ def test_replaybuffermanager(): assert np.allclose(indice_prev, indice), indice_prev indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next - assert np.allclose(buf.unfinished_index(), [0, 5]) buf.add(Batch(obs=[4], act=[4], rew=[4], done=[1]), buffer_ids=[3]) - assert np.allclose(buf.unfinished_index(), [0, 5]) batch, indice = buf.sample(10) batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10, 15]) @@ -450,12 +448,10 @@ def test_replaybuffermanager(): 10, 12, 12, 14, 14, 15, 17, 17, 19, 19, ]) - assert np.allclose(buf.unfinished_index(), [4, 14]) ptr, ep_rew, ep_len, ep_idx = buf.add( Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2]) assert np.all(ep_len == [3]) and np.all(ep_rew == [1]) assert np.all(ptr == [10]) and np.all(ep_idx == [13]) - assert np.allclose(buf.unfinished_index(), [4]) indice = list(sorted(buf.sample_index(0))) assert np.allclose(indice, np.arange(len(buf))) assert np.allclose(buf.prev(indice), [ @@ -506,7 +502,6 @@ def test_cachedbuffer(): assert np.allclose(buf.obs, obs) assert np.all(ep_len == [1]) and np.all(ep_rew == [2.0]) assert np.all(ptr == [0]) and np.all(ep_idx == [0]) - assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_index(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), @@ -515,7 +510,6 @@ def test_cachedbuffer(): assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] assert np.allclose(buf.obs, obs) - assert np.allclose(buf.unfinished_index(), [25]) indice = buf.sample_index(0) assert np.allclose(indice, [0, 1, 2, 25]) assert np.allclose(buf.done[indice], [1, 0, 1, 0]) @@ -596,7 +590,6 @@ def test_multibuf_stack(): 0, 0, 0, 1, 0, # cached_buffer[1] 0, 0, 0, 1, 0, # cached_buffer[2] ]), buf4.done - assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indice = sorted(buf4.sample_index(0)) assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) assert np.allclose(buf4[indice].obs[..., 0], [ @@ -718,7 +711,6 @@ def test_multibuf_hdf5(): test_stack() test_segtree() test_priortized_replaybuffer() - test_priortized_replaybuffer(233333, 200000) test_update() test_pickle() test_hdf5() diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b24e8c6b2..6cc17d872 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -116,7 +116,7 @@ def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": def reset(self) -> None: """Clear all the data in replay buffer and episode statistics.""" - self._index = self._size = 0 + self.last_index = self._index = self._size = 0 self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 def set_batch(self, batch: Batch) -> None: @@ -126,18 +126,13 @@ def set_batch(self, batch: Batch) -> None: ), "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch - def unfinished_index(self) -> np.ndarray: - """Return the index of unfinished episode.""" - last = (self._index - 1) % self._size if self._size else 0 - return np.array([last] if not self.done[last] and self._size else [], np.int) - def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return the index of previous transition. The index won't be modified if it is the beginning of an episode. """ index = (index - 1) % self._size - end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + end_flag = self.done[index] | (index == self.last_index) return (index + end_flag) % self._size def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: @@ -145,7 +140,7 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: The index won't be modified if it is the end of an episode. """ - end_flag = self.done[index] | np.isin(index, self.unfinished_index()) + end_flag = self.done[index] | (index == self.last_index) return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> np.ndarray: @@ -163,6 +158,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: to_indices = [] for _ in range(len(from_indices)): to_indices.append(self._index) + self.last_index = self._index self._index = (self._index + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) @@ -180,7 +176,7 @@ def _add_index( Return (index_to_be_modified, episode_reward, episode_length, episode_start_index). """ - ptr = self._index + self.last_index = ptr = self._index self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize @@ -296,6 +292,13 @@ def get( """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the index. + + :param index: the index for getting stacked data (t in the example). + :param str key: the key to get, should be one of the reserved_keys. + :param default_value: if the given key's data is not found and default_value is + set, return this default_value. + :param int stack_num: the stack num (4 in the example). Default to + self.stack_num. """ if key not in self._meta and default_value is not None: return default_value @@ -459,6 +462,7 @@ def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) def reset(self) -> None: + self.last_index = np.array([0, *self._offset[:-1]]) for buf in self.buffers: buf.reset() @@ -470,12 +474,6 @@ def set_batch(self, batch: Batch) -> None: super().set_batch(batch) self._set_batch_for_children() - def unfinished_index(self) -> np.ndarray: - return np.concatenate([ - buf.unfinished_index() + offset - for offset, buf in zip(self._offset, self.buffers) - ]) - def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) % self.maxsize prev_indices = np.zeros_like(index) @@ -534,7 +532,7 @@ def add( ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) - ptrs = np.array(ptrs) + self.last_index = ptrs = np.array(ptrs) try: self._meta[ptrs] = batch except ValueError: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index be6d8216b..b1f04460a 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,10 +220,10 @@ def compute_episodic_return( Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. - :param batch: a data batch which contains several episodes of data - in sequential order. Mind that the end of each finished episode of batch + :param batch: a data batch which contains several episodes of data in + sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be - recongized by buffer.unfinished_index(). + recongized by buffer.last_index. :type batch: :class:`~tianshou.data.Batch` :param numpy.ndarray indice: tell batch's location in buffer, batch is equal to buffer[indice]. @@ -247,7 +247,7 @@ def compute_episodic_return( v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice) end_flag = batch.done.copy() - end_flag[np.isin(indice, buffer.unfinished_index())] = True + end_flag[np.isin(indice, buffer.last_index)] = True returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda) if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() @@ -310,7 +310,7 @@ def compute_nstep_return( target_q = to_numpy(target_q_torch.reshape(bsz, -1)) target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1) end_flag = buffer.done.copy() - end_flag[buffer.unfinished_index()] = True + end_flag[buffer.last_index] = True target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step, mean, std) From 17b69ed7200b879b8b78721285606623b4446bab Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 21:40:02 +0800 Subject: [PATCH 2/8] intermidiate result but it is wrong in the edge of buffer --- tianshou/data/buffer.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 6cc17d872..2ce07025c 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -462,7 +462,7 @@ def __len__(self) -> int: return sum([len(buf) for buf in self.buffers]) def reset(self) -> None: - self.last_index = np.array([0, *self._offset[:-1]]) + self.last_index = self._offset.copy() for buf in self.buffers: buf.reset() @@ -475,22 +475,14 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) % self.maxsize - prev_indices = np.zeros_like(index) - for offset, buf in zip(self._offset, self.buffers): - mask = (offset <= index) & (index < offset + buf.maxsize) - if np.any(mask): - prev_indices[mask] = buf.prev(index[mask] - offset) + offset - return prev_indices + index = (np.asarray(index) - 1) % self.maxsize + end_flag = self.done[index] | np.isin(index, self.last_index) + return (index + end_flag) % self.maxsize def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) % self.maxsize - next_indices = np.zeros_like(index) - for offset, buf in zip(self._offset, self.buffers): - mask = (offset <= index) & (index < offset + buf.maxsize) - if np.any(mask): - next_indices[mask] = buf.next(index[mask] - offset) + offset - return next_indices + end_flag = self.done[index] | np.isin(index, self.last_index) + return (index + (1 - end_flag)) % self.maxsize def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" From 8df07969eb97081788f079e1eab3a85620344c62 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 21:50:37 +0800 Subject: [PATCH 3/8] another version uses LUT --- tianshou/data/buffer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 2ce07025c..6f3750958 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -456,6 +456,11 @@ def __init__(self, buffer_list: List[ReplayBuffer]) -> None: offset.append(size) size += buf.maxsize self._offset = np.array(offset) + extend_offset = np.array(offset + [size]) + self._prev_indices = np.arange(size) - 1 + self._next_indices = np.arange(size) + 1 + self._prev_indices[offset] = extend_offset[1:] - 1 + self._next_indices[extend_offset[1:] - 1] = offset super().__init__(size=size, **kwargs) def __len__(self) -> int: @@ -475,14 +480,16 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = (np.asarray(index) - 1) % self.maxsize + index = self._prev_indices[index] end_flag = self.done[index] | np.isin(index, self.last_index) - return (index + end_flag) % self.maxsize + index[end_flag] = self._next_indices[index[end_flag]] + return index def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: index = np.asarray(index) % self.maxsize end_flag = self.done[index] | np.isin(index, self.last_index) - return (index + (1 - end_flag)) % self.maxsize + index[~end_flag] = self._next_indices[index[~end_flag]] + return index def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" From 9378dac7d2977cfd0463a233efe496e7bcd95984 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 23:09:32 +0800 Subject: [PATCH 4/8] 31% -> 6% --- test/base/test_buffer.py | 13 ++++-- tianshou/data/buffer.py | 87 ++++++++++++++++++++++++++++++++-------- 2 files changed, 80 insertions(+), 20 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 6419c1e3a..e203df5ea 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -67,7 +67,7 @@ def test_replaybuffer(size=10, bufsize=20): assert b.info.a[1] == 4 and b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 - batch.done = 1 + batch.done = [1] batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) @@ -79,6 +79,13 @@ def test_replaybuffer(size=10, bufsize=20): assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] + # test prev / next + assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) + assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) + batch.done = [0] + b.add(batch, buffer_ids=[0]) + assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) + assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) def test_ignore_obs_next(size=10): @@ -467,8 +474,8 @@ def test_replaybuffermanager(): 15, 17, 17, 19, 19, ]) # corner case: list, int and -1 - assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] - assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] + assert buf.prev(-1) == buf.prev([buf.maxsize - 1]) + assert buf.next(-1) == buf.next([buf.maxsize - 1]) batch = buf._meta batch.info = np.ones(buf.maxsize) buf.set_batch(batch) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 6f3750958..474ac5584 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,6 +1,7 @@ import h5py import torch import numpy as np +from numba import njit from typing import Any, Dict, List, Tuple, Union, Sequence, Optional from tianshou.data.batch import _create_value @@ -456,18 +457,24 @@ def __init__(self, buffer_list: List[ReplayBuffer]) -> None: offset.append(size) size += buf.maxsize self._offset = np.array(offset) - extend_offset = np.array(offset + [size]) - self._prev_indices = np.arange(size) - 1 - self._next_indices = np.arange(size) + 1 - self._prev_indices[offset] = extend_offset[1:] - 1 - self._next_indices[extend_offset[1:] - 1] = offset + self._extend_offset = np.array(offset + [size]) + self._lengths = np.zeros_like(offset) super().__init__(size=size, **kwargs) + self._compile() + + def _compile(self) -> None: + lens = last = index = np.array([0]) + offset = np.array([0, 10]) + done = np.array([False, False]) + _prev_index(index, offset, done, last, lens) + _next_index(index, offset, done, last, lens) def __len__(self) -> int: - return sum([len(buf) for buf in self.buffers]) + return self._lengths.sum() def reset(self) -> None: self.last_index = self._offset.copy() + self._lengths = np.zeros_like(self._offset) for buf in self.buffers: buf.reset() @@ -480,16 +487,18 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = self._prev_indices[index] - end_flag = self.done[index] | np.isin(index, self.last_index) - index[end_flag] = self._next_indices[index[end_flag]] - return index + if np.isscalar(index): + return _prev_index(np.array([index]), self._extend_offset, + self.done, self.last_index, self._lengths)[0] + return _prev_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - index = np.asarray(index) % self.maxsize - end_flag = self.done[index] | np.isin(index, self.last_index) - index[~end_flag] = self._next_indices[index[~end_flag]] - return index + if np.isscalar(index): + return _next_index(np.array([index]), self._extend_offset, + self.done, self.last_index, self._lengths)[0] + return _next_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" @@ -531,7 +540,9 @@ def add( ep_lens.append(ep_len) ep_rews.append(ep_rew) ep_idxs.append(ep_idx + self._offset[buffer_id]) - self.last_index = ptrs = np.array(ptrs) + self.last_index[buffer_id] = ptr + self._offset[buffer_id] + self._lengths[buffer_id] = len(self.buffers[buffer_id]) + ptrs = np.array(ptrs) try: self._meta[ptrs] = batch except ValueError: @@ -561,9 +572,8 @@ def sample_index(self, batch_size: int) -> np.ndarray: if batch_size == 0: # get all available indices sample_num = np.zeros(self.buffer_num, np.int) else: - buffer_lens = np.array([len(buf) for buf in self.buffers]) buffer_idx = np.random.choice( - self.buffer_num, batch_size, p=buffer_lens / buffer_lens.sum() + self.buffer_num, batch_size, p=self._lengths / self._lengths.sum() ) sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) # avoid batch_size > 0 and sample_num == 0 -> get child's all data @@ -723,6 +733,49 @@ def add( updated_ep_idx.append(index[0]) updated_ptr.append(index[-1]) self.buffers[buffer_idx].reset() + self._lengths[0] = len(self.main_buffer) + self._lengths[buffer_idx] = 0 + self.last_index[0] = index[-1] + self.last_index[buffer_idx] = self._offset[buffer_idx] ptr[done] = updated_ptr ep_idx[done] = updated_ep_idx return ptr, ep_rew, ep_len, ep_idx + + +@njit +def _prev_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + prev_index = np.zeros_like(index) + for left, right, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (left <= index) & (index < right) + subind = index[mask] + if len(subind) > 0: + subind = (subind - left - 1) % cur_len + end_flag = done[subind + left] | (subind + left == last) + prev_index[mask] = (subind + end_flag) % cur_len + left + return prev_index + + +@njit +def _next_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + next_index = np.zeros_like(index) + for left, right, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (left <= index) & (index < right) + subind = index[mask] + if len(subind) > 0: + end_flag = done[subind] | (subind == last) + next_index[mask] = (subind - left + 1 - end_flag) % cur_len + left + return next_index From e0640294205a84ec57bc8bfa533b48e953a22644 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 23:41:03 +0800 Subject: [PATCH 5/8] fix test --- tianshou/data/buffer.py | 55 +++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 474ac5584..b10a8b81e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -117,7 +117,8 @@ def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": def reset(self) -> None: """Clear all the data in replay buffer and episode statistics.""" - self.last_index = self._index = self._size = 0 + self.last_index = np.array([0]) + self._index = self._size = 0 self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 def set_batch(self, batch: Batch) -> None: @@ -133,7 +134,7 @@ def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: The index won't be modified if it is the beginning of an episode. """ index = (index - 1) % self._size - end_flag = self.done[index] | (index == self.last_index) + end_flag = self.done[index] | (index == self.last_index[0]) return (index + end_flag) % self._size def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: @@ -141,7 +142,7 @@ def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: The index won't be modified if it is the end of an episode. """ - end_flag = self.done[index] | (index == self.last_index) + end_flag = self.done[index] | (index == self.last_index[0]) return (index + (1 - end_flag)) % self._size def update(self, buffer: "ReplayBuffer") -> np.ndarray: @@ -159,7 +160,7 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: to_indices = [] for _ in range(len(from_indices)): to_indices.append(self._index) - self.last_index = self._index + self.last_index[0] = self._index self._index = (self._index + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) @@ -177,7 +178,7 @@ def _add_index( Return (index_to_be_modified, episode_reward, episode_length, episode_start_index). """ - self.last_index = ptr = self._index + self.last_index[0] = ptr = self._index self._size = min(self._size + 1, self.maxsize) self._index = (self._index + 1) % self.maxsize @@ -464,7 +465,7 @@ def __init__(self, buffer_list: List[ReplayBuffer]) -> None: def _compile(self) -> None: lens = last = index = np.array([0]) - offset = np.array([0, 10]) + offset = np.array([0, 1]) done = np.array([False, False]) _prev_index(index, offset, done, last, lens) _next_index(index, offset, done, last, lens) @@ -487,18 +488,22 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - if np.isscalar(index): + try: + len(index) # type: ignore + return _prev_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) + except TypeError: # generated by len(index) return _prev_index(np.array([index]), self._extend_offset, self.done, self.last_index, self._lengths)[0] - return _prev_index(np.asarray(index), self._extend_offset, - self.done, self.last_index, self._lengths) def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - if np.isscalar(index): + try: + len(index) # type: ignore + return _next_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) + except TypeError: # generated by len(index) return _next_index(np.array([index]), self._extend_offset, self.done, self.last_index, self._lengths)[0] - return _next_index(np.asarray(index), self._extend_offset, - self.done, self.last_index, self._lengths) def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" @@ -752,13 +757,14 @@ def _prev_index( ) -> np.ndarray: index = index % offset[-1] prev_index = np.zeros_like(index) - for left, right, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): - mask = (left <= index) & (index < right) - subind = index[mask] - if len(subind) > 0: - subind = (subind - left - 1) % cur_len - end_flag = done[subind + left] | (subind + left == last) - prev_index[mask] = (subind + end_flag) % cur_len + left + for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (start <= index) & (index < end) + cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] + subind = (subind - start - 1) % cur_len + end_flag = done[subind + start] | (subind + start == last) + prev_index[mask] = (subind + end_flag) % cur_len + start return prev_index @@ -772,10 +778,11 @@ def _next_index( ) -> np.ndarray: index = index % offset[-1] next_index = np.zeros_like(index) - for left, right, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): - mask = (left <= index) & (index < right) - subind = index[mask] - if len(subind) > 0: + for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (start <= index) & (index < end) + cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] end_flag = done[subind] | (subind == last) - next_index[mask] = (subind - left + 1 - end_flag) % cur_len + left + next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start return next_index From 23503734e8ffc14b2fb87b71cf3fff98224d5697 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sun, 21 Feb 2021 07:58:11 +0800 Subject: [PATCH 6/8] fix dirty fix --- tianshou/data/buffer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b10a8b81e..334e91d48 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -311,7 +311,10 @@ def get( if stack_num == 1: # the most often case return val[index] stack: List[Any] = [] - indice = np.asarray(index) + if isinstance(index, list): + indice = np.array(index) + else: + indice = index for _ in range(stack_num): stack = [val[indice]] + stack indice = self.prev(indice) @@ -488,20 +491,18 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - try: - len(index) # type: ignore + if isinstance(index, (list, np.ndarray)): return _prev_index(np.asarray(index), self._extend_offset, self.done, self.last_index, self._lengths) - except TypeError: # generated by len(index) + else: return _prev_index(np.array([index]), self._extend_offset, self.done, self.last_index, self._lengths)[0] def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: - try: - len(index) # type: ignore + if isinstance(index, (list, np.ndarray)): return _next_index(np.asarray(index), self._extend_offset, self.done, self.last_index, self._lengths) - except TypeError: # generated by len(index) + else: return _next_index(np.array([index]), self._extend_offset, self.done, self.last_index, self._lengths)[0] From 1dc1a6722204e04bd06435aa3ec0a73dc68290ef Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 22 Feb 2021 16:54:48 +0800 Subject: [PATCH 7/8] revert unfinished_index() --- test/base/test_buffer.py | 13 ++++++++++--- tianshou/data/buffer.py | 11 +++++++++++ tianshou/policy/base.py | 6 +++--- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index e203df5ea..225375d0a 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -403,7 +403,7 @@ def test_replaybuffermanager(): with pytest.raises(NotImplementedError): # ReplayBufferManager cannot be updated buf.update(buf) - # sample index / prev / next + # sample index / prev / next / unfinished_index indice = buf.sample_index(11000) assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample batch, indice = buf.sample(0) @@ -412,7 +412,9 @@ def test_replaybuffermanager(): assert np.allclose(indice_prev, indice), indice_prev indice_next = buf.next(indice) assert np.allclose(indice_next, indice), indice_next + assert np.allclose(buf.unfinished_index(), [0, 5]) buf.add(Batch(obs=[4], act=[4], rew=[4], done=[1]), buffer_ids=[3]) + assert np.allclose(buf.unfinished_index(), [0, 5]) batch, indice = buf.sample(10) batch, indice = buf.sample(0) assert np.allclose(indice, [0, 5, 10, 15]) @@ -455,10 +457,12 @@ def test_replaybuffermanager(): 10, 12, 12, 14, 14, 15, 17, 17, 19, 19, ]) + assert np.allclose(buf.unfinished_index(), [4, 14]) ptr, ep_rew, ep_len, ep_idx = buf.add( Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2]) assert np.all(ep_len == [3]) and np.all(ep_rew == [1]) assert np.all(ptr == [10]) and np.all(ep_idx == [13]) + assert np.allclose(buf.unfinished_index(), [4]) indice = list(sorted(buf.sample_index(0))) assert np.allclose(indice, np.arange(len(buf))) assert np.allclose(buf.prev(indice), [ @@ -474,8 +478,8 @@ def test_replaybuffermanager(): 15, 17, 17, 19, 19, ]) # corner case: list, int and -1 - assert buf.prev(-1) == buf.prev([buf.maxsize - 1]) - assert buf.next(-1) == buf.next([buf.maxsize - 1]) + assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] + assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] batch = buf._meta batch.info = np.ones(buf.maxsize) buf.set_batch(batch) @@ -509,6 +513,7 @@ def test_cachedbuffer(): assert np.allclose(buf.obs, obs) assert np.all(ep_len == [1]) and np.all(ep_rew == [2.0]) assert np.all(ptr == [0]) and np.all(ep_idx == [0]) + assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_index(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), @@ -517,6 +522,7 @@ def test_cachedbuffer(): assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] assert np.allclose(buf.obs, obs) + assert np.allclose(buf.unfinished_index(), [25]) indice = buf.sample_index(0) assert np.allclose(indice, [0, 1, 2, 25]) assert np.allclose(buf.done[indice], [1, 0, 1, 0]) @@ -597,6 +603,7 @@ def test_multibuf_stack(): 0, 0, 0, 1, 0, # cached_buffer[1] 0, 0, 0, 1, 0, # cached_buffer[2] ]), buf4.done + assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indice = sorted(buf4.sample_index(0)) assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) assert np.allclose(buf4[indice].obs[..., 0], [ diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 334e91d48..7e4af1d1e 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -128,6 +128,11 @@ def set_batch(self, batch: Batch) -> None: ), "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch + def unfinished_index(self) -> np.ndarray: + """Return the index of unfinished episode.""" + last = (self._index - 1) % self._size if self._size else 0 + return np.array([last] if not self.done[last] and self._size else [], np.int) + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: """Return the index of previous transition. @@ -490,6 +495,12 @@ def set_batch(self, batch: Batch) -> None: super().set_batch(batch) self._set_batch_for_children() + def unfinished_index(self) -> np.ndarray: + return np.concatenate([ + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers) + ]) + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: if isinstance(index, (list, np.ndarray)): return _prev_index(np.asarray(index), self._extend_offset, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index b1f04460a..b121090c3 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -223,7 +223,7 @@ def compute_episodic_return( :param batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be - recongized by buffer.last_index. + recongized by buffer.unfinished_index(). :type batch: :class:`~tianshou.data.Batch` :param numpy.ndarray indice: tell batch's location in buffer, batch is equal to buffer[indice]. @@ -247,7 +247,7 @@ def compute_episodic_return( v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice) end_flag = batch.done.copy() - end_flag[np.isin(indice, buffer.last_index)] = True + end_flag[np.isin(indice, buffer.unfinished_index())] = True returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda) if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() @@ -310,7 +310,7 @@ def compute_nstep_return( target_q = to_numpy(target_q_torch.reshape(bsz, -1)) target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1) end_flag = buffer.done.copy() - end_flag[buffer.last_index] = True + end_flag[buffer.unfinished_index()] = True target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step, mean, std) From e4d681ac0a9e1eed2ccecb0607784c4e35ddfcf9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Mon, 22 Feb 2021 16:59:50 +0800 Subject: [PATCH 8/8] small update --- tianshou/policy/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 1386e7012..3b663d630 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,13 +220,12 @@ def compute_episodic_return( Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. - :param batch: a data batch which contains several episodes of data in + :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). - :type batch: :class:`~tianshou.data.Batch` - :param numpy.ndarray indice: tell batch's location in buffer, batch is - equal to buffer[indice]. + :param numpy.ndarray indice: tell batch's location in buffer, batch is equal to + buffer[indice]. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation,