diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..536443860 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,22 @@ +{ + "name": "Tianshou", + "dockerFile": "../Dockerfile", + "workspaceFolder": "/workspaces/tianshou", + "runArgs": ["--shm-size=1g"], + "customizations": { + "vscode": { + "settings": { + "terminal.integrated.shell.linux": "/bin/bash", + "python.pythonPath": "/usr/local/bin/python" + }, + "extensions": [ + "ms-python.python", + "ms-toolsai.jupyter", + "ms-python.vscode-pylance" + ] + } + }, + "forwardPorts": [], + "postCreateCommand": "poetry install --with dev", + "remoteUser": "root" + } \ No newline at end of file diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..fa5050fe5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,14 @@ +data +logs +test/log +docs/jupyter_execute +docs/.jupyter_cache +.lsp +.clj-kondo +docs/_build +coverage* +__pycache__ +*.egg-info +*.egg +.*cache +dist \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index e1c3f2b3c..364f89649 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,6 +2,7 @@ - [ ] I have provided a description of the changes in this Pull Request - [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md - [ ] If applicable, I have added tests to cover my changes. +- [ ] If applicable, I have made sure that the determinism tests run through, meaning that my changes haven't influenced any aspect of training. See info in the contributing documentation. - [ ] I have reformatted the code using `poe format` - [ ] I have checked style and types with `poe lint` and `poe type-check` - [ ] (Optional) I ran tests locally with `poe test` diff --git a/.gitignore b/.gitignore index e63e24b00..c6e843c4d 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,7 @@ docs/conf.py # temporary scripts (for ad-hoc testing), temp folder /temp -/temp*.py \ No newline at end of file +/temp*.py + +# determinism test snapshots +/test/resources/determinism/ diff --git a/CHANGELOG.md b/CHANGELOG.md index c032e5553..fcd58f287 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,23 +1,62 @@ # Changelog -## Unreleased +## Upcoming Release 1.2.0 ### Changes/Improvements -- trainer: +- `trainer`: - Custom scoring now supported for selecting the best model. #1202 -- highlevel: +- `highlevel`: - `DiscreteSACExperimentBuilder`: Expose method `with_actor_factory_default` #1248 #1250 - + - `ActorFactoryDefault`: Fix parameters for hidden sizes and activation not being + passed on in the discrete case (affects `with_actor_factory_default` method of experiment builders) + - `ExperimentConfig`: Do not inherit from other classes, as this breaks automatic handling by + `jsonargparse` when the class is used to define interfaces (as in high-level API examples) + - `AutoAlphaFactoryDefault`: Differentiate discrete and continuous action spaces + and allow coefficient to be modified, adding an informative docstring + (previous implementation was reasonable only for continuous action spaces) + - Adjust usage in `atari_sac_hl` example accordingly. + - `NPGAgentFactory`, `TRPOAgentFactory`: Fix optimizer instantiation including the actor parameters + (which was misleadingly suggested in the docstring in the respective policy classes; docstrings were fixed), + as the actor parameters are intended to be handled via natural gradients internally +- `data`: + - `ReplayBuffer`: Fix collection of empty episodes being disallowed + - Collection was slow due to `isinstance` checks on Protocols and due to Buffer integrity validation. This was solved + by no longer performing `isinstance` on Protocols and by making the integrity validation disabled by default. +- Tests: + - We have introduced extensive **determinism tests** which allow to validate whether + training processes deterministically compute the same results across different development branches. + This is an important step towards ensuring reproducibility and consistency, which will be + instrumental in supporting Tianshou developers in their work, especially in the context of + algorithm development and evaluation. + ### Breaking Changes -- data: - - stats: - - `InfoStats` has a new non-optional field `best_score` which is used - for selecting the best model. #1202 +- `trainer`: + - `BaseTrainer.run` and `__iter__`: Resetting was never optional prior to running the trainer, + yet the recently introduced parameter `reset_prior_to_run` of `run` suggested that it _was_ optional. + Yet the parameter was ultimately not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)` + regardless. The parameter was removed; instead, the parameters of `run` now mirror the parameters of `reset`, + and the implicit `reset` call in `__iter__` was removed. + This aligns with upcoming changes in Tianshou v2.0.0. + * NOTE: If you have been using a trainer without calling `run` but by directly iterating over it, you + will need to call `reset` on the trainer explicitly before iterating over the trainer. + * Using a trainer as an iterator is considered deprecated and support for this will be removed in Tianshou v2.0.0. +- `data`: + - `InfoStats` has a new non-optional field `best_score` which is used + for selecting the best model. #1202 +- `highlevel`: + - Change the way in which seeding is handled: The mechanism introduced in v1.1.0 + was completely revised: + - The `train_seed` and `test_seed` attributes were removed from `SamplingConfig`. + Instead, the seeds are derived from the seed defined in `ExperimentConfig`. + - Seed attributes of `EnvFactory` classes were removed. + Instead, seeds are passed to methods of `EnvFactory`. ## Release 1.1.0 +**NOTE**: This release introduced (potentially severe) performance regressions in data collection, please switch to a newer release for better performance. + ### Highlights #### Evaluation Package diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..4e3827b26 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +# Use the official Python image for the base image. +FROM --platform=linux/amd64 python:3.11-slim + +# Set environment variables to make Python print directly to the terminal and avoid .pyc files. +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 + +# Install system dependencies required for the project. +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + build-essential \ + git \ + wget \ + unzip \ + libvips-dev \ + gnupg2 \ + && rm -rf /var/lib/apt/lists/* + + +# Install pipx. +RUN python3 -m pip install --no-cache-dir pipx \ + && pipx ensurepath + +# Add poetry to the path +ENV PATH="${PATH}:/root/.local/bin" + +# Install the latest version of Poetry using pipx. +RUN pipx install poetry + +# Set the working directory. IMPORTANT: can't be changed as needs to be in sync to the dir where the project is cloned +# to in the codespace +WORKDIR /workspaces/tianshou + +# Copy the pyproject.toml and poetry.lock files (if available) into the image. +COPY pyproject.toml poetry.lock* README.md /workspaces/tianshou/ + +RUN poetry config virtualenvs.create false +RUN poetry install --no-root --with dev + +# The entrypoint will perform an editable install, it is expected that the code is mounted in the container then +# If you don't want to mount the code, you should override the entrypoint +ENTRYPOINT ["/bin/bash", "-c", "poetry install --with dev && poetry run jupyter trust notebooks/*.ipynb docs/02_notebooks/*.ipynb && $0 $@"] \ No newline at end of file diff --git a/docs/04_contributing/04_contributing.rst b/docs/04_contributing/04_contributing.rst index 1397e3473..48cf172c8 100644 --- a/docs/04_contributing/04_contributing.rst +++ b/docs/04_contributing/04_contributing.rst @@ -2,11 +2,12 @@ Contributing to Tianshou ======================== -Install Develop Version ------------------------ +Install Development Environment +------------------------------- Tianshou is built and managed by `poetry `_. For example, -to install all relevant requirements in editable mode you can simply call +to install all relevant requirements (and install Tianshou itself in editable mode) +you can simply call .. code-block:: bash @@ -36,9 +37,9 @@ Please set up pre-commit by running in the main directory. This should make sure that your contribution is properly formatted before every commit. -The code is inspected and formatted by `black` and `ruff`. They are executed as -pre-commit hooks. In addition, `poe the poet` tasks are configured. -Simply run `poe` to see the available tasks. +The code is inspected and formatted by ``black`` and ``ruff``. They are executed as +pre-commit hooks. In addition, ``poe the poet`` tasks are configured. +Simply run ``poe`` to see the available tasks. E.g, to format and check the linting manually you can run: .. code-block:: bash @@ -47,8 +48,8 @@ E.g, to format and check the linting manually you can run: $ poe lint -Type Check ----------- +Type Checks +----------- We use `mypy `_ to check the type annotations. To check, in the main directory, run: @@ -57,8 +58,8 @@ We use `mypy `_ to check the type annotations. $ poe type-check -Test Locally ------------- +Testing Locally +--------------- This command will run automatic tests in the main directory @@ -67,6 +68,30 @@ This command will run automatic tests in the main directory $ poe test +Determinism Tests +~~~~~~~~~~~~~~~~~ + +We implemented "determinism tests" for Tianshou's algorithms, which allow us to determine +whether algorithms still compute exactly the same results even after large refactorings. +These tests are applied by + + 1. creating a behavior snapshot ine the old code branch before the changes and then + 2. running the test in the new branch to ensure that the behavior is the same. + +Unfortunately, full determinism is difficult to achieve across different platforms and even different +machines using the same platform an Python environment. +Therefore, these tests are not carried out in the CI pipeline. +Instead, it is up to the developer to run them locally and check the results whenever a change +is made to the code base that could affect algorithm behavior. + +Technically, the two steps are handled by setting static flags in class ``AlgorithmDeterminismTest`` and then +running either the full test suite or a specific determinism test (``test_*_determinism``, e.g. ``test_ddpg_determinism``) +in the two branches to be compared. + + 1. On the old branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=True`` and run the test(s). + 2. On the new branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=False`` and run the test(s). + 3. Inspect the test results; find a summary in ``determinism_tests.log`` + Test by GitHub Actions ---------------------- diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py index b1a8b18d9..93e2a0954 100644 --- a/docs/autogen_rst.py +++ b/docs/autogen_rst.py @@ -114,8 +114,8 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix="" for f in files_in_dir if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_") ] - if not module_names: - log.debug(f"Skipping {dirname} as it does not contain any .py files") + if not module_names and "__init__.py" not in files_in_dir: + log.debug(f"Skipping {dirname} as it does not contain any modules or __init__.py") continue package_qualname = f"{base_package_qualname}.{dirname}" package_index_rst_path = os.path.join( diff --git a/docs/create_toc.py b/docs/create_toc.py index 3e1779cdf..c6add58e5 100644 --- a/docs/create_toc.py +++ b/docs/create_toc.py @@ -3,6 +3,6 @@ # This script provides a platform-independent way of making the jupyter-book call (used in pyproject.toml) toc_file = Path(__file__).parent / "_toc.yml" -cmd = f"jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >{toc_file}" +cmd = f'jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >"{toc_file}"' print(cmd) os.system(cmd) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt deleted file mode 100644 index 40aa69970..000000000 --- a/docs/spelling_wordlist.txt +++ /dev/null @@ -1,294 +0,0 @@ -tianshou -arXiv -tanh -lr -logits -env -envs -optim -eps -timelimit -TimeLimit -envpool -EnvPool -maxsize -timestep -timesteps -numpy -ndarray -stackoverflow -tensorboard -state_dict -len -tac -fqf -iqn -qrdqn -rl -offpolicy -onpolicy -quantile -quantiles -dqn -param -async -subprocess -deque -nn -equ -cql -fn -boolean -pre -np -cuda -rnn -rew -pre -perceptron -bsz -dataset -mujoco -jit -nstep -preprocess -preprocessing -repo -ReLU -namespace -recv -th -utils -NaN -linesearch -hyperparameters -pseudocode -entropies -nn -config -cpu -rms -debias -indice -regularizer -miniblock -modularize -serializable -softmax -vectorized -optimizers -undiscounted -submodule -subclasses -submodules -tfevent -dirichlet -docstring -webpage -formatter -num -py -pythonic -中文文档位于 -conda -miniconda -Amir -Andreas -Antonoglou -Beattie -Bellemare -Charles -Daan -Demis -Dharshan -Fidjeland -Georg -Hassabis -Helen -Ioannis -Kavukcuoglu -King -Koray -Kumaran -Legg -Mnih -Ostrovski -Petersen -Riedmiller -Rusu -Sadik -Shane -Stig -Veness -Volodymyr -Wierstra -Lillicrap -Pritzel -Heess -Erez -Yuval -Tassa -Schulman -Filip -Wolski -Prafulla -Dhariwal -Radford -Oleg -Klimov -Kaichao -Jiayi -Weng -Duburcq -Huayu -Yi -Su -Strens -Ornstein -Uhlenbeck -mse -gail -airl -ppo -Jupyter -Colab -Colaboratory -IPendulum -Reacher -Runtime -Nvidia -Enduro -Qbert -Seaquest -subnets -subprocesses -isort -yapf -pydocstyle -Args -tuples -tuple -Multi -multi -parameterized -Proximal -metadata -GPU -Dopamine -builtin -params -inplace -deepcopy -Gaussian -stdout -parallelization -minibatch -minibatches -MLP -backpropagation -dataclass -superset -subtype -subdirectory -picklable -ShmemVectorEnv -Github -wandb -jupyter -img -src -parallelized -infty -venv -venvs -subproc -bcq -highlevel -icm -modelbased -td -psrl -ddpg -npg -tf -trpo -crr -pettingzoo -multidiscrete -vecbuf -prio -colab -segtree -multiagent -mapolicy -sensai -sensAI -docstrings -superclass -iterable -functools -str -sklearn -attr -bc -redq -modelfree -bdq -util -logp -autogenerated -subpackage -subpackages -recurse -rollout -rollouts -prepend -prepends -dict -dicts -pytorch -tensordict -onwards -Dominik -Tsinghua -Tianshou -appliedAI -macOS -joblib -master -Panchenko -BA -BH -BO -BD -configs -postfix -backend -rliable -hl -v_s -v_s_ -obs -obs_next -dtype -iqm -kwarg -entrypoint -interquantile -init -kwarg -kwargs -autocompletion -codebase -indexable -sliceable -gaussian -logprob -monte -carlo -subclass -subclassing -dist -dists -subbuffer -subbuffers diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 601481523..3bcb0f6c3 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -69,8 +69,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index b71b0eef3..c644b2469 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -68,8 +68,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 26ebaba08..983608293 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -73,8 +73,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 124def768..76f18f55f 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -45,7 +45,6 @@ def main( training_num: int = 10, test_num: int = 10, frames_stack: int = 4, - save_buffer_name: str | None = None, # TODO add support in high-level API? icm_lr_scale: float = 0.0, icm_reward_scale: float = 0.01, icm_forward_loss_weight: float = 0.2, @@ -69,8 +68,6 @@ def main( env_factory = AtariEnvFactory( task, - sampling_config.train_seed, - sampling_config.test_seed, frames_stack, scale=scale_obs, ) @@ -84,7 +81,9 @@ def main( critic2_lr=critic_lr, gamma=gamma, tau=tau, - alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha, + alpha=AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98) + if auto_alpha + else alpha, estimation_step=n_step, ), ) diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index d7234d863..25a3b09f2 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -383,8 +383,8 @@ def make_atari_env( :return: a tuple of (single env, training envs, test envs). """ - env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale)) - envs = env_factory.create_envs(training_num, test_num) + env_factory = AtariEnvFactory(task, frame_stack, scale=bool(scale)) + envs = env_factory.create_envs(training_num, test_num, seed=seed) return envs.env, envs.train_envs, envs.test_envs @@ -392,8 +392,6 @@ class AtariEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, - train_seed: int, - test_seed: int, frame_stack: int, scale: bool = False, use_envpool_if_available: bool = True, @@ -411,14 +409,12 @@ def __init__( log.info("Not using envpool, because it is not available") super().__init__( task=task, - train_seed=train_seed, - test_seed=test_seed, venv_type=venv_type, envpool_factory=envpool_factory, ) - def create_env(self, mode: EnvMode) -> gym.Env: - env = super().create_env(mode) + def _create_env(self, mode: EnvMode) -> gym.Env: + env = super()._create_env(mode) is_train = mode == EnvMode.TRAIN return wrap_deepmind( env, diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 4e52f4ce2..bcf3e7f18 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -83,7 +83,7 @@ def stop_fn(mean_rewards: float) -> bool: # watch performance policy.set_eps(eps_test) collector = ts.data.Collector[CollectStats](policy, env, exploration_noise=True) - collector.collect(n_episode=100, render=1 / 35) + collector.collect(n_episode=100, render=1 / 35, reset_before_collect=True) if __name__ == "__main__": diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index c804d6c26..bba7f9e76 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -54,12 +54,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( A2CExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py index 27dbfc8d9..daa936533 100644 --- a/examples/mujoco/mujoco_ddpg_hl.py +++ b/examples/mujoco/mujoco_ddpg_hl.py @@ -52,12 +52,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( DDPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 90f27995b..8aff92793 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -37,7 +37,7 @@ def make_mujoco_env( :return: a tuple of (single env, training envs, test envs). """ - envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs( + envs = MujocoEnvFactory(task, obs_norm=obs_norm).create_envs( num_train_envs, num_test_envs, ) @@ -73,28 +73,18 @@ class MujocoEnvFactory(EnvFactoryRegistered): def __init__( self, task: str, - train_seed: int, - test_seed: int, obs_norm: bool = True, venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO, ) -> None: super().__init__( task=task, - train_seed=train_seed, - test_seed=test_seed, venv_type=venv_type, envpool_factory=EnvPoolFactory() if envpool_is_available else None, ) self.obs_norm = obs_norm - def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: - """Create vectorized environments. - - :param num_envs: the number of environments - :param mode: the mode for which to create - :return: the vectorized environments - """ - env = super().create_venv(num_envs, mode) + def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: + env = super().create_venv(num_envs, mode, seed=seed) # obs norm wrapper if self.obs_norm: env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN) @@ -105,8 +95,9 @@ def create_envs( num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, + seed: int | None = None, ) -> ContinuousEnvironments: - envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env) + envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env, seed=seed) assert isinstance(envs, ContinuousEnvironments) if self.obs_norm: diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 387f87c6e..a231e1b21 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -53,12 +53,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( NPGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index b10d4cf26..973a822a6 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -58,12 +58,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 333870809..47ccc9ae2 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -37,7 +37,7 @@ def main( num_experiments: int = 5, run_experiments_sequentially: bool = True, - logger_type: str = "wandb", + logger_type: str = "tensorboard", ) -> RLiableExperimentResult: """:param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. @@ -70,12 +70,7 @@ def main( repeat_per_collect=1, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) hidden_sizes = (64, 64) diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py index c0c63279a..90f6ef318 100644 --- a/examples/mujoco/mujoco_redq_hl.py +++ b/examples/mujoco/mujoco_redq_hl.py @@ -58,12 +58,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( REDQExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 59a600568..f3e8821ae 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -49,12 +49,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( PGExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py index a150f5571..5b2e1519b 100644 --- a/examples/mujoco/mujoco_sac_hl.py +++ b/examples/mujoco/mujoco_sac_hl.py @@ -53,12 +53,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( SACExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py index 5ec9cc17b..8ca54d591 100644 --- a/examples/mujoco/mujoco_td3_hl.py +++ b/examples/mujoco/mujoco_td3_hl.py @@ -58,12 +58,7 @@ def main( start_timesteps_random=True, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=False, - ) + env_factory = MujocoEnvFactory(task, obs_norm=False) experiment = ( TD3ExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 1ec26bad2..4dfc39185 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -55,12 +55,7 @@ def main( repeat_per_collect=repeat_per_collect, ) - env_factory = MujocoEnvFactory( - task, - train_seed=sampling_config.train_seed, - test_seed=sampling_config.test_seed, - obs_norm=True, - ) + env_factory = MujocoEnvFactory(task, obs_norm=True) experiment = ( TRPOExperimentBuilder(env_factory, experiment_config, sampling_config) diff --git a/poetry.lock b/poetry.lock index 583b90756..2cd8e2ed7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5382,13 +5382,13 @@ win32 = ["pywin32"] [[package]] name = "sensai-utils" -version = "1.2.1" +version = "1.4.0" description = "Utilities from sensAI, the Python library for sensible AI" optional = false python-versions = "*" files = [ - {file = "sensai_utils-1.2.1-py3-none-any.whl", hash = "sha256:222e60d9f9d371c9d62ffcd1e6def1186f0d5243588b0b5af57e983beecc95bb"}, - {file = "sensai_utils-1.2.1.tar.gz", hash = "sha256:4d8ca94179931798cef5f920fb042cbf9e7d806c0026b02afb58d0f72211bf27"}, + {file = "sensai_utils-1.4.0-py3-none-any.whl", hash = "sha256:ed6fc57552620e43b33cf364ea0bc0fd7df39391069dd7b621b113ef55547507"}, + {file = "sensai_utils-1.4.0.tar.gz", hash = "sha256:2d32bdcc91fd1428c5cae0181e98623142d2d5f7e115e23d585a842dd9dc59ba"}, ] [package.dependencies] @@ -6896,4 +6896,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "1ea1b72b90269fd86b81b1443785085618248ccf5b62506a166b879115749171" +content-hash = "bff3f4f8cc0d8196ea162a799472c7179486109d30968aa7d1b96b40016a459f" diff --git a/pyproject.toml b/pyproject.toml index 5640f948b..d4535f096 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ overrides = "^7.4.0" packaging = "*" pandas = ">=2.0.0" pettingzoo = "^1.22" -sensai-utils = "^1.2.1" +sensai-utils = "^1.4.0" tensorboard = "^2.5.0" # Torch 2.0.1 causes problems, see https://github.com/pytorch/pytorch/issues/100974 torch = "^2.0.0, !=2.0.1, !=2.1.0" @@ -181,7 +181,8 @@ ignore = [ "PLW2901", # overwrite vars in loop "B027", # empty and non-abstract method in abstract class "D404", # It's fine to start with "This" in docstrings - "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx + "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx + "B023", # forbids function using loop variable without explicit binding ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all @@ -227,13 +228,12 @@ _poetry_sort = "poetry sort" clean-nbs = "python docs/nbstripout.py" format = ["_ruff_format", "_ruff_format_nb", "_black_format", "_poetry_install_sort_plugin", "_poetry_sort"] _autogen_rst = "python docs/autogen_rst.py" -_sphinx_build = "sphinx-build -W -b html docs docs/_build" +_sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going" _jb_generate_toc = "python docs/create_toc.py" _jb_generate_config = "jupyter-book config sphinx docs/" doc-clean = "rm -rf docs/_build" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] -doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build" -doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"] +doc-build = ["doc-generate-files", "_sphinx_build"] _mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 5fa40758f..f8af8c521 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -51,10 +51,6 @@ def test_batch() -> None: Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) with pytest.raises(TypeError): Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) - with pytest.raises(TypeError): - Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) - with pytest.raises(TypeError): - Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) assert torch.allclose(batch.a, torch.ones(2, 3)) batch.cat_(batch) assert torch.allclose(batch.a, torch.ones(4, 3)) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6355d8bfc..1e4769bbf 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -72,7 +72,7 @@ def forward( if self.dict_state: if self.action_shape: action_shape = self.action_shape - elif isinstance(batch.obs, BatchProtocol): + elif isinstance(batch.obs, Batch): action_shape = len(batch.obs["index"]) else: action_shape = len(batch.obs) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index ce7998eff..1569c0df1 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -49,7 +50,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_ddpg(args: argparse.Namespace = get_args()) -> None: +def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -131,4 +132,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_ddpg_determinism() -> None: + main_fn = lambda args: test_ddpg(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_ddpg", main_fn, get_args()).run() diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index d853e2186..3b413eec4 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -52,7 +53,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_npg(args: argparse.Namespace = get_args()) -> None: +def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) @@ -153,4 +154,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_npg_determinism() -> None: + main_fn = lambda args: test_npg(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_npg", main_fn, get_args()).run() diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 4b56bd630..cbb8544ab 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -58,7 +59,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) @@ -166,7 +167,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: print("Fail to restore policy and optim.") # trainer - trainer = OnpolicyTrainer( + result = OnpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, @@ -181,16 +182,17 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: logger=logger, resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, - ) - - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - print(epoch_stat) - # print(info) + ).run() - assert stop_fn(epoch_stat.info_stat.best_reward) + if enable_assertions: + assert stop_fn(result.best_reward) def test_ppo_resume(args: argparse.Namespace = get_args()) -> None: args.resume = True test_ppo(args) + + +def test_ppo_determinism() -> None: + main_fn = lambda args: test_ppo(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_ppo", main_fn, get_args()).run() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 82c8f0637..24a7f420c 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -54,7 +55,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_redq(args: argparse.Namespace = get_args()) -> None: +def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) @@ -162,4 +163,19 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_redq_determinism() -> None: + main_fn = lambda args: test_redq(args, enable_assertions=False) + ignored_messages = [ + "Params[actor_old]", + ] # actor_old only present in v1 (due to flawed inheritance) + AlgorithmDeterminismTest( + "continuous_redq", + main_fn, + get_args(), + ignored_messages=ignored_messages, + ).run() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 09fc3ca45..1d4fc06fe 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -57,7 +58,11 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: +def test_sac_with_il( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, + skip_il: bool = False, +) -> None: # if you want to use python vector env, please refer to other test scripts # train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed) # test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed) @@ -158,7 +163,12 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + if skip_il: + return # here we define an imitation collector with a trivial policy if args.task.startswith("Pendulum"): @@ -203,4 +213,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_sac_determinism() -> None: + main_fn = lambda args: test_sac_with_il(args, enable_assertions=False, skip_il=True) + AlgorithmDeterminismTest("continuous_sac", main_fn, get_args()).run() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 6c59ea25a..82fcce0fb 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -1,6 +1,6 @@ import argparse import os -import pprint +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -52,7 +52,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_td3(args: argparse.Namespace = get_args()) -> None: +def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -135,7 +135,7 @@ def stop_fn(mean_rewards: float) -> bool: return mean_rewards >= args.reward_threshold # Iterator trainer - trainer = OffpolicyTrainer( + result = OffpolicyTrainer( policy=policy, train_collector=train_collector, test_collector=test_collector, @@ -148,10 +148,12 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, save_best_fn=save_best_fn, logger=logger, - ) - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - pprint.pprint(epoch_stat) - # print(info) + ).run() + + if enable_assertions: + assert stop_fn(result.best_reward) + - assert stop_fn(epoch_stat.info_stat.best_reward) +def test_td3_determinism() -> None: + main_fn = lambda args: test_td3(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_td3", main_fn, get_args()).run() diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 91e215116..321e351f3 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -54,7 +55,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_trpo(args: argparse.Namespace = get_args()) -> None: +def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -153,4 +154,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_trpo_determinism() -> None: + main_fn = lambda args: test_trpo(args, enable_assertions=False) + AlgorithmDeterminismTest("continuous_trpo", main_fn, get_args()).run() diff --git a/test/determinism_test.py b/test/determinism_test.py new file mode 100644 index 000000000..828825660 --- /dev/null +++ b/test/determinism_test.py @@ -0,0 +1,123 @@ +from argparse import Namespace +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any + +import pytest +import torch + +from tianshou.utils.determinism import TraceDeterminismTest, TraceLoggerContext + + +class TorchDeterministicModeContext: + def __init__(self, mode: str | int = "default") -> None: + self.new_mode = mode + self.original_mode: str | int | None = None + + def __enter__(self) -> None: + self.original_mode = torch.get_deterministic_debug_mode() + torch.set_deterministic_debug_mode(self.new_mode) + + def __exit__(self, exc_type, exc_value, traceback): # type: ignore + assert self.original_mode is not None + torch.set_deterministic_debug_mode(self.original_mode) + + +class AlgorithmDeterminismTest: + """ + Represents a determinism test for Tianshou's RL algorithms. + + A test using this class should be added for every algorithm in Tianshou. + Then, when making changes to one or more algorithms (e.g. refactoring), run the respective tests + on the old branch (creating snapshots) and then on the new branch that contains the changes + (comparing with the snapshots). + + Intended usage is therefore: + + 1. On the old branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=True and run the tests. + 2. On the new branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=False and run the tests. + 3. Inspect determinism_tests.log + """ + + ENABLED = False + """ + whether determinism tests are enabled. + """ + FORCE_SNAPSHOT_UPDATE = False + """ + whether to force the update/creation of snapshots for every test. + Enable this when running on the "old" branch and you want to prepare the snapshots + for a comparison with the "new" branch. + """ + PASS_IF_CORE_MESSAGES_UNCHANGED = True + """ + whether to pass the test if only the core messages are unchanged. + If this is False, then the full log is required to be equivalent, whereas if it is True, + only the core messages need to be equivalent. + The core messages test whether the algorithm produces the same network parameters. + """ + + def __init__( + self, + name: str, + main_fn: Callable[[Namespace], Any], + args: Namespace, + is_offline: bool = False, + ignored_messages: Sequence[str] = (), + ): + """ + :param name: the (unique!) name of the test + :param main_fn: the function to be called for the test + :param args: the arguments to be passed to the main function (some of which are overridden + for the test) + :param is_offline: whether the algorithm being tested is an offline algorithm and therefore + does not configure the number of training environments (`training_num`) + :param ignored_messages: message fragments to ignore in the trace log (if any) + """ + self.determinism_test = TraceDeterminismTest( + base_path=Path(__file__).parent / "resources" / "determinism", + log_filename="determinism_tests.log", + core_messages=["Params"], + ignored_messages=ignored_messages, + ) + self.name = name + + def set(attr: str, value: Any) -> None: + old_value = getattr(args, attr) + if old_value is None: + raise ValueError(f"Attribute '{attr}' is not defined for args: {args}") + setattr(args, attr, value) + + set("epoch", 3) + set("step_per_epoch", 100) + set("device", "cpu") + if not is_offline: + set("training_num", 1) + set("test_num", 1) + + self.args = args + self.main_fn = main_fn + + def run(self, update_snapshot: bool = False) -> None: + """ + :param update_snapshot: whether to update to snapshot (may be centrally overridden by + FORCE_SNAPSHOT_UPDATE) + """ + if not self.ENABLED: + pytest.skip("Algorithm determinism tests are disabled.") + + if self.FORCE_SNAPSHOT_UPDATE: + update_snapshot = True + + # run the actual process + with TraceLoggerContext() as trace: + with TorchDeterministicModeContext(): + self.main_fn(self.args) + log = trace.get_log() + + self.determinism_test.check( + log, + self.name, + create_reference_result=update_snapshot, + pass_if_core_messages_unchanged=self.PASS_IF_CORE_MESSAGES_UNCHANGED, + ) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 192f24c24..cf0f4ef7c 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -8,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, CollectStats, VectorReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.env import DummyVectorEnv from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer @@ -59,7 +60,11 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: +def test_a2c_with_il( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, + skip_il: bool = False, +) -> None: # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -144,7 +149,12 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + if skip_il: + return # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v1': @@ -165,9 +175,8 @@ def stop_fn(mean_rewards: float) -> bool: seed=args.seed, ) else: - il_env = SubprocVectorEnv( + il_env = DummyVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - context="fork", ) il_env.seed(args.seed) @@ -189,4 +198,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_ppo_determinism() -> None: + main_fn = lambda args: test_a2c_with_il(args, enable_assertions=False, skip_il=True) + AlgorithmDeterminismTest("discrete_a2c", main_fn, get_args()).run() diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 16042f622..719e10d86 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -1,4 +1,5 @@ import argparse +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -113,7 +114,9 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: exploration_noise=True, ) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -142,3 +145,8 @@ def stop_fn(mean_rewards: float) -> bool: test_fn=test_fn, stop_fn=stop_fn, ).run() + + +def test_bdq_determinism() -> None: + main_fn = lambda args: test_bdq(args) + AlgorithmDeterminismTest("discrete_bdq", main_fn, get_args()).run() diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 41d6a0260..2876c4406 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_c51(args: argparse.Namespace = get_args()) -> None: +def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) @@ -104,6 +105,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -115,12 +117,16 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) @@ -200,7 +206,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_c51_resume(args: argparse.Namespace = get_args()) -> None: @@ -213,3 +221,8 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_c51(args) + + +def test_c51_determinism() -> None: + main_fn = lambda args: test_c51(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_c51", main_fn, get_args()).run() diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index f82aca1f6..5c706e884 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -55,7 +56,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_dqn(args: argparse.Namespace = get_args()) -> None: +def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) @@ -106,12 +107,16 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) @@ -153,7 +158,14 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_dqn_determinism() -> None: + main_fn = lambda args: test_dqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_dqn", main_fn, get_args()).run() def test_pdqn(args: argparse.Namespace = get_args()) -> None: diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 4cc0b6bd0..d3fcdbd89 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -47,7 +48,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_drqn(args: argparse.Namespace = get_args()) -> None: +def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) @@ -82,6 +83,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, target_update_freq=args.target_update_freq, ) + # collector buffer = VectorReplayBuffer( args.buffer_size, @@ -90,11 +92,15 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: ignore_obs_next=True, ) train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + # the stack_num is for RNN training: sample framestack obs test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) @@ -129,4 +135,11 @@ def test_fn(epoch: int, env_step: int | None) -> None: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_drqn_determinism() -> None: + main_fn = lambda args: test_drqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_drqn", main_fn, get_args()).run() diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index e0899d315..2e56b4e38 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -60,7 +61,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_fqf(args: argparse.Namespace = get_args()) -> None: +def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -123,12 +124,16 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) @@ -170,10 +175,17 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_pfqf(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_fqf(args) + + +def test_fqf_determinism() -> None: + main_fn = lambda args: test_fqf(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_fqf", main_fn, get_args()).run() diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 08f545b11..57bf28e73 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -60,7 +61,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_iqn(args: argparse.Namespace = get_args()) -> None: +def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -108,6 +109,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, target_update_freq=args.target_update_freq, ).to(args.device) + # buffer buf: ReplayBuffer if args.prioritized_replay: @@ -119,12 +121,16 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) @@ -166,10 +172,17 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_piqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_iqn(args) + + +def test_iqn_determinism() -> None: + main_fn = lambda args: test_iqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_iqn", main_fn, get_args()).run() diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 8a681583d..a4fb28300 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -44,7 +45,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_pg(args: argparse.Namespace = get_args()) -> None: +def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -122,4 +123,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_pg_determinism() -> None: + main_fn = lambda args: test_pg(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_pg", main_fn, get_args()).run() diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 7e541fffb..4226caf8f 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -57,7 +58,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_ppo(args: argparse.Namespace = get_args()) -> None: +def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -149,4 +150,11 @@ def stop_fn(mean_rewards: float) -> bool: save_best_fn=save_best_fn, logger=logger, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_ppo_determinism() -> None: + main_fn = lambda args: test_ppo(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_ppo", main_fn, get_args()).run() diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 5aa543fb5..afa2592c4 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -56,7 +57,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_qrdqn(args: argparse.Namespace = get_args()) -> None: +def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -112,12 +113,16 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) @@ -159,10 +164,17 @@ def test_fn(epoch: int, env_step: int | None) -> None: logger=logger, update_per_step=args.update_per_step, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_pqrdqn(args: argparse.Namespace = get_args()) -> None: args.prioritized_replay = True args.gamma = 0.95 test_qrdqn(args) + + +def test_qrdqn_determinism() -> None: + main_fn = lambda args: test_qrdqn(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_qrdqn", main_fn, get_args()).run() diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index d7d4b15b1..92d10b06a 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -64,7 +65,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_rainbow(args: argparse.Namespace = get_args()) -> None: +def test_rainbow(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -127,12 +128,16 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) - # policy.set_eps(1) + + # initial data collection + policy.set_eps(args.eps_train) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) + # log log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) @@ -221,7 +226,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None: @@ -234,3 +241,8 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None: args.gamma = 0.95 args.seed = 1 test_rainbow(args) + + +def test_rainbow_determinism() -> None: + main_fn = lambda args: test_rainbow(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_rainbow", main_fn, get_args()).run() diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 3409dab0a..9e6a08dc9 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,5 +1,6 @@ import argparse import os +from test.determinism_test import AlgorithmDeterminismTest import gymnasium as gym import numpy as np @@ -50,7 +51,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: +def test_discrete_sac( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, +) -> None: env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -140,4 +144,11 @@ def stop_fn(mean_rewards: float) -> bool: update_per_step=args.update_per_step, test_in_train=False, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_discrete_sac_determinism() -> None: + main_fn = lambda args: test_discrete_sac(args, enable_assertions=False) + AlgorithmDeterminismTest("discrete_sac", main_fn, get_args()).run() diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index 4a131e5fd..1d649f15e 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -8,8 +8,6 @@ class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( task="CartPole-v1", - train_seed=42, - test_seed=1337, venv_type=VectorEnvType.DUMMY, ) @@ -18,7 +16,5 @@ class ContinuousTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( task="Pendulum-v1", - train_seed=42, - test_seed=1337, venv_type=VectorEnvType.DUMMY, ) diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 2ed910902..c409fdb3f 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -2,6 +2,7 @@ import datetime import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_bcq(args: argparse.Namespace = get_args()) -> None: +def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -201,4 +202,11 @@ def watch() -> None: logger=logger, show_progress=args.show_progress, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_bcq_determinism() -> None: + main_fn = lambda args: test_bcq(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_bcq", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index bd84098ba..ea8b6ac11 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -2,7 +2,7 @@ import datetime import os import pickle -import pprint +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -66,7 +66,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_cql(args: argparse.Namespace = get_args()) -> None: +def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -194,10 +194,12 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, logger=logger, ) + stats = trainer.run() - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - pprint.pprint(epoch_stat) - # print(info) + if enable_assertions: + assert stop_fn(stats.best_reward) - assert stop_fn(epoch_stat.info_stat.best_reward) + +def test_cql_determinism() -> None: + main_fn = lambda args: test_cql(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_cql", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index e69e0a1fa..f7d34ba50 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym @@ -36,7 +37,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--unlikely-action-threshold", type=float, default=0.6) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--update-per-epoch", type=int, default=2000) + parser.add_argument("--step-per-epoch", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) @@ -53,7 +54,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: +def test_discrete_bcq( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, +) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -155,7 +159,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, + step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, @@ -164,10 +168,17 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: test_discrete_bcq() args.resume = True test_discrete_bcq(args) + + +def test_discrete_bcq_determinism() -> None: + main_fn = lambda args: test_discrete_bcq(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_discrete_bcq", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 97766d494..373b9a074 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym @@ -35,7 +36,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--target-update-freq", type=int, default=500) parser.add_argument("--min-q-weight", type=float, default=10.0) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--update-per-epoch", type=int, default=1000) + parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64]) parser.add_argument("--test-num", type=int, default=100) @@ -50,7 +51,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: +def test_discrete_cql( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, +) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -118,7 +122,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, + step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, @@ -126,4 +130,10 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() - assert stop_fn(result.best_reward) + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_discrete_cql_determinism() -> None: + main_fn = lambda args: test_discrete_cql(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_discrete_cql", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index bf9a833a9..3593cf206 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_cartpole_data import expert_file_name, gather_data import gymnasium as gym @@ -33,7 +34,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--n-step", type=int, default=3) parser.add_argument("--target-update-freq", type=int, default=320) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--update-per-epoch", type=int, default=1000) + parser.add_argument("--step-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) @@ -48,7 +49,10 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: +def test_discrete_crr( + args: argparse.Namespace = get_args(), + enable_assertions: bool = True, +) -> None: # envs env = gym.make(args.task) assert isinstance(env.action_space, gym.spaces.Discrete) @@ -122,7 +126,7 @@ def stop_fn(mean_rewards: float) -> bool: buffer=buffer, test_collector=test_collector, max_epoch=args.epoch, - step_per_epoch=args.update_per_epoch, + step_per_epoch=args.step_per_epoch, episode_per_test=args.test_num, batch_size=args.batch_size, stop_fn=stop_fn, @@ -130,4 +134,10 @@ def stop_fn(mean_rewards: float) -> bool: logger=logger, ).run() - assert stop_fn(result.best_reward) + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_discrete_crr_determinism() -> None: + main_fn = lambda args: test_discrete_crr(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_discrete_crr", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index c7f183587..98a6b6c48 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -1,6 +1,7 @@ import argparse import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_gail(args: argparse.Namespace = get_args()) -> None: +def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -220,4 +221,11 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str: resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ).run() - assert stop_fn(result.best_reward) + + if enable_assertions: + assert stop_fn(result.best_reward) + + +def test_gail_determinism() -> None: + main_fn = lambda args: test_gail(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_gail", main_fn, get_args(), is_offline=True).run() diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 17c3afb06..40b529c68 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -2,6 +2,7 @@ import datetime import os import pickle +from test.determinism_test import AlgorithmDeterminismTest from test.offline.gather_pendulum_data import expert_file_name, gather_data import gymnasium as gym @@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def test_td3_bc(args: argparse.Namespace = get_args()) -> None: +def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None: if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) @@ -181,10 +182,12 @@ def stop_fn(mean_rewards: float) -> bool: stop_fn=stop_fn, logger=logger, ) + stats = trainer.run() - for epoch_stat in trainer: - print(f"Epoch: {epoch_stat.epoch}") - print(epoch_stat) - # print(info) + if enable_assertions: + assert stop_fn(stats.best_reward) - assert stop_fn(epoch_stat.info_stat.best_reward) + +def test_td3_bc_determinism() -> None: + main_fn = lambda args: test_td3_bc(args, enable_assertions=False) + AlgorithmDeterminismTest("offline_td3_bc", main_fn, get_args(), is_offline=True).run() diff --git a/tianshou/config.py b/tianshou/config.py new file mode 100644 index 000000000..23cbb0cb2 --- /dev/null +++ b/tianshou/config.py @@ -0,0 +1,2 @@ +ENABLE_VALIDATION = False +"""Validation can help catching bugs and issues but it slows down training and collection. Enable it only if needed.""" diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 70478a87d..c7ee7505b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -983,7 +983,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None: - if isinstance(batches, BatchProtocol | dict): + if isinstance(batches, Batch | dict): batches = [batches] # check input format batch_list = [] @@ -1069,7 +1069,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None { batch_key for batch_key, obj in batch.items() - if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0) + if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0) } for batch in batches ] @@ -1080,7 +1080,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None if all(isinstance(element, torch.Tensor) for element in value): self.__dict__[shared_key] = torch.stack(value, axis) # third often - elif all(isinstance(element, BatchProtocol | dict) for element in value): + elif all(isinstance(element, Batch | dict) for element in value): self.__dict__[shared_key] = Batch.stack(value, axis) else: # most often case is np.ndarray try: @@ -1114,7 +1114,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None value = batch.get(key) # TODO: fix code/annotations s.t. the ignores can be removed if ( - isinstance(value, BatchProtocol) # type: ignore + isinstance(value, Batch) # type: ignore and len(value.get_keys()) == 0 # type: ignore ): continue # type: ignore @@ -1288,7 +1288,7 @@ def set_array_at_key( ) from exception else: existing_entry = self[key] - if isinstance(existing_entry, BatchProtocol): + if isinstance(existing_entry, Batch): raise ValueError( f"Cannot set sequence at key {key} because it is a nested batch, " f"can only set a subsequence of an array.", @@ -1312,7 +1312,7 @@ def hasnull(self) -> bool: def is_any_true(boolean_batch: BatchProtocol) -> bool: for val in boolean_batch.values(): - if isinstance(val, BatchProtocol): + if isinstance(val, Batch): if is_any_true(val): return True else: @@ -1375,7 +1375,7 @@ def _apply_batch_values_func_recursively( """ result = batch if inplace else deepcopy(batch) for key, val in batch.__dict__.items(): - if isinstance(val, BatchProtocol): + if isinstance(val, Batch): result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False) else: result[key] = values_transform(val) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 5c4451d57..72c7af5bb 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -3,6 +3,7 @@ import h5py import numpy as np +from sensai.util.pickle import setstate from tianshou.data import Batch from tianshou.data.batch import ( @@ -77,6 +78,7 @@ def __init__( ignore_obs_next: bool = False, save_only_last_obs: bool = False, sample_avail: bool = False, + random_seed: int = 42, **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError ) -> None: # TODO: why do we need this? Just for readout? @@ -96,12 +98,21 @@ def __init__( self._save_only_last_obs = save_only_last_obs self._sample_avail = sample_avail self._meta = cast(RolloutBatchProtocol, Batch()) + self._random_state = np.random.RandomState(random_seed) # Keep in sync with reset! self.last_index = np.array([0]) self._insertion_idx = self._size = 0 self._ep_return, self._ep_len, self._ep_start_idx = 0.0, 0, 0 + def __setstate__(self, state: dict[str, Any]) -> None: + setstate( + ReplayBuffer, + self, + state, + new_default_properties={"_random_state": np.random.RandomState(42)}, + ) + @property def subbuffer_edges(self) -> np.ndarray: """Edges of contained buffers, mostly needed as part of the VectorReplayBuffer interface. @@ -134,8 +145,7 @@ def _get_start_stop_tuples_for_edge_crossing_interval( if stop >= start: raise ValueError( f"Expected stop < start, but got {start=}, {stop=}. " - f"For stop larger than start this method should never be called, " - f"and stop=start should never occur. This can occur either due to an implementation error, " + f"For stop larger-equal than start this method should never be called. This can occur either due to an implementation error, " f"or due a bad configuration of the buffer that resulted in a single episode being so long that " f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). " f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the " @@ -202,7 +212,7 @@ def get_buffer_indices(self, start: int, stop: int) -> np.ndarray: f"Start and stop indices must be within the same subbuffer. " f"Got {start=} in subbuffer edge {start_left_edge} and {stop=} in subbuffer edge {stop_left_edge}.", ) - if stop > start: + if stop >= start: return np.arange(start, stop, dtype=int) else: (start, upper_edge), ( @@ -230,9 +240,6 @@ def __getattr__(self, key: str) -> Any: except KeyError as exception: raise AttributeError from exception - def __setstate__(self, state: dict[str, Any]) -> None: - self.__dict__.update(state) - def __setattr__(self, key: str, value: Any) -> None: assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned" super().__setattr__(key, value) @@ -499,7 +506,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: batch_size = len(self) if self.stack_num == 1 or not self._sample_avail: # most often case if batch_size > 0: - return np.random.choice(self._size, batch_size) + return self._random_state.choice(self._size, batch_size) # TODO: is this behavior really desired? if batch_size == 0: # construct current available indices return np.concatenate( @@ -520,7 +527,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: prev_indices = self.prev(prev_indices) all_indices = all_indices[prev_indices != self.prev(prev_indices)] if batch_size > 0: - return np.random.choice(all_indices, batch_size) + return self._random_state.choice(all_indices, batch_size) return all_indices def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]: diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 087f8d0b0..eb03c1595 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -150,12 +150,12 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: ep_obs = self[unique_ep_indices].obs # to satisfy mypy # TODO: add protocol covering these batches - assert isinstance(ep_obs, BatchProtocol) + assert isinstance(ep_obs, Batch) ep_rew = self[unique_ep_indices].rew if self._save_obs_next: ep_obs_next = self[unique_ep_indices].obs_next # to satisfy mypy - assert isinstance(ep_obs_next, BatchProtocol) + assert isinstance(ep_obs_next, Batch) future_obs = self[future_t[unique_ep_close_indices]].obs_next else: future_obs = self[self.next(future_t[unique_ep_close_indices])].obs @@ -172,7 +172,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: ep_rew[:, her_ep_indices] = self._compute_reward(ep_obs_next)[:, her_ep_indices] else: tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs - assert isinstance(tmp_ep_obs_next, BatchProtocol) + assert isinstance(tmp_ep_obs_next, Batch) ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices] # Sanity check @@ -181,7 +181,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: assert ep_rew.shape == unique_ep_indices.shape # Re-write meta - assert isinstance(self._meta.obs, BatchProtocol) + assert isinstance(self._meta.obs, Batch) self._meta.obs[unique_ep_indices] = ep_obs if self._save_obs_next: self._meta.obs_next[unique_ep_indices] = ep_obs_next # type: ignore diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index e8176aa8c..370c358be 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -208,11 +208,11 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: return all_indices if batch_size is None: batch_size = len(all_indices) - return np.random.choice(all_indices, batch_size) + return self._random_state.choice(all_indices, batch_size) if batch_size == 0 or batch_size is None: # get all available indices sample_num = np.zeros(self.buffer_num, int) else: - buffer_idx = np.random.choice( + buffer_idx = self._random_state.choice( self.buffer_num, batch_size, p=self._lengths / self._lengths.sum(), diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 514db4cc0..3c6d75d4d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -12,6 +12,7 @@ from overrides import override from torch.distributions import Categorical, Distribution +from tianshou.config import ENABLE_VALIDATION from tianshou.data import ( Batch, CachedReplayBuffer, @@ -32,6 +33,7 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.policy.base import episode_mc_return_to_go +from tianshou.utils.determinism import TraceLogger from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode @@ -41,6 +43,8 @@ _TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") +TScalarArrayShape = TypeVar("TScalarArrayShape") + class CollectActionBatchProtocol(Protocol): """A protocol for results of computing actions from a batch of observations within a single collect step. @@ -315,8 +319,32 @@ def __init__( exploration_noise: bool = False, # The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737 collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] - raise_on_nan_in_buffer: bool = True, + raise_on_nan_in_buffer: bool = ENABLE_VALIDATION, ) -> None: + """ + :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch + of actions from a batch of observations. + :param env: a ``gymnasium.Env`` environment or a vectorized instance of the + :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with + a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv` + will be constructed internally from the passed env) + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs) + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. + the rollout batch with this hook also modifies the data that is collected to the buffer! + :param raise_on_nan_in_buffer: whether to raise a `RuntimeError` if NaNs are found in the buffer after + a collection step. Especially useful when episode-level hooks are passed for making + sure that nothing is broken during the collection. Consider setting to False if + the NaN-check becomes a bottleneck. + :param collect_stats_class: the class to use for collecting statistics. Allows customizing + the stats collection logic by passing a subclass of :class:`CollectStats`. Changing + this is rarely necessary and is mainly done by "power users". + """ if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy @@ -554,7 +582,7 @@ def __init__( exploration_noise: bool = False, on_episode_done_hook: Optional["EpisodeRolloutHookProtocol"] = None, on_step_hook: Optional["StepHookProtocol"] = None, - raise_on_nan_in_buffer: bool = True, + raise_on_nan_in_buffer: bool = ENABLE_VALIDATION, collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ @@ -571,7 +599,7 @@ def __init__( :param exploration_noise: determine whether the action needs to be modified with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the - exploration noise into action.. + exploration noise into action. :param on_episode_done_hook: if passed will be executed when an episode is done. The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else). If a dict is returned by the hook it will be used to add new entries to the buffer @@ -777,10 +805,13 @@ def _collect( # noqa: C901 # TODO: can't do it init since AsyncCollector is currently a subclass of Collector if self.env.is_async: raise ValueError( - f"Please use {AsyncCollector.__name__} for asynchronous environments. " + f"Please use AsyncCollector for asynchronous environments. " f"Env class: {self.env.__class__.__name__}.", ) + ready_env_ids_R: np.ndarray[Any, np.dtype[np.signedinteger]] + """provides a mapping from local indices (indexing within `1, ..., R` where `R` is the number of ready envs) + to global ones (indexing within `1, ..., num_envs`). So the entry i in this array is the global index of the i-th ready env.""" if n_step is not None: ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: @@ -839,6 +870,7 @@ def _collect( # noqa: C901 last_info_R=last_info_R, last_hidden_state_RH=last_hidden_state_RH, ) + TraceLogger.log(log, lambda: f"Action: {collect_action_computation_batch_R.act}") # Step 3 obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( @@ -913,6 +945,8 @@ def _collect( # noqa: C901 # local_idx - see block comment on class level # Step 7 env_done_local_idx_D = np.where(done_R)[0] + """Indexes which episodes are done within the ready envs, so it can be used for selecting from `..._R` arrays. + Stands in contrast to the "global" index, which counts within all envs and is unsuitable for selecting from `..._R` arrays.""" episode_lens_D = ep_len_R[env_done_local_idx_D] episode_returns_D = ep_return_R[env_done_local_idx_D] episode_start_indices_D = ep_start_idx_R[env_done_local_idx_D] @@ -931,6 +965,10 @@ def _collect( # noqa: C901 # 0,...,R and this global index is maintained by the ready_env_ids_R array. # See the class block comment for more details env_done_global_idx_D = ready_env_ids_R[env_done_local_idx_D] + """Indexes which episodes are done within all envs, i.e., within the index `1, ..., num_envs`. It can be + used to communicate with the vector env, where env ids are selected from this "global" index. + Is not suited for selecting from the ready envs (`..._R` arrays), use the local counterpart instead. + """ obs_reset_DO, info_reset_D = self.env.reset( env_id=env_done_global_idx_D, **gym_reset_kwargs, @@ -1032,7 +1070,7 @@ def _collect( # noqa: C901 break # Check if we screwed up somewhere - if self.buffer.hasnull(): + if self.raise_on_nan_in_buffer and self.buffer.hasnull(): nan_batch = self.buffer.isnull().apply_values_transform(np.sum) raise MalformedBufferError( diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index ac35ccf3f..ca31b9ac3 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -69,7 +69,14 @@ def wait( raise NotImplementedError def seed(self, seed: int | None = None) -> list[int] | None: - return self.action_space.seed(seed) # issue 299 + """ + Seeds the environment's action space sampler. + NOTE: This does *not* seed the environment itself. + + :param seed: the random seed + :return: a list containing the resulting seed used + """ + return self.action_space.seed(seed) @abstractmethod def render(self, **kwargs: Any) -> Any: diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index 2b8ff5131..dc6205cc6 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -44,13 +44,18 @@ def from_data_dict(cls, data: dict) -> "LoggedCollectStats": Converts SequenceSummaryStats from dict format to dataclass format and ignores fields that are not present. """ + dataclass_data = {} field_names = [f.name for f in fields(cls)] for k, v in data.items(): if k not in field_names: - data.pop(k) + log.info( + f"Key {k} in data dict is not a valid field of LoggedCollectStats, ignoring it.", + ) + continue if isinstance(v, dict): - data[k] = LoggedSummaryData(**v) - return cls(**data) + v = LoggedSummaryData(**v) + dataclass_data[k] = v + return cls(**dataclass_data) @dataclass @@ -114,14 +119,23 @@ def load_from_disk( data = logger_cls.restore_logged_data(entry.path) # TODO: align low-level and high-level dir structure. This is a hack! if not data: + log.info( + f"Could not find data in {entry.path}, trying to restore from subdirectory.", + ) dirs = [ d for d in os.listdir(entry.path) if os.path.isdir(os.path.join(entry.path, d)) ] if len(dirs) != 1: - raise ValueError( - f"Could not restore data from {entry.path}, " - f"expected either events or exactly one subdirectory, ", + _error_message = ( + f"Could not restore experiment data from {entry.path}, " + f"expected either events or exactly one subdirectory, but got {dirs=}. " ) + if not dirs: + _error_message += ( + "The absence of events/subdirectory may be due to an error causing the training to stop or due to" + " too few environment steps, leading to no data being logged." + ) + raise ValueError(_error_message) data = logger_cls.restore_logged_data(os.path.join(entry.path, dirs[0])) if not data: raise ValueError(f"Could not restore data from {entry.path}.") diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 2ad533983..a023f0190 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -320,6 +320,10 @@ def __init__( def _get_policy_class(self) -> type[TPolicy]: pass + @abstractmethod + def _include_actor_in_optim(self) -> bool: + pass + def create_actor_critic_module_opt( self, envs: Environments, @@ -329,7 +333,10 @@ def create_actor_critic_module_opt( actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) actor_critic = ActorCritic(actor, critic) - optim = self.optim_factory.create_optimizer(actor_critic, lr) + if self._include_actor_in_optim(): + optim = self.optim_factory.create_optimizer(actor_critic, lr) + else: + optim = self.optim_factory.create_optimizer(critic, lr) return ActorCriticOpt(actor_critic, optim) @typing.no_type_check @@ -356,21 +363,33 @@ def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]): + def _include_actor_in_optim(self) -> bool: + return True + def _get_policy_class(self) -> type[A2CPolicy]: return A2CPolicy class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]): + def _include_actor_in_optim(self) -> bool: + return True + def _get_policy_class(self) -> type[PPOPolicy]: return PPOPolicy class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]): + def _include_actor_in_optim(self) -> bool: + return False + def _get_policy_class(self) -> type[NPGPolicy]: return NPGPolicy class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]): + def _include_actor_in_optim(self) -> bool: + return False + def _get_policy_class(self) -> type[TRPOPolicy]: return TRPOPolicy diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index ac27cba1a..fb58c8a58 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -56,9 +56,6 @@ class SamplingConfig(ToStringMixin): num_train_envs: int = -1 """the number of training environments to use. If set to -1, use number of CPUs/threads.""" - train_seed: int = 42 - """the seed to use for the training environments.""" - num_test_envs: int = 1 """the number of test environments to use""" @@ -165,10 +162,6 @@ class SamplingConfig(ToStringMixin): Currently only used in Atari examples and may be removed in the future! """ - @property - def test_seed(self) -> int: - return self.train_seed + self.num_train_envs - def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index e61e0ed36..45d85d9d5 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -7,6 +7,7 @@ import gymnasium as gym import gymnasium.spaces +import numpy as np from gymnasium import Env from sensai.util.pickle import setstate from sensai.util.string import ToStringMixin @@ -370,11 +371,58 @@ def __init__(self, venv_type: VectorEnvType): """ self.venv_type = venv_type + @staticmethod + def _create_rng(seed: int | None) -> np.random.Generator: + """ + Creates a random number generator with the given seed. + + :param seed: the seed to use; if None, a random seed will be used + :return: the random number generator + """ + return np.random.default_rng(seed=seed) + + @staticmethod + def _next_seed(rng: np.random.Generator) -> int: + """ + Samples a random seed from the given random number generator. + + :param rng: the random number generator + :return: the sampled random seed + """ + # int32 is needed for envpool compatibility + return int(rng.integers(0, 2**31, dtype=np.int32)) + @abstractmethod - def create_env(self, mode: EnvMode) -> Env: - pass + def _create_env(self, mode: EnvMode) -> Env: + """Creates a single environment for the given mode. + + :param mode: the mode + :return: an environment + """ + + def create_env(self, mode: EnvMode, seed: int | None = None) -> Env: + """ + Creates a single environment for the given mode. + + :param mode: the mode + :param seed: the random seed to use for the environment; if None, the seed will not be specified, + and gymnasium will use a random seed. + :return: the environment + """ + env = self._create_env(mode) - def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: + # initialize the environment with the given seed (if any) + if seed is not None: + rng = self._create_rng(seed) + env.np_random = rng + # also set the seed member within the environment such that it can be retrieved + # (gymnasium's random seed handling is, unfortunately, broken) + if hasattr(env, "_np_random_seed"): + env._np_random_seed = seed + + return env + + def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: """Create vectorized environments. :param num_envs: the number of environments @@ -383,28 +431,47 @@ def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: :return: the vectorized environments """ + rng = self._create_rng(seed) + + def create_factory_fn() -> Callable[[], Env]: + # create a factory function that uses a sampled random seed + return lambda random_seed=self._next_seed(rng): self.create_env(mode, seed=random_seed) # type: ignore + + # create the vectorized environment, seeded appropriately if mode == EnvMode.WATCH: - return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)]) + venv = VectorEnvType.DUMMY.create_venv([create_factory_fn()]) else: - return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs) + venv = self.venv_type.create_venv([create_factory_fn() for _ in range(num_envs)]) + + # seed the action samplers + venv.seed([self._next_seed(rng) for _ in range(num_envs)]) + + return venv def create_envs( self, num_training_envs: int, num_test_envs: int, create_watch_env: bool = False, + seed: int | None = None, ) -> Environments: """Create environments for learning. :param num_training_envs: the number of training environments :param num_test_envs: the number of test environments :param create_watch_env: whether to create an environment for watching the agent + :param seed: the random seed to use for environment creation :return: the environments """ + rng = self._create_rng(seed) env = self.create_env(EnvMode.TRAIN) - train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN) - test_envs = self.create_venv(num_test_envs, EnvMode.TEST) - watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None + train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN, seed=self._next_seed(rng)) + test_envs = self.create_venv(num_test_envs, EnvMode.TEST, seed=self._next_seed(rng)) + watch_env = ( + self.create_venv(1, EnvMode.WATCH, seed=self._next_seed(rng)) + if create_watch_env + else None + ) match EnvType.from_env(env): case EnvType.DISCRETE: return DiscreteEnvironments(env, train_envs, test_envs, watch_env) @@ -423,8 +490,6 @@ def __init__( self, *, task: str, - train_seed: int, - test_seed: int, venv_type: VectorEnvType, envpool_factory: EnvPoolFactory | None = None, render_mode_train: str | None = None, @@ -444,8 +509,6 @@ def __init__( super().__init__(venv_type) self.task = task self.envpool_factory = envpool_factory - self.train_seed = train_seed - self.test_seed = test_seed self.render_modes = { EnvMode.TRAIN: render_mode_train, EnvMode.TEST: render_mode_test, @@ -476,7 +539,7 @@ def _create_kwargs(self, mode: EnvMode) -> dict: kwargs["render_mode"] = self.render_modes.get(mode) return kwargs - def create_env(self, mode: EnvMode) -> Env: + def _create_env(self, mode: EnvMode) -> Env: """Creates a single environment for the given mode. :param mode: the mode @@ -485,17 +548,15 @@ def create_env(self, mode: EnvMode) -> Env: kwargs = self._create_kwargs(mode) return gymnasium.make(self.task, **kwargs) - def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv: - seed = self.train_seed if mode == EnvMode.TRAIN else self.test_seed + def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv: if self.envpool_factory is not None: + rng = self._create_rng(seed) return self.envpool_factory.create_venv( self.task, num_envs, mode, - seed, + self._next_seed(rng), self._create_kwargs(mode), ) else: - venv = super().create_venv(num_envs, mode) - venv.seed(seed) - return venv + return super().create_venv(num_envs, mode, seed=seed) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index c0be23dca..4df648fa9 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -118,7 +118,7 @@ @dataclass -class ExperimentConfig(ToStringMixin, DataclassPPrintMixin): +class ExperimentConfig: """Generic config for setting up the experiment, not RL or training specific.""" seed: int = 42 @@ -219,13 +219,7 @@ def get_seeding_info_as_str(self) -> str: This can be useful for creating unique experiment names based on seeds, e.g. A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`. """ - return "_".join( - [ - f"exp_seed={self.config.seed}", - f"train_seed={self.sampling_config.train_seed}", - f"test_seed={self.sampling_config.test_seed}", - ], - ) + return f"exp_seed={self.config.seed}" def _set_seed(self) -> None: seed = self.config.seed @@ -298,6 +292,7 @@ def create_experiment_world( self.sampling_config.num_train_envs, self.sampling_config.num_test_envs, create_watch_env=self.config.watch, + seed=self.config.seed, ) log.info(f"Created {envs}") @@ -672,13 +667,10 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection: Each experiment in the collection will have a unique name created from the original experiment name and the seeds used. """ - num_train_envs = self.sampling_config.num_train_envs - seeded_experiments = [] for i in range(num_experiments): builder = self.copy() builder.experiment_config.seed += i - builder.sampling_config.train_seed += i * num_train_envs experiment = builder.build() experiment.name += f"_{experiment.get_seeding_info_as_str()}" seeded_experiments.append(experiment) diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index 4a1fe5c2e..ceb1262f7 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -140,7 +140,8 @@ def _create_factory(self, envs: Environments) -> ActorFactory: raise ValueError(self.continuous_actor_type) elif env_type == EnvType.DISCRETE: factory = ActorFactoryDiscreteNet( - self.DEFAULT_HIDDEN_SIZES, + self.hidden_sizes, + activation=self.hidden_activation, softmax_output=self.discrete_softmax, ) else: diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 2b662eb44..1c5d60438 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -21,8 +21,18 @@ def create_auto_alpha( class AutoAlphaFactoryDefault(AutoAlphaFactory): - def __init__(self, lr: float = 3e-4): + def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0): + """ + :param lr: the learning rate for the optimizer of the alpha parameter + :param target_entropy_coefficient: the coefficient with which to multiply the target entropy; + The base value being scaled is `dim(A)` for continuous action spaces and `log(|A|)` for discrete action spaces, + i.e. with the default coefficient -1, we obtain `-dim(A)` and `-log(dim(A))` for continuous and discrete action + spaces respectively, which gives a reasonable trade-off between exploration and exploitation. + For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98); + 1.0 would give full entropy exploration. + """ self.lr = lr + self.target_entropy_coefficient = target_entropy_coefficient def create_auto_alpha( self, @@ -30,7 +40,11 @@ def create_auto_alpha( optim_factory: OptimizerFactory, device: TDevice, ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: - target_entropy = float(-np.prod(envs.get_action_shape())) + action_dim = np.prod(envs.get_action_shape()) + if envs.get_type().is_continuous(): + target_entropy = self.target_entropy_coefficient * float(action_dim) + else: + target_entropy = self.target_entropy_coefficient * np.log(action_dim) log_alpha = torch.zeros(1, requires_grad=True, device=device) alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) return target_entropy, log_alpha, alpha_optim diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 066a23a3b..ced0043d1 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -12,6 +12,7 @@ from numba import njit from numpy.typing import ArrayLike from overrides import override +from sensai.util.hash import pickle_hash from torch import nn from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as @@ -25,6 +26,7 @@ RolloutBatchProtocol, ) from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.determinism import TraceLogger from tianshou.utils.net.common import RandomActor from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode @@ -541,6 +543,7 @@ def update( return TrainingStats() # type: ignore[return-value] start_time = time.time() batch, indices = buffer.sample(sample_size) + TraceLogger.log(logger, lambda: f"Updating with batch: indices={pickle_hash(indices)}") self.updating = True batch = self.process_fn(batch, buffer, indices) with torch_train_mode(self): diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 8c1374709..95b9527d2 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -236,7 +236,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRL for minibatch in batch.split(size=1): obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next obs_next = cast(np.ndarray, obs_next) - assert not isinstance(obs, BatchProtocol), "Observations cannot be Batches here" + assert not isinstance(obs, Batch), "Observations cannot be Batches here" + obs = cast(np.ndarray, obs) trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += minibatch.rew rew_square_sum[obs, act] += minibatch.rew**2 diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 9e04d3feb..005454396 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -38,7 +38,8 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer for the critic network only. The actor network + is optimized via natural gradients internally. :param dist_fn: distribution class for computing the action. :param action_space: env's action space :param optim_critic_iters: Number of times to optimize critic network per update. diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 80bcff672..01f059df8 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,3 +1,4 @@ +import logging import warnings from collections.abc import Callable from dataclasses import dataclass @@ -27,6 +28,9 @@ from tianshou.utils.net.continuous import ActorProb from tianshou.utils.net.discrete import Actor +log = logging.getLogger(__name__) + + # Dimension Naming Convention # B - Batch Size # A - Action @@ -231,5 +235,4 @@ def learn( # type: ignore losses.append(loss.item()) loss_summary_stat = SequenceSummaryStats.from_sequence(losses) - return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value] diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index e7aa5cfd5..51a2d7cf0 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -32,7 +32,8 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) - :param optim: the optimizer for actor and critic network. + :param optim: the optimizer for the critic network only. The actor network + is optimized via natural gradients internally. :param dist_fn: distribution class for computing the action. :param action_space: env's action space :param max_kl: max kl-divergence used to constrain each actor network update. diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 355c9d33b..ec4645741 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -7,6 +7,7 @@ from functools import partial import numpy as np +import torch import tqdm from tianshou.data import ( @@ -27,6 +28,7 @@ LazyLogger, MovAvg, ) +from tianshou.utils.determinism import TraceLogger, torch_param_hash from tianshou.utils.logging import set_numerical_fields_to_precision from tianshou.utils.torch_utils import policy_within_training_step @@ -262,6 +264,7 @@ def _reset_collectors(self, reset_buffer: bool = False) -> None: def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" + TraceLogger.log(log, lambda: "Trainer reset") self.is_run = False self.env_step = 0 if self.resume_from_log: @@ -308,8 +311,37 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No self.stop_fn_flag = False self.iter_num = 0 + self._log_params(self.policy) + + def _log_params(self, module: torch.nn.Module) -> None: + """Logs the parameters of the module to the trace logger by subcomponent (if the trace logger is enabled).""" + from tianshou.utils.net.common import ActorCritic + + if not TraceLogger.is_enabled: + return + + def module_has_params(m: torch.nn.Module) -> bool: + return any(p.requires_grad for p in m.parameters()) + + relevant_modules = {} + + def gather_modules(m: torch.nn.Module) -> None: + for name, submodule in m.named_children(): + if isinstance(submodule, ActorCritic): + gather_modules(submodule) + else: + if module_has_params(submodule): + relevant_modules[name] = submodule + + gather_modules(module) + + for name, module in sorted(relevant_modules.items()): + TraceLogger.log( + log, + lambda: f"Params[{name}]: {torch_param_hash(module)}", + ) + def __iter__(self): # type: ignore - self.reset(reset_collectors=True, reset_buffer=False) return self def __next__(self) -> EpochStats: @@ -329,23 +361,30 @@ def __next__(self) -> EpochStats: # perform n step_per_epoch steps_done_in_this_epoch = 0 with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t: - train_stat: CollectStatsBase + TraceLogger.log(log, lambda: f"Epoch #{self.epoch} start") + collect_stats: CollectStatsBase while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: - train_stat, update_stat, self.stop_fn_flag = self.training_step() + TraceLogger.log(log, lambda: "Training step") + collect_stats, training_stats, self.stop_fn_flag = self.training_step() + TraceLogger.log( + log, + lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict() if training_stats else None}", + ) + self._log_params(self.policy) - if isinstance(train_stat, CollectStats): + if isinstance(collect_stats, CollectStats): pbar_data_dict = { "env_step": str(self.env_step), "env_episode": str(self.env_episode), "rew": f"{self.last_rew:.2f}", "len": str(int(self.last_len)), - "n/ep": str(train_stat.n_collected_episodes), - "n/st": str(train_stat.n_collected_steps), + "n/ep": str(collect_stats.n_collected_episodes), + "n/st": str(collect_stats.n_collected_steps), } # t might be disabled, we track the steps manually - t.update(train_stat.n_collected_steps) - steps_done_in_this_epoch += train_stat.n_collected_steps + t.update(collect_stats.n_collected_steps) + steps_done_in_this_epoch += collect_stats.n_collected_steps if self.stop_fn_flag: t.set_postfix(**pbar_data_dict) @@ -354,7 +393,7 @@ def __next__(self) -> EpochStats: # Code should be restructured! pbar_data_dict = {} assert self.buffer, "No train_collector or buffer specified" - train_stat = CollectStatsBase( + collect_stats = CollectStatsBase( n_collected_steps=len(self.buffer), ) @@ -408,9 +447,9 @@ def __next__(self) -> EpochStats: # in case trainer is used with run(), epoch_stat will not be returned return EpochStats( epoch=self.epoch, - train_collect_stat=train_stat, + train_collect_stat=collect_stats, test_collect_stat=test_stat, - training_stat=update_stat, + training_stat=training_stats, info_stat=info_stat, ) @@ -499,8 +538,12 @@ def _collect_training_data(self) -> CollectStats: n_step=self.step_per_collect, n_episode=self.episode_per_collect, ) + TraceLogger.log( + log, + lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes", + ) - if self.train_collector.buffer.hasnull(): + if self.train_collector.raise_on_nan_in_buffer and self.train_collector.buffer.hasnull(): from tianshou.data.collector import EpisodeRolloutHook from tianshou.env import DummyVectorEnv @@ -611,19 +654,21 @@ def policy_update_fn( stats of the whole dataset """ - def run(self, reset_prior_to_run: bool = True, reset_buffer: bool = False) -> InfoStats: + def run(self, reset_collectors: bool = True, reset_buffer: bool = False) -> InfoStats: """Consume iterator. See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque). - :param reset_prior_to_run: whether to reset collectors prior to run - :param reset_buffer: only has effect if `reset_prior_to_run` is True. - Then it will also reset the buffer. This is usually not necessary, use - with caution. + :param reset_collectors: whether to reset the collectors prior to starting the training process. + Specifically, this will reset the environments in the collectors (starting new episodes), + and the statistics stored in the collector. Whether the contained buffers will be reset/cleared + is determined by the `reset_buffer` parameter. + :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the + contained buffers as well. + This has no effect if `reset_collectors` is False. """ - if reset_prior_to_run: - self.reset(reset_buffer=reset_buffer) + self.reset(reset_collectors=reset_collectors, reset_buffer=reset_buffer) try: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py new file mode 100644 index 000000000..5747eeb12 --- /dev/null +++ b/tianshou/utils/determinism.py @@ -0,0 +1,397 @@ +import difflib +import inspect +import os +import re +import time +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from io import StringIO +from pathlib import Path +from typing import Self + +import torch +from sensai.util import logging +from sensai.util.git import GitStatus, git_status +from sensai.util.pickle import dump_pickle, load_pickle + + +def format_log_message( + logger: logging.Logger, + level: int, + msg: str, + formatter: logging.Formatter, + stacklevel: int = 1, +) -> str: + """ + Formats a log message as it would have been created by `logger.log(level, msg)` with the given formatter. + + :param logger: the logger + :param level: the log level + :param msg: the message + :param formatter: the formatter + :param stacklevel: the stack level of the function to report as the generator + :return: the formatted log message (not including trailing newline) + """ + frame_info = inspect.stack()[stacklevel] + pathname = frame_info.filename + lineno = frame_info.lineno + func = frame_info.function + + record = logger.makeRecord( + name=logger.name, + level=level, + fn=pathname, + lno=lineno, + msg=msg, + args=(), + exc_info=None, + func=func, + extra=None, + ) + record.created = time.time() + record.asctime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created)) + + return formatter.format(record) + + +class TraceLogger: + """Supports the collection of behavioural trace logs, which can, in particular, be used for determinism tests.""" + + is_enabled = False + """ + whether the trace logger is enabled. + + NOTE: The preferred way to enable this is via the context manager. + """ + verbose = False + """ + whether to print trace log messages to stdout. + """ + MESSAGE_TAG = "[TRACE]" + """ + a tag which is added at the beginning of log messages generated by this logger + """ + LOG_LEVEL = logging.DEBUG + log_buffer: StringIO | None = None + log_formatter: logging.Formatter | None = None + + @classmethod + def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> None: + """ + Logs a message intended for tracing agent-env interaction, which is enabled via + `TraceAgentEnvLoggerContext`. + + :param logger: the logger to use for the actual logging + :param message_generator: function which generates the log message (which may be expensive); + if logging is disabled, the function will not be called. + """ + if not cls.is_enabled: + return + + msg = message_generator() + msg = cls.MESSAGE_TAG + " " + msg + + # Log with caller's frame info + logger.log(logging.DEBUG, msg, stacklevel=2) + + # If a dedicated memory buffer is configured, also store the message there + if cls.log_buffer is not None: + msg_formatted = format_log_message( + logger, + logging.DEBUG, + msg, + cls.log_formatter, + stacklevel=2, + ) + cls.log_buffer.write(msg_formatted + "\n") + if cls.verbose: + print(msg_formatted) + + +@dataclass +class TraceLog: + log_lines: list[str] + + def save_log(self, path: str) -> None: + with open(path, "w") as f: + for line in self.log_lines: + f.write(line + "\n") + + def print_log(self) -> None: + for line in self.log_lines: + print(line) + + def get_full_log(self) -> str: + return "\n".join(self.log_lines) + + def reduce_log_to_messages(self) -> "TraceLog": + """ + Removes logger names and function names from the log entries, such that each log message + contains only the main text message itself (starting with the content after the logger's tag). + + :return: the result with reduced log messages + """ + lines = [] + tag = re.escape(TraceLogger.MESSAGE_TAG) + for line in self.log_lines: + lines.append(re.sub(r".*" + tag, "", line)) + return TraceLog(lines) + + def filter_messages( + self, + required_messages: Sequence[str] = (), + optional_messages: Sequence[str] = (), + ignored_messages: Sequence[str] = (), + ) -> "TraceLog": + """ + Applies inclusion and or exclusion filtering to the log messages. + If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied. + If `ignored_messages` is empty, exclusion filtering is applied. + If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence. + + :param required_messages: required message substrings to filter for; each message is required to appear at least once + (triggering exception otherwise) + :param optional_messages: additional messages fragments to filter for; these are not required + :param ignored_messages: message fragments that result in exclusion; takes precedence over + `required_messages` and `optional_messages` + :return: the result with reduced log messages + """ + import numpy as np + + required_message_counters = np.zeros(len(required_messages)) + + def retain_line(line: str) -> bool: + for ignored_message in ignored_messages: + if ignored_message in line: + return False + if required_messages or optional_messages: + for i, main_message in enumerate(required_messages): + if main_message in line: + required_message_counters[i] += 1 + return True + return any(add_message in line for add_message in optional_messages) + else: + return True + + lines = [] + for line in self.log_lines: + if retain_line(line): + lines.append(line) + + assert np.all( + required_message_counters > 0, + ), "Not all types of required messages were found in the trace. Were log messages changed?" + + return TraceLog(lines) + + +class TraceLoggerContext: + """ + A context manager which enables the trace logger. + Apart from enabling the logging, it can optionally create a memory log buffer, such that + getting the trace log is not strictly dependent on the logging system. + """ + + def __init__( + self, + enable_log_buffer: bool = True, + log_format: str = "%(name)s:%(funcName)s - %(message)s", + ) -> None: + """ + :param enable_log_buffer: whether to enable the dedicated log buffer for trace logs, whose contents + can, within the context of this manager, be accessed via method `get_log`. + :param log_format: the logger format string to use for the dedicated log buffer + """ + self._enable_log_buffer = enable_log_buffer + self._log_format: str = log_format + self._log_buffer: StringIO | None = None + + def __enter__(self) -> Self: + TraceLogger.is_enabled = True + + if self._enable_log_buffer: + TraceLogger.log_buffer = StringIO() + TraceLogger.log_formatter = logging.Formatter(self._log_format) + self._log_buffer = TraceLogger.log_buffer + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore + TraceLogger.is_enabled = False + TraceLogger.log_buffer = None + TraceLogger.log_formatter = None + + def get_log(self) -> TraceLog: + """:return: the full trace log that was captured if `enable_log_buffer` was enabled at construction""" + if self._log_buffer is None: + raise Exception( + "This method is only supported if the log buffer is enabled at construction", + ) + return TraceLog(log_lines=self._log_buffer.getvalue().split("\n")) + + +def torch_param_hash(module: torch.nn.Module) -> str: + """ + Computes a hash of the parameters of the given module; parameters not requiring gradients are ignored. + + :param module: a torch module + :return: a hex digest of the parameters of the module + """ + import hashlib + + hasher = hashlib.sha1() + for param in module.parameters(): + if param.requires_grad: + np_array = param.detach().cpu().numpy() + hasher.update(np_array.tobytes()) + return hasher.hexdigest() + + +class TraceDeterminismTest: + def __init__( + self, + base_path: Path, + core_messages: Sequence[str] = (), + ignored_messages: Sequence[str] = (), + log_filename: str | None = None, + ) -> None: + """ + :param base_path: the directory where the reference results are stored (will be created if necessary) + :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core + :param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over + `core_messages` + :param log_filename: the name of the log file to which results are to be written (if any) + """ + base_path.mkdir(parents=True, exist_ok=True) + self.base_path = base_path + self.core_messages = core_messages + self.ignored_messages = ignored_messages + self.log_filename = log_filename + + @dataclass(kw_only=True) + class Result: + git_status: GitStatus + log: TraceLog + + def check( + self, + current_log: TraceLog, + name: str, + create_reference_result: bool = False, + pass_if_core_messages_unchanged: bool = False, + ) -> None: + """ + Checks the given log against the reference result for the given name. + + :param current_log: the result to check + :param name: the name of the reference result; must be unique among all tests! + :param create_reference_result: whether update the reference result with the given result + """ + import pytest + + reference_result_path = self.base_path / f"{name}.pkl.bz2" + current_git_status = git_status() + + if create_reference_result: + current_result = self.Result(git_status=current_git_status, log=current_log) + dump_pickle(current_result, reference_result_path) + + reference_result: TraceDeterminismTest.Result = load_pickle( + reference_result_path, + ) + reference_log = reference_result.log + + current_log_reduced = current_log.reduce_log_to_messages().filter_messages( + ignored_messages=self.ignored_messages, + ) + reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages( + ignored_messages=self.ignored_messages, + ) + + results: list[tuple[TraceLog, str]] = [ + (reference_log_reduced, "expected"), + (current_log_reduced, "current"), + (reference_log, "expected_full"), + (current_log, "current_full"), + ] + + if self.core_messages: + result_main_messages = current_log_reduced.filter_messages( + required_messages=self.core_messages, + ) + reference_result_main_messages = reference_log_reduced.filter_messages( + required_messages=self.core_messages, + ) + results.extend( + [ + (reference_result_main_messages, "expected_core"), + (result_main_messages, "current_core"), + ], + ) + else: + result_main_messages = current_log_reduced + reference_result_main_messages = reference_log_reduced + + logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log() + if logs_equivalent: + status_passed = True + status_message = "OK" + else: + core_messages_unchanged = ( + len(self.core_messages) > 0 + and result_main_messages.get_full_log() + == reference_result_main_messages.get_full_log() + ) + status_passed = core_messages_unchanged and pass_if_core_messages_unchanged + + if status_passed: + status_message = "OK (core messages unchanged)" + else: + # save files for comparison + files = [] + for r, suffix in results: + path = os.path.abspath(f"determinism_{name}_{suffix}.txt") + r.save_log(path) + files.append(path) + + paths_str = "\n".join(files) + main_message = ( + f"Please inspect the changes by diffing the log files:\n{paths_str}\n" + f"If the changes are OK, enable the `create_reference_result` flag temporarily, " + "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n" + ) + + # compute diff and add to message + num_diff_lines_to_show = 30 + for i, line in enumerate( + difflib.unified_diff( + reference_log_reduced.log_lines, + current_log_reduced.log_lines, + fromfile="expected.txt", + tofile="current.txt", + lineterm="", + ), + ): + if i == num_diff_lines_to_show: + break + main_message += line + "\n" + + if core_messages_unchanged: + status_message = ( + "The behaviour log has changed, but the core messages are still the same (so this " + f"probably isn't an issue). {main_message}" + ) + else: + status_message = f"The behaviour log has changed; even the core messages are different. {main_message}" + + # write log message + if self.log_filename: + with open(self.log_filename, "a") as f: + hr = "-" * 100 + f.write(f"\n\n{hr}\nName: {name}\n") + f.write(f"Reference state: {reference_result.git_status}\n") + f.write(f"Current state: {current_git_status}\n") + f.write(f"Test result: {status_message}\n") + + if not status_passed: + pytest.fail(status_message)