diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index fce69a8fa..26a095958 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -1,4 +1,4 @@ -name: PEP8, Types and Docs Check +name: Check Formatting/Typing and Build Docs on: pull_request: @@ -46,9 +46,12 @@ jobs: - name: Install the project dependencies run: | poetry install --with dev --extras "eval" - - name: Lint + - name: Check formatting run: poetry run poe lint - - name: Types + - name: Check typing run: poetry run poe type-check - - name: Docs - run: poetry run poe doc-build + - name: Build docs + run: MYSTNB_DEBUG=1 poetry run poe doc-build + - name: Show errors (if any) + if: failure() + run: find docs/_build/reports -name "*.err.log" -exec echo "--- {} ---" \; -exec cat {} \; diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c725a510a..3d747e4b5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,3 +1,3 @@ # Contributing to Tianshou -Please refer to [the 'Contributing' section on tianshou.org](https://tianshou.org/en/stable/04_contributing/04_contributing.html). +Please refer to the ['Developer Guide' on tianshou.org](https://tianshou.org/en/latest/04_developer_guide/developer_guide.html). diff --git a/README.md b/README.md index 725d30b64..e6d93334d 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ > ℹ️ **Introducing Tianshou version 2** > -> We have just released the first beta version 2.0.0b1 of the new major version of Tianshou, and we invite you to try it! +> We have released the second beta version 2.0.0b2 of the new major version of Tianshou on PyPI, and we invite you to try it! > Version 2 is a complete overhaul of the software design of the procedural API, in which > * we establish a clear separation between learning algorithms and policies (via the separate abstractions `Algorithm` and `Policy`). > * we provide more well-defined, more usable interfaces with extensive documentation of all algorithm and trainer parameters, diff --git a/docs/01_tutorials/02_internals.rst b/docs/01_tutorials/02_internals.rst deleted file mode 100644 index b9a3b7ac2..000000000 --- a/docs/01_tutorials/02_internals.rst +++ /dev/null @@ -1,430 +0,0 @@ -Understanding Tianshou Internals -================================ - -Tianshou splits a Reinforcement Learning agent training procedure into these parts: algorithm, trainer, collector, policy, a data buffer and batches from the buffer. -The algorithm encapsulates the specific RL learning method (e.g., DQN, PPO), which contains a policy and defines how to update it. - -.. - The general control flow can be described as: - - .. image:: /_static/images/concepts_arch.png - :align: center - :height: 300 - - - Here is a more detailed description, where ``Env`` is the environment and ``Model`` is the neural network: - - .. image:: /_static/images/concepts_arch2.png - :align: center - :height: 300 - - -Batch ------ - -Tianshou provides :class:`~tianshou.data.Batch` as the internal data structure to pass any kind of data to other methods, for example, a collector gives a :class:`~tianshou.data.Batch` to policy for learning. Let's take a look at this script: -:: - - >>> import torch, numpy as np - >>> from tianshou.data import Batch - >>> data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3)) - >>> # the list will automatically be converted to numpy array - >>> data.b - array([5, 5]) - >>> data.b = np.array([3, 4, 5]) - >>> print(data) - Batch( - a: 4, - b: array([3, 4, 5]), - c: '2312312', - d: array(['a', '-2', '-3'], dtype=object), - ) - >>> data = Batch(obs={'index': np.zeros((2, 3))}, act=torch.zeros((2, 2))) - >>> data[:, 1] += 6 - >>> print(data[-1]) - Batch( - obs: Batch( - index: array([0., 6., 0.]), - ), - act: tensor([0., 6.]), - ) - -In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair, and perform some common operations over it. - -:ref:`batch_concept` is a dedicated tutorial for :class:`~tianshou.data.Batch`. We strongly recommend every user to read it so as to correctly understand and use :class:`~tianshou.data.Batch`. - - -Buffer ------- - -:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of :class:`~tianshou.data.Batch`. It stores all the data in a batch with circular-queue style. - -The current implementation of Tianshou typically use the following reserved keys in -:class:`~tianshou.data.Batch`: - -* ``obs`` the observation of step :math:`t` ; -* ``act`` the action of step :math:`t` ; -* ``rew`` the reward of step :math:`t` ; -* ``terminated`` the terminated flag of step :math:`t` ; -* ``truncated`` the truncated flag of step :math:`t` ; -* ``done`` the done flag of step :math:`t` (can be inferred as ``terminated or truncated``); -* ``obs_next`` the observation of step :math:`t+1` ; -* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``); -* ``policy`` the data computed by policy in step :math:`t`; - -When adding data to a replay buffer, the done flag will be inferred automatically from ``terminated`` or ``truncated``. - -The following code snippet illustrates the usage, including: - -- the basic data storage: ``add()``; -- get attribute, get slicing data, ...; -- sample from buffer: ``sample_indices(batch_size)`` and ``sample(batch_size)``; -- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``; -- save/load data from buffer: pickle and HDF5; - -:: - - >>> import pickle, numpy as np - >>> from tianshou.data import Batch, ReplayBuffer - >>> buf = ReplayBuffer(size=20) - >>> for i in range(3): - ... buf.add(Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, obs_next=i + 1, info={})) - - >>> buf.obs - # since we set size = 20, len(buf.obs) == 20. - array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - >>> # but there are only three valid items, so len(buf) == 3. - >>> len(buf) - 3 - >>> # save to file "buf.pkl" - >>> pickle.dump(buf, open('buf.pkl', 'wb')) - >>> # save to HDF5 file - >>> buf.save_hdf5('buf.hdf5') - - >>> buf2 = ReplayBuffer(size=10) - >>> for i in range(15): - ... terminated = i % 4 == 0 - ... buf2.add(Batch(obs=i, act=i, rew=i, terminated=terminated, truncated=False, obs_next=i + 1, info={})) - >>> len(buf2) - 10 - >>> buf2.obs - # since its size = 10, it only stores the last 10 steps' result. - array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9]) - - >>> # move buf2's result into buf (meanwhile keep it chronologically) - >>> buf.update(buf2) - >>> buf.obs - array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, - 0, 0, 0, 0]) - - >>> # get all available index by using batch_size = 0 - >>> indices = buf.sample_indices(0) - >>> indices - array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) - >>> # get one step previous/next transition - >>> buf.prev(indices) - array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11]) - >>> buf.next(indices) - array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12]) - - >>> # get a random sample from buffer - >>> # the batch_data is equal to buf[indices]. - >>> batch_data, indices = buf.sample(batch_size=4) - >>> batch_data.obs == buf[indices].obs - array([ True, True, True, True]) - >>> len(buf) - 13 - - >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" - >>> len(buf) - 3 - >>> # load complete buffer from HDF5 file - >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') - >>> len(buf) - 3 - -:class:`~tianshou.data.ReplayBuffer` also supports "frame stack" sampling (typically for RNN usage, see `https://github.com/thu-ml/tianshou/issues/19`), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see `https://github.com/thu-ml/tianshou/issues/38`): - -.. raw:: html - -
- Advance usage of ReplayBuffer - -.. code-block:: python - - >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) - >>> for i in range(16): - ... terminated = i % 5 == 0 - ... ptr, ep_rew, ep_len, ep_idx = buf.add( - ... Batch(obs={'id': i}, act=i, rew=i, - ... terminated=terminated, truncated=False, obs_next={'id': i + 1})) - ... print(i, ep_len, ep_rew) - 0 [1] [0.] - 1 [0] [0.] - 2 [0] [0.] - 3 [0] [0.] - 4 [0] [0.] - 5 [5] [15.] - 6 [0] [0.] - 7 [0] [0.] - 8 [0] [0.] - 9 [0] [0.] - 10 [5] [40.] - 11 [0] [0.] - 12 [0] [0.] - 13 [0] [0.] - 14 [0] [0.] - 15 [5] [65.] - >>> print(buf) # you can see obs_next is not saved in buf - ReplayBuffer( - obs: Batch( - id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), - ), - act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), - rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - done: array([False, True, False, False, False, False, True, False, - False]), - ) - >>> index = np.arange(len(buf)) - >>> print(buf.get(index, 'obs').id) - [[ 7 7 8 9] - [ 7 8 9 10] - [11 11 11 11] - [11 11 11 12] - [11 11 12 13] - [11 12 13 14] - [12 13 14 15] - [ 7 7 7 7] - [ 7 7 7 8]] - >>> # here is another way to get the stacked data - >>> # (stack only for obs and obs_next) - >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() - 0 - >>> # we can get obs_next through __getitem__, even if it doesn't exist - >>> # however, [:] will select the item according to timestamp, - >>> # that equals to index == [7, 8, 0, 1, 2, 3, 4, 5, 6] - >>> print(buf[:].obs_next.id) - [[ 7 7 7 8] - [ 7 7 8 9] - [ 7 8 9 10] - [ 7 8 9 10] - [11 11 11 12] - [11 11 12 13] - [11 12 13 14] - [12 13 14 15] - [12 13 14 15]] - >>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6]) - >>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id) - True - -.. raw:: html - -

- -Tianshou provides other type of data buffer such as :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``) and :class:`~tianshou.data.VectorReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. - - -Algorithm and Policy --------------------- - -Tianshou's RL framework is built around two key abstractions: :class:`~tianshou.algorithm.Algorithm` and :class:`~tianshou.algorithm.Policy`. - -**Algorithm**: The core abstraction that encapsulates a complete RL learning method (e.g., DQN, PPO, SAC). Each algorithm contains a policy and defines how to update it using training data. All algorithm classes inherit from :class:`~tianshou.algorithm.Algorithm`. - -An algorithm class typically has the following parts: - -* :meth:`~tianshou.algorithm.Algorithm.__init__`: initialize the algorithm with a policy and optimization configuration; -* :meth:`~tianshou.algorithm.Algorithm._preprocess_batch`: pre-process data from the replay buffer (e.g., compute n-step returns); -* :meth:`~tianshou.algorithm.Algorithm._update_with_batch`: the algorithm-specific network update logic; -* :meth:`~tianshou.algorithm.Algorithm._postprocess_batch`: post-process the batch data (e.g., update prioritized replay buffer weights); -* :meth:`~tianshou.algorithm.Algorithm.create_trainer`: create the appropriate trainer for this algorithm; - -**Policy**: Represents the mapping from observations to actions. Policy classes inherit from :class:`~tianshou.algorithm.Policy`. - -A policy class typically provides: - -* :meth:`~tianshou.algorithm.Policy.forward`: compute action distribution or Q-values given observations; -* :meth:`~tianshou.algorithm.Policy.compute_action`: get concrete actions from observations for environment interaction; -* :meth:`~tianshou.algorithm.Policy.map_action`: transform raw network outputs to environment action space; - - -.. _policy_state: - -States for policy -^^^^^^^^^^^^^^^^^ - -During the training process, the policy has two main states: training state and testing state. The training state can be further divided into the collecting state and updating state. - -The meaning of training and testing state is obvious: the agent interacts with environment, collects training data and performs update, that's training state; the testing state is to evaluate the performance of the current policy during training process. - -As for the collecting state, it is defined as interacting with environments and collecting training data into the buffer; -we define the updating state as performing a model update by the algorithm's update methods during training process. - -The collection of data from the env may differ in training and in inference (for example, in training one may add exploration noise, or sample from the predicted action distribution instead of taking its mode). The switch between the different collection strategies in training and inference is controlled by ``policy.is_within_training_step``, see also the docstring of it -for more details. - - -policy.forward -^^^^^^^^^^^^^^ - -The ``forward`` function computes the action over given observations. The input and output is algorithm-specific but generally, the function is a mapping of ``(batch, state, ...) -> batch``. - -The input batch is the environment data (e.g., observation, reward, done flag and info). It comes from either :meth:`~tianshou.data.Collector.collect` or :meth:`~tianshou.data.ReplayBuffer.sample`. The first dimension of all variables in the input ``batch`` should be equal to the batch-size. - -The output is also a ``Batch`` which must contain "act" (action) and may contain "state" (hidden state of policy), "policy" (the intermediate result of policy which needs to save into the buffer, see :meth:`~tianshou.algorithm.BasePolicy.forward`), and some other algorithm-specific keys. - -For example, if you try to use your policy to evaluate one episode (and don't want to use :meth:`~tianshou.data.Collector.collect`), use the following code-snippet: -:: - - # assume env is a gym.Env - obs, done = env.reset(), False - while not done: - batch = Batch(obs=[obs]) # the first dimension is batch-size - act = policy(batch).act[0] # policy.forward return a batch, use ".act" to extract the action - obs, rew, done, info = env.step(act) - -For inference, it is recommended to use the shortcut method :meth:`~tianshou.algorithm.Policy.compute_action` to compute the action directly from the observation. - -Here, ``Batch(obs=[obs])`` will automatically create the 0-dimension to be the batch-size. Otherwise, the network cannot determine the batch-size. - - -.. _process_fn: - -Algorithm Preprocessing and N-step Returns -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The algorithm handles data preprocessing, including computing variables that depend on time-series such as N-step or GAE returns. This functionality is implemented in :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` and the static methods :meth:`~tianshou.algorithm.Algorithm.compute_nstep_return` and :meth:`~tianshou.algorithm.Algorithm.compute_episodic_return`. - -Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as: - -.. math:: - - G_t = r_t + \gamma r_{t + 1} + \gamma^2 \max_a Q(s_{t + 2}, a) - -where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is the pseudocode showing the training process **without Tianshou framework**: -:: - - # pseudocode, cannot work - obs = env.reset() - buffer = Buffer(size=10000) - algorithm = DQN(...) - for i in range(int(1e6)): - act = algorithm.policy.compute_action(obs) - obs_next, rew, done, _ = env.step(act) - buffer.store(obs, act, obs_next, rew, done) - obs = obs_next - if i % 1000 == 0: - # algorithm handles sampling, preprocessing, and updating - algorithm.update(sample_size=64, buffer=buffer) - -The algorithm's :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` method automatically handles n-step return computation by calling :meth:`~tianshou.algorithm.Algorithm.compute_nstep_return`, which provides the replay buffer, sample indices, and batch data. Since we store all the data in the order of time, the n-step return can be computed efficiently using the buffer's temporal structure. - -For custom preprocessing logic, you can override :meth:`~tianshou.algorithm.Algorithm._preprocess_batch` in your algorithm subclass. The method receives the sampled batch, buffer, and indices, allowing you to add computed values like returns, advantages, or other algorithm-specific preprocessing steps. - - -Collector ---------- - -The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. - -:meth:`~tianshou.data.Collector.collect` is the main method of :class:`~tianshou.data.Collector`: it lets the policy perform a specified number of steps (``n_step``) or episodes (``n_episode``) and store the data in the replay buffer, then return the statistics of the collected data such as episode's total reward. - -The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. Here are some example usages: -:: - - policy = DiscreteQLearningPolicy(...) # or other policies if you wish - env = gym.make("CartPole-v1") - - replay_buffer = ReplayBuffer(size=10000) - - # here we set up a collector with a single environment - collector = Collector(policy, env, buffer=replay_buffer) - - # the collector supports vectorized environments as well - vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3) - # buffer_num should be equal to (suggested) or larger than #envs - envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(3)]) - collector = Collector(policy, envs, buffer=vec_buffer) - - # collect 3 episodes - collector.collect(n_episode=3) - # collect at least 2 steps - collector.collect(n_step=2) - # collect episodes with visual rendering ("render" is the sleep time between - # rendering consecutive frames) - collector.collect(n_episode=1, render=0.03) - -There is also another type of collector :class:`~tianshou.data.AsyncCollector` which supports asynchronous environment setting (for those taking a long time to step). However, AsyncCollector only supports **at least** ``n_step`` or ``n_episode`` collection due to the property of asynchronous environments. - - -Trainer -------- - -Once you have an algorithm and a collector, you can start the training process. The trainer orchestrates the training loop and calls upon the algorithm's specific network updating logic. Each algorithm creates its appropriate trainer type through the :meth:`~tianshou.algorithm.Algorithm.create_trainer` method. - -Tianshou has three main trainer classes: :class:`~tianshou.trainer.OnPolicyTrainer` for on-policy algorithms such as Policy Gradient, :class:`~tianshou.trainer.OffPolicyTrainer` for off-policy algorithms such as DQN, and :class:`~tianshou.trainer.OfflineTrainer` for offline algorithms such as BCQ. - -The typical workflow is: -:: - - # Create algorithm with policy - algorithm = DQN(policy=policy, optim=optimizer_factory, ...) - - # Create trainer parameters - params = OffPolicyTrainerParams( - max_epochs=100, - step_per_epoch=1000, - train_collector=train_collector, - test_collector=test_collector, - ... - ) - - # Run training (trainer is created automatically) - result = algorithm.run_training(params) - -You can also create trainers manually for more control: -:: - - trainer = algorithm.create_trainer(params) - result = trainer.run() - - -.. _pseudocode: - -A High-level Explanation ------------------------- - -We give a high-level explanation through the pseudocode used in section :ref:`process_fn`: -:: - - # pseudocode, cannot work # methods in tianshou - obs = env.reset() - buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000) - algorithm = DQN(policy=policy, ...) # algorithm.__init__(...) - for i in range(int(1e6)): # done in trainer - act = algorithm.policy.compute_action(obs) # act = policy.compute_action(obs) - obs_next, rew, done, _ = env.step(act) # collector.collect(...) - buffer.store(obs, act, obs_next, rew, done) # collector.collect(...) - obs = obs_next # collector.collect(...) - if i % 1000 == 0: # done in trainer - # the following is done in algorithm.update(batch_size, buffer) - b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size) - # compute 2-step returns. How? - b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # algorithm._preprocess_batch(batch, buffer, indices) - # update DQN policy - algorithm.update(b_s, b_a, b_s_, b_r, b_d, b_ret) # algorithm._update_with_batch(batch) - - -Conclusion ----------- - -So far, we've covered the overall framework of Tianshou with its new architecture centered around the Algorithm abstraction. The key components are: - -- **Algorithm**: Encapsulates the complete RL learning method, containing a policy and defining how to update it -- **Policy**: Handles the mapping from observations to actions -- **Collector**: Manages environment interaction and data collection -- **Trainer**: Orchestrates the training loop and calls the algorithm's update logic -- **Buffer**: Stores and manages experience data -- **Batch**: A flexible data structure for passing data between components. Batches are collected to the buffer by the Collector and are sampled from the buffer by the `Algorithm` where they are used for learning. - -This modular design cleanly separates concerns while maintaining the flexibility to implement various RL algorithms. diff --git a/docs/01_tutorials/03_batch.rst b/docs/01_tutorials/03_batch.rst deleted file mode 100644 index f778f0b71..000000000 --- a/docs/01_tutorials/03_batch.rst +++ /dev/null @@ -1,501 +0,0 @@ -.. _batch_concept: - -Understand Batch -================ - -:class:`~tianshou.data.Batch` is the internal data structure extensively used in Tianshou. It is designed to store and manipulate hierarchical named tensors. This tutorial aims to help users correctly understand the concept and the behavior of :class:`~tianshou.data.Batch` so that users can make the best of Tianshou. - -The tutorial has three parts. We first explain the concept of hierarchical named tensors, and introduce basic usage of :class:`~tianshou.data.Batch`, followed by advanced topics of :class:`~tianshou.data.Batch`. - - -Hierarchical Named Tensors ---------------------------- - -.. sidebar:: The structure of a Batch shown by a tree - - .. Figure:: ../_static/images/batch_tree.png - -"Hierarchical named tensors" refers to a set of tensors where their names form a hierarchy. Suppose there are four tensors ``[t1, t2, t3, t4]`` with names ``[name1, name2, name3, name4]``, where ``name1`` and ``name2`` belong to the same namespace ``name0``, then the full name of tensor ``t1`` is ``name0.name1``. That is, the hierarchy lies in the names of tensors. - -We can describe the structure of hierarchical named tensors using a tree in the right. There is always a "virtual root" node to represent the whole object; internal nodes are keys (names), and leaf nodes are values (scalars or tensors). - -Hierarchical named tensors are needed because we have to deal with the heterogeneity of reinforcement learning problems. The abstraction of RL is very simple, just:: - - state, reward, done = env.step(action) - -``reward`` and ``done`` are simple, they are mostly scalar values. However, the ``state`` and ``action`` vary with environments. For example, ``state`` can be simply a vector, a tensor, or a camera input combined with sensory input. In the last case, it is natural to store them as hierarchical named tensors. This hierarchy can go beyond ``state`` and ``action``: we can store ``state``, ``action``, ``reward``, and ``done`` together as hierarchical named tensors. - -Note that, storing hierarchical named tensors is as easy as creating nested dictionary objects: -:: - - { - 'done': done, - 'reward': reward, - 'state': { - 'camera': camera, - 'sensory': sensory - } - 'action': { - 'direct': direct, - 'point_3d': point_3d, - 'force': force, - } - } - -The real problem is how to **manipulate them**, such as adding new transition tuples into replay buffer and dealing with their heterogeneity. ``Batch`` is designed to easily create, store, and manipulate these hierarchical named tensors. - - -Basic Usages ------------- - -Here we cover some basic usages of ``Batch``, describing what ``Batch`` contains, how to construct ``Batch`` objects and how to manipulate them. - - -What Does Batch Contain -^^^^^^^^^^^^^^^^^^^^^^^ - -The content of ``Batch`` objects can be defined by the following rules. - -1. A ``Batch`` object can be an empty ``Batch()``, or have at least one key-value pairs. ``Batch()`` can be used to reserve keys, too. See :ref:`key_reservations` for this advanced usage. - -2. The keys are always strings (they are names of corresponding values). - -3. The values can be scalars, tensors, or Batch objects. The recursive definition makes it possible to form a hierarchy of batches. - -4. Tensors are the most important values. In short, tensors are n-dimensional arrays of the same data type. We support two types of tensors: `PyTorch `_ tensor type ``torch.Tensor`` and `NumPy `_ tensor type ``np.ndarray``. - -5. Scalars are also valid values. A scalar is a single boolean, number, or object. They can be python scalar (``False``, ``1``, ``2.3``, ``None``, ``'hello'``) or NumPy scalar (``np.bool_(True)``, ``np.int32(1)``, ``np.float64(2.3)``). They just shouldn't be mixed up with Batch/dict/tensors. - -.. note:: - - ``Batch`` cannot store ``dict`` objects, because internally ``Batch`` uses ``dict`` to store data. During construction, ``dict`` objects will be automatically converted to ``Batch`` objects. - - The data types of tensors are bool and numbers (any size of int and float as long as they are supported by NumPy or PyTorch). Besides, NumPy supports ndarray of objects and we take advantage of this feature to store non-number objects in ``Batch``. If one wants to store data that are neither boolean nor numbers (such as strings and sets), they can store the data in ``np.ndarray`` with the ``np.object`` data type. This way, ``Batch`` can store any type of python objects. - - -Construction of Batch -^^^^^^^^^^^^^^^^^^^^^ - -There are two ways to construct a ``Batch`` object: from a ``dict``, or using ``kwargs``. Below are some code snippets. - -.. raw:: html - -
- Construct Batch from dict - -.. code-block:: python - - >>> # directly passing a dict object (possibly nested) is ok - >>> data = Batch({'a': 4, 'b': [5, 5], 'c': '2312312'}) - >>> # the list will automatically be converted to numpy array - >>> data.b - array([5, 5]) - >>> data.b = np.array([3, 4, 5]) - >>> print(data) - Batch( - a: 4, - b: array([3, 4, 5]), - c: '2312312', - ) - >>> # a list of dict objects (possibly nested) will be automatically stacked - >>> data = Batch([{'a': 0.0, 'b': "hello"}, {'a': 1.0, 'b': "world"}]) - >>> print(data) - Batch( - a: array([0., 1.]), - b: array(['hello', 'world'], dtype=object), - ) - -.. raw:: html - -

- -.. raw:: html - -
- Construct Batch from kwargs - -.. code-block:: python - - >>> # construct a Batch with keyword arguments - >>> data = Batch(a=[4, 4], b=[5, 5], c=[None, None]) - >>> print(data) - Batch( - a: array([4, 4]), - b: array([5, 5]), - c: array([None, None], dtype=object), - ) - >>> # combining keyword arguments and batch_dict works fine - >>> data = Batch({'a':[4, 4], 'b':[5, 5]}, c=[None, None]) # the first argument is a dict, and 'c' is a keyword argument - >>> print(data) - Batch( - a: array([4, 4]), - b: array([5, 5]), - c: array([None, None], dtype=object), - ) - >>> arr = np.zeros((3, 4)) - >>> # By default, Batch only keeps the reference to the data, but it also supports data copying - >>> data = Batch(arr=arr, copy=True) # data.arr now is a copy of 'arr' - -.. raw:: html - -

- - -Data Manipulation With Batch -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Users can access the internal data by ``b.key`` or ``b[key]``, where ``b.key`` finds the sub-tree with ``key`` as the root node. If the result is a sub-tree with non-empty keys, the key-reference can be chained, i.e. ``b.key.key1.key2.key3``. When it reaches a leaf node, users get the data (scalars/tensors) stored in that ``Batch`` object. - -.. raw:: html - -
- Access data stored in Batch - -.. code-block:: python - - >>> data = Batch(a=4, b=[5, 5]) - >>> print(data.b) - [5 5] - >>> # obj.key is equivalent to obj["key"] - >>> print(data["a"]) - 4 - >>> # iterating over data items like a dict is supported - >>> for key, value in data.items(): - >>> print(f"{key}: {value}") - a: 4 - b: [5, 5] - >>> # obj.keys() and obj.values() work just like dict.keys() and dict.values() - >>> for key in data.keys(): - >>> print(f"{key}") - a - b - >>> # obj.update() behaves like dict.update() - >>> # this is the same as data.c = 1; data.c = 2; data.e = 3; - >>> data.update(c=1, d=2, e=3) - >>> print(data) - Batch( - a: 4, - b: array([5, 5]), - c: 1, - d: 2, - e: 3, - ) - -.. raw:: html - -

- -.. note:: - - If ``data`` is a ``dict`` object, ``for x in data`` iterates over keys in the dict. However, it has a different meaning for ``Batch`` objects: ``for x in data`` iterates over ``data[0], data[1], ..., data[-1]``. An example is given below. - -``Batch`` also partially reproduces the NumPy ndarray APIs. It supports advanced slicing, such as ``batch[:, i]`` so long as the slice is valid. Broadcast mechanism of NumPy works for ``Batch``, too. - -.. raw:: html - -
- Length, shape, indexing, and slicing of Batch - -.. code-block:: python - - >>> # initialize Batch with tensors - >>> data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5, -5], [1, -2]]) - >>> # if values have the same length/shape, that length/shape is used for this Batch - >>> # else, check the advanced topic for details - >>> print(len(data)) - 2 - >>> print(data.shape) - [2, 2] - >>> # access the first item of all the stored tensors, while keeping the structure of Batch - >>> print(data[0]) - Batch( - a: array([0., 2.]) - b: array([ 5, -5]), - ) - >>> # iterates over ``data[0], data[1], ..., data[-1]`` - >>> for sample in data: - >>> print(sample.a) - [0. 2.] - [1. 3.] - - >>> # Advanced slicing works just fine - >>> # Arithmetic operations are passed to each value in the Batch, with broadcast enabled - >>> data[:, 1] += 1 - >>> print(data) - Batch( - a: array([[0., 3.], - [1., 4.]]), - b: array([[ 5, -4]]), - ) - - >>> # amazingly, you can directly apply np.mean to a Batch object - >>> print(np.mean(data)) - Batch( - a: 1.5, - b: -0.25, - ) - - >>> # directly converted to a list is also available - >>> list(data) - [Batch( - a: array([0., 3.]), - b: array([ 5, -4]), - ), - Batch( - a: array([1., 4.]), - b: array([ 1, -1]), - )] - -.. raw:: html - -

- -Stacking and concatenating multiple ``Batch`` instances, or split an instance into multiple batches, they are all easy and intuitive in Tianshou. For now, we stick to the aggregation (stack/concatenate) of homogeneous (same structure) batches. Stack/Concatenation of heterogeneous batches are discussed in :ref:`aggregation`. - -.. raw:: html - -
- Stack / Concatenate / Split of Batches - -.. code-block:: python - - >>> data_1 = Batch(a=np.array([0.0, 2.0]), b=5) - >>> data_2 = Batch(a=np.array([1.0, 3.0]), b=-5) - >>> data = Batch.stack((data_1, data_2)) - >>> print(data) - Batch( - b: array([ 5, -5]), - a: array([[0., 2.], - [1., 3.]]), - ) - >>> # split supports random shuffling - >>> data_split = list(data.split(1, shuffle=False)) - >>> print(list(data.split(1, shuffle=False))) - [Batch( - b: array([5]), - a: array([[0., 2.]]), - ), Batch( - b: array([-5]), - a: array([[1., 3.]]), - )] - >>> data_cat = Batch.cat(data_split) - >>> print(data_cat) - Batch( - b: array([ 5, -5]), - a: array([[0., 2.], - [1., 3.]]), - ) - -.. raw:: html - -

- - -Advanced Topics ---------------- - -From here on, this tutorial focuses on advanced topics of ``Batch``, including key reservation, length/shape, and aggregation of heterogeneous batches. - - -.. _key_reservations: - -Key Reservations -^^^^^^^^^^^^^^^^ - -.. sidebar:: The structure of a Batch with reserved keys - - .. Figure:: ../_static/images/batch_reserve.png - -In many cases, we know in the first place what keys we have, but we do not know the shape of values until we run the environment. To deal with this, Tianshou supports key reservations: **reserve a key and use a placeholder value**. - -The usage is easy: just use ``Batch()`` to be the value of reserved keys. - -.. code-block:: python - - a = Batch(b=Batch()) # 'b' is a reserved key - # this is called hierarchical key reservation - a = Batch(b=Batch(c=Batch()), d=Batch()) # 'c' and 'd' are reserved key - # the structure of this last Batch is shown in the right figure - a = Batch(key1=tensor1, key2=tensor2, key3=Batch(key4=Batch(), key5=Batch())) - -Still, we can use a tree (in the right) to show the structure of ``Batch`` objects with reserved keys, where reserved keys are special internal nodes that do not have attached leaf nodes. - -.. note:: - - Reserved keys mean that in the future there will eventually be values attached to them. The values can be scalars, tensors, or even **Batch** objects. Understanding this is critical to understand the behavior of ``Batch`` when dealing with heterogeneous Batches. - -The introduction of reserved keys gives rise to the need to check if a key is reserved. - -.. raw:: html - -
- Examples of checking whether Batch is empty - -.. code-block:: python - - >>> len(Batch().get_keys()) == 0 - True - >>> len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0 - False - >>> len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0 - True - >>> len(Batch(d=1).get_keys()) == 0 - False - >>> len(Batch(a=np.float64(1.0)).get_keys()) == 0 - False - -.. raw:: html - -

- -To check whether a Batch is empty, simply use ``len(Batch.get_keys()) == 0`` to decide whether to identify direct emptiness (just a ``Batch()``) or ``len(Batch) == 0`` to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes). - -.. note:: - - Do not get confused with ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details. - - -Length and Shape -^^^^^^^^^^^^^^^^ - -The most common usage of ``Batch`` is to store a Batch of data. The term "Batch" comes from the deep learning community to denote a mini-batch of sampled data from the whole dataset. In this regard, "Batch" typically means a collection of tensors whose first dimensions are the same. Then the length of a ``Batch`` object is simply the batch-size. - -If all the leaf nodes in a ``Batch`` object are tensors, but they have different lengths, they can be readily stored in ``Batch``. However, for ``Batch`` of this kind, the ``len(obj)`` seems a bit ambiguous. Currently, Tianshou returns the length of the shortest tensor, but we strongly recommend that users do not use the ``len(obj)`` operator on ``Batch`` objects with tensors of different lengths. - -.. raw:: html - -
- Examples of len and obj.shape for Batch objects - -.. code-block:: python - - >>> data = Batch(a=[5., 4.], b=np.zeros((2, 3, 4))) - >>> data.shape - [2] - >>> len(data) - 2 - >>> data[0].shape - [] - >>> len(data[0]) - TypeError: Object of type 'Batch' has no len() - -.. raw:: html - -

- -.. note:: - - Following the convention of scientific computation, scalars have no length. If there is any scalar leaf node in a ``Batch`` object, an exception will occur when users call ``len(obj)``. - - Besides, values of reserved keys are undetermined, so they have no length, neither. Or, to be specific, values of reserved keys have lengths of **any**. When there is a mix of tensors and reserved keys, the latter will be ignored in ``len(obj)`` and the minimum length of tensors is returned. When there is not any tensor in the ``Batch`` object, Tianshou raises an exception, too. - -The ``obj.shape`` attribute of ``Batch`` behaves somewhat similar to ``len(obj)``: - -1. If all the leaf nodes in a ``Batch`` object are tensors with the same shape, that shape is returned. - -2. If all the leaf nodes in a ``Batch`` object are tensors but they have different shapes, the minimum length of each dimension is returned. - -3. If there is any scalar value in a ``Batch`` object, ``obj.shape`` returns ``[]``. - -4. The shape of reserved keys is undetermined, too. We treat their shape as ``[]``. - - -.. _aggregation: - -Aggregation of Heterogeneous Batches -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In this section, we talk about aggregation operators (stack/concatenate) on heterogeneous ``Batch`` objects. -The following picture will give you an intuitive understanding of this behavior. It shows two examples of aggregation operators with heterogeneous ``Batch``. The shapes of tensors are annotated in the leaf nodes. - -.. image:: ../_static/images/aggregation.png - -We only consider the heterogeneity in the structure of ``Batch`` objects. The aggregation operators are eventually done by NumPy/PyTorch operators (``np.stack``, ``np.concatenate``, ``torch.stack``, ``torch.cat``). Heterogeneity in values can fail these operators (such as stacking ``np.ndarray`` with ``torch.Tensor``, or stacking tensors with different shapes) and an exception will be raised. - -The behavior is natural: for keys that are not shared across all batches, batches that do not have these keys will be padded by zeros (or ``None`` if the data type is ``np.object``). It can be written in the following scripts: -:: - - >>> # examples of stack: a is missing key `b`, and b is missing key `a` - >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) - >>> b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5]))) - >>> c = Batch.stack([a, b]) - >>> c.a.shape - (2, 4, 4) - >>> c.b.shape - (2, 4, 6) - >>> c.common.c.shape - (2, 4, 5) - >>> # None or 0 is padded with appropriate shape - >>> data_1 = Batch(a=np.array([0.0, 2.0])) - >>> data_2 = Batch(a=np.array([1.0, 3.0]), b='done') - >>> data = Batch.stack((data_1, data_2)) - >>> print(data) - Batch( - a: array([[0., 2.], - [1., 3.]]), - b: array([None, 'done'], dtype=object), - ) - >>> # examples of cat: a is missing key `b`, and b is missing key `a` - >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) - >>> b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5]))) - >>> c = Batch.cat([a, b]) - >>> c.a.shape - (7, 4) - >>> c.b.shape - (7, 3) - >>> c.common.c.shape - (7, 5) - -However, there are some cases when batches are too heterogeneous that they cannot be aggregated: -:: - - >>> a = Batch(a=np.zeros([4, 4])) - >>> b = Batch(a=Batch(b=Batch())) - >>> # this will raise an exception - >>> c = Batch.stack([a, b]) - -Then how to determine if batches can be aggregated? Let's rethink the purpose of reserved keys. What is the advantage of ``a1=Batch(b=Batch())`` over ``a2=Batch()``? The only difference is that ``a1.b`` returns ``Batch()`` but ``a2.b`` raises an exception. That's to say, **we reserve keys for attribute reference**. - -We say a key chain ``k=[key1, key2, ..., keyn]`` applies to ``b`` if the expression ``b.key1.key2.{...}.keyn`` is valid, and the result is ``b[k]``. - -For a set of ``Batch`` objects denoted as :math:`S`, they can be aggregated if there exists a ``Batch`` object ``b`` satisfying the following rules: - - 1. Key chain applicability: For any object ``bi`` in :math:`S`, and any key chain ``k``, if ``bi[k]`` is valid, then ``b[k]`` is valid. - - 2. Type consistency: If ``bi[k]`` is not ``Batch()`` (the last key in the key chain is not a reserved key), then the type of ``b[k]`` should be the same as ``bi[k]`` (both should be scalar/tensor/non-empty Batch values). - -The ``Batch`` object ``b`` satisfying these rules with the minimum number of keys determines the structure of aggregating :math:`S`. The values are relatively easy to define: for any key chain ``k`` that applies to ``b``, ``b[k]`` is the stack/concatenation of ``[bi[k] for bi in S]`` (if ``k`` does not apply to ``bi``, the appropriate size of zeros or ``None`` are filled automatically). If ``bi[k]`` are all ``Batch()``, then the aggregation result is also an empty ``Batch()``. - - -Miscellaneous Notes -^^^^^^^^^^^^^^^^^^^ - -1. ``Batch`` is serializable and therefore Pickle compatible. ``Batch`` objects can be saved to disk and later restored by the python ``pickle`` module. This pickle compatibility is especially important for distributed sampling from environments. - -.. raw:: html - -
- Batch.to_torch_ and Batch.to_numpy_ - -:: - - >>> data = Batch(a=np.zeros((3, 4))) - >>> data.to_torch_(dtype=torch.float32, device='cpu') - >>> print(data.a) - tensor([[0., 0., 0., 0.], - [0., 0., 0., 0.], - [0., 0., 0., 0.]]) - >>> # data.to_numpy_ is also available - >>> data.to_numpy_() - -.. raw:: html - -

- -2. It is often the case that the observations returned from the environment are all NumPy ndarray but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors. - -3. ``obj.stack_([a, b])`` is the same as ``Batch.stack([obj, a, b])``, and ``obj.cat_([a, b])`` is the same as ``Batch.cat([obj, a, b])``. Considering the frequent requirement of concatenating two ``Batch`` objects, Tianshou also supports ``obj.cat_(a)`` to be an alias of ``obj.cat_([a])``. - -4. ``Batch.cat`` and ``Batch.cat_`` does not support ``axis`` argument as ``np.concatenate`` and ``torch.cat`` currently. - -5. ``Batch.stack`` and ``Batch.stack_`` support the ``axis`` argument so that one can stack batches besides the first dimension. But be cautious, if there are keys that are not shared across all batches, ``stack`` with ``axis != 0`` is undefined, and will cause an exception currently. diff --git a/docs/01_tutorials/04_tictactoe.rst b/docs/01_tutorials/04_tictactoe.rst deleted file mode 100644 index c5d6a87c0..000000000 --- a/docs/01_tutorials/04_tictactoe.rst +++ /dev/null @@ -1,650 +0,0 @@ -RL against random policy opponent with PettingZoo -================================================= - -Tianshou is compatible with `PettingZoo` environments for multi-agent RL, although does not directly provide facilities for multi-agent RL. Here are some helpful tutorial links: - -* https://pettingzoo.farama.org/tutorials/tianshou/beginner/ -* https://pettingzoo.farama.org/tutorials/tianshou/intermediate/ -* https://pettingzoo.farama.org/tutorials/tianshou/advanced/ - -In this section, we describe how to use Tianshou to implement RL in a multi-agent setting where, however, only one agent is trained, and the other one adopts a fixed random policy. -The user can then use this as a blueprint to replace the random policy with another trainable agent. -Specifically, we will design an algorithm to learn how to play `Tic Tac Toe `_ (see the image below) against a random opponent. - -.. image:: ../_static/images/tic-tac-toe.png - :align: center - - -Tic-Tac-Toe Environment ------------------------ - -The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any `PettingZoo `_ environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it. -:: - - >>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments - >>> from pettingzoo.classic import tictactoe_v3 # the Tic-Tac-Toe environment to be wrapped - >>> # This board has 3 rows and 3 cols (9 places in total) - >>> # Players place 'x' and 'o' in turn on the board - >>> # The player who first gets 3 consecutive 'x's or 'o's wins - >>> - >>> env = PettingZooEnv(tictactoe_v3.env(render_mode="human")) - >>> obs = env.reset() - >>> env.render() # render the empty board - board (step 0): - | | - - | - | - - _____|_____|_____ - | | - - | - | - - _____|_____|_____ - | | - - | - | - - | | - >>> print(obs) # let's see the shape of the observation - {'agent_id': 'player_1', 'obs': array([[[0, 0], - [0, 0], - [0, 0]], - - [[0, 0], - [0, 0], - [0, 0]], - - [[0, 0], - [0, 0], - [0, 0]]], dtype=int8), 'mask': [True, True, True, True, True, True, True, True, True]} - - -The observation variable ``obs`` returned from the environment is a ``dict``, with three keys ``agent_id``, ``obs``, ``mask``. This is a general structure in multi-agent RL where agents take turns. The meaning of these keys are: - -- ``agent_id``: the id of the current acting agent. In our Tic-Tac-Toe case, the agent_id can be ``player_1`` or ``player_2``. - -- ``obs``: the actual observation of the environment. In the Tic-Tac-Toe game above, the observation variable ``obs`` is a ``np.ndarray`` with the shape of (3, 3, 2). For ``player_1``, the first 3x3 plane represents the placement of Xs, and the second plane shows the placement of Os. The possible values for each cell are 0 or 1; in the first plane, 1 indicates that an X has been placed in that cell, and 0 indicates that X is not in that cell. Similarly, in the second plane, 1 indicates that an O has been placed in that cell, while 0 indicates that an O has not been placed. For ``player_2``, the observation is the same, but Xs and Os swap positions, so Os are encoded in plane 1 and Xs in plane 2. - -- ``mask``: the action mask in the current timestep. In board games or card games, the legal action set varies with time. The mask is a boolean array. For Tic-Tac-Toe, index ``i`` means the place of ``i/N`` th row and ``i%N`` th column. If ``mask[i] == True``, the player can place an ``x`` or ``o`` at that position. Now the board is empty, so the mask is all the true, contains all the positions on the board. - -.. note:: - - There is no special formulation of ``mask`` either in discrete action space or in continuous action space. You can also use some action spaces like ``gymnasium.spaces.Discrete`` or ``gymnasium.spaces.Box`` to represent the available action space. Currently, we use a boolean array. - -Let's play two steps to have an intuitive understanding of the environment. - -:: - - >>> import numpy as np - >>> action = 0 # action is either an integer, or an np.ndarray with one element - >>> obs, reward, done, info = env.step(action) # the env.step follows the api of Gymnasium - >>> print(obs) # notice the change in the observation - {'agent_id': 'player_2', 'obs': array([[[0, 1], - [0, 0], - [0, 0]], - - [[0, 0], - [0, 0], - [0, 0]], - - [[0, 0], - [0, 0], - [0, 0]]], dtype=int8), 'mask': [False, True, True, True, True, True, True, True, True]} - >>> # reward has two items, one for each player: 1 for win, -1 for lose, and 0 otherwise - >>> print(reward) - [0. 0.] - >>> print(done) # done indicates whether the game is over - False - >>> # info is always an empty dict in Tic-Tac-Toe, but may contain some useful information in environments other than Tic-Tac-Toe. - >>> print(info) - {} - -One worth-noting case is that the game is over when there is only one empty position, rather than when there is no position. This is because the player just has one choice (literally no choice) in this game. -:: - - >>> # omitted actions: 3, 1, 4 - >>> obs, reward, done, info = env.step(2) # player_1 wins - >>> print((reward, done)) - ([1, -1], True) - >>> env.render() - | | - X | O | - - _____|_____|_____ - | | - X | O | - - _____|_____|_____ - | | - X | - | - - | | - -After being familiar with the environment, let's try to play with random agents first! - - -Two Random Agents ------------------ - -.. sidebar:: The relationship between MultiAgentPolicyManager (Manager) and BasePolicy (Agent) - - .. Figure:: ../_static/images/marl.png - -Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.algorithm.MARLRandomPolicy` and :class:`~tianshou.algorithm.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. - -:: - - >>> from tianshou.data import Collector - >>> from tianshou.env import DummyVectorEnv - >>> from tianshou.algorithm import RandomPolicy, MultiAgentPolicyManager - >>> - >>> # agents should be wrapped into one policy, - >>> # which is responsible for calling the acting agent correctly - >>> # here we use two random agents - >>> policy = MultiAgentPolicyManager( - >>> [RandomPolicy(action_space=env.action_space), RandomPolicy(action_space=env.action_space)], env - >>> ) - >>> - >>> # need to vectorize the environment for the collector - >>> env = DummyVectorEnv([lambda: env]) - >>> - >>> # use collectors to collect a episode of trajectories - >>> # the reward is a vector, so we need a scalar metric to monitor the training - >>> collector = Collector(policy, env) - >>> - >>> # you will see a long trajectory showing the board status at each timestep - >>> result = collector.collect(n_episode=1, render=.1) - (only show the last 3 steps) - | | - X | X | - - _____|_____|_____ - | | - X | O | - - _____|_____|_____ - | | - O | - | - - | | - | | - X | X | - - _____|_____|_____ - | | - X | O | - - _____|_____|_____ - | | - O | - | O - | | - | | - X | X | X - _____|_____|_____ - | | - X | O | - - _____|_____|_____ - | | - O | - | O - | | - -Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly. - - -Train one Agent against a random opponent ------------------------------------------ - -So let's start to train our Tic-Tac-Toe agent! First, import some required modules. -:: - - import argparse - import os - from copy import deepcopy - from typing import Optional, Tuple - - import gymnasium as gym - import numpy as np - import torch - from pettingzoo.classic import tictactoe_v3 - from torch.utils.tensorboard import SummaryWriter - - from tianshou.data import Collector, VectorReplayBuffer - from tianshou.env import DummyVectorEnv - from tianshou.env.pettingzoo_env import PettingZooEnv - from tianshou.algorithm import ( - BasePolicy, - DQNPolicy, - MultiAgentPolicyManager, - MARLRandomPolicy, - ) - from tianshou.trainer import OffpolicyTrainer - from tianshou.utils import TensorboardLogger - from tianshou.utils.net.common import MLPActor - -The explanation of each Tianshou class/function will be deferred to their first usages. Here we define some arguments and hyperparameters of the experiment. The meaning of arguments is clear by just looking at their names. -:: - - def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0.05) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-4) - parser.add_argument( - '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' - ) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=50) - parser.add_argument('--epoch_num_steps', type=int, default=1000) - parser.add_argument('--collection_step_num_env_steps', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch_size', type=int, default=64) - parser.add_argument( - '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] - ) - parser.add_argument('--num_train_envs', type=int, default=10) - parser.add_argument('--num_test_envs', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.1) - parser.add_argument( - '--win-rate', - type=float, - default=0.6, - help='the expected winning rate: Optimal policy can get 0.7' - ) - parser.add_argument( - '--watch', - default=False, - action='store_true', - help='no training, ' - 'watch the play of pre-trained models' - ) - parser.add_argument( - '--agent-id', - type=int, - default=2, - help='the learned agent plays as the' - ' agent_id-th player. Choices are 1 and 2.' - ) - parser.add_argument( - '--resume-path', - type=str, - default='', - help='the path of agent pth file ' - 'for resuming from a pre-trained agent' - ) - parser.add_argument( - '--opponent-path', - type=str, - default='', - help='the path of opponent agent pth file ' - 'for resuming from a pre-trained agent' - ) - parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' - ) - return parser - - def get_args() -> argparse.Namespace: - parser = get_parser() - return parser.parse_known_args()[0] - -.. sidebar:: The relationship between MultiAgentPolicyManager (Manager) and BasePolicy (Agent) - - .. Figure:: ../_static/images/marl.png - -The following ``get_agents`` function returns agents and their optimizers from either constructing a new policy, or loading from disk, or using the pass-in arguments. For the models: - -- The action model we use is an instance of :class:`~tianshou.utils.net.common.MLPActor`, essentially a multi-layer perceptron with the ReLU activation function; -- The network model is passed to a :class:`~tianshou.algorithm.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; -- The opponent can be either a random agent :class:`~tianshou.algorithm.MARLRandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.algorithm.DQNPolicy` allowing learned agents to play with themselves. - -Both agents are passed to :class:`~tianshou.algorithm.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.algorithm.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. - -Here it is: -:: - - def get_agents( - args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, - ) -> Tuple[BasePolicy, torch.optim.Optimizer, list]: - env = get_env() - observation_space = env.observation_space['observation'] if isinstance( - env.observation_space, gym.spaces.Dict - ) else env.observation_space - args.state_shape = observation_space.shape or observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - if agent_learn is None: - # model - net = MLPActor( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device - ).to(args.device) - if optim is None: - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent_learn = DQNPolicy( - model=net, - optim=optim, - gamma=args.gamma, - action_space=env.action_space, - estimate_space=args.n_step, - target_update_freq=args.target_update_freq - ) - if args.resume_path: - agent_learn.load_state_dict(torch.load(args.resume_path)) - - if agent_opponent is None: - if args.opponent_path: - agent_opponent = deepcopy(agent_learn) - agent_opponent.load_state_dict(torch.load(args.opponent_path)) - else: - agent_opponent = RandomPolicy(action_space=env.action_space) - - if args.agent_id == 1: - agents = [agent_learn, agent_opponent] - else: - agents = [agent_opponent, agent_learn] - policy = MultiAgentPolicyManager(agents, env) - return policy, optim, env.agents - -With the above preparation, we are close to the first learned agent. The following code is almost the same as the code in the DQN tutorial. - -:: - - def get_env(render_mode=None): - return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode)) - - - def train_agent( - args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, - ) -> Tuple[dict, BasePolicy]: - - # ======== environment setup ========= - train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)]) - test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - - # ======== agent setup ========= - policy, optim, agents = get_agents( - args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim - ) - - # ======== collector setup ========= - train_collector = Collector( - policy, - train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True - ) - test_collector = Collector(policy, test_envs, exploration_noise=True) - # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.num_train_envs) - - # ======== tensorboard logging setup ========= - log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) - - # ======== callback functions used during training ========= - def save_best_fn(policy): - if hasattr(args, 'model_save_path'): - model_save_path = args.model_save_path - else: - model_save_path = os.path.join( - args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth' - ) - torch.save( - policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path - ) - - def stop_fn(mean_rewards): - return mean_rewards >= args.win_rate - - def train_fn(epoch, env_step): - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) - - def test_fn(epoch, env_step): - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) - - def reward_metric(rews): - return rews[:, args.agent_id - 1] - - # trainer - result = OffpolicyTrainer( - policy, - train_collector, - test_collector, - args.epoch, - args.epoch_num_steps, - args.collection_step_num_env_steps, - args.num_test_envs, - args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_best_fn=save_best_fn, - update_per_step=args.update_per_step, - logger=logger, - test_in_train=False, - reward_metric=reward_metric - ).run() - - return result, policy.policies[agents[args.agent_id - 1]] - - # ======== a test function that tests a pre-trained agent ====== - def watch( - args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - ) -> None: - env = get_env(render_mode="human") - env = DummyVectorEnv([lambda: env]) - policy, optim, agents = get_agents( - args, agent_learn=agent_learn, agent_opponent=agent_opponent - ) - policy.eval() - policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) - collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") - - # train the agent and watch its performance in a match! - args = get_args() - result, agent = train_agent(args) - watch(args, agent) - -That's it. By executing the code, you will see a progress bar indicating the progress of training. After about less than 1 minute, the agent has finished training, and you can see how it plays against the random agent. Here is an example: - -.. raw:: html - -
- Play with random agent - -:: - - | | - - | - | - - _____|_____|_____ - | | - - | - | X - _____|_____|_____ - | | - - | - | - - | | - | | - - | - | - - _____|_____|_____ - | | - - | O | X - _____|_____|_____ - | | - - | - | - - | | - | | - - | - | - - _____|_____|_____ - | | - X | O | X - _____|_____|_____ - | | - - | - | - - | | - | | - - | O | - - _____|_____|_____ - | | - X | O | X - _____|_____|_____ - | | - - | - | - - | | - | | - - | O | - - _____|_____|_____ - | | - X | O | X - _____|_____|_____ - | | - - | X | - - | | - | | - O | O | - - _____|_____|_____ - | | - X | O | X - _____|_____|_____ - | | - - | X | - - | | - | | - O | O | X - _____|_____|_____ - | | - X | O | X - _____|_____|_____ - | | - - | X | - - | | - | | - O | O | X - _____|_____|_____ - | | - X | O | X - _____|_____|_____ - | | - - | X | O - | | - Final reward: 1.0, length: 8.0 - -.. raw:: html - -

- -Notice that, our learned agent plays the role of agent 2, placing ``o`` on the board. The agent performs pretty well against the random opponent! It learns the rule of the game by trial and error, and learns that three consecutive ``o`` means winning, so it does! - -The above code can be executed in a python shell or can be saved as a script file (we have saved it in ``test/pettingzoo/test_tic_tac_toe.py``). In the latter case, you can train an agent by - -.. code-block:: console - - $ python test_tic_tac_toe.py - -By default, the trained agent is stored in ``log/tic_tac_toe/dqn/policy.pth``. You can also make the trained agent play against itself, by - -.. code-block:: console - - $ python test_tic_tac_toe.py --watch --resume-path log/tic_tac_toe/dqn/policy.pth --opponent-path log/tic_tac_toe/dqn/policy.pth - -Here is our output: - -.. raw:: html - -
- The trained agent play against itself - -:: - - | | - - | - | - - _____|_____|_____ - | | - - | X | - - _____|_____|_____ - | | - - | - | - - | | - | | - - | O | - - _____|_____|_____ - | | - - | X | - - _____|_____|_____ - | | - - | - | - - | | - | | - X | O | - - _____|_____|_____ - | | - - | X | - - _____|_____|_____ - | | - - | - | - - | | - | | - X | O | - - _____|_____|_____ - | | - - | X | - - _____|_____|_____ - | | - - | - | O - | | - | | - X | O | - - _____|_____|_____ - | | - - | X | - - _____|_____|_____ - | | - - | X | O - | | - | | - X | O | O - _____|_____|_____ - | | - - | X | - - _____|_____|_____ - | | - - | X | O - | | - | | - X | O | O - _____|_____|_____ - | | - - | X | - - _____|_____|_____ - | | - X | X | O - | | - | | - X | O | O - _____|_____|_____ - | | - - | X | O - _____|_____|_____ - | | - X | X | O - | | - Final reward: 1.0, length: 8.0 - -.. raw:: html - -

- -Well, although the learned agent plays well against the random agent, it is far away from intelligence. - -Next, maybe you can try to build more intelligent agents by letting the agent learn from self-play, just like AlphaZero! - -In this tutorial, we show an example of how to use Tianshou for training a single agent in a MARL setting. Tianshou is a flexible and easy to use RL library. Make the best of Tianshou by yourself! diff --git a/docs/01_tutorials/05_logger.rst b/docs/01_tutorials/05_logger.rst deleted file mode 100644 index c8161374d..000000000 --- a/docs/01_tutorials/05_logger.rst +++ /dev/null @@ -1,66 +0,0 @@ -Logging Experiments -=================== - -Tianshou comes with multiple experiment tracking and logging solutions to manage and reproduce your experiments. -The dashboard loggers currently available are: - -* :class:`~tianshou.utils.TensorboardLogger` -* :class:`~tianshou.utils.WandbLogger` -* :class:`~tianshou.utils.LazyLogger` - - -TensorboardLogger ------------------ - -Tensorboard tracks your experiment metrics in a local dashboard. Here is how you can use TensorboardLogger in your experiment: - -:: - - from torch.utils.tensorboard import SummaryWriter - from tianshou.utils import TensorboardLogger - - log_path = os.path.join(args.logdir, args.task, "dqn") - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) - result = trainer(..., logger=logger) - - -WandbLogger ------------ - -:class:`~tianshou.utils.WandbLogger` can be used to visualize your experiments in a hosted `W&B dashboard `_. It can be installed via ``pip install wandb``. You can also save your checkpoints in the cloud and restore your runs from those checkpoints. Here is how you can enable WandbLogger: - -:: - - from tianshou.utils import WandbLogger - from torch.utils.tensorboard import SummaryWriter - - logger = WandbLogger(...) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger.load(writer) - result = trainer(..., logger=logger) - -Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration. - -For logging checkpoints on any device, you need to define a ``save_checkpoint_fn`` which saves the experiment checkpoint and returns the path of the saved checkpoint: - -:: - - def save_checkpoint_fn(epoch, env_step, gradient_step): - ckpt_path = ... - # save model - return ckpt_path - -Then, use this function with ``WandbLogger`` to automatically version your experiment checkpoints after every ``save_interval`` step. - -For resuming runs from checkpoint artifacts on any device, pass the W&B ``run_id`` of the run that you want to continue in ``WandbLogger``. It will then download the latest version of the checkpoint and resume your runs from the checkpoint. - -The example scripts are under `test_psrl.py `_ and `atari_dqn.py `_. - - -LazyLogger ----------- - -This is a place-holder logger that does nothing. diff --git a/docs/01_tutorials/07_cheatsheet.rst b/docs/01_tutorials/07_cheatsheet.rst deleted file mode 100644 index c9f3e425d..000000000 --- a/docs/01_tutorials/07_cheatsheet.rst +++ /dev/null @@ -1,459 +0,0 @@ -Cheat Sheet -=========== - -**IMPORTANT**: The content here has not yet been adjusted to the v2 version of Tianshou. It is partially outdated and will be updated soon. - -This page shows some code snippets of how to use Tianshou to develop new -algorithms / apply algorithms to new scenarios. - -By the way, some of these issues can be resolved by using a ``gymnasium.Wrapper``. -It could be a universal solution in the policy-environment interaction. But -you can also use the batch processor :ref:`preprocess_fn` or vectorized -environment wrapper :class:`~tianshou.env.VectorEnvWrapper`. - - -.. _eval_policy: - -Manually Evaluate Policy ------------------------- - -If you'd like to manually see the action generated by a well-trained agent: -:: - - # assume obs is a single environment observation - action = policy(Batch(obs=np.array([obs]))).act[0] - - -.. _resume_training: - -Resume Training Process ------------------------ - -This is related to `Issue 349 `_. - -To resume training process from an existing checkpoint, you need to do the following things in the training process: - -1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer; -2. Use ``TensorboardLogger``; -3. To adjust the save frequency, specify ``save_interval`` when initializing TensorboardLogger. - -And to successfully resume from a checkpoint: - -1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer; -2. Set ``resume_from_log=True`` with trainer; - -We provide an example to show how these steps work: checkout `test_c51.py `_, `test_ppo.py `_ or `test_discrete_bcq.py `_ by running - -.. code-block:: console - - $ python3 test/discrete/test_c51.py # train some epoch - $ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training - - -To correctly render the data (including several tfevent files), we highly recommend using ``tensorboard >= 2.5.0`` (see `here `_ for the reason). Otherwise, it may cause overlapping issue that you need to manually handle with. - -.. _parallel_sampling: - -Parallel Sampling ------------------ - -Tianshou provides the following classes for vectorized environment: - -- :class:`~tianshou.env.DummyVectorEnv` is for pseudo-parallel simulation (implemented with a for-loop, useful for debugging). -- :class:`~tianshou.env.SubprocVectorEnv` uses multiple processes for parallel simulation. This is the most often choice for parallel simulation. -- :class:`~tianshou.env.ShmemVectorEnv` has a similar implementation to :class:`~tianshou.env.SubprocVectorEnv`, but is optimized (in terms of both memory footprint and simulation speed) for environments with large observations such as images. -- :class:`~tianshou.env.RayVectorEnv` is currently the only choice for parallel simulation in a cluster with multiple machines. - -Although these classes are optimized for different scenarios, they have exactly the same APIs because they are sub-classes of :class:`~tianshou.env.BaseVectorEnv`. Just provide a list of functions who return environments upon called, and it is all set. - -:: - - env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]] - venv = SubprocVectorEnv(env_fns) # DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like. - venv.reset() # returns the initial observations of each environment - venv.step(actions) # provide actions for each environment and get their results - -.. sidebar:: An example of sync/async VectorEnv (steps with the same color end up in one batch that is disposed by the policy at the same time). - - .. Figure:: ../_static/images/async.png - -By default, parallel environment simulation is synchronous: a step is done after all environments have finished a step. Synchronous simulation works well if each step of environments costs roughly the same time. - -In case the time cost of environments varies a lot (e.g. 90% step cost 1s, but 10% cost 10s) where slow environments lag fast environments behind, async simulation can be used (related to `Issue 103 `_). The idea is to start those finished environments without waiting for slow environments. - -Asynchronous simulation is a built-in functionality of -:class:`~tianshou.env.BaseVectorEnv`. Just provide ``wait_num`` or ``timeout`` -(or both) and async simulation works. - -:: - - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=x) for i in [2, 3, 4, 5]] - # DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like. - venv = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.2) - venv.reset() # returns the initial observations of each environment - # returns "wait_num" steps or finished steps after "timeout" seconds, - # whichever occurs first. - venv.step(actions, ready_id) - -If we have 4 envs and set ``wait_num = 3``, each of the step only returns 3 results of these 4 envs. - -You can treat the ``timeout`` parameter as a dynamic ``wait_num``. In each vectorized step it only returns the environments finished within the given time. If there is no such environment, it will wait until any of them finished. - -The figure in the right gives an intuitive comparison among synchronous/asynchronous simulation. - -.. note:: - - The async simulation collector would cause some exceptions when used as - ``test_collector`` in :doc:`/03_api/trainer/index` (related to - `Issue 700 `_). Please use - sync version for ``test_collector`` instead. - -.. warning:: - - If you use your own environment, please make sure the ``seed`` method is set up properly, e.g., - - :: - - def seed(self, seed): - np.random.seed(seed) - - Otherwise, the outputs of these envs may be the same with each other. - -.. _envpool_integration: - -EnvPool Integration -------------------- - -`EnvPool `_ is a C++-based vectorized environment implementation and is way faster than the above solutions. The APIs are almost the same as above four classes, so that means you can directly switch the vectorized environment to envpool and get immediate speed-up. - -Currently it supports -`Atari `_, -`Mujoco `_, -`VizDoom `_, -toy_text and classic_control environments. For more information, please refer to `EnvPool's documentation `_. - -:: - - # install envpool: pip3 install envpool - - import envpool - envs = envpool.make_gymnasium("CartPole-v1", num_envs=10) - collector = Collector(policy, envs, buffer) - -Here are some other `examples `_. - -.. _preprocess_fn: - -Handle Batched Data Stream in Collector ---------------------------------------- - -This is related to `Issue 42 `_. - -If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. - -It will receive with "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy", "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. - -These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation. - -For example, you can write your hook as: -:: - - import numpy as np - from collections import deque - - - class MyProcessor: - def __init__(self, size=100): - self.episode_log = None - self.main_log = deque(maxlen=size) - self.main_log.append(0) - self.baseline = 0 - - def preprocess_fn(**kwargs): - """change reward to zero mean""" - # if obs && env_id exist -> reset - # if obs_next/act/rew/done/policy/env_id exist -> normal step - if 'rew' not in kwargs: - # means that it is called after env.reset(), it can only process the obs - return Batch() # none of the variables are needed to be updated - else: - n = len(kwargs['rew']) # the number of envs in collector - if self.episode_log is None: - self.episode_log = [[] for i in range(n)] - for i in range(n): - self.episode_log[i].append(kwargs['rew'][i]) - kwargs['rew'][i] -= self.baseline - for i in range(n): - if kwargs['done'][i]: - self.main_log.append(np.mean(self.episode_log[i])) - self.episode_log[i] = [] - self.baseline = np.mean(self.main_log) - return Batch(rew=kwargs['rew']) - -And finally, -:: - - test_processor = MyProcessor(size=100) - collector = Collector(policy, env, buffer, preprocess_fn=test_processor.preprocess_fn) - -Some examples are in `test/base/test_collector.py `_. - -Another solution is to create a vector environment wrapper through :class:`~tianshou.env.VectorEnvWrapper`, e.g. -:: - - import numpy as np - from collections import deque - from tianshou.env import VectorEnvWrapper - - class MyWrapper(VectorEnvWrapper): - def __init__(self, venv, size=100): - self.episode_log = None - self.main_log = deque(maxlen=size) - self.main_log.append(0) - self.baseline = 0 - - def step(self, action, env_id): - obs, rew, done, info = self.venv.step(action, env_id) - n = len(rew) - if self.episode_log is None: - self.episode_log = [[] for i in range(n)] - for i in range(n): - self.episode_log[i].append(rew[i]) - rew[i] -= self.baseline - for i in range(n): - if done[i]: - self.main_log.append(np.mean(self.episode_log[i])) - self.episode_log[i] = [] - self.baseline = np.mean(self.main_log) - return obs, rew, done, info - - env = MyWrapper(env, size=100) - collector = Collector(policy, env, buffer) - -We provide an observation normalization vector env wrapper: :class:`~tianshou.env.VectorEnvNormObs`. - - -.. _rnn_training: - -RNN-style Training ------------------- - -This is related to `Issue 19 `_. - -First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :class:`~tianshou.data.VectorReplayBuffer`, or other types of buffer you are using, like: -:: - - buf = ReplayBuffer(size=size, stack_num=stack_num) - -Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`. - -The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map ``[s, a]`` pair to a new state: - -- Before: ``(s, a, s', r, d)`` stored in replay buffer, and get stacked s; -- After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a. - - -.. _multi_gpu: - -Multi-GPU Training ------------------- - -To enable training an RL agent with multiple GPUs for a standard environment (i.e., without nested observation) with default networks provided by Tianshou: - -1. Import :class:`~tianshou.utils.net.common.DataParallelNet` from ``tianshou.utils.net.common``; -2. Change the ``device`` argument to ``None`` in the existing networks such as ``MLPActor``, ``Actor``, ``Critic``, ``ActorProb`` -3. Apply ``DataParallelNet`` wrapper to these networks. - -:: - - from tianshou.utils.net.common import MLPActor, DataParallelNet - from tianshou.utils.net.discrete import Actor, Critic - - actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device)) - critic = DataParallelNet(Critic(net, device=None).to(args.device)) - -Yes, that's all! This general approach can be applied to almost all kinds of algorithms implemented in Tianshou. -We provide a complete script to show how to run multi-GPU: `test/discrete/test_ppo.py `_ - -As for other cases such as customized network or environments that have a nested observation, here are the rules: - -1. The data format transformation (numpy -> cuda) is done in the ``DataParallelNet`` wrapper; your customized network should not apply any kinds of data format transformation; -2. Create a similar class that inherit ``DataParallelNet``, which is only in charge of data format transformation (numpy -> cuda); -3. Do the same things above. - - -.. _self_defined_env: - -User-defined Environment and Different State Representation ------------------------------------------------------------ - -This is related to `Issue 38 `_ and `Issue 69 `_. - -First of all, your self-defined environment must follow the Gym's API, some of them are listed below: - -- reset() -> state - -- step(action) -> state, reward, done, info - -- seed(s) -> List[int] - -- render(mode) -> Any - -- close() -> None - -- observation_space: gym.Space - -- action_space: gym.Space - -The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1" as an example: -:: - - >>> e = gym.make('FetchReach-v1') - >>> e.reset() - {'observation': array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04, - 7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06, - 4.73325650e-06, -2.28455228e-06]), - 'achieved_goal': array([1.34183265, 0.74910039, 0.53472272]), - 'desired_goal': array([1.24073906, 0.77753463, 0.63457791])} - -It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as: -:: - - >>> from tianshou.data import Batch, ReplayBuffer - >>> b = ReplayBuffer(size=3) - >>> b.add(Batch(obs=e.reset(), act=0, rew=0, done=0)) - >>> print(b) - ReplayBuffer( - act: array([0, 0, 0]), - done: array([False, False, False]), - obs: Batch( - achieved_goal: array([[1.34183265, 0.74910039, 0.53472272], - [0. , 0. , 0. ], - [0. , 0. , 0. ]]), - desired_goal: array([[1.42154265, 0.62505137, 0.62929863], - [0. , 0. , 0. ], - [0. , 0. , 0. ]]), - observation: array([[ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, - 1.97805133e-04, 7.15193042e-05, 7.73933014e-06, - 5.51992816e-08, -2.42927453e-06, 4.73325650e-06, - -2.28455228e-06], - [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00], - [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, - 0.00000000e+00]]), - ), - rew: array([0, 0, 0]), - ) - >>> print(b.obs.achieved_goal) - [[1.34183265 0.74910039 0.53472272] - [0. 0. 0. ] - [0. 0. 0. ]] - -And the data batch sampled from this replay buffer: -:: - - >>> batch, indices = b.sample(2) - >>> batch.keys() - ['act', 'done', 'info', 'obs', 'obs_next', 'policy', 'rew'] - >>> batch.obs[-1] - Batch( - achieved_goal: array([1.34183265, 0.74910039, 0.53472272]), - desired_goal: array([1.42154265, 0.62505137, 0.62929863]), - observation: array([ 1.34183265e+00, 7.49100387e-01, 5.34722720e-01, 1.97805133e-04, - 7.15193042e-05, 7.73933014e-06, 5.51992816e-08, -2.42927453e-06, - 4.73325650e-06, -2.28455228e-06]), - ) - >>> batch.obs.desired_goal[-1] # recommended - array([1.42154265, 0.62505137, 0.62929863]) - >>> batch.obs[-1].desired_goal # not recommended - array([1.42154265, 0.62505137, 0.62929863]) - >>> batch[-1].obs.desired_goal # not recommended - array([1.42154265, 0.62505137, 0.62929863]) - -Thus, in your self-defined network, just change the ``forward`` function as: -:: - - def forward(self, s, ...): - # s is a batch - observation = s.observation - achieved_goal = s.achieved_goal - desired_goal = s.desired_goal - ... - -For self-defined class, the replay buffer will store the reference into a ``numpy.ndarray``, e.g.: -:: - - >>> import networkx as nx - >>> b = ReplayBuffer(size=3) - >>> b.add(Batch(obs=nx.Graph(), act=0, rew=0, done=0)) - >>> print(b) - ReplayBuffer( - act: array([0, 0, 0]), - done: array([0, 0, 0]), - info: Batch(), - obs: array([, None, - None], dtype=object), - policy: Batch(), - rew: array([0, 0, 0]), - ) - -But the state stored in the buffer may be a shallow-copy. To make sure each of your state stored in the buffer is distinct, please return the deep-copy version of your state in your env: -:: - - def reset(): - return copy.deepcopy(self.graph) - def step(action): - ... - return copy.deepcopy(self.graph), reward, done, {} - -.. note :: - - Please make sure this variable is numpy-compatible, e.g., np.array([variable]) will not result in an empty array. Otherwise, ReplayBuffer cannot create an numpy array to store it. - - -.. _marl_example: - -Multi-Agent Reinforcement Learning ----------------------------------- - -This is related to `Issue 121 `_. The discussion is still goes on. - -With the flexible core APIs, Tianshou can support multi-agent reinforcement learning with minimal efforts. - -Currently, we support three types of multi-agent reinforcement learning paradigms: - -1. Simultaneous move: at each timestep, all the agents take their actions (example: MOBA games) - -2. Cyclic move: players take action in turn (example: Go game) - -3. Conditional move, at each timestep, the environment conditionally selects an agent to take action. (example: `Pig Game `_) - -We mainly address these multi-agent RL problems by converting them into traditional RL formulations. - -For simultaneous move, the solution is simple: we can just add a ``num_agent`` dimension to state, action, and reward. Nothing else is going to change. - -For 2 & 3 (cyclic move and conditional move), they can be unified into a single framework: at each timestep, the environment selects an agent with id ``agent_id`` to play. Since multi-agents are usually wrapped into one object (which we call "abstract agent"), we can pass the ``agent_id`` to the "abstract agent", leaving it to further call the specific agent. - -In addition, legal actions in multi-agent RL often vary with timestep (just like Go games), so the environment should also passes the legal action mask to the "abstract agent", where the mask is a boolean array that "True" for available actions and "False" for illegal actions at the current step. Below is a figure that explains the abstract agent. - -.. image:: /_static/images/marl.png - :align: center - :height: 300 - -The above description gives rise to the following formulation of multi-agent RL: -:: - - act = policy(state, agent_id, mask) - (next_state, next_agent_id, next_mask), reward = env.step(act) - -By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL: -:: - - act = policy(state_) - next_state_, reward = env.step(act) - -Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/01_tutorials/04_tictactoe`. diff --git a/docs/01_tutorials/index.rst b/docs/01_tutorials/index.rst deleted file mode 100644 index f08b66f9f..000000000 --- a/docs/01_tutorials/index.rst +++ /dev/null @@ -1,2 +0,0 @@ -Tutorials -========= \ No newline at end of file diff --git a/docs/01_tutorials/00_training_process.md b/docs/01_user_guide/00_training_process.md similarity index 77% rename from docs/01_tutorials/00_training_process.md rename to docs/01_user_guide/00_training_process.md index 5fe77b17c..81693fdc0 100644 --- a/docs/01_tutorials/00_training_process.md +++ b/docs/01_user_guide/00_training_process.md @@ -1,4 +1,4 @@ -# Understanding the Reinforcement Learning Loop +# The Reinforcement Learning Process The following diagram illustrates the key mechanisms underlying the learning process in model-free reinforcement learning algorithms. It shows how the agent interacts with the environment, collects experiences, and periodically updates its policy based on those experiences. @@ -56,35 +56,36 @@ These entities have direct correspondences in Tianshou's codebase: * The environment is represented by an instance of a class that inherits from `gymnasium.Env`, which is a standard interface for reinforcement learning environments. In practice, environments are typically vectorized to enable parallel interactions, increasing efficiency. - * The policy is encapsulated in the `Policy` class, which provides methods for action selection. - * The replay buffer is implemented in the `ReplayBuffer` class. - A `Collector` instance is used to manage the addition of new experiences to the replay buffer as the agent interacts with the + * The policy is encapsulated in the {class}`~tianshou.algorithm.algorithm_base.Policy` class, which provides methods for action selection. + * The replay buffer is implemented in the {class}`~tianshou.data.buffer.buffer_base.ReplayBuffer` class. + A {class}`~tianshou.data.collector.Collector` instance is used to manage the addition of new experiences to the replay buffer as the agent interacts with the environment. - During the learning phase, the replay buffer may be sampled, providing an instance of `Batch` for the policy update. - * The abstraction for learning algorithms is given by the `Algorithm` class, which defines how to update the policy using data from the + During the learning phase, the replay buffer may be sampled, providing an instance of {class}`~tianshou.data.batch.Batch` for the policy update. + * The abstraction for learning algorithms is given by the {class}`~tianshou.algorithm.algorithm_base.Algorithm` class, which defines how to update the policy using data from the replay buffer. -## The Training Process +(structuring-the-process)= +## Structuring the Process -The learning process itself is reified in Tianshou's `Trainer` class, which orchestrates the interaction between the agent and the +The learning process itself is reified in Tianshou's {class}`~tianshou.trainer.trainer.Trainer` class, which orchestrates the interaction between the agent and the environment, manages the replay buffer, and coordinates the policy updates according to the specified learning algorithm. In general, the process can be described as executing a number of epochs as follows: -* **Epoch**: - * Repeat until a sufficient number of steps is reached (for online learning, typically environment step count) - * **Training Step**: - * For online learning algorithms … - * **Collection Step**: collect state transitions in the environment by running the agent - * (Optionally) conduct a test step if collected data indicates promising behaviour - * **Update Step**: Apply gradient updates using the algorithm’s update logic. +* **epoch**: + * repeat until a sufficient number of steps is reached (for online learning, typically environment step count) + * **training step**: + * for online learning algorithms … + * **collection step**: collect state transitions in the environment by running the agent + * (optionally) conduct a test step if collected data indicates promising behaviour + * **update step**: apply gradient updates using the algorithm’s update logic. The update is based on … * data from the preceding collection step only (on-policy learning) * data from the collection step and previous data (off-policy learning) * data from a user-provided replay buffer (offline learning) - * **Test Step** - * Collect test episodes from dedicated test environments and evaluate agent performance - * (Optionally) stop training early if performance is sufficiently high + * **test step** + * collect test episodes from dedicated test environments and evaluate agent performance + * (optionally) stop training early if performance is sufficiently high ```{admonition} Glossary :class: note @@ -97,4 +98,4 @@ Note that the above description encompasses several modes of model-free reinforc * off-policy learning (where the policy is updated based on data collected using the current and previous policies) * offline learning (where the replay buffer is pre-filled and not updated during training) -In Tianshou, the `Trainer` and `Algorithm` classes are specialised to handle these different modes accordingly. +In Tianshou, the {class}`~tianshou.trainer.trainer.Trainer` and {class}`~tianshou.algorithm.algorithm_base.Algorithm` classes are specialised to handle these different modes accordingly. diff --git a/docs/01_tutorials/01_apis.md b/docs/01_user_guide/01_apis.md similarity index 68% rename from docs/01_tutorials/01_apis.md rename to docs/01_user_guide/01_apis.md index c9f9d670a..c08cb9499 100644 --- a/docs/01_tutorials/01_apis.md +++ b/docs/01_user_guide/01_apis.md @@ -1,9 +1,9 @@ -# Tianshou's Dual API Architecture +# Dual APIs Tianshou provides two distinct APIs to serve different use cases and user preferences: -1. **High-Level API**: A declarative, configuration-based interface designed for ease of use -2. **Procedural API**: A flexible, imperative interface providing maximum control +1. **high-level API**: a declarative, configuration-based interface designed for ease of use +2. **procedural API**: a flexible, imperative interface providing maximum control Both APIs access the same underlying algorithm implementations, allowing you to choose the level of abstraction that best fits your needs without sacrificing functionality. @@ -18,12 +18,12 @@ you declare _what_ you want through configuration objects and let Tianshou handl build and execute the experiment. **Key characteristics:** -- Centered around `ExperimentBuilder` classes (e.g., `DQNExperimentBuilder`, `PPOExperimentBuilder`, etc.) -- Uses configuration dataclasses and factories for all relevant parameters -- Automatically handles component creation and "wiring" -- Provides sensible defaults that adapt to the nature of your environment -- Includes built-in persistence, logging, and experiment management -- Excellent IDE support with auto-completion +- centered around {class}`~tianshou.highlevel.experiment.ExperimentBuilder` classes (e.g., {class}`~tianshou.highlevel.experiment.DQNExperimentBuilder`, {class}`~tianshou.highlevel.experiment.PPOExperimentBuilder`, etc.) +- uses configuration dataclasses and factories for all relevant parameters +- automatically handles component creation and "wiring" +- provides sensible defaults that adapt to the nature of your environment +- includes built-in persistence, logging, and experiment management +- full type hints (but object structure is not flat; a proper IDE is required for seamless user experience) ### Procedural API @@ -32,31 +32,30 @@ You manually create environments, networks, policies, algorithms, collectors, an trainers, then wire them together. **Key characteristics:** -- Direct instantiation of all components -- Explicit control over the training loop -- Lower-level access to internal mechanisms -- Minimal abstraction (closer to the implementation) -- Ideal for algorithm development and research +- direct instantiation of all components +- explicit control over the training loop +- lower-level access to internal mechanisms +- minimal abstraction (closer to the implementation) +- ideal for algorithm development and research ## When to Use Which API -### Use the High-Level API when: +Use the high-level API when ... -- **You're applying existing algorithms** to new problems -- **You want to get started quickly** with minimal boilerplate -- **You need experiment management** with persistence, logging, and reproducibility -- **You prefer declarative code** that focuses on configuration -- **You're building applications** rather than developing new algorithms -- **You want strong IDE support** with auto-completion and type hints +- **you're applying existing algorithms** to new problems +- **you want to get started quickly** with minimal boilerplate +- **you need experiment management** with persistence, logging, and reproducibility +- **you prefer declarative code** that focuses on configuration +- **you're building applications** rather than developing new algorithms -### Use the Procedural API when: +Use the procedural API when: -- **You're developing new algorithms** or modifying existing ones -- **You need fine-grained control** over the training process -- **You want to understand** the internal workings of Tianshou -- **You're implementing custom components** not supported by the high-level API -- **You prefer imperative programming** where each step is explicit -- **You need maximum flexibility** for experimental research +- **you're developing new algorithms** or modifying existing ones +- **you need fine-grained control** over the training process +- **you want to understand** the internal workings of Tianshou +- **you're implementing custom components** not supported by the high-level API +- **you prefer imperative programming** where each step is explicit +- **you need maximum flexibility** for experimental research ## Comparison by Example @@ -123,18 +122,18 @@ experiment.run() ``` **What's happening here:** -1. We create an `ExperimentBuilder` with three main configuration objects +1. We create an {class}`~tianshou.highlevel.experiment.ExperimentBuilder` with three main configuration objects 2. We chain builder methods to specify algorithm parameters, model architecture, and callbacks 3. We call `.build()` to construct the experiment 4. We call `.run()` to execute the entire training pipeline -The high-level API handles: -- Creating and configuring environments -- Building the neural network -- Instantiating the policy and algorithm -- Setting up collectors and replay buffer -- Managing the training loop -- Watching the trained agent +The high-level API handles ... +- creating and configuring environments +- building the neural network +- instantiating the policy and algorithm +- setting up collectors and replay buffer +- managing the training loop +- watching the trained agent ### Procedural API Example @@ -240,18 +239,17 @@ collector.collect(n_episode=100, render=1 / 35) 8. We call `algorithm.run_training()` with explicit parameters 9. We manually set up and run the evaluation collector -The procedural API requires: -- Explicit creation of every component -- Manual extraction of environment properties -- Direct specification of all connections -- Custom callback function definitions +The procedural API requires ... +- explicit creation of every component +- manual extraction of environment properties +- direct specification of all connections ## Key Concepts in the High-Level API ### ExperimentBuilder -The `ExperimentBuilder` is the core abstraction. -Each algorithm has its own builder (e.g., `DQNExperimentBuilder`, `PPOExperimentBuilder`, `SACExperimentBuilder`). +The {class}`~tianshou.highlevel.experiment.ExperimentBuilder` is the core abstraction. +Each algorithm has its own builder (e.g., {class}`~tianshou.highlevel.experiment.DQNExperimentBuilder`, {class}`~tianshou.highlevel.experiment.PPOExperimentBuilder`, {class}`~tianshou.highlevel.experiment.SACExperimentBuilder`). **Some methods you will find in experiment builders:** - `.with__params()` - Set algorithm-specific parameters @@ -266,26 +264,26 @@ Each algorithm has its own builder (e.g., `DQNExperimentBuilder`, `PPOExperiment Three main configuration objects are required when constructing an experiment builder: -1. **Environment Configuration** (`EnvFactory` subclasses) +1. **Environment Configuration** ({class}`~tianshou.highlevel.env.EnvFactory` subclasses) - Defines how to create and configure environments - Existing factories: - - `EnvFactoryRegistered` - For the creation of environments registered in Gymnasium - - `AtariEnvFactory` - For Atari environments with preprocessing - - Custom factories for your own environments can be created by subclassing `EnvFactory` + - {class}`~tianshou.highlevel.env.EnvFactoryRegistered` - For the creation of environments registered in Gymnasium + - {class}`~tianshou.highlevel.env.atari.atari_wrapper.AtariEnvFactory` - For Atari environments with preprocessing + - Custom factories for your own environments can be created by subclassing {class}`~tianshou.highlevel.env.EnvFactory` -2. **Experiment Configuration** (`ExperimentConfig`): +2. **Experiment Configuration** ({class}`~tianshou.highlevel.experiment.ExperimentConfig`): General settings for the experiment, particularly related to - logging - randomization - persistence - watching the trained agent's performance after training -3. **Training Configuration** (`OffPolicyTrainingConfig`, `OnPolicyTrainingConfig`): +3. **Training Configuration** ({class}`~tianshou.highlevel.config.OffPolicyTrainingConfig`, {class}`~tianshou.highlevel.config.OnPolicyTrainingConfig`): Defines all parameters related to the training process ### Parameter Classes -Algorithm parameters are defined in dataclasses specific to each algorithm (e.g., `DQNParams`, `PPOParams`). +Algorithm parameters are defined in dataclasses specific to each algorithm (e.g., {class}`~tianshou.highlevel.params.algorithm_params.DQNParams`, {class}`~tianshou.highlevel.params.algorithm_params.PPOParams`). The parameters are extensively documented. ```{note} @@ -295,7 +293,7 @@ Make sure to use a modern IDE to take advantage of auto-completion and inline do ### Factories The high-level API uses factories extensively: -- **Model Factories**: Create neural networks (e.g., `IntermediateModuleFactoryAtariDQN()`) +- **Model Factories**: Create neural networks (e.g., {class}`~tianshou.highlevel.module.intermediate.IntermediateModuleFactoryAtariDQN`) - **Environment Factories**: Create and configure environments - **Optimizer Factories**: Create optimizers with specific configurations @@ -341,15 +339,15 @@ experiment = ( ### Core Components -You manually create and connect: +You manually create and connect ... -1. **Environments**: Using `gym.make()` and vectorization (`DummyVectorEnv`, `SubprocVectorEnv`) -2. **Networks**: Using `Net` or custom PyTorch modules -3. **Policies**: Using algorithm-specific policy classes (e.g., `DiscreteQLearningPolicy`) -4. **Algorithms**: Using algorithm classes (e.g., `DQN`, `PPO`, `SAC`) -5. **Collectors**: Using `Collector` to gather experience -6. **Buffers**: Using `VectorReplayBuffer` or `ReplayBuffer` -7. **Trainers**: Using the respective trainer class and corresponding parameter class (e.g., `OffPolicyTrainer` and `OffPolicyTrainerParams`) +1. **environments**: e.g. using `gym.make()` and vectorization ({class}`~tianshou.env.DummyVectorEnv`, {class}`~tianshou.env.SubprocVectorEnv`) +2. **networks**: using {class}`~tianshou.utils.net.common.Net` or other PyTorch modules +3. **policies**: using algorithm-specific policy classes (e.g., {class}`~tianshou.algorithm.modelfree.dqn.DiscreteQLearningPolicy`) +4. **algorithms**: using algorithm classes (e.g., {class}`~tianshou.algorithm.modelfree.dqn.DQN`, {class}`~tianshou.algorithm.modelfree.ppo.PPO`, {class}`~tianshou.algorithm.modelfree.sac.SAC`) +5. **collectors**: using {class}`~tianshou.data.Collector` to gather experience +6. **buffers**: using {class}`~tianshou.data.buffer.VectorReplayBuffer` or {class}`~tianshou.data.buffer.ReplayBuffer` +7. **trainers**: using the respective trainer class and corresponding parameter class (e.g., {class}`~tianshou.trainer.OffPolicyTrainer` and {class}`~tianshou.trainer.OffPolicyTrainerParams`) ### Training Loop @@ -357,20 +355,7 @@ The training is executed via `algorithm.run_training()`, which takes a trainer p You can alternatively implement custom training loops (or even your own trainer class) for maximum flexibility. -## Choosing Your Path - -**Use the high-level API** if ... -- you are new to Tianshou, -- you are focused on applying RL to problems, -- you prefer declarative code. - -**Use the procedural API** if ... -- you are developing new algorithms, -- you need maximum flexibility, -- you are comfortable with RL internals, -- you prefer imperative code. - ## Additional Resources -- **High-Level API Examples**: See `examples/` directory (scripts ending in `_hl.py`) -- **Procedural API Examples**: See `examples/` directory (scripts without suffix) +- **high-Level API examples**: See `examples/` directory (scripts ending in `_hl.py`) +- **procedural API examples**: See `examples/` directory (scripts without suffix) diff --git a/docs/01_user_guide/02_core_abstractions.md b/docs/01_user_guide/02_core_abstractions.md new file mode 100644 index 000000000..95e7ecc77 --- /dev/null +++ b/docs/01_user_guide/02_core_abstractions.md @@ -0,0 +1,319 @@ +# Core Abstractions + +Tianshou's architecture is built around a number of key abstractions that work together to provide a modular and flexible reinforcement learning framework. +This document describes the conceptual foundation and functionality of each abstraction, helping you understand how they interact to enable RL agent training. + +Knowing these abstractions is primarily relevant when using the procedural API – and particularly when implementing one's own learning algoriithms. + +## Algorithm + +The **{class}`~tianshou.algorithm.algorithm_base.Algorithm`** is the central abstraction representing the core of a reinforcement learning method (such as DQN, PPO, or SAC). +It implements the key steps within the {ref}`learning process `, containing a {ref}`policy` and defining how to update it from experience data. + +Since an Algorithm contains neural networks and manages their training, the class inherits from `torch.nn.Module`. + +### Core Responsibilities + +An Algorithm implements the details of an {ref}`update step `: + +1. **preprocessing**: Before the actual update begins, the algorithm prepares the training data. + This includes computing derived quantities that depend on temporal sequences, such as n-step returns, GAE advantages, or terminal state handling. + The {meth}`~tianshou.algorithm.algorithm_base.Algorithm._preprocess_batch` method handles this phase, often leveraging static methods like + {meth}`~tianshou.algorithm.algorithm_base.Algorithm.compute_nstep_return` and + {meth}`~tianshou.algorithm.algorithm_base.Algorithm.compute_episodic_return` to + efficiently compute returns using the buffer's temporal structure. + +2. **network update**: The algorithm performs the actual neural network updates based on its specific learning method. + Each algorithm implements its own {meth}`~tianshou.algorithm.algorithm_base.Algorithm._update_with_batch` logic that defines how to update + the policy networks using the preprocessed batch data. + +3. **postprocessing**: After the update, the algorithm may perform cleanup operations, such as updating prioritized replay buffer weights or other + algorithm-specific bookkeeping. + +### Learning Orchestration + +The Algorithm orchestrates the {ref}`update step ` through its +{meth}`~tianshou.algorithm.algorithm_base.Algorithm.update` method, which ensures these three phases execute in proper sequence. +It also manages optimizer state and learning rate schedulers, making them available for state persistence through +{meth}`~tianshou.algorithm.algorithm_base.Algorithm.state_dict` and +{meth}`~tianshou.algorithm.algorithm_base.Algorithm.load_state_dict` methods. + +Each algorithm type (on-policy, off-policy, offline) creates its appropriate trainer through the +{meth}`~tianshou.algorithm.algorithm_base.Algorithm.create_trainer` method, +establishing the connection between the learning logic and the training loop. + +(policy)= +## Policy + +The **{class}`~tianshou.algorithm.algorithm_base.Policy`** represents the agent's decision-making component, i.e. the mapping from observations to actions. +While the Algorithm defines how to learn, the Policy defines what is learned and how to act. + +Like Algorithm, the class inherits from `torch.nn.Module`. + +### States of Operation + +A Policy operates in two main modes: + +- **training mode**: During training, the policy may employ exploration strategies, sample from action distributions, or add noise to encourage discovery. + Training mode is further divided into: + - *collecting state*: When gathering experience from environment interaction + - *updating state*: When performing network updates during learning + +- **testing/inference mode**: During evaluation, the policy typically acts deterministically or uses the mode of predicted distributions to showcase + learned behavior without exploration. + +The flag `is_within_training_step` controls the collection strategy, distinguishing between training and inference behavior. + +### Key Methods + +The Policy provides several essential methods: + +- **{meth}`~tianshou.algorithm.algorithm_base.Policy.forward`**: + The core computation method that processes batched observations to produce action distributions or Q-values. + It takes a batch of environment data and optional hidden state (for recurrent policies), + returning a batch containing at minimum the "act" key, + and potentially "state" (hidden state) and "policy" (intermediate results to be stored in the buffer). + +- **{meth}`~tianshou.algorithm.algorithm_base.Policy.compute_action`**: + A convenient method for inference that takes a single observation and returns a concrete action suitable for the environment. + This method internally calls `forward` with proper batching and unbatching. + +- **{meth}`~tianshou.algorithm.algorithm_base.Policy.map_action`**: Transforms the raw neural network output to the environment's action space format, handling any necessary scaling or discretization. + +The separation between `forward` (which works with batches) and +{meth}`~tianshou.algorithm.algorithm_base.Policy.compute_action` (which works with single observations) provides efficiency +during training and convenience during inference. + +## Collector + +The class **{class}`~tianshou.data.Collector`** bridges the gap between the policy and the environment(s), +managing the process of gathering experience data. +It enables efficient interaction with both single environments and vectorized environments (multiple parallel environments). + +### Data Collection + +The Collector's primary method, {meth}`~tianshou.data.Collector.collect`, orchestrates the environment interaction loop. It can collect either: +- a specified number of steps (`n_step`): useful for maintaining consistent training batch sizes +- a specified number of episodes (`n_episode`): useful for evaluation or when episode-level statistics are important + +During collection, the Collector ... +1. obtains observations from the environment(s), +2. calls the policy to compute actions, +3. steps the environment(s) with these actions, +4. stores the resulting transitions (observation, action, reward, next observation, termination flags, and info) in the replay buffer, +5. manages episode boundaries and reset logic, +6. collects statistics such as episode returns, lengths, and collection speed. + +### Hooks and Extensibility + +The Collector supports customization through hooks that can be triggered at different points in the collection process: +- **step hooks**: called after each environment step +- **episode done hooks**: called when episodes complete + +These hooks enable custom logging, curriculum learning, or other dynamic behaviors during data collection. + +### Vectorized Environments + +The Collector seamlessly handles vectorized environments, where multiple environment instances run in parallel. +This significantly speeds up data collection while maintaining correct episode boundaries and statistics for each environment instance. + +## Trainer + +The **{class}`~tianshou.trainer.Trainer`** orchestrates the complete training loop, coordinating data collection, policy updates, and evaluation. +It provides the high-level control flow that brings all components together. + +### Trainer Types + +Tianshou provides three main trainer types, each suited to different algorithm families: + +- **{class}`~tianshou.trainer.OnPolicyTrainer`**: for algorithms that must learn from freshly collected data (e.g., PPO, A2C). + After each collection phase, the buffer is used for updates and thereafter is cleared. + +- **{class}`~tianshou.trainer.OffPolicyTrainer`**: for algorithms that can learn from any past experience (e.g., DQN, SAC, DDPG). + Data accumulates in the replay buffer over time, and updates sample from this growing pool of experience. + +- **{class}`~tianshou.trainer.OfflineTrainer`**: for algorithms that learn exclusively from a fixed dataset without any environment interaction (e.g., BCQ, CQL). + +### Training Loop Structure + +The training process is organized into epochs, where each epoch consists of: + +1. **data collection**: The trainer uses the train collector to gather experience according to its algorithm type's needs +2. **policy update**: The algorithm performs one or more update steps using the collected data +3. **evaluation**: Periodically, the trainer uses the test collector to evaluate the current policy's performance +4. **logging**: Statistics from collection, updates, and evaluation are logged +5. **checkpointing**: The best policy (according to a scoring function) is saved + +The trainer handles the detailed choreography of these steps, including determining when to collect more data, +how many update steps to perform, when to evaluate, and when to stop training (based on maximum epochs, timesteps, or early stopping criteria). + +### Configuration + +Trainers are configured through parameter dataclasses +({class}`~tianshou.trainer.OnPolicyTrainerParams`, {class}`~tianshou.trainer.OffPolicyTrainerParams`, {class}`~tianshou.trainer.OfflineTrainerParams`) +that specify in particular: +- training duration (number of epochs, steps per epoch) +- collectors for training and testing +- update frequency and batch size +- evaluation frequency +- logging and checkpointing settings +- early stopping criteria + +## Batch + +The class **{class}`~tianshou.data.Batch`** is Tianshou's flexible data structure for passing information between components. +It serves as the lingua franca of the framework, carrying everything from raw environment observations to computed returns and policy outputs. + +### Design Philosophy + +Batch is designed to be ... +- **flexible**: can contain any key-value pairs, with nested structures supported, +- **numpy/torch-compatible**: automatically converts lists to arrays and seamlessly works with both NumPy arrays and PyTorch tensors, +- **sliceable**: supports indexing and slicing operations that work across all contained data, +- **composable**: can be concatenated, stacked, and split to support batching operations. + +### Type Safety with BatchProtocol + +While `Batch` provides a flexible, dictionary-like structure for holding arbitrary data, this flexibility can make it challenging to statically type-check which attributes are present in a batch at any given point in the code. To address this, Tianshou uses **{class}`~tianshou.data.batch.BatchProtocol`** and derived protocols to specify the expected attributes while keeping the actual runtime type as `Batch`. + +BatchProtocol is a Python `Protocol` (from `typing.Protocol`) that defines the interface of a Batch object, specifying which operations and attributes should be available. More importantly, Tianshou provides a rich set of derived protocols in {mod}`tianshou.data.types` that describe batches with specific sets of attributes commonly used throughout the framework: + +- **{class}`~tianshou.data.types.ObsBatchProtocol`**: Contains `obs` and `info` - the minimal batch for policy forward passes +- **{class}`~tianshou.data.types.RolloutBatchProtocol`**: Adds `obs_next`, `act`, `rew`, `terminated`, and `truncated` - typical data from replay buffer sampling +- **{class}`~tianshou.data.types.BatchWithReturnsProtocol`**: Extends RolloutBatchProtocol with `returns` computed from rewards +- **{class}`~tianshou.data.types.BatchWithAdvantagesProtocol`**: Includes `adv` (advantages) and `v_s` (value estimates) for policy gradient methods +- **{class}`~tianshou.data.types.ActStateBatchProtocol`**: Contains `act` and `state` for policy outputs, especially with RNN support +- **{class}`~tianshou.data.types.ModelOutputBatchProtocol`**: Adds `logits` to action and state information +- **{class}`~tianshou.data.types.DistBatchProtocol`**: Contains action distributions (`dist`) for stochastic policies +- **{class}`~tianshou.data.types.PrioBatchProtocol`**: Includes `weight` for prioritized experience replay + +These protocols serve as type hints in function signatures throughout Tianshou, making it explicit what attributes are expected and available. For example, a policy's `forward` method might accept an `ObsBatchProtocol` and return an `ActStateBatchProtocol`, clearly documenting the data contract. Despite these type annotations, the actual objects remain flexible `Batch` instances at runtime, preserving Tianshou's dynamic nature while improving code clarity and IDE support. + +### Common Use Cases + +Batches flow through the system carrying different types of information: + +1. **environment data**: observations, rewards, done flags, and info from environment steps +2. **policy outputs**: actions, hidden states, and intermediate computations +3. **training data**: returns, advantages, and other computed quantities needed for learning +4. **sampling results**: batches sampled from the replay buffer for training + +### Operations + +Key operations on batches include: +- **attribute access**: dot notation (`batch.obs`) or dictionary-style access (`batch['obs']`) +- **slicing**: extract subsets with standard indexing (`batch[0:10]`, `batch[[1,3,5]]`) +- **stacking**: combine multiple batches along a new dimension +- **type conversion**: convert between NumPy and PyTorch with `to_numpy()` and `to_torch()` +- **null handling**: detect and remove null values with `hasnull()`, `isnull()`, and `dropnull()` + +The first dimension of all data in a Batch represents the batch size, enabling vectorized operations. + +## Buffer + +A **buffer** (i.e. class {class}`~tianshou.data.buffer.ReplayBuffer` and its variants) manages the storage and retrieval of experience data. +It acts as the memory of the learning system, preserving the temporal structure of episodes while providing efficient access patterns. + +### Storage Structure + +Buffers store data in a circular queue fashion with a fixed maximum size. When the buffer fills, new data overwrites the oldest stored experiences. +All data is stored within a single underlying Batch object, with the buffer managing: +- **pointer tracking**: current insertion position +- **episode boundaries**: which transitions belong to which episodes +- **temporal relationships**: the sequential order of transitions + +### Reserved Keys + +Buffers use a standard set of keys for storing transitions: +- `obs`: Observation at time t +- `act`: Action taken at time t +- `rew`: Reward received at time t +- `terminated`: True if the episode ended naturally at time t +- `truncated`: True if the episode was cut off at time t (e.g., time limit) +- `done`: Automatically inferred as `terminated or truncated` +- `obs_next`: Observation at time t+1 +- `info`: Additional information from the environment +- `policy`: Intermediate policy computations to be stored + +### Core Operations + +**adding data**: The {meth}`~tianshou.data.buffer.buffer_base.ReplayBuffer.add` method stores new transitions, +automatically handling episode boundaries and computing episode statistics (return, length) +when episodes complete. + +**sampling**: The {meth}`~tianshou.data.buffer.buffer_base.ReplayBuffer.sample` method retrieves batches of experiences for training, +returning both the sampled batch and the corresponding indices. +The sample size can be specified, or set to 0 to retrieve all available data. + +**temporal navigation**: The {meth}`~tianshou.data.buffer.buffer_base.ReplayBuffer.prev` and {meth}`~tianshou.data.buffer.ReplayBuffer.next` +methods enable traversal along the temporal sequence, respecting episode boundaries. +This is essential for computing n-step returns and other time-dependent quantities. + +**persistence**: Buffers support saving and loading via pickle or HDF5 format, enabling dataset collection and offline learning. + +### Buffer Variants + +Tianshou provides specialized buffer types: + +- **{class}`~tianshou.data.buffer.buffer_base.ReplayBuffer`**: the standard buffer for single environments +- **{class}`~tianshou.data.buffer.vecbuf.VectorReplayBuffer`**: manages separate sub-buffers for multiple parallel environments while maintaining chronological order +- **{class}`~tianshou.data.buffer.prio.PrioritizedReplayBuffer`**: samples transitions based on their TD-error or other priority metrics, using an efficient segment tree implementation + +### Advanced Features + +Buffers support sophisticated use cases: +- **frame stacking**: automatically stacks consecutive observations (useful for RNN inputs or Atari) +- **memory optimization**: option to skip storing next observations (useful for Atari where they can be inferred) +- **multi-modal observations**: handle observations with multiple components (e.g., image + vector) + +## Logger + +The **{class}`~tianshou.utils.logger.logger_base.BaseLogger`** abstraction provides a unified interface for recording and tracking training progress, metrics, and statistics. +It decouples the training loop from the specifics of where and how data is logged. + +### Purpose + +Loggers serve several essential functions: +- **progress tracking**: record timesteps, episodes, and epochs as training progresses +- **metric collection**: store performance indicators like rewards, losses, and success rates +- **experiment organization**: manage different data scopes (training, testing, updating) +- **reproducibility**: save training curves and hyperparameters for later analysis + +### Logging Scopes + +The framework organizes logged data into distinct scopes: +- **train data**: metrics from the training collector (episode returns, steps, collection speed) +- **test data**: evaluation metrics from the test collector +- **update data**: learning statistics from the algorithm (losses, gradients, learning rates) +- **info data**: additional custom metrics or metadata + +Each scope has a corresponding log method (`log_train_data`, `log_test_data`, `log_update_data`, `log_info_data`) that the trainer calls at appropriate times. + +### Implementations + +Tianshou provides several logger implementations: +- **{class}`~tianshou.utils.logger.tensorboard.TensorboardLogger`**: writes to TensorBoard format for visualization with TensorBoard +- **{class}`~tianshou.utils.logger.wandb.WandbLogger`**: integrates with Weights & Biases for cloud-based experiment tracking + +All implementations inherit from {class}`~tianshou.utils.logger.logger_base.BaseLogger` and share a common interface, +making it easy to switch between logging backends or use multiple loggers simultaneously. + +### Data Preparation + +Before writing, loggers prepare data through the `prepare_dict_for_logging` method, which can filter, transform, or aggregate metrics. +The `write` method then persists the prepared data to the logging backend with an associated step count. + +## How They Work Together + +These seven abstractions collaborate to enable reinforcement learning: + +1. The **Trainer** initializes and orchestrates the training process. +2. The **Collector** uses the **Policy** to gather experience from environments. +3. Collected transitions are stored in the **Buffer** extracted as **Batches**. +4. The **Algorithm** samples from the **Buffer**, preprocesses the data, and updates the **Policy**. +5. The **Logger** records metrics throughout the process. +6. The cycle repeats until training completes. + +This modular design allows each component to focus on its specific responsibility while maintaining clean interfaces. +You can customize individual components (e.g., implementing a new Algorithm or Buffer) without affecting the others, +making Tianshou both powerful and flexible. diff --git a/docs/01_user_guide/index.rst b/docs/01_user_guide/index.rst new file mode 100644 index 000000000..b9210906c --- /dev/null +++ b/docs/01_user_guide/index.rst @@ -0,0 +1,5 @@ +User Guide +========== + +The user guide provides an introduction to core concepts, establishes the glossary of terms, +introduces Tianshou's dual API architecture and provides an overview of important abstractions. \ No newline at end of file diff --git a/docs/02_deep_dives/0_intro.md b/docs/02_deep_dives/0_intro.md new file mode 100644 index 000000000..72762477f --- /dev/null +++ b/docs/02_deep_dives/0_intro.md @@ -0,0 +1,4 @@ +# Deep Dives + +Our deep dives are a collection of executable tutorials on some of the internal representations used by Tianshou. +Provided as notebooks, you can run them directly in Colab or download them to run them locally. diff --git a/docs/02_deep_dives/L1_Batch.ipynb b/docs/02_deep_dives/L1_Batch.ipynb new file mode 100644 index 000000000..3c87344eb --- /dev/null +++ b/docs/02_deep_dives/L1_Batch.ipynb @@ -0,0 +1,1471 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Batch: Tianshou's Core Data Structure\n", + "\n", + "The `Batch` class is Tianshou's fundamental data structure for efficiently storing and manipulating heterogeneous data in reinforcement learning. This tutorial provides comprehensive guidance on understanding its conceptual foundations, operational behavior, and best practices.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "from typing import cast\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from torch.distributions import Categorical, Normal\n", + "\n", + "from tianshou.data import Batch\n", + "from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Introduction: Why Batch?\n", + "\n", + "### The Challenge in Reinforcement Learning\n", + "\n", + "Reinforcement learning algorithms face a fundamental data management challenge:\n", + "\n", + "1. **Diverse Data Requirements**: Different RL algorithms need different data fields:\n", + " - Basic algorithms: `state`, `action`, `reward`, `done`, `next_state`\n", + " - Actor-Critic: additionally `advantages`, `returns`, `values`\n", + " - Policy Gradient: additionally `log_probs`, `old_log_probs`\n", + " - Off-policy: additionally `priority_weights`\n", + "\n", + "2. **Heterogeneous Observation Spaces**: Environments return diverse observation types:\n", + " - Simple: vectors (`np.array([1.0, 2.0, 3.0])`)\n", + " - Complex: images (`np.array(shape=(84, 84, 3))`)\n", + " - Hybrid: dictionaries combining multiple modalities\n", + " ```python\n", + " obs = {\n", + " 'camera': np.array(shape=(64, 64, 3)),\n", + " 'velocity': np.array([1.2, 0.5]),\n", + " 'inventory': np.array([5, 2, 0])\n", + " }\n", + " ```\n", + "\n", + "3. **Data Flow Across Components**: Data must flow seamlessly through:\n", + " - Collectors (gathering experience from environments)\n", + " - Replay Buffers (storing and sampling transitions)\n", + " - Policies and Algorithms (learning and inference)\n", + "\n", + "### Why Not Alternatives?\n", + "\n", + "#### Plain Dictionaries\n", + "Dictionaries lack essential features\n", + "```python\n", + "data = {'obs': np.array([1, 2]), 'reward': np.array([1.0, 2.0])}\n", + "```\n", + "\n", + "They would work in principle but has no shape/length semantics, no indexing, and no type safety.\n", + "\n", + "#### TensorDict\n", + "While `TensorDict` (used in `pytorch-rl`) is a powerful alternative:\n", + "- **Batch supports arbitrary objects**, not just tensors (useful for object-dtype arrays, custom types)\n", + "- **Batch has better type checking** via `BatchProtocol` (enables IDE autocompletion)\n", + "- **Batch preceded TensorDict** and provides a stable foundation for Tianshou\n", + "- **TensorDict isn't part of core PyTorch** (external dependency)\n", + "\n", + "### What is Batch?\n", + "\n", + "**Batch = Dictionary + Array hybrid with RL-specific features**\n", + "\n", + "Key capabilities:\n", + "- **Dict-like**: Key-value storage with attribute access (`batch.obs`, `batch.reward`)\n", + "- **Array-like**: Shape, indexing, slicing (`batch[0]`, `batch[:10]`, `batch.shape`)\n", + "- **Hierarchical**: Nested structures for complex data\n", + "- **Type-safe**: Protocol-based typing for IDE support\n", + "- **RL-aware**: Special handling for distributions, missing values, heterogeneous aggregation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Core Concepts\n", + "\n", + "### Hierarchical Named Tensors\n", + "\n", + "Batch stores **hierarchical named tensors** - collections of tensors whose identifiers form a structured hierarchy. Consider tensors `[t1, t2, t3, t4]` with names `[name1, name2, name3, name4]`, where `name1` and `name2` are under namespace `name0`. The fully qualified name of `t1` is `name0.name1`.\n", + "\n", + "### Tree Structure Visualization\n", + "\n", + "The structure can be visualized as a tree with:\n", + "- **Root**: The Batch object itself\n", + "- **Internal nodes**: Keys (names)\n", + "- **Leaf nodes**: Values (scalars, arrays, tensors)\n", + "\n", + "```mermaid\n", + "graph TD\n", + " root[\"Batch (root)\"]\n", + " root --> obs[\"obs\"]\n", + " root --> act[\"act\"]\n", + " root --> rew[\"rew\"]\n", + " obs --> camera[\"camera\"]\n", + " obs --> sensory[\"sensory\"]\n", + " camera --> cam_data[\"np.array(3,3)\"]\n", + " sensory --> sens_data[\"np.array(5,)\"]\n", + " act --> act_data[\"np.array(2,)\"]\n", + " rew --> rew_data[\"3.66\"]\n", + " \n", + " style root fill:#e1f5ff\n", + " style obs fill:#fff4e1\n", + " style act fill:#fff4e1\n", + " style rew fill:#fff4e1\n", + " style camera fill:#ffe1f5\n", + " style sensory fill:#ffe1f5\n", + " style cam_data fill:#e8f5e1\n", + " style sens_data fill:#e8f5e1\n", + " style act_data fill:#e8f5e1\n", + " style rew_data fill:#e8f5e1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example: hierarchical structure\n", + "data = {\n", + " \"action\": np.array([1.0, 2.0, 3.0]),\n", + " \"reward\": 3.66,\n", + " \"obs\": {\n", + " \"camera\": np.zeros((3, 3)),\n", + " \"sensory\": np.ones(5),\n", + " },\n", + "}\n", + "\n", + "batch = Batch(data)\n", + "print(batch)\n", + "print(\"\\nAccessing nested values:\")\n", + "print(f\"batch.obs.camera.shape = {batch.obs.camera.shape}\")\n", + "print(f\"batch.obs.sensory = {batch.obs.sensory}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data Flow in RL Pipeline\n", + "\n", + "Batch facilitates data flow throughout the RL pipeline:\n", + "\n", + "```mermaid\n", + "graph LR\n", + " A[Environment] -->|ObsBatchProtocol| B[Collector]\n", + " B -->|RolloutBatchProtocol| C[Replay Buffer]\n", + " C -->|RolloutBatchProtocol| D[Policy]\n", + " D -->|ActBatchProtocol| A\n", + " D -->|BatchWithAdvantages| E[Algorithm/Trainer]\n", + " E --> D\n", + " \n", + " style A fill:#e1f5ff\n", + " style B fill:#fff4e1\n", + " style C fill:#ffe1f5\n", + " style D fill:#e8f5e1\n", + " style E fill:#f5e1e1\n", + "```\n", + "\n", + "Each arrow represents a specific `BatchProtocol` that defines what fields are expected at that stage." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Basic Operations\n", + "\n", + "### 3.1 Construction\n", + "\n", + "Batch objects can be constructed in several ways:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# From keyword arguments\n", + "batch1 = Batch(a=4, b=[5, 5], c=\"hello\")\n", + "print(\"From kwargs:\", batch1)\n", + "\n", + "# From dictionary\n", + "batch2 = Batch({\"a\": 4, \"b\": [5, 5], \"c\": \"hello\"})\n", + "print(\"\\nFrom dict:\", batch2)\n", + "\n", + "# From list of dictionaries (automatically stacked)\n", + "batch3 = Batch([{\"a\": 1, \"b\": 2}, {\"a\": 3, \"b\": 4}])\n", + "print(\"\\nFrom list of dicts:\", batch3)\n", + "\n", + "# Nested batch\n", + "batch4 = Batch(obs=Batch(x=1, y=2), act=5)\n", + "print(\"\\nNested:\", batch4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 Content Rules\n", + "\n", + "Understanding what Batch can store and how it converts data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Keys must be strings\n", + "batch = Batch()\n", + "batch.key1 = \"value\"\n", + "batch.key2 = np.array([1, 2, 3])\n", + "print(\"Keys:\", list(batch.keys()))\n", + "\n", + "# Automatic conversions\n", + "demo = Batch(\n", + " scalar_int=5, # → np.array(5)\n", + " scalar_float=3.14, # → np.array(3.14)\n", + " list_nums=[1, 2, 3], # → np.array([1, 2, 3])\n", + " list_mixed=[1, \"hello\", None], # → np.array([1, \"hello\", None], dtype=object)\n", + " dict_val={\"x\": 1, \"y\": 2}, # → Batch(x=1, y=2)\n", + ")\n", + "\n", + "print(\"\\nAutomatic conversions:\")\n", + "print(f\"scalar_int type: {type(demo.scalar_int)}, value: {demo.scalar_int}\")\n", + "print(f\"list_nums type: {type(demo.list_nums)}, dtype: {demo.list_nums.dtype}\")\n", + "print(f\"list_mixed dtype: {demo.list_mixed.dtype}\")\n", + "print(f\"dict_val type: {type(demo.dict_val)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Important conversions:**\n", + "- Lists of numbers → NumPy arrays\n", + "- Lists with mixed types → Object-dtype arrays\n", + "- Dictionaries → Batch objects (recursively)\n", + "- Scalars → NumPy scalars" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3 Access Patterns\n", + "\n", + "**Important: Understanding Iteration**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = Batch(a=[1, 2, 3], b=[4, 5, 6])\n", + "\n", + "# Attribute vs dictionary access (equivalent)\n", + "print(\"Attribute access:\", batch.a)\n", + "print(\"Dict access:\", batch[\"a\"])\n", + "\n", + "# Getting keys\n", + "print(\"\\nKeys:\", list(batch.keys()))\n", + "\n", + "# Gotcha: Iteration is array like, not over keys\n", + "print(\"\\nIteration behavior:\")\n", + "print(\"for x in batch iterates over batch[0], batch[1], ..., NOT keys!\")\n", + "for i, item in enumerate(batch):\n", + " print(f\"batch[{i}] = {item}\")\n", + "\n", + "# This is different from dict behavior!\n", + "regular_dict = {\"a\": [1, 2, 3], \"b\": [4, 5, 6]}\n", + "print(\"\\nCompare with dict iteration (iterates over keys):\")\n", + "for key in regular_dict:\n", + " print(f\"key = {key}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.4 Indexing & Slicing\n", + "\n", + "Batch supports NumPy-like indexing and slicing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])\n", + "\n", + "print(\"Original batch shape:\", batch.shape)\n", + "print(\"Original batch length:\", len(batch))\n", + "\n", + "# Single index\n", + "print(\"\\nbatch[0]:\")\n", + "print(batch[0])\n", + "\n", + "# Slicing\n", + "print(\"\\nbatch[:1]:\")\n", + "print(batch[:1])\n", + "\n", + "# Advanced indexing\n", + "print(\"\\nbatch[[0, 1]]:\")\n", + "print(batch[[0, 1]])\n", + "\n", + "# Multi-dimensional indexing\n", + "print(\"\\nbatch[:, 0] (first column of all arrays):\")\n", + "print(batch[:, 0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Broadcasting and in-place operations\n", + "batch[:, 1] += 10\n", + "print(\"After batch[:, 1] += 10:\")\n", + "print(batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.5 Stack, Concatenate, and Split\n", + "\n", + "Combining and splitting batches:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Stack: adds a new dimension\nbatch1 = Batch(a=np.array([1, 2]), b=np.array([5, 6]))\nbatch2 = Batch(a=np.array([3, 4]), b=np.array([7, 8]))\n\nstacked = Batch.stack([batch1, batch2])\nprint(\"Stacked:\")\nprint(stacked)\nprint(f\"Shape: {stacked.shape}\")\n\n# Concatenate: extends along existing dimension\nconcatenated = Batch.cat([batch1, batch2])\nprint(\"\\nConcatenated:\")\nprint(concatenated)\nprint(f\"Shape: {concatenated.shape}\")" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Split\n", + "batch = Batch(a=np.arange(10), b=np.arange(10, 20))\n", + "splits = list(batch.split(size=3, shuffle=False))\n", + "print(f\"Split into {len(splits)} batches:\")\n", + "for i, split in enumerate(splits):\n", + " print(f\"Split {i}: a={split.a}, length={len(split)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.6 Data Type Conversion\n", + "\n", + "Converting between NumPy and PyTorch:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create batch with NumPy arrays\n", + "batch = Batch(a=np.zeros((3, 4)), b=np.ones(5))\n", + "print(\"Original (NumPy):\")\n", + "print(f\"batch.a type: {type(batch.a)}\")\n", + "\n", + "# Convert to PyTorch (in-place)\n", + "batch.to_torch_(dtype=torch.float32, device=\"cpu\")\n", + "print(\"\\nAfter to_torch_():\")\n", + "print(f\"batch.a type: {type(batch.a)}\")\n", + "print(f\"batch.a dtype: {batch.a.dtype}\")\n", + "\n", + "# Convert back to NumPy (in-place)\n", + "batch.to_numpy_()\n", + "print(\"\\nAfter to_numpy_():\")\n", + "print(f\"batch.a type: {type(batch.a)}\")\n", + "\n", + "# Non-in-place versions return a new batch\n", + "batch_torch = batch.to_torch()\n", + "print(\"\\nOriginal batch unchanged:\", type(batch.a))\n", + "print(\"New batch:\", type(batch_torch.a))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Type Safety with Protocols\n", + "\n", + "### Why Protocols?\n", + "\n", + "Batch needs to be **flexible** (not fixed fields like dataclasses) but we still want **type safety** and **IDE autocompletion**. Protocols provide the best of both worlds:\n", + "\n", + "- **Runtime flexibility**: Add any fields dynamically\n", + "- **Static type checking**: Type checkers (mypy, pyright) verify correct usage\n", + "- **IDE support**: Autocompletion for expected fields\n", + "\n", + "### What is BatchProtocol?\n", + "\n", + "A `Protocol` defines an interface without implementation. Think of it as a contract: \"any object with these fields is valid.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Creating a typed batch using cast\n", + "# This enables IDE autocompletion and type checking\n", + "\n", + "# ActBatchProtocol: just needs 'act' field\n", + "act_batch = cast(ActBatchProtocol, Batch(act=np.array([1, 2, 3])))\n", + "print(\"ActBatchProtocol:\", act_batch.act)\n", + "\n", + "# ObsBatchProtocol: needs 'obs' and 'info' fields\n", + "obs_batch = cast(\n", + " ObsBatchProtocol,\n", + " Batch(obs=np.array([[1.0, 2.0], [3.0, 4.0]]), info=np.array([{}, {}], dtype=object)),\n", + ")\n", + "print(\"\\nObsBatchProtocol:\", obs_batch.obs)\n", + "\n", + "# RolloutBatchProtocol: needs obs, obs_next, act, rew, terminated, truncated\n", + "rollout_batch = cast(\n", + " RolloutBatchProtocol,\n", + " Batch(\n", + " obs=np.array([[1.0, 2.0], [3.0, 4.0]]),\n", + " obs_next=np.array([[2.0, 3.0], [4.0, 5.0]]),\n", + " act=np.array([0, 1]),\n", + " rew=np.array([1.0, 2.0]),\n", + " terminated=np.array([False, True]),\n", + " truncated=np.array([False, False]),\n", + " info=np.array([{}, {}], dtype=object),\n", + " ),\n", + ")\n", + "print(\"\\nRolloutBatchProtocol reward:\", rollout_batch.rew)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Protocol Hierarchy\n", + "\n", + "Tianshou defines a hierarchy of protocols for different use cases:\n", + "\n", + "```mermaid\n", + "graph TD\n", + " BP[BatchProtocol
Base protocol] --> OBP[ObsBatchProtocol
obs, info]\n", + " BP --> ABP[ActBatchProtocol
act]\n", + " ABP --> ASBP[ActStateBatchProtocol
act, state]\n", + " OBP --> RBP[RolloutBatchProtocol
+obs_next, act, rew,
terminated, truncated]\n", + " RBP --> BWRP[BatchWithReturnsProtocol
+returns]\n", + " BWRP --> BWAP[BatchWithAdvantagesProtocol
+adv, v_s]\n", + " ASBP --> MOBP[ModelOutputBatchProtocol
+logits]\n", + " MOBP --> DBP[DistBatchProtocol
+dist]\n", + " DBP --> DLPBP[DistLogProbBatchProtocol
+log_prob]\n", + " BWAP --> LOPBP[LogpOldProtocol
+logp_old]\n", + " \n", + " style BP fill:#e1f5ff\n", + " style OBP fill:#fff4e1\n", + " style ABP fill:#fff4e1\n", + " style RBP fill:#ffe1f5\n", + " style BWRP fill:#e8f5e1\n", + " style BWAP fill:#e8f5e1\n", + " style DBP fill:#f5e1e1\n", + " style LOPBP fill:#e1e1f5\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using Protocols in Functions\n", + "\n", + "Protocols enable type-safe function signatures:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def process_observations(batch: ObsBatchProtocol) -> np.ndarray:\n", + " \"\"\"Function that expects observations.\n", + "\n", + " IDE will autocomplete batch.obs and batch.info!\n", + " Type checker will verify these fields exist.\n", + " \"\"\"\n", + " # IDE knows batch.obs exists\n", + " return batch.obs if isinstance(batch.obs, np.ndarray) else np.array(batch.obs)\n", + "\n", + "\n", + "def compute_advantage(batch: RolloutBatchProtocol) -> np.ndarray:\n", + " \"\"\"Function that expects rollout data.\n", + "\n", + " IDE will autocomplete batch.rew, batch.obs_next, etc.\n", + " \"\"\"\n", + " # Simplified advantage computation\n", + " return batch.rew # IDE knows this exists\n", + "\n", + "\n", + "# Example usage\n", + "obs_data = Batch(obs=np.array([1, 2, 3]), info=np.array([{}], dtype=object))\n", + "result = process_observations(obs_data)\n", + "print(\"Processed obs:\", result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Key Protocol Types:**\n", + "\n", + "- `ActBatchProtocol`: Just actions (for simple policies)\n", + "- `ObsBatchProtocol`: Observations and info\n", + "- `RolloutBatchProtocol`: Complete transitions (obs, act, rew, done, obs_next)\n", + "- `BatchWithReturnsProtocol`: Rollouts + computed returns\n", + "- `BatchWithAdvantagesProtocol`: Returns + advantages and values\n", + "- `DistBatchProtocol`: Contains distribution objects\n", + "- `LogpOldProtocol`: For importance sampling (PPO, etc.)\n", + "\n", + "See `tianshou/data/types.py` for the complete list!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Distribution Slicing\n", + "\n", + "### Why Special Handling?\n", + "\n", + "PyTorch `Distribution` objects need special slicing because they're not simple arrays. When you slice `batch[0:2]`, Tianshou needs to slice the underlying distribution parameters correctly.\n", + "\n", + "### Supported Distributions\n", + "\n", + "Tianshou supports slicing for:\n", + "- `Categorical`: Discrete distributions\n", + "- `Normal`: Continuous Gaussian distributions\n", + "- `Independent`: Wraps other distributions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Categorical distribution\n", + "probs = torch.tensor([[0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])\n", + "dist = Categorical(probs=probs)\n", + "batch = Batch(dist=dist, values=np.array([1, 2, 3]))\n", + "\n", + "print(\"Original batch length:\", len(batch))\n", + "print(\"Original dist probs shape:\", batch.dist.probs.shape)\n", + "\n", + "# Slicing automatically handles the distribution\n", + "sliced = batch[0:2]\n", + "print(\"\\nSliced batch length:\", len(sliced))\n", + "print(\"Sliced dist probs shape:\", sliced.dist.probs.shape)\n", + "print(\"Sliced values:\", sliced.values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Normal distribution\n", + "loc = torch.tensor([0.0, 1.0, 2.0])\n", + "scale = torch.tensor([1.0, 1.0, 1.0])\n", + "normal_dist = Normal(loc=loc, scale=scale)\n", + "batch_normal = Batch(dist=normal_dist, actions=np.array([0.5, 1.5, 2.5]))\n", + "\n", + "print(\"Normal distribution batch:\")\n", + "print(f\"Original mean: {batch_normal.dist.mean}\")\n", + "\n", + "# Index a single element\n", + "single = batch_normal[1]\n", + "print(f\"\\nSingle element mean: {single.dist.mean}\")\n", + "print(f\"Single element action: {single.actions}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Converting to At Least 2D\n", + "\n", + "Sometimes you need to ensure distributions have a batch dimension:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tianshou.data.batch import dist_to_atleast_2d\n", + "\n", + "# Scalar distribution (no batch dimension)\n", + "scalar_dist = Categorical(probs=torch.tensor([0.3, 0.7]))\n", + "print(\"Scalar dist batch_shape:\", scalar_dist.batch_shape)\n", + "\n", + "# Convert to have batch dimension\n", + "batched_dist = dist_to_atleast_2d(scalar_dist)\n", + "print(\"Batched dist batch_shape:\", batched_dist.batch_shape)\n", + "\n", + "# For entire batch\n", + "scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3)))\n", + "print(\"\\nBefore to_at_least_2d:\", scalar_batch.dist.batch_shape)\n", + "\n", + "batch_2d = scalar_batch.to_at_least_2d()\n", + "print(\"After to_at_least_2d:\", batch_2d.dist.batch_shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use Cases\n", + "\n", + "Distribution slicing is used in:\n", + "- **Policy sampling**: When policies output distributions, slicing batches preserves distribution structure\n", + "- **Replay buffer sampling**: Distributions are stored and retrieved correctly\n", + "- **Advantage computation**: Computing log probabilities on subsets of data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Advanced Topics\n", + "\n", + "### 6.1 Key Reservation\n", + "\n", + "Sometimes you know what keys you'll need but don't have values yet. Reserve keys using empty `Batch()` objects:\n", + "\n", + "```mermaid\n", + "graph TD\n", + " root[\"Batch\"]\n", + " root --> a[\"key1: np.array([1,2,3])\"]\n", + " root --> b[\"key2: Batch() (reserved)\"]\n", + " root --> c[\"key3\"]\n", + " c --> c1[\"subkey1: Batch() (reserved)\"]\n", + " c --> c2[\"subkey2: np.array([4,5])\"]\n", + " \n", + " style root fill:#e1f5ff\n", + " style a fill:#e8f5e1\n", + " style b fill:#ffcccc\n", + " style c fill:#fff4e1\n", + " style c1 fill:#ffcccc\n", + " style c2 fill:#e8f5e1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Reserving keys\n", + "batch = Batch(\n", + " known_field=np.array([1, 2]),\n", + " future_field=Batch(), # Reserved for later\n", + ")\n", + "print(\"Batch with reserved key:\")\n", + "print(batch)\n", + "\n", + "# Later, assign actual data\n", + "batch.future_field = np.array([3, 4])\n", + "print(\"\\nAfter assignment:\")\n", + "print(batch)\n", + "\n", + "# Nested reservation\n", + "batch2 = Batch(\n", + " obs=Batch(\n", + " camera=Batch(), # Reserved\n", + " lidar=np.zeros(10),\n", + " )\n", + ")\n", + "print(\"\\nNested reservation:\")\n", + "print(batch2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.2 Length and Shape Semantics\n", + "\n", + "Understanding when `len()` works and what `shape` means:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Normal case: all tensors same length\n", + "batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5, 6]))\n", + "print(\"Normal batch:\")\n", + "print(f\"len(batch1) = {len(batch1)}\")\n", + "print(f\"batch1.shape = {batch1.shape}\")\n", + "\n", + "# Scalars have no length\n", + "batch2 = Batch(a=5, b=10)\n", + "print(\"\\nScalar batch:\")\n", + "print(f\"batch2.shape = {batch2.shape}\")\n", + "try:\n", + " print(f\"len(batch2) = {len(batch2)}\")\n", + "except TypeError as e:\n", + " print(f\"len(batch2) raises TypeError: {e}\")\n", + "\n", + "# Mixed lengths: returns minimum\n", + "batch3 = Batch(a=[1, 2], b=[3, 4, 5])\n", + "print(\"\\nMixed length batch:\")\n", + "print(f\"len(batch3) = {len(batch3)} (minimum of 2 and 3)\")\n", + "\n", + "# Reserved keys are ignored\n", + "batch4 = Batch(a=[1, 2, 3], reserved=Batch())\n", + "print(\"\\nBatch with reserved key:\")\n", + "print(f\"len(batch4) = {len(batch4)} (reserved key ignored)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.3 Empty Batches\n", + "\n", + "Understanding different meanings of \"empty\":" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. No keys at all\n", + "empty1 = Batch()\n", + "print(\"No keys:\")\n", + "print(f\"len(empty1.get_keys()) = {len(list(empty1.get_keys()))}\")\n", + "print(f\"len(empty1) = {len(empty1)}\")\n", + "\n", + "# 2. Has keys but they're all reserved\n", + "empty2 = Batch(a=Batch(), b=Batch())\n", + "print(\"\\nReserved keys only:\")\n", + "print(f\"len(empty2.get_keys()) = {len(list(empty2.get_keys()))}\")\n", + "print(f\"len(empty2) = {len(empty2)}\")\n", + "\n", + "# 3. Has data but length is 0\n", + "empty3 = Batch(a=np.array([]), b=np.array([]))\n", + "print(\"\\nZero-length arrays:\")\n", + "print(f\"len(empty3.get_keys()) = {len(list(empty3.get_keys()))}\")\n", + "print(f\"len(empty3) = {len(empty3)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Checking emptiness:**\n", + "- `len(batch.get_keys()) == 0`: No keys (completely empty)\n", + "- `len(batch) == 0`: No data elements (may have reserved keys)\n", + "\n", + "**The `.empty()` and `.empty_()` methods:**\n", + "These reset values to zeros/None, different from checking emptiness:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = Batch(a=[1, 2, 3], b=[\"x\", \"y\", \"z\"])\n", + "print(\"Original:\", batch)\n", + "\n", + "# Empty specific index\n", + "batch[0] = Batch.empty(batch[0])\n", + "print(\"\\nAfter emptying index 0:\")\n", + "print(batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.4 Heterogeneous Aggregation\n", + "\n", + "Stacking/concatenating batches with different keys:\n", + "\n", + "```mermaid\n", + "graph LR\n", + " A[\"Batch(a=[1,2], c=5)\"] --> C[\"Batch.stack\"]\n", + " B[\"Batch(b=[3,4], c=6)\"] --> C\n", + " C --> D[\"Batch(a=[[1,2],[0,0]],
b=[[0,0],[3,4]],
c=[5,6])\"]\n", + " \n", + " style A fill:#e1f5ff\n", + " style B fill:#fff4e1\n", + " style C fill:#ffe1f5\n", + " style D fill:#e8f5e1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Stack with different keys (missing keys padded with zeros)\n", + "batch_a = Batch(a=np.ones((2, 3)), shared=np.array([1, 2]))\n", + "batch_b = Batch(b=np.zeros((2, 4)), shared=np.array([3, 4]))\n", + "\n", + "stacked = Batch.stack([batch_a, batch_b])\n", + "print(\"Stacked batch:\")\n", + "print(f\"a.shape = {stacked.a.shape} (padded with zeros for batch_b)\")\n", + "print(f\"b.shape = {stacked.b.shape} (padded with zeros for batch_a)\")\n", + "print(f\"shared.shape = {stacked.shared.shape} (in both batches)\")\n", + "print(stacked)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.5 Missing Values\n", + "\n", + "Handling `None` and `NaN` values:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Batch with missing values\n", + "batch = Batch(a=[1, 2, None, 4], b=[5.0, np.nan, 7.0, 8.0], c=[[1, 2], [3, 4], [5, 6], [7, 8]])\n", + "\n", + "# Check for nulls\n", + "print(\"Has null?\", batch.hasnull())\n", + "\n", + "# Get null mask\n", + "null_mask = batch.isnull()\n", + "print(\"\\nNull mask:\")\n", + "print(f\"a: {null_mask.a}\")\n", + "print(f\"b: {null_mask.b}\")\n", + "\n", + "# Drop rows with any null\n", + "clean_batch = batch.dropnull()\n", + "print(\"\\nAfter dropnull() (keeps rows 0 and 3):\")\n", + "print(f\"Length: {len(clean_batch)}\")\n", + "print(f\"a: {clean_batch.a}\")\n", + "print(f\"b: {clean_batch.b}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.6 Value Transformations\n", + "\n", + "Applying functions to all values recursively:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = Batch(a=np.array([1, 2, 3]), nested=Batch(b=np.array([4.0, 5.0]), c=np.array([6, 7, 8])))\n", + "\n", + "# Apply transformation (returns new batch)\n", + "doubled = batch.apply_values_transform(lambda x: x * 2)\n", + "print(\"Original batch a:\", batch.a)\n", + "print(\"Doubled batch a:\", doubled.a)\n", + "print(\"Doubled nested.b:\", doubled.nested.b)\n", + "\n", + "# In-place transformation\n", + "batch.apply_values_transform(lambda x: x + 10, inplace=True)\n", + "print(\"\\nAfter in-place +10:\")\n", + "print(\"a:\", batch.a)\n", + "print(\"nested.b:\", batch.nested.b)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Surprising Behaviors & Gotchas\n", + "\n", + "### Iteration Does NOT Iterate Over Keys!\n", + "\n", + "**This is the most common source of confusion:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = Batch(a=[1, 2, 3], b=[4, 5, 6])\n", + "\n", + "print(\"WRONG: This doesn't iterate over keys!\")\n", + "for item in batch:\n", + " print(f\"item = {item}\") # Prints batch[0], batch[1], batch[2]\n", + "\n", + "print(\"\\nCORRECT: To iterate over keys:\")\n", + "for key in batch.keys():\n", + " print(f\"key = {key}\")\n", + "\n", + "print(\"\\nCORRECT: To iterate over key-value pairs:\")\n", + "for key, value in batch.items():\n", + " print(f\"{key} = {value}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Automatic Type Conversions\n", + "\n", + "Be aware of these automatic conversions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Lists become arrays\n", + "batch = Batch(a=[1, 2, 3])\n", + "print(\"List → array:\", type(batch.a), batch.a.dtype)\n", + "\n", + "# Dicts become Batch\n", + "batch = Batch(a={\"x\": 1, \"y\": 2})\n", + "print(\"Dict → Batch:\", type(batch.a))\n", + "\n", + "# Scalars become numpy scalars\n", + "batch = Batch(a=5)\n", + "print(\"Scalar → np.ndarray:\", type(batch.a), batch.a)\n", + "\n", + "# Mixed types → object dtype\n", + "batch = Batch(a=[1, \"hello\", None])\n", + "print(\"Mixed → object:\", batch.a.dtype, batch.a)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Length Edge Cases" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. Scalars have no length\n", + "batch_scalar = Batch(a=5, b=10)\n", + "try:\n", + " len(batch_scalar)\n", + "except TypeError as e:\n", + " print(f\"Scalar batch: {e}\")\n", + "\n", + "# 2. Empty nested batches ignored in len()\n", + "batch_empty_nested = Batch(a=[1, 2, 3], b=Batch())\n", + "print(f\"\\nWith empty nested: len = {len(batch_empty_nested)} (ignores b)\")\n", + "\n", + "# 3. Different lengths: returns minimum\n", + "batch_different = Batch(a=[1, 2], b=[1, 2, 3, 4])\n", + "print(f\"Different lengths: len = {len(batch_different)} (minimum)\")\n", + "\n", + "# 4. None values don't affect length\n", + "batch_none = Batch(a=[1, 2, 3], b=None)\n", + "print(f\"With None: len = {len(batch_none)} (None ignored)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### String Keys Only" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Integer keys not allowed\n", + "try:\n", + " batch = Batch({1: \"value\", 2: \"other\"})\n", + "except AssertionError as e:\n", + " print(\"Integer keys not allowed:\", e)\n", + "\n", + "# String keys work\n", + "batch = Batch({\"key1\": \"value\", \"key2\": \"other\"})\n", + "print(\"\\nString keys work:\", list(batch.keys()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cat vs Stack Behavior\n", + "\n", + "Recent changes have made concatenation stricter about structure:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Stack pads missing keys with zeros\n", + "b1 = Batch(a=[1, 2])\n", + "b2 = Batch(b=[3, 4])\n", + "stacked = Batch.stack([b1, b2])\n", + "print(\"Stack (different keys):\")\n", + "print(f\" a: {stacked.a} (b2.a padded with 0)\")\n", + "print(f\" b: {stacked.b} (b1.b padded with 0)\")\n", + "\n", + "# Cat requires same structure now\n", + "b3 = Batch(a=[1, 2], b=[3, 4])\n", + "b4 = Batch(a=[5, 6], b=[7, 8])\n", + "concatenated = Batch.cat([b3, b4])\n", + "print(\"\\nCat (same keys):\")\n", + "print(f\" a: {concatenated.a}\")\n", + "print(f\" b: {concatenated.b}\")\n", + "\n", + "# Cat with different structures raises error\n", + "try:\n", + " Batch.cat([b1, b2]) # Different keys!\n", + "except ValueError:\n", + " print(\"\\nCat with different keys: ValueError raised\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Best Practices\n", + "\n", + "### When to Use Batch\n", + "\n", + "**Good use cases:**\n", + "- Collecting environment data (transitions, episodes)\n", + "- Storing replay buffer data\n", + "- Passing data between components (collector → buffer → policy)\n", + "- Handling heterogeneous observations (dict spaces)\n", + "\n", + "**Consider alternatives:**\n", + "- Simple scalar tracking (use regular variables)\n", + "- Pure tensor operations (use PyTorch tensors directly)\n", + "- Deeply nested arbitrary structures (use dataclasses)\n", + "\n", + "### Structuring Your Batches\n", + "\n", + "**Use protocols for type safety:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Good: Use protocols for clear interfaces\n", + "def train_step(batch: RolloutBatchProtocol) -> float:\n", + " \"\"\"IDE knows what fields exist.\"\"\"\n", + " loss = ((batch.rew - 0.5) ** 2).mean() # Type-safe\n", + " return float(loss)\n", + "\n", + "\n", + "# Create properly typed batch\n", + "train_batch = cast(\n", + " RolloutBatchProtocol,\n", + " Batch(\n", + " obs=np.random.randn(10, 4),\n", + " obs_next=np.random.randn(10, 4),\n", + " act=np.random.randint(0, 2, 10),\n", + " rew=np.random.randn(10),\n", + " terminated=np.zeros(10, dtype=bool),\n", + " truncated=np.zeros(10, dtype=bool),\n", + " info=np.array([{}] * 10, dtype=object),\n", + " ),\n", + ")\n", + "\n", + "loss = train_step(train_batch)\n", + "print(f\"Loss: {loss:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Consistent key naming:**\n", + "- Follow Tianshou conventions: `obs`, `act`, `rew`, `terminated`, `truncated`\n", + "- Use descriptive names: `camera_obs` not `co`\n", + "- Avoid name collisions with Batch methods: don't use `keys`, `items`, `get`, etc.\n", + "\n", + "**When to nest vs flatten:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Good: Nest related data\n", + "batch_nested = Batch(\n", + " obs=Batch(\n", + " camera=np.zeros((32, 64, 64, 3)), lidar=np.zeros((32, 100)), position=np.zeros((32, 3))\n", + " ),\n", + " act=np.zeros(32),\n", + ")\n", + "print(\"Nested structure for related obs:\")\n", + "print(f\" Access: batch.obs.camera.shape = {batch_nested.obs.camera.shape}\")\n", + "\n", + "# Less good: Flat structure loses semantic grouping\n", + "batch_flat = Batch(\n", + " camera=np.zeros((32, 64, 64, 3)),\n", + " lidar=np.zeros((32, 100)),\n", + " position=np.zeros((32, 3)),\n", + " act=np.zeros(32),\n", + ")\n", + "print(\"\\nFlat structure (works but less clear):\")\n", + "print(f\" Access: batch.camera.shape = {batch_flat.camera.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Performance Tips\n", + "\n", + "**Use in-place operations:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "batch = Batch(a=np.random.randn(1000, 100))\n", + "\n", + "# Creates copy\n", + "start = time.time()\n", + "for _ in range(100):\n", + " _ = batch.to_torch()\n", + "time_copy = time.time() - start\n", + "\n", + "# In-place (faster)\n", + "start = time.time()\n", + "for _ in range(100):\n", + " batch.to_torch_()\n", + " batch.to_numpy_()\n", + "time_inplace = time.time() - start\n", + "\n", + "print(f\"Copy: {time_copy:.4f}s\")\n", + "print(f\"In-place: {time_inplace:.4f}s\")\n", + "print(f\"Speedup: {time_copy / time_inplace:.1f}x\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Be mindful of copies:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "arr = np.array([1, 2, 3])\n", + "\n", + "# Default: creates reference (be careful!)\n", + "batch1 = Batch(a=arr)\n", + "batch1.a[0] = 999\n", + "print(f\"Original array modified: {arr}\") # Changed!\n", + "\n", + "# Explicit copy when needed\n", + "arr = np.array([1, 2, 3])\n", + "batch2 = Batch(a=arr, copy=True)\n", + "batch2.a[0] = 999\n", + "print(f\"Original array preserved: {arr}\") # Unchanged" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Avoid unnecessary conversions:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Inefficient: multiple conversions\n", + "batch = Batch(a=np.random.randn(100, 10))\n", + "batch.to_torch_()\n", + "batch.to_numpy_() # Unnecessary if we just need NumPy\n", + "\n", + "# Efficient: convert once, use many times\n", + "batch = Batch(a=np.random.randn(100, 10))\n", + "batch.to_torch_() # Convert once\n", + "# ... do torch operations ...\n", + "# Keep as torch if that's what you need!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Common Patterns\n", + "\n", + "**Pattern 1: Building batches incrementally**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Collect data from multiple steps\n", + "step_data = []\n", + "for i in range(5):\n", + " step_data.append({\"obs\": np.random.randn(4), \"act\": i, \"rew\": np.random.randn()})\n", + "\n", + "# Convert to batch (automatically stacks)\n", + "episode_batch = Batch(step_data)\n", + "print(\"Episode batch shape:\", episode_batch.shape)\n", + "print(\"obs shape:\", episode_batch.obs.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Pattern 2: Slicing for mini-batches**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Large batch\n", + "large_batch = Batch(obs=np.random.randn(100, 4), act=np.random.randint(0, 2, 100))\n", + "\n", + "# Split into mini-batches\n", + "batch_size = 32\n", + "for mini_batch in large_batch.split(batch_size, shuffle=True):\n", + " print(f\"Mini-batch size: {len(mini_batch)}\")\n", + " # Train on mini_batch...\n", + " break # Just show one iteration" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Pattern 3: Extending batches**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Start with some data\n", + "batch = Batch(obs=np.array([[1, 2], [3, 4]]), act=np.array([0, 1]))\n", + "print(\"Initial:\", len(batch))\n", + "\n", + "# Add more data\n", + "new_data = Batch(obs=np.array([[5, 6]]), act=np.array([1]))\n", + "batch.cat_(new_data)\n", + "print(\"After cat_:\", len(batch))\n", + "print(\"obs:\", batch.obs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Summary\n", + "\n", + "### Key Takeaways\n", + "\n", + "1. **Batch = Dict + Array**: Combines key-value storage with array operations\n", + "2. **Hierarchical Structure**: Perfect for complex RL data (nested observations, etc.)\n", + "3. **Type Safety via Protocols**: Use `BatchProtocol` subclasses for IDE support and type checking\n", + "4. **Special RL Features**: Distribution slicing, heterogeneous aggregation, missing value handling\n", + "5. **Remember**: Iteration is over indices, NOT keys!\n", + "\n", + "### Quick Reference\n", + "\n", + "| Operation | Code | Notes |\n", + "|-----------|------|-------|\n", + "| Create | `Batch(a=1, b=[2, 3])` | Auto-converts types |\n", + "| Access | `batch.a` or `batch[\"a\"]` | Equivalent |\n", + "| Index | `batch[0]`, `batch[:10]` | Returns sliced Batch |\n", + "| Iterate indices | `for item in batch:` | Yields batch[0], batch[1], ... |\n", + "| Iterate keys | `for k in batch.keys():` | Like dict |\n", + "| Stack | `Batch.stack([b1, b2])` | Adds dimension |\n", + "| Concatenate | `Batch.cat([b1, b2])` | Extends dimension |\n", + "| Split | `batch.split(size=10)` | Returns iterator |\n", + "| To PyTorch | `batch.to_torch_()` | In-place |\n", + "| To NumPy | `batch.to_numpy_()` | In-place |\n", + "| Transform | `batch.apply_values_transform(fn)` | Recursive |\n", + "\n", + "### Next Steps\n", + "\n", + "- **Collector Deep Dive**: See how Batch flows through data collection\n", + "- **Buffer Deep Dive**: Understand how Batch is stored and sampled\n", + "- **Policy Guide**: Learn how policies work with BatchProtocol\n", + "- **API Reference**: Full details at [Batch API documentation](https://tianshou.org/en/stable/api/tianshou.data.html#tianshou.data.Batch)\n", + "\n", + "### Questions?\n", + "\n", + "- Check the [Tianshou GitHub discussions](https://github.com/thu-ml/tianshou/discussions)\n", + "- Review [issue tracker](https://github.com/thu-ml/tianshou/issues) for known gotchas\n", + "- Read the [source code](https://github.com/thu-ml/tianshou/blob/master/tianshou/data/batch.py) - it's well-documented!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Appendix: Serialization & Advanced Topics\n", + "\n", + "### Pickle Support" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Batch objects are picklable\n", + "original = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n", + "\n", + "# Serialize and deserialize\n", + "serialized = pickle.dumps(original)\n", + "restored = pickle.loads(serialized)\n", + "\n", + "print(\"Original obs.a:\", original.obs.a)\n", + "print(\"Restored obs.a:\", restored.obs.a)\n", + "print(\"Equal:\", original == restored)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Advanced Indexing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Multi-dimensional data\n", + "batch = Batch(a=np.random.randn(5, 3, 2))\n", + "print(\"Original shape:\", batch.a.shape)\n", + "\n", + "# Various indexing operations\n", + "print(\"batch[0].a.shape:\", batch[0].a.shape)\n", + "print(\"batch[:, 0].a.shape:\", batch[:, 0].a.shape)\n", + "print(\"batch[[0, 2, 4]].a.shape:\", batch[[0, 2, 4]].a.shape)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/02_deep_dives/L2_Buffer.ipynb b/docs/02_deep_dives/L2_Buffer.ipynb new file mode 100644 index 000000000..a604dd5bf --- /dev/null +++ b/docs/02_deep_dives/L2_Buffer.ipynb @@ -0,0 +1,1826 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Buffer: Experience Replay in Tianshou\n", + "\n", + "The replay buffer is a fundamental component in reinforcement learning, particularly for off-policy algorithms. Tianshou's buffer implementation extends beyond simple data storage to provide sophisticated trajectory tracking, efficient sampling, and seamless integration with the RL training pipeline.\n", + "\n", + "This tutorial provides comprehensive coverage of Tianshou's buffer system, from basic concepts to advanced features and integration patterns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "import tempfile\n", + "\n", + "import numpy as np\n", + "\n", + "from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer, VectorReplayBuffer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Introduction: Why Buffers in Reinforcement Learning?\n", + "\n", + "### The Role of Experience Replay\n", + "\n", + "Experience replay is a critical technique in modern reinforcement learning that addresses three fundamental challenges:\n", + "\n", + "1. **Breaking Temporal Correlation**: Sequential experiences from an agent are highly correlated. Training directly on these sequences can lead to unstable learning. By storing experiences and sampling randomly, we break these correlations.\n", + "\n", + "2. **Sample Efficiency**: In RL, collecting data through environment interaction is often expensive. Experience replay allows us to reuse each experience multiple times for training, dramatically improving sample efficiency.\n", + "\n", + "3. **Mini-batch Training**: Modern deep learning requires mini-batch gradient descent. Buffers enable efficient batching of experiences for neural network training.\n", + "\n", + "### Why Not Alternatives?\n", + "\n", + "**Plain Python Lists**\n", + "- No efficient random sampling\n", + "- No automatic circular queue behavior\n", + "- No trajectory boundary tracking\n", + "- Poor memory management for large datasets\n", + "\n", + "**Simple Batch Storage**\n", + "- No automatic overwriting when full\n", + "- No episode metadata (returns, lengths)\n", + "- No methods for boundary navigation (prev/next)\n", + "- No specialized sampling strategies\n", + "\n", + "### Buffer = Batch + Trajectory Management + Sampling\n", + "\n", + "Tianshou's buffers build on the `Batch` class to provide:\n", + "- **Circular queue storage**: Automatic overwriting of oldest data\n", + "- **Trajectory tracking**: Episode boundaries, returns, and lengths\n", + "- **Efficient sampling**: Random access with various strategies\n", + "- **Integration utilities**: Seamless connection to Collector and Policy\n", + "\n", + "### Use Cases\n", + "\n", + "- **Off-policy algorithms**: DQN, SAC, TD3, DDPG require experience replay\n", + "- **On-policy with replay**: Some PPO implementations reuse buffer data\n", + "- **Offline RL**: Loading and using pre-collected datasets\n", + "- **Multi-environment training**: VectorReplayBuffer for parallel collection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Buffer Types and Hierarchy\n", + "\n", + "Tianshou provides several buffer implementations, each designed for specific use cases. Understanding this hierarchy is crucial for choosing the right buffer.\n", + "\n", + "### Buffer Hierarchy\n", + "\n", + "```mermaid\n", + "graph TD\n", + " RB[ReplayBuffer
Single environment
Circular queue] --> RBM[ReplayBufferManager
Manages multiple buffers
Contiguous memory]\n", + " RBM --> VRB[VectorReplayBuffer
Parallel environments
Maintains temporal order]\n", + " \n", + " RB --> PRB[PrioritizedReplayBuffer
TD-error based sampling
Importance weights]\n", + " PRB --> PVRB[PrioritizedVectorReplayBuffer
Prioritized + Parallel]\n", + " \n", + " RB --> CRB[CachedReplayBuffer
Primary + auxiliary caches
Imitation learning]\n", + " \n", + " RB --> HERB[HERReplayBuffer
Hindsight Experience Replay
Goal-conditioned RL]\n", + " HERB --> HVRB[HERVectorReplayBuffer
HER + Parallel]\n", + " \n", + " style RB fill:#e1f5ff\n", + " style RBM fill:#fff4e1\n", + " style VRB fill:#ffe1f5\n", + " style PRB fill:#e8f5e1\n", + " style CRB fill:#f5e1e1\n", + " style HERB fill:#e1e1f5\n", + "```\n", + "\n", + "### When to Use Which Buffer\n", + "\n", + "**ReplayBuffer**: Single environment scenarios\n", + "- Simple setup and testing\n", + "- Debugging algorithms\n", + "- Low-parallelism training\n", + "\n", + "**VectorReplayBuffer**: Multiple parallel environments (most common)\n", + "- Standard production use case\n", + "- Efficient parallel data collection\n", + "- Maintains per-environment episode boundaries\n", + "\n", + "**PrioritizedReplayBuffer**: DQN variants with prioritization\n", + "- Rainbow DQN\n", + "- Algorithms requiring importance sampling\n", + "- When some transitions are more valuable than others\n", + "\n", + "**CachedReplayBuffer**: Separate primary and auxiliary caches\n", + "- Imitation learning (expert + agent data)\n", + "- GAIL and similar algorithms\n", + "- When you need different sampling strategies for different data sources\n", + "\n", + "**HERReplayBuffer**: Goal-conditioned reinforcement learning\n", + "- Sparse reward environments\n", + "- Robotics tasks with explicit goals\n", + "- Relabeling failed experiences with achieved goals" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Basic Operations\n", + "\n", + "### 3.1 Construction and Configuration\n", + "\n", + "The ReplayBuffer constructor accepts several important parameters that control its behavior:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a buffer with all configuration options\n", + "buf = ReplayBuffer(\n", + " size=20, # Maximum capacity (transitions)\n", + " stack_num=1, # Frame stacking for RNNs (default: 1, no stacking)\n", + " ignore_obs_next=False, # Save memory by not storing obs_next\n", + " save_only_last_obs=False, # For temporal stacking (Atari-style)\n", + " sample_avail=False, # Sample only valid indices for frame stacking\n", + " random_seed=42, # Reproducible sampling\n", + ")\n", + "\n", + "print(f\"Buffer created: {buf}\")\n", + "print(f\"Max size: {buf.maxsize}\")\n", + "print(f\"Current length: {len(buf)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Parameter Explanations**:\n", + "\n", + "- `size`: Maximum number of transitions the buffer can hold. When full, oldest data is overwritten.\n", + "- `stack_num`: Number of consecutive frames to stack. Used for RNN inputs or frame-based policies (Atari).\n", + "- `ignore_obs_next`: If True, obs_next is not stored, saving memory. The buffer reconstructs it from the next obs when needed.\n", + "- `save_only_last_obs`: For temporal stacking. Only saves the last observation in a stack.\n", + "- `sample_avail`: When True with stack_num > 1, only samples indices where a complete stack is available.\n", + "- `random_seed`: Seeds the random number generator for reproducible sampling." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.2 Reserved Keys and the Done Flag System\n", + "\n", + "ReplayBuffer uses nine reserved keys that integrate with Gymnasium conventions. Understanding the done flag system is critical." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The nine reserved keys\n", + "print(\"Reserved keys:\")\n", + "print(ReplayBuffer._reserved_keys)\n", + "print(\"\\nKeys required for add():\")\n", + "print(ReplayBuffer._required_keys_for_add)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Important: Understanding done, terminated, and truncated**\n", + "\n", + "Gymnasium (the successor to OpenAI Gym) introduced a crucial distinction:\n", + "\n", + "- `terminated`: Episode ended naturally (agent reached goal or failed)\n", + " - Examples: CartPole fell over, agent reached goal state\n", + " - Should be used for bootstrapping calculations\n", + "\n", + "- `truncated`: Episode was cut off artificially (time limit, external interruption)\n", + " - Examples: Maximum episode length reached, environment reset externally \n", + " - Should NOT be used for bootstrapping (the episode could have continued)\n", + "\n", + "- `done`: Computed automatically as `terminated OR truncated`\n", + " - Used internally for episode boundary tracking\n", + " - You should NEVER manually set this field\n", + "\n", + "**Best Practice**: Always use the `info` dictionary for custom metadata rather than adding top-level keys:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# GOOD: Custom metadata in info dictionary\n", + "good_batch = Batch(\n", + " obs=np.array([1.0, 2.0]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=np.array([1.5, 2.5]),\n", + " info={\"custom_metric\": 0.95, \"step_count\": 10}, # Custom data here\n", + ")\n", + "\n", + "# BAD: Don't add custom top-level keys (may conflict with future buffer features)\n", + "# bad_batch = Batch(..., custom_metric=0.95) # Don't do this!\n", + "\n", + "print(\"Good batch structure:\")\n", + "print(good_batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.3 Circular Queue Storage\n", + "\n", + "The buffer implements a circular queue: when it reaches maximum capacity, new data overwrites the oldest entries." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a small buffer to demonstrate circular behavior\n", + "demo_buf = ReplayBuffer(size=5)\n", + "\n", + "print(\"Adding 3 transitions:\")\n", + "for i in range(3):\n", + " demo_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i),\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + "print(f\"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}\")\n", + "print(f\"Observations: {demo_buf.obs[: len(demo_buf)]}\")\n", + "\n", + "print(\"\\nAdding 5 more transitions (total 8, exceeds capacity 5):\")\n", + "for i in range(3, 8):\n", + " demo_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i),\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + "print(f\"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}\")\n", + "print(f\"Observations: {demo_buf.obs[: len(demo_buf)]}\")\n", + "print(\"\\nNotice: First 3 transitions (0,1,2) were overwritten by (3,4,5)\")\n", + "print(\"Buffer now contains: [3, 4, 5, 6, 7]\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.4 Batch-Compatible Operations\n", + "\n", + "Since ReplayBuffer extends Batch functionality, it supports standard indexing and slicing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Indexing and slicing\n", + "print(\"Last transition:\")\n", + "print(demo_buf[-1])\n", + "\n", + "print(\"\\nLast 3 transitions:\")\n", + "print(demo_buf[-3:])\n", + "\n", + "print(\"\\nSpecific indices [0, 2, 4]:\")\n", + "print(demo_buf[np.array([0, 2, 4])])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Trajectory Management\n", + "\n", + "A key distinguishing feature of ReplayBuffer is its automatic tracking of episode boundaries and metadata.\n", + "\n", + "### 4.1 Episode Tracking and Metadata\n", + "\n", + "The `add()` method returns four values that provide episode information:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a fresh buffer for trajectory demonstration\n", + "traj_buf = ReplayBuffer(size=20)\n", + "\n", + "print(\"Episode 1: 4 steps, terminates naturally\")\n", + "for i in range(4):\n", + " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i + 1), # Rewards: 1, 2, 3, 4\n", + " terminated=i == 3, # Last step terminates\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + " print(f\" Step {i}: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len}, ep_start={ep_start}\")\n", + "\n", + "print(\"\\nNotice: Episode return (10.0) and length (4) only appear at the end!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Return Values Explained**:\n", + "\n", + "1. `idx`: Index where the transition was inserted (np.ndarray of shape (1,))\n", + "2. `ep_rew`: Episode return, only non-zero when `done=True` (np.ndarray of shape (1,))\n", + "3. `ep_len`: Episode length, only non-zero when `done=True` (np.ndarray of shape (1,))\n", + "4. `ep_start`: Index where the episode started (np.ndarray of shape (1,))\n", + "\n", + "This automatic computation eliminates manual episode tracking during data collection." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Continue with Episode 2: 5 steps\n", + "print(\"Episode 2: 5 steps, truncated (time limit)\")\n", + "for i in range(4, 9):\n", + " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i + 1),\n", + " terminated=False,\n", + " truncated=i == 8, # Last step truncated\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + " if i == 8:\n", + " print(\n", + " f\" Final step: idx={idx}, ep_rew={ep_rew[0]:.1f}, ep_len={ep_len[0]}, ep_start={ep_start}\"\n", + " )\n", + "\n", + "# Episode 3: Ongoing (not finished)\n", + "print(\"\\nEpisode 3: 3 steps, ongoing (not done)\")\n", + "for i in range(9, 12):\n", + " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i + 1),\n", + " terminated=False,\n", + " truncated=False, # Episode continues\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + " if i == 11:\n", + " print(\n", + " f\" Latest step: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len} (zeros because not done)\"\n", + " )\n", + "\n", + "print(f\"\\nBuffer state: {len(traj_buf)} transitions across 2 complete + 1 ongoing episode\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.2 Boundary Navigation: prev() and next()\n", + "\n", + "The buffer provides methods to navigate within episodes while respecting episode boundaries:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Examine the buffer structure\n", + "print(\"Buffer contents:\")\n", + "print(f\"Indices: {np.arange(len(traj_buf))}\")\n", + "print(f\"Obs: {traj_buf.obs[: len(traj_buf)]}\")\n", + "print(f\"Terminated: {traj_buf.terminated[: len(traj_buf)]}\")\n", + "print(f\"Truncated: {traj_buf.truncated[: len(traj_buf)]}\")\n", + "print(f\"Done: {traj_buf.done[: len(traj_buf)]}\")\n", + "print(\"\\nEpisode boundaries: indices 3 (terminated) and 8 (truncated)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prev() returns the previous index within the same episode\n", + "# It STOPS at episode boundaries\n", + "test_indices = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])\n", + "prev_indices = traj_buf.prev(test_indices)\n", + "\n", + "print(\"prev() behavior:\")\n", + "print(f\"Index: {test_indices}\")\n", + "print(f\"Prev: {prev_indices}\")\n", + "print(\"\\nObservations:\")\n", + "print(\"- Index 0 stays at 0 (start of episode 1)\")\n", + "print(\"- Index 4 stays at 4 (start of episode 2, can't go back to episode 1)\")\n", + "print(\"- Index 9 stays at 9 (start of episode 3, can't go back to episode 2)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# next() returns the next index within the same episode\n", + "# It STOPS at episode boundaries\n", + "next_indices = traj_buf.next(test_indices)\n", + "\n", + "print(\"next() behavior:\")\n", + "print(f\"Index: {test_indices}\")\n", + "print(f\"Next: {next_indices}\")\n", + "print(\"\\nObservations:\")\n", + "print(\"- Index 3 stays at 3 (end of episode 1, terminated)\")\n", + "print(\"- Index 8 stays at 8 (end of episode 2, truncated)\")\n", + "print(\"- Indices 9-11 advance normally (episode 3 ongoing)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Use Cases for prev() and next()**:\n", + "\n", + "These methods are essential for computing algorithmic quantities:\n", + "- **N-step returns**: Use prev() to look back N steps within an episode\n", + "- **GAE (Generalized Advantage Estimation)**: Navigate backwards through episodes\n", + "- **Episode extraction**: Find episode start/end indices\n", + "- **Temporal difference targets**: Ensure you don't bootstrap across episode boundaries" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.3 Identifying Unfinished Episodes\n", + "\n", + "The `unfinished_index()` method returns indices of ongoing episodes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "unfinished = traj_buf.unfinished_index()\n", + "print(f\"Unfinished episode indices: {unfinished}\")\n", + "print(f\"Latest step of ongoing episode: obs={traj_buf.obs[unfinished[0]]}\")\n", + "\n", + "# After finishing episode 3\n", + "traj_buf.add(\n", + " Batch(\n", + " obs=12,\n", + " act=12,\n", + " rew=13.0,\n", + " terminated=True,\n", + " truncated=False,\n", + " obs_next=13,\n", + " info={},\n", + " )\n", + ")\n", + "\n", + "unfinished_after = traj_buf.unfinished_index()\n", + "print(\"\\nAfter finishing episode 3:\")\n", + "print(f\"Unfinished episodes: {unfinished_after} (empty array)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Sampling Strategies\n", + "\n", + "Efficient sampling is critical for RL training. The buffer provides several sampling methods and strategies.\n", + "\n", + "### 5.1 Basic Sampling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a buffer with some data\n", + "sample_buf = ReplayBuffer(size=100)\n", + "for i in range(50):\n", + " sample_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i % 4,\n", + " rew=np.random.random(),\n", + " terminated=(i + 1) % 10 == 0,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Sample with batch_size\n", + "batch, indices = sample_buf.sample(batch_size=8)\n", + "print(f\"Sampled batch size: {len(batch)}\")\n", + "print(f\"Sampled indices: {indices}\")\n", + "print(f\"Sampled observations: {batch.obs}\")\n", + "\n", + "# batch_size=None: return all data in random order\n", + "all_data, all_indices = sample_buf.sample(batch_size=None)\n", + "print(f\"\\nSample all (batch_size=None): {len(all_data)} transitions\")\n", + "\n", + "# batch_size=0: return all data in buffer order\n", + "ordered_data, ordered_indices = sample_buf.sample(batch_size=0)\n", + "print(f\"Get all in order (batch_size=0): {len(ordered_data)} transitions\")\n", + "print(f\"Indices in order: {ordered_indices[:10]}...\") # Show first 10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Sampling Behavior Summary**:\n", + "\n", + "- `batch_size > 0`: Random sample of specified size\n", + "- `batch_size = None`: All data in random order \n", + "- `batch_size = 0`: All data in insertion order\n", + "- `batch_size < 0`: Empty array (edge case handling)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.2 Frame Stacking\n", + "\n", + "The `stack_num` parameter enables automatic frame stacking, useful for RNN inputs or Atari-style environments where temporal context matters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create buffer with frame stacking\n", + "stack_buf = ReplayBuffer(size=20, stack_num=4)\n", + "\n", + "# Add observations: 0, 1, 2, ..., 9\n", + "for i in range(10):\n", + " stack_buf.add(\n", + " Batch(\n", + " obs=np.array([i]), # Single frame\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 9,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Get stacked frames for index 6\n", + "# Should return [3, 4, 5, 6] (4 consecutive frames ending at 6)\n", + "stacked = stack_buf.get(index=6, key=\"obs\")\n", + "print(\"Frame stacking demo:\")\n", + "print(\"Requested index: 6\")\n", + "print(f\"Stacked frames shape: {stacked.shape}\")\n", + "print(f\"Stacked frames: {stacked.flatten()}\")\n", + "print(\"\\nExplanation: stack_num=4, so index 6 returns [obs[3], obs[4], obs[5], obs[6]]\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Demonstrate episode boundary handling with frame stacking\n", + "boundary_buf = ReplayBuffer(size=20, stack_num=4)\n", + "\n", + "# Episode 1: indices 0-4\n", + "for i in range(5):\n", + " boundary_buf.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 4,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Episode 2: indices 5-9\n", + "for i in range(5, 10):\n", + " boundary_buf.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 9,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Try to get stacked frames at episode boundary\n", + "boundary_stack = boundary_buf.get(index=6, key=\"obs\") # Early in episode 2\n", + "print(\"\\nFrame stacking at episode boundary:\")\n", + "print(f\"Index 6 stacked frames: {boundary_stack.flatten()}\")\n", + "print(\"Notice: Frames don't cross episode boundary (5,5,5,6 not 3,4,5,6)\")\n", + "print(\"The buffer uses prev() internally, which respects episode boundaries\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Frame Stacking Use Cases**:\n", + "\n", + "- **RNN/LSTM inputs**: Provide temporal context to recurrent networks\n", + "- **Atari games**: Stack 4 frames to capture motion (as in DQN paper)\n", + "- **Velocity estimation**: Multiple frames allow computing derivatives\n", + "- **Partially observable environments**: Build up state estimates\n", + "\n", + "**Important Notes**:\n", + "- Frame stacking respects episode boundaries (won't stack across episodes)\n", + "- Set `sample_avail=True` to only sample indices where full stacks are available\n", + "- `save_only_last_obs=True` saves memory in Atari-style setups" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. VectorReplayBuffer: Parallel Environment Support\n", + "\n", + "VectorReplayBuffer is essential for modern RL training with parallel environments. It maintains separate subbuffers for each environment while providing a unified interface.\n", + "\n", + "### 6.1 Motivation and Architecture\n", + "\n", + "When training with multiple parallel environments (e.g., 8 environments running simultaneously), we need:\n", + "- **Per-environment episode tracking**: Each environment has its own episode boundaries\n", + "- **Temporal ordering**: Preserve the sequence of events within each environment\n", + "- **Unified sampling**: Sample uniformly across all environments for training\n", + "\n", + "```mermaid\n", + "graph LR\n", + " E1[Env 1] --> B1[Subbuffer 1
2500 capacity]\n", + " E2[Env 2] --> B2[Subbuffer 2
2500 capacity]\n", + " E3[Env 3] --> B3[Subbuffer 3
2500 capacity]\n", + " E4[Env 4] --> B4[Subbuffer 4
2500 capacity]\n", + " \n", + " B1 --> VRB[VectorReplayBuffer
Total: 10000
Unified Sampling]\n", + " B2 --> VRB\n", + " B3 --> VRB\n", + " B4 --> VRB\n", + " \n", + " VRB --> Policy[Policy Training]\n", + " \n", + " style E1 fill:#e1f5ff\n", + " style E2 fill:#e1f5ff\n", + " style E3 fill:#e1f5ff\n", + " style E4 fill:#e1f5ff\n", + " style B1 fill:#fff4e1\n", + " style B2 fill:#fff4e1\n", + " style B3 fill:#fff4e1\n", + " style B4 fill:#fff4e1\n", + " style VRB fill:#ffe1f5\n", + " style Policy fill:#e8f5e1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create VectorReplayBuffer for 4 parallel environments\n", + "vec_buf = VectorReplayBuffer(\n", + " total_size=100, # Total capacity across all subbuffers\n", + " buffer_num=4, # Number of parallel environments\n", + ")\n", + "\n", + "print(\"VectorReplayBuffer created:\")\n", + "print(f\"Total size: {vec_buf.maxsize}\")\n", + "print(f\"Number of subbuffers: {vec_buf.buffer_num}\")\n", + "print(f\"Size per subbuffer: {vec_buf.maxsize // vec_buf.buffer_num}\")\n", + "print(f\"Subbuffer edges: {vec_buf.subbuffer_edges}\")\n", + "print(\"\\nSubbuffer edges define the boundary indices: [0, 25, 50, 75, 100]\")\n", + "print(\"Subbuffer 0: indices 0-24, Subbuffer 1: indices 25-49, etc.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.2 The buffer_ids Parameter\n", + "\n", + "This is one of the most confusing aspects for new users. The `buffer_ids` parameter specifies which subbuffer each transition belongs to." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulate data from 4 parallel environments\n", + "# Each environment produces one transition\n", + "parallel_batch = Batch(\n", + " obs=np.array([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2], [3.1, 3.2]]), # 4 observations\n", + " act=np.array([0, 1, 0, 1]), # 4 actions\n", + " rew=np.array([1.0, 2.0, 3.0, 4.0]), # 4 rewards\n", + " terminated=np.array([False, False, False, False]),\n", + " truncated=np.array([False, False, False, False]),\n", + " obs_next=np.array([[0.2, 0.3], [1.2, 1.3], [2.2, 2.3], [3.2, 3.3]]),\n", + " info=np.array([{}, {}, {}, {}], dtype=object),\n", + ")\n", + "\n", + "print(\"Parallel batch shape:\", parallel_batch.obs.shape)\n", + "print(\"This represents 4 transitions, one from each environment\")\n", + "\n", + "# Add with buffer_ids specifying which subbuffer each transition goes to\n", + "indices, ep_rews, ep_lens, ep_starts = vec_buf.add(\n", + " parallel_batch,\n", + " buffer_ids=[0, 1, 2, 3], # Transition 0→Subbuf 0, 1→Subbuf 1, etc.\n", + ")\n", + "\n", + "print(f\"\\nAdded to indices: {indices}\")\n", + "print(\"Notice: Indices are in different subbuffers:\")\n", + "print(f\" Index {indices[0]} in subbuffer 0 (range 0-24)\")\n", + "print(f\" Index {indices[1]} in subbuffer 1 (range 25-49)\")\n", + "print(f\" Index {indices[2]} in subbuffer 2 (range 50-74)\")\n", + "print(f\" Index {indices[3]} in subbuffer 3 (range 75-99)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add more data to demonstrate buffer_ids\n", + "# Environments don't always produce data in order 0,1,2,3\n", + "# For example, if only environments 1 and 3 are ready:\n", + "partial_batch = Batch(\n", + " obs=np.array([[1.2, 1.3], [3.2, 3.3]]), # Only 2 observations\n", + " act=np.array([0, 1]),\n", + " rew=np.array([2.5, 4.5]),\n", + " terminated=np.array([False, False]),\n", + " truncated=np.array([False, False]),\n", + " obs_next=np.array([[1.3, 1.4], [3.3, 3.4]]),\n", + " info=np.array([{}, {}], dtype=object),\n", + ")\n", + "\n", + "# Only environments 1 and 3 produced data\n", + "indices2, _, _, _ = vec_buf.add(\n", + " partial_batch,\n", + " buffer_ids=[1, 3], # Only these two subbuffers receive data\n", + ")\n", + "\n", + "print(\"Added partial batch (only envs 1 and 3):\")\n", + "print(f\"Indices: {indices2}\")\n", + "print(f\"Subbuffer 1 received data at index {indices2[0]}\")\n", + "print(f\"Subbuffer 3 received data at index {indices2[1]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Important: buffer_ids Requirements**:\n", + "\n", + "For `VectorReplayBuffer`:\n", + "- `buffer_ids` length must match batch size\n", + "- Values must be in range [0, buffer_num)\n", + "- Can be partial (not all environments at once)\n", + "\n", + "For regular `ReplayBuffer`:\n", + "- If `buffer_ids` is not None, it must be [0]\n", + "- Batch must have shape (1, data_length)\n", + "- This is for API compatibility with VectorReplayBuffer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.3 Subbuffer Edges and Episode Handling\n", + "\n", + "Subbuffer edges prevent episodes from spanning across subbuffers, ensuring data from different environments doesn't get mixed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The subbuffer_edges property defines boundaries\n", + "print(f\"Subbuffer edges: {vec_buf.subbuffer_edges}\")\n", + "print(\"\\nThis creates 4 subbuffers:\")\n", + "for i in range(vec_buf.buffer_num):\n", + " start = vec_buf.subbuffer_edges[i]\n", + " end = vec_buf.subbuffer_edges[i + 1]\n", + " print(f\"Subbuffer {i}: indices [{start}, {end})\")\n", + "\n", + "# Episodes cannot cross these boundaries\n", + "# prev() and next() respect subbuffer edges just like episode boundaries\n", + "test_idx = np.array([24, 25, 49, 50]) # At subbuffer edges\n", + "prev_result = vec_buf.prev(test_idx)\n", + "next_result = vec_buf.next(test_idx)\n", + "\n", + "print(\"\\nBoundary navigation test:\")\n", + "print(f\"Indices: {test_idx}\")\n", + "print(f\"prev(): {prev_result}\")\n", + "print(f\"next(): {next_result}\")\n", + "print(\"\\nNotice: prev/next don't cross subbuffer boundaries\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.4 Sampling from VectorReplayBuffer\n", + "\n", + "Sampling is uniform across all subbuffers (proportional to their current fill level):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add more data to have enough for sampling\n", + "for _step in range(10):\n", + " batch = Batch(\n", + " obs=np.random.randn(4, 2),\n", + " act=np.random.randint(0, 2, size=4),\n", + " rew=np.random.random(4),\n", + " terminated=np.zeros(4, dtype=bool),\n", + " truncated=np.zeros(4, dtype=bool),\n", + " obs_next=np.random.randn(4, 2),\n", + " info=np.array([{}] * 4, dtype=object),\n", + " )\n", + " vec_buf.add(batch, buffer_ids=[0, 1, 2, 3])\n", + "\n", + "# Sample batch\n", + "sampled, indices = vec_buf.sample(batch_size=16)\n", + "print(f\"Sampled {len(sampled)} transitions\")\n", + "print(f\"Sample indices (from different subbuffers): {indices}\")\n", + "print(\"\\nNotice indices span across all subbuffer ranges\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Specialized Buffer Variants\n", + "\n", + "### 7.1 PrioritizedReplayBuffer\n", + "\n", + "Implements prioritized experience replay where transitions are sampled based on their TD-error magnitudes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Create prioritized buffer\nprio_buf = PrioritizedReplayBuffer(\n size=100,\n alpha=0.6, # Prioritization exponent (0=uniform, 1=fully prioritized)\n beta=0.4, # Importance sampling correction (annealed to 1)\n)\n\n# Add some transitions\nfor i in range(20):\n prio_buf.add(\n Batch(\n obs=np.array([i]),\n act=i % 4,\n rew=np.random.random(),\n terminated=False,\n truncated=False,\n obs_next=np.array([i + 1]),\n info={},\n )\n )\n\n# Sample returns batch and indices\n# Importance weights are INSIDE the batch as batch.weight\nbatch, indices = prio_buf.sample(batch_size=8)\nprint(f\"Sampled batch size: {len(batch)}\")\nprint(f\"Indices: {indices}\")\nprint(f\"Importance weights (batch.weight): {batch.weight}\")\nprint(\"\\nWeights are stored in batch.weight and compensate for biased sampling\")" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# After computing TD-errors from the sampled batch, update priorities\n", + "# In practice, these would be actual TD-errors: |Q(s,a) - (r + γ*max Q(s',a'))|\n", + "fake_td_errors = np.random.random(len(indices)) * 10 # Simulated TD-errors\n", + "\n", + "# Update priorities (higher TD-error = higher priority)\n", + "prio_buf.update_weight(indices, fake_td_errors)\n", + "\n", + "print(\"Updated priorities based on TD-errors\")\n", + "print(\"Transitions with higher TD-errors will be sampled more frequently\")\n", + "\n", + "# Demonstrate beta annealing\n", + "prio_buf.set_beta(0.6) # Increase beta over training\n", + "print(f\"\\nAnnealed beta to: {prio_buf.options['beta']}\")\n", + "print(\"Beta typically starts at 0.4 and anneals to 1.0 over training\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**PrioritizedReplayBuffer Use Cases**:\n", + "- Rainbow DQN and variants\n", + "- Any algorithm where some transitions are more \"surprising\" and valuable\n", + "- Environments with rare but important events\n", + "\n", + "**Key Parameters**:\n", + "- `alpha`: Controls how much prioritization affects sampling (0=uniform, 1=fully proportional to priority)\n", + "- `beta`: Importance sampling correction to remain unbiased (anneal from ~0.4 to 1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 7.2 Other Specialized Buffers\n", + "\n", + "**CachedReplayBuffer**: Maintains a primary buffer plus auxiliary caches\n", + "- Use case: Imitation learning where you want separate expert and agent buffers\n", + "- Example: GAIL (Generative Adversarial Imitation Learning)\n", + "- Allows different sampling ratios from different sources\n", + "\n", + "**HERReplayBuffer**: Hindsight Experience Replay for goal-conditioned tasks\n", + "- Use case: Sparse reward robotics tasks\n", + "- Relabels failed episodes with achieved goals as if they were intended\n", + "- Dramatically improves learning in goal-reaching tasks\n", + "- See the HER documentation for detailed examples\n", + "\n", + "For detailed usage of these specialized buffers, refer to the Tianshou API documentation and algorithm-specific tutorials." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Serialization and Persistence\n", + "\n", + "Buffers support multiple serialization formats for saving and loading data.\n", + "\n", + "### 8.1 Pickle Serialization\n", + "\n", + "The simplest method, preserving all buffer state including trajectory metadata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create and populate a buffer\n", + "save_buf = ReplayBuffer(size=50)\n", + "for i in range(30):\n", + " save_buf.add(\n", + " Batch(\n", + " obs=np.array([i, i + 1]),\n", + " act=i % 4,\n", + " rew=float(i),\n", + " terminated=(i + 1) % 10 == 0,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1, i + 2]),\n", + " info={\"step\": i},\n", + " )\n", + " )\n", + "\n", + "print(f\"Original buffer: {len(save_buf)} transitions\")\n", + "\n", + "# Serialize with pickle\n", + "pickled_data = pickle.dumps(save_buf)\n", + "print(f\"Serialized size: {len(pickled_data)} bytes\")\n", + "\n", + "# Deserialize\n", + "loaded_buf = pickle.loads(pickled_data)\n", + "print(f\"Loaded buffer: {len(loaded_buf)} transitions\")\n", + "print(f\"Data preserved: obs[0] = {loaded_buf.obs[0]}\")\n", + "print(f\"Metadata preserved: info[0] = {loaded_buf.info[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 8.2 HDF5 Serialization\n", + "\n", + "HDF5 is recommended for large datasets and cross-platform compatibility:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save to HDF5\n", + "with tempfile.NamedTemporaryFile(suffix=\".hdf5\", delete=False) as tmp:\n", + " hdf5_path = tmp.name\n", + "\n", + "save_buf.save_hdf5(hdf5_path, compression=\"gzip\")\n", + "print(f\"Saved to HDF5: {hdf5_path}\")\n", + "\n", + "# Load from HDF5\n", + "loaded_hdf5_buf = ReplayBuffer.load_hdf5(hdf5_path)\n", + "print(f\"Loaded from HDF5: {len(loaded_hdf5_buf)} transitions\")\n", + "print(f\"Data matches: {np.array_equal(save_buf.obs, loaded_hdf5_buf.obs)}\")\n", + "\n", + "# Clean up\n", + "import os\n", + "\n", + "os.unlink(hdf5_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**When to Use HDF5**:\n", + "- Large datasets (> 1GB)\n", + "- Offline RL with pre-collected data\n", + "- Sharing data across platforms\n", + "- Need for compression\n", + "- Integration with external tools (many scientific tools read HDF5)\n", + "\n", + "**When to Use Pickle**:\n", + "- Quick saves during development\n", + "- Small buffers\n", + "- Python-only workflow\n", + "- Simpler serialization needs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 8.3 Loading from Raw Data with from_data()\n", + "\n", + "For offline RL, you can create a buffer from raw arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulate pre-collected offline dataset\n", + "import h5py\n", + "\n", + "# Create temporary HDF5 file with raw data\n", + "with tempfile.NamedTemporaryFile(suffix=\".hdf5\", delete=False) as tmp:\n", + " offline_path = tmp.name\n", + "\n", + "with h5py.File(offline_path, \"w\") as f:\n", + " # Create datasets\n", + " n = 100\n", + " f.create_dataset(\"obs\", data=np.random.randn(n, 4))\n", + " f.create_dataset(\"act\", data=np.random.randint(0, 2, n))\n", + " f.create_dataset(\"rew\", data=np.random.randn(n))\n", + " f.create_dataset(\"terminated\", data=np.random.random(n) < 0.1)\n", + " f.create_dataset(\"truncated\", data=np.zeros(n, dtype=bool))\n", + " f.create_dataset(\"done\", data=np.random.random(n) < 0.1)\n", + " f.create_dataset(\"obs_next\", data=np.random.randn(n, 4))\n", + "\n", + "# Load into buffer\n", + "with h5py.File(offline_path, \"r\") as f:\n", + " offline_buf = ReplayBuffer.from_data(\n", + " obs=f[\"obs\"],\n", + " act=f[\"act\"],\n", + " rew=f[\"rew\"],\n", + " terminated=f[\"terminated\"],\n", + " truncated=f[\"truncated\"],\n", + " done=f[\"done\"],\n", + " obs_next=f[\"obs_next\"],\n", + " )\n", + "\n", + "print(f\"Loaded offline dataset: {len(offline_buf)} transitions\")\n", + "print(f\"Observation shape: {offline_buf.obs.shape}\")\n", + "\n", + "# Clean up\n", + "os.unlink(offline_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is the standard approach for offline RL where you have pre-collected datasets from other sources." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Integration with the RL Pipeline\n", + "\n", + "Understanding how buffers integrate with other Tianshou components is essential for effective usage.\n", + "\n", + "### 9.1 Data Flow in RL Training\n", + "\n", + "```mermaid\n", + "graph LR\n", + " ENV[Vectorized
Environments] -->|observations| COL[Collector]\n", + " POL[Policy] -->|actions| COL\n", + " COL -->|transitions| BUF[Buffer]\n", + " BUF -->|sampled batches| POL\n", + " POL -->|forward pass| ALG[Algorithm]\n", + " ALG -->|loss & gradients| POL\n", + " \n", + " style ENV fill:#e1f5ff\n", + " style COL fill:#fff4e1\n", + " style BUF fill:#ffe1f5\n", + " style POL fill:#e8f5e1\n", + " style ALG fill:#f5e1e1\n", + "```\n", + "\n", + "### 9.2 Typical Training Loop Pattern\n", + "\n", + "Here's how buffers are typically used in a training loop:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pseudocode for typical RL training loop\n", + "# (This is illustrative; actual implementation would use Trainer)\n", + "\n", + "\n", + "def training_loop_pseudocode():\n", + " \"\"\"\n", + " Illustrative training loop showing buffer integration.\n", + "\n", + " In practice, use Tianshou's Trainer class which handles this.\n", + " \"\"\"\n", + " # Setup (illustration only)\n", + " # env = make_vectorized_env(num_envs=8)\n", + " # policy = make_policy()\n", + " # buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\n", + " # collector = Collector(policy, env, buffer)\n", + "\n", + " # Training loop\n", + " # for epoch in range(num_epochs):\n", + " # # 1. Collect data from environments\n", + " # collect_result = collector.collect(n_step=1000)\n", + " # # Collector automatically adds transitions to buffer with correct buffer_ids\n", + " #\n", + " # # 2. Train on multiple batches\n", + " # for _ in range(update_per_collect):\n", + " # # Sample batch from buffer\n", + " # batch, indices = buffer.sample(batch_size=256)\n", + " #\n", + " # # Compute loss and update policy\n", + " # loss = policy.learn(batch)\n", + " #\n", + " # # For prioritized buffers, update priorities\n", + " # # if isinstance(buffer, PrioritizedReplayBuffer):\n", + " # # buffer.update_weight(indices, td_errors)\n", + "\n", + " print(\"This pseudocode illustrates the buffer's role:\")\n", + " print(\"1. Collector fills buffer from environment interaction\")\n", + " print(\"2. Buffer provides random samples for training\")\n", + " print(\"3. Policy learns from sampled batches\")\n", + " print(\"\\nIn practice, use Tianshou's Trainer for this workflow\")\n", + "\n", + "\n", + "training_loop_pseudocode()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 9.3 Collector Integration\n", + "\n", + "The Collector class handles the complexity of:\n", + "- Calling policy to get actions\n", + "- Stepping environments\n", + "- Adding transitions to buffer with correct buffer_ids\n", + "- Tracking episode statistics\n", + "\n", + "When you create a Collector, you pass it a buffer, and it automatically:\n", + "- Uses VectorReplayBuffer for vectorized environments\n", + "- Sets buffer_ids based on which environments are ready\n", + "- Handles episode resets and boundary tracking\n", + "\n", + "See the Collector tutorial for detailed examples of this integration." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Advanced Topics and Edge Cases\n", + "\n", + "### 10.1 Buffer Overflow and Episode Boundaries\n", + "\n", + "What happens when the buffer fills up mid-episode?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Small buffer to demonstrate overflow\n", + "overflow_buf = ReplayBuffer(size=8)\n", + "\n", + "# Add a long episode (12 steps, buffer size is only 8)\n", + "print(\"Adding 12-step episode to buffer with size 8:\")\n", + "for i in range(12):\n", + " idx, ep_rew, ep_len, ep_start = overflow_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 11,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + " if i in [7, 11]:\n", + " print(f\" Step {i}: idx={idx}, buffer_len={len(overflow_buf)}\")\n", + "\n", + "print(\"\\nFinal buffer contents (most recent 8 steps):\")\n", + "print(f\"Observations: {overflow_buf.obs[: len(overflow_buf)]}\")\n", + "print(f\"Episode return: {ep_rew[0]} (sum of all 12 steps, tracked correctly!)\")\n", + "print(\"\\nNote: Buffer overwrote old data but episode statistics are still correct\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Important**: Episode returns and lengths are tracked internally and remain correct even when the episode spans buffer overflows. The buffer maintains `_ep_return`, `_ep_len`, and `_ep_start_idx` to track ongoing episodes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.2 Episode Spanning Subbuffer Edges\n", + "\n", + "In VectorReplayBuffer, episodes can wrap around within their subbuffer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create small VectorReplayBuffer to demonstrate edge crossing\n", + "edge_buf = VectorReplayBuffer(total_size=20, buffer_num=2) # 10 per subbuffer\n", + "\n", + "print(f\"Subbuffer edges: {edge_buf.subbuffer_edges}\")\n", + "print(\"Subbuffer 0: indices 0-9, Subbuffer 1: indices 10-19\\n\")\n", + "\n", + "# Fill subbuffer 0 with 12 steps (wraps around since capacity is 10)\n", + "for i in range(12):\n", + " batch = Batch(\n", + " obs=np.array([[i]]),\n", + " act=np.array([0]),\n", + " rew=np.array([1.0]),\n", + " terminated=np.array([i == 11]),\n", + " truncated=np.array([False]),\n", + " obs_next=np.array([[i + 1]]),\n", + " info=np.array([{}], dtype=object),\n", + " )\n", + " idx, _, _, _ = edge_buf.add(batch, buffer_ids=[0])\n", + " if i >= 10:\n", + " print(f\"Step {i} added at index {idx[0]} (wrapped around in subbuffer 0)\")\n", + "\n", + "# get_buffer_indices handles this correctly\n", + "episode_indices = edge_buf.get_buffer_indices(start=8, stop=2) # Crosses edge\n", + "print(f\"\\nEpisode spanning edge (from 8 to 1): {episode_indices}\")\n", + "print(\"Correctly retrieves [8, 9, 0, 1] within subbuffer 0\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.3 ignore_obs_next Memory Optimization\n", + "\n", + "For memory-constrained scenarios, you can avoid storing obs_next:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Buffer that doesn't store obs_next\n", + "memory_buf = ReplayBuffer(size=10, ignore_obs_next=True)\n", + "\n", + "# Add transitions (obs_next is ignored)\n", + "for i in range(5):\n", + " memory_buf.add(\n", + " Batch(\n", + " obs=np.array([i, i + 1]),\n", + " act=i,\n", + " rew=1.0,\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1, i + 2]), # Provided but not stored\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# When sampling, obs_next is reconstructed from next obs\n", + "sample, _ = memory_buf.sample(batch_size=1)\n", + "print(f\"Sampled obs: {sample.obs}\")\n", + "print(f\"Sampled obs_next: {sample.obs_next}\")\n", + "print(\"\\nobs_next was reconstructed, not stored directly\")\n", + "print(\"This saves memory at the cost of slightly more complex retrieval\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is particularly useful for Atari environments with large observation spaces (84x84x4 frames)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 11. Surprising Behaviors and Gotchas\n", + "\n", + "### 11.1 Most Common Mistake: buffer_ids Confusion\n", + "\n", + "The buffer_ids parameter is the most common source of errors:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# COMMON ERROR 1: Forgetting buffer_ids with VectorReplayBuffer\n", + "vec_demo = VectorReplayBuffer(total_size=100, buffer_num=4)\n", + "\n", + "parallel_data = Batch(\n", + " obs=np.random.randn(4, 2),\n", + " act=np.array([0, 1, 0, 1]),\n", + " rew=np.array([1.0, 2.0, 3.0, 4.0]),\n", + " terminated=np.array([False, False, False, False]),\n", + " truncated=np.array([False, False, False, False]),\n", + " obs_next=np.random.randn(4, 2),\n", + " info=np.array([{}, {}, {}, {}], dtype=object),\n", + ")\n", + "\n", + "# WRONG: Omitting buffer_ids (defaults to [0,1,2,3] which is OK here)\n", + "# But if you have partial data, this will fail\n", + "vec_demo.add(parallel_data) # Works by default\n", + "\n", + "# CORRECT: Always explicit\n", + "vec_demo.add(parallel_data, buffer_ids=[0, 1, 2, 3])\n", + "print(\"Always specify buffer_ids explicitly for clarity\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# COMMON ERROR 2: Shape mismatch with buffer_ids\n", + "try:\n", + " # Trying to add 2 transitions but specifying 4 buffer_ids\n", + " wrong_batch = Batch(\n", + " obs=np.random.randn(2, 2), # Only 2 transitions!\n", + " act=np.array([0, 1]),\n", + " rew=np.array([1.0, 2.0]),\n", + " terminated=np.array([False, False]),\n", + " truncated=np.array([False, False]),\n", + " obs_next=np.random.randn(2, 2),\n", + " info=np.array([{}, {}], dtype=object),\n", + " )\n", + " vec_demo.add(wrong_batch, buffer_ids=[0, 1, 2, 3]) # MISMATCH!\n", + "except (IndexError, ValueError) as e:\n", + " print(f\"Error caught: {type(e).__name__}\")\n", + " print(\"Lesson: buffer_ids length must match batch size\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11.2 Done Flag Confusion\n", + "\n", + "Never manually set the `done` flag:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# WRONG: Manually setting done\n", + "wrong_batch = Batch(\n", + " obs=1,\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=True,\n", + " truncated=False,\n", + " # done=True, # DON'T DO THIS! It will be overwritten anyway\n", + " obs_next=2,\n", + " info={},\n", + ")\n", + "\n", + "# CORRECT: Only set terminated and truncated\n", + "# done is automatically computed as (terminated OR truncated)\n", + "correct_batch = Batch(\n", + " obs=1,\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=True, # Episode ended naturally\n", + " truncated=False, # Not cut off\n", + " obs_next=2,\n", + " info={},\n", + ")\n", + "\n", + "demo = ReplayBuffer(size=10)\n", + "demo.add(correct_batch)\n", + "print(f\"Terminated: {demo.terminated[0]}\")\n", + "print(f\"Truncated: {demo.truncated[0]}\")\n", + "print(f\"Done (auto-computed): {demo.done[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11.3 Sampling from Empty or Near-Empty Buffers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Edge case: Sampling more than available\n", + "small_buf = ReplayBuffer(size=100)\n", + "for i in range(5): # Only 5 transitions\n", + " small_buf.add(\n", + " Batch(obs=i, act=0, rew=1.0, terminated=False, truncated=False, obs_next=i + 1, info={})\n", + " )\n", + "\n", + "# Request 20 but only 5 available - samples with replacement\n", + "batch, indices = small_buf.sample(batch_size=20)\n", + "print(f\"Requested 20, buffer has {len(small_buf)}, got {len(batch)}\")\n", + "print(f\"Indices: {indices}\")\n", + "print(\"Notice: Some indices repeat (sampling with replacement)\")\n", + "\n", + "# Defensive pattern: Check buffer size\n", + "if len(small_buf) >= 128:\n", + " batch, _ = small_buf.sample(128)\n", + "else:\n", + " print(f\"Buffer has {len(small_buf)} < 128, waiting for more data\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11.4 Frame Stacking Valid Indices\n", + "\n", + "With stack_num > 1, not all indices are valid for sampling:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# With frame stacking, early indices can't form complete stacks\n", + "stack_demo = ReplayBuffer(size=20, stack_num=4, sample_avail=True)\n", + "\n", + "for i in range(10):\n", + " stack_demo.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 9,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# With sample_avail=True, only valid indices are sampled\n", + "sampled, indices = stack_demo.sample(batch_size=5)\n", + "print(f\"Sampled indices with stack_num=4, sample_avail=True: {indices}\")\n", + "print(\"All indices >= 3 (can form complete 4-frame stacks)\")\n", + "\n", + "# Without sample_avail, any index can be sampled (may have incomplete stacks)\n", + "stack_demo2 = ReplayBuffer(size=20, stack_num=4, sample_avail=False)\n", + "for i in range(10):\n", + " stack_demo2.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "sampled2, indices2 = stack_demo2.sample(batch_size=5)\n", + "print(f\"\\nSampled indices with sample_avail=False: {indices2}\")\n", + "print(\"May include indices < 3 (incomplete stacks repeated from boundary)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 12. Best Practices\n", + "\n", + "### 12.1 Choosing the Right Buffer\n", + "\n", + "**Decision Tree**:\n", + "\n", + "1. Are you using parallel environments?\n", + " - Yes → Use `VectorReplayBuffer`\n", + " - No → Continue to 2\n", + "\n", + "2. Do you need prioritized experience replay?\n", + " - Yes → Use `PrioritizedReplayBuffer` or `PrioritizedVectorReplayBuffer`\n", + " - No → Continue to 3\n", + "\n", + "3. Is it goal-conditioned RL with sparse rewards?\n", + " - Yes → Use `HERReplayBuffer` or `HERVectorReplayBuffer`\n", + " - No → Continue to 4\n", + "\n", + "4. Do you need separate expert and agent buffers?\n", + " - Yes → Use `CachedReplayBuffer`\n", + " - No → Use `ReplayBuffer` (single env) or `VectorReplayBuffer` (standard choice)\n", + "\n", + "**Most Common Setup**: `VectorReplayBuffer` for production training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.2 Buffer Sizing Guidelines\n", + "\n", + "**Rule of Thumb by Domain**:\n", + "\n", + "- **Atari games**: 1,000,000 transitions (1e6)\n", + "- **Continuous control (MuJoCo)**: 100,000-1,000,000 (1e5-1e6)\n", + "- **Robotics**: 100,000-500,000 (1e5-5e5)\n", + "- **Simple environments (CartPole)**: 10,000-50,000 (1e4-5e4)\n", + "\n", + "**Factors to Consider**:\n", + "- Available RAM (each transition ~observation_size * 2 + metadata)\n", + "- Training time vs sample efficiency tradeoff\n", + "- Algorithm requirements (some need larger buffers)\n", + "\n", + "**Memory Estimation**:\n", + "```python\n", + "# For environments with observation shape (84, 84, 4) (Atari):\n", + "# Each transition: 2 * 84 * 84 * 4 bytes (obs + obs_next) + ~100 bytes overhead\n", + "# = ~56KB per transition\n", + "# 1M transitions = ~56GB (use ignore_obs_next to halve this!)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.3 Configuration Best Practices\n", + "\n", + "**When to use stack_num > 1**:\n", + "- RNN/LSTM policies need temporal context\n", + "- Frame-based policies (Atari with 4-frame stacking)\n", + "- Velocity estimation from positions\n", + "\n", + "**When to use ignore_obs_next=True**:\n", + "- Memory-constrained environments\n", + "- Atari (large observation spaces)\n", + "- When obs_next can be reconstructed from next obs\n", + "\n", + "**When to use save_only_last_obs=True**:\n", + "- Atari with temporal stacking in environment wrapper\n", + "- When observations already contain frame history\n", + "\n", + "**When to use sample_avail=True**:\n", + "- Always use with stack_num > 1 for correctness\n", + "- Ensures samples have complete frame stacks\n", + "- Small performance cost but worth it for data quality" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.4 Integration Patterns\n", + "\n", + "**Pattern 1: Standard Off-Policy Setup**\n", + "```python\n", + "# env = make_vectorized_env(num_envs=8)\n", + "# buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\n", + "# policy = SACPolicy(...)\n", + "# collector = Collector(policy, env, buffer)\n", + "# \n", + "# # Collect and train\n", + "# collector.collect(n_step=1000)\n", + "# for _ in range(10):\n", + "# batch, indices = buffer.sample(256)\n", + "# policy.learn(batch)\n", + "```\n", + "\n", + "**Pattern 2: Pre-fill Buffer Before Training**\n", + "```python\n", + "# # Collect random exploration data\n", + "# collector.collect(n_step=10000) # Fill buffer\n", + "# \n", + "# # Then start training\n", + "# while not converged:\n", + "# collector.collect(n_step=100)\n", + "# for _ in range(10):\n", + "# batch = buffer.sample(256)\n", + "# policy.learn(batch)\n", + "```\n", + "\n", + "**Pattern 3: Offline RL**\n", + "```python\n", + "# # Load pre-collected dataset\n", + "# buffer = ReplayBuffer.load_hdf5(\"expert_data.hdf5\")\n", + "# \n", + "# # Train without further collection\n", + "# for epoch in range(num_epochs):\n", + "# for _ in range(updates_per_epoch):\n", + "# batch = buffer.sample(256)\n", + "# policy.learn(batch)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.5 Performance Tips\n", + "\n", + "**Tip 1: Pre-allocate buffer size appropriately**\n", + "- Don't make buffer too large (wastes memory)\n", + "- Don't make it too small (loses important old experiences)\n", + "- Start with domain defaults and adjust based on performance\n", + "\n", + "**Tip 2: Use HDF5 for large offline datasets**\n", + "- Compression saves disk space\n", + "- Faster loading than pickle for large files\n", + "- Better for sharing across systems\n", + "\n", + "**Tip 3: Batch sampling efficiently**\n", + "- Sample once and use multiple times if possible\n", + "- Don't sample more than you need\n", + "- For multi-GPU training, sample once and split\n", + "\n", + "**Tip 4: Monitor buffer usage**\n", + "```python\n", + "# print(f\"Buffer usage: {len(buffer)}/{buffer.maxsize}\")\n", + "# if len(buffer) < batch_size:\n", + "# print(\"Warning: Sampling with replacement!\")\n", + "```\n", + "\n", + "**Tip 5: Consider ignore_obs_next for large observation spaces**\n", + "- Can halve memory usage\n", + "- Small computational overhead on sampling\n", + "- Especially valuable for image-based RL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## 13. Quick Reference\n\n### Method Summary\n\n| Method | Purpose | Returns | Notes |\n|--------|---------|---------|-------|\n| `add(batch, buffer_ids)` | Add transition(s) | `(idx, ep_rew, ep_len, ep_start)` | ep_rew/ep_len only non-zero when done=True |\n| `sample(size)` | Random sample | `(batch, indices)` | size=None for all (random), 0 for all (ordered) |\n| `prev(idx)` | Previous in episode | `indices` | Stops at episode boundaries |\n| `next(idx)` | Next in episode | `indices` | Stops at episode boundaries |\n| `get(idx, key, stack_num)` | Get with stacking | `data` | Returns stacked frames if stack_num > 1 |\n| `get_buffer_indices(start, stop)` | Episode range | `indices` | Handles edge-crossing episodes |\n| `unfinished_index()` | Ongoing episodes | `indices` | Returns last step of unfinished episodes |\n| `save_hdf5(path)` | Save to HDF5 | - | Recommended for large datasets |\n| `load_hdf5(path)` | Load from HDF5 | `buffer` | Class method |\n| `from_data(...)` | Create from arrays | `buffer` | For offline RL datasets |\n| `reset()` | Clear buffer | - | Optionally keep episode statistics |\n| `sample_indices(size)` | Get indices only | `indices` | For custom sampling logic |\n\n### Common Patterns Cheatsheet\n\n**Single Environment**:\n```python\nbuffer = ReplayBuffer(size=10000)\nbuffer.add(Batch(obs=..., act=..., rew=..., terminated=..., truncated=..., obs_next=..., info={}))\nbatch, indices = buffer.sample(batch_size=256)\n```\n\n**Parallel Environments**:\n```python\nbuffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\nbuffer.add(parallel_batch, buffer_ids=[0,1,2,3,4,5,6,7])\nbatch, indices = buffer.sample(batch_size=256)\n```\n\n**Frame Stacking**:\n```python\nbuffer = ReplayBuffer(size=100000, stack_num=4, sample_avail=True)\nstacked_obs = buffer.get(index=50, key=\"obs\") # Returns 4 stacked frames\n```\n\n**Prioritized Replay**:\n```python\nbuffer = PrioritizedReplayBuffer(size=100000, alpha=0.6, beta=0.4)\nbatch, indices = buffer.sample(batch_size=256)\nweights = batch.weight # Importance weights are inside the batch\n# ... compute TD errors ...\nbuffer.update_weight(indices, td_errors)\n```\n\n**Offline RL**:\n```python\nbuffer = ReplayBuffer.load_hdf5(\"dataset.hdf5\")\n# Or:\nwith h5py.File(\"dataset.hdf5\", \"r\") as f:\n buffer = ReplayBuffer.from_data(obs=f[\"obs\"], act=f[\"act\"], ...)\n```\n\n**Episode Retrieval**:\n```python\n# Find episode boundaries, then:\nepisode_indices = buffer.get_buffer_indices(start=ep_start_idx, stop=ep_end_idx+1)\nepisode = buffer[episode_indices]\n```" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary and Next Steps\n", + "\n", + "This tutorial covered Tianshou's buffer system comprehensively:\n", + "\n", + "1. **Buffer fundamentals**: Why buffers are essential for RL\n", + "2. **Buffer hierarchy**: Understanding different buffer types\n", + "3. **Basic operations**: Construction, configuration, and data management\n", + "4. **Trajectory management**: Episode tracking and boundary navigation\n", + "5. **Sampling strategies**: Basic sampling and frame stacking\n", + "6. **VectorReplayBuffer**: Critical for parallel environments\n", + "7. **Specialized buffers**: Prioritized, cached, and HER variants\n", + "8. **Serialization**: Pickle and HDF5 persistence\n", + "9. **Integration**: How buffers fit in the RL pipeline\n", + "10. **Advanced topics**: Edge cases and overflow handling\n", + "11. **Gotchas**: Common mistakes and how to avoid them\n", + "12. **Best practices**: Configuration, sizing, and performance\n", + "13. **Quick reference**: Method summary and common patterns\n", + "\n", + "### Next Steps\n", + "\n", + "- **Collector Deep Dive**: Learn how Collector fills buffers from environments\n", + "- **Policy Tutorial**: Understand how policies sample from buffers for training\n", + "- **Algorithm Examples**: See buffer usage in specific algorithms (DQN, SAC, PPO)\n", + "- **API Reference**: Full details at [Buffer API documentation](https://tianshou.org/en/stable/api/tianshou.data.html)\n", + "\n", + "### Further Resources\n", + "\n", + "- [Tianshou GitHub](https://github.com/thu-ml/tianshou) for source code and examples\n", + "- [Gymnasium Documentation](https://gymnasium.farama.org/) for environment conventions\n", + "- Research papers on experience replay and prioritized sampling" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/02_deep_dives/L3_Environments.ipynb b/docs/02_deep_dives/L3_Environments.ipynb new file mode 100644 index 000000000..dafac3717 --- /dev/null +++ b/docs/02_deep_dives/L3_Environments.ipynb @@ -0,0 +1,709 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Environments\n", + "\n", + "In reinforcement learning, agents interact with environments to improve their performance through trial and error. This tutorial explores how Tianshou handles environments, from basic single-environment setups to advanced vectorized and parallel configurations.\n", + "\n", + "
\n", + "
\n", + "The agent-environment interaction loop\n", + "
\n", + "\n", + "Tianshou maintains full compatibility with the [Gymnasium](https://gymnasium.farama.org/) API (formerly OpenAI Gym), making it easy to use any Gymnasium-compatible environment." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The Bottleneck Problem\n", + "\n", + "In a standard Gymnasium environment, each interaction follows a sequential pattern:\n", + "\n", + "1. Agent selects an action\n", + "2. Environment processes the action and returns observation and reward\n", + "3. Repeat\n", + "\n", + "This sequential process can become a significant bottleneck in deep reinforcement learning experiments, especially when:\n", + "- The environment simulation is computationally intensive\n", + "- Network training is fast but data collection is slow\n", + "- You have multiple CPU cores available but aren't using them\n", + "\n", + "Tianshou addresses this bottleneck through **vectorized environments**, which allow parallel sampling across multiple CPU cores." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Vectorized Environments\n", + "\n", + "Vectorized environments enable you to run multiple environment instances in parallel, dramatically accelerating data collection. Let's see this in action." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import time\n", + "\n", + "import gymnasium as gym\n", + "import numpy as np\n", + "\n", + "from tianshou.env import DummyVectorEnv, SubprocVectorEnv" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Performance Comparison\n", + "\n", + "Let's compare the sampling speed with different numbers of parallel environments:" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "num_cpus = [1, 2, 5]\n", + "\n", + "for num_cpu in num_cpus:\n", + " # Create vectorized environment with multiple processes\n", + " env = SubprocVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(num_cpu)])\n", + " env.reset()\n", + "\n", + " sampled_steps = 0\n", + " time_start = time.time()\n", + "\n", + " # Sample 1000 steps\n", + " while sampled_steps < 1000:\n", + " act = np.random.choice(2, size=num_cpu)\n", + " obs, rew, terminated, truncated, info = env.step(act)\n", + "\n", + " # Reset terminated environments\n", + " if np.sum(terminated):\n", + " env.reset(np.where(terminated)[0])\n", + "\n", + " sampled_steps += num_cpu\n", + "\n", + " time_used = time.time() - time_start\n", + " print(f\"Sampled 1000 steps in {time_used:.3f}s using {num_cpu} CPU(s)\")\n", + " print(f\" → Speed: {1000 / time_used:.1f} steps/second\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Understanding the Results\n", + "\n", + "You might notice that the speedup isn't perfectly linear with the number of CPUs. Several factors contribute to this:\n", + "\n", + "1. **Straggler Effect**: In synchronous mode, all environments must complete before the next batch begins. Slower environments hold back faster ones.\n", + "2. **Communication Overhead**: Inter-process communication has costs, especially for fast environments.\n", + "3. **Environment Complexity**: For simple environments like CartPole, the overhead may outweigh the benefits.\n", + "\n", + "> **Important**: `SubprocVectorEnv` should only be used when environment execution is slow. For simple, fast environments like CartPole, `DummyVectorEnv` (or even raw Gymnasium environments) can be more efficient because they avoid both the straggler effect and inter-process communication overhead." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Types of Vectorized Environments\n", + "\n", + "Tianshou provides several vectorized environment implementations, each optimized for different scenarios:\n", + "\n", + "### 1. DummyVectorEnv\n", + "**Pseudo-parallel simulation using a for-loop**\n", + "- Best for: Simple/fast environments, debugging\n", + "- Pros: No overhead, deterministic execution\n", + "- Cons: No actual parallelization\n", + "\n", + "### 2. SubprocVectorEnv\n", + "**Multiple processes for true parallel simulation**\n", + "- Best for: Most parallel simulation scenarios\n", + "- Pros: True parallelization, good balance\n", + "- Cons: Inter-process communication overhead\n", + "\n", + "### 3. ShmemVectorEnv\n", + "**Shared memory optimization of SubprocVectorEnv**\n", + "- Best for: Environments with large observations (e.g., images)\n", + "- Pros: Reduced memory footprint, faster for large states\n", + "- Cons: More complex implementation\n", + "\n", + "### 4. RayVectorEnv\n", + "**Ray-based distributed simulation**\n", + "- Best for: Cluster computing with multiple machines\n", + "- Pros: Scales to multiple machines\n", + "- Cons: Requires Ray installation and setup\n", + "\n", + "All these classes share the same API through their base class `BaseVectorEnv`, making it easy to switch between them." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "### Creating a Vectorized Environment" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Standard Gymnasium environment\n", + "gym_env = gym.make(\"CartPole-v1\")\n", + "\n", + "\n", + "# Tianshou vectorized environment\n", + "def create_cartpole_env() -> gym.Env:\n", + " return gym.make(\"CartPole-v1\")\n", + "\n", + "\n", + "# Create 5 parallel environments\n", + "vector_env = DummyVectorEnv([create_cartpole_env for _ in range(5)])\n", + "\n", + "print(f\"Created vectorized environment with {vector_env.env_num} environments\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Environment Interaction\n", + "\n", + "The key difference from standard Gymnasium is that actions, observations, and rewards are all vectorized:" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Standard Gymnasium: reset() returns a single observation\n", + "print(\"Standard Gymnasium reset:\")\n", + "single_obs, info = gym_env.reset()\n", + "print(f\" Shape: {single_obs.shape}\")\n", + "print(f\" Value: {single_obs}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 50 + \"\\n\")\n", + "\n", + "# Vectorized environment: reset() returns stacked observations\n", + "print(\"Vectorized environment reset:\")\n", + "vector_obs, info = vector_env.reset()\n", + "print(f\" Shape: {vector_obs.shape}\")\n", + "print(f\" Value:\\n{vector_obs}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Taking Vectorized Steps" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Take random actions in all environments\n", + "actions = np.random.choice(2, size=vector_env.env_num)\n", + "obs, rew, terminated, truncated, info = vector_env.step(actions)\n", + "\n", + "print(f\"Actions taken: {actions}\")\n", + "print(f\"Rewards received: {rew}\")\n", + "print(f\"Terminated flags: {terminated}\")\n", + "print(\"Info\", info)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Selective Environment Execution\n", + "\n", + "You can interact with specific environments using the `id` parameter:" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Execute only environments 0, 1, and 3\n", + "selected_actions = np.random.choice(2, size=3)\n", + "obs, rew, terminated, truncated, info = vector_env.step(selected_actions, id=[0, 1, 3])\n", + "\n", + "print(\"Executed actions in environments [0, 1, 3]\")\n", + "print(f\"Received {len(rew)} results\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parallel Sampling: Synchronous vs Asynchronous\n", + "\n", + "### Synchronous Mode (Default)\n", + "\n", + "By default, vectorized environments operate synchronously: a step completes only after **all** environments finish their step. This works well when all environments take roughly the same time per step.\n", + "\n", + "### Asynchronous Mode\n", + "\n", + "When environment step times vary significantly (e.g., 90% of steps take 1s, but 10% take 10s), asynchronous mode can help. It allows faster environments to continue without waiting for slower ones.\n", + "\n", + "
\n", + "
\n", + "Comparison of synchronous and asynchronous vectorized environments
\n", + "(Steps with the same color are processed together)\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Enabling Asynchronous Mode\n", + "\n", + "Use the `wait_num` or `timeout` parameters (or both):" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "from functools import partial\n", + "\n", + "\n", + "# Create environments with varying step times\n", + "class SlowEnv(gym.Env):\n", + " \"\"\"Environment with variable step duration.\"\"\"\n", + "\n", + " def __init__(self, sleep_time):\n", + " self.sleep_time = sleep_time\n", + " self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4,))\n", + " self.action_space = gym.spaces.Discrete(2)\n", + " super().__init__()\n", + "\n", + " def reset(self, seed=None, options=None):\n", + " super().reset(seed=seed)\n", + " return np.random.rand(4), {}\n", + "\n", + " def step(self, action):\n", + " time.sleep(self.sleep_time) # Simulate slow computation\n", + " return np.random.rand(4), 0.0, False, False, {}\n", + "\n", + "\n", + "# Create async vectorized environment\n", + "env_fns = [partial(SlowEnv, sleep_time=0.01 * i) for i in [1, 2, 3, 4]]\n", + "async_env = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.1)\n", + "\n", + "print(\"Asynchronous environment created\")\n", + "print(\" wait_num=3: Returns after 3 environments complete\")\n", + "print(\" timeout=0.1: Or after 0.1 seconds, whichever comes first\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### How Async Parameters Work\n", + "\n", + "- **`wait_num`**: Minimum number of environments to wait for (e.g., `wait_num=3` means each step returns results from at least 3 environments)\n", + "- **`timeout`**: Maximum time to wait in seconds (acts as a dynamic `wait_num`—returns whatever is ready after timeout)\n", + "- If no environment finishes within the timeout, the system waits until at least one completes\n", + "\n", + "> **Warning**: Asynchronous collectors can cause exceptions when used as `test_collector` in trainers. Always use synchronous mode for test collectors." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## EnvPool Integration\n", + "\n", + "[EnvPool](https://github.com/sail-sg/envpool/) is a C++-based vectorized environment library that provides significant performance improvements over Python-based solutions for many of the standard environments. Tianshou fully supports EnvPool with minimal code changes.\n", + "\n", + "### Why EnvPool?\n", + "\n", + "- **Performance**: 10x-100x faster than standard vectorized environments for supported environments\n", + "- **Memory Efficient**: Optimized memory usage through shared buffers\n", + "- **Drop-in Replacement**: Nearly identical API to Tianshou's vectorized environments\n", + "\n", + "### Supported Environments\n", + "\n", + "EnvPool currently supports:\n", + "- Atari games\n", + "- MuJoCo physics simulations\n", + "- VizDoom 3D environments\n", + "- Classic control environments\n", + "- Toy text environments" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using EnvPool\n", + "\n", + "First, install EnvPool:\n", + "\n", + "```bash\n", + "pip install envpool\n", + "```\n", + "\n", + "Then use it directly with Tianshou:\n", + "\n", + "```python\n", + "import envpool\n", + "\n", + "# Create EnvPool vectorized environment\n", + "envs = envpool.make_gymnasium(\"CartPole-v1\", num_envs=10)\n", + "\n", + "print(f\"Created EnvPool environment with {envs.spec.config.num_envs} environments\")\n", + "print(\"Ready to use with Tianshou collectors!\")\n", + "\n", + "# Use directly with Tianshou\n", + "collector = Collector(algorithm, envs, buffer)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EnvPool Examples\n", + "\n", + "For complete examples of using EnvPool with Tianshou:\n", + "- [Atari with EnvPool](https://github.com/thu-ml/tianshou/tree/master/examples/atari#envpool)\n", + "- [MuJoCo with EnvPool](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco#envpool)\n", + "- [VizDoom with EnvPool](https://github.com/thu-ml/tianshou/tree/master/examples/vizdoom#envpool)\n", + "- [More EnvPool Examples](https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom Environments and State Representations\n", + "\n", + "Tianshou works seamlessly with custom environments as long as they follow the Gymnasium API. Let's explore how to handle different state representations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Required Gymnasium API\n", + "\n", + "Your custom environment must implement:\n", + "\n", + "```python\n", + "class MyEnv(gym.Env):\n", + " def reset(self, seed=None, options=None) -> Tuple[observation, info]:\n", + " \"\"\"Reset environment to initial state.\"\"\"\n", + " pass\n", + " \n", + " def step(self, action) -> Tuple[observation, reward, terminated, truncated, info]:\n", + " \"\"\"Execute one step in the environment.\"\"\"\n", + " pass\n", + " \n", + " def seed(self, seed: int) -> List[int]:\n", + " \"\"\"Set random seed.\"\"\"\n", + " pass\n", + " \n", + " def render(self, mode='human') -> Any:\n", + " \"\"\"Render the environment.\"\"\"\n", + " pass\n", + " \n", + " def close(self) -> None:\n", + " \"\"\"Clean up resources.\"\"\"\n", + " pass\n", + " \n", + " # Required spaces\n", + " observation_space: gym.Space\n", + " action_space: gym.Space\n", + "```\n", + "\n", + "> **Important**: Make sure your `seed()` method is implemented correctly:\n", + "> ```python\n", + "> def seed(self, seed):\n", + "> np.random.seed(seed)\n", + "> # Also seed other random generators used in your environment\n", + "> ```\n", + "> Without proper seeding, parallel environments may produce identical outputs!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dictionary Observations\n", + "\n", + "Many environments return observations as dictionaries rather than simple arrays. Tianshou's `Batch` class handles this elegantly.\n", + "\n", + "Example with the FetchReach environment:" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "from tianshou.data import Batch, ReplayBuffer\n", + "\n", + "# Example: Creating a mock observation similar to FetchReach\n", + "observation = {\n", + " \"observation\": np.array([1.34, 0.75, 0.53, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),\n", + " \"achieved_goal\": np.array([1.34, 0.75, 0.53]),\n", + " \"desired_goal\": np.array([1.24, 0.78, 0.63]),\n", + "}\n", + "\n", + "# Store in replay buffer\n", + "buffer = ReplayBuffer(size=10)\n", + "buffer.add(Batch(obs=observation, act=0, rew=0.0, terminated=False, truncated=False))\n", + "\n", + "print(\"Stored observation structure:\")\n", + "print(buffer.obs)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Accessing Dictionary Observations\n", + "\n", + "When sampling from the buffer, you can access nested dictionary values in multiple ways:" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# Sample a batch\n", + "batch, indices = buffer.sample(batch_size=1)\n", + "\n", + "print(\"Batch keys:\", list(batch.keys()))\n", + "print(\"\\nAccessing nested observation:\")\n", + "\n", + "# Recommended way: access through batch first\n", + "print(\"batch.obs.desired_goal[0]:\", batch.obs.desired_goal[0])\n", + "\n", + "# Alternative ways (not recommended)\n", + "print(\"batch.obs[0].desired_goal:\", batch.obs[0].desired_goal)\n", + "print(\"batch[0].obs.desired_goal:\", batch[0].obs.desired_goal)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Using Dictionary Observations in Networks\n", + "\n", + "When designing networks for environments with dictionary observations:" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class CustomNetwork(nn.Module):\n", + " \"\"\"Network that processes dictionary observations.\"\"\"\n", + "\n", + " def __init__(self, obs_dim, goal_dim, hidden_dim, action_dim):\n", + " super().__init__()\n", + "\n", + " # Separate processing for different observation components\n", + " self.obs_encoder = nn.Linear(obs_dim, hidden_dim)\n", + " self.goal_encoder = nn.Linear(goal_dim * 2, hidden_dim) # achieved + desired\n", + "\n", + " # Combined processing\n", + " self.fc = nn.Sequential(\n", + " nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim)\n", + " )\n", + "\n", + " def forward(self, obs_batch, **kwargs):\n", + " # Extract components from the batch\n", + " observation = obs_batch.observation\n", + " achieved_goal = obs_batch.achieved_goal\n", + " desired_goal = obs_batch.desired_goal\n", + "\n", + " # Process each component\n", + " obs_feat = self.obs_encoder(observation)\n", + " goal_feat = self.goal_encoder(torch.cat([achieved_goal, desired_goal], dim=-1))\n", + "\n", + " # Combine and output\n", + " combined = torch.cat([obs_feat, goal_feat], dim=-1)\n", + " return self.fc(combined)\n", + "\n", + "\n", + "# Example usage\n", + "net = CustomNetwork(obs_dim=10, goal_dim=3, hidden_dim=64, action_dim=4)\n", + "print(\"Network created for dictionary observations\")\n", + "print(\" Input: observation (10D) + achieved_goal (3D) + desired_goal (3D)\")\n", + "print(\" Output: actions (4D)\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom Object States\n", + "\n", + "For more complex state representations (e.g., graphs, custom objects), Tianshou stores references in numpy arrays. However, you must ensure deep copies to avoid state aliasing:" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "import copy\n", + "\n", + "import networkx as nx\n", + "\n", + "\n", + "class GraphEnv(gym.Env):\n", + " \"\"\"Example environment with graph-based states.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.graph = nx.Graph()\n", + " self.action_space = gym.spaces.Discrete(5)\n", + " self.observation_space = gym.spaces.Box(low=0, high=1, shape=(10,)) # for compatibility\n", + "\n", + " def reset(self, seed=None, options=None):\n", + " super().reset(seed=seed)\n", + " self.graph = nx.erdos_renyi_graph(10, 0.3)\n", + " # IMPORTANT: Return deep copy to avoid reference issues\n", + " return copy.deepcopy(self.graph), {}\n", + "\n", + " def step(self, action):\n", + " # Modify graph based on action\n", + " if action < 4 and len(self.graph.nodes) > 0:\n", + " nodes = list(self.graph.nodes)\n", + " if len(nodes) >= 2:\n", + " self.graph.add_edge(nodes[0], nodes[1])\n", + "\n", + " # IMPORTANT: Return deep copy\n", + " return copy.deepcopy(self.graph), 0.0, False, False, {}\n", + "\n", + "\n", + "# Test storing graph objects\n", + "graph_buffer = ReplayBuffer(size=5)\n", + "env = GraphEnv()\n", + "obs, _ = env.reset()\n", + "graph_buffer.add(Batch(obs=obs, act=0, rew=0.0, terminated=False, truncated=False))\n", + "\n", + "print(\"Graph objects stored in buffer:\")\n", + "print(graph_buffer.obs)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> **Important**: When using custom objects as states:\n", + "> 1. Always return `copy.deepcopy(state)` in both `reset()` and `step()`\n", + "> 2. Ensure the object is numpy-compatible: `np.array([your_object])` should not result in an empty array\n", + "> 3. The object may be stored as a shallow copy in the buffer—deep copying prevents state aliasing" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Best Practices Summary\n", + "\n", + "### Choosing the Right Environment Wrapper\n", + "\n", + "| Scenario | Recommended Wrapper | Why |\n", + "|----------|-------------------|-----|\n", + "| Simple/fast environments | `DummyVectorEnv` or raw Gym | Minimal overhead |\n", + "| Most parallel scenarios | `SubprocVectorEnv` | Good balance of speed and simplicity |\n", + "| Large observations (images) | `ShmemVectorEnv` | Optimized memory usage |\n", + "| Multi-machine clusters | `RayVectorEnv` | Distributed computing support |\n", + "| Maximum performance | EnvPool | C++-based, 10x-100x speedup |\n", + "\n", + "### Performance Tips\n", + "\n", + "1. **Profile First**: Measure whether environment or training is your bottleneck before optimizing\n", + "2. **Start Simple**: Begin with `DummyVectorEnv` for debugging, then upgrade to parallel versions\n", + "3. **Use EnvPool**: If your environment is supported, EnvPool offers the best performance\n", + "4. **Async for Variable Times**: Use asynchronous mode only when environment step times vary significantly\n", + "5. **Proper Seeding**: Always implement the `seed()` method correctly in custom environments\n", + "\n", + "### Common Pitfalls\n", + "\n", + "- ❌ Using `SubprocVectorEnv` for fast environments → Use `DummyVectorEnv` instead\n", + "- ❌ Forgetting to deep-copy custom states → States will be aliased in the buffer\n", + "- ❌ Not implementing `seed()` properly → Parallel environments produce identical results\n", + "- ❌ Using async collectors for testing → Causes exceptions in trainers\n", + "- ❌ Assuming linear speedup → Account for communication overhead and straggler effects" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Further Reading\n", + "\n", + "- **Tianshou Documentation**: [Environment API Reference](https://tianshou.org/en/master/03_api/env/venvs.html)\n", + "- **EnvPool**: [Official Documentation](https://envpool.readthedocs.io/)\n", + "- **Gymnasium**: [Environment Creation Tutorial](https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/)\n", + "- **Ray**: [Distributed RL with Ray](https://docs.ray.io/en/latest/rllib/index.html)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/02_notebooks/L4_GAE.ipynb b/docs/02_deep_dives/L4_GAE.ipynb similarity index 99% rename from docs/02_notebooks/L4_GAE.ipynb rename to docs/02_deep_dives/L4_GAE.ipynb index 8393d6f92..87d318a03 100644 --- a/docs/02_notebooks/L4_GAE.ipynb +++ b/docs/02_deep_dives/L4_GAE.ipynb @@ -6,7 +6,7 @@ "id": "QJ5krjrcbuiA" }, "source": [ - "# Notes on Generalized Advantage Estimation\n", + "# Generalized Advantage Estimation\n", "\n", "\n" ] diff --git a/docs/02_deep_dives/L5_Collector.ipynb b/docs/02_deep_dives/L5_Collector.ipynb new file mode 100644 index 000000000..5bf3de12b --- /dev/null +++ b/docs/02_deep_dives/L5_Collector.ipynb @@ -0,0 +1,713 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "M98bqxdMsTXK" + }, + "source": [ + "# Collector\n", + "\n", + "The Collector serves as the orchestration layer between the policy (agent) and the environment in Tianshou's architecture. It manages the interaction loop, persists collected experiences to a replay buffer, and computes episode-level statistics. This module is fundamental to both training data collection and policy evaluation workflows." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OX5cayLv4Ziu" + }, + "source": [ + "## Core Applications\n", + "\n", + "The Collector supports two primary use cases in reinforcement learning experiments:\n", + "1. **Training**: Collecting interaction data for policy optimization\n", + "2. **Evaluation**: Assessing policy performance without learning" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z6XKbj28u8Ze" + }, + "source": [ + "### Policy Evaluation\n", + "\n", + "Periodic policy evaluation is essential in deep reinforcement learning (DRL) experiments to monitor training progress and assess generalization. The Collector provides a standardized interface for this purpose.\n", + "\n", + "**Setup**: A Collector requires two components:\n", + "1. An environment (or vectorized environment for parallelization)\n", + "2. A policy instance to evaluate" + ] + }, + { + "cell_type": "code", + "metadata": { + "editable": true, + "id": "w8t9ubO7u69J", + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-cell", + "remove-output" + ], + "ExecuteTime": { + "end_time": "2025-10-26T21:59:25.914405Z", + "start_time": "2025-10-26T21:59:22.196044Z" + } + }, + "source": [ + "import gymnasium as gym\n", + "import torch\n", + "\n", + "from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy" + ], + "outputs": [], + "execution_count": 1 + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-26T21:59:30.621207Z", + "start_time": "2025-10-26T21:59:25.922401Z" + } + }, + "source": [ + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.utils.net.common import Net\n", + "from tianshou.utils.net.discrete import DiscreteActor\n", + "\n", + "# Initialize single environment for configuration\n", + "env = gym.make(\"CartPole-v1\")\n", + "\n", + "# Create vectorized test environments (2 parallel environments)\n", + "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", + "\n", + "# Configure neural network architecture\n", + "assert env.observation_space.shape is not None # for mypy\n", + "preprocess_net = Net(\n", + " state_shape=env.observation_space.shape,\n", + " hidden_sizes=[\n", + " 16,\n", + " ],\n", + ")\n", + "\n", + "# Initialize discrete action actor network\n", + "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", + "actor = DiscreteActor(preprocess_net=preprocess_net, action_shape=env.action_space.n)\n", + "\n", + "# Create policy with categorical action distribution\n", + "policy = ProbabilisticActorPolicy(\n", + " actor=actor,\n", + " dist_fn=torch.distributions.Categorical,\n", + " action_space=env.action_space,\n", + " action_scaling=False,\n", + ")\n", + "\n", + "# Initialize collector for evaluation\n", + "test_collector = Collector[CollectStats](policy, test_envs)" + ], + "outputs": [], + "execution_count": 2 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wmt8vuwpzQdR" + }, + "source": [ + "### Evaluating Untrained Policy Performance\n", + "\n", + "We now evaluate the randomly initialized policy across 9 episodes to establish a baseline performance metric:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9SuT6MClyjyH", + "outputId": "1e48f13b-c1fe-4fc2-ca1b-669485efdcae", + "ExecuteTime": { + "end_time": "2025-10-26T21:59:31.362074Z", + "start_time": "2025-10-26T21:59:30.752198Z" + } + }, + "source": [ + "# Collect 9 complete episodes with environment reset\n", + "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n", + "\n", + "collect_result.pprint_asdict()" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CollectStats\n", + "----------------------------------------\n", + "{ 'collect_speed': 288.36823267420584,\n", + " 'collect_time': 0.5860562324523926,\n", + " 'lens': array([15, 22, 29, 22, 8, 16, 28, 10, 19]),\n", + " 'lens_stat': { 'max': 29.0,\n", + " 'mean': 18.77777777777778,\n", + " 'min': 8.0,\n", + " 'std': 6.876332643007022},\n", + " 'n_collected_episodes': 9,\n", + " 'n_collected_steps': 169,\n", + " 'pred_dist_std_array': array([[0.49482444],\n", + " [0.49513358],\n", + " [0.491721 ],\n", + " [0.49804375],\n", + " [0.48936436],\n", + " [0.49519676],\n", + " [0.49186328],\n", + " [0.4981152 ],\n", + " [0.49512368],\n", + " [0.49527684],\n", + " [0.49800068],\n", + " [0.4982014 ],\n", + " [0.49516457],\n", + " [0.4953748 ],\n", + " [0.49805665],\n", + " [0.49269873],\n", + " [0.49946678],\n", + " [0.49553478],\n", + " [0.49997097],\n", + " [0.49289048],\n", + " [0.4998387 ],\n", + " [0.4957225 ],\n", + " [0.49908724],\n", + " [0.4930759 ],\n", + " [0.49776313],\n", + " [0.49590224],\n", + " [0.49913287],\n", + " [0.4986142 ],\n", + " [0.4998805 ],\n", + " [0.49605882],\n", + " [0.49581137],\n", + " [0.49861884],\n", + " [0.4922438 ],\n", + " [0.4962572 ],\n", + " [0.49569792],\n", + " [0.49863097],\n", + " [0.4982123 ],\n", + " [0.49961406],\n", + " [0.49553847],\n", + " [0.4999985 ],\n", + " [0.49808985],\n", + " [0.4997094 ],\n", + " [0.49964666],\n", + " [0.4987858 ],\n", + " [0.49796286],\n", + " [0.4948797 ],\n", + " [0.49960598],\n", + " [0.4916098 ],\n", + " [0.4999896 ],\n", + " [0.49003887],\n", + " [0.4997966 ],\n", + " [0.48927104],\n", + " [0.4999768 ],\n", + " [0.4899478 ],\n", + " [0.49948972],\n", + " [0.49140957],\n", + " [0.4978501 ],\n", + " [0.49466696],\n", + " [0.49509352],\n", + " [0.49118617],\n", + " [0.49797186],\n", + " [0.49447665],\n", + " [0.49950802],\n", + " [0.49740306],\n", + " [0.498081 ],\n", + " [0.49935713],\n", + " [0.49534237],\n", + " [0.49994358],\n", + " [0.49823537],\n", + " [0.499905 ],\n", + " [0.4955164 ],\n", + " [0.49991024],\n", + " [0.49839276],\n", + " [0.4999328 ],\n", + " [0.49570385],\n", + " [0.4993451 ],\n", + " [0.49855497],\n", + " [0.49995714],\n", + " [0.4995841 ],\n", + " [0.49939576],\n", + " [0.49999622],\n", + " [0.49824795],\n", + " [0.49972966],\n", + " [0.49653304],\n", + " [0.49880832],\n", + " [0.49425703],\n", + " [0.49974373],\n", + " [0.49662435],\n", + " [0.49473634],\n", + " [0.49472553],\n", + " [0.49773476],\n", + " [0.49140546],\n", + " [0.49935693],\n", + " [0.48954245],\n", + " [0.4999408 ],\n", + " [0.491403 ],\n", + " [0.49988943],\n", + " [0.49483237],\n", + " [0.49920663],\n", + " [0.49134007],\n", + " [0.49793968],\n", + " [0.4894678 ],\n", + " [0.49924025],\n", + " [0.4912263 ],\n", + " [0.4945435 ],\n", + " [0.49469063],\n", + " [0.49759832],\n", + " [0.49754107],\n", + " [0.49466005],\n", + " [0.49943802],\n", + " [0.4977191 ],\n", + " [0.49995443],\n", + " [0.49479046],\n", + " [0.49937534],\n", + " [0.49785116],\n", + " [0.49731255],\n", + " [0.49934685],\n", + " [0.4993554 ],\n", + " [0.49798217],\n", + " [0.4999266 ],\n", + " [0.4993439 ],\n", + " [0.49931702],\n", + " [0.49815634],\n", + " [0.49991363],\n", + " [0.4993506 ],\n", + " [0.49928144],\n", + " [0.49821213],\n", + " [0.4973895 ],\n", + " [0.49938264],\n", + " [0.4992856 ],\n", + " [0.4999623 ],\n", + " [0.49991205],\n", + " [0.49940434],\n", + " [0.49991933],\n", + " [0.49825713],\n", + " [0.49990463],\n", + " [0.49554875],\n", + " [0.49924377],\n", + " [0.49196848],\n", + " [0.49991465],\n", + " [0.48965713],\n", + " [0.49991086],\n", + " [0.4888782 ],\n", + " [0.49921995],\n", + " [0.48808664],\n", + " [0.49516302],\n", + " [0.48725367],\n", + " [0.49179506],\n", + " [0.4879356 ],\n", + " [0.4952572 ],\n", + " [0.48861024],\n", + " [0.49187768],\n", + " [0.48927858],\n", + " [0.4953288 ],\n", + " [0.48839873],\n", + " [0.49193203],\n", + " [0.49538046],\n", + " [0.49808696],\n", + " [0.49537748],\n", + " [0.49810043],\n", + " [0.4953903 ],\n", + " [0.4981276 ],\n", + " [0.49956635],\n", + " [0.49998853],\n", + " [0.49978945],\n", + " [0.49897715],\n", + " [0.4975953 ],\n", + " [0.49903452],\n", + " [0.49765074]], dtype=float32),\n", + " 'pred_dist_std_array_stat': { 0: { 'max': 0.4999985098838806,\n", + " 'mean': 0.4965951144695282,\n", + " 'min': 0.48725366592407227,\n", + " 'std': 0.003376598935574293}},\n", + " 'returns': array([15., 22., 29., 22., 8., 16., 28., 10., 19.]),\n", + " 'returns_stat': { 'max': 29.0,\n", + " 'mean': 18.77777777777778,\n", + " 'min': 8.0,\n", + " 'std': 6.876332643007022}}\n" + ] + } + ], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zX9AQY0M0R3C" + }, + "source": [ + "### Baseline Comparison: Random Policy\n", + "\n", + "To contextualize the initialized policy's performance, we establish a random action baseline:" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UEcs8P8P0RLt", + "outputId": "85f02f9d-b79b-48b2-99c6-36a1602f0884", + "ExecuteTime": { + "end_time": "2025-10-26T21:59:31.431099Z", + "start_time": "2025-10-26T21:59:31.371074Z" + } + }, + "source": [ + "# Evaluate random policy by sampling actions uniformly from action space\n", + "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n", + "\n", + "collect_result.pprint_asdict()" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CollectStats\n", + "----------------------------------------\n", + "{ 'collect_speed': 4407.5322624798,\n", + " 'collect_time': 0.053998470306396484,\n", + " 'lens': array([11, 13, 15, 29, 15, 12, 15, 30, 98]),\n", + " 'lens_stat': { 'max': 98.0,\n", + " 'mean': 26.444444444444443,\n", + " 'min': 11.0,\n", + " 'std': 26.16236105175657},\n", + " 'n_collected_episodes': 9,\n", + " 'n_collected_steps': 238,\n", + " 'pred_dist_std_array': None,\n", + " 'pred_dist_std_array_stat': None,\n", + " 'returns': array([11., 13., 15., 29., 15., 12., 15., 30., 98.]),\n", + " 'returns_stat': { 'max': 98.0,\n", + " 'mean': 26.444444444444443,\n", + " 'min': 11.0,\n", + " 'std': 26.16236105175657}}\n" + ] + } + ], + "execution_count": 4 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sKQRTiG10ljU" + }, + "source": [ + "**Observation**: The randomly initialized policy performs comparably to (or worse than) uniform random actions prior to training. This is expected behavior, as the network weights lack task-specific optimization." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8RKmHIoG1A1k" + }, + "source": [ + "### Training Data Collection\n", + "\n", + "During the training phase, the Collector manages experience gathering and automatic storage in a replay buffer. This enables the experience replay mechanism fundamental to off-policy algorithms." + ] + }, + { + "cell_type": "code", + "metadata": { + "editable": true, + "id": "CB9XB9bF1YPC", + "slideshow": { + "slide_type": "" + }, + "tags": [], + "ExecuteTime": { + "end_time": "2025-10-26T21:59:31.452144Z", + "start_time": "2025-10-26T21:59:31.444096Z" + } + }, + "source": [ + "# Configuration for parallel training data collection\n", + "train_env_num = 4\n", + "buffer_size = 100\n", + "\n", + "# Initialize vectorized training environments\n", + "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", + "\n", + "# Create replay buffer compatible with vectorized environments\n", + "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", + "\n", + "# Initialize training collector with buffer integration\n", + "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" + ], + "outputs": [], + "execution_count": 5 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rWKDazA42IUQ" + }, + "source": [ + "### Step-Based Collection\n", + "\n", + "The Collector supports both step-based and episode-based collection modes. Here we demonstrate step-based collection, which is commonly used in training loops with fixed update frequencies.\n", + "\n", + "**Note**: When using vectorized environments, the actual number of collected steps may exceed the requested amount to maintain synchronization across parallel environments." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-fUtQOnM2Yi1", + "outputId": "dceee987-433e-4b75-ed9e-823c20a9e1c2", + "ExecuteTime": { + "end_time": "2025-10-26T21:59:31.501487Z", + "start_time": "2025-10-26T21:59:31.459140Z" + } + }, + "source": [ + "# Reset collector and buffer to clean state\n", + "train_collector.reset()\n", + "replayBuffer.reset()\n", + "\n", + "print(f\"Replay buffer before collecting is empty, and has length={len(replayBuffer)} \\n\")\n", + "\n", + "# Collect 50 environment steps\n", + "n_step = 50\n", + "collect_result = train_collector.collect(n_step=n_step)\n", + "\n", + "print(\n", + " f\"Replay buffer after collecting {n_step} steps has length={len(replayBuffer)}.\\n\"\n", + " f\"The actual count may exceed n_step when it is not a multiple of train_env_num \\n\"\n", + " f\"due to vectorization synchronization requirements.\\n\",\n", + ")\n", + "\n", + "collect_result.pprint_asdict()" + ], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Replay buffer before collecting is empty, and has length=0 \n", + "\n", + "Replay buffer after collecting 50 steps has length=52.\n", + "The actual count may exceed n_step when it is not a multiple of train_env_num \n", + "due to vectorization synchronization requirements.\n", + "\n", + "CollectStats\n", + "----------------------------------------\n", + "{ 'collect_speed': 1529.5011711244197,\n", + " 'collect_time': 0.03399801254272461,\n", + " 'lens': array([], dtype=int32),\n", + " 'lens_stat': None,\n", + " 'n_collected_episodes': 0,\n", + " 'n_collected_steps': 52,\n", + " 'pred_dist_std_array': array([[0.4944575 ],\n", + " [0.49571753],\n", + " [0.49482644],\n", + " [0.49571693],\n", + " [0.49746 ],\n", + " [0.49228 ],\n", + " [0.491648 ],\n", + " [0.49237084],\n", + " [0.49931562],\n", + " [0.48953396],\n", + " [0.4949102 ],\n", + " [0.49022076],\n", + " [0.49992043],\n", + " [0.4921799 ],\n", + " [0.49171764],\n", + " [0.4894729 ],\n", + " [0.4992769 ],\n", + " [0.48948848],\n", + " [0.49497682],\n", + " [0.48870105],\n", + " [0.49763048],\n", + " [0.49201292],\n", + " [0.49787888],\n", + " [0.4893877 ],\n", + " [0.49927947],\n", + " [0.4955971 ],\n", + " [0.49943653],\n", + " [0.49005648],\n", + " [0.49780723],\n", + " [0.49179533],\n", + " [0.49995926],\n", + " [0.49153325],\n", + " [0.49928913],\n", + " [0.48941523],\n", + " [0.49986592],\n", + " [0.49499276],\n", + " [0.4999287 ],\n", + " [0.49152908],\n", + " [0.4991583 ],\n", + " [0.49093276],\n", + " [0.4998997 ],\n", + " [0.48936346],\n", + " [0.4978821 ],\n", + " [0.49442956],\n", + " [0.49992698],\n", + " [0.49117777],\n", + " [0.49921465],\n", + " [0.49751103],\n", + " [0.4992887 ],\n", + " [0.4893143 ],\n", + " [0.49991187],\n", + " [0.4992216 ]], dtype=float32),\n", + " 'pred_dist_std_array_stat': { 0: { 'max': 0.499959260225296,\n", + " 'mean': 0.49497732520103455,\n", + " 'min': 0.4887010455131531,\n", + " 'std': 0.003929081838577986}},\n", + " 'returns': array([], dtype=float64),\n", + " 'returns_stat': None}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "F:\\Users\\Dominik Jain\\Dev\\AI\\tianshou\\tianshou\\data\\collector.py:537: UserWarning: n_step=50 is not a multiple of (self.env_num=4), which may cause extra transitions being collected into the buffer.\n", + " warnings.warn(\n" + ] + } + ], + "execution_count": 6 + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Buffer Sampling Verification\n", + "\n", + "Verify that collected experiences are properly stored and can be sampled for training:" + ] + }, + { + "cell_type": "code", + "metadata": { + "ExecuteTime": { + "end_time": "2025-10-26T21:59:31.517583Z", + "start_time": "2025-10-26T21:59:31.509483Z" + } + }, + "source": [ + "# Sample mini-batch of 10 transitions from buffer\n", + "replayBuffer.sample(10)" + ], + "outputs": [ + { + "data": { + "text/plain": [ + "(Batch(\n", + " obs: array([[-7.59119692e-04, -3.54404569e-01, 8.15278068e-02,\n", + " 6.34967446e-01],\n", + " [ 2.03953441e-02, -5.46947002e-01, 4.59121428e-02,\n", + " 8.69558692e-01],\n", + " [-5.53812869e-02, -3.63834441e-01, 1.84285983e-01,\n", + " 8.54350269e-01],\n", + " [ 5.94463721e-02, -3.39802876e-02, -5.61027192e-02,\n", + " -2.05838066e-02],\n", + " [ 1.70439295e-02, -3.58715117e-01, 2.22064722e-02,\n", + " 6.39448643e-01],\n", + " [ 1.51256351e-02, 2.27344140e-01, 1.95531528e-02,\n", + " -2.54039675e-01],\n", + " [-7.69001395e-02, -7.54580617e-01, 1.79230303e-01,\n", + " 1.36748278e+00],\n", + " [-3.51171643e-02, -1.14145672e+00, 1.09657384e-01,\n", + " 1.86768615e+00],\n", + " [ 2.10114848e-02, 3.47817928e-01, -1.05900057e-01,\n", + " -6.93330288e-01],\n", + " [-1.53460149e-02, 5.40259123e-01, -4.36654910e-02,\n", + " -9.24050748e-01]], dtype=float32),\n", + " act: array([1, 1, 0, 1, 0, 0, 1, 1, 0, 1], dtype=int64),\n", + " rew: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n", + " terminated: array([False, False, False, False, False, False, False, False, False,\n", + " False]),\n", + " truncated: array([False, False, False, False, False, False, False, False, False,\n", + " False]),\n", + " done: array([False, False, False, False, False, False, False, False, False,\n", + " False]),\n", + " obs_next: array([[-0.00784721, -0.16050874, 0.09422715, 0.36903235],\n", + " [ 0.0094564 , -0.3524787 , 0.06330331, 0.59165704],\n", + " [-0.06265797, -0.5609251 , 0.20137298, 1.1988543 ],\n", + " [ 0.05876676, 0.16189948, -0.05651439, -0.33042672],\n", + " [ 0.00986963, -0.5541395 , 0.03499544, 0.93904114],\n", + " [ 0.01967252, 0.03194853, 0.01447236, 0.04474595],\n", + " [-0.09199175, -0.5620968 , 0.20657995, 1.135794 ],\n", + " [-0.05794629, -0.94769216, 0.1470111 , 1.6109598 ],\n", + " [ 0.02796784, 0.15431203, -0.11976666, -0.435774 ],\n", + " [-0.00454083, 0.73594284, -0.06214651, -1.2301302 ]],\n", + " dtype=float32),\n", + " info: Batch(\n", + " env_id: array([0, 0, 0, 1, 2, 2, 2, 2, 3, 3]),\n", + " ),\n", + " policy: Batch(),\n", + " ),\n", + " array([ 6, 3, 12, 31, 56, 53, 62, 60, 81, 78]))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 7 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8NP7lOBU3-VS" + }, + "source": [ + "## Advanced Topics\n", + "\n", + "### Asynchronous Collection\n", + "\n", + "The standard `Collector` implementation may collect more steps than requested when using vectorized environments. In the example above, requesting 50 steps resulted in 52 steps (the smallest multiple of 4 that is ≥50).\n", + "\n", + "For scenarios requiring precise step counts, Tianshou provides the `AsyncCollector`, which enables exact step collection at the cost of additional implementation complexity. This is particularly relevant for:\n", + "- Strict reproducibility requirements\n", + "- Algorithms sensitive to exact batch sizes\n", + "- Fine-grained control over data collection\n", + "\n", + "Consult the [AsyncCollector documentation](https://tianshou.org/en/master/03_api/data/collector.html#tianshou.data.collector.AsyncCollector) for implementation details and usage patterns." + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/02_deep_dives/L6_MARL.ipynb b/docs/02_deep_dives/L6_MARL.ipynb new file mode 100644 index 000000000..855e7e3f7 --- /dev/null +++ b/docs/02_deep_dives/L6_MARL.ipynb @@ -0,0 +1,840 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-Agent Reinforcement Learning (MARL)\n", + "\n", + "This tutorial demonstrates how to use Tianshou for multi-agent reinforcement learning scenarios. We'll explore different MARL paradigms and implement a practical example using the Tic-Tac-Toe game." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MARL Paradigms\n", + "\n", + "Tianshou supports three fundamental types of multi-agent reinforcement learning paradigms:\n", + "\n", + "1. **Simultaneous move**: All agents take their actions at each timestep simultaneously (e.g., MOBA games)\n", + "2. **Cyclic move**: Agents take actions sequentially in turns (e.g., Go)\n", + "3. **Conditional move**: The environment conditionally selects which agent acts at each timestep (e.g., [Pig Game](https://en.wikipedia.org/wiki/Pig_(dice_game)))\n", + "\n", + "Our approach addresses these multi-agent RL problems by converting them into traditional single-agent RL formulations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Converting MARL to Single-Agent RL\n", + "\n", + "### Simultaneous Move\n", + "\n", + "For simultaneous-move scenarios, the solution is straightforward: we add an extra `num_agents` dimension to the state, action, and reward tensors. No other modifications are necessary.\n", + "\n", + "### Cyclic and Conditional Move\n", + "\n", + "Both cyclic and conditional move scenarios can be unified into a single framework. At each timestep, the environment selects an agent identified by `agent_id` to act. Since multiple agents are typically wrapped into a single object (the \"abstract agent\"), we pass the `agent_id` to this abstract agent, which then delegates the action to the appropriate specific agent.\n", + "\n", + "Additionally, in multi-agent RL, the set of legal actions often varies across timesteps (as in Go). Therefore, the environment must also provide a legal action mask to the abstract agent. This mask is a boolean array where `True` indicates available actions and `False` indicates illegal actions at the current timestep.\n", + "\n", + "
\n", + "
\n", + "The abstract agent framework for multi-agent RL\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Unified Formulation\n", + "\n", + "This architecture leads to the following formulation of multi-agent RL:\n", + "\n", + "```python\n", + "act = policy(state, agent_id, mask)\n", + "(next_state, next_agent_id, next_mask), reward = env.step(act)\n", + "```\n", + "\n", + "By constructing an augmented state `state_ = (state, agent_id, mask)`, we can reduce this to the standard single-agent RL formulation:\n", + "\n", + "```python\n", + "act = policy(state_)\n", + "next_state_, reward = env.step(act)\n", + "```\n", + "\n", + "Following this principle, we'll implement a Q-learning algorithm to play [Tic-Tac-Toe](https://en.wikipedia.org/wiki/Tic-tac-toe) against a random opponent." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## PettingZoo Integration\n", + "\n", + "Tianshou is fully compatible with [PettingZoo](https://pettingzoo.farama.org/) environments for multi-agent RL. While Tianshou doesn't directly provide specialized MARL facilities, it offers a flexible framework that can be adapted to various MARL scenarios.\n", + "\n", + "For comprehensive tutorials on using Tianshou with PettingZoo, refer to:\n", + "\n", + "* [Beginner Tutorial](https://pettingzoo.farama.org/tutorials/tianshou/beginner/)\n", + "* [Intermediate Tutorial](https://pettingzoo.farama.org/tutorials/tianshou/intermediate/)\n", + "* [Advanced Tutorial](https://pettingzoo.farama.org/tutorials/tianshou/advanced/)\n", + "\n", + "In this tutorial, we'll demonstrate how to use Tianshou in a multi-agent setting where only one agent is trained while the other uses a fixed random policy. You can then use this as a blueprint to replace the random policy with another trainable agent.\n", + "\n", + "Specifically, we'll train an agent to play Tic-Tac-Toe against a random opponent:\n", + "\n", + "
\n", + "
\n", + "Tic-Tac-Toe game board\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Exploring the Tic-Tac-Toe Environment\n", + "\n", + "The complete scripts are located in `test/pettingzoo/`. Tianshou provides the `PettingZooEnv` wrapper class that can wrap any PettingZoo environment. Let's explore the 3×3 Tic-Tac-Toe environment provided by PettingZoo." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pettingzoo.classic import tictactoe_v3 # the Tic-Tac-Toe environment\n", + "\n", + "from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments\n", + "\n", + "# Initialize the environment\n", + "# The board has 3 rows and 3 columns (9 positions total)\n", + "# Players place 'X' and 'O' alternately on the board\n", + "# The first player to get 3 consecutive marks wins\n", + "env = PettingZooEnv(tictactoe_v3.env(render_mode=\"human\"))\n", + "obs = env.reset()\n", + "env.render() # render the empty board" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output shows an empty 3×3 board:\n", + "\n", + "```\n", + "board (step 0):\n", + " | |\n", + " - | - | -\n", + "_____|_____|_____\n", + " | |\n", + " - | - | -\n", + "_____|_____|_____\n", + " | |\n", + " - | - | -\n", + " | |\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Examine the observation structure\n", + "print(obs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Understanding the Observation Space\n", + "\n", + "The observation returned by the environment is a dictionary with three keys:\n", + "\n", + "- **`agent_id`**: The identifier of the currently acting agent (e.g., `'player_1'` or `'player_2'`)\n", + "\n", + "- **`obs`**: The actual environment observation. For Tic-Tac-Toe, this is a numpy array with shape `(3, 3, 2)`:\n", + " - For `player_1`: The first 3×3 plane represents X placements, the second plane represents O placements\n", + " - For `player_2`: The planes are swapped (O in first plane, X in second)\n", + " - Each cell contains either 0 (empty/not placed) or 1 (mark placed)\n", + "\n", + "- **`mask`**: A boolean array indicating legal actions at the current timestep. For Tic-Tac-Toe, index `i` corresponds to position `(i // 3, i % 3)` on the board. If `mask[i] == True`, the player can place their mark at that position. Initially, all positions are available, so all mask values are `True`.\n", + "\n", + "> **Note**: The mask representation is flexible and works for both discrete and continuous action spaces. While we use a boolean array here, you could also use action spaces like `gymnasium.spaces.Discrete` or `gymnasium.spaces.Box` to represent available actions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Playing a Few Steps\n", + "\n", + "Let's play a couple of moves to understand the environment dynamics better." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# Take an action (place mark at position 0 - top-left corner)\n", + "action = 0 # action can be an integer or a numpy array with one element\n", + "obs, reward, done, truncated, info = env.step(action) # follows the Gymnasium API\n", + "\n", + "print(\"Observation after first move:\")\n", + "print(obs)\n", + "\n", + "# Examine the reward structure\n", + "# Reward has two items (one for each player): 1 for win, -1 for loss, 0 otherwise\n", + "print(f\"\\nReward: {reward}\")\n", + "\n", + "# Check if the game is over\n", + "print(f\"Done: {done}\")\n", + "\n", + "# Info is typically an empty dict in Tic-Tac-Toe but may contain useful information in other environments\n", + "print(f\"Info: {info}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that after the first move:\n", + "- The `agent_id` switches to `'player_2'`\n", + "- The observation array shows the X placement in the first position\n", + "- The mask now has `False` at index 0 (that position is occupied)\n", + "- The reward is `[0, 0]` (no winner yet)\n", + "- The game continues (`done = False`)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "Note: If we continue playing, the game terminates when only one empty position remains, rather than when the board is completely full. This is because a player with only one available position has no meaningful choice." + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Random Agents\n", + "\n", + "Now that we understand the environment, let's start by watching two random agents play against each other.\n", + "\n", + "Tianshou provides built-in classes for multi-agent learning. The key components are:\n", + "\n", + "- **`RandomPolicy`**: A policy that randomly selects actions\n", + "- **`MultiAgentPolicyManager`**: Manages multiple agent policies and delegates actions to the appropriate agent based on `agent_id`\n", + "\n", + "
\n", + "
\n", + "The relationship between MultiAgentPolicyManager and individual agent policies\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tianshou.algorithm.multiagent.marl import MultiAgentOffPolicyAlgorithm\n", + "from tianshou.algorithm.random import MARLRandomDiscreteMaskedOffPolicyAlgorithm\n", + "from tianshou.data import Collector\n", + "from tianshou.env import DummyVectorEnv\n", + "\n", + "# Create a multi-agent algorithm with two random agents\n", + "policy = MultiAgentOffPolicyAlgorithm(\n", + " algorithms=[\n", + " MARLRandomDiscreteMaskedOffPolicyAlgorithm(action_space=env.action_space),\n", + " MARLRandomDiscreteMaskedOffPolicyAlgorithm(action_space=env.action_space),\n", + " ],\n", + " env=env,\n", + ")\n", + "\n", + "# Vectorize the environment for the collector\n", + "env = DummyVectorEnv([lambda: env])\n", + "\n", + "# Create a collector to gather trajectories\n", + "collector = Collector(policy, env)\n", + "\n", + "# Collect and visualize one episode\n", + "result = collector.collect(n_episode=1, render=0.1, reset_before_collect=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You'll see the game progress step by step. Here's an example of the final moves:\n", + "\n", + "```\n", + " | |\n", + " X | X | -\n", + "_____|_____|_____\n", + " | |\n", + " X | O | -\n", + "_____|_____|_____\n", + " | |\n", + " O | - | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " X | X | -\n", + "_____|_____|_____\n", + " | |\n", + " X | O | -\n", + "_____|_____|_____\n", + " | |\n", + " O | - | O\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " X | X | X\n", + "_____|_____|_____\n", + " | |\n", + " X | O | -\n", + "_____|_____|_____\n", + " | |\n", + " O | - | O\n", + " | |\n", + "```\n", + "\n", + "Random agents perform poorly. In the game above, although agent 2 eventually wins, a smart agent 1 would have won immediately by placing an X at position (1, 1) (center of middle row)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training an Agent Against a Random Opponent\n", + "\n", + "Now let's train an intelligent agent! We'll use Deep Q-Network (DQN) to learn optimal play against a random opponent." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Imports and Setup\n", + "\n", + "First, let's import all necessary modules:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from copy import deepcopy\n", + "from functools import partial\n", + "\n", + "import gymnasium\n", + "import torch\n", + "from pettingzoo.classic import tictactoe_v3\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "\n", + "from tianshou.algorithm import (\n", + " DQN,\n", + " Algorithm,\n", + " MARLRandomDiscreteMaskedOffPolicyAlgorithm,\n", + " MultiAgentOffPolicyAlgorithm,\n", + ")\n", + "from tianshou.algorithm.algorithm_base import OffPolicyAlgorithm\n", + "from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy\n", + "from tianshou.algorithm.optim import AdamOptimizerFactory, OptimizerFactory\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", + "from tianshou.data.stats import InfoStats\n", + "from tianshou.env import DummyVectorEnv\n", + "from tianshou.env.pettingzoo_env import PettingZooEnv\n", + "from tianshou.trainer import OffPolicyTrainerParams\n", + "from tianshou.utils import TensorboardLogger\n", + "from tianshou.utils.net.common import Net" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Hyperparameters\n", + "\n", + "Let's define the hyperparameters for our training experiment directly (no argparse needed in notebooks!):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define hyperparameters\n", + "class Args:\n", + " seed = 1626\n", + " eps_test = 0.05\n", + " eps_train = 0.1\n", + " buffer_size = 20000\n", + " lr = 1e-4\n", + " gamma = 0.9 # A smaller gamma favors earlier wins\n", + " n_step = 3\n", + " target_update_freq = 320\n", + " epoch = 50\n", + " epoch_num_steps = 1000\n", + " collection_step_num_env_steps = 10\n", + " update_per_step = 0.1\n", + " batch_size = 64\n", + " hidden_sizes = [128, 128, 128, 128] # noqa: RUF012\n", + " num_train_envs = 10\n", + " num_test_envs = 10\n", + " logdir = \"log\"\n", + " render = 0.1\n", + " win_rate = 0.6 # Target winning rate (optimal policy can get ~0.7)\n", + " watch = False # Set to True to skip training and watch pre-trained models\n", + " agent_id = 2 # The learned agent plays as player 2\n", + " resume_path = \"\" # Path to pre-trained agent .pth file\n", + " opponent_path = \"\" # Path to pre-trained opponent .pth file\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + " model_save_path = None # Will be set in save_best_fn\n", + "\n", + "\n", + "args = Args()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Agent Setup\n", + "\n", + "The `get_agents` function creates and configures our agents:\n", + "\n", + "- **Neural Network**: We use `Net`, a multi-layer perceptron with ReLU activations\n", + "- **Learning Algorithm**: A `DiscreteQLearningPolicy` combined with `DQN` for Q-learning updates\n", + "- **Opponent**: Either a `MARLRandomDiscreteMaskedOffPolicyAlgorithm` that randomly chooses legal actions, or a pre-trained agent for self-play\n", + "\n", + "Both agents are managed by `MultiAgentOffPolicyAlgorithm`, which:\n", + "- Calls the correct agent based on `agent_id` in the observation\n", + "- Dispatches data to each agent according to their `agent_id`\n", + "- Makes each agent perceive the environment as a single-agent problem\n", + "\n", + "
\n", + "
\n", + "How MultiAgentOffPolicyAlgorithm coordinates agent algorithms\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_env(render_mode: str | None = None) -> PettingZooEnv:\n", + " return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))\n", + "\n", + "\n", + "def get_agents(\n", + " args,\n", + " agent_learn: OffPolicyAlgorithm | None = None,\n", + " agent_opponent: OffPolicyAlgorithm | None = None,\n", + " optim: OptimizerFactory | None = None,\n", + ") -> tuple[MultiAgentOffPolicyAlgorithm, torch.optim.Optimizer | None, list]:\n", + " \"\"\"Create or load agents for training.\"\"\"\n", + " env = get_env()\n", + " observation_space = (\n", + " env.observation_space.spaces[\"observation\"]\n", + " if isinstance(env.observation_space, gymnasium.spaces.Dict)\n", + " else env.observation_space\n", + " )\n", + " args.state_shape = observation_space.shape or int(observation_space.n)\n", + " args.action_shape = env.action_space.shape or int(env.action_space.n)\n", + "\n", + " if agent_learn is None:\n", + " # Create the neural network model\n", + " net = Net(\n", + " state_shape=args.state_shape,\n", + " action_shape=args.action_shape,\n", + " hidden_sizes=args.hidden_sizes,\n", + " ).to(args.device)\n", + "\n", + " if optim is None:\n", + " optim = AdamOptimizerFactory(lr=args.lr)\n", + "\n", + " # Create Q-learning policy for the learning agent\n", + " algorithm = DiscreteQLearningPolicy(\n", + " model=net,\n", + " action_space=env.action_space,\n", + " eps_training=args.eps_train,\n", + " eps_inference=args.eps_test,\n", + " )\n", + "\n", + " # Wrap in DQN algorithm\n", + " agent_learn = DQN(\n", + " policy=algorithm,\n", + " optim=optim,\n", + " n_step_return_horizon=args.n_step,\n", + " gamma=args.gamma,\n", + " target_update_freq=args.target_update_freq,\n", + " )\n", + "\n", + " if args.resume_path:\n", + " agent_learn.load_state_dict(torch.load(args.resume_path))\n", + "\n", + " if agent_opponent is None:\n", + " if args.opponent_path:\n", + " # Load a pre-trained opponent for self-play\n", + " agent_opponent = deepcopy(agent_learn)\n", + " agent_opponent.load_state_dict(torch.load(args.opponent_path))\n", + " else:\n", + " # Use a random opponent\n", + " agent_opponent = MARLRandomDiscreteMaskedOffPolicyAlgorithm(\n", + " action_space=env.action_space\n", + " )\n", + "\n", + " # Arrange agents based on which player position the learning agent takes\n", + " if args.agent_id == 1:\n", + " agents = [agent_learn, agent_opponent]\n", + " else:\n", + " agents = [agent_opponent, agent_learn]\n", + "\n", + " ma_algorithm = MultiAgentOffPolicyAlgorithm(algorithms=agents, env=env)\n", + " return ma_algorithm, optim, env.agents" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training Loop\n", + "\n", + "The training procedure follows the standard Tianshou workflow, similar to single-agent DQN training:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_agent(\n", + " args,\n", + " agent_learn: OffPolicyAlgorithm | None = None,\n", + " agent_opponent: OffPolicyAlgorithm | None = None,\n", + " optim: OptimizerFactory | None = None,\n", + ") -> tuple[InfoStats, OffPolicyAlgorithm]:\n", + " \"\"\"Train the agent using DQN.\"\"\"\n", + " # ======== Environment Setup =========\n", + " train_envs = DummyVectorEnv([get_env for _ in range(args.num_train_envs)])\n", + " test_envs = DummyVectorEnv([get_env for _ in range(args.num_test_envs)])\n", + "\n", + " # Set random seeds for reproducibility\n", + " np.random.seed(args.seed)\n", + " torch.manual_seed(args.seed)\n", + " train_envs.seed(args.seed)\n", + " test_envs.seed(args.seed)\n", + "\n", + " # ======== Agent Setup =========\n", + " marl_algorithm, optim, agents = get_agents(\n", + " args,\n", + " agent_learn=agent_learn,\n", + " agent_opponent=agent_opponent,\n", + " optim=optim,\n", + " )\n", + "\n", + " # ======== Collector Setup =========\n", + " train_collector = Collector[CollectStats](\n", + " marl_algorithm,\n", + " train_envs,\n", + " VectorReplayBuffer(args.buffer_size, len(train_envs)),\n", + " exploration_noise=True,\n", + " )\n", + " test_collector = Collector[CollectStats](marl_algorithm, test_envs, exploration_noise=True)\n", + "\n", + " # Collect initial random samples\n", + " train_collector.reset()\n", + " train_collector.collect(n_step=args.batch_size * args.num_train_envs)\n", + "\n", + " # ======== Logging Setup =========\n", + " log_path = os.path.join(args.logdir, \"tic_tac_toe\", \"dqn\")\n", + " writer = SummaryWriter(log_path)\n", + " writer.add_text(\"args\", str(args))\n", + " logger = TensorboardLogger(writer)\n", + "\n", + " player_agent_id = agents[args.agent_id - 1]\n", + "\n", + " # ======== Callback Functions =========\n", + " def save_best_fn(policy: Algorithm) -> None:\n", + " \"\"\"Save the best performing policy.\"\"\"\n", + " if hasattr(args, \"model_save_path\") and args.model_save_path:\n", + " model_save_path = args.model_save_path\n", + " else:\n", + " model_save_path = os.path.join(args.logdir, \"tic_tac_toe\", \"dqn\", \"policy.pth\")\n", + " torch.save(policy.get_algorithm(player_agent_id).state_dict(), model_save_path)\n", + "\n", + " def stop_fn(mean_rewards: float) -> bool:\n", + " \"\"\"Stop training when target win rate is achieved.\"\"\"\n", + " return mean_rewards >= args.win_rate\n", + "\n", + " def reward_metric(rews: np.ndarray) -> np.ndarray:\n", + " \"\"\"Extract the reward for our learning agent.\"\"\"\n", + " return rews[:, args.agent_id - 1]\n", + "\n", + " # ======== Trainer =========\n", + " result = marl_algorithm.run_training(\n", + " OffPolicyTrainerParams(\n", + " train_collector=train_collector,\n", + " test_collector=test_collector,\n", + " max_epochs=args.epoch,\n", + " epoch_num_steps=args.epoch_num_steps,\n", + " collection_step_num_env_steps=args.collection_step_num_env_steps,\n", + " test_step_num_episodes=args.num_test_envs,\n", + " batch_size=args.batch_size,\n", + " stop_fn=stop_fn,\n", + " save_best_fn=save_best_fn,\n", + " update_step_num_gradient_steps_per_sample=args.update_per_step,\n", + " logger=logger,\n", + " test_in_train=False,\n", + " multi_agent_return_reduction=reward_metric,\n", + " show_progress=False,\n", + " )\n", + " )\n", + "\n", + " return result, marl_algorithm.get_algorithm(player_agent_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation Function\n", + "\n", + "This function allows us to watch a trained agent play:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def watch(\n", + " args,\n", + " agent_learn: OffPolicyAlgorithm | None = None,\n", + " agent_opponent: OffPolicyAlgorithm | None = None,\n", + ") -> None:\n", + " \"\"\"Watch a pre-trained agent play.\"\"\"\n", + " env = DummyVectorEnv([partial(get_env, render_mode=\"human\")])\n", + " policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)\n", + " collector = Collector[CollectStats](policy, env, exploration_noise=True)\n", + " result = collector.collect(n_episode=1, render=args.render, reset_before_collect=True)\n", + " result.pprint_asdict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Running the Training\n", + "\n", + "Now let's train the agent and watch it play!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Train the agent\n", + "result, agent = train_agent(args)\n", + "\n", + "# Watch the trained agent play\n", + "watch(args, agent)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Results\n", + "\n", + "After training for less than a minute, you'll see the agent play against the random opponent. Here's an example game:\n", + "\n", + "
\n", + "Example: Trained Agent vs Random Opponent\n", + "\n", + "```\n", + " | |\n", + " - | - | -\n", + "_____|_____|_____\n", + " | |\n", + " - | - | X\n", + "_____|_____|_____\n", + " | |\n", + " - | - | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " - | - | -\n", + "_____|_____|_____\n", + " | |\n", + " - | O | X\n", + "_____|_____|_____\n", + " | |\n", + " - | - | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " - | - | -\n", + "_____|_____|_____\n", + " | |\n", + " X | O | X\n", + "_____|_____|_____\n", + " | |\n", + " - | - | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " - | O | -\n", + "_____|_____|_____\n", + " | |\n", + " X | O | X\n", + "_____|_____|_____\n", + " | |\n", + " - | - | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " - | O | -\n", + "_____|_____|_____\n", + " | |\n", + " X | O | X\n", + "_____|_____|_____\n", + " | |\n", + " - | X | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " O | O | -\n", + "_____|_____|_____\n", + " | |\n", + " X | O | X\n", + "_____|_____|_____\n", + " | |\n", + " - | X | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " O | O | X\n", + "_____|_____|_____\n", + " | |\n", + " X | O | X\n", + "_____|_____|_____\n", + " | |\n", + " - | X | -\n", + " | |\n", + "```\n", + "\n", + "```\n", + " | |\n", + " O | O | X\n", + "_____|_____|_____\n", + " | |\n", + " X | O | X\n", + "_____|_____|_____\n", + " | |\n", + " - | X | O\n", + " | |\n", + "```\n", + "\n", + "Final reward: 1.0, length: 8.0\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that our trained agent plays as player 2 (O) and wins! The agent has learned the game rules through trial and error, understanding that three consecutive O marks lead to victory." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "It is easily possible to make the trained agent play against itself. Try this as an exercise!" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While the trained agent plays well against a random opponent, it's still far from perfect play. The next step would be to implement self-play training, similar to AlphaZero, where the agent continuously improves by playing against increasingly stronger versions of itself." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this tutorial, we demonstrated how to use Tianshou for training a single agent in a multi-agent reinforcement learning setting. Key takeaways:\n", + "\n", + "1. **MARL Paradigms**: Tianshou supports simultaneous, cyclic, and conditional move scenarios\n", + "2. **Abstraction**: Multi-agent problems can be converted to single-agent RL through clever state augmentation\n", + "3. **PettingZoo Integration**: Seamless compatibility with PettingZoo environments via `PettingZooEnv`\n", + "4. **Algorithm Management**: `MultiAgentOffPolicyAlgorithm` handles agent coordination and data distribution\n", + "5. **Flexible Framework**: Easy to extend from single-agent training to more complex multi-agent scenarios\n", + "\n", + "Tianshou provides a flexible and intuitive framework for reinforcement learning. Experiment with different architectures, training regimes, and opponent strategies to build even more capable agents!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/02_notebooks/0_intro.md b/docs/02_notebooks/0_intro.md deleted file mode 100644 index e68b36e63..000000000 --- a/docs/02_notebooks/0_intro.md +++ /dev/null @@ -1,9 +0,0 @@ -# Notebook Tutorials - -Here is a collection of executable tutorials for Tianshou. You can run them -directly in colab, or download them and run them locally. - -They will guide you step by step to show you how the most basic modules in Tianshou -work and how they collaborate with each other to conduct a classic DRL experiment. - -**IMPORTANT**: The notebooks are not yet adjusted to the v2 version of Tianshou! Their content is partly outdated and will be updated soon. diff --git a/docs/02_notebooks/L1_Batch.ipynb b/docs/02_notebooks/L1_Batch.ipynb deleted file mode 100644 index d40869287..000000000 --- a/docs/02_notebooks/L1_Batch.ipynb +++ /dev/null @@ -1,407 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "69y6AHvq1S3f" - }, - "source": [ - "# Batch\n", - "In this tutorial, we will introduce the **Batch** to you, which serves as the fundamental data structure in Tianshou. Think of Batch as a numpy-enhanced version of a Python dictionary. It is also similar to pytorch's tensordict,\n", - "although with a somehow different type structure." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "editable": true, - "id": "NkfiIe_y2FI-", - "outputId": "5008275f-8f77-489a-af64-b35af4448589", - "slideshow": { - "slide_type": "" - }, - "tags": [ - "remove-output", - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "import pickle\n", - "\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from tianshou.data import Batch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data = Batch(a=4, b=[5, 5], c=\"2312312\", d=(\"a\", -2, -3))\n", - "print(data)\n", - "print(data.b)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S6e6OuXe3UT-" - }, - "source": [ - "A batch stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.\n", - "\n", - "## Why do we need Batch in Tianshou?\n", - "The motivation behind the implementation of Batch module is simple. In DRL, you need to handle a lot of dictionary-format data. For instance, most algorithms would require you to store state, action, and reward data for every step when interacting with the environment. All of them can be organized as a dictionary, and the\n", - " Batch class helps Tianshou in unifying the interfaces of a diverse set of algorithms. In addition, Batch supports advanced indexing, concatenation, and splitting, as well as printing formatted outputs akin to standard numpy arrays, proving invaluable for developers.\n", - "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_Xenx64M9HhV" - }, - "source": [ - "## Basic Usages" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4YGX_f1Z9Uil" - }, - "source": [ - "### Initialization\n", - "Batch can be constructed directly from a python dictionary, and all data structures\n", - " will be converted to numpy arrays if possible." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Jl3-4BRbp3MM", - "outputId": "a8b225f6-2893-4716-c694-3c2ff558b7f0" - }, - "outputs": [], - "source": [ - "# converted from a python library\n", - "print(\"========================================\")\n", - "batch1 = Batch({\"a\": [4, 4], \"b\": (5, 5)})\n", - "print(batch1)\n", - "\n", - "# initialization of batch2 is equivalent to batch1\n", - "print(\"========================================\")\n", - "batch2 = Batch(a=[4, 4], b=(5, 5))\n", - "print(batch2)\n", - "\n", - "# the dictionary can be nested, and it will be turned into a nested Batch\n", - "print(\"========================================\")\n", - "data = {\n", - " \"action\": np.array([1.0, 2.0, 3.0]),\n", - " \"reward\": 3.66,\n", - " \"obs\": {\n", - " \"rgb_obs\": np.zeros((3, 3)),\n", - " \"flatten_obs\": np.ones(5),\n", - " },\n", - "}\n", - "\n", - "batch3 = Batch(data, extra=\"extra_string\")\n", - "print(batch3)\n", - "# batch3.obs is also a Batch\n", - "print(type(batch3.obs))\n", - "print(batch3.obs.rgb_obs)\n", - "\n", - "# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you\n", - "# want to use parallelized environments to collect data.\n", - "print(\"========================================\")\n", - "batch4 = Batch([data] * 3)\n", - "print(batch4)\n", - "print(batch4.obs.rgb_obs.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JCf6bqY3uf5L" - }, - "source": [ - "### Getting access to data\n", - "You can effortlessly search for or modify key-value pairs within a Batch, much like interacting with a Python dictionary." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2TNIY90-vU9b", - "outputId": "de52ffe9-03c2-45f2-d95a-4071132daa4a" - }, - "outputs": [], - "source": [ - "batch1 = Batch({\"a\": [4, 4], \"b\": (5, 5)})\n", - "print(batch1)\n", - "\n", - "# add or delete key-value pair in batch1\n", - "print(\"========================================\")\n", - "batch1.c = Batch(c1=np.arange(3), c2=False)\n", - "del batch1.a\n", - "print(batch1)\n", - "\n", - "# access value by key\n", - "print(\"========================================\")\n", - "assert batch1[\"c\"] is batch1.c\n", - "print(\"c\" in batch1)\n", - "\n", - "# traverse the Batch\n", - "print(\"========================================\")\n", - "for key, value in batch1.items():\n", - " print(str(key) + \": \" + str(value))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bVywStbV9jD2" - }, - "source": [ - "### Indexing and Slicing\n", - "If all values in Batch share the same shape in certain dimensions, Batch can support array-like indexing and slicing." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "gKza3OJnzc_D", - "outputId": "4f240bfe-4a69-4c1b-b40e-983c5c4d0cbc" - }, - "outputs": [], - "source": [ - "# Let us suppose we have collected the data from stepping from 4 environments\n", - "step_outputs = [\n", - " {\n", - " \"act\": np.random.randint(10),\n", - " \"rew\": 0.0,\n", - " \"obs\": np.ones((3, 3)),\n", - " \"info\": {\"done\": np.random.choice(2), \"failed\": False},\n", - " \"terminated\": False,\n", - " \"truncated\": False,\n", - " }\n", - " for _ in range(4)\n", - "]\n", - "batch = Batch(step_outputs)\n", - "print(batch)\n", - "print(batch.shape)\n", - "\n", - "# advanced indexing is supported, if we only want to select data in a given set of environments\n", - "print(\"========================================\")\n", - "print(batch[0])\n", - "print(batch[[0, 3]])\n", - "\n", - "# slicing is also supported\n", - "print(\"========================================\")\n", - "print(batch[-2:])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Aggregation and Splitting\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1vUwQ-Hw9jtu" - }, - "source": [ - "Again, just like a numpy array. Play the example code below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "f5UkReyn3_kb", - "outputId": "e7bb3324-7f20-4810-a328-479117efca55" - }, - "outputs": [], - "source": [ - "# concat batches with compatible keys\n", - "# try incompatible keys yourself if you feel curious\n", - "print(\"========================================\")\n", - "b1 = Batch(a=[{\"b\": np.float64(1.0), \"d\": Batch(e=np.array(3.0))}])\n", - "b2 = Batch(a=[{\"b\": np.float64(4.0), \"d\": {\"e\": np.array(6.0)}}])\n", - "b12_cat_out = Batch.cat([b1, b2])\n", - "print(b1)\n", - "print(b2)\n", - "print(b12_cat_out)\n", - "\n", - "# stack batches with compatible keys\n", - "# try incompatible keys yourself if you feel curious\n", - "print(\"========================================\")\n", - "b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))\n", - "b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))\n", - "b34_stack = Batch.stack((b3, b4), axis=1)\n", - "print(b3)\n", - "print(b4)\n", - "print(b34_stack)\n", - "\n", - "# split the batch into small batches of size 1, breaking the order of the data\n", - "print(\"========================================\")\n", - "print(type(b34_stack.split(1)))\n", - "print(list(b34_stack.split(1, shuffle=True)))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Smc_W1Cx6zRS" - }, - "source": [ - "### Data type converting\n", - "Besides numpy array, Batch actually also supports Torch Tensor. The usages are exactly the same. Cool, isn't it?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y6im_Mtb7Ody", - "outputId": "898e82c4-b940-4c35-a0f9-dedc4a9bc500" - }, - "outputs": [], - "source": [ - "batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))\n", - "batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))\n", - "batch_cat = Batch.cat([batch1, batch2, batch1])\n", - "print(batch_cat)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1wfTUVKb6xki" - }, - "source": [ - "You can convert the data type easily, if you no longer want to use hybrid data type anymore." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "F7WknVs98DHD", - "outputId": "cfd0712a-1df3-4208-e6cc-9149840bdc40" - }, - "outputs": [], - "source": [ - "batch_cat.to_numpy_()\n", - "print(batch_cat)\n", - "batch_cat.to_torch_()\n", - "print(batch_cat)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NTFVle1-9Biz" - }, - "source": [ - "Batch is even serializable, just in case you may need to save it to disk or restore it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Lnf17OXv9YRb", - "outputId": "753753f2-3f66-4d4b-b4ff-d57f9c40d1da" - }, - "outputs": [], - "source": [ - "batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n", - "batch_pk = pickle.loads(pickle.dumps(batch))\n", - "print(batch_pk)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-vPMiPZ-9kJN" - }, - "source": [ - "## Further Reading" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8Oc1p8ud9kcu" - }, - "source": [ - "Would you like to learn more advanced usages of Batch? Feel curious about how data is organized inside the Batch? Check the [documentation](https://tianshou.readthedocs.io/en/master/03_api/data/batch.html) and other [tutorials](https://tianshou.readthedocs.io/en/master/01_tutorials/03_batch.html#) for more details." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L2_Buffer.ipynb b/docs/02_notebooks/L2_Buffer.ipynb deleted file mode 100644 index 4f51abca5..000000000 --- a/docs/02_notebooks/L2_Buffer.ipynb +++ /dev/null @@ -1,427 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pickle\n", - "\n", - "import numpy as np\n", - "\n", - "from tianshou.data import Batch, ReplayBuffer" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xoPiGVD8LNma" - }, - "source": [ - "# Buffer\n", - "Replay Buffer is a very common module in DRL implementations. In Tianshou, the Buffer module can be viewed as a specialized form of Batch, designed to track all data trajectories and offering utilities like sampling methods beyond basic storage.\n", - "\n", - "There are many kinds of Buffer modules in Tianshou, two most basic ones are ReplayBuffer and VectorReplayBuffer. The later one is specially designed for parallelized environments (will introduce in tutorial [Vectorized Environment](https://tianshou.readthedocs.io/en/master/02_notebooks/L3_Vectorized__Environment.html)). In this tutorial, we will focus on ReplayBuffer." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OdesCAxANehZ" - }, - "source": [ - "## Usages" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fUbLl9T_SrTR" - }, - "source": [ - "### Basic usages as a batch\n", - "Typically, a buffer stores all data in batches, employing a circular-queue mechanism." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mocZ6IqZTH62", - "outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52" - }, - "outputs": [], - "source": [ - "# a buffer is initialised with its maxsize set to 10 (older data will be discarded if more data flow in).\n", - "print(\"========================================\")\n", - "dummy_buf = ReplayBuffer(size=10)\n", - "print(dummy_buf)\n", - "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")\n", - "\n", - "# add 3 steps of data into ReplayBuffer sequentially\n", - "print(\"========================================\")\n", - "for i in range(3):\n", - " dummy_buf.add(\n", - " Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),\n", - " )\n", - "print(dummy_buf)\n", - "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")\n", - "\n", - "# add another 10 steps of data into ReplayBuffer sequentially\n", - "print(\"========================================\")\n", - "for i in range(3, 13):\n", - " dummy_buf.add(\n", - " Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),\n", - " )\n", - "print(dummy_buf)\n", - "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "H8B85Y5yUfTy" - }, - "source": [ - "Just like Batch, ReplayBuffer supports concatenation, splitting, advanced slicing and indexing, etc." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cOX-ADOPNeEK", - "outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6" - }, - "outputs": [], - "source": [ - "print(dummy_buf[-1])\n", - "print(dummy_buf[-3:])\n", - "# Try more methods you find useful in Batch yourself." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vqldap-2WQBh" - }, - "source": [ - "ReplayBuffer can also be saved into local disk, still keeping track of the trajectories. This is extremely helpful in offline DRL settings." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ppx0L3niNT5K" - }, - "outputs": [], - "source": [ - "_dummy_buf = pickle.loads(pickle.dumps(dummy_buf))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Eqezp0OyXn6J" - }, - "source": [ - "### Understanding reserved keys for buffer\n", - "As explained above, ReplayBuffer is specially designed to utilize the implementations of DRL algorithms. So, for convenience, we reserve certain nine reserved keys in Batch.\n", - "\n", - "* `obs`\n", - "* `act`\n", - "* `rew`\n", - "* `terminated`\n", - "* `truncated`\n", - "* `done`\n", - "* `obs_next`\n", - "* `info`\n", - "* `policy`\n", - "\n", - "The meaning of these nine reserved keys are consistent with the meaning in [Gymansium](https://gymnasium.farama.org/index.html#). We would recommend you simply use these nine keys when adding batched data into ReplayBuffer, because\n", - "some of them are tracked in ReplayBuffer (e.g. \"done\" value is tracked to help us determine a trajectory's start index and end index, together with its total reward and episode length.)\n", - "\n", - "```\n", - "buf.add(Batch(......, extro_info=0)) # This is okay but not recommended.\n", - "buf.add(Batch(......, info={\"extro_info\":0})) # Recommended.\n", - "```\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ueAbTspsc6jo" - }, - "source": [ - "### Data sampling\n", - "The primary purpose of maintaining a replay buffer in DRL is to sample data for training. `ReplayBuffer.sample()` and `ReplayBuffer.split(..., shuffle=True)` can both fulfill this need." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "P5xnYOhrchDl", - "outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd" - }, - "outputs": [], - "source": [ - "dummy_buf.sample(batch_size=5)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IWyaOSKOcgK4" - }, - "source": [ - "## Trajectory tracking\n", - "Compared to Batch, a unique feature of ReplayBuffer is that it can help you track the environment trajectories.\n", - "\n", - "First, let us simulate a situation, where we add three trajectories into the buffer. The last trajectory is still not finished yet." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "editable": true, - "id": "H0qRb6HLfhLB", - "outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b", - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "trajectory_buffer = ReplayBuffer(size=10)\n", - "# Add the first trajectory (length is 3) into ReplayBuffer\n", - "print(\"========================================\")\n", - "for i in range(3):\n", - " result = trajectory_buffer.add(\n", - " Batch(\n", - " obs=i,\n", - " act=i,\n", - " rew=i,\n", - " terminated=1 if i == 2 else 0,\n", - " truncated=0,\n", - " done=i == 2,\n", - " obs_next=i + 1,\n", - " info={},\n", - " ),\n", - " )\n", - " print(result)\n", - "print(trajectory_buffer)\n", - "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")\n", - "\n", - "# Add the second trajectory (length is 5) into ReplayBuffer\n", - "print(\"========================================\")\n", - "for i in range(3, 8):\n", - " result = trajectory_buffer.add(\n", - " Batch(\n", - " obs=i,\n", - " act=i,\n", - " rew=i,\n", - " terminated=1 if i == 7 else 0,\n", - " truncated=0,\n", - " done=i == 7,\n", - " obs_next=i + 1,\n", - " info={},\n", - " ),\n", - " )\n", - " print(result)\n", - "print(trajectory_buffer)\n", - "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")\n", - "\n", - "# Add the third trajectory (length is 5, still not finished) into ReplayBuffer\n", - "print(\"========================================\")\n", - "for i in range(8, 13):\n", - " result = trajectory_buffer.add(\n", - " Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=False, obs_next=i + 1, info={}),\n", - " )\n", - " print(result)\n", - "print(trajectory_buffer)\n", - "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dO7PWdb_hkXA" - }, - "source": [ - "### Episode length and rewards tracking\n", - "Notice that `ReplayBuffer.add()` returns a tuple of 4 numbers every time, meaning `(current_index, episode_reward, episode_length, episode_start_index)`. `episode_reward` and `episode_length` are valid only when a trajectory is finished. This might save developers some trouble.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xbVc90z8itH0" - }, - "source": [ - "### Episode index management\n", - "In the ReplayBuffer above, we can get access to any data step by indexing.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4mKwo54MjupY", - "outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961" - }, - "outputs": [], - "source": [ - "print(trajectory_buffer)\n", - "print(\"========================================\")\n", - "\n", - "data = trajectory_buffer[6]\n", - "print(data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "p5Co_Fmzj8Sw" - }, - "source": [ - "We know that step \"6\" is not the start of an episode - which should be step \"3\", since \"3-7\" is the second trajectory we add into the ReplayBuffer - but we wonder how do we get the earliest index of that episode.\n", - "\n", - "This may seem easy but actually it is not. We cannot simply look at the \"done\" flag preceding the start of a new episode, because since the third-added trajectory is not finished yet, step \"3\" is surrounded by flag \"False\". There are many things to consider. Things could get more nasty when using more advanced ReplayBuffer like VectorReplayBuffer, since it does not store the data in a simple circular-queue.\n", - "\n", - "Luckily, all ReplayBuffer instances help you identify step indexes through a unified API. One can simply input an array of indexes and look for their previous index in the episode." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# previous step of indexes [0, 1, 2, 3, 4, 5, 6] are:\n", - "print(trajectory_buffer.prev(np.array([0, 1, 2, 3, 4, 5, 6])))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4Wlb57V4lQyQ" - }, - "source": [ - "Using `ReplayBuffer.prev()`, we know that the earliest step of that episode is step \"3\". Similarly, `ReplayBuffer.next()` helps us identify the last index of an episode regardless of which kind of ReplayBuffer we are using." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zl5TRMo7oOy5", - "outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb" - }, - "outputs": [], - "source": [ - "# next step of indexes [4,5,6,7,8,9] are:\n", - "print(trajectory_buffer.next(np.array([4, 5, 6, 7, 8, 9])))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YJ9CcWZXoOXw" - }, - "source": [ - "We can also search for the indexes which are labeled \"done: False\", but are the last step in a trajectory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Xkawk97NpItg", - "outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd" - }, - "outputs": [], - "source": [ - "print(trajectory_buffer.unfinished_index())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8_lMr0j3pOmn" - }, - "source": [ - "Aforementioned APIs will be helpful when we calculate quantities like GAE and n-step-returns in DRL algorithms ([Example usage in Tianshou](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)). The unified APIs ensure a modular design and a flexible interface." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FEyE0c7tNfwa" - }, - "source": [ - "## Further Reading\n", - "### Other Buffer Module\n", - "\n", - "* PrioritizedReplayBuffer, which helps you implement [prioritized experience replay](https://arxiv.org/abs/1511.05952)\n", - "* CachedReplayBuffer, one main buffer with several cached buffers (higher sample efficiency in some scenarios)\n", - "* ReplayBufferManager, A base class that can be inherited (may help you manage multiple buffers).\n", - "\n", - "Refer to the documentation and source code for further details.\n", - "\n", - "### Support for steps stacking to use RNN in DRL.\n", - "There is an option called `stack_num` (default to 1) when initializing the ReplayBuffer, which may help you use RNN in your algorithm. Check the documentation for details." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L3_Vectorized__Environment.ipynb b/docs/02_notebooks/L3_Vectorized__Environment.ipynb deleted file mode 100644 index 19e5489a2..000000000 --- a/docs/02_notebooks/L3_Vectorized__Environment.ipynb +++ /dev/null @@ -1,229 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "W5V7z3fVX7_b" - }, - "source": [ - "# Vectorized Environment\n", - "In reinforcement learning, an agent engages with environments to enhance its performance. In this tutorial we will concentrate on the environment part. Although there are many kinds of environments or their libraries in DRL research, Tianshou chooses to keep a consistent API with [OPENAI Gym](https://gym.openai.com/).\n", - "\n", - "
\n", - "\n", - "\n", - " The agents interacting with the environment \n", - "
\n", - "\n", - "In Gym, an environment receives an action and returns next observation and reward. This process is slow and sometimes can be the throughput bottleneck in a DRL experiment.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "A0NGWZ8adBwt" - }, - "source": [ - "Tianshou provides vectorized environment wrapper for a Gym environment. This wrapper allows you to make use of multiple cpu cores in your server to accelerate the data sampling." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "editable": true, - "id": "67wKtkiNi3lb", - "outputId": "1e04353b-7a91-4c32-e2ae-f3889d58aa5e", - "slideshow": { - "slide_type": "" - }, - "tags": [ - "remove-output", - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "import time\n", - "\n", - "import gymnasium as gym\n", - "import numpy as np\n", - "\n", - "from tianshou.env import DummyVectorEnv, SubprocVectorEnv" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "num_cpus = [1, 2, 5]\n", - "for num_cpu in num_cpus:\n", - " env = SubprocVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(num_cpu)])\n", - " env.reset()\n", - " sampled_steps = 0\n", - " time_start = time.time()\n", - " while sampled_steps < 1000:\n", - " act = np.random.choice(2, size=num_cpu)\n", - " obs, rew, terminated, truncated, info = env.step(act)\n", - " if np.sum(terminated):\n", - " env.reset(np.where(terminated)[0])\n", - " sampled_steps += num_cpu\n", - " time_used = time.time() - time_start\n", - " print(f\"{time_used}s used to sample 1000 steps if using {num_cpu} cpus.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "S1b6vxp9nEUS" - }, - "source": [ - "You may notice that the speed doesn't increase linearly when we add subprocess numbers. There are multiple reasons behind this. One reason is that synchronize exception causes straggler effect. One way to solve this would be to use asynchronous mode. We leave this for further reading if you feel interested.\n", - "\n", - "Note that SubprocVectorEnv should only be used when the environment execution is slow. In practice, DummyVectorEnv (or raw Gym environment) is actually more efficient for a simple environment like CartPole because now you avoid both straggler effect and the overhead of communication between subprocesses." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z6yPxdqFp18j" - }, - "source": [ - "## Usages\n", - "### Initialization\n", - "Just pass in a list of functions which return the initialized environment upon called." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ssLcrL_pq24-" - }, - "outputs": [], - "source": [ - "# In Gym\n", - "gym_env = gym.make(\"CartPole-v1\")\n", - "\n", - "\n", - "# In Tianshou\n", - "def create_cartpole_env() -> gym.Env:\n", - " return gym.make(\"CartPole-v1\")\n", - "\n", - "\n", - "# We can distribute the environments on the available cpus, which we assume to be 5 in this case\n", - "vector_env = DummyVectorEnv([create_cartpole_env for _ in range(5)])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X7p8csjdrwIN" - }, - "source": [ - "### EnvPool supporting\n", - "Besides integrated environment wrappers, Tianshou also fully supports [EnvPool](https://github.com/sail-sg/envpool/). Explore its Github page yourself." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kvIfqh0vqAR5" - }, - "source": [ - "### Environment execution and resetting\n", - "The only difference between Vectorized environments and standard Gym environments is that passed in actions and returned rewards/observations are also vectorized." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BH1ZnPG6tkdD" - }, - "outputs": [], - "source": [ - "# In gymnasium, env.reset() returns an observation, info tuple\n", - "print(\"In Gym, env.reset() returns a single observation.\")\n", - "print(gym_env.reset())\n", - "\n", - "# In Tianshou, envs.reset() returns stacked observations.\n", - "print(\"========================================\")\n", - "print(\"In Tianshou, a VectorEnv's reset() returns stacked observations.\")\n", - "print(vector_env.reset())\n", - "\n", - "info = vector_env.step(np.random.choice(2, size=vector_env.env_num))[4]\n", - "print(info)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qXroB7KluvP9" - }, - "source": [ - "If we only want to execute several environments. The `id` argument can be used." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ufvFViKTu8d_" - }, - "outputs": [], - "source": [ - "info = vector_env.step(np.random.choice(2, size=3), id=[0, 3, 1])[4]\n", - "print(info)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fekHR1a6X_HB" - }, - "source": [ - "## Further Reading\n", - "### Other environment wrappers in Tianshou\n", - "\n", - "\n", - "* ShmemVectorEnv: use share memory instead of pipe based on SubprocVectorEnv;\n", - "* RayVectorEnv: use Ray for concurrent activities and is currently the only choice for parallel simulation in a cluster with multiple machines.\n", - "\n", - "Check the [documentation](https://tianshou.org/en/master/03_api/env/venvs.html) for details.\n", - "\n", - "### Difference between synchronous and asynchronous mode (How to choose?)\n", - "For further insights, refer to the [Parallel Sampling](https://tianshou.org/en/master/01_tutorials/07_cheatsheet.html#parallel-sampling) tutorial." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb deleted file mode 100644 index a52dd25eb..000000000 --- a/docs/02_notebooks/L5_Collector.ipynb +++ /dev/null @@ -1,271 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "M98bqxdMsTXK" - }, - "source": [ - "# Collector\n", - "From its literal meaning, we can easily know that the Collector in Tianshou is used to collect training data. More specifically, the Collector controls the interaction between Policy (agent) and the environment. It also helps save the interaction data into the ReplayBuffer and returns episode statistics.\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OX5cayLv4Ziu" - }, - "source": [ - "## Usages\n", - "Collector can be used both for training (data collecting) and evaluation in Tianshou." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Z6XKbj28u8Ze" - }, - "source": [ - "### Policy evaluation\n", - "We need to evaluate our trained policy from time to time in DRL experiments. Collector can help us with this.\n", - "\n", - "First we have to initialize a Collector with an (vectorized) environment and a given policy (agent)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "w8t9ubO7u69J", - "slideshow": { - "slide_type": "" - }, - "tags": [ - "hide-cell", - "remove-output" - ] - }, - "outputs": [], - "source": [ - "import gymnasium as gym\n", - "import torch\n", - "\n", - "from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", - "from tianshou.env import DummyVectorEnv\n", - "from tianshou.utils.net.common import Net\n", - "from tianshou.utils.net.discrete import DiscreteActor\n", - "\n", - "env = gym.make(\"CartPole-v1\")\n", - "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n", - "\n", - "# model\n", - "assert env.observation_space.shape is not None # for mypy\n", - "preprocess_net = Net(\n", - " state_shape=env.observation_space.shape,\n", - " hidden_sizes=[\n", - " 16,\n", - " ],\n", - ")\n", - "\n", - "assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n", - "actor = DiscreteActor(preprocess_net=preprocess_net, action_shape=env.action_space.n)\n", - "\n", - "policy = ProbabilisticActorPolicy(\n", - " actor=actor,\n", - " dist_fn=torch.distributions.Categorical,\n", - " action_space=env.action_space,\n", - " action_scaling=False,\n", - ")\n", - "test_collector = Collector[CollectStats](policy, test_envs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wmt8vuwpzQdR" - }, - "source": [ - "Now we would like to collect 9 episodes of data to test how our initialized Policy performs." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9SuT6MClyjyH", - "outputId": "1e48f13b-c1fe-4fc2-ca1b-669485efdcae" - }, - "outputs": [], - "source": [ - "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n", - "\n", - "collect_result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zX9AQY0M0R3C" - }, - "source": [ - "Now we wonder what is the performance of a random policy." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UEcs8P8P0RLt", - "outputId": "85f02f9d-b79b-48b2-99c6-36a1602f0884" - }, - "outputs": [], - "source": [ - "# Reset the collector\n", - "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n", - "\n", - "collect_result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sKQRTiG10ljU" - }, - "source": [ - "It seems like an initialized policy performs even worse than a random policy without any training." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8RKmHIoG1A1k" - }, - "source": [ - "### Data Collecting\n", - "Data collecting is mostly used during training, when we need to store the collected data in a ReplayBuffer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "editable": true, - "id": "CB9XB9bF1YPC", - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "train_env_num = 4\n", - "buffer_size = 100\n", - "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", - "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", - "\n", - "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rWKDazA42IUQ" - }, - "source": [ - "Now we can collect 50 steps of data, which will be automatically saved in the replay buffer. You can still choose to collect a certain number of episodes rather than steps. Try it yourself." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-fUtQOnM2Yi1", - "outputId": "dceee987-433e-4b75-ed9e-823c20a9e1c2" - }, - "outputs": [], - "source": [ - "train_collector.reset()\n", - "replayBuffer.reset()\n", - "\n", - "print(f\"Replay buffer before collecting is empty, and has length={len(replayBuffer)} \\n\")\n", - "n_step = 50\n", - "collect_result = train_collector.collect(n_step=n_step)\n", - "print(\n", - " f\"Replay buffer after collecting {n_step} steps has length={len(replayBuffer)}.\\n\"\n", - " f\"This may exceed n_step when it is not a multiple of train_env_num because of vectorization.\\n\",\n", - ")\n", - "collect_result.pprint_asdict()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Sample some data from the replay buffer." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "replayBuffer.sample(10)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8NP7lOBU3-VS" - }, - "source": [ - "## Further Reading\n", - "The above collector actually collects 52 data at a time because 52 % 4 = 0. There is one asynchronous collector which allows you collect exactly 50 steps. Check the [documentation](https://tianshou.org/en/master/03_api/data/collector.html#tianshou.data.collector.AsyncCollector) for details." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/docs/01_tutorials/06_benchmark.rst b/docs/04_benchmarks/benchmarks.rst similarity index 99% rename from docs/01_tutorials/06_benchmark.rst rename to docs/04_benchmarks/benchmarks.rst index 26f80eb2b..6112322e3 100644 --- a/docs/01_tutorials/06_benchmark.rst +++ b/docs/04_benchmarks/benchmarks.rst @@ -1,5 +1,5 @@ -Benchmark -========= +Benchmarks +========== Mujoco Benchmark diff --git a/docs/04_contributing/04_contributing.rst b/docs/04_contributing/04_contributing.rst deleted file mode 100644 index 48cf172c8..000000000 --- a/docs/04_contributing/04_contributing.rst +++ /dev/null @@ -1,148 +0,0 @@ -Contributing to Tianshou -======================== - - -Install Development Environment -------------------------------- - -Tianshou is built and managed by `poetry `_. For example, -to install all relevant requirements (and install Tianshou itself in editable mode) -you can simply call - -.. code-block:: bash - - $ poetry install --with dev - - -Platform-Specific Configuration -------------------------------- - -**Windows**: -Since the repository contains symbolic links, make sure this is supported: - - * Enable Windows Developer Mode to allow symbolic links to be created: Search Start Menu for "Developer Settings" and enable "Developer Mode" - * Enable symbolic links for this repository: ``git config core.symlinks true`` - * Re-checkout the current git state: ``git checkout .`` - - -PEP8 Code Style Check and Formatting ----------------------------------------- - -Please set up pre-commit by running - -.. code-block:: bash - - $ pre-commit install - -in the main directory. This should make sure that your contribution is properly -formatted before every commit. - -The code is inspected and formatted by ``black`` and ``ruff``. They are executed as -pre-commit hooks. In addition, ``poe the poet`` tasks are configured. -Simply run ``poe`` to see the available tasks. -E.g, to format and check the linting manually you can run: - -.. code-block:: bash - - $ poe format - $ poe lint - - -Type Checks ------------ - -We use `mypy `_ to check the type annotations. To check, in the main directory, run: - -.. code-block:: bash - - $ poe type-check - - -Testing Locally ---------------- - -This command will run automatic tests in the main directory - -.. code-block:: bash - - $ poe test - - -Determinism Tests -~~~~~~~~~~~~~~~~~ - -We implemented "determinism tests" for Tianshou's algorithms, which allow us to determine -whether algorithms still compute exactly the same results even after large refactorings. -These tests are applied by - - 1. creating a behavior snapshot ine the old code branch before the changes and then - 2. running the test in the new branch to ensure that the behavior is the same. - -Unfortunately, full determinism is difficult to achieve across different platforms and even different -machines using the same platform an Python environment. -Therefore, these tests are not carried out in the CI pipeline. -Instead, it is up to the developer to run them locally and check the results whenever a change -is made to the code base that could affect algorithm behavior. - -Technically, the two steps are handled by setting static flags in class ``AlgorithmDeterminismTest`` and then -running either the full test suite or a specific determinism test (``test_*_determinism``, e.g. ``test_ddpg_determinism``) -in the two branches to be compared. - - 1. On the old branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=True`` and run the test(s). - 2. On the new branch: (Temporarily) set ``ENABLED=True`` and ``FORCE_SNAPSHOT_UPDATE=False`` and run the test(s). - 3. Inspect the test results; find a summary in ``determinism_tests.log`` - -Test by GitHub Actions ----------------------- - -1. Click the ``Actions`` button in your own repo: - -.. image:: ../_static/images/action1.jpg - :align: center - -2. Click the green button: - -.. image:: ../_static/images/action2.jpg - :align: center - -3. You will see ``Actions Enabled.`` on the top of html page. - -4. When you push a new commit to your own repo (e.g. ``git push``), it will automatically run the test in this page: - -.. image:: ../_static/images/action3.png - :align: center - - -Documentation -------------- - -Documentations are written under the ``docs/`` directory as ReStructuredText (``.rst``) files. ``index.rst`` is the main page. A Tutorial on ReStructuredText can be found `here `_. - -API References are automatically generated by `Sphinx `_ according to the outlines under ``docs/api/`` and should be modified when any code changes. - -To compile documentation into webpage, run - -.. code-block:: bash - - $ poe doc-build - -The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/). - - -Documentation Generation Test ------------------------------ - -We have the following three documentation tests: - -1. pydocstyle (as part of ruff): test all docstring under ``tianshou/``; - -2. doc8 (as part of ruff): test ReStructuredText format; - -3. sphinx spelling and test: test if there is any error/warning when generating front-end html documentation. - -To check, in the main directory, run: - -.. code-block:: bash - - $ poe lint - $ poe doc-build diff --git a/docs/05_developer_guide/developer_guide.md b/docs/05_developer_guide/developer_guide.md new file mode 100644 index 000000000..c5dc71881 --- /dev/null +++ b/docs/05_developer_guide/developer_guide.md @@ -0,0 +1,215 @@ +# Developer Guide + +The section addresses developers of Tianshou, providing information for +both casual contributors and maintainers alike. + + +## Python Virtual Environment + +Tianshou is built and managed by [poetry](https://python-poetry.org/). + +The development environment uses Python 3.11. + +To install all relevant requirements (as well as Tianshou itself in editable mode) +you can simply call + + poetry install --with dev + +```{important} +Depending on your setup, you may need to create and activate an empty virtual environment +using the right Python version beforehand. For instance, to do this with conda, use: + + conda create -n tianshou python=3.11 + conda activate tianshou +``` + + +## Code Style and Auto-Formatting + +When editing code in Tianshou, strive for **local consistency**, i.e. +adhere to the style already present in the codebase. + +Tianshou uses an auto-formatting for consistency. +To apply it, call + + poe format + +To check whether your formatting is compliant without applying the +auto-formatter, call + + poe lint + + +## Type Checking + +We use [mypy](https://github.com/python/mypy/) to perform static type analysis. +To check typing, run + + poe type-check + + +## Tests + +### Running the Test Suite Locally + +Tianshou uses pytest. Tests are located in `./test`. + +To run the full set of tests locally, run + + poe test + +### Determinism Tests + +We implemented **determinism tests** for Tianshou's algorithms, which allow us to determine +whether algorithms still compute exactly the same results even after large refactorings. +These tests are applied by + + 1. creating a behavior snapshot in the old code branch before the changes and then + 2. running the respective determinism test in the new branch to ensure that the behavior is the same. + +Unfortunately, full determinism is difficult to achieve across different platforms and even different +machines using the same platform an Python environment. +Therefore, these tests are not carried out in the CI pipeline. +Instead, it is up to the developer to run them locally and check the results whenever a change +is made to the codebase that could affect algorithm behavior. + +Technically, the two steps are handled by setting static flags in class `AlgorithmDeterminismTest` and then +running either the full test suite or a specific determinism test (`test_*_determinism`, e.g. `test_ddpg_determinism`) +in the two branches to be compared. + + 1. On the old branch: (Temporarily) set `ENABLED=True` and `FORCE_SNAPSHOT_UPDATE=True` and run the test(s). + 2. On the new branch: (Temporarily) set `ENABLED=True` and `FORCE_SNAPSHOT_UPDATE=False` and run the test(s). + 3. Inspect the test results; find a summary in `determinism_tests.log` + +### Tests in CI (GitHub Actions) + +CI tests will extensively test Tianshou's functionality in multiple environments. + +In particular, we test + * on Ubuntu (full functionality tested) + * **py_pinned**: using the pinned development environment (Python 3.11, known versions of all dependencies) + * **py_latest**: using a more recent Python version with the newest set of compatible dependencies (automatically resolved) + * on Windows and macOS (core functionality tested) + + +#### Principle of Maximum Compatibility + +The idea behind testing with dynamically resolved dependencies is that we want to maximize the applicability +of Tianshou: For important dependencies that could conflict with environments used by our users, **we do not restrict the version of a dependency unless there is a known incompatibility.** + +If incompatibilities should arise (e.g. by the "py_latest" test failing), we either + * resolve them by making the code compatible with both old and new versions OR + * add an upper bound to our dependency declarations (excluding the incompatible versions) and release a new + version of Tianshou to make these exclusions explicit. + + +## High-Level API + +The high-level API provides a declarative, user-friendly interface for setting up reinforcement learning experiments. From a library developer's perspective, it is important that this API be clearly structured and maintainable. This section explains the architectural principles and how to extend the API to support new algorithms. + +### Core Abstractions + +The high-level API is built around a clear separation of concerns: + +**Parameter Classes** are dataclasses (inheriting from `Params`) that represent algorithm-specific configuration. +They capture hyperparameters in a high-level, user-friendly form. +Because the high-level interface must abstract away from low-level details, parameters may need transformation before being passed to policy classes. +This is handled via `ParamTransformer` instances, which successively transform the parameter dictionary representation. +To maintain clarity and reduce coupling, parameter transformers are co-located with the parameters they affect. +The system uses inheritance and mixins extensively to reduce duplication while maintaining flexibility. + +**Factories** embody the principle of declarative configuration. +Because object creation may depend on other objects that don't yet exist at configuration time (e.g., neural networks depend on environment properties), +the API transitions from objects to factories. +Key factory types include: +- `EnvFactory` for creating training, test, and watch environments +- `AgentFactory` as the central factory that creates policies, trainers, and collectors +- Various specialized factories for optimizers, actors, critics, noise, distributions, learning rate schedulers, and policy wrappers + +**Algorithm Factories** (subclasses of `AlgorithmFactory`) are the core components responsible for orchestrating the creation of all algorithm-specific objects. +They handle the creation of neural network architectures, apply parameter transformations, instantiate policies, and create trainers with appropriate collectors. +To support a new algorithm, this is the primary extension point. + +**Experiment Builders** (subclasses of `ExperimentBuilder`) provide the user-facing interface following the builder pattern. +They contain sensible defaults while allowing customization through fluent `with_*` methods. +Builder mixins provide composable functionality for common patterns (e.g., actor/critic configuration), avoiding code duplication across algorithm-specific implementations. + +### Supporting a New Algorithm + +Extending the high-level API to support a new algorithm involves creating three main components: + +**Parameter Class**: Define a dataclass in `tianshou/highlevel/params/algorithm_params.py` that inherits from appropriate base classes and mixins. +The choice of base class depends on the algorithm's architecture (actor-critic, single network, etc.) and learning paradigm (on-policy, off-policy). +Override `_get_param_transformers()` to specify how high-level parameters should be transformed for the low-level policy API. +Common transformers handle optimizer creation, noise instantiation, and environment-dependent parameter resolution. + +**Algorithm Factory**: Implement a subclass of `AlgorithmFactory` in `tianshou/highlevel/algorithm.py`. +In most cases, inherit from existing base factories like `ActorCriticOnPolicyAlgorithmFactory`, `ActorCriticOffPolicyAlgorithmFactory`, +or `DiscreteCriticOnlyOffPolicyAlgorithmFactory`, which handle common creation patterns. +The primary requirement is implementing `_get_algorithm_class()` to return the appropriate algorithm class. +For algorithms with non-standard requirements, override `_create_algorithm()`, `_create_kwargs()`, etc. to customize the instantiation logic. + +**Experiment Builder**: Add a builder class in `tianshou/highlevel/experiment.py` that inherits from `OnPolicyExperimentBuilder` or `OffPolicyExperimentBuilder` +along with appropriate mixins. The mixins provide standard functionality for configuring actors and critics +(single critic, dual critics, critic ensembles, parameter sharing patterns, etc.). +The main responsibility is implementing `_create_algorithm_factory()` to instantiate the algorithm factory with appropriate parameters and network factories. +Optionally provide `with_*` methods for algorithm-specific configuration. + +Export the new classes in `tianshou/highlevel/__init__.py` to make them available to users. + +### Design Principles + +The architecture follows several key principles: + +**Separation of Concerns**: Configuration is cleanly separated from implementation. +The transformation system bridges these layers while maintaining independence. + +**Declarative Configuration**: Factories enable a declarative style where experiments are defined by what should be created rather than imperative steps. +This makes experiments easily serializable and reproducible. + +**Composition and Inheritance**: Mixins and inheritance reduce code duplication. +Common functionality is factored into reusable components while maintaining flexibility for algorithm-specific requirements. + +**Progressive Disclosure**: The API provides sensible defaults for simple use cases while allowing deep customization when needed. +Users can progress from simple configurations to advanced setups without fighting the abstractions. + +**Co-location**: Related code is kept together. Parameter transformers are defined near the parameters they transform, +maintaining clarity about dependencies and making the codebase easier to navigate. + +**Type Safety**: Extensive use of generics and type hints ensures that type checkers can catch configuration errors at development time rather than runtime. + + +## Documentation + +Documentation is in the `docs/` directory, using Markdown (`.md`), ReStructuredText (`.rst`) and notebook files. +`index.rst` is the main page. + +API References are automatically generated by [Sphinx](http://www.sphinx-doc.org/en/stable/) according to the outlines under `docs/api/` and should be modified when any code changes. + +To compile documentation into webpage, run + + poe doc-build + +The generated webpages can subsequently be found in `docs/_build` and can be viewed with any browser. + +### Verifications + +We have several automated verification methods for documentation: + +1. pydocstyle (as part of ruff): tests all docstring under `tianshou/`; + +2. doc8 (as part of ruff): tests ReStructuredText format; + +3. sphinx spelling and test: test if there is any error/warning when generating front-end html documentation. + + +## Creating a Release + +To release a new version on PyPI, + + * set the version to be released in `tianshou/__init__.py` and in `pyproject.toml`, creating a commit + * tag the commit with the version (using the format `v1.2.3`) + * push the commit (`git push`) and the tag (`git push --tags`) + * create a new release on GitHub based on the tag; this will trigger the release job for PyPI. + +In the past, we provided releases to conda-forge as well, but this is currently not maintained. diff --git a/docs/04_contributing/05_contributors.rst b/docs/06_contributors/contributors.rst similarity index 100% rename from docs/04_contributing/05_contributors.rst rename to docs/06_contributors/contributors.rst diff --git a/docs/_config.yml b/docs/_config.yml index 0f110fb33..ffe3959bb 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -98,6 +98,7 @@ sphinx: - sphinx.ext.viewcode - sphinx_toolbox.more_autodoc.sourcelink - sphinxcontrib.spelling + - sphinxcontrib.mermaid local_extensions : # A list of local extensions to load by sphinx specified by "name: path" items recursive_update : false # A boolean indicating whether to overwrite the Sphinx config (true) or recursively update (false) config : # key-value pairs to directly over-ride the Sphinx configuration @@ -140,6 +141,7 @@ sphinx: [ 'spelling', 'text/plain', 90 ], ] mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + myst_fence_as_directive: ["mermaid"] mathjax3_config: loader: { load: [ '[tex]/configmacros' ] } tex: diff --git a/docs/_static/images/action1.jpg b/docs/_static/images/action1.jpg deleted file mode 100644 index 49620d512..000000000 Binary files a/docs/_static/images/action1.jpg and /dev/null differ diff --git a/docs/_static/images/action2.jpg b/docs/_static/images/action2.jpg deleted file mode 100644 index e07c33f52..000000000 Binary files a/docs/_static/images/action2.jpg and /dev/null differ diff --git a/docs/_static/images/action3.png b/docs/_static/images/action3.png deleted file mode 100644 index 8da8da442..000000000 Binary files a/docs/_static/images/action3.png and /dev/null differ diff --git a/docs/_static/js/benchmark.js b/docs/_static/js/benchmark.js index da44b43a9..655b8e6d3 100644 --- a/docs/_static/js/benchmark.js +++ b/docs/_static/js/benchmark.js @@ -22,7 +22,7 @@ var atari_envs = [ function getDataSource(selectEnv, dirName) { return { - // Paths are relative to the only file using this script, which is docs/01_tutorials/06_benchmark.rst + // Paths are relative to the only file using this script, which is docs/04_benchmarks/benchmarks.rst $schema: "https://vega.github.io/schema/vega-lite/v5.json", data: { url: "../_static/js/" + dirName + "/benchmark/" + selectEnv + "/result.json" diff --git a/poetry.lock b/poetry.lock index 09adb4a4b..5b81bd448 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6171,6 +6171,25 @@ files = [ [package.extras] test = ["flake8", "mypy", "pytest"] +[[package]] +name = "sphinxcontrib-mermaid" +version = "1.0.0" +description = "Mermaid diagrams in yours Sphinx powered docs" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "sphinxcontrib_mermaid-1.0.0-py3-none-any.whl", hash = "sha256:60b72710ea02087f212028feb09711225fbc2e343a10d34822fe787510e1caa3"}, + {file = "sphinxcontrib_mermaid-1.0.0.tar.gz", hash = "sha256:2e8ab67d3e1e2816663f9347d026a8dee4a858acdd4ad32dd1c808893db88146"}, +] + +[package.dependencies] +pyyaml = "*" +sphinx = "*" + +[package.extras] +test = ["defusedxml", "myst-parser", "pytest", "ruff", "sphinx"] + [[package]] name = "sphinxcontrib-qthelp" version = "1.0.6" @@ -7070,4 +7089,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "6a5ae8b5b701f0daee90e241187c1628477b6ac96394a3cb15f2921659e80e34" +content-hash = "74f33f02b6e6d6e6d45c1a8fa1798d6fdd29decb3cc3c81b6d3a8fe0d5c45ad7" diff --git a/pyproject.toml b/pyproject.toml index 7c38d5b93..c84b87142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ sphinx-togglebutton = "^0.3.2" sphinx-toolbox = "^3.5.0" sphinxcontrib-bibtex = "*" sphinxcontrib-spelling = "^8.0.0" +sphinxcontrib-mermaid = "^1.0.0" types-requests = "^2.31.0.20240311" types-tabulate = "^0.9.0.20240106" # this is needed for wandb only (undisclosed dependency) @@ -161,6 +162,7 @@ select = [ "ASYNC", "B", "C4", "C90", "COM", "D", "DTZ", "E", "F", "FLY", "G", "I", "ISC", "PIE", "PLC", "PLE", "PLW", "RET", "RUF", "RSE", "SIM", "TID", "UP", "W", "YTT", ] ignore = [ + "RUF003", # custom (greek) letters "SIM118", # Needed b/c iter(batch) != iter(batch.keys()). See https://github.com/thu-ml/tianshou/issues/922 "E501", # line too long. ruff does a good enough job "E741", # variable names like "l". this isn't a huge problem @@ -240,4 +242,4 @@ doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] doc-build = ["doc-clean", "doc-generate-files", "_sphinx_build"] _mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" -type-check = ["_mypy", "_mypy_nb"] +type-check = ["_mypy"] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index ffdcf0efe..652102f21 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -315,8 +315,6 @@ class BatchProtocol(Protocol): (recursive) dictionary of objects that can be either numpy arrays, torch tensors, or batches themselves. It is designed to make it extremely easily to access, manipulate and set partial view of the heterogeneous data conveniently. - - For a detailed description, please refer to :ref:`batch_concept`. """ @property diff --git a/tianshou/data/buffer/buffer_base.py b/tianshou/data/buffer/buffer_base.py index fe07f7834..9de6eda42 100644 --- a/tianshou/data/buffer/buffer_base.py +++ b/tianshou/data/buffer/buffer_base.py @@ -28,9 +28,6 @@ class ReplayBuffer: ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style. - For the example usage of ReplayBuffer, please check out Section "Buffer" in - :doc:`/01_tutorials/02_internals`. - :param size: the maximum size of replay buffer. :param stack_num: the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking). @@ -396,13 +393,12 @@ def _update_state_pre_add( ep_len = self._ep_len else: if isinstance(self._ep_return, np.ndarray): # type: ignore[unreachable] - # TODO: fix this! - log.error( # type: ignore[unreachable] - f"ep_return should be a scalar but is a numpy array: {self._ep_return.shape=}. " - "This doesn't make sense for a ReplayBuffer, but currently tests of CachedReplayBuffer require" - "this behavior for some reason. Should be fixed ASAP! " - "Returning an array of zeros instead of a scalar zero.", - ) + # TODO: [original remark by MischaPanch] Check whether the entire else case is really correct/necessary. + # ep_return should be a scalar but is a numpy array. + # This doesn't make sense for a ReplayBuffer, but currently tests of CachedReplayBuffer require + # this behavior for some reason; it also occurs in the MARL notebook, for example. + # Will return an array of zeros instead of a scalar zero. + pass ep_return = np.zeros_like(self._ep_return) # type: ignore ep_len = 0 diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 2e8287ba9..91e77632d 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -32,7 +32,6 @@ class PettingZooEnv(AECEnv, ABC): env.close() The available action's mask is set to True, otherwise it is set to False. - Further usage can be found at :ref:`marl_example`. """ def __init__(self, env: BaseWrapper): diff --git a/tianshou/trainer/trainer.py b/tianshou/trainer.py similarity index 100% rename from tianshou/trainer/trainer.py rename to tianshou/trainer.py diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py deleted file mode 100644 index 702ce6bf4..000000000 --- a/tianshou/trainer/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Trainer package.""" - -from .trainer import ( - OfflineTrainer, - OfflineTrainerParams, - OffPolicyTrainer, - OffPolicyTrainerParams, - OnPolicyTrainer, - OnPolicyTrainerParams, - Trainer, - TrainerParams, -)