From 2b1594a1c8932a1362a9e8e034cea901673b2e88 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 30 Apr 2024 16:12:43 +0200 Subject: [PATCH 1/4] Clean up handling of an Experiment's name (and, by extension, a run's name) --- CHANGELOG.md | 6 +- examples/atari/atari_dqn_hl.py | 2 +- examples/atari/atari_iqn_hl.py | 2 +- examples/atari/atari_ppo_hl.py | 2 +- examples/atari/atari_sac_hl.py | 2 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg_hl.py | 2 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo_hl.py | 2 +- examples/mujoco/mujoco_redq_hl.py | 2 +- examples/mujoco/mujoco_reinforce_hl.py | 2 +- examples/mujoco/mujoco_sac_hl.py | 2 +- examples/mujoco/mujoco_td3_hl.py | 2 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- test/highlevel/test_experiment_builder.py | 4 +- tianshou/highlevel/experiment.py | 67 ++++++++++++----------- 16 files changed, 56 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b542ea11..2c62cca4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,11 @@ - `to_dict` in Batch supports also non-recursive conversion. #1098 - Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098 - `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105. -- `Experiment` and `ExperimentConfig` now have a `name`, that can however be overridden when `Experiment.run()` is called. #1074 +- `highlevel.experiment`: + - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and + which determines the default run name and therefore the persistence subdirectory. + It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than + `experiment_name` (although the latter will still be interpreted correctly). #1074 - When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074 - New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074 - `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074 diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 289363e1e..aa76983be 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -104,7 +104,7 @@ def main( ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 850c0ffa4..23df1cd25 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -96,7 +96,7 @@ def main( .with_epoch_stop_callback(AtariEpochStopCallback(task)) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index ea45df556..10dcd0a7e 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -115,7 +115,7 @@ def main( ), ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 8b1bf2825..cf09b40ea 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -103,7 +103,7 @@ def main( ), ) experiment = builder.build() - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 96ad8c584..bce02e9c0 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_factory_default(hidden_sizes, nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index a476245ab..db9c4e3e2 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -74,7 +74,7 @@ def main( .with_critic_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 2e437caca..ab265a87a 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -85,7 +85,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 601b08413..27a701b12 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -95,7 +95,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index 9b4bca75b..f52372906 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -83,7 +83,7 @@ def main( .with_critic_ensemble_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index a5ec65f9a..46eb64fa2 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -72,7 +72,7 @@ def main( .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index 9ffa0f43c..5ca731868 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -80,7 +80,7 @@ def main( .with_common_critic_factory_default(hidden_sizes) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 6adc73d26..3a32c7f42 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -85,7 +85,7 @@ def main( .with_common_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 3af69bd45..f54d4c312 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -89,7 +89,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(override_experiment_name=log_name) + experiment.run(run_name=log_name) if __name__ == "__main__": diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 725d7f7b5..0ba8a7bac 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -49,7 +49,7 @@ def test_experiment_builder_continuous_default_params(builder_cls: type[Experime sampling_config=sampling_config, ) experiment = builder.build() - experiment.run(override_experiment_name="test") + experiment.run(run_name="test") print(experiment) @@ -77,7 +77,7 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment sampling_config=sampling_config, ) experiment = builder.build() - experiment.run(override_experiment_name="test") + experiment.run(run_name="test") print(experiment) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 6f9eb7c00..e31a39289 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -6,7 +6,7 @@ from copy import copy from dataclasses import dataclass from pprint import pformat -from typing import Literal, Self +from typing import Self, Dict, Any import numpy as np import torch @@ -80,7 +80,7 @@ ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy -from tianshou.utils import LazyLogger, logging +from tianshou.utils import LazyLogger, deprecation, logging from tianshou.utils.logging import datetime_tag from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin @@ -145,8 +145,8 @@ def __init__( env_factory: EnvFactory, agent_factory: AgentFactory, sampling_config: SamplingConfig, + name: str, logger_factory: LoggerFactory | None = None, - name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", ): if logger_factory is None: logger_factory = LoggerFactoryDefault() @@ -155,8 +155,6 @@ def __init__( self.env_factory = env_factory self.agent_factory = agent_factory self.logger_factory = logger_factory - if name == "DATETIME_TAG": - name = datetime_tag() self.name = name def get_seeding_info_as_str(self) -> str: @@ -205,33 +203,41 @@ def save(self, directory: str) -> None: def run( self, - override_experiment_name: str | Literal["DATETIME_TAG"] | None = None, + run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, + **kwargs: Dict[str, Any], ) -> ExperimentResult: """Run the experiment and return the results. - :param override_experiment_name: if not None, will adjust the current instance's `name` name attribute. - The name corresponds to the directory (within the logging - directory) where all results associated with the experiment will be saved. + :param run_name: Defines a name for this run of the experiment, which determines + the subdirectory (within the persistence base directory) where all results will be saved. + If None, the experiment's name will be used. The name may contain path separators (i.e. `os.path.sep`, as used by `os.path.join`), in which case a nested directory structure will be created. - If "DATETIME_TAG" is passed, use a name containing the current date and time. This option - is useful for preventing file-name collisions if a single experiment is executed repeatedly. :param logger_run_id: Run identifier to use for logger initialization/resumption (applies when using wandb, in particular). :param raise_error_on_dirname_collision: set to `False` e.g., when continuing a previously executed experiment with the same name. + :param kwargs: for backward compatibility with old parameter names only :return: """ - if override_experiment_name is not None: - if override_experiment_name == "DATETIME_TAG": - override_experiment_name = datetime_tag() - self.name = override_experiment_name + # backward compatibility + _experiment_name = kwargs.pop("experiment_name", None) + if _experiment_name is not None: + run_name = _experiment_name + deprecation( + "Parameter run_name should now be used instead of experiment_name. " + "Support for experiment_name will be removed in the future.", + ) + assert len(kwargs) == 0, f"Received unexpected arguments: {kwargs}" + + if run_name is None: + run_name = self.name # initialize persistence directory use_persistence = self.config.persistence_enabled - persistence_dir = os.path.join(self.config.persistence_base_dir, self.name) + persistence_dir = os.path.join(self.config.persistence_base_dir, run_name) if use_persistence: os.makedirs(persistence_dir, exist_ok=not raise_error_on_dirname_collision) @@ -240,7 +246,7 @@ def run( enabled=use_persistence and self.config.log_file_enabled, ): # log initial information - log.info(f"Running experiment (name='{self.name}'):\n{self.pprints()}") + log.info(f"Running experiment (name='{run_name}'):\n{self.pprints()}") log.info(f"Working directory: {os.getcwd()}") self._set_seed() @@ -271,7 +277,7 @@ def run( if use_persistence: logger = self.logger_factory.create_logger( log_dir=persistence_dir, - experiment_name=self.name, + experiment_name=run_name, run_id=logger_run_id, config_dict=full_config, ) @@ -364,7 +370,7 @@ def __init__( self._optim_factory: OptimizerFactory | None = None self._policy_wrapper_factory: PolicyWrapperFactory | None = None self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() - self._experiment_name: str = "" + self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() @contextmanager def temp_config_mutation(self) -> Iterator[Self]: @@ -467,18 +473,17 @@ def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self: self._trainer_callbacks.epoch_stop_callback = callback return self - def with_experiment_name( + def with_name( self, - experiment_name: str | Literal["DATETIME_TAG"] = "DATETIME_TAG", + name: str, ) -> Self: """Sets the name of the experiment. - :param experiment_name: the name. If "DATETIME_TAG" (default) is given, the current date and time will be used. + :param name: the name to use for this experiment, which, when the experiment is run, + will determine the storage sub-folder by default :return: the builder """ - if experiment_name == "DATETIME_TAG": - experiment_name = datetime_tag() - self._experiment_name = experiment_name + self._name = name return self @abstractmethod @@ -504,12 +509,12 @@ def build(self, add_seeding_info_to_name: bool = False) -> Experiment: if self._policy_wrapper_factory: agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) experiment: Experiment = Experiment( - self._config, - self._env_factory, - agent_factory, - self._sampling_config, - self._logger_factory, - name=self._experiment_name, + config=self._config, + env_factory=self._env_factory, + agent_factory=agent_factory, + sampling_config=self._sampling_config, + name=self._name, + logger_factory=self._logger_factory, ) if add_seeding_info_to_name: if not experiment.name: From f8cca8b07c45a2f98036df8b9f359bbfcc3eb43e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 30 Apr 2024 17:22:11 +0200 Subject: [PATCH 2/4] Improve creation of multiple seeded experiments: * Add class ExperimentCollection to improve usability * Remove parameters from ExperimentBuilder.build * Renamed ExperimentBuilder.build_default_seeded_experiments to build_seeded_collection, changing the return type to ExperimentCollection * Replace temp_config_mutation (which was not appropriate for the public API) with method copy (which performs a safe deep copy) --- examples/mujoco/mujoco_ppo_hl_multi.py | 135 ++++++---------------- test/highlevel/test_experiment_builder.py | 27 ----- tianshou/highlevel/experiment.py | 92 ++++++++------- 3 files changed, 83 insertions(+), 171 deletions(-) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 4408a132c..319375f12 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -14,8 +14,6 @@ import os import sys -from collections.abc import Sequence -from typing import Literal import torch @@ -41,86 +39,30 @@ def main( - experiment_config: ExperimentConfig, - task: str = "Ant-v4", - num_experiments: int = 5, - buffer_size: int = 4096, - hidden_sizes: Sequence[int] = (64, 64), - lr: float = 3e-4, - gamma: float = 0.99, - epoch: int = 3, - step_per_epoch: int = 30000, - step_per_collect: int = 2048, - repeat_per_collect: int = 10, - batch_size: int = 64, - training_num: int = 10, - test_num: int = 10, - rew_norm: bool = True, - vf_coef: float = 0.25, - ent_coef: float = 0.0, - gae_lambda: float = 0.95, - bound_action_method: Literal["clip", "tanh"] | None = "clip", - lr_decay: bool = True, - max_grad_norm: float = 0.5, - eps_clip: float = 0.2, - dual_clip: float | None = None, - value_clip: bool = False, - norm_adv: bool = False, - recompute_adv: bool = True, + num_experiments: int = 2, run_experiments_sequentially: bool = True, -) -> str: - """Use the high-level API of TianShou to evaluate the PPO algorithm on a MuJoCo environment with multiple seeds for - a given configuration. The results for each run are stored in separate sub-folders. After the agents are trained, - the results are evaluated using the rliable API. - - :param experiment_config: - :param task: a mujoco task name - :param num_experiments: how many experiments to run with different seeds - :param buffer_size: - :param hidden_sizes: - :param lr: - :param gamma: - :param epoch: - :param step_per_epoch: - :param step_per_collect: - :param repeat_per_collect: - :param batch_size: - :param training_num: - :param test_num: - :param rew_norm: - :param vf_coef: - :param ent_coef: - :param gae_lambda: - :param bound_action_method: - :param lr_decay: - :param max_grad_norm: - :param eps_clip: - :param dual_clip: - :param value_clip: - :param norm_adv: - :param recompute_adv: - :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. +) -> RLiableExperimentResult: + """:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. LIMITATIONS: currently, the parallel execution does not seem to work properly on linux. It might generally be undesired to run multiple experiments in parallel on the same machine, as a single experiment already uses all available CPU cores by default. :return: the directory where the results are stored """ + task = "Ant-v4" persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) - experiment_config.persistence_base_dir = persistence_dir - log.info(f"Will save all experiment results to {persistence_dir}.") - experiment_config.watch = False + experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False) sampling_config = SamplingConfig( - num_epochs=epoch, - step_per_epoch=step_per_epoch, - batch_size=batch_size, - num_train_envs=training_num, - num_test_envs=test_num, - num_test_episodes=test_num, - buffer_size=buffer_size, - step_per_collect=step_per_collect, - repeat_per_collect=repeat_per_collect, + num_epochs=1, + step_per_epoch=5000, + batch_size=64, + num_train_envs=10, + num_test_envs=10, + num_test_episodes=10, + buffer_size=4096, + step_per_collect=2048, + repeat_per_collect=10, ) env_factory = MujocoEnvFactory( @@ -133,52 +75,45 @@ def main( else VectorEnvType.SUBPROC_SHARED_MEM, ) - experiments = ( + hidden_sizes = (64, 64) + + experiment_collection = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ppo_params( PPOParams( - discount_factor=gamma, - gae_lambda=gae_lambda, - action_bound_method=bound_action_method, - reward_normalization=rew_norm, - ent_coef=ent_coef, - vf_coef=vf_coef, - max_grad_norm=max_grad_norm, - value_clip=value_clip, - advantage_normalization=norm_adv, - eps_clip=eps_clip, - dual_clip=dual_clip, - recompute_advantage=recompute_adv, - lr=lr, - lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) - if lr_decay - else None, + discount_factor=0.99, + gae_lambda=0.95, + action_bound_method="clip", + reward_normalization=True, + ent_coef=0.0, + vf_coef=0.25, + max_grad_norm=0.5, + value_clip=False, + advantage_normalization=False, + eps_clip=0.2, + dual_clip=None, + recompute_advantage=True, + lr=3e-4, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config), dist_fn=DistributionFunctionFactoryIndependentGaussians(), ), ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .with_logger_factory(LoggerFactoryDefault("tensorboard")) - .build_default_seeded_experiments(num_experiments) + .build_seeded_collection(num_experiments) ) if run_experiments_sequentially: launcher = RegisteredExpLauncher.sequential.create_launcher() else: launcher = RegisteredExpLauncher.joblib.create_launcher() - launcher.launch(experiments) - - return persistence_dir - + experiment_collection.run(launcher) -def eval_experiments(log_dir: str) -> RLiableExperimentResult: - """Evaluate the experiments in the given log directory using the rliable API.""" - rliable_result = RLiableExperimentResult.load_from_disk(log_dir) + rliable_result = RLiableExperimentResult.load_from_disk(persistence_dir) rliable_result.eval_results(show_plots=True, save_plots=True) return rliable_result if __name__ == "__main__": - log_dir = logging.run_cli(main, level=logging.INFO) - assert isinstance(log_dir, str) # for mypy - evaluation_result = eval_experiments(log_dir) + result = logging.run_cli(main, level=logging.INFO) diff --git a/test/highlevel/test_experiment_builder.py b/test/highlevel/test_experiment_builder.py index 0ba8a7bac..cb52c5ae3 100644 --- a/test/highlevel/test_experiment_builder.py +++ b/test/highlevel/test_experiment_builder.py @@ -79,30 +79,3 @@ def test_experiment_builder_discrete_default_params(builder_cls: type[Experiment experiment = builder.build() experiment.run(run_name="test") print(experiment) - - -def test_temp_builder_modification() -> None: - env_factory = DiscreteTestEnvFactory() - sampling_config = SamplingConfig( - num_epochs=1, - step_per_epoch=100, - num_train_envs=2, - num_test_envs=2, - ) - builder = PPOExperimentBuilder( - experiment_config=ExperimentConfig(persistence_enabled=False), - env_factory=env_factory, - sampling_config=sampling_config, - ) - original_seed = builder.experiment_config.seed - original_train_seed = builder.sampling_config.train_seed - - with builder.temp_config_mutation(): - builder.experiment_config.seed += 12345 - builder.sampling_config.train_seed += 456 - exp = builder.build() - - assert builder.experiment_config.seed == original_seed - assert builder.sampling_config.train_seed == original_train_seed - assert exp.config.seed == original_seed + 12345 - assert exp.sampling_config.train_seed == original_train_seed + 456 diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index e31a39289..8fc21cfad 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -1,12 +1,11 @@ import os import pickle from abc import abstractmethod -from collections.abc import Iterator, Sequence -from contextlib import contextmanager -from copy import copy +from collections.abc import Sequence +from copy import deepcopy from dataclasses import dataclass from pprint import pformat -from typing import Self, Dict, Any +from typing import TYPE_CHECKING, Any, Self, Union, cast import numpy as np import torch @@ -85,6 +84,10 @@ from tianshou.utils.net.common import ModuleType from tianshou.utils.string import ToStringMixin +if TYPE_CHECKING: + from tianshou.evaluation.launcher import ExpLauncher, RegisteredExpLauncher + + log = logging.getLogger(__name__) @@ -157,19 +160,6 @@ def __init__( self.logger_factory = logger_factory self.name = name - def get_seeding_info_as_str(self) -> str: - """Useful for creating unique experiment names based on seeds. - - A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. - """ - return "_".join( - [ - f"exp_seed={self.config.seed}", - f"train_seed={self.sampling_config.train_seed}", - f"test_seed={self.sampling_config.test_seed}", - ], - ) - @classmethod def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": """Restores an experiment from a previously stored pickle. @@ -184,6 +174,20 @@ def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experim experiment.config.policy_restore_directory = directory return experiment + def get_seeding_info_as_str(self) -> str: + """Returns information on the seeds used in the experiment as a string. + + This can be useful for creating unique experiment names based on seeds, e.g. + A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. + """ + return "_".join( + [ + f"exp_seed={self.config.seed}", + f"train_seed={self.sampling_config.train_seed}", + f"test_seed={self.sampling_config.test_seed}", + ], + ) + def _set_seed(self) -> None: seed = self.config.seed log.info(f"Setting random seed {seed}") @@ -206,7 +210,7 @@ def run( run_name: str | None = None, logger_run_id: str | None = None, raise_error_on_dirname_collision: bool = True, - **kwargs: Dict[str, Any], + **kwargs: dict[str, Any], ) -> ExperimentResult: """Run the experiment and return the results. @@ -225,7 +229,7 @@ def run( # backward compatibility _experiment_name = kwargs.pop("experiment_name", None) if _experiment_name is not None: - run_name = _experiment_name + run_name = cast(str, _experiment_name) deprecation( "Parameter run_name should now be used instead of experiment_name. " "Support for experiment_name will be removed in the future.", @@ -352,6 +356,18 @@ def _watch_agent( ) +class ExperimentCollection: + def __init__(self, experiments: list[Experiment]): + self.experiments = experiments + + def run(self, launcher: Union["ExpLauncher", "RegisteredExpLauncher"]) -> None: + from tianshou.evaluation.launcher import RegisteredExpLauncher + + if isinstance(launcher, RegisteredExpLauncher): + launcher = launcher.create_launcher() + launcher.launch(experiments=self.experiments) + + class ExperimentBuilder: def __init__( self, @@ -372,14 +388,8 @@ def __init__( self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks() self._name: str = self.__class__.__name__.replace("Builder", "") + "_" + datetime_tag() - @contextmanager - def temp_config_mutation(self) -> Iterator[Self]: - """Returns the builder instance where the configs can be modified without affecting the current instance.""" - original_sampling_config = copy(self.sampling_config) - original_experiment_config = copy(self.experiment_config) - yield self - self.sampling_config = original_sampling_config - self.experiment_config = original_experiment_config + def copy(self) -> Self: + return deepcopy(self) @property def experiment_config(self) -> ExperimentConfig: @@ -496,12 +506,9 @@ def _get_optim_factory(self) -> OptimizerFactory: else: return self._optim_factory - def build(self, add_seeding_info_to_name: bool = False) -> Experiment: + def build(self) -> Experiment: """Creates the experiment based on the options specified via this builder. - :param add_seeding_info_to_name: whether to add a postfix to the experiment name that contains - info about the training seeds. Useful for creating multiple experiments that only differ - by seeds. :return: the experiment """ agent_factory = self._create_agent_factory() @@ -516,27 +523,24 @@ def build(self, add_seeding_info_to_name: bool = False) -> Experiment: name=self._name, logger_factory=self._logger_factory, ) - if add_seeding_info_to_name: - if not experiment.name: - experiment.name = experiment.get_seeding_info_as_str() - else: - experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}" return experiment - def build_default_seeded_experiments(self, num_experiments: int) -> list[Experiment]: - """Creates a list of experiments with non-overlapping seeds, starting from the configured seed. + def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: + """Creates a collection of experiments with non-overlapping random seeds, starting from the configured seed. - Each experiment will have a unique name that is created from the original experiment name and the seeds used. + Each experiment in the collection will have a unique name that is created from the original experiment name and the seeds used. """ num_train_envs = self.sampling_config.num_train_envs seeded_experiments = [] for i in range(num_experiments): - with self.temp_config_mutation(): - self.experiment_config.seed += i - self.sampling_config.train_seed += i * num_train_envs - seeded_experiments.append(self.build(add_seeding_info_to_name=True)) - return seeded_experiments + builder = self.copy() + builder.experiment_config.seed += i + builder.sampling_config.train_seed += i * num_train_envs + experiment = builder.build() + experiment.name += f"_{experiment.get_seeding_info_as_str()}" + seeded_experiments.append(experiment) + return ExperimentCollection(seeded_experiments) class _BuilderMixinActorFactory(ActorFutureProviderProtocol): From ea0c4f1a30da926d7e1c27e6359e755edc3098f2 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 30 Apr 2024 17:31:48 +0200 Subject: [PATCH 3/4] Update change log with changes from #1131 --- CHANGELOG.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c62cca4f..361fd57c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,9 +16,12 @@ - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and which determines the default run name and therefore the persistence subdirectory. It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than - `experiment_name` (although the latter will still be interpreted correctly). #1074 -- When building an `Experiment` from an `ExperimentConfig`, the user has the option to add info about seeds to the name. #1074 -- New method in `ExperimentConfig` called `build_default_seeded_experiments`. #1074 + `experiment_name` (although the latter will still be interpreted correctly). #1074 #1131 + - Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131 + - `ExperimentBuilder`: + - Add method `build_seeded_collection` for the sound creation of multiple + experiments with varying random seeds #1131 + - Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131 - `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074 - New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). Launchers for parallelization currently in alpha state. #1074 From 393e55aa5892e96a8095f430517654585f4b561e Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 30 Apr 2024 17:47:06 +0200 Subject: [PATCH 4/4] Improve change log #1129 --- CHANGELOG.md | 60 +++++++++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 361fd57c1..35db33ab4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,34 +3,42 @@ ## Release 1.1.0 ### Api Extensions -- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 -- `Collector`s can now be closed, and their reset is more granular. #1063 -- Trainers can control whether collectors should be reset prior to training. #1063 -- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 -- `SamplingConfig` supports `batch_size=None`. #1077 -- Batch received new methods: `to_numpy_` and `to_torch_`. #1098, #1117 -- `to_dict` in Batch supports also non-recursive conversion. #1098 -- Batch `__eq__` implemented, semantic equality check of batches is now possible. #1098 +- `data`: + - `Batch`: + - Add methods `to_dict` and `to_list_of_dicts`. #1063 #1098 + - Add methods `to_numpy_` and `to_torch_`. #1098, #1117 + - Add `__eq__` (semantic equality check). #1098 + - `data.collector`: + - `Collector`: + - Add method `close` #1063 + - Method `reset` is now more granular (new flags controlling behavior). #1063 + - `CollectStats`: Add convenience constructor `with_autogenerated_stats`. #1063 +- `trainer`: + - Trainers can now control whether collectors should be reset prior to training. #1063 - `Batch.keys()` deprecated in favor of `Batch.get_keys()` (needed to make iteration consistent with naming) #1105. -- `highlevel.experiment`: - - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and - which determines the default run name and therefore the persistence subdirectory. - It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than - `experiment_name` (although the latter will still be interpreted correctly). #1074 #1131 - - Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131 - - `ExperimentBuilder`: - - Add method `build_seeded_collection` for the sound creation of multiple - experiments with varying random seeds #1131 - - Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131 -- `SamplingConfig` has an explicit training seed, `test_seed` is inferred. #1074 -- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). -Launchers for parallelization currently in alpha state. #1074 +- `highlevel`: + - `SamplingConfig`: + - Add support for `batch_size=None`. #1077 + - Add `training_seed` for explicit seeding of training and test environments, the `test_seed` is inferred from `training_seed`. #1074 + - `highlevel.experiment`: + - `Experiment` now has a `name` attribute, which can be set using `ExperimentBuilder.with_name` and + which determines the default run name and therefore the persistence subdirectory. + It can still be overridden in `Experiment.run()`, the new parameter name being `run_name` rather than + `experiment_name` (although the latter will still be interpreted correctly). #1074 #1131 + - Add class `ExperimentCollection` for the convenient execution of multiple experiment runs #1131 + - `ExperimentBuilder`: + - Add method `build_seeded_collection` for the sound creation of multiple + experiments with varying random seeds #1131 + - Add method `copy` to facilitate the creation of multiple experiments from a single builder #1131 +- `evaluation`: New package for repeating the same experiment with multiple seeds and aggregating the results. #1074 + - The module `evaluation.launchers` for parallelization is currently in alpha state. - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 -- `continuous.Critic`: - - Add flag `apply_preprocess_net_to_obs_only` to allow the - preprocessing network to be applied to the observations only (without - the actions concatenated), which is essential for the case where we want - to reuse the actor's preprocessing network #1128 +- `utils.net`: + - `continuous.Critic`: + - Add flag `apply_preprocess_net_to_obs_only` to allow the + preprocessing network to be applied to the observations only (without + the actions concatenated), which is essential for the case where we want + to reuse the actor's preprocessing network #1128 ### Fixes - `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,