diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 000000000..536443860
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,22 @@
+{
+ "name": "Tianshou",
+ "dockerFile": "../Dockerfile",
+ "workspaceFolder": "/workspaces/tianshou",
+ "runArgs": ["--shm-size=1g"],
+ "customizations": {
+ "vscode": {
+ "settings": {
+ "terminal.integrated.shell.linux": "/bin/bash",
+ "python.pythonPath": "/usr/local/bin/python"
+ },
+ "extensions": [
+ "ms-python.python",
+ "ms-toolsai.jupyter",
+ "ms-python.vscode-pylance"
+ ]
+ }
+ },
+ "forwardPorts": [],
+ "postCreateCommand": "poetry install --with dev",
+ "remoteUser": "root"
+ }
\ No newline at end of file
diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 000000000..fa5050fe5
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,14 @@
+data
+logs
+test/log
+docs/jupyter_execute
+docs/.jupyter_cache
+.lsp
+.clj-kondo
+docs/_build
+coverage*
+__pycache__
+*.egg-info
+*.egg
+.*cache
+dist
\ No newline at end of file
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index e1c3f2b3c..364f89649 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -2,6 +2,7 @@
- [ ] I have provided a description of the changes in this Pull Request
- [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md
- [ ] If applicable, I have added tests to cover my changes.
+- [ ] If applicable, I have made sure that the determinism tests run through, meaning that my changes haven't influenced any aspect of training. See info in the contributing documentation.
- [ ] I have reformatted the code using `poe format`
- [ ] I have checked style and types with `poe lint` and `poe type-check`
- [ ] (Optional) I ran tests locally with `poe test`
diff --git a/.gitignore b/.gitignore
index e63e24b00..c6e843c4d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -158,4 +158,7 @@ docs/conf.py
# temporary scripts (for ad-hoc testing), temp folder
/temp
-/temp*.py
\ No newline at end of file
+/temp*.py
+
+# determinism test snapshots
+/test/resources/determinism/
diff --git a/CHANGELOG.md b/CHANGELOG.md
index c032e5553..fcd58f287 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,23 +1,62 @@
# Changelog
-## Unreleased
+## Upcoming Release 1.2.0
### Changes/Improvements
-- trainer:
+- `trainer`:
- Custom scoring now supported for selecting the best model. #1202
-- highlevel:
+- `highlevel`:
- `DiscreteSACExperimentBuilder`: Expose method `with_actor_factory_default` #1248 #1250
-
+ - `ActorFactoryDefault`: Fix parameters for hidden sizes and activation not being
+ passed on in the discrete case (affects `with_actor_factory_default` method of experiment builders)
+ - `ExperimentConfig`: Do not inherit from other classes, as this breaks automatic handling by
+ `jsonargparse` when the class is used to define interfaces (as in high-level API examples)
+ - `AutoAlphaFactoryDefault`: Differentiate discrete and continuous action spaces
+ and allow coefficient to be modified, adding an informative docstring
+ (previous implementation was reasonable only for continuous action spaces)
+ - Adjust usage in `atari_sac_hl` example accordingly.
+ - `NPGAgentFactory`, `TRPOAgentFactory`: Fix optimizer instantiation including the actor parameters
+ (which was misleadingly suggested in the docstring in the respective policy classes; docstrings were fixed),
+ as the actor parameters are intended to be handled via natural gradients internally
+- `data`:
+ - `ReplayBuffer`: Fix collection of empty episodes being disallowed
+ - Collection was slow due to `isinstance` checks on Protocols and due to Buffer integrity validation. This was solved
+ by no longer performing `isinstance` on Protocols and by making the integrity validation disabled by default.
+- Tests:
+ - We have introduced extensive **determinism tests** which allow to validate whether
+ training processes deterministically compute the same results across different development branches.
+ This is an important step towards ensuring reproducibility and consistency, which will be
+ instrumental in supporting Tianshou developers in their work, especially in the context of
+ algorithm development and evaluation.
+
### Breaking Changes
-- data:
- - stats:
- - `InfoStats` has a new non-optional field `best_score` which is used
- for selecting the best model. #1202
+- `trainer`:
+ - `BaseTrainer.run` and `__iter__`: Resetting was never optional prior to running the trainer,
+ yet the recently introduced parameter `reset_prior_to_run` of `run` suggested that it _was_ optional.
+ Yet the parameter was ultimately not respected, because `__iter__` would always call `reset(reset_collectors=True, reset_buffer=False)`
+ regardless. The parameter was removed; instead, the parameters of `run` now mirror the parameters of `reset`,
+ and the implicit `reset` call in `__iter__` was removed.
+ This aligns with upcoming changes in Tianshou v2.0.0.
+ * NOTE: If you have been using a trainer without calling `run` but by directly iterating over it, you
+ will need to call `reset` on the trainer explicitly before iterating over the trainer.
+ * Using a trainer as an iterator is considered deprecated and support for this will be removed in Tianshou v2.0.0.
+- `data`:
+ - `InfoStats` has a new non-optional field `best_score` which is used
+ for selecting the best model. #1202
+- `highlevel`:
+ - Change the way in which seeding is handled: The mechanism introduced in v1.1.0
+ was completely revised:
+ - The `train_seed` and `test_seed` attributes were removed from `SamplingConfig`.
+ Instead, the seeds are derived from the seed defined in `ExperimentConfig`.
+ - Seed attributes of `EnvFactory` classes were removed.
+ Instead, seeds are passed to methods of `EnvFactory`.
## Release 1.1.0
+**NOTE**: This release introduced (potentially severe) performance regressions in data collection, please switch to a newer release for better performance.
+
### Highlights
#### Evaluation Package
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 000000000..4e3827b26
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,42 @@
+# Use the official Python image for the base image.
+FROM --platform=linux/amd64 python:3.11-slim
+
+# Set environment variables to make Python print directly to the terminal and avoid .pyc files.
+ENV PYTHONUNBUFFERED=1
+ENV PYTHONDONTWRITEBYTECODE=1
+
+# Install system dependencies required for the project.
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ curl \
+ build-essential \
+ git \
+ wget \
+ unzip \
+ libvips-dev \
+ gnupg2 \
+ && rm -rf /var/lib/apt/lists/*
+
+
+# Install pipx.
+RUN python3 -m pip install --no-cache-dir pipx \
+ && pipx ensurepath
+
+# Add poetry to the path
+ENV PATH="${PATH}:/root/.local/bin"
+
+# Install the latest version of Poetry using pipx.
+RUN pipx install poetry
+
+# Set the working directory. IMPORTANT: can't be changed as needs to be in sync to the dir where the project is cloned
+# to in the codespace
+WORKDIR /workspaces/tianshou
+
+# Copy the pyproject.toml and poetry.lock files (if available) into the image.
+COPY pyproject.toml poetry.lock* README.md /workspaces/tianshou/
+
+RUN poetry config virtualenvs.create false
+RUN poetry install --no-root --with dev
+
+# The entrypoint will perform an editable install, it is expected that the code is mounted in the container then
+# If you don't want to mount the code, you should override the entrypoint
+ENTRYPOINT ["/bin/bash", "-c", "poetry install --with dev && poetry run jupyter trust notebooks/*.ipynb docs/02_notebooks/*.ipynb && $0 $@"]
\ No newline at end of file
diff --git a/docs/04_contributing/04_contributing.rst b/docs/04_contributing/04_contributing.rst
index 1397e3473..48cf172c8 100644
--- a/docs/04_contributing/04_contributing.rst
+++ b/docs/04_contributing/04_contributing.rst
@@ -2,11 +2,12 @@ Contributing to Tianshou
========================
-Install Develop Version
------------------------
+Install Development Environment
+-------------------------------
Tianshou is built and managed by `poetry `_. For example,
-to install all relevant requirements in editable mode you can simply call
+to install all relevant requirements (and install Tianshou itself in editable mode)
+you can simply call
.. code-block:: bash
@@ -36,9 +37,9 @@ Please set up pre-commit by running
in the main directory. This should make sure that your contribution is properly
formatted before every commit.
-The code is inspected and formatted by `black` and `ruff`. They are executed as
-pre-commit hooks. In addition, `poe the poet` tasks are configured.
-Simply run `poe` to see the available tasks.
+The code is inspected and formatted by ``black`` and ``ruff``. They are executed as
+pre-commit hooks. In addition, ``poe the poet`` tasks are configured.
+Simply run ``poe`` to see the available tasks.
E.g, to format and check the linting manually you can run:
.. code-block:: bash
@@ -47,8 +48,8 @@ E.g, to format and check the linting manually you can run:
$ poe lint
-Type Check
-----------
+Type Checks
+-----------
We use `mypy `_ to check the type annotations. To check, in the main directory, run:
@@ -57,8 +58,8 @@ We use `mypy `_ to check the type annotations.
$ poe type-check
-Test Locally
-------------
+Testing Locally
+---------------
This command will run automatic tests in the main directory
@@ -67,6 +68,30 @@ This command will run automatic tests in the main directory
$ poe test
+Determinism Tests
+~~~~~~~~~~~~~~~~~
+
+We implemented "determinism tests" for Tianshou's algorithms, which allow us to determine
+whether algorithms still compute exactly the same results even after large refactorings.
+These tests are applied by
+
+ 1. creating a behavior snapshot ine the old code branch before the changes and then
+ 2. running the test in the new branch to ensure that the behavior is the same.
+
+Unfortunately, full determinism is difficult to achieve across different platforms and even different
+machines using the same platform an Python environment.
+Therefore, these tests are not carried out in the CI pipeline.
+Instead, it is up to the developer to run them locally and check the results whenever a change
+is made to the code base that could affect algorithm behavior.
+
+Technically, the two steps are handled by setting static flags in class ``AlgorithmDeterminismTest`` and then
+running either the full test suite or a specific determinism test (``test_*_determinism``, e.g. ``test_ddpg_determinism``)
+in the two branches to be compared.
+
+ 1. On the old branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=True`` and run the test(s).
+ 2. On the new branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=False`` and run the test(s).
+ 3. Inspect the test results; find a summary in ``determinism_tests.log``
+
Test by GitHub Actions
----------------------
diff --git a/docs/autogen_rst.py b/docs/autogen_rst.py
index b1a8b18d9..93e2a0954 100644
--- a/docs/autogen_rst.py
+++ b/docs/autogen_rst.py
@@ -114,8 +114,8 @@ def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""
for f in files_in_dir
if os.path.isdir(os.path.join(root, dirname, f)) and not f.startswith("_")
]
- if not module_names:
- log.debug(f"Skipping {dirname} as it does not contain any .py files")
+ if not module_names and "__init__.py" not in files_in_dir:
+ log.debug(f"Skipping {dirname} as it does not contain any modules or __init__.py")
continue
package_qualname = f"{base_package_qualname}.{dirname}"
package_index_rst_path = os.path.join(
diff --git a/docs/create_toc.py b/docs/create_toc.py
index 3e1779cdf..c6add58e5 100644
--- a/docs/create_toc.py
+++ b/docs/create_toc.py
@@ -3,6 +3,6 @@
# This script provides a platform-independent way of making the jupyter-book call (used in pyproject.toml)
toc_file = Path(__file__).parent / "_toc.yml"
-cmd = f"jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >{toc_file}"
+cmd = f'jupyter-book toc from-project docs -e .rst -e .md -e .ipynb >"{toc_file}"'
print(cmd)
os.system(cmd)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
deleted file mode 100644
index 40aa69970..000000000
--- a/docs/spelling_wordlist.txt
+++ /dev/null
@@ -1,294 +0,0 @@
-tianshou
-arXiv
-tanh
-lr
-logits
-env
-envs
-optim
-eps
-timelimit
-TimeLimit
-envpool
-EnvPool
-maxsize
-timestep
-timesteps
-numpy
-ndarray
-stackoverflow
-tensorboard
-state_dict
-len
-tac
-fqf
-iqn
-qrdqn
-rl
-offpolicy
-onpolicy
-quantile
-quantiles
-dqn
-param
-async
-subprocess
-deque
-nn
-equ
-cql
-fn
-boolean
-pre
-np
-cuda
-rnn
-rew
-pre
-perceptron
-bsz
-dataset
-mujoco
-jit
-nstep
-preprocess
-preprocessing
-repo
-ReLU
-namespace
-recv
-th
-utils
-NaN
-linesearch
-hyperparameters
-pseudocode
-entropies
-nn
-config
-cpu
-rms
-debias
-indice
-regularizer
-miniblock
-modularize
-serializable
-softmax
-vectorized
-optimizers
-undiscounted
-submodule
-subclasses
-submodules
-tfevent
-dirichlet
-docstring
-webpage
-formatter
-num
-py
-pythonic
-中文文档位于
-conda
-miniconda
-Amir
-Andreas
-Antonoglou
-Beattie
-Bellemare
-Charles
-Daan
-Demis
-Dharshan
-Fidjeland
-Georg
-Hassabis
-Helen
-Ioannis
-Kavukcuoglu
-King
-Koray
-Kumaran
-Legg
-Mnih
-Ostrovski
-Petersen
-Riedmiller
-Rusu
-Sadik
-Shane
-Stig
-Veness
-Volodymyr
-Wierstra
-Lillicrap
-Pritzel
-Heess
-Erez
-Yuval
-Tassa
-Schulman
-Filip
-Wolski
-Prafulla
-Dhariwal
-Radford
-Oleg
-Klimov
-Kaichao
-Jiayi
-Weng
-Duburcq
-Huayu
-Yi
-Su
-Strens
-Ornstein
-Uhlenbeck
-mse
-gail
-airl
-ppo
-Jupyter
-Colab
-Colaboratory
-IPendulum
-Reacher
-Runtime
-Nvidia
-Enduro
-Qbert
-Seaquest
-subnets
-subprocesses
-isort
-yapf
-pydocstyle
-Args
-tuples
-tuple
-Multi
-multi
-parameterized
-Proximal
-metadata
-GPU
-Dopamine
-builtin
-params
-inplace
-deepcopy
-Gaussian
-stdout
-parallelization
-minibatch
-minibatches
-MLP
-backpropagation
-dataclass
-superset
-subtype
-subdirectory
-picklable
-ShmemVectorEnv
-Github
-wandb
-jupyter
-img
-src
-parallelized
-infty
-venv
-venvs
-subproc
-bcq
-highlevel
-icm
-modelbased
-td
-psrl
-ddpg
-npg
-tf
-trpo
-crr
-pettingzoo
-multidiscrete
-vecbuf
-prio
-colab
-segtree
-multiagent
-mapolicy
-sensai
-sensAI
-docstrings
-superclass
-iterable
-functools
-str
-sklearn
-attr
-bc
-redq
-modelfree
-bdq
-util
-logp
-autogenerated
-subpackage
-subpackages
-recurse
-rollout
-rollouts
-prepend
-prepends
-dict
-dicts
-pytorch
-tensordict
-onwards
-Dominik
-Tsinghua
-Tianshou
-appliedAI
-macOS
-joblib
-master
-Panchenko
-BA
-BH
-BO
-BD
-configs
-postfix
-backend
-rliable
-hl
-v_s
-v_s_
-obs
-obs_next
-dtype
-iqm
-kwarg
-entrypoint
-interquantile
-init
-kwarg
-kwargs
-autocompletion
-codebase
-indexable
-sliceable
-gaussian
-logprob
-monte
-carlo
-subclass
-subclassing
-dist
-dists
-subbuffer
-subbuffers
diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py
index 601481523..3bcb0f6c3 100644
--- a/examples/atari/atari_dqn_hl.py
+++ b/examples/atari/atari_dqn_hl.py
@@ -69,8 +69,6 @@ def main(
env_factory = AtariEnvFactory(
task,
- sampling_config.train_seed,
- sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py
index b71b0eef3..c644b2469 100644
--- a/examples/atari/atari_iqn_hl.py
+++ b/examples/atari/atari_iqn_hl.py
@@ -68,8 +68,6 @@ def main(
env_factory = AtariEnvFactory(
task,
- sampling_config.train_seed,
- sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py
index 26ebaba08..983608293 100644
--- a/examples/atari/atari_ppo_hl.py
+++ b/examples/atari/atari_ppo_hl.py
@@ -73,8 +73,6 @@ def main(
env_factory = AtariEnvFactory(
task,
- sampling_config.train_seed,
- sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py
index 124def768..76f18f55f 100644
--- a/examples/atari/atari_sac_hl.py
+++ b/examples/atari/atari_sac_hl.py
@@ -45,7 +45,6 @@ def main(
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
- save_buffer_name: str | None = None, # TODO add support in high-level API?
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
@@ -69,8 +68,6 @@ def main(
env_factory = AtariEnvFactory(
task,
- sampling_config.train_seed,
- sampling_config.test_seed,
frames_stack,
scale=scale_obs,
)
@@ -84,7 +81,9 @@ def main(
critic2_lr=critic_lr,
gamma=gamma,
tau=tau,
- alpha=AutoAlphaFactoryDefault(lr=alpha_lr) if auto_alpha else alpha,
+ alpha=AutoAlphaFactoryDefault(lr=alpha_lr, target_entropy_coefficient=0.98)
+ if auto_alpha
+ else alpha,
estimation_step=n_step,
),
)
diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py
index d7234d863..25a3b09f2 100644
--- a/examples/atari/atari_wrapper.py
+++ b/examples/atari/atari_wrapper.py
@@ -383,8 +383,8 @@ def make_atari_env(
:return: a tuple of (single env, training envs, test envs).
"""
- env_factory = AtariEnvFactory(task, seed, seed + training_num, frame_stack, scale=bool(scale))
- envs = env_factory.create_envs(training_num, test_num)
+ env_factory = AtariEnvFactory(task, frame_stack, scale=bool(scale))
+ envs = env_factory.create_envs(training_num, test_num, seed=seed)
return envs.env, envs.train_envs, envs.test_envs
@@ -392,8 +392,6 @@ class AtariEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
- train_seed: int,
- test_seed: int,
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
@@ -411,14 +409,12 @@ def __init__(
log.info("Not using envpool, because it is not available")
super().__init__(
task=task,
- train_seed=train_seed,
- test_seed=test_seed,
venv_type=venv_type,
envpool_factory=envpool_factory,
)
- def create_env(self, mode: EnvMode) -> gym.Env:
- env = super().create_env(mode)
+ def _create_env(self, mode: EnvMode) -> gym.Env:
+ env = super()._create_env(mode)
is_train = mode == EnvMode.TRAIN
return wrap_deepmind(
env,
diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py
index 4e52f4ce2..bcf3e7f18 100644
--- a/examples/discrete/discrete_dqn.py
+++ b/examples/discrete/discrete_dqn.py
@@ -83,7 +83,7 @@ def stop_fn(mean_rewards: float) -> bool:
# watch performance
policy.set_eps(eps_test)
collector = ts.data.Collector[CollectStats](policy, env, exploration_noise=True)
- collector.collect(n_episode=100, render=1 / 35)
+ collector.collect(n_episode=100, render=1 / 35, reset_before_collect=True)
if __name__ == "__main__":
diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py
index c804d6c26..bba7f9e76 100644
--- a/examples/mujoco/mujoco_a2c_hl.py
+++ b/examples/mujoco/mujoco_a2c_hl.py
@@ -54,12 +54,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=True,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=True)
experiment = (
A2CExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_ddpg_hl.py b/examples/mujoco/mujoco_ddpg_hl.py
index 27dbfc8d9..daa936533 100644
--- a/examples/mujoco/mujoco_ddpg_hl.py
+++ b/examples/mujoco/mujoco_ddpg_hl.py
@@ -52,12 +52,7 @@ def main(
start_timesteps_random=True,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=False,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=False)
experiment = (
DDPGExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py
index 90f27995b..8aff92793 100644
--- a/examples/mujoco/mujoco_env.py
+++ b/examples/mujoco/mujoco_env.py
@@ -37,7 +37,7 @@ def make_mujoco_env(
:return: a tuple of (single env, training envs, test envs).
"""
- envs = MujocoEnvFactory(task, seed, seed + num_train_envs, obs_norm=obs_norm).create_envs(
+ envs = MujocoEnvFactory(task, obs_norm=obs_norm).create_envs(
num_train_envs,
num_test_envs,
)
@@ -73,28 +73,18 @@ class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
- train_seed: int,
- test_seed: int,
obs_norm: bool = True,
venv_type: VectorEnvType = VectorEnvType.SUBPROC_SHARED_MEM_AUTO,
) -> None:
super().__init__(
task=task,
- train_seed=train_seed,
- test_seed=test_seed,
venv_type=venv_type,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)
self.obs_norm = obs_norm
- def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
- """Create vectorized environments.
-
- :param num_envs: the number of environments
- :param mode: the mode for which to create
- :return: the vectorized environments
- """
- env = super().create_venv(num_envs, mode)
+ def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv:
+ env = super().create_venv(num_envs, mode, seed=seed)
# obs norm wrapper
if self.obs_norm:
env = VectorEnvNormObs(env, update_obs_rms=mode == EnvMode.TRAIN)
@@ -105,8 +95,9 @@ def create_envs(
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
+ seed: int | None = None,
) -> ContinuousEnvironments:
- envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env)
+ envs = super().create_envs(num_training_envs, num_test_envs, create_watch_env, seed=seed)
assert isinstance(envs, ContinuousEnvironments)
if self.obs_norm:
diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py
index 387f87c6e..a231e1b21 100644
--- a/examples/mujoco/mujoco_npg_hl.py
+++ b/examples/mujoco/mujoco_npg_hl.py
@@ -53,12 +53,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=True,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=True)
experiment = (
NPGExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py
index b10d4cf26..973a822a6 100644
--- a/examples/mujoco/mujoco_ppo_hl.py
+++ b/examples/mujoco/mujoco_ppo_hl.py
@@ -58,12 +58,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=True,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=True)
experiment = (
PPOExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py
index 333870809..47ccc9ae2 100644
--- a/examples/mujoco/mujoco_ppo_hl_multi.py
+++ b/examples/mujoco/mujoco_ppo_hl_multi.py
@@ -37,7 +37,7 @@
def main(
num_experiments: int = 5,
run_experiments_sequentially: bool = True,
- logger_type: str = "wandb",
+ logger_type: str = "tensorboard",
) -> RLiableExperimentResult:
""":param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds.
:param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel.
@@ -70,12 +70,7 @@ def main(
repeat_per_collect=1,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=True,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=True)
hidden_sizes = (64, 64)
diff --git a/examples/mujoco/mujoco_redq_hl.py b/examples/mujoco/mujoco_redq_hl.py
index c0c63279a..90f6ef318 100644
--- a/examples/mujoco/mujoco_redq_hl.py
+++ b/examples/mujoco/mujoco_redq_hl.py
@@ -58,12 +58,7 @@ def main(
start_timesteps_random=True,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=False,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=False)
experiment = (
REDQExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py
index 59a600568..f3e8821ae 100644
--- a/examples/mujoco/mujoco_reinforce_hl.py
+++ b/examples/mujoco/mujoco_reinforce_hl.py
@@ -49,12 +49,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=True,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=True)
experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py
index a150f5571..5b2e1519b 100644
--- a/examples/mujoco/mujoco_sac_hl.py
+++ b/examples/mujoco/mujoco_sac_hl.py
@@ -53,12 +53,7 @@ def main(
start_timesteps_random=True,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=False,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=False)
experiment = (
SACExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_td3_hl.py b/examples/mujoco/mujoco_td3_hl.py
index 5ec9cc17b..8ca54d591 100644
--- a/examples/mujoco/mujoco_td3_hl.py
+++ b/examples/mujoco/mujoco_td3_hl.py
@@ -58,12 +58,7 @@ def main(
start_timesteps_random=True,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=False,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=False)
experiment = (
TD3ExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py
index 1ec26bad2..4dfc39185 100644
--- a/examples/mujoco/mujoco_trpo_hl.py
+++ b/examples/mujoco/mujoco_trpo_hl.py
@@ -55,12 +55,7 @@ def main(
repeat_per_collect=repeat_per_collect,
)
- env_factory = MujocoEnvFactory(
- task,
- train_seed=sampling_config.train_seed,
- test_seed=sampling_config.test_seed,
- obs_norm=True,
- )
+ env_factory = MujocoEnvFactory(task, obs_norm=True)
experiment = (
TRPOExperimentBuilder(env_factory, experiment_config, sampling_config)
diff --git a/poetry.lock b/poetry.lock
index 583b90756..2cd8e2ed7 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -5382,13 +5382,13 @@ win32 = ["pywin32"]
[[package]]
name = "sensai-utils"
-version = "1.2.1"
+version = "1.4.0"
description = "Utilities from sensAI, the Python library for sensible AI"
optional = false
python-versions = "*"
files = [
- {file = "sensai_utils-1.2.1-py3-none-any.whl", hash = "sha256:222e60d9f9d371c9d62ffcd1e6def1186f0d5243588b0b5af57e983beecc95bb"},
- {file = "sensai_utils-1.2.1.tar.gz", hash = "sha256:4d8ca94179931798cef5f920fb042cbf9e7d806c0026b02afb58d0f72211bf27"},
+ {file = "sensai_utils-1.4.0-py3-none-any.whl", hash = "sha256:ed6fc57552620e43b33cf364ea0bc0fd7df39391069dd7b621b113ef55547507"},
+ {file = "sensai_utils-1.4.0.tar.gz", hash = "sha256:2d32bdcc91fd1428c5cae0181e98623142d2d5f7e115e23d585a842dd9dc59ba"},
]
[package.dependencies]
@@ -6896,4 +6896,4 @@ vizdoom = ["vizdoom"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "1ea1b72b90269fd86b81b1443785085618248ccf5b62506a166b879115749171"
+content-hash = "bff3f4f8cc0d8196ea162a799472c7179486109d30968aa7d1b96b40016a459f"
diff --git a/pyproject.toml b/pyproject.toml
index 5640f948b..d4535f096 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,7 +35,7 @@ overrides = "^7.4.0"
packaging = "*"
pandas = ">=2.0.0"
pettingzoo = "^1.22"
-sensai-utils = "^1.2.1"
+sensai-utils = "^1.4.0"
tensorboard = "^2.5.0"
# Torch 2.0.1 causes problems, see https://github.com/pytorch/pytorch/issues/100974
torch = "^2.0.0, !=2.0.1, !=2.1.0"
@@ -181,7 +181,8 @@ ignore = [
"PLW2901", # overwrite vars in loop
"B027", # empty and non-abstract method in abstract class
"D404", # It's fine to start with "This" in docstrings
- "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx
+ "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx
+ "B023", # forbids function using loop variable without explicit binding
]
unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
@@ -227,13 +228,12 @@ _poetry_sort = "poetry sort"
clean-nbs = "python docs/nbstripout.py"
format = ["_ruff_format", "_ruff_format_nb", "_black_format", "_poetry_install_sort_plugin", "_poetry_sort"]
_autogen_rst = "python docs/autogen_rst.py"
-_sphinx_build = "sphinx-build -W -b html docs docs/_build"
+_sphinx_build = "sphinx-build -b html docs docs/_build -W --keep-going"
_jb_generate_toc = "python docs/create_toc.py"
_jb_generate_config = "jupyter-book config sphinx docs/"
doc-clean = "rm -rf docs/_build"
doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"]
-doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build"
-doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"]
+doc-build = ["doc-generate-files", "_sphinx_build"]
_mypy = "mypy tianshou test examples"
_mypy_nb = "nbqa mypy docs"
type-check = ["_mypy", "_mypy_nb"]
diff --git a/test/base/test_batch.py b/test/base/test_batch.py
index 5fa40758f..f8af8c521 100644
--- a/test/base/test_batch.py
+++ b/test/base/test_batch.py
@@ -51,10 +51,6 @@ def test_batch() -> None:
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
- with pytest.raises(TypeError):
- Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
- with pytest.raises(TypeError):
- Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
assert torch.allclose(batch.a, torch.ones(2, 3))
batch.cat_(batch)
assert torch.allclose(batch.a, torch.ones(4, 3))
diff --git a/test/base/test_collector.py b/test/base/test_collector.py
index 6355d8bfc..1e4769bbf 100644
--- a/test/base/test_collector.py
+++ b/test/base/test_collector.py
@@ -72,7 +72,7 @@ def forward(
if self.dict_state:
if self.action_shape:
action_shape = self.action_shape
- elif isinstance(batch.obs, BatchProtocol):
+ elif isinstance(batch.obs, Batch):
action_shape = len(batch.obs["index"])
else:
action_shape = len(batch.obs)
diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py
index ce7998eff..1569c0df1 100644
--- a/test/continuous/test_ddpg.py
+++ b/test/continuous/test_ddpg.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -49,7 +50,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_ddpg(args: argparse.Namespace = get_args()) -> None:
+def test_ddpg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
@@ -131,4 +132,11 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_ddpg_determinism() -> None:
+ main_fn = lambda args: test_ddpg(args, enable_assertions=False)
+ AlgorithmDeterminismTest("continuous_ddpg", main_fn, get_args()).run()
diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py
index d853e2186..3b413eec4 100644
--- a/test/continuous/test_npg.py
+++ b/test/continuous/test_npg.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -52,7 +53,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_npg(args: argparse.Namespace = get_args()) -> None:
+def test_npg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
@@ -153,4 +154,11 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_npg_determinism() -> None:
+ main_fn = lambda args: test_npg(args, enable_assertions=False)
+ AlgorithmDeterminismTest("continuous_npg", main_fn, get_args()).run()
diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py
index 4b56bd630..cbb8544ab 100644
--- a/test/continuous/test_ppo.py
+++ b/test/continuous/test_ppo.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -58,7 +59,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_ppo(args: argparse.Namespace = get_args()) -> None:
+def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
@@ -166,7 +167,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
print("Fail to restore policy and optim.")
# trainer
- trainer = OnpolicyTrainer(
+ result = OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
@@ -181,16 +182,17 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn,
- )
-
- for epoch_stat in trainer:
- print(f"Epoch: {epoch_stat.epoch}")
- print(epoch_stat)
- # print(info)
+ ).run()
- assert stop_fn(epoch_stat.info_stat.best_reward)
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
def test_ppo_resume(args: argparse.Namespace = get_args()) -> None:
args.resume = True
test_ppo(args)
+
+
+def test_ppo_determinism() -> None:
+ main_fn = lambda args: test_ppo(args, enable_assertions=False)
+ AlgorithmDeterminismTest("continuous_ppo", main_fn, get_args()).run()
diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py
index 82c8f0637..24a7f420c 100644
--- a/test/continuous/test_redq.py
+++ b/test/continuous/test_redq.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -54,7 +55,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_redq(args: argparse.Namespace = get_args()) -> None:
+def test_redq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Box)
space_info = SpaceInfo.from_env(env)
@@ -162,4 +163,19 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_redq_determinism() -> None:
+ main_fn = lambda args: test_redq(args, enable_assertions=False)
+ ignored_messages = [
+ "Params[actor_old]",
+ ] # actor_old only present in v1 (due to flawed inheritance)
+ AlgorithmDeterminismTest(
+ "continuous_redq",
+ main_fn,
+ get_args(),
+ ignored_messages=ignored_messages,
+ ).run()
diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py
index 09fc3ca45..1d4fc06fe 100644
--- a/test/continuous/test_sac_with_il.py
+++ b/test/continuous/test_sac_with_il.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -57,7 +58,11 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
+def test_sac_with_il(
+ args: argparse.Namespace = get_args(),
+ enable_assertions: bool = True,
+ skip_il: bool = False,
+) -> None:
# if you want to use python vector env, please refer to other test scripts
# train_envs = env = envpool.make_gymnasium(args.task, num_envs=args.training_num, seed=args.seed)
# test_envs = envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed)
@@ -158,7 +163,12 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+ if skip_il:
+ return
# here we define an imitation collector with a trivial policy
if args.task.startswith("Pendulum"):
@@ -203,4 +213,11 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_sac_determinism() -> None:
+ main_fn = lambda args: test_sac_with_il(args, enable_assertions=False, skip_il=True)
+ AlgorithmDeterminismTest("continuous_sac", main_fn, get_args()).run()
diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py
index 6c59ea25a..82fcce0fb 100644
--- a/test/continuous/test_td3.py
+++ b/test/continuous/test_td3.py
@@ -1,6 +1,6 @@
import argparse
import os
-import pprint
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -52,7 +52,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_td3(args: argparse.Namespace = get_args()) -> None:
+def test_td3(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
@@ -135,7 +135,7 @@ def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= args.reward_threshold
# Iterator trainer
- trainer = OffpolicyTrainer(
+ result = OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
@@ -148,10 +148,12 @@ def stop_fn(mean_rewards: float) -> bool:
stop_fn=stop_fn,
save_best_fn=save_best_fn,
logger=logger,
- )
- for epoch_stat in trainer:
- print(f"Epoch: {epoch_stat.epoch}")
- pprint.pprint(epoch_stat)
- # print(info)
+ ).run()
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
- assert stop_fn(epoch_stat.info_stat.best_reward)
+def test_td3_determinism() -> None:
+ main_fn = lambda args: test_td3(args, enable_assertions=False)
+ AlgorithmDeterminismTest("continuous_td3", main_fn, get_args()).run()
diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py
index 91e215116..321e351f3 100644
--- a/test/continuous/test_trpo.py
+++ b/test/continuous/test_trpo.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -54,7 +55,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_trpo(args: argparse.Namespace = get_args()) -> None:
+def test_trpo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
@@ -153,4 +154,11 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_trpo_determinism() -> None:
+ main_fn = lambda args: test_trpo(args, enable_assertions=False)
+ AlgorithmDeterminismTest("continuous_trpo", main_fn, get_args()).run()
diff --git a/test/determinism_test.py b/test/determinism_test.py
new file mode 100644
index 000000000..828825660
--- /dev/null
+++ b/test/determinism_test.py
@@ -0,0 +1,123 @@
+from argparse import Namespace
+from collections.abc import Callable, Sequence
+from pathlib import Path
+from typing import Any
+
+import pytest
+import torch
+
+from tianshou.utils.determinism import TraceDeterminismTest, TraceLoggerContext
+
+
+class TorchDeterministicModeContext:
+ def __init__(self, mode: str | int = "default") -> None:
+ self.new_mode = mode
+ self.original_mode: str | int | None = None
+
+ def __enter__(self) -> None:
+ self.original_mode = torch.get_deterministic_debug_mode()
+ torch.set_deterministic_debug_mode(self.new_mode)
+
+ def __exit__(self, exc_type, exc_value, traceback): # type: ignore
+ assert self.original_mode is not None
+ torch.set_deterministic_debug_mode(self.original_mode)
+
+
+class AlgorithmDeterminismTest:
+ """
+ Represents a determinism test for Tianshou's RL algorithms.
+
+ A test using this class should be added for every algorithm in Tianshou.
+ Then, when making changes to one or more algorithms (e.g. refactoring), run the respective tests
+ on the old branch (creating snapshots) and then on the new branch that contains the changes
+ (comparing with the snapshots).
+
+ Intended usage is therefore:
+
+ 1. On the old branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=True and run the tests.
+ 2. On the new branch: Set ENABLED=True and FORCE_SNAPSHOT_UPDATE=False and run the tests.
+ 3. Inspect determinism_tests.log
+ """
+
+ ENABLED = False
+ """
+ whether determinism tests are enabled.
+ """
+ FORCE_SNAPSHOT_UPDATE = False
+ """
+ whether to force the update/creation of snapshots for every test.
+ Enable this when running on the "old" branch and you want to prepare the snapshots
+ for a comparison with the "new" branch.
+ """
+ PASS_IF_CORE_MESSAGES_UNCHANGED = True
+ """
+ whether to pass the test if only the core messages are unchanged.
+ If this is False, then the full log is required to be equivalent, whereas if it is True,
+ only the core messages need to be equivalent.
+ The core messages test whether the algorithm produces the same network parameters.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ main_fn: Callable[[Namespace], Any],
+ args: Namespace,
+ is_offline: bool = False,
+ ignored_messages: Sequence[str] = (),
+ ):
+ """
+ :param name: the (unique!) name of the test
+ :param main_fn: the function to be called for the test
+ :param args: the arguments to be passed to the main function (some of which are overridden
+ for the test)
+ :param is_offline: whether the algorithm being tested is an offline algorithm and therefore
+ does not configure the number of training environments (`training_num`)
+ :param ignored_messages: message fragments to ignore in the trace log (if any)
+ """
+ self.determinism_test = TraceDeterminismTest(
+ base_path=Path(__file__).parent / "resources" / "determinism",
+ log_filename="determinism_tests.log",
+ core_messages=["Params"],
+ ignored_messages=ignored_messages,
+ )
+ self.name = name
+
+ def set(attr: str, value: Any) -> None:
+ old_value = getattr(args, attr)
+ if old_value is None:
+ raise ValueError(f"Attribute '{attr}' is not defined for args: {args}")
+ setattr(args, attr, value)
+
+ set("epoch", 3)
+ set("step_per_epoch", 100)
+ set("device", "cpu")
+ if not is_offline:
+ set("training_num", 1)
+ set("test_num", 1)
+
+ self.args = args
+ self.main_fn = main_fn
+
+ def run(self, update_snapshot: bool = False) -> None:
+ """
+ :param update_snapshot: whether to update to snapshot (may be centrally overridden by
+ FORCE_SNAPSHOT_UPDATE)
+ """
+ if not self.ENABLED:
+ pytest.skip("Algorithm determinism tests are disabled.")
+
+ if self.FORCE_SNAPSHOT_UPDATE:
+ update_snapshot = True
+
+ # run the actual process
+ with TraceLoggerContext() as trace:
+ with TorchDeterministicModeContext():
+ self.main_fn(self.args)
+ log = trace.get_log()
+
+ self.determinism_test.check(
+ log,
+ self.name,
+ create_reference_result=update_snapshot,
+ pass_if_core_messages_unchanged=self.PASS_IF_CORE_MESSAGES_UNCHANGED,
+ )
diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py
index 192f24c24..cf0f4ef7c 100644
--- a/test/discrete/test_a2c_with_il.py
+++ b/test/discrete/test_a2c_with_il.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -8,7 +9,7 @@
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
-from tianshou.env import DummyVectorEnv, SubprocVectorEnv
+from tianshou.env import DummyVectorEnv
from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer
@@ -59,7 +60,11 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
+def test_a2c_with_il(
+ args: argparse.Namespace = get_args(),
+ enable_assertions: bool = True,
+ skip_il: bool = False,
+) -> None:
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -144,7 +149,12 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+ if skip_il:
+ return
# here we define an imitation collector with a trivial policy
# if args.task == 'CartPole-v1':
@@ -165,9 +175,8 @@ def stop_fn(mean_rewards: float) -> bool:
seed=args.seed,
)
else:
- il_env = SubprocVectorEnv(
+ il_env = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
- context="fork",
)
il_env.seed(args.seed)
@@ -189,4 +198,11 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_ppo_determinism() -> None:
+ main_fn = lambda args: test_a2c_with_il(args, enable_assertions=False, skip_il=True)
+ AlgorithmDeterminismTest("discrete_a2c", main_fn, get_args()).run()
diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py
index 16042f622..719e10d86 100644
--- a/test/discrete/test_bdq.py
+++ b/test/discrete/test_bdq.py
@@ -1,4 +1,5 @@
import argparse
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -113,7 +114,9 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
exploration_noise=True,
)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
@@ -142,3 +145,8 @@ def stop_fn(mean_rewards: float) -> bool:
test_fn=test_fn,
stop_fn=stop_fn,
).run()
+
+
+def test_bdq_determinism() -> None:
+ main_fn = lambda args: test_bdq(args)
+ AlgorithmDeterminismTest("discrete_bdq", main_fn, get_args()).run()
diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py
index 41d6a0260..2876c4406 100644
--- a/test/discrete/test_c51.py
+++ b/test/discrete/test_c51.py
@@ -1,6 +1,7 @@
import argparse
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_c51(args: argparse.Namespace = get_args()) -> None:
+def test_c51(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
space_info = SpaceInfo.from_env(env)
@@ -104,6 +105,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
estimation_step=args.n_step,
target_update_freq=args.target_update_freq,
).to(args.device)
+
# buffer
buf: ReplayBuffer
if args.prioritized_replay:
@@ -115,12 +117,16 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
+
# collector
train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
+
# log
log_path = os.path.join(args.logdir, args.task, "c51")
writer = SummaryWriter(log_path)
@@ -200,7 +206,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
def test_c51_resume(args: argparse.Namespace = get_args()) -> None:
@@ -213,3 +221,8 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None:
args.gamma = 0.95
args.seed = 1
test_c51(args)
+
+
+def test_c51_determinism() -> None:
+ main_fn = lambda args: test_c51(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_c51", main_fn, get_args()).run()
diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py
index f82aca1f6..5c706e884 100644
--- a/test/discrete/test_dqn.py
+++ b/test/discrete/test_dqn.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -55,7 +56,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_dqn(args: argparse.Namespace = get_args()) -> None:
+def test_dqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
space_info = SpaceInfo.from_env(env)
@@ -106,12 +107,16 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
+
# collector
train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
+
# log
log_path = os.path.join(args.logdir, args.task, "dqn")
writer = SummaryWriter(log_path)
@@ -153,7 +158,14 @@ def test_fn(epoch: int, env_step: int | None) -> None:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_dqn_determinism() -> None:
+ main_fn = lambda args: test_dqn(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_dqn", main_fn, get_args()).run()
def test_pdqn(args: argparse.Namespace = get_args()) -> None:
diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py
index 4cc0b6bd0..d3fcdbd89 100644
--- a/test/discrete/test_drqn.py
+++ b/test/discrete/test_drqn.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -47,7 +48,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_drqn(args: argparse.Namespace = get_args()) -> None:
+def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
space_info = SpaceInfo.from_env(env)
@@ -82,6 +83,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
action_space=env.action_space,
target_update_freq=args.target_update_freq,
)
+
# collector
buffer = VectorReplayBuffer(
args.buffer_size,
@@ -90,11 +92,15 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
ignore_obs_next=True,
)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
+
# the stack_num is for RNN training: sample framestack obs
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
+
# log
log_path = os.path.join(args.logdir, args.task, "drqn")
writer = SummaryWriter(log_path)
@@ -129,4 +135,11 @@ def test_fn(epoch: int, env_step: int | None) -> None:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_drqn_determinism() -> None:
+ main_fn = lambda args: test_drqn(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_drqn", main_fn, get_args()).run()
diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py
index e0899d315..2e56b4e38 100644
--- a/test/discrete/test_fqf.py
+++ b/test/discrete/test_fqf.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -60,7 +61,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_fqf(args: argparse.Namespace = get_args()) -> None:
+def test_fqf(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -123,12 +124,16 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
+
# collector
train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
+
# log
log_path = os.path.join(args.logdir, args.task, "fqf")
writer = SummaryWriter(log_path)
@@ -170,10 +175,17 @@ def test_fn(epoch: int, env_step: int | None) -> None:
logger=logger,
update_per_step=args.update_per_step,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
def test_pfqf(args: argparse.Namespace = get_args()) -> None:
args.prioritized_replay = True
args.gamma = 0.95
test_fqf(args)
+
+
+def test_fqf_determinism() -> None:
+ main_fn = lambda args: test_fqf(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_fqf", main_fn, get_args()).run()
diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py
index 08f545b11..57bf28e73 100644
--- a/test/discrete/test_iqn.py
+++ b/test/discrete/test_iqn.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -60,7 +61,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_iqn(args: argparse.Namespace = get_args()) -> None:
+def test_iqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -108,6 +109,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
estimation_step=args.n_step,
target_update_freq=args.target_update_freq,
).to(args.device)
+
# buffer
buf: ReplayBuffer
if args.prioritized_replay:
@@ -119,12 +121,16 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
+
# collector
train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
+
# log
log_path = os.path.join(args.logdir, args.task, "iqn")
writer = SummaryWriter(log_path)
@@ -166,10 +172,17 @@ def test_fn(epoch: int, env_step: int | None) -> None:
logger=logger,
update_per_step=args.update_per_step,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
def test_piqn(args: argparse.Namespace = get_args()) -> None:
args.prioritized_replay = True
args.gamma = 0.95
test_iqn(args)
+
+
+def test_iqn_determinism() -> None:
+ main_fn = lambda args: test_iqn(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_iqn", main_fn, get_args()).run()
diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py
index 8a681583d..a4fb28300 100644
--- a/test/discrete/test_pg.py
+++ b/test/discrete/test_pg.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -44,7 +45,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_pg(args: argparse.Namespace = get_args()) -> None:
+def test_pg(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
@@ -122,4 +123,11 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_pg_determinism() -> None:
+ main_fn = lambda args: test_pg(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_pg", main_fn, get_args()).run()
diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py
index 7e541fffb..4226caf8f 100644
--- a/test/discrete/test_ppo.py
+++ b/test/discrete/test_ppo.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -57,7 +58,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_ppo(args: argparse.Namespace = get_args()) -> None:
+def test_ppo(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
@@ -149,4 +150,11 @@ def stop_fn(mean_rewards: float) -> bool:
save_best_fn=save_best_fn,
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_ppo_determinism() -> None:
+ main_fn = lambda args: test_ppo(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_ppo", main_fn, get_args()).run()
diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py
index 5aa543fb5..afa2592c4 100644
--- a/test/discrete/test_qrdqn.py
+++ b/test/discrete/test_qrdqn.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -56,7 +57,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
+def test_qrdqn(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -112,12 +113,16 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
+
# collector
train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
+
# log
log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path)
@@ -159,10 +164,17 @@ def test_fn(epoch: int, env_step: int | None) -> None:
logger=logger,
update_per_step=args.update_per_step,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
def test_pqrdqn(args: argparse.Namespace = get_args()) -> None:
args.prioritized_replay = True
args.gamma = 0.95
test_qrdqn(args)
+
+
+def test_qrdqn_determinism() -> None:
+ main_fn = lambda args: test_qrdqn(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_qrdqn", main_fn, get_args()).run()
diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py
index d7d4b15b1..92d10b06a 100644
--- a/test/discrete/test_rainbow.py
+++ b/test/discrete/test_rainbow.py
@@ -1,6 +1,7 @@
import argparse
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -64,7 +65,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_rainbow(args: argparse.Namespace = get_args()) -> None:
+def test_rainbow(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -127,12 +128,16 @@ def noisy_linear(x: int, y: int) -> NoisyLinear:
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
+
# collector
train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)
- # policy.set_eps(1)
+
+ # initial data collection
+ policy.set_eps(args.eps_train)
train_collector.reset()
train_collector.collect(n_step=args.batch_size * args.training_num)
+
# log
log_path = os.path.join(args.logdir, args.task, "rainbow")
writer = SummaryWriter(log_path)
@@ -221,7 +226,9 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None:
@@ -234,3 +241,8 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None:
args.gamma = 0.95
args.seed = 1
test_rainbow(args)
+
+
+def test_rainbow_determinism() -> None:
+ main_fn = lambda args: test_rainbow(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_rainbow", main_fn, get_args()).run()
diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py
index 3409dab0a..9e6a08dc9 100644
--- a/test/discrete/test_sac.py
+++ b/test/discrete/test_sac.py
@@ -1,5 +1,6 @@
import argparse
import os
+from test.determinism_test import AlgorithmDeterminismTest
import gymnasium as gym
import numpy as np
@@ -50,7 +51,10 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
+def test_discrete_sac(
+ args: argparse.Namespace = get_args(),
+ enable_assertions: bool = True,
+) -> None:
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -140,4 +144,11 @@ def stop_fn(mean_rewards: float) -> bool:
update_per_step=args.update_per_step,
test_in_train=False,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_discrete_sac_determinism() -> None:
+ main_fn = lambda args: test_discrete_sac(args, enable_assertions=False)
+ AlgorithmDeterminismTest("discrete_sac", main_fn, get_args()).run()
diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py
index 4a131e5fd..1d649f15e 100644
--- a/test/highlevel/env_factory.py
+++ b/test/highlevel/env_factory.py
@@ -8,8 +8,6 @@ class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None:
super().__init__(
task="CartPole-v1",
- train_seed=42,
- test_seed=1337,
venv_type=VectorEnvType.DUMMY,
)
@@ -18,7 +16,5 @@ class ContinuousTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None:
super().__init__(
task="Pendulum-v1",
- train_seed=42,
- test_seed=1337,
venv_type=VectorEnvType.DUMMY,
)
diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py
index 2ed910902..c409fdb3f 100644
--- a/test/offline/test_bcq.py
+++ b/test/offline/test_bcq.py
@@ -2,6 +2,7 @@
import datetime
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
@@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_bcq(args: argparse.Namespace = get_args()) -> None:
+def test_bcq(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
@@ -201,4 +202,11 @@ def watch() -> None:
logger=logger,
show_progress=args.show_progress,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_bcq_determinism() -> None:
+ main_fn = lambda args: test_bcq(args, enable_assertions=False)
+ AlgorithmDeterminismTest("offline_bcq", main_fn, get_args(), is_offline=True).run()
diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py
index bd84098ba..ea8b6ac11 100644
--- a/test/offline/test_cql.py
+++ b/test/offline/test_cql.py
@@ -2,7 +2,7 @@
import datetime
import os
import pickle
-import pprint
+from test.determinism_test import AlgorithmDeterminismTest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
@@ -66,7 +66,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_cql(args: argparse.Namespace = get_args()) -> None:
+def test_cql(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
@@ -194,10 +194,12 @@ def stop_fn(mean_rewards: float) -> bool:
stop_fn=stop_fn,
logger=logger,
)
+ stats = trainer.run()
- for epoch_stat in trainer:
- print(f"Epoch: {epoch_stat.epoch}")
- pprint.pprint(epoch_stat)
- # print(info)
+ if enable_assertions:
+ assert stop_fn(stats.best_reward)
- assert stop_fn(epoch_stat.info_stat.best_reward)
+
+def test_cql_determinism() -> None:
+ main_fn = lambda args: test_cql(args, enable_assertions=False)
+ AlgorithmDeterminismTest("offline_cql", main_fn, get_args(), is_offline=True).run()
diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py
index e69e0a1fa..f7d34ba50 100644
--- a/test/offline/test_discrete_bcq.py
+++ b/test/offline/test_discrete_bcq.py
@@ -1,6 +1,7 @@
import argparse
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
from test.offline.gather_cartpole_data import expert_file_name, gather_data
import gymnasium as gym
@@ -36,7 +37,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--unlikely-action-threshold", type=float, default=0.6)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=5)
- parser.add_argument("--update-per-epoch", type=int, default=2000)
+ parser.add_argument("--step-per-epoch", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
@@ -53,7 +54,10 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
+def test_discrete_bcq(
+ args: argparse.Namespace = get_args(),
+ enable_assertions: bool = True,
+) -> None:
# envs
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -155,7 +159,7 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
buffer=buffer,
test_collector=test_collector,
max_epoch=args.epoch,
- step_per_epoch=args.update_per_epoch,
+ step_per_epoch=args.step_per_epoch,
episode_per_test=args.test_num,
batch_size=args.batch_size,
stop_fn=stop_fn,
@@ -164,10 +168,17 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None:
test_discrete_bcq()
args.resume = True
test_discrete_bcq(args)
+
+
+def test_discrete_bcq_determinism() -> None:
+ main_fn = lambda args: test_discrete_bcq(args, enable_assertions=False)
+ AlgorithmDeterminismTest("offline_discrete_bcq", main_fn, get_args(), is_offline=True).run()
diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py
index 97766d494..373b9a074 100644
--- a/test/offline/test_discrete_cql.py
+++ b/test/offline/test_discrete_cql.py
@@ -1,6 +1,7 @@
import argparse
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
from test.offline.gather_cartpole_data import expert_file_name, gather_data
import gymnasium as gym
@@ -35,7 +36,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--min-q-weight", type=float, default=10.0)
parser.add_argument("--epoch", type=int, default=5)
- parser.add_argument("--update-per-epoch", type=int, default=1000)
+ parser.add_argument("--step-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64])
parser.add_argument("--test-num", type=int, default=100)
@@ -50,7 +51,10 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
+def test_discrete_cql(
+ args: argparse.Namespace = get_args(),
+ enable_assertions: bool = True,
+) -> None:
# envs
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -118,7 +122,7 @@ def stop_fn(mean_rewards: float) -> bool:
buffer=buffer,
test_collector=test_collector,
max_epoch=args.epoch,
- step_per_epoch=args.update_per_epoch,
+ step_per_epoch=args.step_per_epoch,
episode_per_test=args.test_num,
batch_size=args.batch_size,
stop_fn=stop_fn,
@@ -126,4 +130,10 @@ def stop_fn(mean_rewards: float) -> bool:
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_discrete_cql_determinism() -> None:
+ main_fn = lambda args: test_discrete_cql(args, enable_assertions=False)
+ AlgorithmDeterminismTest("offline_discrete_cql", main_fn, get_args(), is_offline=True).run()
diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py
index bf9a833a9..3593cf206 100644
--- a/test/offline/test_discrete_crr.py
+++ b/test/offline/test_discrete_crr.py
@@ -1,6 +1,7 @@
import argparse
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
from test.offline.gather_cartpole_data import expert_file_name, gather_data
import gymnasium as gym
@@ -33,7 +34,7 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--epoch", type=int, default=5)
- parser.add_argument("--update-per-epoch", type=int, default=1000)
+ parser.add_argument("--step-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
@@ -48,7 +49,10 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
+def test_discrete_crr(
+ args: argparse.Namespace = get_args(),
+ enable_assertions: bool = True,
+) -> None:
# envs
env = gym.make(args.task)
assert isinstance(env.action_space, gym.spaces.Discrete)
@@ -122,7 +126,7 @@ def stop_fn(mean_rewards: float) -> bool:
buffer=buffer,
test_collector=test_collector,
max_epoch=args.epoch,
- step_per_epoch=args.update_per_epoch,
+ step_per_epoch=args.step_per_epoch,
episode_per_test=args.test_num,
batch_size=args.batch_size,
stop_fn=stop_fn,
@@ -130,4 +134,10 @@ def stop_fn(mean_rewards: float) -> bool:
logger=logger,
).run()
- assert stop_fn(result.best_reward)
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_discrete_crr_determinism() -> None:
+ main_fn = lambda args: test_discrete_crr(args, enable_assertions=False)
+ AlgorithmDeterminismTest("offline_discrete_crr", main_fn, get_args(), is_offline=True).run()
diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py
index c7f183587..98a6b6c48 100644
--- a/test/offline/test_gail.py
+++ b/test/offline/test_gail.py
@@ -1,6 +1,7 @@
import argparse
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
@@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_gail(args: argparse.Namespace = get_args()) -> None:
+def test_gail(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
@@ -220,4 +221,11 @@ def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn,
).run()
- assert stop_fn(result.best_reward)
+
+ if enable_assertions:
+ assert stop_fn(result.best_reward)
+
+
+def test_gail_determinism() -> None:
+ main_fn = lambda args: test_gail(args, enable_assertions=False)
+ AlgorithmDeterminismTest("offline_gail", main_fn, get_args(), is_offline=True).run()
diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py
index 17c3afb06..40b529c68 100644
--- a/test/offline/test_td3_bc.py
+++ b/test/offline/test_td3_bc.py
@@ -2,6 +2,7 @@
import datetime
import os
import pickle
+from test.determinism_test import AlgorithmDeterminismTest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
@@ -61,7 +62,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
-def test_td3_bc(args: argparse.Namespace = get_args()) -> None:
+def test_td3_bc(args: argparse.Namespace = get_args(), enable_assertions: bool = True) -> None:
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
@@ -181,10 +182,12 @@ def stop_fn(mean_rewards: float) -> bool:
stop_fn=stop_fn,
logger=logger,
)
+ stats = trainer.run()
- for epoch_stat in trainer:
- print(f"Epoch: {epoch_stat.epoch}")
- print(epoch_stat)
- # print(info)
+ if enable_assertions:
+ assert stop_fn(stats.best_reward)
- assert stop_fn(epoch_stat.info_stat.best_reward)
+
+def test_td3_bc_determinism() -> None:
+ main_fn = lambda args: test_td3_bc(args, enable_assertions=False)
+ AlgorithmDeterminismTest("offline_td3_bc", main_fn, get_args(), is_offline=True).run()
diff --git a/tianshou/config.py b/tianshou/config.py
new file mode 100644
index 000000000..23cbb0cb2
--- /dev/null
+++ b/tianshou/config.py
@@ -0,0 +1,2 @@
+ENABLE_VALIDATION = False
+"""Validation can help catching bugs and issues but it slows down training and collection. Enable it only if needed."""
diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py
index 70478a87d..c7ee7505b 100644
--- a/tianshou/data/batch.py
+++ b/tianshou/data/batch.py
@@ -983,7 +983,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
- if isinstance(batches, BatchProtocol | dict):
+ if isinstance(batches, Batch | dict):
batches = [batches]
# check input format
batch_list = []
@@ -1069,7 +1069,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
{
batch_key
for batch_key, obj in batch.items()
- if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0)
+ if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
}
for batch in batches
]
@@ -1080,7 +1080,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
if all(isinstance(element, torch.Tensor) for element in value):
self.__dict__[shared_key] = torch.stack(value, axis)
# third often
- elif all(isinstance(element, BatchProtocol | dict) for element in value):
+ elif all(isinstance(element, Batch | dict) for element in value):
self.__dict__[shared_key] = Batch.stack(value, axis)
else: # most often case is np.ndarray
try:
@@ -1114,7 +1114,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
value = batch.get(key)
# TODO: fix code/annotations s.t. the ignores can be removed
if (
- isinstance(value, BatchProtocol) # type: ignore
+ isinstance(value, Batch) # type: ignore
and len(value.get_keys()) == 0 # type: ignore
):
continue # type: ignore
@@ -1288,7 +1288,7 @@ def set_array_at_key(
) from exception
else:
existing_entry = self[key]
- if isinstance(existing_entry, BatchProtocol):
+ if isinstance(existing_entry, Batch):
raise ValueError(
f"Cannot set sequence at key {key} because it is a nested batch, "
f"can only set a subsequence of an array.",
@@ -1312,7 +1312,7 @@ def hasnull(self) -> bool:
def is_any_true(boolean_batch: BatchProtocol) -> bool:
for val in boolean_batch.values():
- if isinstance(val, BatchProtocol):
+ if isinstance(val, Batch):
if is_any_true(val):
return True
else:
@@ -1375,7 +1375,7 @@ def _apply_batch_values_func_recursively(
"""
result = batch if inplace else deepcopy(batch)
for key, val in batch.__dict__.items():
- if isinstance(val, BatchProtocol):
+ if isinstance(val, Batch):
result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False)
else:
result[key] = values_transform(val)
diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py
index 5c4451d57..72c7af5bb 100644
--- a/tianshou/data/buffer/base.py
+++ b/tianshou/data/buffer/base.py
@@ -3,6 +3,7 @@
import h5py
import numpy as np
+from sensai.util.pickle import setstate
from tianshou.data import Batch
from tianshou.data.batch import (
@@ -77,6 +78,7 @@ def __init__(
ignore_obs_next: bool = False,
save_only_last_obs: bool = False,
sample_avail: bool = False,
+ random_seed: int = 42,
**kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError
) -> None:
# TODO: why do we need this? Just for readout?
@@ -96,12 +98,21 @@ def __init__(
self._save_only_last_obs = save_only_last_obs
self._sample_avail = sample_avail
self._meta = cast(RolloutBatchProtocol, Batch())
+ self._random_state = np.random.RandomState(random_seed)
# Keep in sync with reset!
self.last_index = np.array([0])
self._insertion_idx = self._size = 0
self._ep_return, self._ep_len, self._ep_start_idx = 0.0, 0, 0
+ def __setstate__(self, state: dict[str, Any]) -> None:
+ setstate(
+ ReplayBuffer,
+ self,
+ state,
+ new_default_properties={"_random_state": np.random.RandomState(42)},
+ )
+
@property
def subbuffer_edges(self) -> np.ndarray:
"""Edges of contained buffers, mostly needed as part of the VectorReplayBuffer interface.
@@ -134,8 +145,7 @@ def _get_start_stop_tuples_for_edge_crossing_interval(
if stop >= start:
raise ValueError(
f"Expected stop < start, but got {start=}, {stop=}. "
- f"For stop larger than start this method should never be called, "
- f"and stop=start should never occur. This can occur either due to an implementation error, "
+ f"For stop larger-equal than start this method should never be called. This can occur either due to an implementation error, "
f"or due a bad configuration of the buffer that resulted in a single episode being so long that "
f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). "
f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the "
@@ -202,7 +212,7 @@ def get_buffer_indices(self, start: int, stop: int) -> np.ndarray:
f"Start and stop indices must be within the same subbuffer. "
f"Got {start=} in subbuffer edge {start_left_edge} and {stop=} in subbuffer edge {stop_left_edge}.",
)
- if stop > start:
+ if stop >= start:
return np.arange(start, stop, dtype=int)
else:
(start, upper_edge), (
@@ -230,9 +240,6 @@ def __getattr__(self, key: str) -> Any:
except KeyError as exception:
raise AttributeError from exception
- def __setstate__(self, state: dict[str, Any]) -> None:
- self.__dict__.update(state)
-
def __setattr__(self, key: str, value: Any) -> None:
assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned"
super().__setattr__(key, value)
@@ -499,7 +506,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray:
batch_size = len(self)
if self.stack_num == 1 or not self._sample_avail: # most often case
if batch_size > 0:
- return np.random.choice(self._size, batch_size)
+ return self._random_state.choice(self._size, batch_size)
# TODO: is this behavior really desired?
if batch_size == 0: # construct current available indices
return np.concatenate(
@@ -520,7 +527,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray:
prev_indices = self.prev(prev_indices)
all_indices = all_indices[prev_indices != self.prev(prev_indices)]
if batch_size > 0:
- return np.random.choice(all_indices, batch_size)
+ return self._random_state.choice(all_indices, batch_size)
return all_indices
def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]:
diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py
index 087f8d0b0..eb03c1595 100644
--- a/tianshou/data/buffer/her.py
+++ b/tianshou/data/buffer/her.py
@@ -150,12 +150,12 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
ep_obs = self[unique_ep_indices].obs
# to satisfy mypy
# TODO: add protocol covering these batches
- assert isinstance(ep_obs, BatchProtocol)
+ assert isinstance(ep_obs, Batch)
ep_rew = self[unique_ep_indices].rew
if self._save_obs_next:
ep_obs_next = self[unique_ep_indices].obs_next
# to satisfy mypy
- assert isinstance(ep_obs_next, BatchProtocol)
+ assert isinstance(ep_obs_next, Batch)
future_obs = self[future_t[unique_ep_close_indices]].obs_next
else:
future_obs = self[self.next(future_t[unique_ep_close_indices])].obs
@@ -172,7 +172,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
ep_rew[:, her_ep_indices] = self._compute_reward(ep_obs_next)[:, her_ep_indices]
else:
tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs
- assert isinstance(tmp_ep_obs_next, BatchProtocol)
+ assert isinstance(tmp_ep_obs_next, Batch)
ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices]
# Sanity check
@@ -181,7 +181,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
assert ep_rew.shape == unique_ep_indices.shape
# Re-write meta
- assert isinstance(self._meta.obs, BatchProtocol)
+ assert isinstance(self._meta.obs, Batch)
self._meta.obs[unique_ep_indices] = ep_obs
if self._save_obs_next:
self._meta.obs_next[unique_ep_indices] = ep_obs_next # type: ignore
diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py
index e8176aa8c..370c358be 100644
--- a/tianshou/data/buffer/manager.py
+++ b/tianshou/data/buffer/manager.py
@@ -208,11 +208,11 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray:
return all_indices
if batch_size is None:
batch_size = len(all_indices)
- return np.random.choice(all_indices, batch_size)
+ return self._random_state.choice(all_indices, batch_size)
if batch_size == 0 or batch_size is None: # get all available indices
sample_num = np.zeros(self.buffer_num, int)
else:
- buffer_idx = np.random.choice(
+ buffer_idx = self._random_state.choice(
self.buffer_num,
batch_size,
p=self._lengths / self._lengths.sum(),
diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py
index 514db4cc0..3c6d75d4d 100644
--- a/tianshou/data/collector.py
+++ b/tianshou/data/collector.py
@@ -12,6 +12,7 @@
from overrides import override
from torch.distributions import Categorical, Distribution
+from tianshou.config import ENABLE_VALIDATION
from tianshou.data import (
Batch,
CachedReplayBuffer,
@@ -32,6 +33,7 @@
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy
from tianshou.policy.base import episode_mc_return_to_go
+from tianshou.utils.determinism import TraceLogger
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import torch_train_mode
@@ -41,6 +43,8 @@
_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None")
+TScalarArrayShape = TypeVar("TScalarArrayShape")
+
class CollectActionBatchProtocol(Protocol):
"""A protocol for results of computing actions from a batch of observations within a single collect step.
@@ -315,8 +319,32 @@ def __init__(
exploration_noise: bool = False,
# The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737
collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment]
- raise_on_nan_in_buffer: bool = True,
+ raise_on_nan_in_buffer: bool = ENABLE_VALIDATION,
) -> None:
+ """
+ :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch
+ of actions from a batch of observations.
+ :param env: a ``gymnasium.Env`` environment or a vectorized instance of the
+ :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with
+ a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv`
+ will be constructed internally from the passed env)
+ :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
+ If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
+ of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs)
+ as the default buffer.
+ :param exploration_noise: determine whether the action needs to be modified
+ with the corresponding policy's exploration noise. If so, "policy.
+ exploration_noise(act, batch)" will be called automatically to add the
+ exploration noise into action.
+ the rollout batch with this hook also modifies the data that is collected to the buffer!
+ :param raise_on_nan_in_buffer: whether to raise a `RuntimeError` if NaNs are found in the buffer after
+ a collection step. Especially useful when episode-level hooks are passed for making
+ sure that nothing is broken during the collection. Consider setting to False if
+ the NaN-check becomes a bottleneck.
+ :param collect_stats_class: the class to use for collecting statistics. Allows customizing
+ the stats collection logic by passing a subclass of :class:`CollectStats`. Changing
+ this is rarely necessary and is mainly done by "power users".
+ """
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
# Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy
@@ -554,7 +582,7 @@ def __init__(
exploration_noise: bool = False,
on_episode_done_hook: Optional["EpisodeRolloutHookProtocol"] = None,
on_step_hook: Optional["StepHookProtocol"] = None,
- raise_on_nan_in_buffer: bool = True,
+ raise_on_nan_in_buffer: bool = ENABLE_VALIDATION,
collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment]
) -> None:
"""
@@ -571,7 +599,7 @@ def __init__(
:param exploration_noise: determine whether the action needs to be modified
with the corresponding policy's exploration noise. If so, "policy.
exploration_noise(act, batch)" will be called automatically to add the
- exploration noise into action..
+ exploration noise into action.
:param on_episode_done_hook: if passed will be executed when an episode is done.
The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else).
If a dict is returned by the hook it will be used to add new entries to the buffer
@@ -777,10 +805,13 @@ def _collect( # noqa: C901
# TODO: can't do it init since AsyncCollector is currently a subclass of Collector
if self.env.is_async:
raise ValueError(
- f"Please use {AsyncCollector.__name__} for asynchronous environments. "
+ f"Please use AsyncCollector for asynchronous environments. "
f"Env class: {self.env.__class__.__name__}.",
)
+ ready_env_ids_R: np.ndarray[Any, np.dtype[np.signedinteger]]
+ """provides a mapping from local indices (indexing within `1, ..., R` where `R` is the number of ready envs)
+ to global ones (indexing within `1, ..., num_envs`). So the entry i in this array is the global index of the i-th ready env."""
if n_step is not None:
ready_env_ids_R = np.arange(self.env_num)
elif n_episode is not None:
@@ -839,6 +870,7 @@ def _collect( # noqa: C901
last_info_R=last_info_R,
last_hidden_state_RH=last_hidden_state_RH,
)
+ TraceLogger.log(log, lambda: f"Action: {collect_action_computation_batch_R.act}")
# Step 3
obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step(
@@ -913,6 +945,8 @@ def _collect( # noqa: C901
# local_idx - see block comment on class level
# Step 7
env_done_local_idx_D = np.where(done_R)[0]
+ """Indexes which episodes are done within the ready envs, so it can be used for selecting from `..._R` arrays.
+ Stands in contrast to the "global" index, which counts within all envs and is unsuitable for selecting from `..._R` arrays."""
episode_lens_D = ep_len_R[env_done_local_idx_D]
episode_returns_D = ep_return_R[env_done_local_idx_D]
episode_start_indices_D = ep_start_idx_R[env_done_local_idx_D]
@@ -931,6 +965,10 @@ def _collect( # noqa: C901
# 0,...,R and this global index is maintained by the ready_env_ids_R array.
# See the class block comment for more details
env_done_global_idx_D = ready_env_ids_R[env_done_local_idx_D]
+ """Indexes which episodes are done within all envs, i.e., within the index `1, ..., num_envs`. It can be
+ used to communicate with the vector env, where env ids are selected from this "global" index.
+ Is not suited for selecting from the ready envs (`..._R` arrays), use the local counterpart instead.
+ """
obs_reset_DO, info_reset_D = self.env.reset(
env_id=env_done_global_idx_D,
**gym_reset_kwargs,
@@ -1032,7 +1070,7 @@ def _collect( # noqa: C901
break
# Check if we screwed up somewhere
- if self.buffer.hasnull():
+ if self.raise_on_nan_in_buffer and self.buffer.hasnull():
nan_batch = self.buffer.isnull().apply_values_transform(np.sum)
raise MalformedBufferError(
diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py
index ac35ccf3f..ca31b9ac3 100644
--- a/tianshou/env/worker/base.py
+++ b/tianshou/env/worker/base.py
@@ -69,7 +69,14 @@ def wait(
raise NotImplementedError
def seed(self, seed: int | None = None) -> list[int] | None:
- return self.action_space.seed(seed) # issue 299
+ """
+ Seeds the environment's action space sampler.
+ NOTE: This does *not* seed the environment itself.
+
+ :param seed: the random seed
+ :return: a list containing the resulting seed used
+ """
+ return self.action_space.seed(seed)
@abstractmethod
def render(self, **kwargs: Any) -> Any:
diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py
index 2b8ff5131..dc6205cc6 100644
--- a/tianshou/evaluation/rliable_evaluation_hl.py
+++ b/tianshou/evaluation/rliable_evaluation_hl.py
@@ -44,13 +44,18 @@ def from_data_dict(cls, data: dict) -> "LoggedCollectStats":
Converts SequenceSummaryStats from dict format to dataclass format and ignores fields that are not present.
"""
+ dataclass_data = {}
field_names = [f.name for f in fields(cls)]
for k, v in data.items():
if k not in field_names:
- data.pop(k)
+ log.info(
+ f"Key {k} in data dict is not a valid field of LoggedCollectStats, ignoring it.",
+ )
+ continue
if isinstance(v, dict):
- data[k] = LoggedSummaryData(**v)
- return cls(**data)
+ v = LoggedSummaryData(**v)
+ dataclass_data[k] = v
+ return cls(**dataclass_data)
@dataclass
@@ -114,14 +119,23 @@ def load_from_disk(
data = logger_cls.restore_logged_data(entry.path)
# TODO: align low-level and high-level dir structure. This is a hack!
if not data:
+ log.info(
+ f"Could not find data in {entry.path}, trying to restore from subdirectory.",
+ )
dirs = [
d for d in os.listdir(entry.path) if os.path.isdir(os.path.join(entry.path, d))
]
if len(dirs) != 1:
- raise ValueError(
- f"Could not restore data from {entry.path}, "
- f"expected either events or exactly one subdirectory, ",
+ _error_message = (
+ f"Could not restore experiment data from {entry.path}, "
+ f"expected either events or exactly one subdirectory, but got {dirs=}. "
)
+ if not dirs:
+ _error_message += (
+ "The absence of events/subdirectory may be due to an error causing the training to stop or due to"
+ " too few environment steps, leading to no data being logged."
+ )
+ raise ValueError(_error_message)
data = logger_cls.restore_logged_data(os.path.join(entry.path, dirs[0]))
if not data:
raise ValueError(f"Could not restore data from {entry.path}.")
diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py
index 2ad533983..a023f0190 100644
--- a/tianshou/highlevel/agent.py
+++ b/tianshou/highlevel/agent.py
@@ -320,6 +320,10 @@ def __init__(
def _get_policy_class(self) -> type[TPolicy]:
pass
+ @abstractmethod
+ def _include_actor_in_optim(self) -> bool:
+ pass
+
def create_actor_critic_module_opt(
self,
envs: Environments,
@@ -329,7 +333,10 @@ def create_actor_critic_module_opt(
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
- optim = self.optim_factory.create_optimizer(actor_critic, lr)
+ if self._include_actor_in_optim():
+ optim = self.optim_factory.create_optimizer(actor_critic, lr)
+ else:
+ optim = self.optim_factory.create_optimizer(critic, lr)
return ActorCriticOpt(actor_critic, optim)
@typing.no_type_check
@@ -356,21 +363,33 @@ def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
+ def _include_actor_in_optim(self) -> bool:
+ return True
+
def _get_policy_class(self) -> type[A2CPolicy]:
return A2CPolicy
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
+ def _include_actor_in_optim(self) -> bool:
+ return True
+
def _get_policy_class(self) -> type[PPOPolicy]:
return PPOPolicy
class NPGAgentFactory(ActorCriticAgentFactory[NPGParams, NPGPolicy]):
+ def _include_actor_in_optim(self) -> bool:
+ return False
+
def _get_policy_class(self) -> type[NPGPolicy]:
return NPGPolicy
class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
+ def _include_actor_in_optim(self) -> bool:
+ return False
+
def _get_policy_class(self) -> type[TRPOPolicy]:
return TRPOPolicy
diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py
index ac27cba1a..fb58c8a58 100644
--- a/tianshou/highlevel/config.py
+++ b/tianshou/highlevel/config.py
@@ -56,9 +56,6 @@ class SamplingConfig(ToStringMixin):
num_train_envs: int = -1
"""the number of training environments to use. If set to -1, use number of CPUs/threads."""
- train_seed: int = 42
- """the seed to use for the training environments."""
-
num_test_envs: int = 1
"""the number of test environments to use"""
@@ -165,10 +162,6 @@ class SamplingConfig(ToStringMixin):
Currently only used in Atari examples and may be removed in the future!
"""
- @property
- def test_seed(self) -> int:
- return self.train_seed + self.num_train_envs
-
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()
diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py
index e61e0ed36..45d85d9d5 100644
--- a/tianshou/highlevel/env.py
+++ b/tianshou/highlevel/env.py
@@ -7,6 +7,7 @@
import gymnasium as gym
import gymnasium.spaces
+import numpy as np
from gymnasium import Env
from sensai.util.pickle import setstate
from sensai.util.string import ToStringMixin
@@ -370,11 +371,58 @@ def __init__(self, venv_type: VectorEnvType):
"""
self.venv_type = venv_type
+ @staticmethod
+ def _create_rng(seed: int | None) -> np.random.Generator:
+ """
+ Creates a random number generator with the given seed.
+
+ :param seed: the seed to use; if None, a random seed will be used
+ :return: the random number generator
+ """
+ return np.random.default_rng(seed=seed)
+
+ @staticmethod
+ def _next_seed(rng: np.random.Generator) -> int:
+ """
+ Samples a random seed from the given random number generator.
+
+ :param rng: the random number generator
+ :return: the sampled random seed
+ """
+ # int32 is needed for envpool compatibility
+ return int(rng.integers(0, 2**31, dtype=np.int32))
+
@abstractmethod
- def create_env(self, mode: EnvMode) -> Env:
- pass
+ def _create_env(self, mode: EnvMode) -> Env:
+ """Creates a single environment for the given mode.
+
+ :param mode: the mode
+ :return: an environment
+ """
+
+ def create_env(self, mode: EnvMode, seed: int | None = None) -> Env:
+ """
+ Creates a single environment for the given mode.
+
+ :param mode: the mode
+ :param seed: the random seed to use for the environment; if None, the seed will not be specified,
+ and gymnasium will use a random seed.
+ :return: the environment
+ """
+ env = self._create_env(mode)
- def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
+ # initialize the environment with the given seed (if any)
+ if seed is not None:
+ rng = self._create_rng(seed)
+ env.np_random = rng
+ # also set the seed member within the environment such that it can be retrieved
+ # (gymnasium's random seed handling is, unfortunately, broken)
+ if hasattr(env, "_np_random_seed"):
+ env._np_random_seed = seed
+
+ return env
+
+ def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
@@ -383,28 +431,47 @@ def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
:return: the vectorized environments
"""
+ rng = self._create_rng(seed)
+
+ def create_factory_fn() -> Callable[[], Env]:
+ # create a factory function that uses a sampled random seed
+ return lambda random_seed=self._next_seed(rng): self.create_env(mode, seed=random_seed) # type: ignore
+
+ # create the vectorized environment, seeded appropriately
if mode == EnvMode.WATCH:
- return VectorEnvType.DUMMY.create_venv([lambda: self.create_env(mode)])
+ venv = VectorEnvType.DUMMY.create_venv([create_factory_fn()])
else:
- return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
+ venv = self.venv_type.create_venv([create_factory_fn() for _ in range(num_envs)])
+
+ # seed the action samplers
+ venv.seed([self._next_seed(rng) for _ in range(num_envs)])
+
+ return venv
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
+ seed: int | None = None,
) -> Environments:
"""Create environments for learning.
:param num_training_envs: the number of training environments
:param num_test_envs: the number of test environments
:param create_watch_env: whether to create an environment for watching the agent
+ :param seed: the random seed to use for environment creation
:return: the environments
"""
+ rng = self._create_rng(seed)
env = self.create_env(EnvMode.TRAIN)
- train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
- test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
- watch_env = self.create_venv(1, EnvMode.WATCH) if create_watch_env else None
+ train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN, seed=self._next_seed(rng))
+ test_envs = self.create_venv(num_test_envs, EnvMode.TEST, seed=self._next_seed(rng))
+ watch_env = (
+ self.create_venv(1, EnvMode.WATCH, seed=self._next_seed(rng))
+ if create_watch_env
+ else None
+ )
match EnvType.from_env(env):
case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs, watch_env)
@@ -423,8 +490,6 @@ def __init__(
self,
*,
task: str,
- train_seed: int,
- test_seed: int,
venv_type: VectorEnvType,
envpool_factory: EnvPoolFactory | None = None,
render_mode_train: str | None = None,
@@ -444,8 +509,6 @@ def __init__(
super().__init__(venv_type)
self.task = task
self.envpool_factory = envpool_factory
- self.train_seed = train_seed
- self.test_seed = test_seed
self.render_modes = {
EnvMode.TRAIN: render_mode_train,
EnvMode.TEST: render_mode_test,
@@ -476,7 +539,7 @@ def _create_kwargs(self, mode: EnvMode) -> dict:
kwargs["render_mode"] = self.render_modes.get(mode)
return kwargs
- def create_env(self, mode: EnvMode) -> Env:
+ def _create_env(self, mode: EnvMode) -> Env:
"""Creates a single environment for the given mode.
:param mode: the mode
@@ -485,17 +548,15 @@ def create_env(self, mode: EnvMode) -> Env:
kwargs = self._create_kwargs(mode)
return gymnasium.make(self.task, **kwargs)
- def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
- seed = self.train_seed if mode == EnvMode.TRAIN else self.test_seed
+ def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv:
if self.envpool_factory is not None:
+ rng = self._create_rng(seed)
return self.envpool_factory.create_venv(
self.task,
num_envs,
mode,
- seed,
+ self._next_seed(rng),
self._create_kwargs(mode),
)
else:
- venv = super().create_venv(num_envs, mode)
- venv.seed(seed)
- return venv
+ return super().create_venv(num_envs, mode, seed=seed)
diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py
index c0be23dca..4df648fa9 100644
--- a/tianshou/highlevel/experiment.py
+++ b/tianshou/highlevel/experiment.py
@@ -118,7 +118,7 @@
@dataclass
-class ExperimentConfig(ToStringMixin, DataclassPPrintMixin):
+class ExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42
@@ -219,13 +219,7 @@ def get_seeding_info_as_str(self) -> str:
This can be useful for creating unique experiment names based on seeds, e.g.
A typical example is to do `experiment.name = f"{experiment.name}_{experiment.get_seeding_info_as_str()}"`.
"""
- return "_".join(
- [
- f"exp_seed={self.config.seed}",
- f"train_seed={self.sampling_config.train_seed}",
- f"test_seed={self.sampling_config.test_seed}",
- ],
- )
+ return f"exp_seed={self.config.seed}"
def _set_seed(self) -> None:
seed = self.config.seed
@@ -298,6 +292,7 @@ def create_experiment_world(
self.sampling_config.num_train_envs,
self.sampling_config.num_test_envs,
create_watch_env=self.config.watch,
+ seed=self.config.seed,
)
log.info(f"Created {envs}")
@@ -672,13 +667,10 @@ def build_seeded_collection(self, num_experiments: int) -> ExperimentCollection:
Each experiment in the collection will have a unique name created from the original experiment name
and the seeds used.
"""
- num_train_envs = self.sampling_config.num_train_envs
-
seeded_experiments = []
for i in range(num_experiments):
builder = self.copy()
builder.experiment_config.seed += i
- builder.sampling_config.train_seed += i * num_train_envs
experiment = builder.build()
experiment.name += f"_{experiment.get_seeding_info_as_str()}"
seeded_experiments.append(experiment)
diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py
index 4a1fe5c2e..ceb1262f7 100644
--- a/tianshou/highlevel/module/actor.py
+++ b/tianshou/highlevel/module/actor.py
@@ -140,7 +140,8 @@ def _create_factory(self, envs: Environments) -> ActorFactory:
raise ValueError(self.continuous_actor_type)
elif env_type == EnvType.DISCRETE:
factory = ActorFactoryDiscreteNet(
- self.DEFAULT_HIDDEN_SIZES,
+ self.hidden_sizes,
+ activation=self.hidden_activation,
softmax_output=self.discrete_softmax,
)
else:
diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py
index 2b662eb44..1c5d60438 100644
--- a/tianshou/highlevel/params/alpha.py
+++ b/tianshou/highlevel/params/alpha.py
@@ -21,8 +21,18 @@ def create_auto_alpha(
class AutoAlphaFactoryDefault(AutoAlphaFactory):
- def __init__(self, lr: float = 3e-4):
+ def __init__(self, lr: float = 3e-4, target_entropy_coefficient: float = -1.0):
+ """
+ :param lr: the learning rate for the optimizer of the alpha parameter
+ :param target_entropy_coefficient: the coefficient with which to multiply the target entropy;
+ The base value being scaled is `dim(A)` for continuous action spaces and `log(|A|)` for discrete action spaces,
+ i.e. with the default coefficient -1, we obtain `-dim(A)` and `-log(dim(A))` for continuous and discrete action
+ spaces respectively, which gives a reasonable trade-off between exploration and exploitation.
+ For decidedly stochastic exploration, you can use a positive value closer to 1 (e.g. 0.98);
+ 1.0 would give full entropy exploration.
+ """
self.lr = lr
+ self.target_entropy_coefficient = target_entropy_coefficient
def create_auto_alpha(
self,
@@ -30,7 +40,11 @@ def create_auto_alpha(
optim_factory: OptimizerFactory,
device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
- target_entropy = float(-np.prod(envs.get_action_shape()))
+ action_dim = np.prod(envs.get_action_shape())
+ if envs.get_type().is_continuous():
+ target_entropy = self.target_entropy_coefficient * float(action_dim)
+ else:
+ target_entropy = self.target_entropy_coefficient * np.log(action_dim)
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
return target_entropy, log_alpha, alpha_optim
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index 066a23a3b..ced0043d1 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -12,6 +12,7 @@
from numba import njit
from numpy.typing import ArrayLike
from overrides import override
+from sensai.util.hash import pickle_hash
from torch import nn
from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_numpy, to_torch_as
@@ -25,6 +26,7 @@
RolloutBatchProtocol,
)
from tianshou.utils import MultipleLRSchedulers
+from tianshou.utils.determinism import TraceLogger
from tianshou.utils.net.common import RandomActor
from tianshou.utils.print import DataclassPPrintMixin
from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode
@@ -541,6 +543,7 @@ def update(
return TrainingStats() # type: ignore[return-value]
start_time = time.time()
batch, indices = buffer.sample(sample_size)
+ TraceLogger.log(logger, lambda: f"Updating with batch: indices={pickle_hash(indices)}")
self.updating = True
batch = self.process_fn(batch, buffer, indices)
with torch_train_mode(self):
diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py
index 8c1374709..95b9527d2 100644
--- a/tianshou/policy/modelbased/psrl.py
+++ b/tianshou/policy/modelbased/psrl.py
@@ -236,7 +236,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRL
for minibatch in batch.split(size=1):
obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
obs_next = cast(np.ndarray, obs_next)
- assert not isinstance(obs, BatchProtocol), "Observations cannot be Batches here"
+ assert not isinstance(obs, Batch), "Observations cannot be Batches here"
+ obs = cast(np.ndarray, obs)
trans_count[obs, act, obs_next] += 1
rew_sum[obs, act] += minibatch.rew
rew_square_sum[obs, act] += minibatch.rew**2
diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py
index 9e04d3feb..005454396 100644
--- a/tianshou/policy/modelfree/npg.py
+++ b/tianshou/policy/modelfree/npg.py
@@ -38,7 +38,8 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
- :param optim: the optimizer for actor and critic network.
+ :param optim: the optimizer for the critic network only. The actor network
+ is optimized via natural gradients internally.
:param dist_fn: distribution class for computing the action.
:param action_space: env's action space
:param optim_critic_iters: Number of times to optimize critic network per update.
diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py
index 80bcff672..01f059df8 100644
--- a/tianshou/policy/modelfree/pg.py
+++ b/tianshou/policy/modelfree/pg.py
@@ -1,3 +1,4 @@
+import logging
import warnings
from collections.abc import Callable
from dataclasses import dataclass
@@ -27,6 +28,9 @@
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.discrete import Actor
+log = logging.getLogger(__name__)
+
+
# Dimension Naming Convention
# B - Batch Size
# A - Action
@@ -231,5 +235,4 @@ def learn( # type: ignore
losses.append(loss.item())
loss_summary_stat = SequenceSummaryStats.from_sequence(losses)
-
return PGTrainingStats(loss=loss_summary_stat) # type: ignore[return-value]
diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py
index e7aa5cfd5..51a2d7cf0 100644
--- a/tianshou/policy/modelfree/trpo.py
+++ b/tianshou/policy/modelfree/trpo.py
@@ -32,7 +32,8 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
- :param optim: the optimizer for actor and critic network.
+ :param optim: the optimizer for the critic network only. The actor network
+ is optimized via natural gradients internally.
:param dist_fn: distribution class for computing the action.
:param action_space: env's action space
:param max_kl: max kl-divergence used to constrain each actor network update.
diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py
index 355c9d33b..ec4645741 100644
--- a/tianshou/trainer/base.py
+++ b/tianshou/trainer/base.py
@@ -7,6 +7,7 @@
from functools import partial
import numpy as np
+import torch
import tqdm
from tianshou.data import (
@@ -27,6 +28,7 @@
LazyLogger,
MovAvg,
)
+from tianshou.utils.determinism import TraceLogger, torch_param_hash
from tianshou.utils.logging import set_numerical_fields_to_precision
from tianshou.utils.torch_utils import policy_within_training_step
@@ -262,6 +264,7 @@ def _reset_collectors(self, reset_buffer: bool = False) -> None:
def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None:
"""Initialize or reset the instance to yield a new iterator from zero."""
+ TraceLogger.log(log, lambda: "Trainer reset")
self.is_run = False
self.env_step = 0
if self.resume_from_log:
@@ -308,8 +311,37 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No
self.stop_fn_flag = False
self.iter_num = 0
+ self._log_params(self.policy)
+
+ def _log_params(self, module: torch.nn.Module) -> None:
+ """Logs the parameters of the module to the trace logger by subcomponent (if the trace logger is enabled)."""
+ from tianshou.utils.net.common import ActorCritic
+
+ if not TraceLogger.is_enabled:
+ return
+
+ def module_has_params(m: torch.nn.Module) -> bool:
+ return any(p.requires_grad for p in m.parameters())
+
+ relevant_modules = {}
+
+ def gather_modules(m: torch.nn.Module) -> None:
+ for name, submodule in m.named_children():
+ if isinstance(submodule, ActorCritic):
+ gather_modules(submodule)
+ else:
+ if module_has_params(submodule):
+ relevant_modules[name] = submodule
+
+ gather_modules(module)
+
+ for name, module in sorted(relevant_modules.items()):
+ TraceLogger.log(
+ log,
+ lambda: f"Params[{name}]: {torch_param_hash(module)}",
+ )
+
def __iter__(self): # type: ignore
- self.reset(reset_collectors=True, reset_buffer=False)
return self
def __next__(self) -> EpochStats:
@@ -329,23 +361,30 @@ def __next__(self) -> EpochStats:
# perform n step_per_epoch
steps_done_in_this_epoch = 0
with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t:
- train_stat: CollectStatsBase
+ TraceLogger.log(log, lambda: f"Epoch #{self.epoch} start")
+ collect_stats: CollectStatsBase
while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag:
- train_stat, update_stat, self.stop_fn_flag = self.training_step()
+ TraceLogger.log(log, lambda: "Training step")
+ collect_stats, training_stats, self.stop_fn_flag = self.training_step()
+ TraceLogger.log(
+ log,
+ lambda: f"Training step complete: stats={training_stats.get_loss_stats_dict() if training_stats else None}",
+ )
+ self._log_params(self.policy)
- if isinstance(train_stat, CollectStats):
+ if isinstance(collect_stats, CollectStats):
pbar_data_dict = {
"env_step": str(self.env_step),
"env_episode": str(self.env_episode),
"rew": f"{self.last_rew:.2f}",
"len": str(int(self.last_len)),
- "n/ep": str(train_stat.n_collected_episodes),
- "n/st": str(train_stat.n_collected_steps),
+ "n/ep": str(collect_stats.n_collected_episodes),
+ "n/st": str(collect_stats.n_collected_steps),
}
# t might be disabled, we track the steps manually
- t.update(train_stat.n_collected_steps)
- steps_done_in_this_epoch += train_stat.n_collected_steps
+ t.update(collect_stats.n_collected_steps)
+ steps_done_in_this_epoch += collect_stats.n_collected_steps
if self.stop_fn_flag:
t.set_postfix(**pbar_data_dict)
@@ -354,7 +393,7 @@ def __next__(self) -> EpochStats:
# Code should be restructured!
pbar_data_dict = {}
assert self.buffer, "No train_collector or buffer specified"
- train_stat = CollectStatsBase(
+ collect_stats = CollectStatsBase(
n_collected_steps=len(self.buffer),
)
@@ -408,9 +447,9 @@ def __next__(self) -> EpochStats:
# in case trainer is used with run(), epoch_stat will not be returned
return EpochStats(
epoch=self.epoch,
- train_collect_stat=train_stat,
+ train_collect_stat=collect_stats,
test_collect_stat=test_stat,
- training_stat=update_stat,
+ training_stat=training_stats,
info_stat=info_stat,
)
@@ -499,8 +538,12 @@ def _collect_training_data(self) -> CollectStats:
n_step=self.step_per_collect,
n_episode=self.episode_per_collect,
)
+ TraceLogger.log(
+ log,
+ lambda: f"Collected {collect_stats.n_collected_steps} steps, {collect_stats.n_collected_episodes} episodes",
+ )
- if self.train_collector.buffer.hasnull():
+ if self.train_collector.raise_on_nan_in_buffer and self.train_collector.buffer.hasnull():
from tianshou.data.collector import EpisodeRolloutHook
from tianshou.env import DummyVectorEnv
@@ -611,19 +654,21 @@ def policy_update_fn(
stats of the whole dataset
"""
- def run(self, reset_prior_to_run: bool = True, reset_buffer: bool = False) -> InfoStats:
+ def run(self, reset_collectors: bool = True, reset_buffer: bool = False) -> InfoStats:
"""Consume iterator.
See itertools - recipes. Use functions that consume iterators at C speed
(feed the entire iterator into a zero-length deque).
- :param reset_prior_to_run: whether to reset collectors prior to run
- :param reset_buffer: only has effect if `reset_prior_to_run` is True.
- Then it will also reset the buffer. This is usually not necessary, use
- with caution.
+ :param reset_collectors: whether to reset the collectors prior to starting the training process.
+ Specifically, this will reset the environments in the collectors (starting new episodes),
+ and the statistics stored in the collector. Whether the contained buffers will be reset/cleared
+ is determined by the `reset_buffer` parameter.
+ :param reset_collector_buffers: whether, for the case where the collectors are reset, to reset/clear the
+ contained buffers as well.
+ This has no effect if `reset_collectors` is False.
"""
- if reset_prior_to_run:
- self.reset(reset_buffer=reset_buffer)
+ self.reset(reset_collectors=reset_collectors, reset_buffer=reset_buffer)
try:
self.is_run = True
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
diff --git a/tianshou/utils/determinism.py b/tianshou/utils/determinism.py
new file mode 100644
index 000000000..5747eeb12
--- /dev/null
+++ b/tianshou/utils/determinism.py
@@ -0,0 +1,397 @@
+import difflib
+import inspect
+import os
+import re
+import time
+from collections.abc import Callable, Sequence
+from dataclasses import dataclass
+from io import StringIO
+from pathlib import Path
+from typing import Self
+
+import torch
+from sensai.util import logging
+from sensai.util.git import GitStatus, git_status
+from sensai.util.pickle import dump_pickle, load_pickle
+
+
+def format_log_message(
+ logger: logging.Logger,
+ level: int,
+ msg: str,
+ formatter: logging.Formatter,
+ stacklevel: int = 1,
+) -> str:
+ """
+ Formats a log message as it would have been created by `logger.log(level, msg)` with the given formatter.
+
+ :param logger: the logger
+ :param level: the log level
+ :param msg: the message
+ :param formatter: the formatter
+ :param stacklevel: the stack level of the function to report as the generator
+ :return: the formatted log message (not including trailing newline)
+ """
+ frame_info = inspect.stack()[stacklevel]
+ pathname = frame_info.filename
+ lineno = frame_info.lineno
+ func = frame_info.function
+
+ record = logger.makeRecord(
+ name=logger.name,
+ level=level,
+ fn=pathname,
+ lno=lineno,
+ msg=msg,
+ args=(),
+ exc_info=None,
+ func=func,
+ extra=None,
+ )
+ record.created = time.time()
+ record.asctime = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.created))
+
+ return formatter.format(record)
+
+
+class TraceLogger:
+ """Supports the collection of behavioural trace logs, which can, in particular, be used for determinism tests."""
+
+ is_enabled = False
+ """
+ whether the trace logger is enabled.
+
+ NOTE: The preferred way to enable this is via the context manager.
+ """
+ verbose = False
+ """
+ whether to print trace log messages to stdout.
+ """
+ MESSAGE_TAG = "[TRACE]"
+ """
+ a tag which is added at the beginning of log messages generated by this logger
+ """
+ LOG_LEVEL = logging.DEBUG
+ log_buffer: StringIO | None = None
+ log_formatter: logging.Formatter | None = None
+
+ @classmethod
+ def log(cls, logger: logging.Logger, message_generator: Callable[[], str]) -> None:
+ """
+ Logs a message intended for tracing agent-env interaction, which is enabled via
+ `TraceAgentEnvLoggerContext`.
+
+ :param logger: the logger to use for the actual logging
+ :param message_generator: function which generates the log message (which may be expensive);
+ if logging is disabled, the function will not be called.
+ """
+ if not cls.is_enabled:
+ return
+
+ msg = message_generator()
+ msg = cls.MESSAGE_TAG + " " + msg
+
+ # Log with caller's frame info
+ logger.log(logging.DEBUG, msg, stacklevel=2)
+
+ # If a dedicated memory buffer is configured, also store the message there
+ if cls.log_buffer is not None:
+ msg_formatted = format_log_message(
+ logger,
+ logging.DEBUG,
+ msg,
+ cls.log_formatter,
+ stacklevel=2,
+ )
+ cls.log_buffer.write(msg_formatted + "\n")
+ if cls.verbose:
+ print(msg_formatted)
+
+
+@dataclass
+class TraceLog:
+ log_lines: list[str]
+
+ def save_log(self, path: str) -> None:
+ with open(path, "w") as f:
+ for line in self.log_lines:
+ f.write(line + "\n")
+
+ def print_log(self) -> None:
+ for line in self.log_lines:
+ print(line)
+
+ def get_full_log(self) -> str:
+ return "\n".join(self.log_lines)
+
+ def reduce_log_to_messages(self) -> "TraceLog":
+ """
+ Removes logger names and function names from the log entries, such that each log message
+ contains only the main text message itself (starting with the content after the logger's tag).
+
+ :return: the result with reduced log messages
+ """
+ lines = []
+ tag = re.escape(TraceLogger.MESSAGE_TAG)
+ for line in self.log_lines:
+ lines.append(re.sub(r".*" + tag, "", line))
+ return TraceLog(lines)
+
+ def filter_messages(
+ self,
+ required_messages: Sequence[str] = (),
+ optional_messages: Sequence[str] = (),
+ ignored_messages: Sequence[str] = (),
+ ) -> "TraceLog":
+ """
+ Applies inclusion and or exclusion filtering to the log messages.
+ If either `required_messages` or `optional_messages` is empty, inclusion filtering is applied.
+ If `ignored_messages` is empty, exclusion filtering is applied.
+ If both inclusion and exclusion filtering are applied, the exclusion filtering takes precedence.
+
+ :param required_messages: required message substrings to filter for; each message is required to appear at least once
+ (triggering exception otherwise)
+ :param optional_messages: additional messages fragments to filter for; these are not required
+ :param ignored_messages: message fragments that result in exclusion; takes precedence over
+ `required_messages` and `optional_messages`
+ :return: the result with reduced log messages
+ """
+ import numpy as np
+
+ required_message_counters = np.zeros(len(required_messages))
+
+ def retain_line(line: str) -> bool:
+ for ignored_message in ignored_messages:
+ if ignored_message in line:
+ return False
+ if required_messages or optional_messages:
+ for i, main_message in enumerate(required_messages):
+ if main_message in line:
+ required_message_counters[i] += 1
+ return True
+ return any(add_message in line for add_message in optional_messages)
+ else:
+ return True
+
+ lines = []
+ for line in self.log_lines:
+ if retain_line(line):
+ lines.append(line)
+
+ assert np.all(
+ required_message_counters > 0,
+ ), "Not all types of required messages were found in the trace. Were log messages changed?"
+
+ return TraceLog(lines)
+
+
+class TraceLoggerContext:
+ """
+ A context manager which enables the trace logger.
+ Apart from enabling the logging, it can optionally create a memory log buffer, such that
+ getting the trace log is not strictly dependent on the logging system.
+ """
+
+ def __init__(
+ self,
+ enable_log_buffer: bool = True,
+ log_format: str = "%(name)s:%(funcName)s - %(message)s",
+ ) -> None:
+ """
+ :param enable_log_buffer: whether to enable the dedicated log buffer for trace logs, whose contents
+ can, within the context of this manager, be accessed via method `get_log`.
+ :param log_format: the logger format string to use for the dedicated log buffer
+ """
+ self._enable_log_buffer = enable_log_buffer
+ self._log_format: str = log_format
+ self._log_buffer: StringIO | None = None
+
+ def __enter__(self) -> Self:
+ TraceLogger.is_enabled = True
+
+ if self._enable_log_buffer:
+ TraceLogger.log_buffer = StringIO()
+ TraceLogger.log_formatter = logging.Formatter(self._log_format)
+ self._log_buffer = TraceLogger.log_buffer
+
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
+ TraceLogger.is_enabled = False
+ TraceLogger.log_buffer = None
+ TraceLogger.log_formatter = None
+
+ def get_log(self) -> TraceLog:
+ """:return: the full trace log that was captured if `enable_log_buffer` was enabled at construction"""
+ if self._log_buffer is None:
+ raise Exception(
+ "This method is only supported if the log buffer is enabled at construction",
+ )
+ return TraceLog(log_lines=self._log_buffer.getvalue().split("\n"))
+
+
+def torch_param_hash(module: torch.nn.Module) -> str:
+ """
+ Computes a hash of the parameters of the given module; parameters not requiring gradients are ignored.
+
+ :param module: a torch module
+ :return: a hex digest of the parameters of the module
+ """
+ import hashlib
+
+ hasher = hashlib.sha1()
+ for param in module.parameters():
+ if param.requires_grad:
+ np_array = param.detach().cpu().numpy()
+ hasher.update(np_array.tobytes())
+ return hasher.hexdigest()
+
+
+class TraceDeterminismTest:
+ def __init__(
+ self,
+ base_path: Path,
+ core_messages: Sequence[str] = (),
+ ignored_messages: Sequence[str] = (),
+ log_filename: str | None = None,
+ ) -> None:
+ """
+ :param base_path: the directory where the reference results are stored (will be created if necessary)
+ :param core_messages: message fragments that make up the core of a trace; if empty, all messages are considered core
+ :param ignored_messages: message fragments to ignore in the trace log (if any); takes precedence over
+ `core_messages`
+ :param log_filename: the name of the log file to which results are to be written (if any)
+ """
+ base_path.mkdir(parents=True, exist_ok=True)
+ self.base_path = base_path
+ self.core_messages = core_messages
+ self.ignored_messages = ignored_messages
+ self.log_filename = log_filename
+
+ @dataclass(kw_only=True)
+ class Result:
+ git_status: GitStatus
+ log: TraceLog
+
+ def check(
+ self,
+ current_log: TraceLog,
+ name: str,
+ create_reference_result: bool = False,
+ pass_if_core_messages_unchanged: bool = False,
+ ) -> None:
+ """
+ Checks the given log against the reference result for the given name.
+
+ :param current_log: the result to check
+ :param name: the name of the reference result; must be unique among all tests!
+ :param create_reference_result: whether update the reference result with the given result
+ """
+ import pytest
+
+ reference_result_path = self.base_path / f"{name}.pkl.bz2"
+ current_git_status = git_status()
+
+ if create_reference_result:
+ current_result = self.Result(git_status=current_git_status, log=current_log)
+ dump_pickle(current_result, reference_result_path)
+
+ reference_result: TraceDeterminismTest.Result = load_pickle(
+ reference_result_path,
+ )
+ reference_log = reference_result.log
+
+ current_log_reduced = current_log.reduce_log_to_messages().filter_messages(
+ ignored_messages=self.ignored_messages,
+ )
+ reference_log_reduced = reference_log.reduce_log_to_messages().filter_messages(
+ ignored_messages=self.ignored_messages,
+ )
+
+ results: list[tuple[TraceLog, str]] = [
+ (reference_log_reduced, "expected"),
+ (current_log_reduced, "current"),
+ (reference_log, "expected_full"),
+ (current_log, "current_full"),
+ ]
+
+ if self.core_messages:
+ result_main_messages = current_log_reduced.filter_messages(
+ required_messages=self.core_messages,
+ )
+ reference_result_main_messages = reference_log_reduced.filter_messages(
+ required_messages=self.core_messages,
+ )
+ results.extend(
+ [
+ (reference_result_main_messages, "expected_core"),
+ (result_main_messages, "current_core"),
+ ],
+ )
+ else:
+ result_main_messages = current_log_reduced
+ reference_result_main_messages = reference_log_reduced
+
+ logs_equivalent = current_log_reduced.get_full_log() == reference_log_reduced.get_full_log()
+ if logs_equivalent:
+ status_passed = True
+ status_message = "OK"
+ else:
+ core_messages_unchanged = (
+ len(self.core_messages) > 0
+ and result_main_messages.get_full_log()
+ == reference_result_main_messages.get_full_log()
+ )
+ status_passed = core_messages_unchanged and pass_if_core_messages_unchanged
+
+ if status_passed:
+ status_message = "OK (core messages unchanged)"
+ else:
+ # save files for comparison
+ files = []
+ for r, suffix in results:
+ path = os.path.abspath(f"determinism_{name}_{suffix}.txt")
+ r.save_log(path)
+ files.append(path)
+
+ paths_str = "\n".join(files)
+ main_message = (
+ f"Please inspect the changes by diffing the log files:\n{paths_str}\n"
+ f"If the changes are OK, enable the `create_reference_result` flag temporarily, "
+ "rerun the test and then commit the updated reference file.\n\nHere's the first part of the diff:\n"
+ )
+
+ # compute diff and add to message
+ num_diff_lines_to_show = 30
+ for i, line in enumerate(
+ difflib.unified_diff(
+ reference_log_reduced.log_lines,
+ current_log_reduced.log_lines,
+ fromfile="expected.txt",
+ tofile="current.txt",
+ lineterm="",
+ ),
+ ):
+ if i == num_diff_lines_to_show:
+ break
+ main_message += line + "\n"
+
+ if core_messages_unchanged:
+ status_message = (
+ "The behaviour log has changed, but the core messages are still the same (so this "
+ f"probably isn't an issue). {main_message}"
+ )
+ else:
+ status_message = f"The behaviour log has changed; even the core messages are different. {main_message}"
+
+ # write log message
+ if self.log_filename:
+ with open(self.log_filename, "a") as f:
+ hr = "-" * 100
+ f.write(f"\n\n{hr}\nName: {name}\n")
+ f.write(f"Reference state: {reference_result.git_status}\n")
+ f.write(f"Current state: {current_git_status}\n")
+ f.write(f"Test result: {status_message}\n")
+
+ if not status_passed:
+ pytest.fail(status_message)