From bbc36d014e03be76d989684c9600a0194779d85b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 19:44:42 +0100 Subject: [PATCH 01/56] Fix missing reset in discrete_dqn --- examples/discrete/discrete_dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 4e52f4ce2..bcf3e7f18 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -83,7 +83,7 @@ def stop_fn(mean_rewards: float) -> bool: # watch performance policy.set_eps(eps_test) collector = ts.data.Collector[CollectStats](policy, env, exploration_noise=True) - collector.collect(n_episode=100, render=1 / 35) + collector.collect(n_episode=100, render=1 / 35, reset_before_collect=True) if __name__ == "__main__": From b5665e31644690659331d8a295cf40f3415ff6eb Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 5 Mar 2025 00:15:56 +0100 Subject: [PATCH 02/56] ActorFactoryDefault: Fix hidden sizes and activation not being passed on for discrete case --- tianshou/highlevel/module/actor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 4a1fe5c2e..ceb1262f7 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -140,7 +140,8 @@ def _create_factory(self, envs: Environments) -> ActorFactory: raise ValueError(self.continuous_actor_type) elif env_type == EnvType.DISCRETE: factory = ActorFactoryDiscreteNet( - self.DEFAULT_HIDDEN_SIZES, + self.hidden_sizes, + activation=self.hidden_activation, softmax_output=self.discrete_softmax, ) else: From decb416ee95fdc37d0b46cd8d294ff55cb2dfd19 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 4 Mar 2025 22:30:17 +0100 Subject: [PATCH 03/56] ExperimentConfig: Do not inherit from anything (breaks jsonargparse auto-handling with defaults) --- tianshou/highlevel/experiment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index c0be23dca..a908e1d06 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -118,7 +118,7 @@ @dataclass -class ExperimentConfig(ToStringMixin, DataclassPPrintMixin): +class ExperimentConfig: """Generic config for setting up the experiment, not RL or training specific.""" seed: int = 42 From 76a25d1225cf88dfd434e86aefd72ad9ea13575d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 01:52:32 +0100 Subject: [PATCH 04/56] AutoAlphaFactoryDefault: Differentiate discrete and continuous action spaces and allow coefficient to be modified, adding an informative docstring (previous implementation was reasonable only for continuous action spaces) Adjust parametrisation to match procedural example in atari_sac_hl --- examples/atari/atari_sac_hl.py | 5 +++-- tianshou/highlevel/params/alpha.py | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 124def768..211e07d57 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -45,7 +45,6 @@ def main( training_num: int = 10, test_num: int = 10, frames_stack: int = 4, - save_buffer_name: str | None = None, # TODO add support in high-level API? icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, @@ -84,7 +83,9 @@ def main( critic2_lr=critic_lr, gamma=gamma, tau=tau, - alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, + alpha=AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) + if auto_alpha + else alpha, estimation_step=n_step, ), ) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 2b662eb44..70fff0942 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -21,8 +21,18 @@ def create_auto_alpha( class AutoAlphaFactoryDefault(AutoAlphaFactory): - def __init__(self, lr: float = 3e-4): + def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): + """ + :param lr: the learning rate for the optimizer of the alpha parameter + :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; + The base value being scaled is dim(A) for continuous action spaces and log(|A|) for discrete action spaces, + i.e. with the default coefficient -1, we obtain -dim(A) and -log(dim(A)) for continuous and discrete action + spaces respectively, which gives a reasonable trade-off between exploration and exploitation. + For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); + 1.0 would give full entropy exploration. + """ self.lr = lr + self.target_entropy_coefficient = target_entropy_coefficient def create_auto_alpha( self, @@ -30,7 +40,11 @@ def create_auto_alpha( optim_factory: OptimizerFactory, device: TDevice, ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: - target_entropy = float(-np.prod(envs.get_action_shape())) + action_dim = np.prod(envs.get_action_shape()) + if envs.get_type().is_continuous(): + target_entropy = self.target_entropy_coefficient * float(action_dim) + else: + target_entropy = self.target_entropy_coefficient * np.log(action_dim) log_alpha = torch.zeros(1, requires_grad=True, device=device) alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) return target_entropy, log_alpha, alpha_optim From eeb6610506b2f0e1cb8ef7d78a666f1f5253c238 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 7 Mar 2025 23:26:51 +0100 Subject: [PATCH 05/56] Use DummyVectorEnv instead of Subproc in test_a2c_with_il --- test/discrete/test_a2c_with_il.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 192f24c24..4aceba169 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.env import DummyVectorEnv from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer @@ -165,9 +165,8 @@ def stop_fn(mean_rewards: float) -> bool: seed=args.seed, ) else: - il_env = SubprocVectorEnv( + il_env = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - context="fork", ) il_env.seed(args.seed) From 8dbf0bfb2b90f23b76b6fdbfe2df1836d3ba5bbe Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 17 Mar 2025 13:10:42 +0100 Subject: [PATCH 06/56] Fix misleading docstring and corresponding errors pertaining to optim in NPG and TRPO * Parameter optim must not include the actor parameters (as they are updated via natural gradients that are computed internally) * Fix incorrect optimizer instantiation in high-level API --- tianshou/highlevel/agent.py | 21 ++++++++++++++++++++- tianshou/policy/modelfree/npg.py | 3 ++- tianshou/policy/modelfree/trpo.py | 3 ++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 2ad533983..a023f0190 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -320,6 +320,10 @@ def __init__( def _get_policy_class(self) -> type[TPolicy]: pass + @abstractmethod + def _include_actor_in_optim(self) -> bool: + pass + def create_actor_critic_module_opt( self, envs: Environments, @@ -329,7 +333,10 @@ def create_actor_critic_module_opt( actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) actor_critic = ActorCritic(actor, critic) - optim = self.optim_factory.create_optimizer(actor_critic, lr) + if self._include_actor_in_optim(): + optim = self.optim_factory.create_optimizer(actor_critic, lr) + else: + optim = self.optim_factory.create_optimizer(critic, lr) return ActorCriticOpt(actor_critic, optim) @typing.no_type_check @@ -356,21 +363,33 @@ def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): + def _include_actor_in_optim(self) -> bool: + return True + def _get_policy_class(self) -> type[A2CPolicy]: return A2CPolicy class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): + def _include_actor_in_optim(self) -> bool: + return True + def _get_policy_class(self) -> type[PPOPolicy]: return PPOPolicy class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): + def _include_actor_in_optim(self) -> bool: + return False + def _get_policy_class(self) -> type[NPGPolicy]: return NPGPolicy class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): + def _include_actor_in_optim(self) -> bool: + return False + def _get_policy_class(self) -> type[TRPOPolicy]: return TRPOPolicy diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 9e04d3feb..005454396 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -38,7 +38,8 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer for the critic network only. The actor network + is optimized via natural gradients internally. :param dist_fn: distribution class for computing the action. :param action_space: env's action space :param optim_critic_iters: Number of times to optimize critic network per update. diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index e7aa5cfd5..51a2d7cf0 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -32,7 +32,8 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer for the critic network only. The actor network + is optimized via natural gradients internally. :param dist_fn: distribution class for computing the action. :param action_space: env's action space :param max_kl: max kl-divergence used to constrain each actor network update. From 17133314c90eef2527e5436ec93cf7454dc4a8b2 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 22 Apr 2025 22:18:33 +0200 Subject: [PATCH 07/56] Update changelog --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c032e5553..06a196377 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,17 @@ - Custom scoring now supported for selecting the best model. #1202 - highlevel: - `DiscreteSACExperimentBuilder`: Expose method `with_actor_factory_default` #1248 #1250 + - `ActorFactoryDefault`: Fix parameters for hidden sizes and activation not being + passed on in the discrete case (affects `with_actor_factory_default` method of experiment builders) + - `ExperimentConfig`: Do not inherit from other classes, as this breaks automatic handling by + `jsonargparse` when the class is used to define interfaces (as in high-level API examples) + - `AutoAlphaFactoryDefault`: Differentiate discrete and continuous action spaces + and allow coefficient to be modified, adding an informative docstring + (previous implementation was reasonable only for continuous action spaces) + - Adjust usage in `atari_sac_hl` example accordingly. + - `NPGAgentFactory`, `TRPOAgentFactory`: Fix optimizer instantiation including the actor parameters + (which was misleadingly suggested in the docstring in the respective policy classes; docstrings were fixed), + as the actor parameters are intended to be handled via natural gradients internally ### Breaking Changes From 528fd2cfbf871fdf4b55441c06dc72058b224e0a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 22 Apr 2025 23:08:40 +0200 Subject: [PATCH 08/56] `BaseTrainer.run` and `__iter__`: Resetting was never optional prior to running the trainer, yet recently introduced parameter `reset_prior_to_run` of `run` suggested that it was optional. But it was not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)` regardless. The parameter was removed; instead, the parameters of `run` now mirror the parameters of `reset`, and the implicit `reset` call in `__iter__` was removed. This aligns with upcoming changes in Tianshou v2.0.0. --- CHANGELOG.md | 7 +++++++ tianshou/trainer/base.py | 17 +++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06a196377..c3e440ca2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,13 @@ ### Breaking Changes +- trainer: + - `BaseTrainer.run` and `__iter__`: Resetting was never optional prior to running the trainer, + yet recently introduced parameter `reset_prior_to_run` of `run` suggested that it was optional. + But it was not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)` + regardless. The parameter was removed; instead, the parameters of `run` now mirror the parameters of `reset`, + and the implicit `reset` call in `__iter__` was removed. + This aligns with upcoming changes in Tianshou v2.0.0. - data: - stats: - `InfoStats` has a new non-optional field `best_score` which is used diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 355c9d33b..d4193de5a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -309,7 +309,6 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No self.iter_num = 0 def __iter__(self): # type: ignore - self.reset(reset_collectors=True, reset_buffer=False) return self def __next__(self) -> EpochStats: @@ -611,19 +610,21 @@ def policy_update_fn( stats of the whole dataset """ - def run(self, reset_prior_to_run: bool = True, reset_buffer: bool = False) -> InfoStats: + def run(self, reset_collectors: bool = True, reset_buffer: bool = False) -> InfoStats: """Consume iterator. See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque). - :param reset_prior_to_run: whether to reset collectors prior to run - :param reset_buffer: only has effect if `reset_prior_to_run` is True. - Then it will also reset the buffer. This is usually not necessary, use - with caution. + :param reset_collectors: whether to reset the collectors prior to starting the training process. + Specifically, this will reset the environments in the collectors (starting new episodes), + and the statistics stored in the collector. Whether the contained buffers will be reset/cleared + is determined by the `reset_buffer` parameter. + :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the + contained buffers as well. + This has no effect if `reset_collectors` is False. """ - if reset_prior_to_run: - self.reset(reset_buffer=reset_buffer) + self.reset(reset_collectors=reset_collectors, reset_buffer=reset_buffer) try: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque From 4f17673a30b5265254df2352b58cf14b7e209724 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 22 Apr 2025 19:01:49 +0200 Subject: [PATCH 09/56] Add basic implementation for determinism tests --- pyproject.toml | 3 +- test/determinism_test.py | 34 ++++ test/discrete/test_dqn.py | 12 +- test/discrete/test_pg.py | 12 +- test/discrete/test_ppo.py | 12 +- tianshou/data/collector.py | 2 + tianshou/policy/modelfree/pg.py | 5 +- tianshou/trainer/base.py | 28 +++ tianshou/utils/determinism.py | 332 ++++++++++++++++++++++++++++++++ 9 files changed, 432 insertions(+), 8 deletions(-) create mode 100644 test/determinism_test.py create mode 100644 tianshou/utils/determinism.py diff --git a/pyproject.toml b/pyproject.toml index 5640f948b..d773bab56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,7 +181,8 @@ ignore = [ "PLW2901", # overwrite vars in loop "B027", # empty and non-abstract method in abstract class "D404", # It's fine to start with "This" in docstrings - "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx + "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx + "B023", # forbids function using loop variable without explicit binding ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all diff --git a/test/determinism_test.py b/test/determinism_test.py new file mode 100644 index 000000000..084934434 --- /dev/null +++ b/test/determinism_test.py @@ -0,0 +1,34 @@ +from argparse import Namespace +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from tianshou.utils.determinism import TraceDeterminismTest, TraceLoggerContext + + +class AlgorithmDeterminismTest: + def __init__(self, name: str, main_fn: Callable[[Namespace], Any], args: Namespace): + self.determinism_test = TraceDeterminismTest( + base_path=Path(__file__).parent / "resources" / "determinism", + ) + self.name = name + + def set(attr: str, value: Any) -> None: + old_value = getattr(args, attr) + if old_value is None: + raise ValueError(f"Attribute '{attr}' is not defined for args: {args}") + setattr(args, attr, value) + + set("epoch", 3) + set("step_per_epoch", 100) + set("device", "cpu") + set("training_num", 1) + set("test_num", 1) + + # run the actual process + with TraceLoggerContext() as trace: + main_fn(args) + self.log = trace.get_log() + + def run(self, update_snapshot: bool = False) -> None: + self.determinism_test.check(self.log, self.name, create_reference_result=update_snapshot) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index f82aca1f6..61cfb6174 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -55,7 +56,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_dqn(args: argparse.Namespace = get_args()) -> None: +def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) @@ -153,7 +154,14 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_dqn_determinism() -> None: + main_fn = lambda args: test_dqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_dqn", main_fn, get_args()).run(update_snapshot=True) def test_pdqn(args: argparse.Namespace = get_args()) -> None: diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 8a681583d..67479954f 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -44,7 +45,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_pg(args: argparse.Namespace = get_args()) -> None: +def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -122,4 +123,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_pg_determinism() -> None: + main_fn = lambda args: test_pg(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_pg", main_fn, get_args()).run(update_snapshot=False) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 7e541fffb..2540c3934 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -57,7 +58,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -149,4 +150,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_ppo_determinism() -> None: + main_fn = lambda args: test_ppo(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_ppo", main_fn, get_args()).run(update_snapshot=False) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 514db4cc0..d615af4e8 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -32,6 +32,7 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.policy.base import episode_mc_return_to_go +from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode @@ -839,6 +840,7 @@ def _collect( # noqa: C901 last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, ) + TraceLogger.log(log, lambda: f"Action: {collect_action_computation_batch_R.act}") # Step 3 obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 80bcff672..01f059df8 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,3 +1,4 @@ +import logging import warnings from collections.abc import Callable from dataclasses import dataclass @@ -27,6 +28,9 @@ from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net.discrete import Actor +log = logging.getLogger(__name__) + + # Dimension Naming Convention # B - Batch Size # A - Action @@ -231,5 +235,4 @@ def learn( # type: ignore losses.append(loss.item()) loss_summary_stat = SequenceSummaryStats.from_sequence(losses) - return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index d4193de5a..eb37ecaa5 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -7,6 +7,7 @@ from functools import partial import numpy as np +import torch import tqdm from tianshou.data import ( @@ -27,6 +28,7 @@ LazyLogger, MovAvg, ) +from tianshou.utils.determinism import TraceLogger, torch_param_hash from tianshou.utils.logging import set_numerical_fields_to_precision from tianshou.utils.torch_utils import policy_within_training_step @@ -308,6 +310,27 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No self.stop_fn_flag = False self.iter_num = 0 + self._log_params(self.policy) + + def _log_params(self, module: torch.nn.Module) -> None: + """Logs the parameters of the module to the trace logger by subcomponent (if the trace logger is enabled).""" + if not TraceLogger.is_enabled: + return + + def module_has_params(m: torch.nn.Module) -> bool: + return any(p.requires_grad for p in m.parameters()) + + relevant_modules = {} + for name, submodule in module.named_children(): + if module_has_params(submodule): + relevant_modules[name] = submodule + + for name, module in sorted(relevant_modules.items()): + TraceLogger.log( + log, + lambda: f"Params[{name}]: {torch_param_hash(module)}", + ) + def __iter__(self): # type: ignore return self @@ -331,6 +354,7 @@ def __next__(self) -> EpochStats: train_stat: CollectStatsBase while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: train_stat, update_stat, self.stop_fn_flag = self.training_step() + self._log_params(self.policy) if isinstance(train_stat, CollectStats): pbar_data_dict = { @@ -498,6 +522,10 @@ def _collect_training_data(self) -> CollectStats: n_step=self.step_per_collect, n_episode=self.episode_per_collect, ) + TraceLogger.log( + log, + lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes", + ) if self.train_collector.buffer.hasnull(): from tianshou.data.collector import EpisodeRolloutHook diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py new file mode 100644 index 000000000..e0c3aa18e --- /dev/null +++ b/tianshou/utils/determinism.py @@ -0,0 +1,332 @@ +import difflib +import inspect +import os +import re +import time +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from io import StringIO +from pathlib import Path +from typing import Self + +import torch +from sensai.util import logging +from sensai.util.pickle import dump_pickle, load_pickle + + +def format_log_message( + logger: logging.Logger, + level: int, + msg: str, + formatter: logging.Formatter, + stacklevel: int = 1, +) -> str: + """ + Formats a log message as it would have been created by `logger.log(level, msg)` with the given formatter. + + :param logger: the logger + :param level: the log level + :param msg: the message + :param formatter: the formatter + :param stacklevel: the stack level of the function to report as the generator + :return: the formatted log message (not including trailing newline) + """ + frame_info = inspect.stack()[stacklevel] + pathname = frame_info.filename + lineno = frame_info.lineno + func = frame_info.function + + record = logger.makeRecord( + name=logger.name, + level=level, + fn=pathname, + lno=lineno, + msg=msg, + args=(), + exc_info=None, + func=func, + extra=None, + ) + record.created = time.time() + record.asctime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created)) + + return formatter.format(record) + + +class TraceLogger: + """Supports the collection of behavioural trace logs, which can, in particular, be used for determinism tests.""" + + is_enabled = False + """ + whether the trace logger is enabled. + + NOTE: The preferred way to enable this is via the context manager. + """ + MESSAGE_TAG = "[TRACE]" + """ + a tag which is added at the beginning of log messages generated by this logger + """ + LOG_LEVEL = logging.DEBUG + log_buffer: StringIO | None = None + log_formatter: logging.Formatter | None = None + + @classmethod + def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> None: + """ + Logs a message intended for tracing agent-env interaction, which is enabled via + `TraceAgentEnvLoggerContext`. + + :param logger: the logger to use for the actual logging + :param message_generator: function which generates the log message (which may be expensive); + if logging is disabled, the function will not be called. + """ + if not cls.is_enabled: + return + + msg = message_generator() + msg = cls.MESSAGE_TAG + " " + msg + + # Log with caller's frame info + logger.log(logging.DEBUG, msg, stacklevel=2) + + # If a dedicated memory buffer is configured, also store the message there + if cls.log_buffer is not None: + msg_formatted = format_log_message( + logger, + logging.DEBUG, + msg, + cls.log_formatter, + stacklevel=2, + ) + cls.log_buffer.write(msg_formatted + "\n") + + +@dataclass +class TraceLog: + log_lines: list[str] + + def save_log(self, path: str) -> None: + with open(path, "w") as f: + for line in self.log_lines: + f.write(line + "\n") + + def print_log(self) -> None: + for line in self.log_lines: + print(line) + + def get_full_log(self) -> str: + return "\n".join(self.log_lines) + + def reduce_log_to_messages(self) -> "TraceLog": + """ + Removes logger names and function names from the log entries, such that each log message + contains only the main text message itself (starting with the content after the logger's tag). + + :return: the result with reduced log messages + """ + lines = [] + tag = re.escape(TraceLogger.MESSAGE_TAG) + for line in self.log_lines: + lines.append(re.sub(r".*" + tag, "", line)) + return TraceLog(lines) + + def filter_messages( + self, + required_messages: Sequence[str] = (), + optional_messages: Sequence[str] = (), + ) -> "TraceLog": + """ + Reduces the set of log messages to a set of core messages that indicate that the fundamental + trace is still the same (same actions, same states, same images). + + :param required_messages: message substrings to filter for; each message is required to appear at least once + (triggering exception otherwise) + :param optional_messages: additional messages fragments to filter for; these are not required + :return: the result with reduced log messages + """ + import numpy as np + + required_message_counters = np.zeros(len(required_messages)) + + def retain_line(line: str) -> bool: + for i, main_message in enumerate(required_messages): + if main_message in line: + required_message_counters[i] += 1 + return True + return any(add_message in line for add_message in optional_messages) + + lines = [] + for line in self.log_lines: + if retain_line(line): + lines.append(line) + + assert np.all( + required_message_counters > 0, + ), "Not all types of required messages were found in the trace. Were log messages changed?" + + return TraceLog(lines) + + +class TraceLoggerContext: + """ + A context manager which enables the trace logger. + Apart from enabling the logging, it can optionally create a memory log buffer, such that + getting the trace log is not strictly dependent on the logging system. + """ + + def __init__( + self, + enable_log_buffer: bool = True, + log_format: str = "%(name)s:%(funcName)s - %(message)s", + ) -> None: + """ + :param enable_log_buffer: whether to enable the dedicated log buffer for trace logs, whose contents + can, within the context of this manager, be accessed via method `get_log`. + :param log_format: the logger format string to use for the dedicated log buffer + """ + self._enable_log_buffer = enable_log_buffer + self._log_format: str = log_format + self._log_buffer: StringIO | None = None + + def __enter__(self) -> Self: + TraceLogger.is_enabled = True + + if self._enable_log_buffer: + TraceLogger.log_buffer = StringIO() + TraceLogger.log_formatter = logging.Formatter(self._log_format) + self._log_buffer = TraceLogger.log_buffer + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore + TraceLogger.is_enabled = False + TraceLogger.log_buffer = None + TraceLogger.log_formatter = None + + def get_log(self) -> TraceLog: + """:return: the full trace log that was captured if `enable_log_buffer` was enabled at construction""" + if self._log_buffer is None: + raise Exception( + "This method is only supported if the log buffer is enabled at construction", + ) + return TraceLog(log_lines=self._log_buffer.getvalue().split("\n")) + + +def torch_param_hash(module: torch.nn.Module) -> str: + """ + Computes a hash of the parameters of the given module; parameters not requiring gradients are ignored. + + :param module: a torch module + :return: a hex digest of the parameters of the module + """ + import hashlib + + hasher = hashlib.sha1() + for param in module.parameters(): + if param.requires_grad: + np_array = param.detach().cpu().numpy() + hasher.update(np_array.tobytes()) + return hasher.hexdigest() + + +class TraceDeterminismTest: + def __init__(self, base_path: Path, core_messages: Sequence[str] = ()): + """ + :param base_path: the directory where the reference results are stored (will be created if necessary) + :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core + """ + base_path.mkdir(parents=True, exist_ok=True) + self.base_path = base_path + self.core_messages = core_messages + + def check(self, result: TraceLog, name: str, create_reference_result: bool = False) -> None: + """ + Checks the given log against the reference result for the given name. + + :param result: the result to check + :param name: the name of the reference result + :param create_reference_result: whether update the reference result with the given result + """ + import pytest + + reference_result_path = self.base_path / f"{name}.pkl.bz2" + + if create_reference_result: + dump_pickle(result, reference_result_path) + + reference_result: TraceLog = load_pickle( + reference_result_path, + ) + + result_reduced = result.reduce_log_to_messages() + reference_result_reduced = reference_result.reduce_log_to_messages() + + results: list[tuple[TraceLog, str]] = [ + (reference_result_reduced, "expected"), + (result_reduced, "current"), + (reference_result, "expected_full"), + (result, "current_full"), + ] + + if self.core_messages: + result_main_messages = result_reduced.filter_messages( + required_messages=self.core_messages, + ) + reference_result_main_messages = reference_result_reduced.filter_messages( + required_messages=self.core_messages, + ) + results.extend( + [ + (reference_result_main_messages, "expected_core"), + (result_main_messages, "current_core"), + ], + ) + else: + result_main_messages = result_reduced + reference_result_main_messages = reference_result_reduced + + logs_equivalent = result_reduced.get_full_log() == reference_result_reduced.get_full_log() + if not logs_equivalent: + # save files for comparison + files = [] + for r, suffix in results: + path = os.path.abspath(f"determinism_{name}_{suffix}.txt") + r.save_log(path) + files.append(path) + + paths_str = "\n".join(files) + main_message = ( + f"Please inspect the changes by diffing the log files:\n{paths_str}\n" + f"If the changes are OK, enable the `create_reference_result` flag temporarily, " + "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" + ) + + # compute diff and add to message + num_diff_lines_to_show = 30 + for i, line in enumerate( + difflib.unified_diff( + reference_result_reduced.log_lines, + result_reduced.log_lines, + fromfile="expected.txt", + tofile="current.txt", + lineterm="", + ), + ): + if i == num_diff_lines_to_show: + break + main_message += line + "\n" + + core_messages_changed_only = ( + len(self.core_messages) > 0 + and result_main_messages.get_full_log() + == reference_result_main_messages.get_full_log() + ) + if core_messages_changed_only: + pytest.fail( + "The meta-agent training log has changed, but the core messages are still the same (so this " + f"probably isn't an issue). {main_message}", + ) + else: + pytest.fail( + f"The meta-agent training log has changed; even the core messages are different. {main_message}", + ) From d78f0ed06eec8880c872f9bc83345ccf7ae48bf1 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 22 Apr 2025 23:40:56 +0200 Subject: [PATCH 10/56] Log parameters of ActorCritic components separately --- tianshou/trainer/base.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index eb37ecaa5..87e9df41a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -314,6 +314,8 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No def _log_params(self, module: torch.nn.Module) -> None: """Logs the parameters of the module to the trace logger by subcomponent (if the trace logger is enabled).""" + from tianshou.utils.net.common import ActorCritic + if not TraceLogger.is_enabled: return @@ -321,9 +323,16 @@ def module_has_params(m: torch.nn.Module) -> bool: return any(p.requires_grad for p in m.parameters()) relevant_modules = {} - for name, submodule in module.named_children(): - if module_has_params(submodule): - relevant_modules[name] = submodule + + def gather_modules(m: torch.nn.Module) -> None: + for name, submodule in m.named_children(): + if isinstance(submodule, ActorCritic): + gather_modules(submodule) + else: + if module_has_params(submodule): + relevant_modules[name] = submodule + + gather_modules(module) for name, module in sorted(relevant_modules.items()): TraceLogger.log( From 5061c226f86e2c443af914fb39fe602b790bdb8f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 23 Apr 2025 15:15:40 +0200 Subject: [PATCH 11/56] Fix failure message --- tianshou/utils/determinism.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index e0c3aa18e..888c90a6c 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -323,10 +323,10 @@ def check(self, result: TraceLog, name: str, create_reference_result: bool = Fal ) if core_messages_changed_only: pytest.fail( - "The meta-agent training log has changed, but the core messages are still the same (so this " + "The behaviour log has changed, but the core messages are still the same (so this " f"probably isn't an issue). {main_message}", ) else: pytest.fail( - f"The meta-agent training log has changed; even the core messages are different. {main_message}", + f"The behaviour log has changed; even the core messages are different. {main_message}", ) From c88f84401d0caebe1ce7c08cd52e300ded6b7604 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 23 Apr 2025 18:01:28 +0200 Subject: [PATCH 12/56] Add TorchDeterministicModeContext --- test/determinism_test.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/test/determinism_test.py b/test/determinism_test.py index 084934434..308606bb0 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -1,3 +1,4 @@ +import torch from argparse import Namespace from collections.abc import Callable from pathlib import Path @@ -6,6 +7,19 @@ from tianshou.utils.determinism import TraceDeterminismTest, TraceLoggerContext +class TorchDeterministicModeContext: + def __init__(self, mode="default"): + self.new_mode = mode + self.original_mode = None + + def __enter__(self): + self.original_mode = torch.get_deterministic_debug_mode() + torch.set_deterministic_debug_mode(self.new_mode) + + def __exit__(self, exc_type, exc_value, traceback): + torch.set_deterministic_debug_mode(self.original_mode) + + class AlgorithmDeterminismTest: def __init__(self, name: str, main_fn: Callable[[Namespace], Any], args: Namespace): self.determinism_test = TraceDeterminismTest( @@ -27,7 +41,8 @@ def set(attr: str, value: Any) -> None: # run the actual process with TraceLoggerContext() as trace: - main_fn(args) + with TorchDeterministicModeContext(): + main_fn(args) self.log = trace.get_log() def run(self, update_snapshot: bool = False) -> None: From 9fbfd9930e03033412f73846c0897f1ece01abe0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 23 Apr 2025 19:44:52 +0200 Subject: [PATCH 13/56] Devcontainer --- .devcontainer/devcontainer.json | 22 +++++++++++++++++ .dockerignore | 14 +++++++++++ Dockerfile | 42 +++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+) create mode 100644 .devcontainer/devcontainer.json create mode 100644 .dockerignore create mode 100644 Dockerfile diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..536443860 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,22 @@ +{ + "name": "Tianshou", + "dockerFile": "../Dockerfile", + "workspaceFolder": "/workspaces/tianshou", + "runArgs": ["--shm-size=1g"], + "customizations": { + "vscode": { + "settings": { + "terminal.integrated.shell.linux": "/bin/bash", + "python.pythonPath": "/usr/local/bin/python" + }, + "extensions": [ + "ms-python.python", + "ms-toolsai.jupyter", + "ms-python.vscode-pylance" + ] + } + }, + "forwardPorts": [], + "postCreateCommand": "poetry install --with dev", + "remoteUser": "root" + } \ No newline at end of file diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..fa5050fe5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,14 @@ +data +logs +test/log +docs/jupyter_execute +docs/.jupyter_cache +.lsp +.clj-kondo +docs/_build +coverage* +__pycache__ +*.egg-info +*.egg +.*cache +dist \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..4e3827b26 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +# Use the official Python image for the base image. +FROM --platform=linux/amd64 python:3.11-slim + +# Set environment variables to make Python print directly to the terminal and avoid .pyc files. +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Install system dependencies required for the project. +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + build-essential \ + git \ + wget \ + unzip \ + libvips-dev \ + gnupg2 \ + && rm -rf /var/lib/apt/lists/* + + +# Install pipx. +RUN python3 -m pip install --no-cache-dir pipx \ + && pipx ensurepath + +# Add poetry to the path +ENV PATH="${PATH}:/root/.local/bin" + +# Install the latest version of Poetry using pipx. +RUN pipx install poetry + +# Set the working directory. IMPORTANT: can't be changed as needs to be in sync to the dir where the project is cloned +# to in the codespace +WORKDIR /workspaces/tianshou + +# Copy the pyproject.toml and poetry.lock files (if available) into the image. +COPY pyproject.toml poetry.lock* README.md /workspaces/tianshou/ + +RUN poetry config virtualenvs.create false +RUN poetry install --no-root --with dev + +# The entrypoint will perform an editable install, it is expected that the code is mounted in the container then +# If you don't want to mount the code, you should override the entrypoint +ENTRYPOINT ["/bin/bash", "-c", "poetry install --with dev && poetry run jupyter trust notebooks/*.ipynb docs/02_notebooks/*.ipynb && $0 $@"] \ No newline at end of file From cf0e0d83d9643a694ca3eca71a89b5299b18cdb7 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 12:28:20 +0200 Subject: [PATCH 14/56] Update sensai-utils to 1.4.0 --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 583b90756..2cd8e2ed7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5382,13 +5382,13 @@ win32 = ["pywin32"] [[package]] name = "sensai-utils" -version = "1.2.1" +version = "1.4.0" description = "Utilities from sensAI, the Python library for sensible AI" optional = false python-versions = "*" files = [ - {file = "sensai_utils-1.2.1-py3-none-any.whl", hash = "sha256:222e60d9f9d371c9d62ffcd1e6def1186f0d5243588b0b5af57e983beecc95bb"}, - {file = "sensai_utils-1.2.1.tar.gz", hash = "sha256:4d8ca94179931798cef5f920fb042cbf9e7d806c0026b02afb58d0f72211bf27"}, + {file = "sensai_utils-1.4.0-py3-none-any.whl", hash = "sha256:ed6fc57552620e43b33cf364ea0bc0fd7df39391069dd7b621b113ef55547507"}, + {file = "sensai_utils-1.4.0.tar.gz", hash = "sha256:2d32bdcc91fd1428c5cae0181e98623142d2d5f7e115e23d585a842dd9dc59ba"}, ] [package.dependencies] @@ -6896,4 +6896,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "1ea1b72b90269fd86b81b1443785085618248ccf5b62506a166b879115749171" +content-hash = "bff3f4f8cc0d8196ea162a799472c7179486109d30968aa7d1b96b40016a459f" diff --git a/pyproject.toml b/pyproject.toml index d773bab56..178222f9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ overrides = "^7.4.0" packaging = "*" pandas = ">=2.0.0" pettingzoo = "^1.22" -sensai-utils = "^1.2.1" +sensai-utils = "^1.4.0" tensorboard = "^2.5.0" # Torch 2.0.1 causes problems, see https://github.com/pytorch/pytorch/issues/100974 torch = "^2.0.0, !=2.0.1, !=2.1.0" From 3ed3c20b58b529313366d53fb6eb3a600d4e8004 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 13:10:07 +0200 Subject: [PATCH 15/56] Support new mode of operation determinism tests, where each developer is responsible for creating the snapshot(s) on the original branch and then compare with results on a modified branch. Add writing of a log file for determinism tests. --- .gitignore | 5 ++- test/determinism_test.py | 48 +++++++++++++++++--- test/discrete/test_dqn.py | 2 +- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 2 +- tianshou/utils/determinism.py | 83 +++++++++++++++++++++++++---------- 6 files changed, 111 insertions(+), 31 deletions(-) diff --git a/.gitignore b/.gitignore index e63e24b00..c6e843c4d 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,7 @@ docs/conf.py # temporary scripts (for ad-hoc testing), temp folder /temp -/temp*.py \ No newline at end of file +/temp*.py + +# determinism test snapshots +/test/resources/determinism/ diff --git a/test/determinism_test.py b/test/determinism_test.py index 308606bb0..181537c30 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -1,9 +1,11 @@ -import torch from argparse import Namespace from collections.abc import Callable from pathlib import Path from typing import Any +import pytest +import torch + from tianshou.utils.determinism import TraceDeterminismTest, TraceLoggerContext @@ -21,9 +23,36 @@ def __exit__(self, exc_type, exc_value, traceback): class AlgorithmDeterminismTest: + """ + Represents a determinism test for Tianshou's RL algorithms. + + A test using this class should be added for every algorithm in Tianshou. + Then, when making changes to one or more algorithms (e.g. refactoring), run the respective tests + on the old branch (creating snapshots) and then on the new branch that contains the changes + (comparing with the snapshots). + + Intended usage is therefore: + + 1. On the old branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=True and run the tests. + 2. On the new branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=False and run the tests. + 3. Inspect determinism_tests.log + """ + + ENABLED = False + """ + whether determinism tests are enabled. + """ + FORCE_SNAPSHOT_UPDATE = False + """ + whether to force the update/creation of snapshots for every test. + Enable this when running on the "old" branch and you want to prepare the snapshots + for a comparison with the "new" branch. + """ + def __init__(self, name: str, main_fn: Callable[[Namespace], Any], args: Namespace): self.determinism_test = TraceDeterminismTest( base_path=Path(__file__).parent / "resources" / "determinism", + log_filename="determinism_tests.log", ) self.name = name @@ -39,11 +68,20 @@ def set(attr: str, value: Any) -> None: set("training_num", 1) set("test_num", 1) + self.args = args + self.main_fn = main_fn + + def run(self, update_snapshot: bool = False) -> None: + if not self.ENABLED: + pytest.skip("Algorithm determinism tests are disabled.") + + if self.FORCE_SNAPSHOT_UPDATE: + update_snapshot = True + # run the actual process with TraceLoggerContext() as trace: with TorchDeterministicModeContext(): - main_fn(args) - self.log = trace.get_log() + self.main_fn(self.args) + log = trace.get_log() - def run(self, update_snapshot: bool = False) -> None: - self.determinism_test.check(self.log, self.name, create_reference_result=update_snapshot) + self.determinism_test.check(log, self.name, create_reference_result=update_snapshot) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 61cfb6174..eeb5e8207 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -161,7 +161,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: def test_dqn_determinism() -> None: main_fn = lambda args: test_dqn(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_dqn", main_fn, get_args()).run(update_snapshot=True) + AlgorithmDeterminismTest("discrete_dqn", main_fn, get_args()).run() def test_pdqn(args: argparse.Namespace = get_args()) -> None: diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 67479954f..a4fb28300 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -130,4 +130,4 @@ def stop_fn(mean_rewards: float) -> bool: def test_pg_determinism() -> None: main_fn = lambda args: test_pg(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_pg", main_fn, get_args()).run(update_snapshot=False) + AlgorithmDeterminismTest("discrete_pg", main_fn, get_args()).run() diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 2540c3934..4226caf8f 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -157,4 +157,4 @@ def stop_fn(mean_rewards: float) -> bool: def test_ppo_determinism() -> None: main_fn = lambda args: test_ppo(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_ppo", main_fn, get_args()).run(update_snapshot=False) + AlgorithmDeterminismTest("discrete_ppo", main_fn, get_args()).run() diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index 888c90a6c..a4d79e73b 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -11,6 +11,7 @@ import torch from sensai.util import logging +from sensai.util.git import GitStatus, git_status from sensai.util.pickle import dump_pickle, load_pickle @@ -230,49 +231,69 @@ def torch_param_hash(module: torch.nn.Module) -> str: class TraceDeterminismTest: - def __init__(self, base_path: Path, core_messages: Sequence[str] = ()): + def __init__( + self, + base_path: Path, + core_messages: Sequence[str] = (), + log_filename: str | None = None, + ) -> None: """ :param base_path: the directory where the reference results are stored (will be created if necessary) :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core + :param log_filename: the name of the log file to which results are to be written (if any) """ base_path.mkdir(parents=True, exist_ok=True) self.base_path = base_path self.core_messages = core_messages + self.log_filename = log_filename + + @dataclass(kw_only=True) + class Result: + git_status: GitStatus + log: TraceLog - def check(self, result: TraceLog, name: str, create_reference_result: bool = False) -> None: + def check( + self, + current_log: TraceLog, + name: str, + create_reference_result: bool = False, + ) -> None: """ Checks the given log against the reference result for the given name. - :param result: the result to check - :param name: the name of the reference result + :param current_log: the result to check + :param name: the name of the reference result; must be unique among all tests! :param create_reference_result: whether update the reference result with the given result """ import pytest reference_result_path = self.base_path / f"{name}.pkl.bz2" + current_git_status = git_status() if create_reference_result: - dump_pickle(result, reference_result_path) + current_result = self.Result(git_status=current_git_status, log=current_log) + dump_pickle(current_result, reference_result_path) - reference_result: TraceLog = load_pickle( + reference_result: TraceDeterminismTest.Result = load_pickle( reference_result_path, ) + reference_log = reference_result.log - result_reduced = result.reduce_log_to_messages() - reference_result_reduced = reference_result.reduce_log_to_messages() + current_log_reduced = current_log.reduce_log_to_messages() + reference_log_reduced = reference_log.reduce_log_to_messages() results: list[tuple[TraceLog, str]] = [ - (reference_result_reduced, "expected"), - (result_reduced, "current"), - (reference_result, "expected_full"), - (result, "current_full"), + (reference_log_reduced, "expected"), + (current_log_reduced, "current"), + (reference_log, "expected_full"), + (current_log, "current_full"), ] if self.core_messages: - result_main_messages = result_reduced.filter_messages( + result_main_messages = current_log_reduced.filter_messages( required_messages=self.core_messages, ) - reference_result_main_messages = reference_result_reduced.filter_messages( + reference_result_main_messages = reference_log_reduced.filter_messages( required_messages=self.core_messages, ) results.extend( @@ -282,11 +303,17 @@ def check(self, result: TraceLog, name: str, create_reference_result: bool = Fal ], ) else: - result_main_messages = result_reduced - reference_result_main_messages = reference_result_reduced + result_main_messages = current_log_reduced + reference_result_main_messages = reference_log_reduced + + status_passed = True + logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log() + if logs_equivalent: + status_passed = True + status_message = "OK" + else: + status_passed = False - logs_equivalent = result_reduced.get_full_log() == reference_result_reduced.get_full_log() - if not logs_equivalent: # save files for comparison files = [] for r, suffix in results: @@ -305,8 +332,8 @@ def check(self, result: TraceLog, name: str, create_reference_result: bool = Fal num_diff_lines_to_show = 30 for i, line in enumerate( difflib.unified_diff( - reference_result_reduced.log_lines, - result_reduced.log_lines, + reference_log_reduced.log_lines, + current_log_reduced.log_lines, fromfile="expected.txt", tofile="current.txt", lineterm="", @@ -322,11 +349,23 @@ def check(self, result: TraceLog, name: str, create_reference_result: bool = Fal == reference_result_main_messages.get_full_log() ) if core_messages_changed_only: - pytest.fail( + status_message = ( "The behaviour log has changed, but the core messages are still the same (so this " f"probably isn't an issue). {main_message}", ) else: - pytest.fail( + status_message = ( f"The behaviour log has changed; even the core messages are different. {main_message}", ) + + # write log message + if self.log_filename: + with open(self.log_filename, "a") as f: + hr = "-" * 100 + f.write(f"\n\n{hr}\nName: {name}\n") + f.write(f"Reference state: {reference_result.git_status}\n") + f.write(f"Current state: {current_git_status}\n") + f.write(f"Test result: {status_message}\n") + + if not status_passed: + pytest.fail(status_message) From 364814d66806a05db34917201b126d9f8d1710f5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 21:42:47 +0200 Subject: [PATCH 16/56] Add determinism test for DiscreteBCQ --- test/determinism_test.py | 19 +++++++++++++++++-- test/offline/test_discrete_bcq.py | 18 ++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/test/determinism_test.py b/test/determinism_test.py index 181537c30..7dfdb2dfb 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -49,7 +49,21 @@ class AlgorithmDeterminismTest: for a comparison with the "new" branch. """ - def __init__(self, name: str, main_fn: Callable[[Namespace], Any], args: Namespace): + def __init__( + self, + name: str, + main_fn: Callable[[Namespace], Any], + args: Namespace, + is_offline: bool = False, + ): + """ + :param name: the (unique!) name of the test + :param main_fn: the function to be called for the test + :param args: the arguments to be passed to the main function (some of which are overridden + for the test) + :param is_offline: whether the algorithm being tested is an offline algorithm and therefore + does not configure the number of training environments (`training_num`) + """ self.determinism_test = TraceDeterminismTest( base_path=Path(__file__).parent / "resources" / "determinism", log_filename="determinism_tests.log", @@ -65,7 +79,8 @@ def set(attr: str, value: Any) -> None: set("epoch", 3) set("step_per_epoch", 100) set("device", "cpu") - set("training_num", 1) + if not is_offline: + set("training_num", 1) set("test_num", 1) self.args = args diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index e69e0a1fa..ddb6ce672 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym @@ -36,7 +37,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--unlikely-action-threshold", type=float, default=0.6) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--update-per-epoch", type=int, default=2000) + parser.add_argument("--step-per-epoch", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) @@ -53,7 +54,9 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: +def test_discrete_bcq( + args: argparse.Namespace = get_args(), enable_assertions: bool = True +) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -155,7 +158,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, + step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, @@ -164,10 +167,17 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: test_discrete_bcq() args.resume = True test_discrete_bcq(args) + + +def test_discrete_bcq_determinism() -> None: + main_fn = lambda args: test_discrete_bcq(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_bcq", main_fn, get_args(), is_offline=True).run() From 7a8902a777ccf30c146c16841eae4ff5ad23e46f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 21:47:54 +0200 Subject: [PATCH 17/56] Fix message assignment --- tianshou/utils/determinism.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index a4d79e73b..c29c72898 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -351,11 +351,11 @@ def check( if core_messages_changed_only: status_message = ( "The behaviour log has changed, but the core messages are still the same (so this " - f"probably isn't an issue). {main_message}", + f"probably isn't an issue). {main_message}" ) else: status_message = ( - f"The behaviour log has changed; even the core messages are different. {main_message}", + f"The behaviour log has changed; even the core messages are different. {main_message}" ) # write log message From 60e8cead379ba357c05db932975d9b5f6d74542b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 22:09:46 +0200 Subject: [PATCH 18/56] Log TrainingStats with TraceLogger after every training step --- tianshou/trainer/base.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 87e9df41a..647ab2af3 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -360,24 +360,25 @@ def __next__(self) -> EpochStats: # perform n step_per_epoch steps_done_in_this_epoch = 0 with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t: - train_stat: CollectStatsBase + collect_stats: CollectStatsBase while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: - train_stat, update_stat, self.stop_fn_flag = self.training_step() + collect_stats, training_stats, self.stop_fn_flag = self.training_step() + TraceLogger.log(log, lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict()}") self._log_params(self.policy) - if isinstance(train_stat, CollectStats): + if isinstance(collect_stats, CollectStats): pbar_data_dict = { "env_step": str(self.env_step), "env_episode": str(self.env_episode), "rew": f"{self.last_rew:.2f}", "len": str(int(self.last_len)), - "n/ep": str(train_stat.n_collected_episodes), - "n/st": str(train_stat.n_collected_steps), + "n/ep": str(collect_stats.n_collected_episodes), + "n/st": str(collect_stats.n_collected_steps), } # t might be disabled, we track the steps manually - t.update(train_stat.n_collected_steps) - steps_done_in_this_epoch += train_stat.n_collected_steps + t.update(collect_stats.n_collected_steps) + steps_done_in_this_epoch += collect_stats.n_collected_steps if self.stop_fn_flag: t.set_postfix(**pbar_data_dict) @@ -386,7 +387,7 @@ def __next__(self) -> EpochStats: # Code should be restructured! pbar_data_dict = {} assert self.buffer, "No train_collector or buffer specified" - train_stat = CollectStatsBase( + collect_stats = CollectStatsBase( n_collected_steps=len(self.buffer), ) @@ -440,9 +441,9 @@ def __next__(self) -> EpochStats: # in case trainer is used with run(), epoch_stat will not be returned return EpochStats( epoch=self.epoch, - train_collect_stat=train_stat, + train_collect_stat=collect_stats, test_collect_stat=test_stat, - training_stat=update_stat, + training_stat=training_stats, info_stat=info_stat, ) From 57ec496fd683c60c44fdf1f393b8fa445f6eef49 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 22:27:48 +0200 Subject: [PATCH 19/56] Log sampled batch indices with TraceLogger when performing update --- tianshou/policy/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 066a23a3b..125bc84df 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -12,6 +12,7 @@ from numba import njit from numpy.typing import ArrayLike from overrides import override +from sensai.util.hash import pickle_hash from torch import nn from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as @@ -25,6 +26,7 @@ RolloutBatchProtocol, ) from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.determinism import TraceLogger from tianshou.utils.net.common import RandomActor from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode @@ -541,6 +543,7 @@ def update( return TrainingStats() # type: ignore[return-value] start_time = time.time() batch, indices = buffer.sample(sample_size) + TraceLogger.log(logger, lambda: f"Updating with batch: {pickle_hash(indices)}") self.updating = True batch = self.process_fn(batch, buffer, indices) with torch_train_mode(self): From 4c0699b52f7e9d4d7ee69cedf81ed1c9e24f147c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 22:30:17 +0200 Subject: [PATCH 20/56] Formatting --- test/offline/test_discrete_bcq.py | 3 ++- tianshou/trainer/base.py | 5 ++++- tianshou/utils/determinism.py | 4 +--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index ddb6ce672..ba79f48fd 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -55,7 +55,8 @@ def get_args() -> argparse.Namespace: def test_discrete_bcq( - args: argparse.Namespace = get_args(), enable_assertions: bool = True + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, ) -> None: # envs env = gym.make(args.task) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 647ab2af3..044c369d2 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -363,7 +363,10 @@ def __next__(self) -> EpochStats: collect_stats: CollectStatsBase while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: collect_stats, training_stats, self.stop_fn_flag = self.training_step() - TraceLogger.log(log, lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict()}") + TraceLogger.log( + log, + lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict()}", + ) self._log_params(self.policy) if isinstance(collect_stats, CollectStats): diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index c29c72898..9d713f1ef 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -354,9 +354,7 @@ def check( f"probably isn't an issue). {main_message}" ) else: - status_message = ( - f"The behaviour log has changed; even the core messages are different. {main_message}" - ) + status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" # write log message if self.log_filename: From c1f580e83ab1e9fd81186c3308c6b59768fd51dd Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 5 May 2025 22:41:15 +0200 Subject: [PATCH 21/56] ReplayBuffer: Establish determinism by using a well-defined RandomState --- tianshou/data/buffer/base.py | 18 +++++++++++++----- tianshou/data/buffer/manager.py | 4 ++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 5c4451d57..aad0c6f2b 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -3,6 +3,7 @@ import h5py import numpy as np +from sensai.util.pickle import setstate from tianshou.data import Batch from tianshou.data.batch import ( @@ -77,6 +78,7 @@ def __init__( ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, + random_seed: int = 42, **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError ) -> None: # TODO: why do we need this? Just for readout? @@ -96,12 +98,21 @@ def __init__( self._save_only_last_obs = save_only_last_obs self._sample_avail = sample_avail self._meta = cast(RolloutBatchProtocol, Batch()) + self._random_state = np.random.RandomState(random_seed) # Keep in sync with reset! self.last_index = np.array([0]) self._insertion_idx = self._size = 0 self._ep_return, self._ep_len, self._ep_start_idx = 0.0, 0, 0 + def __setstate__(self, state: dict[str, Any]) -> None: + setstate( + ReplayBuffer, + self, + state, + new_default_properties={"_random_state": np.random.RandomState(42)}, + ) + @property def subbuffer_edges(self) -> np.ndarray: """Edges of contained buffers, mostly needed as part of the VectorReplayBuffer interface. @@ -230,9 +241,6 @@ def __getattr__(self, key: str) -> Any: except KeyError as exception: raise AttributeError from exception - def __setstate__(self, state: dict[str, Any]) -> None: - self.__dict__.update(state) - def __setattr__(self, key: str, value: Any) -> None: assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned" super().__setattr__(key, value) @@ -499,7 +507,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: batch_size = len(self) if self.stack_num == 1 or not self._sample_avail: # most often case if batch_size > 0: - return np.random.choice(self._size, batch_size) + return self._random_state.choice(self._size, batch_size) # TODO: is this behavior really desired? if batch_size == 0: # construct current available indices return np.concatenate( @@ -520,7 +528,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: prev_indices = self.prev(prev_indices) all_indices = all_indices[prev_indices != self.prev(prev_indices)] if batch_size > 0: - return np.random.choice(all_indices, batch_size) + return self._random_state.choice(all_indices, batch_size) return all_indices def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]: diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index e8176aa8c..370c358be 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -208,11 +208,11 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: return all_indices if batch_size is None: batch_size = len(all_indices) - return np.random.choice(all_indices, batch_size) + return self._random_state.choice(all_indices, batch_size) if batch_size == 0 or batch_size is None: # get all available indices sample_num = np.zeros(self.buffer_num, int) else: - buffer_idx = np.random.choice( + buffer_idx = self._random_state.choice( self.buffer_num, batch_size, p=self._lengths / self._lengths.sum(), From 5f515a157a670e1fce5ff6babb5e7b671653bc64 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 12 May 2025 22:25:34 +0200 Subject: [PATCH 22/56] Improve change log entry pertaining to the breaking change in the trainer --- CHANGELOG.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3e440ca2..229613e5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,11 +24,13 @@ - trainer: - `BaseTrainer.run` and `__iter__`: Resetting was never optional prior to running the trainer, - yet recently introduced parameter `reset_prior_to_run` of `run` suggested that it was optional. - But it was not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)` + yet the recently introduced parameter `reset_prior_to_run` of `run` suggested that it _was_ optional. + Yet the parameter was ultimately not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)` regardless. The parameter was removed; instead, the parameters of `run` now mirror the parameters of `reset`, - and the implicit `reset` call in `__iter__` was removed. - This aligns with upcoming changes in Tianshou v2.0.0. + and the implicit `reset` call in `__iter__` was removed. + This aligns with upcoming changes in Tianshou v2.0.0. + NOTE: If you have been using a trainer without calling `run` but by directly iterating over it, you + will need to call `reset` on the trainer explicitly before iterating over the trainer. - data: - stats: - `InfoStats` has a new non-optional field `best_score` which is used From c05294fd4c0c16455c168266f1ca9e513103d706 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 12 May 2025 22:28:31 +0200 Subject: [PATCH 23/56] Add determinism tests for virtually all algorithms Fix some broken tests that directly used the trainer's iterator instead of using run(): * test/continuous/test_ppo * test/continuous/test_td3 --- test/continuous/test_ddpg.py | 12 ++++++++++-- test/continuous/test_npg.py | 12 ++++++++++-- test/continuous/test_ppo.py | 20 +++++++++++--------- test/continuous/test_redq.py | 12 ++++++++++-- test/continuous/test_sac_with_il.py | 23 ++++++++++++++++++++--- test/continuous/test_td3.py | 20 +++++++++++--------- test/continuous/test_trpo.py | 12 ++++++++++-- test/determinism_test.py | 4 ++++ test/discrete/test_a2c_with_il.py | 23 ++++++++++++++++++++--- test/discrete/test_bdq.py | 6 ++++++ test/discrete/test_c51.py | 12 ++++++++++-- test/discrete/test_drqn.py | 12 ++++++++++-- test/discrete/test_fqf.py | 12 ++++++++++-- test/discrete/test_iqn.py | 12 ++++++++++-- test/discrete/test_qrdqn.py | 12 ++++++++++-- test/discrete/test_rainbow.py | 12 ++++++++++-- test/discrete/test_sac.py | 15 +++++++++++++-- test/offline/test_bcq.py | 12 ++++++++++-- test/offline/test_cql.py | 11 +++++++++-- test/offline/test_discrete_bcq.py | 2 +- test/offline/test_discrete_cql.py | 18 ++++++++++++++---- test/offline/test_discrete_crr.py | 18 ++++++++++++++---- test/offline/test_gail.py | 12 ++++++++++-- test/offline/test_td3_bc.py | 11 +++++++++-- 24 files changed, 252 insertions(+), 63 deletions(-) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index ce7998eff..3ec3fabbe 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -49,7 +50,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_ddpg(args: argparse.Namespace = get_args()) -> None: +def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -131,4 +132,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_ddpg_determinism(): + main_fn = lambda args: test_ddpg(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_ddpg", main_fn, get_args()).run() diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index d853e2186..185f4749a 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -52,7 +53,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_npg(args: argparse.Namespace = get_args()) -> None: +def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) @@ -153,4 +154,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_npg_determinism(): + main_fn = lambda args: test_npg(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_npg", main_fn, get_args()).run() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 4b56bd630..b045acbd9 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -58,7 +59,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) @@ -166,7 +167,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: print("Fail to restore policy and optim.") # trainer - trainer = OnpolicyTrainer( + result = OnpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, @@ -181,16 +182,17 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, - ) - - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - print(epoch_stat) - # print(info) + ).run() - assert stop_fn(epoch_stat.info_stat.best_reward) + if enable_assertions: + assert stop_fn(result.best_reward) def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_ppo(args) + + +def test_ppo_determinism(): + main_fn = lambda args: test_ppo(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_ppo", main_fn, get_args()).run() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 82c8f0637..5dacdd869 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -54,7 +55,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_redq(args: argparse.Namespace = get_args()) -> None: +def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) @@ -162,4 +163,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_redq_determinism(): + main_fn = lambda args: test_redq(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_redq", main_fn, get_args()).run() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 09fc3ca45..31103ae59 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -57,7 +58,11 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: +def test_sac_with_il( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, + skip_il: bool = False, +) -> None: # if you want to use python vector env, please refer to other test scripts # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) @@ -158,7 +163,12 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + if skip_il: + return # here we define an imitation collector with a trivial policy if args.task.startswith("Pendulum"): @@ -203,4 +213,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_sac_determinism(): + main_fn = lambda args: test_sac_with_il(args, enable_assertions=False, skip_il=True) + AlgorithmDeterminismTest("continuous_sac", main_fn, get_args()).run() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 6c59ea25a..26c723f19 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -1,6 +1,6 @@ import argparse import os -import pprint +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -52,7 +52,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_td3(args: argparse.Namespace = get_args()) -> None: +def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -135,7 +135,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # Iterator trainer - trainer = OffpolicyTrainer( + result = OffpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, @@ -148,10 +148,12 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - ) - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - pprint.pprint(epoch_stat) - # print(info) + ).run() + + if enable_assertions: + assert stop_fn(result.best_reward) + - assert stop_fn(epoch_stat.info_stat.best_reward) +def test_td3_determinism(): + main_fn = lambda args: test_td3(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_td3", main_fn, get_args()).run() diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 91e215116..74a2f85f8 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -54,7 +55,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_trpo(args: argparse.Namespace = get_args()) -> None: +def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -153,4 +154,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_trpo_determinism(): + main_fn = lambda args: test_trpo(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_trpo", main_fn, get_args()).run() diff --git a/test/determinism_test.py b/test/determinism_test.py index 7dfdb2dfb..f675e44bd 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -87,6 +87,10 @@ def set(attr: str, value: Any) -> None: self.main_fn = main_fn def run(self, update_snapshot: bool = False) -> None: + """ + :param update_snapshot: whether to update to snapshot (may be centrally overridden by + FORCE_SNAPSHOT_UPDATE) + """ if not self.ENABLED: pytest.skip("Algorithm determinism tests are disabled.") diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 4aceba169..cf0f4ef7c 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -59,7 +60,11 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: +def test_a2c_with_il( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, + skip_il: bool = False, +) -> None: # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -144,7 +149,12 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + if skip_il: + return # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': @@ -188,4 +198,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_ppo_determinism() -> None: + main_fn = lambda args: test_a2c_with_il(args, enable_assertions=False, skip_il=True) + AlgorithmDeterminismTest("discrete_a2c", main_fn, get_args()).run() diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 16042f622..7a1f9650a 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -1,4 +1,5 @@ import argparse +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -142,3 +143,8 @@ def stop_fn(mean_rewards: float) -> bool: test_fn=test_fn, stop_fn=stop_fn, ).run() + + +def test_ppo_determinism() -> None: + main_fn = lambda args: test_bdq(args) + AlgorithmDeterminismTest("discrete_bdq", main_fn, get_args()).run() diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 41d6a0260..e7405ef47 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_c51(args: argparse.Namespace = get_args()) -> None: +def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) @@ -200,7 +201,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_c51_resume(args: argparse.Namespace = get_args()) -> None: @@ -213,3 +216,8 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_c51(args) + + +def test_c51_determinism(): + main_fn = lambda args: test_c51(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_c51", main_fn, get_args()).run() diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 4cc0b6bd0..8b7d7ed85 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -47,7 +48,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_drqn(args: argparse.Namespace = get_args()) -> None: +def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) @@ -129,4 +130,11 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_drqn_determinism() -> None: + main_fn = lambda args: test_drqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_drqn", main_fn, get_args()).run() diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index e0899d315..72453ee38 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -60,7 +61,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_fqf(args: argparse.Namespace = get_args()) -> None: +def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -170,10 +171,17 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_pfqf(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_fqf(args) + + +def test_fqf_determinism() -> None: + main_fn = lambda args: test_fqf(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_fqf", main_fn, get_args()).run() diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 08f545b11..018c97b0f 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -60,7 +61,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_iqn(args: argparse.Namespace = get_args()) -> None: +def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -166,10 +167,17 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_piqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_iqn(args) + + +def test_iqm_determinism() -> None: + main_fn = lambda args: test_iqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_iqm", main_fn, get_args()).run() diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 5aa543fb5..1b66798f5 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -56,7 +57,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_qrdqn(args: argparse.Namespace = get_args()) -> None: +def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -159,10 +160,17 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_pqrdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_qrdqn(args) + + +def test_qrdqn_determinism(): + main_fn = lambda args: test_qrdqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_qrdqn", main_fn, get_args()).run() diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index d7d4b15b1..1f94fd849 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -64,7 +65,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_rainbow(args: argparse.Namespace = get_args()) -> None: +def test_rainbow(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -221,7 +222,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None: @@ -234,3 +237,8 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_rainbow(args) + + +def test_rainbow_determinism(): + main_fn = lambda args: test_rainbow(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_rainbow", main_fn, get_args()).run() diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 3409dab0a..99572f8fe 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -50,7 +51,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: +def test_discrete_sac( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, +) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -140,4 +144,11 @@ def stop_fn(mean_rewards: float) -> bool: update_per_step=args.update_per_step, test_in_train=False, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_ppo_determinism() -> None: + main_fn = lambda args: test_discrete_sac(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_sac", main_fn, get_args()).run() diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 2ed910902..c409fdb3f 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -2,6 +2,7 @@ import datetime import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_bcq(args: argparse.Namespace = get_args()) -> None: +def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -201,4 +202,11 @@ def watch() -> None: logger=logger, show_progress=args.show_progress, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_bcq_determinism() -> None: + main_fn = lambda args: test_bcq(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_bcq", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index bd84098ba..e4be3ecbe 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -3,6 +3,7 @@ import os import pickle import pprint +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -66,7 +67,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_cql(args: argparse.Namespace = get_args()) -> None: +def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -200,4 +201,10 @@ def stop_fn(mean_rewards: float) -> bool: pprint.pprint(epoch_stat) # print(info) - assert stop_fn(epoch_stat.info_stat.best_reward) + if enable_assertions: + assert stop_fn(epoch_stat.info_stat.best_reward) + + +def test_cql_determinism(): + main_fn = lambda args: test_cql(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_cql", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index ba79f48fd..f7d34ba50 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -181,4 +181,4 @@ def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: def test_discrete_bcq_determinism() -> None: main_fn = lambda args: test_discrete_bcq(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_bcq", main_fn, get_args(), is_offline=True).run() + AlgorithmDeterminismTest("offline_discrete_bcq", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 97766d494..373b9a074 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym @@ -35,7 +36,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--min-q-weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--update-per-epoch", type=int, default=1000) + parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64]) parser.add_argument("--test-num", type=int, default=100) @@ -50,7 +51,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: +def test_discrete_cql( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, +) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -118,7 +122,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, + step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, @@ -126,4 +130,10 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() - assert stop_fn(result.best_reward) + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_discrete_cql_determinism() -> None: + main_fn = lambda args: test_discrete_cql(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_discrete_cql", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index bf9a833a9..3593cf206 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym @@ -33,7 +34,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--update-per-epoch", type=int, default=1000) + parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) @@ -48,7 +49,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: +def test_discrete_crr( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, +) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -122,7 +126,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, + step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, @@ -130,4 +134,10 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() - assert stop_fn(result.best_reward) + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_discrete_crr_determinism() -> None: + main_fn = lambda args: test_discrete_crr(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_discrete_crr", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index c7f183587..98a6b6c48 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_gail(args: argparse.Namespace = get_args()) -> None: +def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -220,4 +221,11 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_gail_determinism() -> None: + main_fn = lambda args: test_gail(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_gail", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 17c3afb06..a2b8744d8 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -2,6 +2,7 @@ import datetime import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_td3_bc(args: argparse.Namespace = get_args()) -> None: +def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -187,4 +188,10 @@ def stop_fn(mean_rewards: float) -> bool: print(epoch_stat) # print(info) - assert stop_fn(epoch_stat.info_stat.best_reward) + if enable_assertions: + assert stop_fn(epoch_stat.info_stat.best_reward) + + +def test_discrete_bcq_determinism() -> None: + main_fn = lambda args: test_td3_bc(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_td3_bc", main_fn, get_args(), is_offline=True).run() From 61c9fa335b644b5526c194ef6954f995830d75de Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 01:00:52 +0200 Subject: [PATCH 24/56] Fix determinism test name --- test/discrete/test_iqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 018c97b0f..523a99471 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -180,4 +180,4 @@ def test_piqn(args: argparse.Namespace = get_args()) -> None: def test_iqm_determinism() -> None: main_fn = lambda args: test_iqn(args, enable_assertions=False) - AlgorithmDeterminismTest("discrete_iqm", main_fn, get_args()).run() + AlgorithmDeterminismTest("discrete_iqn", main_fn, get_args()).run() From 2816d0470d180da0d8843622336a927aad69829a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 02:32:24 +0200 Subject: [PATCH 25/56] Fix test name --- test/discrete/test_sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 99572f8fe..9e6a08dc9 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -149,6 +149,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_ppo_determinism() -> None: +def test_discrete_sac_determinism() -> None: main_fn = lambda args: test_discrete_sac(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_sac", main_fn, get_args()).run() From c7d48a3cb694a90f2d9689509acad69b7a5a38b3 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 23:47:21 +0200 Subject: [PATCH 26/56] Add more trace log messages for context --- tianshou/trainer/base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 044c369d2..90dd2eb0e 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -264,6 +264,7 @@ def _reset_collectors(self, reset_buffer: bool = False) -> None: def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" + TraceLogger.log(log, lambda: "Trainer reset") self.is_run = False self.env_step = 0 if self.resume_from_log: @@ -360,8 +361,10 @@ def __next__(self) -> EpochStats: # perform n step_per_epoch steps_done_in_this_epoch = 0 with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t: + TraceLogger.log(log, lambda: f"Epoch #{self.epoch} start") collect_stats: CollectStatsBase while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: + TraceLogger.log(log, lambda: "Training step") collect_stats, training_stats, self.stop_fn_flag = self.training_step() TraceLogger.log( log, From 63c5e9550338c3cbfaefd48ca9840991fca8d09a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 23:50:01 +0200 Subject: [PATCH 27/56] Configure training eps value for initial data collection (DQN, BDQ) --- test/discrete/test_bdq.py | 4 +++- test/discrete/test_dqn.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 7a1f9650a..657d0f0df 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -114,7 +114,9 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: exploration_noise=True, ) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index eeb5e8207..5c706e884 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -107,12 +107,16 @@ def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) From b735e0b49c5826ca2fc5d2984e99120640c991d1 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 13 May 2025 23:52:59 +0200 Subject: [PATCH 28/56] Fix test names --- test/discrete/test_bdq.py | 2 +- test/discrete/test_iqn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 657d0f0df..719e10d86 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -147,6 +147,6 @@ def stop_fn(mean_rewards: float) -> bool: ).run() -def test_ppo_determinism() -> None: +def test_bdq_determinism() -> None: main_fn = lambda args: test_bdq(args) AlgorithmDeterminismTest("discrete_bdq", main_fn, get_args()).run() diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 523a99471..8fc669280 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -178,6 +178,6 @@ def test_piqn(args: argparse.Namespace = get_args()) -> None: test_iqn(args) -def test_iqm_determinism() -> None: +def test_iqn_determinism() -> None: main_fn = lambda args: test_iqn(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_iqn", main_fn, get_args()).run() From 809279b686c45002b2ceaaac6480619c4160d68a Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 00:07:18 +0200 Subject: [PATCH 29/56] Configure training eps value for initial data collection (C51, FQF, IQN, QRDQN, Rainbow) --- test/discrete/test_c51.py | 7 ++++++- test/discrete/test_fqf.py | 6 +++++- test/discrete/test_iqn.py | 7 ++++++- test/discrete/test_qrdqn.py | 6 +++++- test/discrete/test_rainbow.py | 6 +++++- 5 files changed, 27 insertions(+), 5 deletions(-) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index e7405ef47..d0c9c7218 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -105,6 +105,7 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -116,12 +117,16 @@ def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 72453ee38..2e56b4e38 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -124,12 +124,16 @@ def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 8fc669280..57bf28e73 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -109,6 +109,7 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -120,12 +121,16 @@ def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = Tr ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 1b66798f5..3e2fa81db 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -113,12 +113,16 @@ def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 1f94fd849..d29870012 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -128,12 +128,16 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) From 790dbb36bdc79a42d1a649cb9a83e10938bdbe9c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 21:09:40 +0200 Subject: [PATCH 30/56] TraceLogger: Add flag 'verbose' --- tianshou/utils/determinism.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index 9d713f1ef..11bb325bb 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -63,6 +63,10 @@ class TraceLogger: NOTE: The preferred way to enable this is via the context manager. """ + verbose = False + """ + whether to print trace log messages to stdout. + """ MESSAGE_TAG = "[TRACE]" """ a tag which is added at the beginning of log messages generated by this logger @@ -100,6 +104,8 @@ def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> No stacklevel=2, ) cls.log_buffer.write(msg_formatted + "\n") + if cls.verbose: + print(msg_formatted) @dataclass From 3fc484a24918e9e7d4e59a32491c46236084f064 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 May 2025 22:16:02 +0200 Subject: [PATCH 31/56] v1: Removed unused and failing test --- test/base/test_batch.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 5fa40758f..f8af8c521 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -51,10 +51,6 @@ def test_batch() -> None: Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) - with pytest.raises(TypeError): - Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) - with pytest.raises(TypeError): - Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) From 5b46038582ab22d2c17c20bfb50a8791aa47c254 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 14 May 2025 22:16:40 +0200 Subject: [PATCH 32/56] v1: minor type validation --- test/determinism_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/determinism_test.py b/test/determinism_test.py index f675e44bd..e3824a1c9 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -19,6 +19,9 @@ def __enter__(self): torch.set_deterministic_debug_mode(self.new_mode) def __exit__(self, exc_type, exc_value, traceback): + assert ( + self.original_mode is not None + ), "original_mode should not be None, did you enter the context?" torch.set_deterministic_debug_mode(self.original_mode) From f73b24772d1dd21e6b7fc5d8a57f9afb0035897d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 22:24:21 +0200 Subject: [PATCH 33/56] Fix test name --- test/offline/test_td3_bc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index a2b8744d8..aae4233e8 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -192,6 +192,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(epoch_stat.info_stat.best_reward) -def test_discrete_bcq_determinism() -> None: +def test_td3_bc_determinism() -> None: main_fn = lambda args: test_td3_bc(args, enable_assertions=False) AlgorithmDeterminismTest("offline_td3_bc", main_fn, get_args(), is_offline=True).run() From b6fe90e21f7db7e61672442238efa91d3156109d Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 22:25:51 +0200 Subject: [PATCH 34/56] Use trainer run instead of direct iteration --- test/offline/test_cql.py | 8 ++------ test/offline/test_td3_bc.py | 8 ++------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index e4be3ecbe..c7d8b5e48 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -195,14 +195,10 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, logger=logger, ) - - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - pprint.pprint(epoch_stat) - # print(info) + stats = trainer.run() if enable_assertions: - assert stop_fn(epoch_stat.info_stat.best_reward) + assert stop_fn(stats.best_reward) def test_cql_determinism(): diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index aae4233e8..40b529c68 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -182,14 +182,10 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, logger=logger, ) - - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - print(epoch_stat) - # print(info) + stats = trainer.run() if enable_assertions: - assert stop_fn(epoch_stat.info_stat.best_reward) + assert stop_fn(stats.best_reward) def test_td3_bc_determinism() -> None: From 744561e82a334e86061965fbf21ebda21b5c1042 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 May 2025 22:45:11 +0200 Subject: [PATCH 35/56] Improve trace log message --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 125bc84df..ced0043d1 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -543,7 +543,7 @@ def update( return TrainingStats() # type: ignore[return-value] start_time = time.time() batch, indices = buffer.sample(sample_size) - TraceLogger.log(logger, lambda: f"Updating with batch: {pickle_hash(indices)}") + TraceLogger.log(logger, lambda: f"Updating with batch: indices={pickle_hash(indices)}") self.updating = True batch = self.process_fn(batch, buffer, indices) with torch_train_mode(self): From a89cb14473fb09a991d655e4063c716db90d6b4c Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 12:59:54 +0200 Subject: [PATCH 36/56] Improve change log --- CHANGELOG.md | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 229613e5e..c0d29b0dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,12 @@ # Changelog -## Unreleased +## Upcoming Release 1.2.0 ### Changes/Improvements -- trainer: +- `trainer`: - Custom scoring now supported for selecting the best model. #1202 -- highlevel: +- `highlevel`: - `DiscreteSACExperimentBuilder`: Expose method `with_actor_factory_default` #1248 #1250 - `ActorFactoryDefault`: Fix parameters for hidden sizes and activation not being passed on in the discrete case (affects `with_actor_factory_default` method of experiment builders) @@ -19,22 +19,29 @@ - `NPGAgentFactory`, `TRPOAgentFactory`: Fix optimizer instantiation including the actor parameters (which was misleadingly suggested in the docstring in the respective policy classes; docstrings were fixed), as the actor parameters are intended to be handled via natural gradients internally - +- Tests: + - We have introduced extensive **determinism tests** which allow to validate whether + training processes deterministically compute the same results across different development branches. + This is an important step towards ensuring reproducibility and consistency, which will be + instrumental in supporting Tianshou developers in their work, especially in the context of + algorithm development and evaluation. + ### Breaking Changes -- trainer: +- `trainer`: - `BaseTrainer.run` and `__iter__`: Resetting was never optional prior to running the trainer, yet the recently introduced parameter `reset_prior_to_run` of `run` suggested that it _was_ optional. Yet the parameter was ultimately not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)` regardless. The parameter was removed; instead, the parameters of `run` now mirror the parameters of `reset`, and the implicit `reset` call in `__iter__` was removed. This aligns with upcoming changes in Tianshou v2.0.0. - NOTE: If you have been using a trainer without calling `run` but by directly iterating over it, you - will need to call `reset` on the trainer explicitly before iterating over the trainer. -- data: - - stats: - - `InfoStats` has a new non-optional field `best_score` which is used - for selecting the best model. #1202 + * NOTE: If you have been using a trainer without calling `run` but by directly iterating over it, you + will need to call `reset` on the trainer explicitly before iterating over the trainer. + * Using a trainer as an iterator is considered deprecated and support for this will be removed in Tianshou v2.0.0. +- `data`: + - `InfoStats` has a new non-optional field `best_score` which is used + for selecting the best model. #1202 + ## Release 1.1.0 From cd57fa7df7f279657f4b78af60842b71806b4677 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 15 May 2025 13:07:25 +0200 Subject: [PATCH 37/56] Fix mypy issues --- test/continuous/test_ddpg.py | 2 +- test/continuous/test_npg.py | 2 +- test/continuous/test_ppo.py | 2 +- test/continuous/test_redq.py | 2 +- test/continuous/test_sac_with_il.py | 2 +- test/continuous/test_td3.py | 2 +- test/continuous/test_trpo.py | 2 +- test/determinism_test.py | 12 +++++------- test/discrete/test_c51.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/offline/test_cql.py | 3 +-- tianshou/trainer/base.py | 1 + 13 files changed, 17 insertions(+), 19 deletions(-) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 3ec3fabbe..1569c0df1 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -137,6 +137,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_ddpg_determinism(): +def test_ddpg_determinism() -> None: main_fn = lambda args: test_ddpg(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_ddpg", main_fn, get_args()).run() diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 185f4749a..3b413eec4 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -159,6 +159,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_npg_determinism(): +def test_npg_determinism() -> None: main_fn = lambda args: test_npg(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_npg", main_fn, get_args()).run() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index b045acbd9..cbb8544ab 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -193,6 +193,6 @@ def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: test_ppo(args) -def test_ppo_determinism(): +def test_ppo_determinism() -> None: main_fn = lambda args: test_ppo(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_ppo", main_fn, get_args()).run() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 5dacdd869..bb43033ed 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -168,6 +168,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_redq_determinism(): +def test_redq_determinism() -> None: main_fn = lambda args: test_redq(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_redq", main_fn, get_args()).run() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 31103ae59..1d4fc06fe 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -218,6 +218,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_sac_determinism(): +def test_sac_determinism() -> None: main_fn = lambda args: test_sac_with_il(args, enable_assertions=False, skip_il=True) AlgorithmDeterminismTest("continuous_sac", main_fn, get_args()).run() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 26c723f19..82fcce0fb 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -154,6 +154,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_td3_determinism(): +def test_td3_determinism() -> None: main_fn = lambda args: test_td3(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_td3", main_fn, get_args()).run() diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 74a2f85f8..321e351f3 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -159,6 +159,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(result.best_reward) -def test_trpo_determinism(): +def test_trpo_determinism() -> None: main_fn = lambda args: test_trpo(args, enable_assertions=False) AlgorithmDeterminismTest("continuous_trpo", main_fn, get_args()).run() diff --git a/test/determinism_test.py b/test/determinism_test.py index e3824a1c9..5cf3b8773 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -10,18 +10,16 @@ class TorchDeterministicModeContext: - def __init__(self, mode="default"): + def __init__(self, mode: str | int = "default") -> None: self.new_mode = mode - self.original_mode = None + self.original_mode: str | int | None = None - def __enter__(self): + def __enter__(self) -> None: self.original_mode = torch.get_deterministic_debug_mode() torch.set_deterministic_debug_mode(self.new_mode) - def __exit__(self, exc_type, exc_value, traceback): - assert ( - self.original_mode is not None - ), "original_mode should not be None, did you enter the context?" + def __exit__(self, exc_type, exc_value, traceback): # type: ignore + assert self.original_mode is not None torch.set_deterministic_debug_mode(self.original_mode) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index d0c9c7218..2876c4406 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -223,6 +223,6 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None: test_c51(args) -def test_c51_determinism(): +def test_c51_determinism() -> None: main_fn = lambda args: test_c51(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_c51", main_fn, get_args()).run() diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 3e2fa81db..afa2592c4 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -175,6 +175,6 @@ def test_pqrdqn(args: argparse.Namespace = get_args()) -> None: test_qrdqn(args) -def test_qrdqn_determinism(): +def test_qrdqn_determinism() -> None: main_fn = lambda args: test_qrdqn(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_qrdqn", main_fn, get_args()).run() diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index d29870012..92d10b06a 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -243,6 +243,6 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None: test_rainbow(args) -def test_rainbow_determinism(): +def test_rainbow_determinism() -> None: main_fn = lambda args: test_rainbow(args, enable_assertions=False) AlgorithmDeterminismTest("discrete_rainbow", main_fn, get_args()).run() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index c7d8b5e48..ea8b6ac11 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -2,7 +2,6 @@ import datetime import os import pickle -import pprint from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data @@ -201,6 +200,6 @@ def stop_fn(mean_rewards: float) -> bool: assert stop_fn(stats.best_reward) -def test_cql_determinism(): +def test_cql_determinism() -> None: main_fn = lambda args: test_cql(args, enable_assertions=False) AlgorithmDeterminismTest("offline_cql", main_fn, get_args(), is_offline=True).run() diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 90dd2eb0e..00928783f 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -366,6 +366,7 @@ def __next__(self) -> EpochStats: while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: TraceLogger.log(log, lambda: "Training step") collect_stats, training_stats, self.stop_fn_flag = self.training_step() + assert training_stats is not None TraceLogger.log( log, lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict()}", From 2b57654089135573226ddddd312da3f93d112b04 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 19:58:28 +0200 Subject: [PATCH 38/56] Relax determinism tests: * Require only core message equivalence (network parameter hashes) for the test to pass * Allow to ignore certain messages on a per-test level --- test/continuous/test_redq.py | 7 +- test/determinism_test.py | 20 +++++- tianshou/utils/determinism.py | 120 ++++++++++++++++++++-------------- 3 files changed, 95 insertions(+), 52 deletions(-) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index bb43033ed..b8f5d2e8e 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -170,4 +170,9 @@ def stop_fn(mean_rewards: float) -> bool: def test_redq_determinism() -> None: main_fn = lambda args: test_redq(args, enable_assertions=False) - AlgorithmDeterminismTest("continuous_redq", main_fn, get_args()).run() + ignored_messages = [ + "Params[actor_old]" + ] # actor_old only present in v1 (due to flawed inheritance) + AlgorithmDeterminismTest( + "continuous_redq", main_fn, get_args(), ignored_messages=ignored_messages + ).run() diff --git a/test/determinism_test.py b/test/determinism_test.py index 5cf3b8773..828825660 100644 --- a/test/determinism_test.py +++ b/test/determinism_test.py @@ -1,5 +1,5 @@ from argparse import Namespace -from collections.abc import Callable +from collections.abc import Callable, Sequence from pathlib import Path from typing import Any @@ -49,6 +49,13 @@ class AlgorithmDeterminismTest: Enable this when running on the "old" branch and you want to prepare the snapshots for a comparison with the "new" branch. """ + PASS_IF_CORE_MESSAGES_UNCHANGED = True + """ + whether to pass the test if only the core messages are unchanged. + If this is False, then the full log is required to be equivalent, whereas if it is True, + only the core messages need to be equivalent. + The core messages test whether the algorithm produces the same network parameters. + """ def __init__( self, @@ -56,6 +63,7 @@ def __init__( main_fn: Callable[[Namespace], Any], args: Namespace, is_offline: bool = False, + ignored_messages: Sequence[str] = (), ): """ :param name: the (unique!) name of the test @@ -64,10 +72,13 @@ def __init__( for the test) :param is_offline: whether the algorithm being tested is an offline algorithm and therefore does not configure the number of training environments (`training_num`) + :param ignored_messages: message fragments to ignore in the trace log (if any) """ self.determinism_test = TraceDeterminismTest( base_path=Path(__file__).parent / "resources" / "determinism", log_filename="determinism_tests.log", + core_messages=["Params"], + ignored_messages=ignored_messages, ) self.name = name @@ -104,4 +115,9 @@ def run(self, update_snapshot: bool = False) -> None: self.main_fn(self.args) log = trace.get_log() - self.determinism_test.check(log, self.name, create_reference_result=update_snapshot) + self.determinism_test.check( + log, + self.name, + create_reference_result=update_snapshot, + pass_if_core_messages_unchanged=self.PASS_IF_CORE_MESSAGES_UNCHANGED, + ) diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index 11bb325bb..77149f4a5 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -141,14 +141,19 @@ def filter_messages( self, required_messages: Sequence[str] = (), optional_messages: Sequence[str] = (), + ignored_messages: Sequence[str] = (), ) -> "TraceLog": """ - Reduces the set of log messages to a set of core messages that indicate that the fundamental - trace is still the same (same actions, same states, same images). + Applies inclusion and or exclusion filtering to the log messages. + If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied. + If `ignored_messages` is empty, exclusion filtering is applied. + If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence. - :param required_messages: message substrings to filter for; each message is required to appear at least once + :param required_messages: required message substrings to filter for; each message is required to appear at least once (triggering exception otherwise) :param optional_messages: additional messages fragments to filter for; these are not required + :param ignored_messages: message fragments that result in exclusion; takes precedence over + `required_messages` and `optional_messages` :return: the result with reduced log messages """ import numpy as np @@ -156,11 +161,17 @@ def filter_messages( required_message_counters = np.zeros(len(required_messages)) def retain_line(line: str) -> bool: - for i, main_message in enumerate(required_messages): - if main_message in line: - required_message_counters[i] += 1 - return True - return any(add_message in line for add_message in optional_messages) + for ignored_message in ignored_messages: + if ignored_message in line: + return False + if required_messages or optional_messages: + for i, main_message in enumerate(required_messages): + if main_message in line: + required_message_counters[i] += 1 + return True + return any(add_message in line for add_message in optional_messages) + else: + return True lines = [] for line in self.log_lines: @@ -241,16 +252,20 @@ def __init__( self, base_path: Path, core_messages: Sequence[str] = (), + ignored_messages: Sequence[str] = (), log_filename: str | None = None, ) -> None: """ :param base_path: the directory where the reference results are stored (will be created if necessary) :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core + :param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over + `core_messages` :param log_filename: the name of the log file to which results are to be written (if any) """ base_path.mkdir(parents=True, exist_ok=True) self.base_path = base_path self.core_messages = core_messages + self.ignored_messages = ignored_messages self.log_filename = log_filename @dataclass(kw_only=True) @@ -263,6 +278,7 @@ def check( current_log: TraceLog, name: str, create_reference_result: bool = False, + pass_if_core_messages_unchanged: bool = False, ) -> None: """ Checks the given log against the reference result for the given name. @@ -285,8 +301,12 @@ def check( ) reference_log = reference_result.log - current_log_reduced = current_log.reduce_log_to_messages() - reference_log_reduced = reference_log.reduce_log_to_messages() + current_log_reduced = current_log.reduce_log_to_messages().filter_messages( + ignored_messages=self.ignored_messages + ) + reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages( + ignored_messages=self.ignored_messages + ) results: list[tuple[TraceLog, str]] = [ (reference_log_reduced, "expected"), @@ -312,55 +332,57 @@ def check( result_main_messages = current_log_reduced reference_result_main_messages = reference_log_reduced - status_passed = True logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log() if logs_equivalent: status_passed = True status_message = "OK" else: - status_passed = False - - # save files for comparison - files = [] - for r, suffix in results: - path = os.path.abspath(f"determinism_{name}_{suffix}.txt") - r.save_log(path) - files.append(path) - - paths_str = "\n".join(files) - main_message = ( - f"Please inspect the changes by diffing the log files:\n{paths_str}\n" - f"If the changes are OK, enable the `create_reference_result` flag temporarily, " - "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" - ) - - # compute diff and add to message - num_diff_lines_to_show = 30 - for i, line in enumerate( - difflib.unified_diff( - reference_log_reduced.log_lines, - current_log_reduced.log_lines, - fromfile="expected.txt", - tofile="current.txt", - lineterm="", - ), - ): - if i == num_diff_lines_to_show: - break - main_message += line + "\n" - - core_messages_changed_only = ( + core_messages_unchanged = ( len(self.core_messages) > 0 and result_main_messages.get_full_log() == reference_result_main_messages.get_full_log() ) - if core_messages_changed_only: - status_message = ( - "The behaviour log has changed, but the core messages are still the same (so this " - f"probably isn't an issue). {main_message}" - ) + status_passed = core_messages_unchanged and pass_if_core_messages_unchanged + + if status_passed: + status_message = "OK (core messages unchanged)" else: - status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" + # save files for comparison + files = [] + for r, suffix in results: + path = os.path.abspath(f"determinism_{name}_{suffix}.txt") + r.save_log(path) + files.append(path) + + paths_str = "\n".join(files) + main_message = ( + f"Please inspect the changes by diffing the log files:\n{paths_str}\n" + f"If the changes are OK, enable the `create_reference_result` flag temporarily, " + "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" + ) + + # compute diff and add to message + num_diff_lines_to_show = 30 + for i, line in enumerate( + difflib.unified_diff( + reference_log_reduced.log_lines, + current_log_reduced.log_lines, + fromfile="expected.txt", + tofile="current.txt", + lineterm="", + ), + ): + if i == num_diff_lines_to_show: + break + main_message += line + "\n" + + if core_messages_unchanged: + status_message = ( + "The behaviour log has changed, but the core messages are still the same (so this " + f"probably isn't an issue). {main_message}" + ) + else: + status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" # write log message if self.log_filename: From 0c385f98c1d6ea28a8b2119e7623496e809596bc Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 20:44:46 +0200 Subject: [PATCH 39/56] test_drqn: Collect initial data in training mode --- test/discrete/test_drqn.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 8b7d7ed85..d3fcdbd89 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -83,6 +83,7 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T action_space=env.action_space, target_update_freq=args.target_update_freq, ) + # collector buffer = VectorReplayBuffer( args.buffer_size, @@ -91,11 +92,15 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T ignore_obs_next=True, ) train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + # the stack_num is for RNN training: sample framestack obs test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) From a4e81eaa603ec320a4d87479d8e9329004c0c300 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 22:47:54 +0200 Subject: [PATCH 40/56] Formatting --- test/continuous/test_redq.py | 7 +++++-- tianshou/utils/determinism.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index b8f5d2e8e..24a7f420c 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -171,8 +171,11 @@ def stop_fn(mean_rewards: float) -> bool: def test_redq_determinism() -> None: main_fn = lambda args: test_redq(args, enable_assertions=False) ignored_messages = [ - "Params[actor_old]" + "Params[actor_old]", ] # actor_old only present in v1 (due to flawed inheritance) AlgorithmDeterminismTest( - "continuous_redq", main_fn, get_args(), ignored_messages=ignored_messages + "continuous_redq", + main_fn, + get_args(), + ignored_messages=ignored_messages, ).run() diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py index 77149f4a5..5747eeb12 100644 --- a/tianshou/utils/determinism.py +++ b/tianshou/utils/determinism.py @@ -302,10 +302,10 @@ def check( reference_log = reference_result.log current_log_reduced = current_log.reduce_log_to_messages().filter_messages( - ignored_messages=self.ignored_messages + ignored_messages=self.ignored_messages, ) reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages( - ignored_messages=self.ignored_messages + ignored_messages=self.ignored_messages, ) results: list[tuple[TraceLog, str]] = [ From eaa7f96cae5b869e78b75eae575ac5d190d4bf56 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 16 May 2025 22:48:05 +0200 Subject: [PATCH 41/56] Fix assertion (stats can be None) --- tianshou/trainer/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 00928783f..b5ea2de34 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -366,10 +366,9 @@ def __next__(self) -> EpochStats: while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: TraceLogger.log(log, lambda: "Training step") collect_stats, training_stats, self.stop_fn_flag = self.training_step() - assert training_stats is not None TraceLogger.log( log, - lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict()}", + lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict() if training_stats else None}", ) self._log_params(self.policy) From af0a959903f1eeb410fcfa01982bd87df76e865f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 12:24:42 +0200 Subject: [PATCH 42/56] Fix create_toc_py not accounting for spaces in paths --- docs/create_toc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/create_toc.py b/docs/create_toc.py index 3e1779cdf..c6add58e5 100644 --- a/docs/create_toc.py +++ b/docs/create_toc.py @@ -3,6 +3,6 @@ # This script provides a platform-independent way of making the jupyter-book call (used in pyproject.toml) toc_file = Path(__file__).parent / "_toc.yml" -cmd = f"jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >{toc_file}" +cmd = f'jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >"{toc_file}"' print(cmd) os.system(cmd) From 619051c366efb0a835e7ea66120570bf66a78e79 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 12:31:36 +0200 Subject: [PATCH 43/56] Fix unquoted maths in docstring --- tianshou/highlevel/params/alpha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 70fff0942..1c5d60438 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -25,8 +25,8 @@ def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): """ :param lr: the learning rate for the optimizer of the alpha parameter :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; - The base value being scaled is dim(A) for continuous action spaces and log(|A|) for discrete action spaces, - i.e. with the default coefficient -1, we obtain -dim(A) and -log(dim(A)) for continuous and discrete action + The base value being scaled is `dim(A)` for continuous action spaces and `log(|A|)` for discrete action spaces, + i.e. with the default coefficient -1, we obtain `-dim(A)` and `-log(dim(A))` for continuous and discrete action spaces respectively, which gives a reasonable trade-off between exploration and exploitation. For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); 1.0 would give full entropy exploration. From cf22adf08dc806f6e1edd968b284bf3a6be0d654 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 17 May 2025 14:51:03 +0200 Subject: [PATCH 44/56] v1: improvement in doc-build commands --- docs/autogen_rst.py | 4 +- docs/spelling_wordlist.txt | 294 ------------------------------------- pyproject.toml | 5 +- 3 files changed, 4 insertions(+), 299 deletions(-) delete mode 100644 docs/spelling_wordlist.txt diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index b1a8b18d9..1477b2117 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -114,8 +114,8 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix="" for f in files_in_dir if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_") ] - if not module_names: - log.debug(f"Skipping {dirname} as it does not contain any .py files") + if not module_names and not "__init__.py" in files_in_dir: + log.debug(f"Skipping {dirname} as it does not contain any modules or __init__.py") continue package_qualname = f"{base_package_qualname}.{dirname}" package_index_rst_path = os.path.join( diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt deleted file mode 100644 index 40aa69970..000000000 --- a/docs/spelling_wordlist.txt +++ /dev/null @@ -1,294 +0,0 @@ -tianshou -arXiv -tanh -lr -logits -env -envs -optim -eps -timelimit -TimeLimit -envpool -EnvPool -maxsize -timestep -timesteps -numpy -ndarray -stackoverflow -tensorboard -state_dict -len -tac -fqf -iqn -qrdqn -rl -offpolicy -onpolicy -quantile -quantiles -dqn -param -async -subprocess -deque -nn -equ -cql -fn -boolean -pre -np -cuda -rnn -rew -pre -perceptron -bsz -dataset -mujoco -jit -nstep -preprocess -preprocessing -repo -ReLU -namespace -recv -th -utils -NaN -linesearch -hyperparameters -pseudocode -entropies -nn -config -cpu -rms -debias -indice -regularizer -miniblock -modularize -serializable -softmax -vectorized -optimizers -undiscounted -submodule -subclasses -submodules -tfevent -dirichlet -docstring -webpage -formatter -num -py -pythonic -中文文档位于 -conda -miniconda -Amir -Andreas -Antonoglou -Beattie -Bellemare -Charles -Daan -Demis -Dharshan -Fidjeland -Georg -Hassabis -Helen -Ioannis -Kavukcuoglu -King -Koray -Kumaran -Legg -Mnih -Ostrovski -Petersen -Riedmiller -Rusu -Sadik -Shane -Stig -Veness -Volodymyr -Wierstra -Lillicrap -Pritzel -Heess -Erez -Yuval -Tassa -Schulman -Filip -Wolski -Prafulla -Dhariwal -Radford -Oleg -Klimov -Kaichao -Jiayi -Weng -Duburcq -Huayu -Yi -Su -Strens -Ornstein -Uhlenbeck -mse -gail -airl -ppo -Jupyter -Colab -Colaboratory -IPendulum -Reacher -Runtime -Nvidia -Enduro -Qbert -Seaquest -subnets -subprocesses -isort -yapf -pydocstyle -Args -tuples -tuple -Multi -multi -parameterized -Proximal -metadata -GPU -Dopamine -builtin -params -inplace -deepcopy -Gaussian -stdout -parallelization -minibatch -minibatches -MLP -backpropagation -dataclass -superset -subtype -subdirectory -picklable -ShmemVectorEnv -Github -wandb -jupyter -img -src -parallelized -infty -venv -venvs -subproc -bcq -highlevel -icm -modelbased -td -psrl -ddpg -npg -tf -trpo -crr -pettingzoo -multidiscrete -vecbuf -prio -colab -segtree -multiagent -mapolicy -sensai -sensAI -docstrings -superclass -iterable -functools -str -sklearn -attr -bc -redq -modelfree -bdq -util -logp -autogenerated -subpackage -subpackages -recurse -rollout -rollouts -prepend -prepends -dict -dicts -pytorch -tensordict -onwards -Dominik -Tsinghua -Tianshou -appliedAI -macOS -joblib -master -Panchenko -BA -BH -BO -BD -configs -postfix -backend -rliable -hl -v_s -v_s_ -obs -obs_next -dtype -iqm -kwarg -entrypoint -interquantile -init -kwarg -kwargs -autocompletion -codebase -indexable -sliceable -gaussian -logprob -monte -carlo -subclass -subclassing -dist -dists -subbuffer -subbuffers diff --git a/pyproject.toml b/pyproject.toml index 178222f9a..d4535f096 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,13 +228,12 @@ _poetry_sort = "poetry sort" clean-nbs = "python docs/nbstripout.py" format = ["_ruff_format", "_ruff_format_nb", "_black_format", "_poetry_install_sort_plugin", "_poetry_sort"] _autogen_rst = "python docs/autogen_rst.py" -_sphinx_build = "sphinx-build -W -b html docs docs/_build" +_sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going" _jb_generate_toc = "python docs/create_toc.py" _jb_generate_config = "jupyter-book config sphinx docs/" doc-clean = "rm -rf docs/_build" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] -doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build" -doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"] +doc-build = ["doc-generate-files", "_sphinx_build"] _mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] From 3192dbfb0e299819e7cf828a58ae8d182319f251 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 14:03:11 +0200 Subject: [PATCH 45/56] Fix ruff complaint --- docs/autogen_rst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index 1477b2117..93e2a0954 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -114,7 +114,7 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix="" for f in files_in_dir if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_") ] - if not module_names and not "__init__.py" in files_in_dir: + if not module_names and "__init__.py" not in files_in_dir: log.debug(f"Skipping {dirname} as it does not contain any modules or __init__.py") continue package_qualname = f"{base_package_qualname}.{dirname}" From de78ecb3eb037b839e8a744e962179cf238db760 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 17:45:25 +0200 Subject: [PATCH 46/56] Document determinism test usage --- docs/04_contributing/04_contributing.rst | 44 ++++++++++++++++++------ 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/docs/04_contributing/04_contributing.rst b/docs/04_contributing/04_contributing.rst index 1397e3473..bf831fb39 100644 --- a/docs/04_contributing/04_contributing.rst +++ b/docs/04_contributing/04_contributing.rst @@ -2,11 +2,12 @@ Contributing to Tianshou ======================== -Install Develop Version ------------------------ +Install Development Environment +------------------------------- Tianshou is built and managed by `poetry `_. For example, -to install all relevant requirements in editable mode you can simply call +to install all relevant requirements (and install Tianshou itself in editable mode) +you can simply call .. code-block:: bash @@ -36,9 +37,9 @@ Please set up pre-commit by running in the main directory. This should make sure that your contribution is properly formatted before every commit. -The code is inspected and formatted by `black` and `ruff`. They are executed as -pre-commit hooks. In addition, `poe the poet` tasks are configured. -Simply run `poe` to see the available tasks. +The code is inspected and formatted by ``black`` and ``ruff``. They are executed as +pre-commit hooks. In addition, ``poe the poet`` tasks are configured. +Simply run ``poe`` to see the available tasks. E.g, to format and check the linting manually you can run: .. code-block:: bash @@ -47,8 +48,8 @@ E.g, to format and check the linting manually you can run: $ poe lint -Type Check ----------- +Type Checks +----------- We use `mypy `_ to check the type annotations. To check, in the main directory, run: @@ -57,8 +58,8 @@ We use `mypy `_ to check the type annotations. $ poe type-check -Test Locally ------------- +Testing Locally +--------------- This command will run automatic tests in the main directory @@ -67,6 +68,29 @@ This command will run automatic tests in the main directory $ poe test +Determinism Tests +~~~~~~~~~~~~~~~~~ + +We implemented "determinism tests" for Tianshou's algorithms, which allow us to determine +whether algorithms still compute exactly the same results even after large refactorings. +These tests are applied by + 1. creating a behavior snapshot ine the old code branch before the changes and then + 2. running the test in the new branch to ensure that the behavior is the same. + +Unfortunately, full determinism is difficult to achieve across different platforms and even different +machines using the same platform an Python environment. +Therefore, these tests are not carried out in the CI pipeline. +Instead, it is up to the developer to run them locally and check the results whenever a change +is made to the code base that could affect algorithm behavior. + +Technically, the two steps are handled by setting static flags in class ``AlgorithmDeterminismTest`` and then +running either the full test suite or a specific determinism test (``test_*_determinism``, e.g. ``test_ddpg_determinism``) +in the two branches to be compared. + + 1. On the old branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=True`` and run the test(s). + 2. On the new branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=False`` and run the test(s). + 3. Inspect the test results; find a summary in ``determinism_tests.log`` + Test by GitHub Actions ---------------------- From 802fb83a4273d696c6930963428a36b2da24e3e2 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 19 May 2025 18:01:08 +0200 Subject: [PATCH 47/56] Mentioned determinism tests in PR template --- .github/PULL_REQUEST_TEMPLATE.md | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index e1c3f2b3c..364f89649 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,6 +2,7 @@ - [ ] I have provided a description of the changes in this Pull Request - [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md - [ ] If applicable, I have added tests to cover my changes. +- [ ] If applicable, I have made sure that the determinism tests run through, meaning that my changes haven't influenced any aspect of training. See info in the contributing documentation. - [ ] I have reformatted the code using `poe format` - [ ] I have checked style and types with `poe lint` and `poe type-check` - [ ] (Optional) I ran tests locally with `poe test` From 2cd40cb3aed0e6810a760160a7889d48a0de4db3 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 19 May 2025 18:37:04 +0200 Subject: [PATCH 48/56] Allow collection of empty episodes (done on reset) Slightly enhanced docstrings in collector --- CHANGELOG.md | 2 ++ tianshou/data/buffer/base.py | 5 ++--- tianshou/data/collector.py | 13 ++++++++++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0d29b0dd..611593f9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ - `NPGAgentFactory`, `TRPOAgentFactory`: Fix optimizer instantiation including the actor parameters (which was misleadingly suggested in the docstring in the respective policy classes; docstrings were fixed), as the actor parameters are intended to be handled via natural gradients internally +- `data`: + - `ReplayBuffer`: Fix collection of empty episodes being disallowed - Tests: - We have introduced extensive **determinism tests** which allow to validate whether training processes deterministically compute the same results across different development branches. diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index aad0c6f2b..72c7af5bb 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -145,8 +145,7 @@ def _get_start_stop_tuples_for_edge_crossing_interval( if stop >= start: raise ValueError( f"Expected stop < start, but got {start=}, {stop=}. " - f"For stop larger than start this method should never be called, " - f"and stop=start should never occur. This can occur either due to an implementation error, " + f"For stop larger-equal than start this method should never be called. This can occur either due to an implementation error, " f"or due a bad configuration of the buffer that resulted in a single episode being so long that " f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). " f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the " @@ -213,7 +212,7 @@ def get_buffer_indices(self, start: int, stop: int) -> np.ndarray: f"Start and stop indices must be within the same subbuffer. " f"Got {start=} in subbuffer edge {start_left_edge} and {stop=} in subbuffer edge {stop_left_edge}.", ) - if stop > start: + if stop >= start: return np.arange(start, stop, dtype=int) else: (start, upper_edge), ( diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d615af4e8..f436cac48 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -42,6 +42,8 @@ _TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") +TScalarArrayShape = TypeVar("TScalarArrayShape") + class CollectActionBatchProtocol(Protocol): """A protocol for results of computing actions from a batch of observations within a single collect step. @@ -778,10 +780,13 @@ def _collect( # noqa: C901 # TODO: can't do it init since AsyncCollector is currently a subclass of Collector if self.env.is_async: raise ValueError( - f"Please use {AsyncCollector.__name__} for asynchronous environments. " + f"Please use AsyncCollector for asynchronous environments. " f"Env class: {self.env.__class__.__name__}.", ) + ready_env_ids_R: np.ndarray[Any, np.dtype[np.signedinteger]] + """provides a mapping from local indices (indexing within `1, ..., R` where `R` is the number of ready envs) + to global ones (indexing within `1, ..., num_envs`). So the entry i in this array is the global index of the i-th ready env.""" if n_step is not None: ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: @@ -915,6 +920,8 @@ def _collect( # noqa: C901 # local_idx - see block comment on class level # Step 7 env_done_local_idx_D = np.where(done_R)[0] + """Indexes which episodes are done within the ready envs, so it can be used for selecting from `..._R` arrays. + Stands in contrast to the "global" index, which counts within all envs and is unsuitable for selecting from `..._R` arrays.""" episode_lens_D = ep_len_R[env_done_local_idx_D] episode_returns_D = ep_return_R[env_done_local_idx_D] episode_start_indices_D = ep_start_idx_R[env_done_local_idx_D] @@ -933,6 +940,10 @@ def _collect( # noqa: C901 # 0,...,R and this global index is maintained by the ready_env_ids_R array. # See the class block comment for more details env_done_global_idx_D = ready_env_ids_R[env_done_local_idx_D] + """Indexes which episodes are done within all envs, i.e., within the index `1, ..., num_envs`. It can be + used to communicate with the vector env, where env ids are selected from this "global" index. + Is not suited for selecting from the ready envs (`..._R` arrays), use the local counterpart instead. + """ obs_reset_DO, info_reset_D = self.env.reset( env_id=env_done_global_idx_D, **gym_reset_kwargs, From 981e649f26f7a112802e159aae531edf8be0aac1 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 19:59:48 +0200 Subject: [PATCH 49/56] High-level API: Change the way in which seeding is handled The mechanism introduced in v1.1.0 was completely revised: - The `train_seed` and `test_seed` attributes were removed from `SamplingConfig`. Instead, the seeds are derived from the seed defined in `ExperimentConfig`. - Seed attributes of `EnvFactory` classes were removed. Instead, seeds are passed to methods of `EnvFactory`. --- CHANGELOG.md | 8 ++- examples/atari/atari_dqn_hl.py | 2 - examples/atari/atari_iqn_hl.py | 2 - examples/atari/atari_ppo_hl.py | 2 - examples/atari/atari_sac_hl.py | 2 - examples/atari/atari_wrapper.py | 10 +-- examples/mujoco/mujoco_a2c_hl.py | 7 +- examples/mujoco/mujoco_ddpg_hl.py | 7 +- examples/mujoco/mujoco_env.py | 19 ++--- examples/mujoco/mujoco_npg_hl.py | 7 +- examples/mujoco/mujoco_ppo_hl.py | 7 +- examples/mujoco/mujoco_ppo_hl_multi.py | 9 +-- examples/mujoco/mujoco_redq_hl.py | 7 +- examples/mujoco/mujoco_reinforce_hl.py | 7 +- examples/mujoco/mujoco_sac_hl.py | 7 +- examples/mujoco/mujoco_td3_hl.py | 7 +- examples/mujoco/mujoco_trpo_hl.py | 7 +- test/highlevel/env_factory.py | 4 -- tianshou/env/worker/base.py | 9 ++- tianshou/highlevel/config.py | 7 -- tianshou/highlevel/env.py | 98 +++++++++++++++++++++----- tianshou/highlevel/experiment.py | 12 +--- 22 files changed, 115 insertions(+), 132 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 611593f9a..e4d14ae09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,13 @@ - `data`: - `InfoStats` has a new non-optional field `best_score` which is used for selecting the best model. #1202 - +- `highlevel`: + - Change the way in which seeding is handled: The mechanism introduced in v1.1.0 + was completely revised: + - The `train_seed` and `test_seed` attributes were removed from `SamplingConfig`. + Instead, the seeds are derived from the seed defined in `ExperimentConfig`. + - Seed attributes of `EnvFactory` classes were removed. + Instead, seeds are passed to methods of `EnvFactory`. ## Release 1.1.0 diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 601481523..3bcb0f6c3 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -69,8 +69,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index b71b0eef3..c644b2469 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -68,8 +68,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 26ebaba08..983608293 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -73,8 +73,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 211e07d57..76f18f55f 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -68,8 +68,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index d7234d863..db6c6dcd8 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -383,8 +383,8 @@ def make_atari_env( :return: a tuple of (single env, training envs, test envs). """ - env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale)) - envs = env_factory.create_envs(training_num, test_num) + env_factory = AtariEnvFactory(task, frame_stack, scale=bool(scale)) + envs = env_factory.create_envs(training_num, test_num, seed=seed) return envs.env, envs.train_envs, envs.test_envs @@ -392,8 +392,6 @@ class AtariEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, - train_seed: int, - test_seed: int, frame_stack: int, scale: bool = False, use_envpool_if_available: bool = True, @@ -411,13 +409,11 @@ def __init__( log.info("Not using envpool, because it is not available") super().__init__( task=task, - train_seed=train_seed, - test_seed=test_seed, venv_type=venv_type, envpool_factory=envpool_factory, ) - def create_env(self, mode: EnvMode) -> gym.Env: + def _create_env(self, mode: EnvMode) -> gym.Env: env = super().create_env(mode) is_train = mode == EnvMode.TRAIN return wrap_deepmind( diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index c804d6c26..bba7f9e76 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -54,12 +54,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( A2CExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 27dbfc8d9..daa936533 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -52,12 +52,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 90f27995b..8aff92793 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -37,7 +37,7 @@ def make_mujoco_env( :return: a tuple of (single env, training envs, test envs). """ - envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs( + envs = MujocoEnvFactory(task, obs_norm=obs_norm).create_envs( num_train_envs, num_test_envs, ) @@ -73,28 +73,18 @@ class MujocoEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, - train_seed: int, - test_seed: int, obs_norm: bool = True, venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO, ) -> None: super().__init__( task=task, - train_seed=train_seed, - test_seed=test_seed, venv_type=venv_type, envpool_factory=EnvPoolFactory() if envpool_is_available else None, ) self.obs_norm = obs_norm - def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: - """Create vectorized environments. - - :param num_envs: the number of environments - :param mode: the mode for which to create - :return: the vectorized environments - """ - env = super().create_venv(num_envs, mode) + def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: + env = super().create_venv(num_envs, mode, seed=seed) # obs norm wrapper if self.obs_norm: env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN) @@ -105,8 +95,9 @@ def create_envs( num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, + seed: int | None = None, ) -> ContinuousEnvironments: - envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env) + envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env, seed=seed) assert isinstance(envs, ContinuousEnvironments) if self.obs_norm: diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 387f87c6e..a231e1b21 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -53,12 +53,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( NPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index b10d4cf26..973a822a6 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -58,12 +58,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 333870809..47ccc9ae2 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -37,7 +37,7 @@ def main( num_experiments: int = 5, run_experiments_sequentially: bool = True, - logger_type: str = "wandb", + logger_type: str = "tensorboard", ) -> RLiableExperimentResult: """:param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. @@ -70,12 +70,7 @@ def main( repeat_per_collect=1, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) hidden_sizes = (64, 64) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index c0c63279a..90f6ef318 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -58,12 +58,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( REDQExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 59a600568..f3e8821ae 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -49,12 +49,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( PGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index a150f5571..5b2e1519b 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -53,12 +53,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( SACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 5ec9cc17b..8ca54d591 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -58,12 +58,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 1ec26bad2..4dfc39185 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -55,12 +55,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 4a131e5fd..1d649f15e 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -8,8 +8,6 @@ class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( task="CartPole-v1", - train_seed=42, - test_seed=1337, venv_type=VectorEnvType.DUMMY, ) @@ -18,7 +16,5 @@ class ContinuousTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( task="Pendulum-v1", - train_seed=42, - test_seed=1337, venv_type=VectorEnvType.DUMMY, ) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index ac35ccf3f..ca31b9ac3 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -69,7 +69,14 @@ def wait( raise NotImplementedError def seed(self, seed: int | None = None) -> list[int] | None: - return self.action_space.seed(seed) # issue 299 + """ + Seeds the environment's action space sampler. + NOTE: This does *not* seed the environment itself. + + :param seed: the random seed + :return: a list containing the resulting seed used + """ + return self.action_space.seed(seed) @abstractmethod def render(self, **kwargs: Any) -> Any: diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index ac27cba1a..fb58c8a58 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -56,9 +56,6 @@ class SamplingConfig(ToStringMixin): num_train_envs: int = -1 """the number of training environments to use. If set to -1, use number of CPUs/threads.""" - train_seed: int = 42 - """the seed to use for the training environments.""" - num_test_envs: int = 1 """the number of test environments to use""" @@ -165,10 +162,6 @@ class SamplingConfig(ToStringMixin): Currently only used in Atari examples and may be removed in the future! """ - @property - def test_seed(self) -> int: - return self.train_seed + self.num_train_envs - def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index e61e0ed36..d4acd65be 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -7,6 +7,7 @@ import gymnasium as gym import gymnasium.spaces +import numpy as np from gymnasium import Env from sensai.util.pickle import setstate from sensai.util.string import ToStringMixin @@ -370,11 +371,57 @@ def __init__(self, venv_type: VectorEnvType): """ self.venv_type = venv_type + @staticmethod + def _create_rng(seed: int | None) -> np.random.Generator: + """ + Creates a random number generator with the given seed. + + :param seed: the seed to use; if None, a random seed will be used + :return: the random number generator + """ + return np.random.default_rng(seed=seed) + + @staticmethod + def _next_seed(rng: np.random.Generator) -> int: + """ + Samples a random seed from the given random number generator. + + :param rng: the random number generator + :return: the sampled random seed + """ + return int(rng.integers(0, 2**64, dtype=np.uint64)) + @abstractmethod - def create_env(self, mode: EnvMode) -> Env: - pass + def _create_env(self, mode: EnvMode) -> Env: + """Creates a single environment for the given mode. + + :param mode: the mode + :return: an environment + """ + + def create_env(self, mode: EnvMode, seed: int | None = None) -> Env: + """ + Creates a single environment for the given mode. + + :param mode: the mode + :param seed: the random seed to use for the environment; if None, the seed will not be specified, + and gymnasium will use a random seed. + :return: the environment + """ + env = self._create_env(mode) - def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + # initialize the environment with the given seed (if any) + if seed is not None: + rng = self._create_rng(seed) + env.np_random = rng + # also set the seed member within the environment such that it can be retrieved + # (gymnasium's random seed handling is, unfortunately, broken) + if hasattr(env, "_np_random_seed"): + env._np_random_seed = seed + + return env + + def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: """Create vectorized environments. :param num_envs: the number of environments @@ -383,28 +430,47 @@ def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: :return: the vectorized environments """ + rng = self._create_rng(seed) + + def create_factory_fn() -> Callable[[], Env]: + # create a factory function that uses a sampled random seed + return lambda random_seed=self._next_seed(rng): self.create_env(mode, seed=random_seed) # type: ignore + + # create the vectorized environment, seeded appropriately if mode == EnvMode.WATCH: - return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)]) + venv = VectorEnvType.DUMMY.create_venv([create_factory_fn()]) else: - return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) + venv = self.venv_type.create_venv([create_factory_fn() for _ in range(num_envs)]) + + # seed the action samplers + venv.seed([self._next_seed(rng) for _ in range(num_envs)]) + + return venv def create_envs( self, num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, + seed: int | None = None, ) -> Environments: """Create environments for learning. :param num_training_envs: the number of training environments :param num_test_envs: the number of test environments :param create_watch_env: whether to create an environment for watching the agent + :param seed: the random seed to use for environment creation :return: the environments """ + rng = self._create_rng(seed) env = self.create_env(EnvMode.TRAIN) - train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) - test_envs = self.create_venv(num_test_envs, EnvMode.TEST) - watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None + train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN, seed=self._next_seed(rng)) + test_envs = self.create_venv(num_test_envs, EnvMode.TEST, seed=self._next_seed(rng)) + watch_env = ( + self.create_venv(1, EnvMode.WATCH, seed=self._next_seed(rng)) + if create_watch_env + else None + ) match EnvType.from_env(env): case EnvType.DISCRETE: return DiscreteEnvironments(env, train_envs, test_envs, watch_env) @@ -423,8 +489,6 @@ def __init__( self, *, task: str, - train_seed: int, - test_seed: int, venv_type: VectorEnvType, envpool_factory: EnvPoolFactory | None = None, render_mode_train: str | None = None, @@ -444,8 +508,6 @@ def __init__( super().__init__(venv_type) self.task = task self.envpool_factory = envpool_factory - self.train_seed = train_seed - self.test_seed = test_seed self.render_modes = { EnvMode.TRAIN: render_mode_train, EnvMode.TEST: render_mode_test, @@ -476,7 +538,7 @@ def _create_kwargs(self, mode: EnvMode) -> dict: kwargs["render_mode"] = self.render_modes.get(mode) return kwargs - def create_env(self, mode: EnvMode) -> Env: + def _create_env(self, mode: EnvMode) -> Env: """Creates a single environment for the given mode. :param mode: the mode @@ -485,17 +547,15 @@ def create_env(self, mode: EnvMode) -> Env: kwargs = self._create_kwargs(mode) return gymnasium.make(self.task, **kwargs) - def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: - seed = self.train_seed if mode == EnvMode.TRAIN else self.test_seed + def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: if self.envpool_factory is not None: + rng = self._create_rng(seed) return self.envpool_factory.create_venv( self.task, num_envs, mode, - seed, + self._next_seed(rng), self._create_kwargs(mode), ) else: - venv = super().create_venv(num_envs, mode) - venv.seed(seed) - return venv + return super().create_venv(num_envs, mode, seed=seed) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index a908e1d06..4df648fa9 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -219,13 +219,7 @@ def get_seeding_info_as_str(self) -> str: This can be useful for creating unique experiment names based on seeds, e.g. A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. """ - return "_".join( - [ - f"exp_seed={self.config.seed}", - f"train_seed={self.sampling_config.train_seed}", - f"test_seed={self.sampling_config.test_seed}", - ], - ) + return f"exp_seed={self.config.seed}" def _set_seed(self) -> None: seed = self.config.seed @@ -298,6 +292,7 @@ def create_experiment_world( self.sampling_config.num_train_envs, self.sampling_config.num_test_envs, create_watch_env=self.config.watch, + seed=self.config.seed, ) log.info(f"Created {envs}") @@ -672,13 +667,10 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: Each experiment in the collection will have a unique name created from the original experiment name and the seeds used. """ - num_train_envs = self.sampling_config.num_train_envs - seeded_experiments = [] for i in range(num_experiments): builder = self.copy() builder.experiment_config.seed += i - builder.sampling_config.train_seed += i * num_train_envs experiment = builder.build() experiment.name += f"_{experiment.get_seeding_info_as_str()}" seeded_experiments.append(experiment) From 5f5bab96bcc17d1bef1ae0c1b378551cad11434b Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Mon, 19 May 2025 20:55:02 +0200 Subject: [PATCH 50/56] Fix syntax issue --- docs/04_contributing/04_contributing.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/04_contributing/04_contributing.rst b/docs/04_contributing/04_contributing.rst index bf831fb39..48cf172c8 100644 --- a/docs/04_contributing/04_contributing.rst +++ b/docs/04_contributing/04_contributing.rst @@ -74,6 +74,7 @@ Determinism Tests We implemented "determinism tests" for Tianshou's algorithms, which allow us to determine whether algorithms still compute exactly the same results even after large refactorings. These tests are applied by + 1. creating a behavior snapshot ine the old code branch before the changes and then 2. running the test in the new branch to ensure that the behavior is the same. From a86e246cfee15d8d03fce6bfcea1f6a77fe37f9f Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 19 Jun 2025 10:34:46 +0200 Subject: [PATCH 51/56] AtariEnvFactory: Fix super call --- examples/atari/atari_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index db6c6dcd8..25a3b09f2 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -414,7 +414,7 @@ def __init__( ) def _create_env(self, mode: EnvMode) -> gym.Env: - env = super().create_env(mode) + env = super()._create_env(mode) is_train = mode == EnvMode.TRAIN return wrap_deepmind( env, From 856e2b89e365b3f83a896444fb4310c0c35f0318 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 21 Jun 2025 14:57:52 +0200 Subject: [PATCH 52/56] v1: adjust range for seed to be compatible with envpool --- tianshou/highlevel/env.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index d4acd65be..45d85d9d5 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -389,7 +389,8 @@ def _next_seed(rng: np.random.Generator) -> int: :param rng: the random number generator :return: the sampled random seed """ - return int(rng.integers(0, 2**64, dtype=np.uint64)) + # int32 is needed for envpool compatibility + return int(rng.integers(0, 2**31, dtype=np.int32)) @abstractmethod def _create_env(self, mode: EnvMode) -> Env: From d8daab20c8a7a96ffcdda4675ca95afa11d6000d Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 21 Jun 2025 15:03:45 +0200 Subject: [PATCH 53/56] v1: disable buffer hasnull checks by default Control validation enabling with global flag --- tianshou/config.py | 2 ++ tianshou/data/collector.py | 33 +++++++++++++++++++++++++++++---- tianshou/trainer/base.py | 2 +- 3 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 tianshou/config.py diff --git a/tianshou/config.py b/tianshou/config.py new file mode 100644 index 000000000..23cbb0cb2 --- /dev/null +++ b/tianshou/config.py @@ -0,0 +1,2 @@ +ENABLE_VALIDATION = False +"""Validation can help catching bugs and issues but it slows down training and collection. Enable it only if needed.""" diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index f436cac48..3c6d75d4d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,6 +12,7 @@ from overrides import override from torch.distributions import Categorical, Distribution +from tianshou.config import ENABLE_VALIDATION from tianshou.data import ( Batch, CachedReplayBuffer, @@ -318,8 +319,32 @@ def __init__( exploration_noise: bool = False, # The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737 collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] - raise_on_nan_in_buffer: bool = True, + raise_on_nan_in_buffer: bool = ENABLE_VALIDATION, ) -> None: + """ + :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch + of actions from a batch of observations. + :param env: a ``gymnasium.Env`` environment or a vectorized instance of the + :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with + a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv` + will be constructed internally from the passed env) + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs) + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. + the rollout batch with this hook also modifies the data that is collected to the buffer! + :param raise_on_nan_in_buffer: whether to raise a `RuntimeError` if NaNs are found in the buffer after + a collection step. Especially useful when episode-level hooks are passed for making + sure that nothing is broken during the collection. Consider setting to False if + the NaN-check becomes a bottleneck. + :param collect_stats_class: the class to use for collecting statistics. Allows customizing + the stats collection logic by passing a subclass of :class:`CollectStats`. Changing + this is rarely necessary and is mainly done by "power users". + """ if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy @@ -557,7 +582,7 @@ def __init__( exploration_noise: bool = False, on_episode_done_hook: Optional["EpisodeRolloutHookProtocol"] = None, on_step_hook: Optional["StepHookProtocol"] = None, - raise_on_nan_in_buffer: bool = True, + raise_on_nan_in_buffer: bool = ENABLE_VALIDATION, collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ @@ -574,7 +599,7 @@ def __init__( :param exploration_noise: determine whether the action needs to be modified with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the - exploration noise into action.. + exploration noise into action. :param on_episode_done_hook: if passed will be executed when an episode is done. The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else). If a dict is returned by the hook it will be used to add new entries to the buffer @@ -1045,7 +1070,7 @@ def _collect( # noqa: C901 break # Check if we screwed up somewhere - if self.buffer.hasnull(): + if self.raise_on_nan_in_buffer and self.buffer.hasnull(): nan_batch = self.buffer.isnull().apply_values_transform(np.sum) raise MalformedBufferError( diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index b5ea2de34..ec4645741 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -543,7 +543,7 @@ def _collect_training_data(self) -> CollectStats: lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes", ) - if self.train_collector.buffer.hasnull(): + if self.train_collector.raise_on_nan_in_buffer and self.train_collector.buffer.hasnull(): from tianshou.data.collector import EpisodeRolloutHook from tianshou.env import DummyVectorEnv From 0db2e7458c0fdb058717fe1ea549c773ff439927 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 21 Jun 2025 16:05:36 +0200 Subject: [PATCH 54/56] v1: fixes in rliable eval data loading, better logging Don't mutate incoming dict, don't load invalid fields --- tianshou/evaluation/rliable_evaluation_hl.py | 26 +++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index 2b8ff5131..dc6205cc6 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -44,13 +44,18 @@ def from_data_dict(cls, data: dict) -> "LoggedCollectStats": Converts SequenceSummaryStats from dict format to dataclass format and ignores fields that are not present. """ + dataclass_data = {} field_names = [f.name for f in fields(cls)] for k, v in data.items(): if k not in field_names: - data.pop(k) + log.info( + f"Key {k} in data dict is not a valid field of LoggedCollectStats, ignoring it.", + ) + continue if isinstance(v, dict): - data[k] = LoggedSummaryData(**v) - return cls(**data) + v = LoggedSummaryData(**v) + dataclass_data[k] = v + return cls(**dataclass_data) @dataclass @@ -114,14 +119,23 @@ def load_from_disk( data = logger_cls.restore_logged_data(entry.path) # TODO: align low-level and high-level dir structure. This is a hack! if not data: + log.info( + f"Could not find data in {entry.path}, trying to restore from subdirectory.", + ) dirs = [ d for d in os.listdir(entry.path) if os.path.isdir(os.path.join(entry.path, d)) ] if len(dirs) != 1: - raise ValueError( - f"Could not restore data from {entry.path}, " - f"expected either events or exactly one subdirectory, ", + _error_message = ( + f"Could not restore experiment data from {entry.path}, " + f"expected either events or exactly one subdirectory, but got {dirs=}. " ) + if not dirs: + _error_message += ( + "The absence of events/subdirectory may be due to an error causing the training to stop or due to" + " too few environment steps, leading to no data being logged." + ) + raise ValueError(_error_message) data = logger_cls.restore_logged_data(os.path.join(entry.path, dirs[0])) if not data: raise ValueError(f"Could not restore data from {entry.path}.") From fd93ab35dae877bdfbf436b48b83a8bcfb20d4fe Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 23 Jun 2025 12:20:43 +0200 Subject: [PATCH 55/56] v1: replace all isinstance checks from BatchProtocol to Batch Seriously improves performance of Batch constructor --- test/base/test_collector.py | 2 +- tianshou/data/batch.py | 14 +++++++------- tianshou/data/buffer/her.py | 8 ++++---- tianshou/policy/modelbased/psrl.py | 3 ++- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6355d8bfc..1e4769bbf 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -72,7 +72,7 @@ def forward( if self.dict_state: if self.action_shape: action_shape = self.action_shape - elif isinstance(batch.obs, BatchProtocol): + elif isinstance(batch.obs, Batch): action_shape = len(batch.obs["index"]) else: action_shape = len(batch.obs) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 70478a87d..c7ee7505b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -983,7 +983,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: - if isinstance(batches, BatchProtocol | dict): + if isinstance(batches, Batch | dict): batches = [batches] # check input format batch_list = [] @@ -1069,7 +1069,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None { batch_key for batch_key, obj in batch.items() - if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0) + if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0) } for batch in batches ] @@ -1080,7 +1080,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None if all(isinstance(element, torch.Tensor) for element in value): self.__dict__[shared_key] = torch.stack(value, axis) # third often - elif all(isinstance(element, BatchProtocol | dict) for element in value): + elif all(isinstance(element, Batch | dict) for element in value): self.__dict__[shared_key] = Batch.stack(value, axis) else: # most often case is np.ndarray try: @@ -1114,7 +1114,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None value = batch.get(key) # TODO: fix code/annotations s.t. the ignores can be removed if ( - isinstance(value, BatchProtocol) # type: ignore + isinstance(value, Batch) # type: ignore and len(value.get_keys()) == 0 # type: ignore ): continue # type: ignore @@ -1288,7 +1288,7 @@ def set_array_at_key( ) from exception else: existing_entry = self[key] - if isinstance(existing_entry, BatchProtocol): + if isinstance(existing_entry, Batch): raise ValueError( f"Cannot set sequence at key {key} because it is a nested batch, " f"can only set a subsequence of an array.", @@ -1312,7 +1312,7 @@ def hasnull(self) -> bool: def is_any_true(boolean_batch: BatchProtocol) -> bool: for val in boolean_batch.values(): - if isinstance(val, BatchProtocol): + if isinstance(val, Batch): if is_any_true(val): return True else: @@ -1375,7 +1375,7 @@ def _apply_batch_values_func_recursively( """ result = batch if inplace else deepcopy(batch) for key, val in batch.__dict__.items(): - if isinstance(val, BatchProtocol): + if isinstance(val, Batch): result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False) else: result[key] = values_transform(val) diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 087f8d0b0..eb03c1595 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -150,12 +150,12 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: ep_obs = self[unique_ep_indices].obs # to satisfy mypy # TODO: add protocol covering these batches - assert isinstance(ep_obs, BatchProtocol) + assert isinstance(ep_obs, Batch) ep_rew = self[unique_ep_indices].rew if self._save_obs_next: ep_obs_next = self[unique_ep_indices].obs_next # to satisfy mypy - assert isinstance(ep_obs_next, BatchProtocol) + assert isinstance(ep_obs_next, Batch) future_obs = self[future_t[unique_ep_close_indices]].obs_next else: future_obs = self[self.next(future_t[unique_ep_close_indices])].obs @@ -172,7 +172,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: ep_rew[:, her_ep_indices] = self._compute_reward(ep_obs_next)[:, her_ep_indices] else: tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs - assert isinstance(tmp_ep_obs_next, BatchProtocol) + assert isinstance(tmp_ep_obs_next, Batch) ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices] # Sanity check @@ -181,7 +181,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: assert ep_rew.shape == unique_ep_indices.shape # Re-write meta - assert isinstance(self._meta.obs, BatchProtocol) + assert isinstance(self._meta.obs, Batch) self._meta.obs[unique_ep_indices] = ep_obs if self._save_obs_next: self._meta.obs_next[unique_ep_indices] = ep_obs_next # type: ignore diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 8c1374709..95b9527d2 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -236,7 +236,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRL for minibatch in batch.split(size=1): obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next obs_next = cast(np.ndarray, obs_next) - assert not isinstance(obs, BatchProtocol), "Observations cannot be Batches here" + assert not isinstance(obs, Batch), "Observations cannot be Batches here" + obs = cast(np.ndarray, obs) trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += minibatch.rew rew_square_sum[obs, act] += minibatch.rew**2 From 812afc8735ce4ffd7e9df754eb5aa392b15adc7f Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 23 Jun 2025 13:40:42 +0200 Subject: [PATCH 56/56] v1: changelog [ci skip] --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e4d14ae09..fcd58f287 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ as the actor parameters are intended to be handled via natural gradients internally - `data`: - `ReplayBuffer`: Fix collection of empty episodes being disallowed + - Collection was slow due to `isinstance` checks on Protocols and due to Buffer integrity validation. This was solved + by no longer performing `isinstance` on Protocols and by making the integrity validation disabled by default. - Tests: - We have introduced extensive **determinism tests** which allow to validate whether training processes deterministically compute the same results across different development branches. @@ -53,6 +55,8 @@ ## Release 1.1.0 +**NOTE**: This release introduced (potentially severe) performance regressions in data collection, please switch to a newer release for better performance. + ### Highlights #### Evaluation Package