From 8f2780b48596086a41e2f10400160f5c45279bfe Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sat, 25 Jul 2020 12:13:33 +0200 Subject: [PATCH 1/2] stack_num starts at 1 (for no stacking) instead of 0. --- tianshou/data/buffer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b7ddcff38..10f1e1e10 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -104,7 +104,7 @@ class ReplayBuffer: :param int size: the size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater - than 1, defaults to 0 (no stacking). + than or equal to 1, defaults to 1 (no stacking). :param bool ignore_obs_next: whether to store obs_next, defaults to ``False``. :param bool sample_avail: the parameter indicating sampling only available @@ -112,19 +112,19 @@ class ReplayBuffer: This feature is not supported in Prioritized Replay Buffer currently. """ - def __init__(self, size: int, stack_num: Optional[int] = 0, + def __init__(self, size: int, stack_num: int = 1, ignore_obs_next: bool = False, sample_avail: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size - self._stack = stack_num - assert stack_num != 1, 'stack_num should greater than 1' + self._stack_num = None self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next self._index = 0 self._size = 0 self._meta = Batch() + self._set_stack_num(stack_num) self.reset() def __len__(self) -> int: @@ -161,6 +161,7 @@ def _get_stack_num(self): return self._stack def _set_stack_num(self, num): + assert num > 0, 'stack_num should greater than 0' self._stack = num def update(self, buffer: 'ReplayBuffer') -> None: @@ -169,7 +170,7 @@ def update(self, buffer: 'ReplayBuffer') -> None: return i = begin = buffer._index % len(buffer) origin = buffer._get_stack_num() - buffer._set_stack_num(0) + buffer._set_stack_num(1) while True: self.add(**buffer[i]) i = (i + 1) % len(buffer) @@ -276,7 +277,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, key = 'obs' val = self._meta.__dict__[key] try: - if stack_num > 0: + if stack_num > 1: stack = [] for _ in range(stack_num): stack = [val[indice]] + stack @@ -300,7 +301,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, def __getitem__(self, index: Union[ slice, int, np.integer, np.ndarray]) -> Batch: - """Return a data batch: self[index]. If stack_num is set to be > 0, + """Return a data batch: self[index]. If stack_num is larger than 1, return the stacked obs and obs_next with shape [batch, len, ...]. """ return Batch( From f35a22ec399af1c9f147e52733c3ee97ad87cc60 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Sat, 25 Jul 2020 12:45:23 +0200 Subject: [PATCH 2/2] Use getter/stepper for stack_num. --- tianshou/data/buffer.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 10f1e1e10..4491bee01 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -117,14 +117,14 @@ def __init__(self, size: int, stack_num: int = 1, sample_avail: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size - self._stack_num = None + self._stack = None + self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 self._avail_index = [] self._save_s_ = not ignore_obs_next self._index = 0 self._size = 0 self._meta = Batch() - self._set_stack_num(stack_num) self.reset() def __len__(self) -> int: @@ -157,10 +157,12 @@ def _add_to_buffer(self, name: str, inst: Any) -> None: value.__dict__[key] = _create_value(inst[key], self._maxsize) value[self._index] = inst - def _get_stack_num(self): + @property + def stack_num(self): return self._stack - def _set_stack_num(self, num): + @stack_num.setter + def stack_num(self, num): assert num > 0, 'stack_num should greater than 0' self._stack = num @@ -169,14 +171,14 @@ def update(self, buffer: 'ReplayBuffer') -> None: if len(buffer) == 0: return i = begin = buffer._index % len(buffer) - origin = buffer._get_stack_num() - buffer._set_stack_num(1) + stack_num_orig = buffer.stack_num + buffer.stack_num = 1 while True: self.add(**buffer[i]) i = (i + 1) % len(buffer) if i == begin: break - buffer._set_stack_num(origin) + buffer.stack_num = stack_num_orig def add(self, obs: Union[dict, Batch, np.ndarray], @@ -205,15 +207,15 @@ def add(self, if self._avail: # update current frame avail = sum(self.done[i] for i in range( - self._index - self._stack + 1, self._index)) == 0 - if self._size < self._stack - 1: + self._index - self.stack_num + 1, self._index)) == 0 + if self._size < self.stack_num - 1: avail = False if avail and self._index not in self._avail_index: self._avail_index.append(self._index) elif not avail and self._index in self._avail_index: self._avail_index.remove(self._index) # remove the later available frame because of broken storage - t = (self._index + self._stack - 1) % self._maxsize + t = (self._index + self.stack_num - 1) % self._maxsize if t in self._avail_index: self._avail_index.remove(t) @@ -256,7 +258,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, given from buffer initialization procedure. """ if stack_num is None: - stack_num = self._stack + stack_num = self.stack_num if isinstance(indice, slice): indice = np.arange( 0 if indice.start is None