diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 381581ce9..05595af49 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -96,7 +96,7 @@ def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": """Load replay buffer from HDF5 file.""" with h5py.File(path, "r") as f: buf = cls.__new__(cls) - buf.__setstate__(from_hdf5(f, device=device)) + buf.__setstate__(from_hdf5(f, device=device)) # type: ignore return buf def reset(self, keep_statistics: bool = False) -> None: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5e5006dee..82545a041 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -46,6 +46,10 @@ class Collector(object): Please make sure the given environment has a time limitation if using n_episode collect option. + + .. note:: + In past versions of Tianshou, the replay buffer that was passed to `__init__` + was automatically reset. This is not done in the current implementation. """ def __init__( @@ -68,7 +72,7 @@ def __init__( self.preprocess_fn = preprocess_fn self._action_space = env.action_space # avoid creating attribute outside __init__ - self.reset() + self.reset(False) def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: """Check if the buffer matches the constraint.""" @@ -94,15 +98,20 @@ def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: ) self.buffer = buffer - def reset(self) -> None: - """Reset all related variables in the collector.""" + def reset(self, reset_buffer: bool = True) -> None: + """Reset the environment, statistics, current data and possibly replay memory. + + :param bool reset_buffer: if true, reset the replay buffer that is attached + to the collector. + """ # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy self.data = Batch( obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} ) self.reset_env() - self.reset_buffer() + if reset_buffer: + self.reset_buffer() self.reset_stat() def reset_stat(self) -> None: diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 761540502..38effe3f2 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -48,7 +48,12 @@ def __init__( self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) + self.last = MLP( + input_dim, # type: ignore + self.output_dim, + hidden_sizes, + device=self.device + ) self._max = max_action def forward( @@ -96,7 +101,12 @@ def __init__( self.preprocess = preprocess_net self.output_dim = 1 input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, 1, hidden_sizes, device=self.device) + self.last = MLP( + input_dim, # type: ignore + 1, + hidden_sizes, + device=self.device + ) def forward( self, @@ -165,11 +175,19 @@ def __init__( self.device = device self.output_dim = int(np.prod(action_shape)) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) + self.mu = MLP( + input_dim, # type: ignore + self.output_dim, + hidden_sizes, + device=self.device + ) self._c_sigma = conditioned_sigma if conditioned_sigma: self.sigma = MLP( - input_dim, self.output_dim, hidden_sizes, device=self.device + input_dim, # type: ignore + self.output_dim, + hidden_sizes, + device=self.device ) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index bcc6531e3..844691c09 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -49,7 +49,12 @@ def __init__( self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) + self.last = MLP( + input_dim, # type: ignore + self.output_dim, + hidden_sizes, + device=self.device + ) self.softmax_output = softmax_output def forward( @@ -101,7 +106,12 @@ def __init__( self.preprocess = preprocess_net self.output_dim = last_size input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) + self.last = MLP( + input_dim, # type: ignore + last_size, + hidden_sizes, + device=self.device + ) def forward( self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any @@ -183,8 +193,10 @@ def __init__( self.input_dim = getattr( preprocess_net, "output_dim", preprocess_net_output_dim ) - self.embed_model = CosineEmbeddingNetwork(num_cosines, - self.input_dim).to(device) + self.embed_model = CosineEmbeddingNetwork( + num_cosines, + self.input_dim # type: ignore + ).to(device) def forward( # type: ignore self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any