这是indexloc提供的服务,不要输入任何密码
Skip to content

Replay buffer allows stack_num = 1 #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,21 @@ class ReplayBuffer:

:param int size: the size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than 1, defaults to 0 (no stacking).
than or equal to 1, defaults to 1 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to
``False``.
:param bool sample_avail: the parameter indicating sampling only available
index when using frame-stack sampling method, defaults to ``False``.
This feature is not supported in Prioritized Replay Buffer currently.
"""

def __init__(self, size: int, stack_num: Optional[int] = 0,
def __init__(self, size: int, stack_num: int = 1,
ignore_obs_next: bool = False,
sample_avail: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._stack = stack_num
assert stack_num != 1, 'stack_num should greater than 1'
self._stack = None
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._save_s_ = not ignore_obs_next
Expand Down Expand Up @@ -157,25 +157,28 @@ def _add_to_buffer(self, name: str, inst: Any) -> None:
value.__dict__[key] = _create_value(inst[key], self._maxsize)
value[self._index] = inst

def _get_stack_num(self):
@property
def stack_num(self):
return self._stack

def _set_stack_num(self, num):
@stack_num.setter
def stack_num(self, num):
assert num > 0, 'stack_num should greater than 0'
self._stack = num

def update(self, buffer: 'ReplayBuffer') -> None:
"""Move the data from the given buffer to self."""
if len(buffer) == 0:
return
i = begin = buffer._index % len(buffer)
origin = buffer._get_stack_num()
buffer._set_stack_num(0)
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
self.add(**buffer[i])
i = (i + 1) % len(buffer)
if i == begin:
break
buffer._set_stack_num(origin)
buffer.stack_num = stack_num_orig

def add(self,
obs: Union[dict, Batch, np.ndarray],
Expand Down Expand Up @@ -204,15 +207,15 @@ def add(self,
if self._avail:
# update current frame
avail = sum(self.done[i] for i in range(
self._index - self._stack + 1, self._index)) == 0
if self._size < self._stack - 1:
self._index - self.stack_num + 1, self._index)) == 0
if self._size < self.stack_num - 1:
avail = False
if avail and self._index not in self._avail_index:
self._avail_index.append(self._index)
elif not avail and self._index in self._avail_index:
self._avail_index.remove(self._index)
# remove the later available frame because of broken storage
t = (self._index + self._stack - 1) % self._maxsize
t = (self._index + self.stack_num - 1) % self._maxsize
if t in self._avail_index:
self._avail_index.remove(t)

Expand Down Expand Up @@ -255,7 +258,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
given from buffer initialization procedure.
"""
if stack_num is None:
stack_num = self._stack
stack_num = self.stack_num
if isinstance(indice, slice):
indice = np.arange(
0 if indice.start is None
Expand All @@ -276,7 +279,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
key = 'obs'
val = self._meta.__dict__[key]
try:
if stack_num > 0:
if stack_num > 1:
stack = []
for _ in range(stack_num):
stack = [val[indice]] + stack
Expand All @@ -300,7 +303,7 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,

def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
"""Return a data batch: self[index]. If stack_num is set to be > 0,
"""Return a data batch: self[index]. If stack_num is larger than 1,
return the stacked obs and obs_next with shape [batch, len, ...].
"""
return Batch(
Expand Down