diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 029220c56..e1c3f2b3c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,6 +1,6 @@ - [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s) - [ ] I have provided a description of the changes in this Pull Request -- [ ] I have added documentation for my changes +- [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md - [ ] If applicable, I have added tests to cover my changes. - [ ] I have reformatted the code using `poe format` - [ ] I have checked style and types with `poe lint` and `poe type-check` diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ce216139..5b542ea11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,19 @@ - New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). Launchers for parallelization currently in alpha state. #1074 - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 +- `continuous.Critic`: + - Add flag `apply_preprocess_net_to_obs_only` to allow the + preprocessing network to be applied to the observations only (without + the actions concatenated), which is essential for the case where we want + to reuse the actor's preprocessing network #1128 + +### Fixes +- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics, + fixing the case where we want to reuse an actor's preprocessing network for the critic (affects usages + of the experiment builder method `with_critic_factory_use_actor` with continuous environments) #1128 +- `atari_network.DQN`: + - Fix constructor input validation #1128 + - Fix `output_dim` not being set if `features_only`=True and `output_dim_added_layer` is not None #1128 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 9a97b20cb..0a6675cb7 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -152,7 +152,7 @@ "id": "Lh2-hwE5Dn9I" }, "source": [ - "Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." + "Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." ] }, { diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index ea900e975..4f2a5600a 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -66,7 +66,7 @@ def __init__( layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, ) -> None: # TODO: Add docstring - if features_only and output_dim_added_layer is not None: + if not features_only and output_dim_added_layer is not None: raise ValueError( "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", ) @@ -98,6 +98,7 @@ def __init__( layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), nn.ReLU(inplace=True), ) + self.output_dim = output_dim_added_layer else: self.output_dim = base_cnn_output_dim diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index a34964f5a..b1719be56 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -172,18 +172,19 @@ def unfinished_index(self) -> np.ndarray: return np.array([last] if not self.done[last] and self._size else [], int) def prev(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of previous transition. - - The index won't be modified if it is the beginning of an episode. + """Return the index of preceding step within the same episode if it exists. + If it does not exist (because it is the first index within the episode), + the index remains unmodified. """ - index = (index - 1) % self._size + index = (index - 1) % self._size # compute preceding index with wrap-around + # end_flag will be 1 if the previous index is the last step of an episode or + # if it is the very last index of the buffer (wrap-around case), and 0 otherwise end_flag = self.done[index] | (index == self.last_index[0]) return (index + end_flag) % self._size def next(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of next transition. - - The index won't be modified if it is the end of an episode. + """Return the index of next step if there is a next step within the episode. + If there isn't a next step, the index remains unmodified. """ end_flag = self.done[index] | (index == self.last_index[0]) return (index + (1 - end_flag)) % self._size diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index aa4f53ec7..f8ca0c0d6 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -118,9 +118,12 @@ class SamplingConfig(ToStringMixin): replay_buffer_ignore_obs_next: bool = False replay_buffer_save_only_last_obs: bool = False - """if True, only the most recent frame is saved when appending to experiences rather than the - full stacked frames. This avoids duplicating observations in buffer memory. Set to False to - save stacked frames in full. + """if True, for the case where the environment outputs stacked frames (e.g. because it + is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate + observations in buffer memory. Specifically, if the environment outputs observations `obs` with + shape (N, ...), only obs[-1] of shape (...) will be stored. + Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting + :attr:`replay_buffer_stack_num`. """ replay_buffer_stack_num: int = 1 @@ -128,6 +131,9 @@ class SamplingConfig(ToStringMixin): the number of consecutive environment observations to stack and use as the observation input to the agent for each time step. Setting this to a value greater than 1 can help agents learn temporal aspects (e.g. velocities of moving objects for which only positions are observed). + + If the environment already stacks frames (e.g. using a `FrameStack` wrapper), this should either not + be used or should be used in conjunction with :attr:`replay_buffer_save_only_last_obs`. """ @property diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index f1984e4d7..4eacef115 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -197,7 +197,11 @@ def create_module( last_size=last_size, ).to(device) elif envs.get_type().is_continuous(): - return continuous.Critic(actor.get_preprocess_net(), device=device).to(device) + return continuous.Critic( + actor.get_preprocess_net(), + device=device, + apply_preprocess_net_to_obs_only=True, + ).to(device) else: raise ValueError diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 6cd4a0f63..0b28f98f9 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -15,6 +15,7 @@ TLinearLayer, get_output_dim, ) +from tianshou.utils.pickle import setstate SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -109,6 +110,9 @@ class Critic(CriticBase): `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. + :param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before + concatenating with the action) - and without the observations being modified in any way beforehand. + This allows the actor's preprocessing network to be reused for the critic. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -122,11 +126,13 @@ def __init__( preprocess_net_output_dim: int | None = None, linear_layer: TLinearLayer = nn.Linear, flatten_input: bool = True, + apply_preprocess_net_to_obs_only: bool = False, ) -> None: super().__init__() self.device = device self.preprocess = preprocess_net self.output_dim = 1 + self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP( input_dim, @@ -137,6 +143,14 @@ def __init__( flatten_input=flatten_input, ) + def __setstate__(self, state: dict) -> None: + setstate( + Critic, + self, + state, + new_default_properties={"apply_preprocess_net_to_obs_only": False}, + ) + def forward( self, obs: np.ndarray | torch.Tensor, @@ -148,7 +162,10 @@ def forward( obs, device=self.device, dtype=torch.float32, - ).flatten(1) + ) + if self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + obs = obs.flatten(1) if act is not None: act = torch.as_tensor( act, @@ -156,8 +173,9 @@ def forward( dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) - values_B, hidden_BH = self.preprocess(obs) - return self.last(values_B) + if not self.apply_preprocess_net_to_obs_only: + obs, _ = self.preprocess(obs) + return self.last(obs) class ActorProb(BaseActor): diff --git a/tianshou/utils/pickle.py b/tianshou/utils/pickle.py new file mode 100644 index 000000000..924716222 --- /dev/null +++ b/tianshou/utils/pickle.py @@ -0,0 +1,97 @@ +"""Helper functions for persistence/pickling, which have been copied from sensAI (specifically `sensai.util.pickle`).""" + +from collections.abc import Iterable +from copy import copy +from typing import Any + + +def setstate( + cls: type, + obj: Any, + state: dict[str, Any], + renamed_properties: dict[str, str] | None = None, + new_optional_properties: list[str] | None = None, + new_default_properties: dict[str, Any] | None = None, + removed_properties: list[str] | None = None, +) -> None: + """Helper function for safe implementations of `__setstate__` in classes, which appropriately handles the cases where + a parent class already implements `__setstate__` and where it does not. Call this function whenever you would actually + like to call the super-class' implementation. + Unfortunately, `__setstate__` is not implemented in `object`, rendering `super().__setstate__(state)` invalid in the general case. + + :param cls: the class in which you are implementing `__setstate__` + :param obj: the instance of `cls` + :param state: the state dictionary + :param renamed_properties: a mapping from old property names to new property names + :param new_optional_properties: a list of names of new property names, which, if not present, shall be initialized with None + :param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present + :param removed_properties: a list of names of properties that are no longer being used + """ + # handle new/changed properties + if renamed_properties is not None: + for mOld, mNew in renamed_properties.items(): + if mOld in state: + state[mNew] = state[mOld] + del state[mOld] + if new_optional_properties is not None: + for mNew in new_optional_properties: + if mNew not in state: + state[mNew] = None + if new_default_properties is not None: + for mNew, mValue in new_default_properties.items(): + if mNew not in state: + state[mNew] = mValue + if removed_properties is not None: + for p in removed_properties: + if p in state: + del state[p] + # call super implementation, if any + s = super(cls, obj) + if hasattr(s, "__setstate__"): + s.__setstate__(state) + else: + obj.__dict__ = state + + +def getstate( + cls: type, + obj: Any, + transient_properties: Iterable[str] | None = None, + excluded_properties: Iterable[str] | None = None, + override_properties: dict[str, Any] | None = None, + excluded_default_properties: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Helper function for safe implementations of `__getstate__` in classes, which appropriately handles the cases where + a parent class already implements `__getstate__` and where it does not. Call this function whenever you would actually + like to call the super-class' implementation. + Unfortunately, `__getstate__` is not implemented in `object`, rendering `super().__getstate__()` invalid in the general case. + + :param cls: the class in which you are implementing `__getstate__` + :param obj: the instance of `cls` + :param transient_properties: transient properties which shall be set to None in serializations + :param excluded_properties: properties which shall be completely removed from serializations + :param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set; + use this to set a fixed value for an existing property or to add a completely new property + :param excluded_default_properties: properties which shall be completely removed from serializations, if they are set + to the given default value + :return: the state dictionary, which may be modified by the receiver + """ + s = super(cls, obj) + d = s.__getstate__() if hasattr(s, "__getstate__") else obj.__dict__ + d = copy(d) + if transient_properties is not None: + for p in transient_properties: + if p in d: + d[p] = None + if excluded_properties is not None: + for p in excluded_properties: + if p in d: + del d[p] + if override_properties is not None: + for k, v in override_properties.items(): + d[k] = v + if excluded_default_properties is not None: + for p, v in excluded_default_properties.items(): + if p in d and d[p] == v: + del d[p] + return d