diff --git a/docs/_config.yml b/docs/_config.yml index 925c99439..a0bb290a2 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -102,6 +102,10 @@ sphinx: recursive_update : false # A boolean indicating whether to overwrite the Sphinx config (true) or recursively update (false) config : # key-value pairs to directly over-ride the Sphinx configuration autodoc_typehints_format: "short" + autodoc_member_order: "bysource" + autoclass_content: "both" + autodoc_default_options: + show-inheritance: True html_js_files: # We have to list them explicitly because they need to be loaded in a specific order - js/vega@5.js diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c30b9f2cb..04a4a7803 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -271,3 +271,12 @@ v_s_ obs obs_next dtype +entrypoint +interquantile +init +kwarg +kwargs +autocompletion +codebase +indexable +sliceable diff --git a/pyproject.toml b/pyproject.toml index d332d9393..2b7ef9744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,7 +176,10 @@ ignore = [ "RET505", "D106", # undocumented public nested class "D205", # blank line after summary (prevents summary-only docstrings, which makes no sense) + "D212", # no blank line after """. This clashes with sphinx for multiline descriptions of :param: that start directly after """ "PLW2901", # overwrite vars in loop + "B027", # empty and non-abstract method in abstract class + "D404", # It's fine to start with "This" in docstrings ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index c64b0fc75..4ea99dbcb 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,3 +1,47 @@ +"""This module implements :class:`Batch`, a flexible data structure for +handling heterogeneous data in reinforcement learning algorithms. Such a data structure +is needed since RL algorithms differ widely in the conceptual fields that they need. +`Batch` is the main data carrier in Tianshou. It bears some similarities to +`TensorDict `_ +that is used for a similar purpose in `pytorch-rl `_. +The main differences between the two are that `Batch` can hold arbitrary objects (and not just torch tensors), +and that Tianshou implements `BatchProtocol` for enabling type checking and autocompletion (more on that below). + +The `Batch` class is designed to store and manipulate collections of data with +varying types and structures. It strikes a balance between flexibility and type safety, the latter mainly +achieved through the use of protocols. One can thing of it as a mixture of a dictionary and an array, +as it has both key-value pairs and nesting, while also having a shape, being indexable and sliceable. + +Key features of the `Batch` class include: + +1. Flexible data storage: Can hold numpy arrays, torch tensors, scalars, and nested Batch objects. +2. Dynamic attribute access: Allows setting and accessing data using attribute notation (e.g., `batch.observation`). + This allows for type-safe and readable code and enables IDE autocompletion. See comments on `BatchProtocol` below. +3. Indexing and slicing: Supports numpy-like indexing and slicing operations. The slicing is extended to nested + Batch objects and torch Distributions. +4. Batch operations: Provides methods for splitting, shuffling, concatenating and stacking multiple Batch objects. +5. Data type conversion: Offers methods to convert data between numpy arrays and torch tensors. +6. Value transformations: Allows applying functions to all values in the Batch recursively. +7. Analysis utilities: Provides methods for checking for missing values, dropping entries with missing values, + and others. + +Since we want to keep `Batch` flexible and not fix a specific set of fields or their types, +we don't have fixed interfaces for actual `Batch` objects that are used throughout +tianshou (such interfaces could be dataclasses, for example). However, we still want to enable +IDE autocompletion and type checking for `Batch` objects. To achieve this, we rely on dynamic duck typing +by using `Protocol`. The :class:`BatchProtocol` defines the interface that all `Batch` objects should adhere to, +and its various implementations (like :class:`~.types.ActBatchProtocol` or :class:`~.types.RolloutBatchProtocol`) define the specific +fields that are expected in the respective `Batch` objects. The protocols are then used as type hints +throughout the codebase. Protocols can't be instantiated, but we can cast to them. +For example, we "instantiate" an `ActBatchProtocol` with something like: + +>>> act_batch = cast(ActBatchProtocol, Batch(act=my_action)) + +The users can decide for themselves how to structure their `Batch` objects, and can opt in to the +`BatchProtocol` style to enable type checking and autocompletion. Opting out will have no effect on +the functionality. +""" + import pprint import warnings from collections.abc import Callable, Collection, Iterable, Iterator, KeysView, Sequence diff --git a/tianshou/data/types.py b/tianshou/data/types.py index a4fd43543..fd2f6d287 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -9,10 +9,6 @@ TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] -d: dict[str, TNestedDictValue] = {"a": {"b": np.array([1, 2, 3])}} -d["c"] = np.array([1, 2, 3]) - - class ObsBatchProtocol(BatchProtocol, Protocol): """Observations of an environment that a policy can turn into actions. @@ -62,6 +58,8 @@ class ActStateBatchProtocol(ActBatchProtocol, Protocol): """Contains action and state (which can be None), useful for policies that can support RNNs.""" state: dict | BatchProtocol | np.ndarray | None + """Hidden state of RNNs, or None if not using RNNs. Used for recurrent policies. + At the moment support for recurrent is experimental!""" class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol): @@ -121,7 +119,7 @@ class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): class ImitationBatchProtocol(ActBatchProtocol, Protocol): - """Similar to other batches, but contains imitation_logits and q_value fields.""" + """Similar to other batches, but contains `imitation_logits` and `q_value` fields.""" state: dict | Batch | np.ndarray | None q_value: torch.Tensor diff --git a/tianshou/evaluation/launcher.py b/tianshou/evaluation/launcher.py index 534e5f835..4dddae5af 100644 --- a/tianshou/evaluation/launcher.py +++ b/tianshou/evaluation/launcher.py @@ -27,6 +27,8 @@ class JoblibConfig: class ExpLauncher(ABC): + """Base interface for launching multiple experiments simultaneously.""" + def __init__( self, experiment_runner: Callable[ @@ -34,11 +36,13 @@ def __init__( InfoStats | None, ] = lambda exp: exp.run().trainer_result, ): - """:param experiment_runner: can be used to override the default way in which an experiment is executed. - Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment - collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms - that are not yet supported by the high-level interfaces. - Passing this allows arbitrary things to happen during experiment execution, so use it with caution! + """ + :param experiment_runner: determines how an experiment is to be executed. + Overriding the default can be useful, e.g., for using high-level interfaces + to set up an experiment (or an experiment collection) and tinkering with it prior to execution. + This need often arises when prototyping with mechanisms that are not yet supported by + the high-level interfaces. + Allows arbitrary things to happen during experiment execution, so use it with caution!. """ self.experiment_runner = experiment_runner @@ -112,6 +116,7 @@ def __init__( super().__init__(experiment_runner=experiment_runner) self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig() # Joblib's backend is hard-coded to loky since the threading backend produces different results + # TODO: fix this if self.joblib_cfg.backend != "loky": log.warning( f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. " diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index c1313262e..6c35710a6 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -125,15 +125,6 @@ def create_train_test_collector( if reset_collectors: train_collector.reset() test_collector.reset() - - if self.sampling_config.start_timesteps > 0: - log.info( - f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", - ) - train_collector.collect( - n_step=self.sampling_config.start_timesteps, - random=self.sampling_config.start_timesteps_random, - ) return train_collector, test_collector def set_policy_wrapper_factory( @@ -200,6 +191,7 @@ def create_trainer( batch_size=sampling_config.batch_size, step_per_collect=sampling_config.step_per_collect, save_best_fn=policy_persistence.get_save_best_fn(world), + save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), logger=world.logger, test_in_train=False, train_fn=train_fn, diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index c6c692f64..3cb9fadb2 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -361,11 +361,11 @@ def create_venv( class EnvFactory(ToStringMixin, ABC): - """Main interface for the creation of environments (in various forms).""" - def __init__(self, venv_type: VectorEnvType): - """:param venv_type: the type of vectorized environment to use for train and test environments. - watch environments are always created as dummy environments. + """Main interface for the creation of environments (in various forms). + + :param venv_type: the type of vectorized environment to use for train and test environments. + `WATCH` environments are always created as `DUMMY` vector environments. """ self.venv_type = venv_type @@ -377,7 +377,8 @@ def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: """Create vectorized environments. :param num_envs: the number of environments - :param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env. + :param mode: the mode for which to create. + In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env. :return: the vectorized environments """ @@ -437,9 +438,7 @@ def __init__( :param render_mode_train: the render mode to use for training environments :param render_mode_test: the render mode to use for test environments :param render_mode_watch: the render mode to use for environments that are used to watch agent performance - :param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. - If envpool is used, the gymnasium parameters will be appropriately translated for use with - `envpool.make_gymnasium`. + :param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. If envpool is used, the gymnasium parameters will be appropriately translated for use with `envpool.make_gymnasium`. """ super().__init__(venv_type) self.task = task diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index dbcd3f156..8a08f9a8f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,12 +1,35 @@ +"""The experiment module provides high-level interfaces for setting up and running reinforcement learning experiments. + +The main entry points are: + +* :class:`ExperimentConfig`: a dataclass for configuring the experiment. The configuration is + different from RL specific configuration (such as policy and trainer parameters) + and only pertains to configuration that is common to all experiments. +* :class:`Experiment`: represents a reinforcement learning experiment. + It is composed of configuration and factory objects, is lightweight and serializable. + An instance of `Experiment` is usually saved as a pickle file after an experiment is executed. +* :class:`ExperimentBuilder`: a helper class for creating experiments. It contains a lot of defaults + and allows for easy customization of the experiment setup. +* :class:`ExperimentCollection`: a shallow wrapper around a list of experiments providing a + simple interface for running them with a launcher. Useful for running multiple experiments in parallel, in + particular, for the important case of running experiments that only differ in their random seeds. + +Various implementations of the `ExperimentBuilder` are provided for each of the algorithms supported by Tianshou. +""" + import os import pickle from abc import abstractmethod from collections.abc import Sequence +from contextlib import suppress from copy import deepcopy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from pprint import pformat from typing import TYPE_CHECKING, Any, Self, Union, cast +if TYPE_CHECKING: + from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher + import numpy as np import torch @@ -51,7 +74,10 @@ ) from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory -from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam +from tianshou.highlevel.optim import ( + OptimizerFactory, + OptimizerFactoryAdam, +) from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, @@ -79,13 +105,11 @@ ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy -from tianshou.utils import LazyLogger, deprecation, logging +from tianshou.utils import LazyLogger, logging from tianshou.utils.logging import datetime_tag from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin - -if TYPE_CHECKING: - from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher +from tianshou.utils.warning import deprecation log = logging.getLogger(__name__) @@ -125,7 +149,12 @@ class ExperimentResult: """Contains the results of an experiment.""" world: World - """contains all the essential instances of the experiment""" + """The `World` contains all the essential instances of the experiment. + Can also be created via `Experiment.create_experiment_world` for more custom setups, see docstring there. + + Note: it is typically not serializable, so it is not stored in the experiment pickle, and shouldn't be + sent across processes, meaning also that `ExperimentResult` itself is typically not serializable. + """ trainer_result: InfoStats | None """dataclass of results as returned by the trainer (if any)""" @@ -136,6 +165,14 @@ class Experiment(ToStringMixin): An experiment is composed only of configuration and factory objects, which themselves should be designed to contain only configuration. Therefore, experiments can easily be stored/pickled and later restored without any problems. + + The main entry points are: + + 1. :meth:`run`: runs the experiment and returns the results + 2. :meth:`create_experiment_world`: creates the world object for the experiment, which contains all relevant instances. + Useful for setting up the experiment and running it in a more custom way. + + The methods :meth:`save` and :meth:`from_directory` can be used to store and restore experiments. """ LOG_FILENAME = "log.txt" @@ -204,43 +241,42 @@ def save(self, directory: str) -> None: with open(path, "wb") as f: pickle.dump(self, f) - def run( + def create_experiment_world( self, - run_name: str | None = None, + override_experiment_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, - **kwargs: dict[str, Any], - ) -> ExperimentResult: - """Run the experiment and return the results. - - :param run_name: Defines a name for this run of the experiment, which determines - the subdirectory (within the persistence base directory) where all results will be saved. - If None, the experiment's name will be used. + reset_collectors: bool = True, + ) -> World: + """Creates the world object for the experiment. + + The world object contains all relevant instances for the experiment, + such as environments, policy, collectors, etc. + This method is the main entrypoint for users who don't want to use `run` directly. A common use case + is that some configuration or custom logic should happen before the training loop starts, but one + still wants to use the convenience of high-level interfaces for setting up the experiment. + + :param override_experiment_name: pass to override the experiment name in the resulting `World`. + Affects the name of the persistence directory and logger configuration. If None, the experiment's + name will be used. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. - :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when - using wandb, in particular). - :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed - experiment with the same name. - :param kwargs: for backward compatibility with old parameter names only - :return: + :param logger_run_id: Run identifier to use for logger initialization/resumption. + :param raise_error_on_dirname_collision: whether to raise an error on collisions when creating the + persistence directory. Only takes effect if persistence is enabled. Set to `False` e.g., when continuing + a previously executed experiment with the same `persistence_base_dir` and name. + :param reset_collectors: whether to reset the collectors before training starts. + Setting to `False` can be useful when continuing training from a previous run with restored collectors, + or for adding custom logic before training starts. """ - # backward compatibility - _experiment_name = kwargs.pop("experiment_name", None) - if _experiment_name is not None: - run_name = cast(str, _experiment_name) - deprecation( - "Parameter run_name should now be used instead of experiment_name. " - "Support for experiment_name will be removed in the future.", - ) - assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" - - if run_name is None: - run_name = self.name + if override_experiment_name is not None: + exp_name = override_experiment_name + else: + exp_name = self.name # initialize persistence directory use_persistence = self.config.persistence_enabled - persistence_dir = os.path.join(self.config.persistence_base_dir, run_name) + persistence_dir = os.path.join(self.config.persistence_base_dir, exp_name) if use_persistence: os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) @@ -249,7 +285,7 @@ def run( enabled=use_persistence and self.config.log_file_enabled, ): # log initial information - log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}") + log.info(f"Preparing experiment world (name='{exp_name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() @@ -276,11 +312,16 @@ def run( # initialize logger full_config = self._build_config_dict() full_config.update(envs.info()) + full_config["experiment_config"] = asdict(self.config) + full_config["sampling_config"] = asdict(self.sampling_config) + with suppress(AttributeError): + full_config["policy_params"] = asdict(self.agent_factory.params) # type: ignore + logger: TLogger if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, - experiment_name=run_name, + experiment_name=exp_name, run_id=logger_run_id, config_dict=full_config, ) @@ -294,6 +335,7 @@ def run( train_collector, test_collector = self.agent_factory.create_train_test_collector( policy, envs, + reset_collectors=reset_collectors, ) # create context object with all relevant instances (except trainer; added later) @@ -315,23 +357,87 @@ def run( self.config.device, ) - # train policy - log.info("Starting training") - trainer_result: InfoStats | None = None if self.config.train: trainer = self.agent_factory.create_trainer(world, policy_persistence) world.trainer = trainer - trainer_result = trainer.run() + + return world + + def run( + self, + run_name: str | None = None, + logger_run_id: str | None = None, + raise_error_on_dirname_collision: bool = True, + **kwargs: dict[str, Any], + ) -> ExperimentResult: + """Run the experiment and return the results. + + :param run_name: Defines a name for this run of the experiment, which determines + the subdirectory (within the persistence base directory) where all results will be saved. + If None, the experiment's name will be used. + The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case + a nested directory structure will be created. + :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when + using wandb, in particular). + :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed + experiment with the same name. + :param kwargs: for backwards compatibility with old parameter names only + :return: + """ + # backward compatibility + _experiment_name = kwargs.pop("experiment_name", None) + if _experiment_name is not None: + run_name = cast(str, _experiment_name) + deprecation( + "Parameter run_name should now be used instead of experiment_name. " + "Support for experiment_name will be removed in the future.", + ) + assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" + + if run_name is None: + run_name = self.name + + world = self.create_experiment_world( + override_experiment_name=run_name, + logger_run_id=logger_run_id, + raise_error_on_dirname_collision=raise_error_on_dirname_collision, + ) + + persistence_dir = world.persist_directory + use_persistence = self.config.persistence_enabled + + with logging.FileLoggerContext( + os.path.join(persistence_dir, self.LOG_FILENAME), + enabled=use_persistence and self.config.log_file_enabled, + ): + trainer_result: InfoStats | None = None + if self.config.train: + # prefilling buffers with either random or current agent's actions + if self.sampling_config.start_timesteps > 0: + log.info( + f"Collecting {self.sampling_config.start_timesteps} initial environment " + f"steps before training (random={self.sampling_config.start_timesteps_random})", + ) + world.train_collector.collect( + n_step=self.sampling_config.start_timesteps, + random=self.sampling_config.start_timesteps_random, + ) + + log.info("Starting training") + assert world.trainer is not None + world.trainer.run() + if use_persistence: + world.logger.finalize() log.info(f"Training result:\n{pformat(trainer_result)}") # watch agent performance if self.config.watch: - assert envs.watch_env is not None + assert world.envs.watch_env is not None log.info("Watching agent performance") self._watch_agent( self.config.watch_num_episodes, - policy, - envs.watch_env, + world.policy, + world.envs.watch_env, self.config.watch_render, ) @@ -372,16 +478,31 @@ def run( class ExperimentBuilder: + """A helper class (following the builder pattern) for creating experiments. + + It contains a lot of defaults for the setup which can be adjusted using the + various `with_` methods. For example, the default optimizer is Adam, but can be + adjusted using :meth:`with_optim_factory`. Moreover, for simply configuring the default + optimizer instead of using a different one, one can use :meth:`with_optim_factory_default`. + """ + def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, sampling_config: SamplingConfig | None = None, ): + """:param env_factory: controls how environments are to be created. + :param experiment_config: the configuration for the experiment. If None, will use the default values + of `ExperimentConfig`. + :param sampling_config: the sampling configuration to use. If None, will use the default values + of `SamplingConfig`. + """ if experiment_config is None: experiment_config = ExperimentConfig() if sampling_config is None: sampling_config = SamplingConfig() + self._config = experiment_config self._env_factory = env_factory self._sampling_config = sampling_config @@ -443,6 +564,7 @@ def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: def with_optim_factory_default( self, + # Keep values in sync with default values in OptimizerFactoryAdam betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0, @@ -505,6 +627,7 @@ def _create_agent_factory(self) -> AgentFactory: def _get_optim_factory(self) -> OptimizerFactory: if self._optim_factory is None: + # same mechanism as in `with_optim_factory_default` return OptimizerFactoryAdam() else: return self._optim_factory @@ -531,7 +654,13 @@ def build(self) -> Experiment: def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed. - Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used. + Useful for performing statistically meaningful evaluations of an algorithm's performance. + The `rliable` recommendation is to use at least 5 experiments for computing quantities such as the + interquantile mean and performance profiles. See the usage in example scripts + like `examples/mujoco/mujoco_ppo_hl_multi.py`. + + Each experiment in the collection will have a unique name created from the original experiment name + and the seeds used. """ num_train_envs = self.sampling_config.num_train_envs @@ -553,7 +682,7 @@ def __init__(self, continuous_actor_type: ContinuousActorType): self._actor_factory: ActorFactory | None = None def with_actor_factory(self, actor_factory: ActorFactory) -> Self: - """Allows to customize the actor component via the specification of a factory. + """Allows customizing the actor component via the specification of a factory. If this function is not called, a default actor factory (with default parameters) will be used. diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index db5fd90ff..bdb01fbf0 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -45,6 +45,8 @@ def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim class OptimizerFactoryAdam(OptimizerFactory): + # Note: currently used as default optimizer + # values should be kept in sync with `ExperimentBuilder.with_optim_factory_default` def __init__( self, betas: tuple[float, float] = (0.9, 0.999), diff --git a/tianshou/highlevel/params/noise.py b/tianshou/highlevel/params/noise.py index 66e0c53c4..ca4ca1fdd 100644 --- a/tianshou/highlevel/params/noise.py +++ b/tianshou/highlevel/params/noise.py @@ -17,8 +17,8 @@ def __init__(self, std_fraction: float): This factory can only be applied to continuous action spaces. - :param std_fraction: fraction (between 0 and 1) of the maximum action value that shall - be used as the standard deviation + :param std_fraction: fraction (between 0 and 1) of the maximum action value that + shall be used as the standard deviation """ self.std_fraction = std_fraction diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 52b18d1b6..2758f5066 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from enum import Enum +from pathlib import Path from typing import TYPE_CHECKING import torch @@ -128,3 +129,26 @@ def save_best_fn(pol: torch.nn.Module) -> None: self.persist(pol, world) return save_best_fn + + def get_save_checkpoint_fn(self, world: World) -> Callable[[int, int, int], str] | None: + if not self.enabled: + return None + + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: + path = Path(self.mode.get_filename()) + path_with_epoch = path.with_stem(f"{path.stem}_epoch_{epoch}") + path = world.persist_path(path_with_epoch.name) + match self.mode: + case self.Mode.POLICY_STATE_DICT: + log.info(f"Saving policy state dictionary in {path}") + torch.save(world.policy.state_dict(), path) + case self.Mode.POLICY: + log.info(f"Saving policy object in {path}") + torch.save(world.policy, path) + case _: + raise NotImplementedError + if self.additional_persistence is not None: + self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world) + return path + + return save_checkpoint_fn diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 4eccc6a18..1083017d4 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -135,8 +135,9 @@ class EpochStopCallbackRewardThreshold(EpochStopCallback): """ def __init__(self, threshold: float | None = None): - """:param threshold: the reward threshold beyond which to stop training. - If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`. + """ + :param threshold: the reward threshold beyond which to stop training. + If it is None, will use threshold specified by the environment, i.e. `env.spec.reward_threshold`. """ self.threshold = threshold diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index c1bd73795..fadb0ad57 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -153,6 +153,9 @@ def restore_logged_data( :return: a dict containing the logged data. """ + def finalize(self) -> None: + """Finalize the logger, e.g. close the file handler.""" + class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer."""