From f40a6e5d5bee6f99baf838d47f94a7a189b7ae09 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 Aug 2024 18:12:37 +0200 Subject: [PATCH 1/8] Highlevel, persistence: more customization for policy persistence --- tianshou/highlevel/agent.py | 1 + tianshou/highlevel/persistence.py | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index c1313262e..765efc649 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -200,6 +200,7 @@ def create_trainer( batch_size=sampling_config.batch_size, step_per_collect=sampling_config.step_per_collect, save_best_fn=policy_persistence.get_save_best_fn(world), + save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world), logger=world.logger, test_in_train=False, train_fn=train_fn, diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 52b18d1b6..cc04d0653 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from enum import Enum +from pathlib import Path from typing import TYPE_CHECKING import torch @@ -128,3 +129,25 @@ def save_best_fn(pol: torch.nn.Module) -> None: self.persist(pol, world) return save_best_fn + + def get_save_checkpoint_fn(self, world: World) -> Callable[[int, int, int], str]: + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: + if not self.enabled: + return None + path = Path(self.mode.get_filename()) + path_with_epoch = path.with_stem(f"{path.stem}_epoch_{epoch}") + path = world.persist_path(path_with_epoch.name) + match self.mode: + case self.Mode.POLICY_STATE_DICT: + log.info(f"Saving policy state dictionary in {path}") + torch.save(world.policy.state_dict(), path) + case self.Mode.POLICY: + log.info(f"Saving policy object in {path}") + torch.save(world.policy, path) + case _: + raise NotImplementedError + if self.additional_persistence is not None: + self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world) + return path + + return save_checkpoint_fn From fe4efc989add3cd4715e35b8426153af6b6b347e Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 Aug 2024 18:17:10 +0200 Subject: [PATCH 2/8] Highlevel, experiment: permit extraction of World. Moved handling of `start_timesteps` from AgentFactory to Experiment.run. Slight internal simplifications, extended documentation --- pyproject.toml | 1 + tianshou/highlevel/agent.py | 9 -- tianshou/highlevel/experiment.py | 233 ++++++++++++++++++++++--------- tianshou/utils/logger/base.py | 3 + 4 files changed, 173 insertions(+), 73 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d332d9393..ed8b13347 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,6 +177,7 @@ ignore = [ "D106", # undocumented public nested class "D205", # blank line after summary (prevents summary-only docstrings, which makes no sense) "PLW2901", # overwrite vars in loop + "B027", # empty and non-abstract method in abstract class ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 765efc649..6c35710a6 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -125,15 +125,6 @@ def create_train_test_collector( if reset_collectors: train_collector.reset() test_collector.reset() - - if self.sampling_config.start_timesteps > 0: - log.info( - f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", - ) - train_collector.collect( - n_step=self.sampling_config.start_timesteps, - random=self.sampling_config.start_timesteps_random, - ) return train_collector, test_collector def set_policy_wrapper_factory( diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index dbcd3f156..62f09f802 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,11 +1,33 @@ +"""The experiment module provides high-level interfaces for setting up and running reinforcement learning experiments. + +The main entry points are: + - `ExperimentConfig`: a dataclass for configuring the experiment. The configuration is + different from RL specific configuration (such as policy and trainer parameters) + and only pertains to configuration that is common to all experiments. +- `Experiment`: represents a reinforcement learning experiment. + It is composed of configuration and factory objects, is lightweight and serializable. + An instance of `Experiment` is usually saved as a pickle file after an experiment is executed. +- `ExperimentBuilder`: a helper class for creating experiments. It contains a lot of defaults + and allows for easy customization of the experiment setup. +- `ExperimentCollection`: a shallow wrapper around a list of experiments providing a + simple interface for running them with a launcher. Useful for running multiple experiments in parallel, in + particular, for the important case of running experiments that only differ in their random seeds. + +Various implementations of the `ExperimentBuilder` are provided for each of the algorithms supported by Tianshou. +""" + import os import pickle from abc import abstractmethod from collections.abc import Sequence +from contextlib import suppress from copy import deepcopy -from dataclasses import dataclass +from dataclasses import asdict, dataclass from pprint import pformat -from typing import TYPE_CHECKING, Any, Self, Union, cast +from typing import TYPE_CHECKING, Literal, Self, Union, Unpack + +if TYPE_CHECKING: + from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher import numpy as np import torch @@ -51,7 +73,13 @@ ) from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory -from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam +from tianshou.highlevel.optim import ( + DEFAULT_OPTIM_FACTORY_PARAMS, + DefaultOptimFactoryParams, + OptimizerFactory, + OptimizerFactoryAdam, + OptimizerFactoryDefault, +) from tianshou.highlevel.params.policy_params import ( A2CParams, DDPGParams, @@ -79,14 +107,11 @@ ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy -from tianshou.utils import LazyLogger, deprecation, logging +from tianshou.utils import LazyLogger, logging from tianshou.utils.logging import datetime_tag from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin -if TYPE_CHECKING: - from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher - log = logging.getLogger(__name__) @@ -125,7 +150,12 @@ class ExperimentResult: """Contains the results of an experiment.""" world: World - """contains all the essential instances of the experiment""" + """The `World` contains all the essential instances of the experiment. + Can also be created via `Experiment.create_experiment_world` for more custom setups, see docstring there. + + Note: it is typically not serializable, so it is not stored in the experiment pickle, and shouldn't be + sent across processes, meaning also that `ExperimentResult` itself is typically not serializable. + """ trainer_result: InfoStats | None """dataclass of results as returned by the trainer (if any)""" @@ -136,6 +166,14 @@ class Experiment(ToStringMixin): An experiment is composed only of configuration and factory objects, which themselves should be designed to contain only configuration. Therefore, experiments can easily be stored/pickled and later restored without any problems. + + The main entry points are: + + 1. `run`: runs the experiment and returns the results + 2. `create_experiment_world`: creates the world object for the experiment, which contains all relevant instances. + Useful for setting up the experiment and running it in a more custom way. + + The methods `save` and `from_directory` can be used to store and restore experiments. """ LOG_FILENAME = "log.txt" @@ -204,43 +242,39 @@ def save(self, directory: str) -> None: with open(path, "wb") as f: pickle.dump(self, f) - def run( + def create_experiment_world( self, - run_name: str | None = None, + override_experiment_name: str | Literal["DATETIME_TAG"] | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, - **kwargs: dict[str, Any], - ) -> ExperimentResult: - """Run the experiment and return the results. - - :param run_name: Defines a name for this run of the experiment, which determines - the subdirectory (within the persistence base directory) where all results will be saved. - If None, the experiment's name will be used. - The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case - a nested directory structure will be created. - :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when - using wandb, in particular). - :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed - experiment with the same name. - :param kwargs: for backward compatibility with old parameter names only - :return: + reset_collectors: bool = True, + ) -> World: + """Creates the world object for the experiment. + + The world object contains all relevant instances for the experiment, + such as environments, policy, collectors, etc. + This method is the main entrypoint for users who don't want to use `run` directly. A common use case + is that some configuration or custom logic should happen before the training loop starts, but one + still wants to use the convenience of high-level interfaces for setting up the experiment. + + :param override_experiment_name: whether to override the experiment name in the resulting `World`. + Passing `DATETIME_TAG` will use a name containing the current date and time. + :param logger_run_id: Run identifier to use for logger initialization/resumption. + :param raise_error_on_dirname_collision: whether to raise an error on collisions when creating the + persistence directory. Only takes effect if persistence is enabled. Set to `False` e.g., when continuing + a previously executed experiment with the same `persistence_base_dir` and name. + :param reset_collectors: whether to reset the collectors before training starts. + Setting to `False` can be useful when continuing training from a previous run with restored collectors, + or for adding custom logic before training starts. """ - # backward compatibility - _experiment_name = kwargs.pop("experiment_name", None) - if _experiment_name is not None: - run_name = cast(str, _experiment_name) - deprecation( - "Parameter run_name should now be used instead of experiment_name. " - "Support for experiment_name will be removed in the future.", - ) - assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" - - if run_name is None: - run_name = self.name + if override_experiment_name is not None: + if override_experiment_name == "DATETIME_TAG": + override_experiment_name = datetime_tag() + self.name = override_experiment_name # initialize persistence directory use_persistence = self.config.persistence_enabled - persistence_dir = os.path.join(self.config.persistence_base_dir, run_name) + persistence_dir = os.path.join(self.config.persistence_base_dir, self.name) if use_persistence: os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) @@ -249,7 +283,7 @@ def run( enabled=use_persistence and self.config.log_file_enabled, ): # log initial information - log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}") + log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() @@ -276,11 +310,16 @@ def run( # initialize logger full_config = self._build_config_dict() full_config.update(envs.info()) + full_config["experiment_config"] = asdict(self.config) + full_config["sampling_config"] = asdict(self.sampling_config) + with suppress(AttributeError): + full_config["policy_params"] = asdict(self.agent_factory.params) # type: ignore + logger: TLogger if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, - experiment_name=run_name, + experiment_name=self.name, run_id=logger_run_id, config_dict=full_config, ) @@ -294,6 +333,7 @@ def run( train_collector, test_collector = self.agent_factory.create_train_test_collector( policy, envs, + reset_collectors=reset_collectors, ) # create context object with all relevant instances (except trainer; added later) @@ -315,23 +355,74 @@ def run( self.config.device, ) - # train policy - log.info("Starting training") - trainer_result: InfoStats | None = None if self.config.train: trainer = self.agent_factory.create_trainer(world, policy_persistence) world.trainer = trainer - trainer_result = trainer.run() + + return world + + def run( + self, + run_name: str | Literal["DATETIME_TAG"] | None = None, + logger_run_id: str | None = None, + raise_error_on_dirname_collision: bool = True, + ) -> ExperimentResult: + """Run the experiment and return the results. + + :param run_name: if not None, will adjust the current instance's `name` name attribute. + The name corresponds to the directory (within the logging + directory) where all results associated with the experiment will be saved. + The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case + a nested directory structure will be created. + If "DATETIME_TAG" is passed, use a name containing the current date and time. This option + is useful for preventing file-name collisions if a single experiment is executed repeatedly. + :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when + using wandb, in particular). + :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed + experiment with the same name. + :return: + """ + world = self.create_experiment_world( + override_experiment_name=run_name, + logger_run_id=logger_run_id, + raise_error_on_dirname_collision=raise_error_on_dirname_collision, + ) + + persistence_dir = world.persist_directory + use_persistence = self.config.persistence_enabled + + with logging.FileLoggerContext( + os.path.join(persistence_dir, self.LOG_FILENAME), + enabled=use_persistence and self.config.log_file_enabled, + ): + trainer_result: InfoStats | None = None + if self.config.train: + # prefilling buffers with random actions + if self.sampling_config.start_timesteps > 0: + log.info( + f"Collecting {self.sampling_config.start_timesteps} initial environment " + f"steps before training (random={self.sampling_config.start_timesteps_random})", + ) + world.train_collector.collect( + n_step=self.sampling_config.start_timesteps, + random=self.sampling_config.start_timesteps_random, + ) + + log.info("Starting training") + assert world.trainer is not None + world.trainer.run() + if use_persistence: + world.logger.finalize() log.info(f"Training result:\n{pformat(trainer_result)}") # watch agent performance if self.config.watch: - assert envs.watch_env is not None + assert world.envs.watch_env is not None log.info("Watching agent performance") self._watch_agent( self.config.watch_num_episodes, - policy, - envs.watch_env, + world.policy, + world.envs.watch_env, self.config.watch_render, ) @@ -372,19 +463,30 @@ def run( class ExperimentBuilder: + OPTIM_FACTORY_DEFAULT_CLS = OptimizerFactoryAdam + def __init__( self, env_factory: EnvFactory, experiment_config: ExperimentConfig | None = None, sampling_config: SamplingConfig | None = None, ): - if experiment_config is None: - experiment_config = ExperimentConfig() - if sampling_config is None: - sampling_config = SamplingConfig() - self._config = experiment_config + """A helper class (following the builder pattern) for creating experiments. + + It contains a lot of defaults for the setup which can be adjusted using the + various `with_` methods. For example, the default optimizer is Adam, but can be + adjusted using `with_optim_factory`. Moreover, for simply configuring the default + optimizer instead of using a different one, one can use `with_optim_factory_default`. + + :param env_factory: controls how environments are to be created. + :param experiment_config: the configuration for the experiment. If None, will use the default values + of `ExperimentConfig`. + :param sampling_config: the sampling configuration to use. If None, will use the default values + of `SamplingConfig`. + """ + self._config = experiment_config or ExperimentConfig() self._env_factory = env_factory - self._sampling_config = sampling_config + self._sampling_config = sampling_config or SamplingConfig() self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None @@ -443,18 +545,15 @@ def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: def with_optim_factory_default( self, - betas: tuple[float, float] = (0.9, 0.999), - eps: float = 1e-08, - weight_decay: float = 0, + **kwargs: Unpack[DefaultOptimFactoryParams], ) -> Self: - """Configures the use of the default optimizer, Adam, with the given parameters. + """Configures the use of the default optimizer, with the given parameters. - :param betas: coefficients used for computing running averages of gradient and its square - :param eps: term added to the denominator to improve numerical stability - :param weight_decay: weight decay (L2 penalty) + :param kwargs: the parameters to use for the optimizer, see `DefaultOptimFactoryParams`. :return: the builder """ - self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) + default_optim_params: DefaultOptimFactoryParams = {**DEFAULT_OPTIM_FACTORY_PARAMS, **kwargs} + self._optim_factory = OptimizerFactoryDefault(**default_optim_params) return self def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self: @@ -505,7 +604,7 @@ def _create_agent_factory(self) -> AgentFactory: def _get_optim_factory(self) -> OptimizerFactory: if self._optim_factory is None: - return OptimizerFactoryAdam() + return OptimizerFactoryDefault() else: return self._optim_factory @@ -531,7 +630,13 @@ def build(self) -> Experiment: def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed. - Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used. + Useful for performing statistically meaningful evaluations of an algorithm's performance. + The `rliable` recommendation is to use at least 5 experiments for computing quantities such as the + interquantile mean and performance profiles. See the usage in example scripts + like `examples/mujoco/mujoco_ppo_hl_multi.py`. + + Each experiment in the collection will have a unique name created from the original experiment name + and the seeds used. """ num_train_envs = self.sampling_config.num_train_envs @@ -553,7 +658,7 @@ def __init__(self, continuous_actor_type: ContinuousActorType): self._actor_factory: ActorFactory | None = None def with_actor_factory(self, actor_factory: ActorFactory) -> Self: - """Allows to customize the actor component via the specification of a factory. + """Allows customizing the actor component via the specification of a factory. If this function is not called, a default actor factory (with default parameters) will be used. diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index c1bd73795..fadb0ad57 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -153,6 +153,9 @@ def restore_logged_data( :return: a dict containing the logged data. """ + def finalize(self) -> None: + """Finalize the logger, e.g. close the file handler.""" + class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" From 88017e5e3a236b3647818a24b515aa0e2b040b8b Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 Aug 2024 18:27:58 +0200 Subject: [PATCH 3/8] Highlevel, persistence: fix return of get_save_checkpoint_fn --- tianshou/highlevel/persistence.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index cc04d0653..7e0e16810 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -76,6 +76,7 @@ class Mode(Enum): def get_filename(self) -> str: return self.value + ".pt" + def __init__( self, additional_persistence: Persistence | None = None, @@ -130,10 +131,11 @@ def save_best_fn(pol: torch.nn.Module) -> None: return save_best_fn - def get_save_checkpoint_fn(self, world: World) -> Callable[[int, int, int], str]: + def get_save_checkpoint_fn(self, world: World) -> Callable[[int, int, int], str] | None: + if not self.enabled: + return None + def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: - if not self.enabled: - return None path = Path(self.mode.get_filename()) path_with_epoch = path.with_stem(f"{path.stem}_epoch_{epoch}") path = world.persist_path(path_with_epoch.name) From 488251585340139a5971ead75c5faa0fcd0ab9d1 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 4 Aug 2024 18:49:44 +0200 Subject: [PATCH 4/8] Docs --- docs/_config.yml | 4 +++ docs/spelling_wordlist.txt | 5 +++ pyproject.toml | 1 + tianshou/evaluation/launcher.py | 15 ++++++--- tianshou/highlevel/env.py | 15 +++++---- tianshou/highlevel/experiment.py | 49 +++++++++++++++--------------- tianshou/highlevel/params/noise.py | 4 +-- tianshou/highlevel/persistence.py | 1 - tianshou/highlevel/trainer.py | 5 +-- 9 files changed, 56 insertions(+), 43 deletions(-) diff --git a/docs/_config.yml b/docs/_config.yml index 925c99439..a0bb290a2 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -102,6 +102,10 @@ sphinx: recursive_update : false # A boolean indicating whether to overwrite the Sphinx config (true) or recursively update (false) config : # key-value pairs to directly over-ride the Sphinx configuration autodoc_typehints_format: "short" + autodoc_member_order: "bysource" + autoclass_content: "both" + autodoc_default_options: + show-inheritance: True html_js_files: # We have to list them explicitly because they need to be loaded in a specific order - js/vega@5.js diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c30b9f2cb..94a941d20 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -271,3 +271,8 @@ v_s_ obs obs_next dtype +entrypoint +interquantile +init +kwarg +kwargs diff --git a/pyproject.toml b/pyproject.toml index ed8b13347..bc068cda9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -176,6 +176,7 @@ ignore = [ "RET505", "D106", # undocumented public nested class "D205", # blank line after summary (prevents summary-only docstrings, which makes no sense) + "D212", # no blank line after """. This clashes with sphinx for multiline descriptions of :param: that start directly after """ "PLW2901", # overwrite vars in loop "B027", # empty and non-abstract method in abstract class ] diff --git a/tianshou/evaluation/launcher.py b/tianshou/evaluation/launcher.py index 534e5f835..4dddae5af 100644 --- a/tianshou/evaluation/launcher.py +++ b/tianshou/evaluation/launcher.py @@ -27,6 +27,8 @@ class JoblibConfig: class ExpLauncher(ABC): + """Base interface for launching multiple experiments simultaneously.""" + def __init__( self, experiment_runner: Callable[ @@ -34,11 +36,13 @@ def __init__( InfoStats | None, ] = lambda exp: exp.run().trainer_result, ): - """:param experiment_runner: can be used to override the default way in which an experiment is executed. - Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment - collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms - that are not yet supported by the high-level interfaces. - Passing this allows arbitrary things to happen during experiment execution, so use it with caution! + """ + :param experiment_runner: determines how an experiment is to be executed. + Overriding the default can be useful, e.g., for using high-level interfaces + to set up an experiment (or an experiment collection) and tinkering with it prior to execution. + This need often arises when prototyping with mechanisms that are not yet supported by + the high-level interfaces. + Allows arbitrary things to happen during experiment execution, so use it with caution!. """ self.experiment_runner = experiment_runner @@ -112,6 +116,7 @@ def __init__( super().__init__(experiment_runner=experiment_runner) self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig() # Joblib's backend is hard-coded to loky since the threading backend produces different results + # TODO: fix this if self.joblib_cfg.backend != "loky": log.warning( f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. " diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index c6c692f64..3cb9fadb2 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -361,11 +361,11 @@ def create_venv( class EnvFactory(ToStringMixin, ABC): - """Main interface for the creation of environments (in various forms).""" - def __init__(self, venv_type: VectorEnvType): - """:param venv_type: the type of vectorized environment to use for train and test environments. - watch environments are always created as dummy environments. + """Main interface for the creation of environments (in various forms). + + :param venv_type: the type of vectorized environment to use for train and test environments. + `WATCH` environments are always created as `DUMMY` vector environments. """ self.venv_type = venv_type @@ -377,7 +377,8 @@ def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: """Create vectorized environments. :param num_envs: the number of environments - :param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env. + :param mode: the mode for which to create. + In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env. :return: the vectorized environments """ @@ -437,9 +438,7 @@ def __init__( :param render_mode_train: the render mode to use for training environments :param render_mode_test: the render mode to use for test environments :param render_mode_watch: the render mode to use for environments that are used to watch agent performance - :param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. - If envpool is used, the gymnasium parameters will be appropriately translated for use with - `envpool.make_gymnasium`. + :param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. If envpool is used, the gymnasium parameters will be appropriately translated for use with `envpool.make_gymnasium`. """ super().__init__(venv_type) self.task = task diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 62f09f802..d3dec9c69 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,17 +1,18 @@ """The experiment module provides high-level interfaces for setting up and running reinforcement learning experiments. The main entry points are: - - `ExperimentConfig`: a dataclass for configuring the experiment. The configuration is - different from RL specific configuration (such as policy and trainer parameters) - and only pertains to configuration that is common to all experiments. -- `Experiment`: represents a reinforcement learning experiment. - It is composed of configuration and factory objects, is lightweight and serializable. - An instance of `Experiment` is usually saved as a pickle file after an experiment is executed. -- `ExperimentBuilder`: a helper class for creating experiments. It contains a lot of defaults - and allows for easy customization of the experiment setup. -- `ExperimentCollection`: a shallow wrapper around a list of experiments providing a - simple interface for running them with a launcher. Useful for running multiple experiments in parallel, in - particular, for the important case of running experiments that only differ in their random seeds. + +* :class:`ExperimentConfig`: a dataclass for configuring the experiment. The configuration is + different from RL specific configuration (such as policy and trainer parameters) + and only pertains to configuration that is common to all experiments. +* :class:`Experiment`: represents a reinforcement learning experiment. + It is composed of configuration and factory objects, is lightweight and serializable. + An instance of `Experiment` is usually saved as a pickle file after an experiment is executed. +* :class:`ExperimentBuilder`: a helper class for creating experiments. It contains a lot of defaults + and allows for easy customization of the experiment setup. +* :class:`ExperimentCollection`: a shallow wrapper around a list of experiments providing a + simple interface for running them with a launcher. Useful for running multiple experiments in parallel, in + particular, for the important case of running experiments that only differ in their random seeds. Various implementations of the `ExperimentBuilder` are provided for each of the algorithms supported by Tianshou. """ @@ -77,7 +78,6 @@ DEFAULT_OPTIM_FACTORY_PARAMS, DefaultOptimFactoryParams, OptimizerFactory, - OptimizerFactoryAdam, OptimizerFactoryDefault, ) from tianshou.highlevel.params.policy_params import ( @@ -169,11 +169,11 @@ class Experiment(ToStringMixin): The main entry points are: - 1. `run`: runs the experiment and returns the results - 2. `create_experiment_world`: creates the world object for the experiment, which contains all relevant instances. + 1. :meth:`run`: runs the experiment and returns the results + 2. :meth:`create_experiment_world`: creates the world object for the experiment, which contains all relevant instances. Useful for setting up the experiment and running it in a more custom way. - The methods `save` and `from_directory` can be used to store and restore experiments. + The methods :meth:`save` and :meth:`from_directory` can be used to store and restore experiments. """ LOG_FILENAME = "log.txt" @@ -313,7 +313,7 @@ def create_experiment_world( full_config["experiment_config"] = asdict(self.config) full_config["sampling_config"] = asdict(self.sampling_config) with suppress(AttributeError): - full_config["policy_params"] = asdict(self.agent_factory.params) # type: ignore + full_config["policy_params"] = asdict(self.agent_factory.params) # type: ignore logger: TLogger if use_persistence: @@ -463,7 +463,13 @@ def run( class ExperimentBuilder: - OPTIM_FACTORY_DEFAULT_CLS = OptimizerFactoryAdam + """A helper class (following the builder pattern) for creating experiments. + + It contains a lot of defaults for the setup which can be adjusted using the + various `with_` methods. For example, the default optimizer is Adam, but can be + adjusted using :meth:`with_optim_factory`. Moreover, for simply configuring the default + optimizer instead of using a different one, one can use :meth:`with_optim_factory_default`. + """ def __init__( self, @@ -471,14 +477,7 @@ def __init__( experiment_config: ExperimentConfig | None = None, sampling_config: SamplingConfig | None = None, ): - """A helper class (following the builder pattern) for creating experiments. - - It contains a lot of defaults for the setup which can be adjusted using the - various `with_` methods. For example, the default optimizer is Adam, but can be - adjusted using `with_optim_factory`. Moreover, for simply configuring the default - optimizer instead of using a different one, one can use `with_optim_factory_default`. - - :param env_factory: controls how environments are to be created. + """:param env_factory: controls how environments are to be created. :param experiment_config: the configuration for the experiment. If None, will use the default values of `ExperimentConfig`. :param sampling_config: the sampling configuration to use. If None, will use the default values diff --git a/tianshou/highlevel/params/noise.py b/tianshou/highlevel/params/noise.py index 66e0c53c4..ca4ca1fdd 100644 --- a/tianshou/highlevel/params/noise.py +++ b/tianshou/highlevel/params/noise.py @@ -17,8 +17,8 @@ def __init__(self, std_fraction: float): This factory can only be applied to continuous action spaces. - :param std_fraction: fraction (between 0 and 1) of the maximum action value that shall - be used as the standard deviation + :param std_fraction: fraction (between 0 and 1) of the maximum action value that + shall be used as the standard deviation """ self.std_fraction = std_fraction diff --git a/tianshou/highlevel/persistence.py b/tianshou/highlevel/persistence.py index 7e0e16810..2758f5066 100644 --- a/tianshou/highlevel/persistence.py +++ b/tianshou/highlevel/persistence.py @@ -76,7 +76,6 @@ class Mode(Enum): def get_filename(self) -> str: return self.value + ".pt" - def __init__( self, additional_persistence: Persistence | None = None, diff --git a/tianshou/highlevel/trainer.py b/tianshou/highlevel/trainer.py index 4eccc6a18..1083017d4 100644 --- a/tianshou/highlevel/trainer.py +++ b/tianshou/highlevel/trainer.py @@ -135,8 +135,9 @@ class EpochStopCallbackRewardThreshold(EpochStopCallback): """ def __init__(self, threshold: float | None = None): - """:param threshold: the reward threshold beyond which to stop training. - If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`. + """ + :param threshold: the reward threshold beyond which to stop training. + If it is None, will use threshold specified by the environment, i.e. `env.spec.reward_threshold`. """ self.threshold = threshold From d30bd3c2d67408d29a08a69376cec7692de4671c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 5 Aug 2024 11:01:03 +0200 Subject: [PATCH 5/8] Docs for batch module --- docs/spelling_wordlist.txt | 4 +++ pyproject.toml | 1 + tianshou/data/batch.py | 58 +++++++++++++++++++++++++++++++++----- tianshou/data/types.py | 8 ++---- 4 files changed, 59 insertions(+), 12 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 94a941d20..04a4a7803 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -276,3 +276,7 @@ interquantile init kwarg kwargs +autocompletion +codebase +indexable +sliceable diff --git a/pyproject.toml b/pyproject.toml index bc068cda9..2b7ef9744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,6 +179,7 @@ ignore = [ "D212", # no blank line after """. This clashes with sphinx for multiline descriptions of :param: that start directly after """ "PLW2901", # overwrite vars in loop "B027", # empty and non-abstract method in abstract class + "D404", # It's fine to start with "This" in docstrings ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index c64b0fc75..a1e313db4 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,3 +1,47 @@ +"""This module implements :class:`Batch`, a flexible data structure for +handling heterogeneous data in reinforcement learning algorithms. Such a data structure +is needed since RL algorithms differ widely in the conceptual fields that they need. +`Batch` is the main data carrier in Tianshou. It bears some similarities to +`TensorDict `_ +that is used for a similar purpose in `pytorch-rl `_. +The main differences between the two are that `Batch` can hold arbitrary objects (and not just torch tensors), +and that Tianshou implements `BatchProtocol` for enabling type checking and autocompletion (more on that below). + +The `Batch` class is designed to store and manipulate collections of data with +varying types and structures. It strikes a balance between flexibility and type safety, the latter mainly +achieved through the use of protocols. One can thing of it as a mixture of a dictionary and an array, +as it has both key-value pairs and nesting, while also having a shape, being indexable and sliceable. + +Key features of the `Batch` class include: + +1. Flexible data storage: Can hold numpy arrays, torch tensors, scalars, and nested Batch objects. +2. Dynamic attribute access: Allows setting and accessing data using attribute notation (e.g., `batch.observation`). + This allows for type-safe and readable code and enables IDE autocompletion. See comments on `BatchProtocol` below. +3. Indexing and slicing: Supports numpy-like indexing and slicing operations. The slicing is extended to nested + Batch objects and torch Distributions. +4. Batch operations: Provides methods for splitting, shuffling, concatenating and stacking multiple Batch objects. +5. Data type conversion: Offers methods to convert data between numpy arrays and torch tensors. +6. Value transformations: Allows applying functions to all values in the Batch recursively. +7. Analysis utilities: Provides methods for checking for missing values, dropping entries with missing values, + and others. + +Since we want to keep `Batch` flexible and not fix a specific set of fields or their types, +we don't have fixed interfaces for actual `Batch` objects that are used throughout +tianshou (such interfaces could be dataclasses, for example). However, we still want to enable +IDE autocompletion and type checking for `Batch` objects. To achieve this, we rely on dynamic duck typing +by using `Protocol`. The :class:`BatchProtocol` defines the interface that all `Batch` objects should adhere to, +and its various implementations (like :class:`~.types.ActBatchProtocol` or :class:`~.types.RolloutBatchProtocol`) define the specific +fields that are expected in the respective `Batch` objects. The protocols are then used as type hints +throughout the codebase. Protocols can't be instantiated, but we can cast to them. +For example, we "instantiate" an `ActBatchProtocol` with something like: + +>>> act_batch = cast(ActBatchProtocol, Batch(act=my_action)) + +The users can decide for themselves how to structure their `Batch` objects, and can opt in to the +`BatchProtocol` style to enable type checking and autocompletion. Opting out will have no effect on +the functionality. +""" + import pprint import warnings from collections.abc import Callable, Collection, Iterable, Iterator, KeysView, Sequence @@ -563,10 +607,10 @@ class Batch(BatchProtocol): def __init__( self, batch_dict: dict - | BatchProtocol - | Sequence[dict | BatchProtocol] - | np.ndarray - | None = None, + | BatchProtocol + | Sequence[dict | BatchProtocol] + | np.ndarray + | None = None, copy: bool = False, **kwargs: Any, ) -> None: @@ -906,10 +950,10 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: if isinstance(value, Batch) and len(value.get_keys()) == 0: continue try: - self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value + self.__dict__[key][sum_lens[i]: sum_lens[i + 1]] = value except KeyError: self.__dict__[key] = create_value(value, sum_lens[-1], stack=False) - self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value + self.__dict__[key][sum_lens[i]: sum_lens[i + 1]] = value def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: if isinstance(batches, BatchProtocol | dict): @@ -1144,7 +1188,7 @@ def split( if merge_last and idx + size + size >= length: yield self[indices[idx:]] break - yield self[indices[idx : idx + size]] + yield self[indices[idx: idx + size]] @overload def apply_values_transform( diff --git a/tianshou/data/types.py b/tianshou/data/types.py index a4fd43543..fd2f6d287 100644 --- a/tianshou/data/types.py +++ b/tianshou/data/types.py @@ -9,10 +9,6 @@ TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"] -d: dict[str, TNestedDictValue] = {"a": {"b": np.array([1, 2, 3])}} -d["c"] = np.array([1, 2, 3]) - - class ObsBatchProtocol(BatchProtocol, Protocol): """Observations of an environment that a policy can turn into actions. @@ -62,6 +58,8 @@ class ActStateBatchProtocol(ActBatchProtocol, Protocol): """Contains action and state (which can be None), useful for policies that can support RNNs.""" state: dict | BatchProtocol | np.ndarray | None + """Hidden state of RNNs, or None if not using RNNs. Used for recurrent policies. + At the moment support for recurrent is experimental!""" class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol): @@ -121,7 +119,7 @@ class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol): class ImitationBatchProtocol(ActBatchProtocol, Protocol): - """Similar to other batches, but contains imitation_logits and q_value fields.""" + """Similar to other batches, but contains `imitation_logits` and `q_value` fields.""" state: dict | Batch | np.ndarray | None q_value: torch.Tensor From 0c6e598d0b10b0bf8d80f464ae4b067b8d384bb2 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 5 Aug 2024 11:02:40 +0200 Subject: [PATCH 6/8] Formatting --- tianshou/data/batch.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a1e313db4..4ea99dbcb 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -607,10 +607,10 @@ class Batch(BatchProtocol): def __init__( self, batch_dict: dict - | BatchProtocol - | Sequence[dict | BatchProtocol] - | np.ndarray - | None = None, + | BatchProtocol + | Sequence[dict | BatchProtocol] + | np.ndarray + | None = None, copy: bool = False, **kwargs: Any, ) -> None: @@ -950,10 +950,10 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: if isinstance(value, Batch) and len(value.get_keys()) == 0: continue try: - self.__dict__[key][sum_lens[i]: sum_lens[i + 1]] = value + self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value except KeyError: self.__dict__[key] = create_value(value, sum_lens[-1], stack=False) - self.__dict__[key][sum_lens[i]: sum_lens[i + 1]] = value + self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: if isinstance(batches, BatchProtocol | dict): @@ -1188,7 +1188,7 @@ def split( if merge_last and idx + size + size >= length: yield self[indices[idx:]] break - yield self[indices[idx: idx + size]] + yield self[indices[idx : idx + size]] @overload def apply_values_transform( From 892736dd43d627560a48daf68dce32dacffd0c4d Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 6 Aug 2024 10:26:44 +0200 Subject: [PATCH 7/8] Experiment: restored backwards compat. for passing experiment_name Also minor aesthetic changes --- tianshou/highlevel/experiment.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index d3dec9c69..04d57eb88 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -111,6 +111,7 @@ from tianshou.utils.logging import datetime_tag from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin +from tianshou.utils.warning import deprecation log = logging.getLogger(__name__) @@ -366,6 +367,7 @@ def run( run_name: str | Literal["DATETIME_TAG"] | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, + **kwargs, ) -> ExperimentResult: """Run the experiment and return the results. @@ -380,8 +382,22 @@ def run( using wandb, in particular). :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed experiment with the same name. + :param kwargs: for backwards compatibility with old parameter names only :return: """ + # backward compatibility + _experiment_name = kwargs.pop("experiment_name", None) + if _experiment_name is not None: + run_name = _experiment_name + deprecation( + "Parameter run_name should now be used instead of experiment_name. " + "Support for experiment_name will be removed in the future.", + ) + assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" + + if run_name is None: + run_name = self.name + world = self.create_experiment_world( override_experiment_name=run_name, logger_run_id=logger_run_id, @@ -397,7 +413,7 @@ def run( ): trainer_result: InfoStats | None = None if self.config.train: - # prefilling buffers with random actions + # prefilling buffers with either random or current agent's actions if self.sampling_config.start_timesteps > 0: log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment " @@ -483,9 +499,14 @@ def __init__( :param sampling_config: the sampling configuration to use. If None, will use the default values of `SamplingConfig`. """ - self._config = experiment_config or ExperimentConfig() + if experiment_config is None: + experiment_config = ExperimentConfig() + if sampling_config is None: + sampling_config = SamplingConfig() + + self._config = experiment_config self._env_factory = env_factory - self._sampling_config = sampling_config or SamplingConfig() + self._sampling_config = sampling_config self._logger_factory: LoggerFactory | None = None self._optim_factory: OptimizerFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None From 4acdac724a510940181661e8991dfdbaa33d560a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 6 Aug 2024 12:59:16 +0200 Subject: [PATCH 8/8] ExperimentBuilder: Restored run_name logic --- tianshou/highlevel/experiment.py | 58 +++++++++++++++++--------------- tianshou/highlevel/optim.py | 2 ++ 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 04d57eb88..8a08f9a8f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -25,7 +25,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from pprint import pformat -from typing import TYPE_CHECKING, Literal, Self, Union, Unpack +from typing import TYPE_CHECKING, Any, Self, Union, cast if TYPE_CHECKING: from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher @@ -75,10 +75,8 @@ from tianshou.highlevel.module.intermediate import IntermediateModuleFactory from tianshou.highlevel.module.special import ImplicitQuantileNetworkFactory from tianshou.highlevel.optim import ( - DEFAULT_OPTIM_FACTORY_PARAMS, - DefaultOptimFactoryParams, OptimizerFactory, - OptimizerFactoryDefault, + OptimizerFactoryAdam, ) from tianshou.highlevel.params.policy_params import ( A2CParams, @@ -245,7 +243,7 @@ def save(self, directory: str) -> None: def create_experiment_world( self, - override_experiment_name: str | Literal["DATETIME_TAG"] | None = None, + override_experiment_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, reset_collectors: bool = True, @@ -258,8 +256,11 @@ def create_experiment_world( is that some configuration or custom logic should happen before the training loop starts, but one still wants to use the convenience of high-level interfaces for setting up the experiment. - :param override_experiment_name: whether to override the experiment name in the resulting `World`. - Passing `DATETIME_TAG` will use a name containing the current date and time. + :param override_experiment_name: pass to override the experiment name in the resulting `World`. + Affects the name of the persistence directory and logger configuration. If None, the experiment's + name will be used. + The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case + a nested directory structure will be created. :param logger_run_id: Run identifier to use for logger initialization/resumption. :param raise_error_on_dirname_collision: whether to raise an error on collisions when creating the persistence directory. Only takes effect if persistence is enabled. Set to `False` e.g., when continuing @@ -269,13 +270,13 @@ def create_experiment_world( or for adding custom logic before training starts. """ if override_experiment_name is not None: - if override_experiment_name == "DATETIME_TAG": - override_experiment_name = datetime_tag() - self.name = override_experiment_name + exp_name = override_experiment_name + else: + exp_name = self.name # initialize persistence directory use_persistence = self.config.persistence_enabled - persistence_dir = os.path.join(self.config.persistence_base_dir, self.name) + persistence_dir = os.path.join(self.config.persistence_base_dir, exp_name) if use_persistence: os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) @@ -284,7 +285,7 @@ def create_experiment_world( enabled=use_persistence and self.config.log_file_enabled, ): # log initial information - log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}") + log.info(f"Preparing experiment world (name='{exp_name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() @@ -320,7 +321,7 @@ def create_experiment_world( if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, - experiment_name=self.name, + experiment_name=exp_name, run_id=logger_run_id, config_dict=full_config, ) @@ -364,20 +365,18 @@ def create_experiment_world( def run( self, - run_name: str | Literal["DATETIME_TAG"] | None = None, + run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, - **kwargs, + **kwargs: dict[str, Any], ) -> ExperimentResult: """Run the experiment and return the results. - :param run_name: if not None, will adjust the current instance's `name` name attribute. - The name corresponds to the directory (within the logging - directory) where all results associated with the experiment will be saved. + :param run_name: Defines a name for this run of the experiment, which determines + the subdirectory (within the persistence base directory) where all results will be saved. + If None, the experiment's name will be used. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. - If "DATETIME_TAG" is passed, use a name containing the current date and time. This option - is useful for preventing file-name collisions if a single experiment is executed repeatedly. :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when using wandb, in particular). :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed @@ -388,7 +387,7 @@ def run( # backward compatibility _experiment_name = kwargs.pop("experiment_name", None) if _experiment_name is not None: - run_name = _experiment_name + run_name = cast(str, _experiment_name) deprecation( "Parameter run_name should now be used instead of experiment_name. " "Support for experiment_name will be removed in the future.", @@ -565,15 +564,19 @@ def with_optim_factory(self, optim_factory: OptimizerFactory) -> Self: def with_optim_factory_default( self, - **kwargs: Unpack[DefaultOptimFactoryParams], + # Keep values in sync with default values in OptimizerFactoryAdam + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-08, + weight_decay: float = 0, ) -> Self: - """Configures the use of the default optimizer, with the given parameters. + """Configures the use of the default optimizer, Adam, with the given parameters. - :param kwargs: the parameters to use for the optimizer, see `DefaultOptimFactoryParams`. + :param betas: coefficients used for computing running averages of gradient and its square + :param eps: term added to the denominator to improve numerical stability + :param weight_decay: weight decay (L2 penalty) :return: the builder """ - default_optim_params: DefaultOptimFactoryParams = {**DEFAULT_OPTIM_FACTORY_PARAMS, **kwargs} - self._optim_factory = OptimizerFactoryDefault(**default_optim_params) + self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay) return self def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self: @@ -624,7 +627,8 @@ def _create_agent_factory(self) -> AgentFactory: def _get_optim_factory(self) -> OptimizerFactory: if self._optim_factory is None: - return OptimizerFactoryDefault() + # same mechanism as in `with_optim_factory_default` + return OptimizerFactoryAdam() else: return self._optim_factory diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index db5fd90ff..bdb01fbf0 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -45,6 +45,8 @@ def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim class OptimizerFactoryAdam(OptimizerFactory): + # Note: currently used as default optimizer + # values should be kept in sync with `ExperimentBuilder.with_optim_factory_default` def __init__( self, betas: tuple[float, float] = (0.9, 0.999),