diff --git a/tianshou/algorithm/algorithm_base.py b/tianshou/algorithm/algorithm_base.py index 50884d646..e4dea38ab 100644 --- a/tianshou/algorithm/algorithm_base.py +++ b/tianshou/algorithm/algorithm_base.py @@ -2,6 +2,7 @@ import time from abc import ABC, abstractmethod from collections.abc import Callable, Mapping +from copy import copy from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast @@ -945,6 +946,50 @@ def update( ) +class OfflineAlgorithmFromOffPolicyAlgorithm( + OfflineAlgorithm[TPolicy], + Generic[TPolicy], + ABC, +): + """Base class for offline algorithms that use the same data preprocessing as an off-policy algorithm. + + Has to be used within a diamond inheritance pattern, as it does not call `super().__init__` in order to not + initialize `Algorithm` (and thereby `nn.Module`) twice. The diamond inheritance is used for transforming the respective off-policy algorithm + into a derived offline variant, see usages of this class in the codebase. + """ + + # noinspection PyMissingConstructor + def __init__( + self, *, policy: TPolicy, off_policy_algorithm_class: type[OfflineAlgorithm[TPolicy]] + ): + self._off_policy_algorithm_class = off_policy_algorithm_class + + @override + def process_buffer(self, buffer: TBuffer) -> TBuffer: + """Use the off-policy algorithm's batch pre-processing for processing the buffer once before training. + + This implementation avoids unnecessary re-computation of preprocessing. + """ + buffer = copy(buffer) + batch, indices = buffer.sample(0) + processed_batch = self._off_policy_algorithm_class._preprocess_batch( + self, batch, buffer, indices # type: ignore[arg-type] + ) + buffer_batch = copy(buffer._meta) + buffer_batch.update(processed_batch) + buffer.set_batch(buffer_batch) + return buffer + + @override + def _preprocess_batch( + self, + batch: RolloutBatchProtocol, + buffer: ReplayBuffer, + indices: np.ndarray, + ) -> RolloutBatchProtocol | BatchWithReturnsProtocol: + return batch + + class OnPolicyWrapperAlgorithm( OnPolicyAlgorithm[TPolicy], Generic[TPolicy], diff --git a/tianshou/algorithm/imitation/td3_bc.py b/tianshou/algorithm/imitation/td3_bc.py index 5ccbbe0fb..68a09dc8c 100644 --- a/tianshou/algorithm/imitation/td3_bc.py +++ b/tianshou/algorithm/imitation/td3_bc.py @@ -2,7 +2,7 @@ import torch.nn.functional as F from tianshou.algorithm import TD3 -from tianshou.algorithm.algorithm_base import OfflineAlgorithm +from tianshou.algorithm.algorithm_base import OfflineAlgorithmFromOffPolicyAlgorithm from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy from tianshou.algorithm.modelfree.td3 import TD3TrainingStats from tianshou.algorithm.optim import OptimizerFactory @@ -11,7 +11,7 @@ # NOTE: This uses diamond inheritance to convert from off-policy to offline -class TD3BC(OfflineAlgorithm[ContinuousDeterministicPolicy], TD3): # type: ignore +class TD3BC(OfflineAlgorithmFromOffPolicyAlgorithm[ContinuousDeterministicPolicy], TD3): # type: ignore """Implementation of TD3+BC. arXiv:2106.06860.""" def __init__( @@ -97,6 +97,9 @@ def __init__( update_actor_freq=update_actor_freq, n_step_return_horizon=n_step_return_horizon, ) + OfflineAlgorithmFromOffPolicyAlgorithm.__init__( + self, policy=policy, off_policy_algorithm_class=TD3 # type: ignore[arg-type] + ) self.alpha = alpha def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats: diff --git a/tianshou/data/buffer/buffer_base.py b/tianshou/data/buffer/buffer_base.py index 72c7af5bb..96a12d359 100644 --- a/tianshou/data/buffer/buffer_base.py +++ b/tianshou/data/buffer/buffer_base.py @@ -297,13 +297,8 @@ def reset(self, keep_statistics: bool = False) -> None: if not keep_statistics: self._ep_return, self._ep_len = 0.0, 0 - # TODO: is this method really necessary? It's kinda dangerous, can accidentally - # remove all references to collected data def set_batch(self, batch: RolloutBatchProtocol) -> None: - """Manually choose the batch you want the ReplayBuffer to manage.""" - assert len(batch) == self.maxsize and set(batch.get_keys()).issubset( - self._reserved_keys, - ), "Input batch doesn't meet ReplayBuffer's data form requirement." + """Manually choose the batch you want the ReplayBuffer to manage. Use with caution!.""" self._meta = batch def unfinished_index(self) -> np.ndarray: @@ -495,12 +490,10 @@ def add( def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. - Return all available indices in the buffer if batch_size is 0; return an empty - numpy array if batch_size < 0 or no available index can be sampled. - - :param batch_size: the number of indices to be sampled. If None, it will be set - to the length of the buffer (i.e. return all available indices in a - random order). + :param batch_size: the number of indices to be sampled. Three cases are possible: + 1. positive int - sample random indices of that size + 2. zero - all indices in current order + 3. None - all indices but in random order """ if batch_size is None: batch_size = len(self) @@ -533,8 +526,10 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]: """Get a random sample from buffer with size = batch_size. - Return all the data in the buffer if batch_size is 0. - + :param batch_size: the number of indices to be sampled. Three cases are possible: + 1. positive int - sample random indices of that size + 2. zero - all indices in current order + 3. None - all indices but in random order :return: Sample data and its corresponding index inside the buffer. """ indices = self.sample_indices(batch_size)