这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_collector(gym_reset_kwargs) -> None:
policy,
env,
ReplayBuffer(size=100),
logger.preprocess_fn,
preprocess_fn=logger.preprocess_fn,
)
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 3
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_collector(gym_reset_kwargs) -> None:
Collector(policy, dum, ReplayBuffer(10))
with pytest.raises(TypeError):
Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError):
with pytest.raises(ValueError):
c2.collect()

# test NXEnv
Expand Down Expand Up @@ -792,6 +792,28 @@ def test_collector_envpool_gym_reset_return_info() -> None:
assert np.allclose(c0.buffer.info["env_id"], env_ids)


def test_collector_with_vector_env():
writer = SummaryWriter("log/collector")
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [1, 8, 9, 10]]

dum = DummyVectorEnv(env_fns)
policy = MyPolicy()

c2 = Collector(
policy,
dum,
VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn,
)

c2r = c2.collect(n_episode=10, gym_reset_kwargs=None)
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 8, 9, 10]), c2r.lens)

c3r = c2.collect(n_episode=12, sample_equal_num_episodes_per_worker=True, gym_reset_kwargs=None)
assert np.array_equal(np.array([1, 8, 9, 10, 1, 8, 9, 10, 1, 8, 9, 10]), c3r.lens)


if __name__ == "__main__":
test_collector(gym_reset_kwargs=None)
test_collector(gym_reset_kwargs={})
Expand All @@ -801,3 +823,4 @@ def test_collector_envpool_gym_reset_return_info() -> None:
test_collector_with_async(gym_reset_kwargs=None)
test_collector_with_async(gym_reset_kwargs={"return_info": True})
test_collector_envpool_gym_reset_return_info()
test_collector_with_vector_env()
3 changes: 1 addition & 2 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,8 @@ def split(
batch if the length of the batch is smaller than "size". Size of -1 means
the whole batch.
:param shuffle: randomly shuffle the entire data batch if it is
True, otherwise remain in the same. Default to True.
True, otherwise remain in the same.
:param merge_last: merge the last batch into the previous one.
Default to False.
"""
...

Expand Down
76 changes: 43 additions & 33 deletions tianshou/data/buffer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@ class ReplayBuffer:

ReplayBuffer can be considered as a specialized form (or management) of Batch. It
stores all the data in a batch with circular-queue style.

For the example usage of ReplayBuffer, please check out Section Buffer in
For an example of how to use the ReplayBuffer, please refer to section Buffer in
:doc:`/01_tutorials/01_concepts`.

:param size: the maximum size of replay buffer.
:param stack_num: the frame-stack sampling argument, should be greater than or
equal to 1. Default to 1 (no stacking).
:param ignore_obs_next: whether to not store obs_next. Default to False.
:param stack_num: the frame-stack sampling argument. It Should be greater than or
equal to 1.
:param ignore_obs_next: whether to not store obs_next.
:param save_only_last_obs: only save the last obs/obs_next when it has a shape
of (timestep, ...) because of temporal stacking. Default to False.
of (timestep, ...) because of temporal stacking.
:param sample_avail: the parameter indicating sampling only available index
when using frame-stack sampling method. Default to False.
when using frame-stack sampling method.
"""

_reserved_keys = (
_RESERVED_KEYS = (
"obs",
"act",
"rew",
Expand All @@ -41,7 +40,7 @@ class ReplayBuffer:
"info",
"policy",
)
_input_keys = (
_INPUT_KEYS = (
"obs",
"act",
"rew",
Expand All @@ -51,6 +50,7 @@ class ReplayBuffer:
"info",
"policy",
)
_REQUIRED_KEYS = frozenset({"obs", "act", "rew", "terminated", "truncated", "done"})

def __init__(
self,
Expand Down Expand Up @@ -104,7 +104,8 @@ def __setstate__(self, state: dict[str, Any]) -> None:

def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned"
if key in self._RESERVED_KEYS:
raise ValueError(f"key '{key}' is reserved and cannot be assigned")
super().__setattr__(key, value)

def save_hdf5(self, path: str, compression: str | None = None) -> None:
Expand Down Expand Up @@ -162,7 +163,7 @@ def reset(self, keep_statistics: bool = False) -> None:
def set_batch(self, batch: RolloutBatchProtocol) -> None:
"""Manually choose the batch you want the ReplayBuffer to manage."""
assert len(batch) == self.maxsize and set(batch.keys()).issubset(
self._reserved_keys,
self._RESERVED_KEYS,
), "Input batch doesn't meet ReplayBuffer's data form requirement."
self._meta = batch

Expand All @@ -181,17 +182,17 @@ def prev(self, index: int | np.ndarray) -> np.ndarray:
return (index + end_flag) % self._size

def next(self, index: int | np.ndarray) -> np.ndarray:
"""Return the index of next transition.
"""Return the index of the next transition.

The index won't be modified if it is the end of an episode.
"""
end_flag = self.done[index] | (index == self.last_index[0])
return (index + (1 - end_flag)) % self._size

def update(self, buffer: "ReplayBuffer") -> np.ndarray:
"""Move the data from the given buffer to current buffer.
"""Move the data from the given buffer to the current buffer.

Return the updated indices. If update fails, return an empty array.
Return the updated indices. If the update fails, return an empty array.
"""
if len(buffer) == 0 or self.maxsize == 0:
return np.array([], int)
Expand All @@ -212,41 +213,43 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
self._meta[to_indices] = buffer._meta[from_indices]
return to_indices

def _add_index(
def _update_buffer_state_after_adding_batch(
self,
rew: float | np.ndarray,
done: bool,
) -> tuple[int, float | np.ndarray, int, int]:
"""Maintain the buffer's state after adding one data batch.

Return (index_to_be_modified, episode_reward, episode_length,
Return (index_to_add_at, episode_reward, episode_length,
episode_start_index).
"""
self.last_index[0] = ptr = self._index
self.last_index[0] = index_to_add_at = self._index
self._size = min(self._size + 1, self.maxsize)
self._index = (self._index + 1) % self.maxsize

self._ep_rew += rew
self._ep_len += 1

if done:
result = ptr, self._ep_rew, self._ep_len, self._ep_idx
result = index_to_add_at, self._ep_rew, self._ep_len, self._ep_idx
self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index
return result
return ptr, self._ep_rew * 0.0, 0, self._ep_idx
return index_to_add_at, self._ep_rew * 0.0, 0, self._ep_idx

def add(
self,
batch: RolloutBatchProtocol,
buffer_ids: np.ndarray | list[int] | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into replay buffer.
"""Add a batch of data into the replay buffer.

:param batch: the input data batch. "obs", "act", "rew",
"terminated", "truncated" are required keys.
:param buffer_ids: to make consistent with other buffer's add function; if it
is not None, we assume the input batch's first dimension is always 1.

Note: episode_start_index is the index of the first transition in the episode.

Return (current_index, episode_reward, episode_length, episode_start_index). If
the episode is not finished, the return value of episode_length and
episode_reward is 0.
Expand All @@ -257,26 +260,32 @@ def add(
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(
batch.keys(),
) # important to do after preprocess batch
# important to do this after preprocessing the batch
if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()):
raise RuntimeError(
f"The input batch you try to add is missing the keys {missing_keys}.",
)
stacked_batch = buffer_ids is not None
if stacked_batch:
assert len(batch) == 1
if stacked_batch and len(batch) != 1:
raise RuntimeError(
f"len(batch) has to equal 1 when buffer_ids is not None (currently it is {buffer_ids}), but instead it is {len(batch)}.",
)
if self._save_only_last_obs:
batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1]
if not self._save_obs_next:
batch.pop("obs_next", None)
elif self._save_only_last_obs:
batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1]
# get ptr
# get index to add at
if stacked_batch:
rew, done = batch.rew[0], batch.done[0]
else:
rew, done = batch.rew, batch.done
ptr, ep_rew, ep_len, ep_idx = (np.array([x]) for x in self._add_index(rew, done))
ep_add_at_idx, ep_rew, ep_len, ep_start_idx = (
np.array([x]) for x in self._update_buffer_state_after_adding_batch(rew, done)
)
try:
self._meta[ptr] = batch
self._meta[ep_add_at_idx] = batch
except ValueError:
stack = not stacked_batch
batch.rew = batch.rew.astype(float)
Expand All @@ -287,8 +296,8 @@ def add(
self._meta = create_value(batch, self.maxsize, stack) # type: ignore
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, stack)
self._meta[ptr] = batch
return ptr, ep_rew, ep_len, ep_idx
self._meta[ep_add_at_idx] = batch
return ep_add_at_idx, ep_rew, ep_len, ep_start_idx

def sample_indices(self, batch_size: int | None) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Expand Down Expand Up @@ -349,10 +358,11 @@ def get(
stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``.

:param index: the index for getting stacked data.
:param str key: the key to get, should be one of the reserved_keys.
:param str key: the key to get. Should be one of the reserved_keys.
:param default_value: if the given key's data is not found and default_value is
set, return this default_value.
:param stack_num: Default to self.stack_num.
:param stack_num: number of objects to stack. It should be greater than or
equal to 1.
"""
if key not in self._meta and default_value is not None:
return default_value
Expand Down Expand Up @@ -415,6 +425,6 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBat
"policy": self.get(indices, "policy", Batch()),
}
for key in self._meta.__dict__:
if key not in self._input_keys:
if key not in self._INPUT_KEYS:
batch_dict[key] = self._meta[key][indices]
return cast(RolloutBatchProtocol, Batch(batch_dict))
46 changes: 27 additions & 19 deletions tianshou/data/buffer/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class ReplayBufferManager(ReplayBuffer):
These replay buffers have contiguous memory layout, and the storage space each
buffer has is a shallow copy of the topmost memory.

:param buffer_list: a list of ReplayBuffer needed to be handled.
:param buffer_list: a list of ReplayBuffer objects needed to be handled.

.. seealso::
.. see also::

Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
Expand Down Expand Up @@ -118,20 +118,25 @@ def add(
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into ReplayBufferManager.

Each of the data's length (first dimension) must equal to the length of
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1].
Each of the data's lengths (first dimension) must be equal to the length of
buffer_ids. By default, buffer_ids is [0, 1, ..., buffer_num - 1].

Return (current_index, episode_reward, episode_length, episode_start_index). If
the episode is not finished, the return value of episode_length and
episode_reward is 0.
"""
# todo heavy code duplication with ReplayBuffer in buffer/base.py
# preprocess batch
new_batch = Batch()
for key in set(self._reserved_keys).intersection(batch.keys()):
for key in set(self._RESERVED_KEYS).intersection(batch.keys()):
new_batch.__dict__[key] = batch[key]
batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.keys())
if missing_keys := self._REQUIRED_KEYS.difference(batch.keys()):
raise RuntimeError(
f"The input batch you try to add is missing the keys {missing_keys}.",
)

if self._save_only_last_obs:
batch.obs = batch.obs[:, -1]
if not self._save_obs_next:
Expand All @@ -141,21 +146,24 @@ def add(
# get index
if buffer_ids is None:
buffer_ids = np.arange(self.buffer_num)
ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], []

ep_add_at_idxs, ep_lens, ep_rews, ep_start_idxs = [], [], [], []
for batch_idx, buffer_id in enumerate(buffer_ids):
ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index(
ep_add_at_idx, ep_rew, ep_len, ep_start_idx = self.buffers[
buffer_id
]._update_buffer_state_after_adding_batch(
batch.rew[batch_idx],
batch.done[batch_idx],
)
ptrs.append(ptr + self._offset[buffer_id])
ep_add_at_idxs.append(ep_add_at_idx + self._offset[buffer_id])
ep_lens.append(ep_len)
ep_rews.append(ep_rew)
ep_idxs.append(ep_idx + self._offset[buffer_id])
self.last_index[buffer_id] = ptr + self._offset[buffer_id]
ep_start_idxs.append(ep_start_idx + self._offset[buffer_id])
self.last_index[buffer_id] = ep_add_at_idx + self._offset[buffer_id]
self._lengths[buffer_id] = len(self.buffers[buffer_id])
ptrs = np.array(ptrs)
ep_add_at_idxs = np.array(ep_add_at_idxs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the semantics of this variable? The docstring calls it the current index. If it's not the ep_last_idxs, what does it mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index at which the transition is added to the buffer.
As there is ep_start_idx that indicates the first transition of the current episode, ep_last_idx should be the index of the last transition in the episode. Whenever the current transition does not contain done, this is not the last index of the episode (as it continues) but the index at which to add the current transition.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking for discussion in pair programming

try:
self._meta[ptrs] = batch
self._meta[ep_add_at_idxs] = batch
except ValueError:
batch.rew = batch.rew.astype(float)
batch.done = batch.done.astype(bool)
Expand All @@ -166,8 +174,8 @@ def add(
else: # dynamic key pops up in batch
alloc_by_keys_diff(self._meta, batch, self.maxsize, False)
self._set_batch_for_children()
self._meta[ptrs] = batch
return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
self._meta[ep_add_at_idxs] = batch
return ep_add_at_idxs, np.array(ep_rews), np.array(ep_lens), np.array(ep_start_idxs)

def sample_indices(self, batch_size: int | None) -> np.ndarray:
# TODO: simplify this code
Expand Down Expand Up @@ -212,9 +220,9 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage
These replay buffers have contiguous memory layout, and the storage space each
buffer has is a shallow copy of the topmost memory.

:param buffer_list: a list of PrioritizedReplayBuffer needed to be handled.
:param buffer_list: a list of PrioritizedReplayBuffer objects needed to be handled.

.. seealso::
.. see also::

Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
Expand All @@ -233,9 +241,9 @@ class HERReplayBufferManager(ReplayBufferManager):
These replay buffers have contiguous memory layout, and the storage space each
buffer has is a shallow copy of the topmost memory.

:param buffer_list: a list of HERReplayBuffer needed to be handled.
:param buffer_list: a list of HERReplayBuffer objects needed to be handled.

.. seealso::
.. see also::

Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
Expand Down
Loading