diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 2ee931ec8..54c9f1cf6 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -98,11 +98,12 @@ def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": buf.__setstate__(from_hdf5(f, device=device)) return buf - def reset(self) -> None: + def reset(self, keep_statistics: bool = False) -> None: """Clear all the data in replay buffer and episode statistics.""" self.last_index = np.array([0]) self._index = self._size = 0 - self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 + if not keep_statistics: + self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 def set_batch(self, batch: Batch) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index ccd03eb98..fa9db2556 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -48,11 +48,11 @@ def _compile(self) -> None: def __len__(self) -> int: return self._lengths.sum() - def reset(self) -> None: + def reset(self, keep_statistics: bool = False) -> None: self.last_index = self._offset.copy() self._lengths = np.zeros_like(self._offset) for buf in self.buffers: - buf.reset() + buf.reset(keep_statistics=keep_statistics) def _set_batch_for_children(self) -> None: for offset, buf in zip(self._offset, self.buffers): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 819342863..bf7399080 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -105,9 +105,9 @@ def reset_stat(self) -> None: """Reset the statistic variables.""" self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 - def reset_buffer(self) -> None: + def reset_buffer(self, keep_statistics: bool = False) -> None: """Reset the data buffer.""" - self.buffer.reset() + self.buffer.reset(keep_statistics=keep_statistics) def reset_env(self) -> None: """Reset all of the environments.""" diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 2325ce6ee..e396295d0 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -134,7 +134,7 @@ def onpolicy_trainer( losses = policy.update( 0, train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect) - train_collector.reset_buffer() + train_collector.reset_buffer(keep_statistics=True) step = max([1] + [ len(v) for v in losses.values() if isinstance(v, list)]) gradient_step += step