From 0548533ad643df74fcc86f2b93ec1e596305d27a Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Tue, 23 Jan 2024 10:34:11 +0100 Subject: [PATCH 1/4] Draft for explicit seed mechanism of train seed, specific test seeds --- examples/mujoco/mujoco_env.py | 33 ++++++- examples/mujoco/mujoco_ppo_hl_experiment.py | 103 ++++++++++++++++++++ test/highlevel/test_eval_procedure.py | 79 +++++++++++++++ tianshou/env/venvs.py | 6 +- tianshou/highlevel/experiment.py | 18 +++- tianshou/hpo_sweepers/__init__.py | 0 6 files changed, 231 insertions(+), 8 deletions(-) create mode 100644 examples/mujoco/mujoco_ppo_hl_experiment.py create mode 100644 test/highlevel/test_eval_procedure.py create mode 100644 tianshou/hpo_sweepers/__init__.py diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 3a4812108..125516c22 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,6 +1,7 @@ import logging import pickle import warnings +from typing import Literal import gymnasium as gym @@ -17,7 +18,11 @@ log = logging.getLogger(__name__) -def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): +def make_mujoco_env(task: str, seed: int, num_train_envs: int, + num_test_envs: int, obs_norm: bool, + train_seed_mechanism: Literal["consecutive"]|Literal["repeat"]| None = None, + test_seeds: tuple[int] | None = None + ): #makes mujoco envs, name is not really honest """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. @@ -27,6 +32,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in if envpool is not None: train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed) test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed) + #todo robert check how seeding is done here else: warnings.warn( "Recommend using envpool (pip install envpool) " @@ -35,8 +41,22 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in env = gym.make(task) train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) - train_envs.seed(seed) - test_envs.seed(seed) + if train_seed_mechanism == "consecutive": + train_envs.seed([seed + i for i in range(num_train_envs)]) + elif train_seed_mechanism == "repeat": + train_envs.seed([seed for _ in range(num_train_envs)]) + elif train_seed_mechanism is None: + train_envs.seed(seed) + else: + NotImplementedError(f"train_seed_mechanism {train_seed_mechanism} not implemented") + + #train_envs.seed(seed) # the make_mujoco_env function requieres seed to be an int, whereas the seed function allows for seed in int | list[int] | None with very differnt behavior + if test_seeds is None: + test_envs.seed(seed) + else: + assert len(test_seeds) == num_test_envs + test_envs.seed(test_seeds) + if obs_norm: # obs norm wrapper train_envs = VectorEnvNormObs(train_envs) @@ -69,10 +89,13 @@ def restore(self, event: RestoreEvent, world: World): class MujocoEnvFactory(EnvFactory): - def __init__(self, task: str, seed: int, obs_norm=True): + def __init__(self, task: str, seed: int, obs_norm=True, + train_seed_mechanism: Literal["consecutive"]|Literal["repeat"]|None = None, test_seeds: tuple[int]|None = None): self.task = task self.seed = seed self.obs_norm = obs_norm + self.train_seed_mechanism = train_seed_mechanism + self.test_seeds = test_seeds def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments: env, train_envs, test_envs = make_mujoco_env( @@ -81,6 +104,8 @@ def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousE num_train_envs=num_training_envs, num_test_envs=num_test_envs, obs_norm=self.obs_norm, + train_seed_mechanism=self.train_seed_mechanism, + test_seeds=self.test_seeds, ) envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) if self.obs_norm: diff --git a/examples/mujoco/mujoco_ppo_hl_experiment.py b/examples/mujoco/mujoco_ppo_hl_experiment.py new file mode 100644 index 000000000..c9eab23ef --- /dev/null +++ b/examples/mujoco/mujoco_ppo_hl_experiment.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +import os +from collections.abc import Sequence +from typing import Literal + +import torch + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + ExperimentConfig, + EvaluationProtocalExperimentConfig, + PPOExperimentBuilder, +) +from tianshou.highlevel.params.dist_fn import ( + DistributionFunctionFactoryIndependentGaussians, +) +from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear +from tianshou.highlevel.params.policy_params import PPOParams +from tianshou.utils import logging +from tianshou.utils.logging import datetime_tag + + +def main( + experiment_config: EvaluationProtocalExperimentConfig, + task: str = "Ant-v4", + buffer_size: int = 4096, + hidden_sizes: Sequence[int] = (64, 64), + lr: float = 3e-4, + gamma: float = 0.99, + epoch: int = 100, + step_per_epoch: int = 30000, + step_per_collect: int = 2048, + repeat_per_collect: int = 10, + batch_size: int = 64, + training_num: int = 4, #64, + test_num: int = 2, #10, + rew_norm: bool = True, + vf_coef: float = 0.25, + ent_coef: float = 0.0, + gae_lambda: float = 0.95, + bound_action_method: Literal["clip", "tanh"] | None = "clip", + lr_decay: bool = True, + max_grad_norm: float = 0.5, + eps_clip: float = 0.2, + dual_clip: float | None = None, + value_clip: bool = False, + norm_adv: bool = False, + recompute_adv: bool = True, +): + + log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) + + sampling_config = SamplingConfig( + num_epochs=epoch, + step_per_epoch=step_per_epoch, + batch_size=batch_size, + num_train_envs=training_num, + num_test_envs=test_num, + buffer_size=buffer_size, + step_per_collect=step_per_collect, + repeat_per_collect=repeat_per_collect, + ) + + env_factory = MujocoEnvFactory(task, + experiment_config.seed, + obs_norm=True, + train_seed_mechanism=experiment_config.train_seed_mechanism, + test_seeds=experiment_config.test_seeds) + + experiment = ( + PPOExperimentBuilder(env_factory, experiment_config, sampling_config) + .with_ppo_params( + PPOParams( + discount_factor=gamma, + gae_lambda=gae_lambda, + action_bound_method=bound_action_method, + reward_normalization=rew_norm, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + value_clip=value_clip, + advantage_normalization=norm_adv, + eps_clip=eps_clip, + dual_clip=dual_clip, + recompute_advantage=recompute_adv, + lr=lr, + lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config) + if lr_decay + else None, + dist_fn=DistributionFunctionFactoryIndependentGaussians(), + ), + ) + .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) + .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) + .build() + ) + experiment.run(log_name) + + +if __name__ == "__main__": + logging.run_cli(main) diff --git a/test/highlevel/test_eval_procedure.py b/test/highlevel/test_eval_procedure.py new file mode 100644 index 000000000..a7bd4a8f9 --- /dev/null +++ b/test/highlevel/test_eval_procedure.py @@ -0,0 +1,79 @@ +from test.highlevel.env_factory import ContinuousTestEnvFactory, DiscreteTestEnvFactory + +import pytest + +from tianshou.highlevel.config import SamplingConfig +from tianshou.highlevel.experiment import ( + A2CExperimentBuilder, + DDPGExperimentBuilder, + DiscreteSACExperimentBuilder, + DQNExperimentBuilder, + IQNExperimentBuilder, + PGExperimentBuilder, + PPOExperimentBuilder, + REDQExperimentBuilder, + SACExperimentBuilder, + TD3ExperimentBuilder, + TRPOExperimentBuilder, + ExperimentConfig, + EvaluationProtocalExperimentConfig +) +from examples.mujoco.mujoco_env import MujocoEnvFactory + + +def test_standard_experiment_config(): + experiment_config = ExperimentConfig + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) + env_factory = MujocoEnvFactory(task="Ant-v4", + seed=experiment_config.seed, + obs_norm=True,) + + ppo = PPOExperimentBuilder( + experiment_config=experiment_config, + env_factory=env_factory, + sampling_config=sampling_config, + ) + experiment = ppo.build() + experiment.run("test") + print(experiment) + + +@pytest.mark.parametrize("experiment_config", + [ + EvaluationProtocalExperimentConfig( + persistence_enabled=False, + train_seed_mechanism="consecutive", + test_seeds=(2,3)), + EvaluationProtocalExperimentConfig(persistence_enabled=False, train_seed_mechanism="repeat", + test_seeds=(2,3)), + EvaluationProtocalExperimentConfig(persistence_enabled=False, train_seed_mechanism=None, + test_seeds=(2,3)), + + ], +) +def test_experiment_builder_continuous_default_params(experiment_config): + sampling_config = SamplingConfig( + num_epochs=1, + step_per_epoch=100, + num_train_envs=2, + num_test_envs=2, + ) + env_factory = MujocoEnvFactory(task="Ant-v4", + seed=experiment_config.seed, + obs_norm=True, + train_seed_mechanism=experiment_config.train_seed_mechanism, + test_seeds=experiment_config.test_seeds) + + ppo = PPOExperimentBuilder( + experiment_config=experiment_config, + env_factory=env_factory, + sampling_config=sampling_config, + ) + experiment = ppo.build() + experiment.run("test") + print(experiment) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index a9b0aff74..f96255d0a 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -315,7 +315,7 @@ def step( np.stack(info_list), ) - def seed(self, seed: int | list[int] | None = None) -> list[list[int] | None]: + def seed(self, seed: int | list[int] |tuple[int]| None = None) -> list[list[int] | None]: """Set the seed for all environments. Accept ``None``, an int (which will extend ``i`` to @@ -323,11 +323,11 @@ def seed(self, seed: int | list[int] | None = None) -> list[list[int] | None]: :return: The list of seeds used in this env's random number generators. The first value in the list should be the "main" seed, or the value - which a reproducer pass to "seed". + which a reproducer passes to "seed". """ self._assert_is_not_closed() seed_list: list[None] | list[int] - if seed is None: + if seed is None: #todo robert check this can happen, results in error when put in config seed_list = [seed] * self.env_num elif isinstance(seed, int): seed_list = [seed + i for i in range(self.env_num)] diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 3989d3583..c33ad71e2 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from pprint import pformat -from typing import Self +from typing import Self, Literal import numpy as np import torch @@ -114,6 +114,22 @@ class ExperimentConfig: policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY """Controls the way in which the policy is persisted""" +@dataclass +class EvaluationProtocalExperimentConfig(ExperimentConfig): + train_seed_mechanism: Literal["consecutive"]|Literal["repeat"] = "consecutive" # or repeat, if all train seeds are supposed to be the same or consecutive numbers, compare seed function in venvs.py + test_seeds : tuple[int] = (22,49,1995,123456) + +from dataclasses import dataclass + +@dataclass +class Person: + name: str + age: int + +@dataclass +class Employee(Person): + employee_id: int + department: str @dataclass class ExperimentResult: diff --git a/tianshou/hpo_sweepers/__init__.py b/tianshou/hpo_sweepers/__init__.py new file mode 100644 index 000000000..e69de29bb From ee56baf78a46ceda58bd4915ee352af620b7a06d Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Wed, 24 Jan 2024 12:30:41 +0100 Subject: [PATCH 2/4] Add enum instead of literal, fix subset of comments --- examples/mujoco/mujoco_env.py | 14 ++++---- examples/mujoco/mujoco_ppo_hl_experiment.py | 2 +- test/highlevel/test_eval_procedure.py | 21 ++++-------- tianshou/highlevel/experiment.py | 38 +++++++++++++-------- 4 files changed, 38 insertions(+), 37 deletions(-) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 125516c22..f537aee63 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,12 +1,12 @@ import logging import pickle import warnings -from typing import Literal import gymnasium as gym from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory +from tianshou.highlevel.experiment import TrainSeedMechanism from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent from tianshou.highlevel.world import World @@ -20,7 +20,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool, - train_seed_mechanism: Literal["consecutive"]|Literal["repeat"]| None = None, + train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.NONE, test_seeds: tuple[int] | None = None ): #makes mujoco envs, name is not really honest """Wrapper function for Mujoco env. @@ -41,16 +41,15 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, env = gym.make(task) train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)]) test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)]) - if train_seed_mechanism == "consecutive": + if train_seed_mechanism.is_consecutive(): train_envs.seed([seed + i for i in range(num_train_envs)]) - elif train_seed_mechanism == "repeat": + elif train_seed_mechanism.is_repeat(): train_envs.seed([seed for _ in range(num_train_envs)]) - elif train_seed_mechanism is None: + elif train_seed_mechanism.is_none(): train_envs.seed(seed) else: NotImplementedError(f"train_seed_mechanism {train_seed_mechanism} not implemented") - #train_envs.seed(seed) # the make_mujoco_env function requieres seed to be an int, whereas the seed function allows for seed in int | list[int] | None with very differnt behavior if test_seeds is None: test_envs.seed(seed) else: @@ -90,7 +89,8 @@ def restore(self, event: RestoreEvent, world: World): class MujocoEnvFactory(EnvFactory): def __init__(self, task: str, seed: int, obs_norm=True, - train_seed_mechanism: Literal["consecutive"]|Literal["repeat"]|None = None, test_seeds: tuple[int]|None = None): + train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.NONE, + test_seeds: tuple[int]|None = None): self.task = task self.seed = seed self.obs_norm = obs_norm diff --git a/examples/mujoco/mujoco_ppo_hl_experiment.py b/examples/mujoco/mujoco_ppo_hl_experiment.py index c9eab23ef..257649621 100644 --- a/examples/mujoco/mujoco_ppo_hl_experiment.py +++ b/examples/mujoco/mujoco_ppo_hl_experiment.py @@ -10,7 +10,7 @@ from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( ExperimentConfig, - EvaluationProtocalExperimentConfig, + EvaluationProtocolExperimentConfig, PPOExperimentBuilder, ) from tianshou.highlevel.params.dist_fn import ( diff --git a/test/highlevel/test_eval_procedure.py b/test/highlevel/test_eval_procedure.py index a7bd4a8f9..4c96ef022 100644 --- a/test/highlevel/test_eval_procedure.py +++ b/test/highlevel/test_eval_procedure.py @@ -4,19 +4,10 @@ from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.experiment import ( - A2CExperimentBuilder, - DDPGExperimentBuilder, - DiscreteSACExperimentBuilder, - DQNExperimentBuilder, - IQNExperimentBuilder, - PGExperimentBuilder, PPOExperimentBuilder, - REDQExperimentBuilder, - SACExperimentBuilder, - TD3ExperimentBuilder, - TRPOExperimentBuilder, ExperimentConfig, - EvaluationProtocalExperimentConfig + EvaluationProtocolExperimentConfig, + TrainSeedMechanism ) from examples.mujoco.mujoco_env import MujocoEnvFactory @@ -45,13 +36,13 @@ def test_standard_experiment_config(): @pytest.mark.parametrize("experiment_config", [ - EvaluationProtocalExperimentConfig( + EvaluationProtocolExperimentConfig( persistence_enabled=False, - train_seed_mechanism="consecutive", + train_seed_mechanism=TrainSeedMechanism.CONSECUTIVE, test_seeds=(2,3)), - EvaluationProtocalExperimentConfig(persistence_enabled=False, train_seed_mechanism="repeat", + EvaluationProtocolExperimentConfig(persistence_enabled=False, train_seed_mechanism=TrainSeedMechanism.REPEAT, test_seeds=(2,3)), - EvaluationProtocalExperimentConfig(persistence_enabled=False, train_seed_mechanism=None, + EvaluationProtocolExperimentConfig(persistence_enabled=False, train_seed_mechanism=TrainSeedMechanism.NONE, test_seeds=(2,3)), ], diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index c33ad71e2..b346f84b8 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -3,6 +3,7 @@ from abc import abstractmethod from collections.abc import Sequence from dataclasses import dataclass +from enum import Enum from pprint import pformat from typing import Self, Literal @@ -84,7 +85,21 @@ log = logging.getLogger(__name__) +class TrainSeedMechanism(Enum): + """Enumeration of mechanisms to create the train seed set.""" + REPEAT = "repeat" + CONSECUTIVE = "consecutive" + NONE = "none" + + def is_repeat(self) -> bool: + return self == TrainSeedMechanism.REPEAT + + def is_consecutive(self) -> bool: + return self == TrainSeedMechanism.CONSECUTIVE + + def is_none(self) -> bool: + return self == TrainSeedMechanism.NONE @dataclass class ExperimentConfig: """Generic config for setting up the experiment, not RL or training specific.""" @@ -114,22 +129,17 @@ class ExperimentConfig: policy_persistence_mode: PolicyPersistence.Mode = PolicyPersistence.Mode.POLICY """Controls the way in which the policy is persisted""" -@dataclass -class EvaluationProtocalExperimentConfig(ExperimentConfig): - train_seed_mechanism: Literal["consecutive"]|Literal["repeat"] = "consecutive" # or repeat, if all train seeds are supposed to be the same or consecutive numbers, compare seed function in venvs.py +@dataclass(kw_only=True) +class EvaluationProtocolExperimentConfig(ExperimentConfig): + train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.CONSECUTIVE + """Whether all train seeds are supposed to be the same or consecutive + numbers, compare seed function in venvs.py """ test_seeds : tuple[int] = (22,49,1995,123456) + """Set of seeds to use during testing""" + def __post_init__(self): + assert set(self.test_seeds).isdisjoint({self.seed}) + #todo adapt this for more complex version of train and test seed sets -from dataclasses import dataclass - -@dataclass -class Person: - name: str - age: int - -@dataclass -class Employee(Person): - employee_id: int - department: str @dataclass class ExperimentResult: From 945891e8d24e2235e02eeb576982612edfafce49 Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Wed, 24 Jan 2024 22:44:47 +0100 Subject: [PATCH 3/4] Fix tuple type hint, add seed of env generating a transtition to the buffer --- examples/mujoco/mujoco_env.py | 6 +++--- examples/mujoco/mujoco_ppo_hl_experiment.py | 8 ++++---- test/highlevel/test_eval_procedure.py | 3 ++- tianshou/data/collector.py | 8 +++++++- tianshou/highlevel/agent.py | 7 ++++++- tianshou/highlevel/env.py | 4 +++- tianshou/highlevel/experiment.py | 8 ++++++-- 7 files changed, 31 insertions(+), 13 deletions(-) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index f537aee63..fd0313555 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -21,7 +21,7 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool, train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.NONE, - test_seeds: tuple[int] | None = None + test_seeds: tuple[int, ...] | None = None ): #makes mujoco envs, name is not really honest """Wrapper function for Mujoco env. @@ -90,7 +90,7 @@ def restore(self, event: RestoreEvent, world: World): class MujocoEnvFactory(EnvFactory): def __init__(self, task: str, seed: int, obs_norm=True, train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.NONE, - test_seeds: tuple[int]|None = None): + test_seeds: tuple[int, ...] | None = None): self.task = task self.seed = seed self.obs_norm = obs_norm @@ -107,7 +107,7 @@ def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousE train_seed_mechanism=self.train_seed_mechanism, test_seeds=self.test_seeds, ) - envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs) + envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs, test_seeds = self.test_seeds) if self.obs_norm: envs.set_persistence(MujocoEnvObsRmsPersistence()) return envs diff --git a/examples/mujoco/mujoco_ppo_hl_experiment.py b/examples/mujoco/mujoco_ppo_hl_experiment.py index 257649621..1b0244bf8 100644 --- a/examples/mujoco/mujoco_ppo_hl_experiment.py +++ b/examples/mujoco/mujoco_ppo_hl_experiment.py @@ -23,7 +23,7 @@ def main( - experiment_config: EvaluationProtocalExperimentConfig, + experiment_config: EvaluationProtocolExperimentConfig, task: str = "Ant-v4", buffer_size: int = 4096, hidden_sizes: Sequence[int] = (64, 64), @@ -34,7 +34,7 @@ def main( step_per_collect: int = 2048, repeat_per_collect: int = 10, batch_size: int = 64, - training_num: int = 4, #64, + training_num: int = 3, #64, test_num: int = 2, #10, rew_norm: bool = True, vf_coef: float = 0.25, @@ -49,7 +49,7 @@ def main( norm_adv: bool = False, recompute_adv: bool = True, ): - + b = 5 log_name = os.path.join(task, "ppo", str(experiment_config.seed), datetime_tag()) sampling_config = SamplingConfig( @@ -96,7 +96,7 @@ def main( .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) .build() ) - experiment.run(log_name) + experiment.run(log_name, experiment_config.record_seed_of_transition_to_buffer_test) if __name__ == "__main__": diff --git a/test/highlevel/test_eval_procedure.py b/test/highlevel/test_eval_procedure.py index 4c96ef022..de4ce6f53 100644 --- a/test/highlevel/test_eval_procedure.py +++ b/test/highlevel/test_eval_procedure.py @@ -66,5 +66,6 @@ def test_experiment_builder_continuous_default_params(experiment_config): sampling_config=sampling_config, ) experiment = ppo.build() - experiment.run("test") + experiment.run("test", + record_seed_of_transition_to_buffer_test=experiment_config.record_seed_of_transition_to_buffer_test) print(experiment) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index f188f3ca0..d096823e2 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -94,6 +94,8 @@ def __init__( buffer: ReplayBuffer | None = None, preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None, exploration_noise: bool = False, + record_seed_of_transition_to_buffer_test: bool = False, + test_seeds: tuple[int, ...] | None = None, ) -> None: super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): @@ -111,6 +113,8 @@ def __init__( self.data: RolloutBatchProtocol # avoid creating attribute outside __init__ self.reset(False) + self.record_seed_of_transition_to_buffer_test = record_seed_of_transition_to_buffer_test + self.test_seeds = test_seeds def _assign_buffer(self, buffer: ReplayBuffer | None) -> None: """Check if the buffer matches the constraint.""" @@ -317,7 +321,9 @@ def collect( ready_env_ids, ) done = np.logical_or(terminated, truncated) - + if self.record_seed_of_transition_to_buffer_test: + for env_id, active_id in enumerate(ready_env_ids): + info[env_id]["seed"] = self.test_seeds[active_id] self.data.update( obs_next=obs_next, rew=rew, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 727bfb244..61639c827 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -93,6 +93,8 @@ def create_train_test_collector( self, policy: BasePolicy, envs: Environments, + record_seed_of_transition_to_buffer_test: bool = False, + test_seeds: tuple[int, ...] | None = None, ) -> tuple[Collector, Collector]: buffer_size = self.sampling_config.buffer_size train_envs = envs.train_envs @@ -113,7 +115,10 @@ def create_train_test_collector( ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, envs.test_envs) + test_collector = Collector(policy, envs.test_envs, + record_seed_of_transition_to_buffer_test=record_seed_of_transition_to_buffer_test, + test_seeds = test_seeds) + self.test_seeds = test_seeds if self.sampling_config.start_timesteps > 0: log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index c78b72c6c..21fc9f8fb 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -147,9 +147,11 @@ def get_type(self) -> EnvType: class ContinuousEnvironments(Environments): """Represents (vectorized) continuous environments.""" - def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv, + test_seeds: tuple[int, ...]|None = None): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) + self.test_seeds = test_seeds @staticmethod def from_factory( diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index b346f84b8..e0b18aaaf 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum from pprint import pformat -from typing import Self, Literal +from typing import Self, Literal, Tuple import numpy as np import torch @@ -134,8 +134,9 @@ class EvaluationProtocolExperimentConfig(ExperimentConfig): train_seed_mechanism: TrainSeedMechanism = TrainSeedMechanism.CONSECUTIVE """Whether all train seeds are supposed to be the same or consecutive numbers, compare seed function in venvs.py """ - test_seeds : tuple[int] = (22,49,1995,123456) + test_seeds : tuple[int,...] = (4,22) """Set of seeds to use during testing""" + record_seed_of_transition_to_buffer_test: bool = True def __post_init__(self): assert set(self.test_seeds).isdisjoint({self.seed}) #todo adapt this for more complex version of train and test seed sets @@ -213,6 +214,7 @@ def run( self, experiment_name: str | None = None, logger_run_id: str | None = None, + record_seed_of_transition_to_buffer_test: bool = False, ) -> ExperimentResult: """Run the experiment and return the results. @@ -283,6 +285,8 @@ def run( train_collector, test_collector = self.agent_factory.create_train_test_collector( policy, envs, + record_seed_of_transition_to_buffer_test =record_seed_of_transition_to_buffer_test, + test_seeds = envs.test_seeds ) # create context object with all relevant instances (except trainer; added later) From 5063f1a69b051f3a745d25e022f127ffe69ae9a7 Mon Sep 17 00:00:00 2001 From: Robert Mueller <2robert.mueller@gmail.com> Date: Fri, 26 Jan 2024 11:24:10 +0100 Subject: [PATCH 4/4] load tensorboard logs into python using tbparse --- tianshou/utils/evaluation_protocol_analysis.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 tianshou/utils/evaluation_protocol_analysis.py diff --git a/tianshou/utils/evaluation_protocol_analysis.py b/tianshou/utils/evaluation_protocol_analysis.py new file mode 100644 index 000000000..b7819a900 --- /dev/null +++ b/tianshou/utils/evaluation_protocol_analysis.py @@ -0,0 +1,12 @@ +from tbparse import SummaryReader + +def read_data(log_dir): + reader = SummaryReader(log_dir) + print(reader) + df = reader.scalars + print(df['tag']) + + +if __name__ == '__main__': + log_dir = "examples/mujoco/log/Ant-v4/ppo/42/20240126-105848" + read_data(log_dir) \ No newline at end of file