From d60423b9da647c220c06031d18a74d182fcd1011 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 17 Nov 2025 21:36:05 +0700 Subject: [PATCH 1/3] Docs: support for mermaid diagrams --- docs/_config.yml | 2 ++ poetry.lock | 21 ++++++++++++++++++++- pyproject.toml | 1 + 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/docs/_config.yml b/docs/_config.yml index 0f110fb33..ffe3959bb 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -98,6 +98,7 @@ sphinx: - sphinx.ext.viewcode - sphinx_toolbox.more_autodoc.sourcelink - sphinxcontrib.spelling + - sphinxcontrib.mermaid local_extensions : # A list of local extensions to load by sphinx specified by "name: path" items recursive_update : false # A boolean indicating whether to overwrite the Sphinx config (true) or recursively update (false) config : # key-value pairs to directly over-ride the Sphinx configuration @@ -140,6 +141,7 @@ sphinx: [ 'spelling', 'text/plain', 90 ], ] mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js + myst_fence_as_directive: ["mermaid"] mathjax3_config: loader: { load: [ '[tex]/configmacros' ] } tex: diff --git a/poetry.lock b/poetry.lock index 09adb4a4b..5b81bd448 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6171,6 +6171,25 @@ files = [ [package.extras] test = ["flake8", "mypy", "pytest"] +[[package]] +name = "sphinxcontrib-mermaid" +version = "1.0.0" +description = "Mermaid diagrams in yours Sphinx powered docs" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "sphinxcontrib_mermaid-1.0.0-py3-none-any.whl", hash = "sha256:60b72710ea02087f212028feb09711225fbc2e343a10d34822fe787510e1caa3"}, + {file = "sphinxcontrib_mermaid-1.0.0.tar.gz", hash = "sha256:2e8ab67d3e1e2816663f9347d026a8dee4a858acdd4ad32dd1c808893db88146"}, +] + +[package.dependencies] +pyyaml = "*" +sphinx = "*" + +[package.extras] +test = ["defusedxml", "myst-parser", "pytest", "ruff", "sphinx"] + [[package]] name = "sphinxcontrib-qthelp" version = "1.0.6" @@ -7070,4 +7089,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "6a5ae8b5b701f0daee90e241187c1628477b6ac96394a3cb15f2921659e80e34" +content-hash = "74f33f02b6e6d6e6d45c1a8fa1798d6fdd29decb3cc3c81b6d3a8fe0d5c45ad7" diff --git a/pyproject.toml b/pyproject.toml index 9bbdf61fc..78a22687d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,7 @@ sphinx-togglebutton = "^0.3.2" sphinx-toolbox = "^3.5.0" sphinxcontrib-bibtex = "*" sphinxcontrib-spelling = "^8.0.0" +sphinxcontrib-mermaid = "^1.0.0" types-requests = "^2.31.0.20240311" types-tabulate = "^0.9.0.20240106" # this is needed for wandb only (undisclosed dependency) From 46d306c9fce3d99ea0ee06c9a624c0f5c168b5da Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 17 Nov 2025 21:36:24 +0700 Subject: [PATCH 2/3] Rewrote and extended batch notebook --- docs/02_deep_dives/L1_Batch.ipynb | 1529 ++++++++++++++++++++--------- 1 file changed, 1062 insertions(+), 467 deletions(-) diff --git a/docs/02_deep_dives/L1_Batch.ipynb b/docs/02_deep_dives/L1_Batch.ipynb index 379239a8c..3c87344eb 100644 --- a/docs/02_deep_dives/L1_Batch.ipynb +++ b/docs/02_deep_dives/L1_Batch.ipynb @@ -4,857 +4,1452 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Batch\n", + "# Batch: Tianshou's Core Data Structure\n", "\n", - "The `Batch` class serves as a fundamental data structure within Tianshou, designed to efficiently store and manipulate hierarchical named tensors. This tutorial provides comprehensive guidance on understanding the conceptual foundations and operational behavior of `Batch`, enabling users to fully leverage Tianshou's capabilities.\n", - "\n", - "The tutorial is organized into three sections: first, we establish the concept of hierarchical named tensors; second, we introduce basic `Batch` operations; and third, we explore advanced topics.\n", - "\n", - "## Hierarchical Named Tensors\n", - "\n", - "Hierarchical named tensors refer to a collection of tensors whose identifiers form a structured hierarchy. Consider a set of four tensors `[t1, t2, t3, t4]` with corresponding names `[name1, name2, name3, name4]`, where `name1` and `name2` reside within namespace `name0`. In this configuration, the fully qualified name of tensor `t1` becomes `name0.name1`, demonstrating how hierarchy manifests through tensor naming conventions.\n", - "\n", - "The structure of hierarchical named tensors can be represented using a tree data structure. This representation includes a virtual root node representing the entire object, with internal nodes serving as keys (names) and leaf nodes containing values (scalars or tensors).\n", - "\n", - "
\n", - "\n", - "
\n", - "\n", - "The necessity for hierarchical named tensors arises from the inherent heterogeneity of reinforcement learning problems. While the RL abstraction is elegantly simple:\n", - "\n", - "```python\n", - "state, reward, done = env.step(action)\n", - "```\n", - "\n", - "The `reward` and `done` components are typically scalar values. However, both `state` and `action` exhibit significant variation across different environments. For instance, a `state` may be represented as a simple vector, a tensor, or a combination of camera and sensory inputs. In the latter case, hierarchical named tensors provide a natural storage mechanism. This hierarchical structure extends beyond `state` and `action` to encompass all transition components (`state`, `action`, `reward`, `done`) within a unified hierarchical framework.\n", - "\n", - "While storing hierarchical named tensors is straightforward using nested dictionary structures:\n", - "\n", - "```python\n", - "{\n", - " 'done': done,\n", - " 'reward': reward,\n", - " 'state': {\n", - " 'camera': camera,\n", - " 'sensory': sensory\n", - " },\n", - " 'action': {\n", - " 'direct': direct,\n", - " 'point_3d': point_3d,\n", - " 'force': force,\n", - " }\n", - "}\n", - "```\n", - "\n", - "The challenge lies in **manipulating** these structures efficiently—for example, when adding new transition tuples to a replay buffer while handling their heterogeneity. The `Batch` class addresses this challenge by providing streamlined methods to create, store, and manipulate hierarchical named tensors.\n", - "\n", - "`Batch` can be conceptualized as a NumPy-enhanced Python dictionary. It shares similarities with PyTorch's `tensordict`, though with distinct type structure characteristics.\n", - "\n", - "
\n", - "\n", - "Data flow\n", - "
" + "The `Batch` class is Tianshou's fundamental data structure for efficiently storing and manipulating heterogeneous data in reinforcement learning. This tutorial provides comprehensive guidance on understanding its conceptual foundations, operational behavior, and best practices.\n" ] }, { "cell_type": "code", - "metadata": { - "tags": [ - "remove-output", - "hide-cell" - ] - }, + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "import pickle\n", + "from typing import cast\n", "\n", "import numpy as np\n", "import torch\n", + "from torch.distributions import Categorical, Normal\n", "\n", - "from tianshou.data import Batch" - ], - "outputs": [], - "execution_count": null + "from tianshou.data import Batch\n", + "from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Basic Usage\n", + "## 1. Introduction: Why Batch?\n", + "\n", + "### The Challenge in Reinforcement Learning\n", + "\n", + "Reinforcement learning algorithms face a fundamental data management challenge:\n", "\n", - "This section covers fundamental `Batch` operations, including the contents of `Batch` objects, construction methods, and manipulation techniques.\n", + "1. **Diverse Data Requirements**: Different RL algorithms need different data fields:\n", + " - Basic algorithms: `state`, `action`, `reward`, `done`, `next_state`\n", + " - Actor-Critic: additionally `advantages`, `returns`, `values`\n", + " - Policy Gradient: additionally `log_probs`, `old_log_probs`\n", + " - Off-policy: additionally `priority_weights`\n", "\n", - "### Content Specification\n", + "2. **Heterogeneous Observation Spaces**: Environments return diverse observation types:\n", + " - Simple: vectors (`np.array([1.0, 2.0, 3.0])`)\n", + " - Complex: images (`np.array(shape=(84, 84, 3))`)\n", + " - Hybrid: dictionaries combining multiple modalities\n", + " ```python\n", + " obs = {\n", + " 'camera': np.array(shape=(64, 64, 3)),\n", + " 'velocity': np.array([1.2, 0.5]),\n", + " 'inventory': np.array([5, 2, 0])\n", + " }\n", + " ```\n", "\n", - "The content of `Batch` objects is defined by the following rules:\n", + "3. **Data Flow Across Components**: Data must flow seamlessly through:\n", + " - Collectors (gathering experience from environments)\n", + " - Replay Buffers (storing and sampling transitions)\n", + " - Policies and Algorithms (learning and inference)\n", "\n", - "1. A `Batch` object may be empty (`Batch()`) or contain at least one key-value pair. Empty `Batch` objects can be utilized for key reservation (detailed in the Advanced Topics section).\n", + "### Why Not Alternatives?\n", "\n", - "2. Keys must be strings, serving as identifiers for their corresponding values.\n", + "#### Plain Dictionaries\n", + "Dictionaries lack essential features\n", + "```python\n", + "data = {'obs': np.array([1, 2]), 'reward': np.array([1.0, 2.0])}\n", + "```\n", "\n", - "3. Values may be scalars, tensors, or `Batch` objects. This recursive definition enables the construction of hierarchical batch structures.\n", + "They would work in principle but has no shape/length semantics, no indexing, and no type safety.\n", "\n", - "4. Tensors constitute the primary value type. Tensors are n-dimensional arrays of uniform data type. Two tensor types are supported: [PyTorch](https://pytorch.org/) tensor type `torch.Tensor` and [NumPy](https://numpy.org/) tensor type `np.ndarray`.\n", + "#### TensorDict\n", + "While `TensorDict` (used in `pytorch-rl`) is a powerful alternative:\n", + "- **Batch supports arbitrary objects**, not just tensors (useful for object-dtype arrays, custom types)\n", + "- **Batch has better type checking** via `BatchProtocol` (enables IDE autocompletion)\n", + "- **Batch preceded TensorDict** and provides a stable foundation for Tianshou\n", + "- **TensorDict isn't part of core PyTorch** (external dependency)\n", "\n", - "5. Scalars represent valid values, comprising single boolean values, numbers, or objects. These include Python scalars (`False`, `1`, `2.3`, `None`, `'hello'`) and NumPy scalars (`np.bool_(True)`, `np.int32(1)`, `np.float64(2.3)`). Scalars must not be conflated with `Batch`/dict/tensor types.\n", + "### What is Batch?\n", "\n", - "**Note:** `Batch` objects cannot directly store `dict` objects due to internal implementation using dictionaries for data storage. During construction, `dict` objects are automatically converted to `Batch` objects.\n", + "**Batch = Dictionary + Array hybrid with RL-specific features**\n", "\n", - "Supported tensor data types include boolean and numeric types (any integer or floating-point precision supported by NumPy or PyTorch). NumPy's support for object arrays enables storage of non-numeric data types within `Batch`. For data that are neither boolean nor numeric (e.g., strings, sets), storage within `np.ndarray` with `np.object` data type is supported, allowing `Batch` to accommodate arbitrary Python objects." + "Key capabilities:\n", + "- **Dict-like**: Key-value storage with attribute access (`batch.obs`, `batch.reward`)\n", + "- **Array-like**: Shape, indexing, slicing (`batch[0]`, `batch[:10]`, `batch.shape`)\n", + "- **Hierarchical**: Nested structures for complex data\n", + "- **Type-safe**: Protocol-based typing for IDE support\n", + "- **RL-aware**: Special handling for distributions, missing values, heterogeneous aggregation" ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "data = Batch(a=4, b=[5, 5], c=\"2312312\", d=(\"a\", -2, -3))\n", - "print(data)\n", - "print(data.b)" - ], - "outputs": [], - "execution_count": null + "## 2. Core Concepts\n", + "\n", + "### Hierarchical Named Tensors\n", + "\n", + "Batch stores **hierarchical named tensors** - collections of tensors whose identifiers form a structured hierarchy. Consider tensors `[t1, t2, t3, t4]` with names `[name1, name2, name3, name4]`, where `name1` and `name2` are under namespace `name0`. The fully qualified name of `t1` is `name0.name1`.\n", + "\n", + "### Tree Structure Visualization\n", + "\n", + "The structure can be visualized as a tree with:\n", + "- **Root**: The Batch object itself\n", + "- **Internal nodes**: Keys (names)\n", + "- **Leaf nodes**: Values (scalars, arrays, tensors)\n", + "\n", + "```mermaid\n", + "graph TD\n", + " root[\"Batch (root)\"]\n", + " root --> obs[\"obs\"]\n", + " root --> act[\"act\"]\n", + " root --> rew[\"rew\"]\n", + " obs --> camera[\"camera\"]\n", + " obs --> sensory[\"sensory\"]\n", + " camera --> cam_data[\"np.array(3,3)\"]\n", + " sensory --> sens_data[\"np.array(5,)\"]\n", + " act --> act_data[\"np.array(2,)\"]\n", + " rew --> rew_data[\"3.66\"]\n", + " \n", + " style root fill:#e1f5ff\n", + " style obs fill:#fff4e1\n", + " style act fill:#fff4e1\n", + " style rew fill:#fff4e1\n", + " style camera fill:#ffe1f5\n", + " style sensory fill:#ffe1f5\n", + " style cam_data fill:#e8f5e1\n", + " style sens_data fill:#e8f5e1\n", + " style act_data fill:#e8f5e1\n", + " style rew_data fill:#e8f5e1\n", + "```" + ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "A `Batch` object stores all input data as key-value pairs and automatically converts values to NumPy arrays when applicable." + "# Example: hierarchical structure\n", + "data = {\n", + " \"action\": np.array([1.0, 2.0, 3.0]),\n", + " \"reward\": 3.66,\n", + " \"obs\": {\n", + " \"camera\": np.zeros((3, 3)),\n", + " \"sensory\": np.ones(5),\n", + " },\n", + "}\n", + "\n", + "batch = Batch(data)\n", + "print(batch)\n", + "print(\"\\nAccessing nested values:\")\n", + "print(f\"batch.obs.camera.shape = {batch.obs.camera.shape}\")\n", + "print(f\"batch.obs.sensory = {batch.obs.sensory}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Construction Methods\n", + "### Data Flow in RL Pipeline\n", + "\n", + "Batch facilitates data flow throughout the RL pipeline:\n", + "\n", + "```mermaid\n", + "graph LR\n", + " A[Environment] -->|ObsBatchProtocol| B[Collector]\n", + " B -->|RolloutBatchProtocol| C[Replay Buffer]\n", + " C -->|RolloutBatchProtocol| D[Policy]\n", + " D -->|ActBatchProtocol| A\n", + " D -->|BatchWithAdvantages| E[Algorithm/Trainer]\n", + " E --> D\n", + " \n", + " style A fill:#e1f5ff\n", + " style B fill:#fff4e1\n", + " style C fill:#ffe1f5\n", + " style D fill:#e8f5e1\n", + " style E fill:#f5e1e1\n", + "```\n", "\n", - "Two primary construction methods are available for `Batch` objects: construction from a dictionary, or using keyword arguments. The following examples demonstrate these approaches." + "Each arrow represents a specific `BatchProtocol` that defines what fields are expected at that stage." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Dictionary-Based Construction" + "## 3. Basic Operations\n", + "\n", + "### 3.1 Construction\n", + "\n", + "Batch objects can be constructed in several ways:" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Direct dictionary passing (potentially nested) is supported\n", - "data = Batch({\"a\": 4, \"b\": [5, 5], \"c\": \"2312312\"})\n", - "# Lists are automatically converted to NumPy arrays\n", - "print(data.b)\n", - "data.b = np.array([3, 4, 5])\n", - "print(data)" - ], "outputs": [], - "execution_count": null + "source": [ + "# From keyword arguments\n", + "batch1 = Batch(a=4, b=[5, 5], c=\"hello\")\n", + "print(\"From kwargs:\", batch1)\n", + "\n", + "# From dictionary\n", + "batch2 = Batch({\"a\": 4, \"b\": [5, 5], \"c\": \"hello\"})\n", + "print(\"\\nFrom dict:\", batch2)\n", + "\n", + "# From list of dictionaries (automatically stacked)\n", + "batch3 = Batch([{\"a\": 1, \"b\": 2}, {\"a\": 3, \"b\": 4}])\n", + "print(\"\\nFrom list of dicts:\", batch3)\n", + "\n", + "# Nested batch\n", + "batch4 = Batch(obs=Batch(x=1, y=2), act=5)\n", + "print(\"\\nNested:\", batch4)" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Lists of dictionary objects (potentially nested) are automatically stacked\n", - "data = Batch([{\"a\": 0.0, \"b\": \"hello\"}, {\"a\": 1.0, \"b\": \"world\"}])\n", - "print(data)" - ], + "### 3.2 Content Rules\n", + "\n", + "Understanding what Batch can store and how it converts data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Keys must be strings\n", + "batch = Batch()\n", + "batch.key1 = \"value\"\n", + "batch.key2 = np.array([1, 2, 3])\n", + "print(\"Keys:\", list(batch.keys()))\n", + "\n", + "# Automatic conversions\n", + "demo = Batch(\n", + " scalar_int=5, # → np.array(5)\n", + " scalar_float=3.14, # → np.array(3.14)\n", + " list_nums=[1, 2, 3], # → np.array([1, 2, 3])\n", + " list_mixed=[1, \"hello\", None], # → np.array([1, \"hello\", None], dtype=object)\n", + " dict_val={\"x\": 1, \"y\": 2}, # → Batch(x=1, y=2)\n", + ")\n", + "\n", + "print(\"\\nAutomatic conversions:\")\n", + "print(f\"scalar_int type: {type(demo.scalar_int)}, value: {demo.scalar_int}\")\n", + "print(f\"list_nums type: {type(demo.list_nums)}, dtype: {demo.list_nums.dtype}\")\n", + "print(f\"list_mixed dtype: {demo.list_mixed.dtype}\")\n", + "print(f\"dict_val type: {type(demo.dict_val)}\")" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Keyword Argument Construction" + "**Important conversions:**\n", + "- Lists of numbers → NumPy arrays\n", + "- Lists with mixed types → Object-dtype arrays\n", + "- Dictionaries → Batch objects (recursively)\n", + "- Scalars → NumPy scalars" ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Construction using keyword arguments\n", - "data = Batch(a=[4, 4], b=[5, 5], c=[None, None])\n", - "print(data)" - ], - "outputs": [], - "execution_count": null + "### 3.3 Access Patterns\n", + "\n", + "**Important: Understanding Iteration**" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Combining dictionary and keyword arguments\n", - "data = Batch(\n", - " {\"a\": [4, 4], \"b\": [5, 5]}, c=[None, None]\n", - ") # First argument is a dictionary; 'c' is a keyword argument\n", - "print(data)" - ], "outputs": [], - "execution_count": null + "source": [ + "batch = Batch(a=[1, 2, 3], b=[4, 5, 6])\n", + "\n", + "# Attribute vs dictionary access (equivalent)\n", + "print(\"Attribute access:\", batch.a)\n", + "print(\"Dict access:\", batch[\"a\"])\n", + "\n", + "# Getting keys\n", + "print(\"\\nKeys:\", list(batch.keys()))\n", + "\n", + "# Gotcha: Iteration is array like, not over keys\n", + "print(\"\\nIteration behavior:\")\n", + "print(\"for x in batch iterates over batch[0], batch[1], ..., NOT keys!\")\n", + "for i, item in enumerate(batch):\n", + " print(f\"batch[{i}] = {item}\")\n", + "\n", + "# This is different from dict behavior!\n", + "regular_dict = {\"a\": [1, 2, 3], \"b\": [4, 5, 6]}\n", + "print(\"\\nCompare with dict iteration (iterates over keys):\")\n", + "for key in regular_dict:\n", + " print(f\"key = {key}\")" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "arr = np.zeros((3, 4))\n", - "# By default, Batch maintains references to data; explicit copying is supported via the copy parameter\n", - "data = Batch(arr=arr, copy=True) # data.arr is now a copy of 'arr'" - ], "outputs": [], - "execution_count": null + "source": "" }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Nested Batch Construction" + "### 3.4 Indexing & Slicing\n", + "\n", + "Batch supports NumPy-like indexing and slicing:" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "# Nested dictionaries are converted to nested Batch objects\n", - "data = {\n", - " \"action\": np.array([1.0, 2.0, 3.0]),\n", - " \"reward\": 3.66,\n", - " \"obs\": {\n", - " \"rgb_obs\": np.zeros((3, 3)),\n", - " \"flatten_obs\": np.ones(5),\n", - " },\n", - "}\n", + "batch = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])\n", "\n", - "batch = Batch(data, extra=\"extra_string\")\n", - "print(batch)\n", - "# batch.obs is also a Batch instance\n", - "print(type(batch.obs))\n", - "print(batch.obs.rgb_obs)" - ], - "outputs": [], - "execution_count": null + "print(\"Original batch shape:\", batch.shape)\n", + "print(\"Original batch length:\", len(batch))\n", + "\n", + "# Single index\n", + "print(\"\\nbatch[0]:\")\n", + "print(batch[0])\n", + "\n", + "# Slicing\n", + "print(\"\\nbatch[:1]:\")\n", + "print(batch[:1])\n", + "\n", + "# Advanced indexing\n", + "print(\"\\nbatch[[0, 1]]:\")\n", + "print(batch[[0, 1]])\n", + "\n", + "# Multi-dimensional indexing\n", + "print(\"\\nbatch[:, 0] (first column of all arrays):\")\n", + "print(batch[:, 0])" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Lists of dictionaries/Batches are automatically concatenated/stacked\n", - "# This feature facilitates data collection from parallelized environments\n", - "batch = Batch([data] * 3)\n", - "print(batch)\n", - "print(batch.obs.rgb_obs.shape)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Broadcasting and in-place operations\n", + "batch[:, 1] += 10\n", + "print(\"After batch[:, 1] += 10:\")\n", + "print(batch)" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Data Manipulation\n", + "### 3.5 Stack, Concatenate, and Split\n", "\n", - "Internal data can be accessed using either `b.key` or `b[key]` notation, where `b.key` retrieves the subtree rooted at `key`. When the result is a non-empty subtree, key references can be chained (e.g., `b.key.key1.key2.key3`). Upon reaching a leaf node, the stored data (scalars or tensors) is returned." + "Combining and splitting batches:" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "data = Batch(a=4, b=[5, 5])\n", - "print(data.b)\n", - "# Attribute access (obj.key) is equivalent to dictionary access (obj[\"key\"])\n", - "print(data[\"a\"])" - ], "outputs": [], - "execution_count": null + "source": "# Stack: adds a new dimension\nbatch1 = Batch(a=np.array([1, 2]), b=np.array([5, 6]))\nbatch2 = Batch(a=np.array([3, 4]), b=np.array([7, 8]))\n\nstacked = Batch.stack([batch1, batch2])\nprint(\"Stacked:\")\nprint(stacked)\nprint(f\"Shape: {stacked.shape}\")\n\n# Concatenate: extends along existing dimension\nconcatenated = Batch.cat([batch1, batch2])\nprint(\"\\nConcatenated:\")\nprint(concatenated)\nprint(f\"Shape: {concatenated.shape}\")" }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Dictionary-style iteration over items is supported\n", - "for key, value in data.items():\n", - " print(f\"{key}: {value}\")" - ], "outputs": [], - "execution_count": null + "source": [ + "# Split\n", + "batch = Batch(a=np.arange(10), b=np.arange(10, 20))\n", + "splits = list(batch.split(size=3, shuffle=False))\n", + "print(f\"Split into {len(splits)} batches:\")\n", + "for i, split in enumerate(splits):\n", + " print(f\"Split {i}: a={split.a}, length={len(split)}\")" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Methods keys() and values() behave analogously to their dict counterparts\n", - "for key in data.keys():\n", - " print(f\"{key}\")" - ], - "outputs": [], - "execution_count": null + "### 3.6 Data Type Conversion\n", + "\n", + "Converting between NumPy and PyTorch:" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# The update() method operates analogously to dict.update()\n", - "# Equivalent to: data.c = 1; data.d = 2; data.e = 3;\n", - "data.update(c=1, d=2, e=3)\n", - "print(data)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Create batch with NumPy arrays\n", + "batch = Batch(a=np.zeros((3, 4)), b=np.ones(5))\n", + "print(\"Original (NumPy):\")\n", + "print(f\"batch.a type: {type(batch.a)}\")\n", + "\n", + "# Convert to PyTorch (in-place)\n", + "batch.to_torch_(dtype=torch.float32, device=\"cpu\")\n", + "print(\"\\nAfter to_torch_():\")\n", + "print(f\"batch.a type: {type(batch.a)}\")\n", + "print(f\"batch.a dtype: {batch.a.dtype}\")\n", + "\n", + "# Convert back to NumPy (in-place)\n", + "batch.to_numpy_()\n", + "print(\"\\nAfter to_numpy_():\")\n", + "print(f\"batch.a type: {type(batch.a)}\")\n", + "\n", + "# Non-in-place versions return a new batch\n", + "batch_torch = batch.to_torch()\n", + "print(\"\\nOriginal batch unchanged:\", type(batch.a))\n", + "print(\"New batch:\", type(batch_torch.a))" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Adding and deleting key-value pairs\n", - "batch1 = Batch({\"a\": [4, 4], \"b\": (5, 5)})\n", - "print(batch1)\n", + "## 4. Type Safety with Protocols\n", "\n", - "batch1.c = Batch(c1=np.arange(3), c2=False)\n", - "del batch1.a\n", - "print(batch1)\n", + "### Why Protocols?\n", "\n", - "# Accessing values by key\n", - "assert batch1[\"c\"] is batch1.c\n", - "print(\"c\" in batch1)" - ], + "Batch needs to be **flexible** (not fixed fields like dataclasses) but we still want **type safety** and **IDE autocompletion**. Protocols provide the best of both worlds:\n", + "\n", + "- **Runtime flexibility**: Add any fields dynamically\n", + "- **Static type checking**: Type checkers (mypy, pyright) verify correct usage\n", + "- **IDE support**: Autocompletion for expected fields\n", + "\n", + "### What is BatchProtocol?\n", + "\n", + "A `Protocol` defines an interface without implementation. Think of it as a contract: \"any object with these fields is valid.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Creating a typed batch using cast\n", + "# This enables IDE autocompletion and type checking\n", + "\n", + "# ActBatchProtocol: just needs 'act' field\n", + "act_batch = cast(ActBatchProtocol, Batch(act=np.array([1, 2, 3])))\n", + "print(\"ActBatchProtocol:\", act_batch.act)\n", + "\n", + "# ObsBatchProtocol: needs 'obs' and 'info' fields\n", + "obs_batch = cast(\n", + " ObsBatchProtocol,\n", + " Batch(obs=np.array([[1.0, 2.0], [3.0, 4.0]]), info=np.array([{}, {}], dtype=object)),\n", + ")\n", + "print(\"\\nObsBatchProtocol:\", obs_batch.obs)\n", + "\n", + "# RolloutBatchProtocol: needs obs, obs_next, act, rew, terminated, truncated\n", + "rollout_batch = cast(\n", + " RolloutBatchProtocol,\n", + " Batch(\n", + " obs=np.array([[1.0, 2.0], [3.0, 4.0]]),\n", + " obs_next=np.array([[2.0, 3.0], [4.0, 5.0]]),\n", + " act=np.array([0, 1]),\n", + " rew=np.array([1.0, 2.0]),\n", + " terminated=np.array([False, True]),\n", + " truncated=np.array([False, False]),\n", + " info=np.array([{}, {}], dtype=object),\n", + " ),\n", + ")\n", + "print(\"\\nRolloutBatchProtocol reward:\", rollout_batch.rew)" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "**Important Note:** While `for x in data` iterates over keys when `data` is a `dict` object, for `Batch` objects this syntax iterates over `data[0], data[1], ..., data[-1]`." + "### Protocol Hierarchy\n", + "\n", + "Tianshou defines a hierarchy of protocols for different use cases:\n", + "\n", + "```mermaid\n", + "graph TD\n", + " BP[BatchProtocol
Base protocol] --> OBP[ObsBatchProtocol
obs, info]\n", + " BP --> ABP[ActBatchProtocol
act]\n", + " ABP --> ASBP[ActStateBatchProtocol
act, state]\n", + " OBP --> RBP[RolloutBatchProtocol
+obs_next, act, rew,
terminated, truncated]\n", + " RBP --> BWRP[BatchWithReturnsProtocol
+returns]\n", + " BWRP --> BWAP[BatchWithAdvantagesProtocol
+adv, v_s]\n", + " ASBP --> MOBP[ModelOutputBatchProtocol
+logits]\n", + " MOBP --> DBP[DistBatchProtocol
+dist]\n", + " DBP --> DLPBP[DistLogProbBatchProtocol
+log_prob]\n", + " BWAP --> LOPBP[LogpOldProtocol
+logp_old]\n", + " \n", + " style BP fill:#e1f5ff\n", + " style OBP fill:#fff4e1\n", + " style ABP fill:#fff4e1\n", + " style RBP fill:#ffe1f5\n", + " style BWRP fill:#e8f5e1\n", + " style BWAP fill:#e8f5e1\n", + " style DBP fill:#f5e1e1\n", + " style LOPBP fill:#e1e1f5\n", + "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Length, Shape, Indexing, and Slicing\n", + "### Using Protocols in Functions\n", "\n", - "`Batch` implements a subset of NumPy ndarray APIs, supporting advanced slicing operations (e.g., `batch[:, i]`) provided the slice is valid. NumPy's broadcasting mechanism is also supported." + "Protocols enable type-safe function signatures:" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Initializing Batch with tensors\n", - "data = Batch(a=np.array([[0.0, 2.0], [1.0, 3.0]]), b=[[5.0, -5.0], [1.0, -2.0]])\n", - "# When all values share the same length/shape, the Batch adopts that length/shape\n", - "print(len(data))\n", - "print(data.shape)" - ], "outputs": [], - "execution_count": null + "source": [ + "def process_observations(batch: ObsBatchProtocol) -> np.ndarray:\n", + " \"\"\"Function that expects observations.\n", + "\n", + " IDE will autocomplete batch.obs and batch.info!\n", + " Type checker will verify these fields exist.\n", + " \"\"\"\n", + " # IDE knows batch.obs exists\n", + " return batch.obs if isinstance(batch.obs, np.ndarray) else np.array(batch.obs)\n", + "\n", + "\n", + "def compute_advantage(batch: RolloutBatchProtocol) -> np.ndarray:\n", + " \"\"\"Function that expects rollout data.\n", + "\n", + " IDE will autocomplete batch.rew, batch.obs_next, etc.\n", + " \"\"\"\n", + " # Simplified advantage computation\n", + " return batch.rew # IDE knows this exists\n", + "\n", + "\n", + "# Example usage\n", + "obs_data = Batch(obs=np.array([1, 2, 3]), info=np.array([{}], dtype=object))\n", + "result = process_observations(obs_data)\n", + "print(\"Processed obs:\", result)" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Accessing the first element of all stored tensors while preserving Batch structure\n", - "print(data[0])" - ], - "outputs": [], - "execution_count": null + "**Key Protocol Types:**\n", + "\n", + "- `ActBatchProtocol`: Just actions (for simple policies)\n", + "- `ObsBatchProtocol`: Observations and info\n", + "- `RolloutBatchProtocol`: Complete transitions (obs, act, rew, done, obs_next)\n", + "- `BatchWithReturnsProtocol`: Rollouts + computed returns\n", + "- `BatchWithAdvantagesProtocol`: Returns + advantages and values\n", + "- `DistBatchProtocol`: Contains distribution objects\n", + "- `LogpOldProtocol`: For importance sampling (PPO, etc.)\n", + "\n", + "See `tianshou/data/types.py` for the complete list!" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Iteration over data[0], data[1], ..., data[-1]\n", - "for sample in data:\n", - " print(sample.a)" - ], - "outputs": [], - "execution_count": null + "## 5. Distribution Slicing\n", + "\n", + "### Why Special Handling?\n", + "\n", + "PyTorch `Distribution` objects need special slicing because they're not simple arrays. When you slice `batch[0:2]`, Tianshou needs to slice the underlying distribution parameters correctly.\n", + "\n", + "### Supported Distributions\n", + "\n", + "Tianshou supports slicing for:\n", + "- `Categorical`: Discrete distributions\n", + "- `Normal`: Continuous Gaussian distributions\n", + "- `Independent`: Wraps other distributions" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Advanced slicing with arithmetic operations and broadcasting\n", - "data[:, 1] += 1\n", - "print(data)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Categorical distribution\n", + "probs = torch.tensor([[0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])\n", + "dist = Categorical(probs=probs)\n", + "batch = Batch(dist=dist, values=np.array([1, 2, 3]))\n", + "\n", + "print(\"Original batch length:\", len(batch))\n", + "print(\"Original dist probs shape:\", batch.dist.probs.shape)\n", + "\n", + "# Slicing automatically handles the distribution\n", + "sliced = batch[0:2]\n", + "print(\"\\nSliced batch length:\", len(sliced))\n", + "print(\"Sliced dist probs shape:\", sliced.dist.probs.shape)\n", + "print(\"Sliced values:\", sliced.values)" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Direct application of NumPy functions to Batch objects\n", - "print(np.mean(data))" - ], "outputs": [], - "execution_count": null + "source": [ + "# Normal distribution\n", + "loc = torch.tensor([0.0, 1.0, 2.0])\n", + "scale = torch.tensor([1.0, 1.0, 1.0])\n", + "normal_dist = Normal(loc=loc, scale=scale)\n", + "batch_normal = Batch(dist=normal_dist, actions=np.array([0.5, 1.5, 2.5]))\n", + "\n", + "print(\"Normal distribution batch:\")\n", + "print(f\"Original mean: {batch_normal.dist.mean}\")\n", + "\n", + "# Index a single element\n", + "single = batch_normal[1]\n", + "print(f\"\\nSingle element mean: {single.dist.mean}\")\n", + "print(f\"Single element action: {single.actions}\")" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Conversion to list is supported\n", - "list(data)" - ], + "### Converting to At Least 2D\n", + "\n", + "Sometimes you need to ensure distributions have a batch dimension:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "from tianshou.data.batch import dist_to_atleast_2d\n", + "\n", + "# Scalar distribution (no batch dimension)\n", + "scalar_dist = Categorical(probs=torch.tensor([0.3, 0.7]))\n", + "print(\"Scalar dist batch_shape:\", scalar_dist.batch_shape)\n", + "\n", + "# Convert to have batch dimension\n", + "batched_dist = dist_to_atleast_2d(scalar_dist)\n", + "print(\"Batched dist batch_shape:\", batched_dist.batch_shape)\n", + "\n", + "# For entire batch\n", + "scalar_batch = Batch(a=1, b=2, dist=Categorical(probs=torch.ones(3)))\n", + "print(\"\\nBefore to_at_least_2d:\", scalar_batch.dist.batch_shape)\n", + "\n", + "batch_2d = scalar_batch.to_at_least_2d()\n", + "print(\"After to_at_least_2d:\", batch_2d.dist.batch_shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Use Cases\n", + "\n", + "Distribution slicing is used in:\n", + "- **Policy sampling**: When policies output distributions, slicing batches preserves distribution structure\n", + "- **Replay buffer sampling**: Distributions are stored and retrieved correctly\n", + "- **Advantage computation**: Computing log probabilities on subsets of data" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Environment Stepping Example" + "## 6. Advanced Topics\n", + "\n", + "### 6.1 Key Reservation\n", + "\n", + "Sometimes you know what keys you'll need but don't have values yet. Reserve keys using empty `Batch()` objects:\n", + "\n", + "```mermaid\n", + "graph TD\n", + " root[\"Batch\"]\n", + " root --> a[\"key1: np.array([1,2,3])\"]\n", + " root --> b[\"key2: Batch() (reserved)\"]\n", + " root --> c[\"key3\"]\n", + " c --> c1[\"subkey1: Batch() (reserved)\"]\n", + " c --> c2[\"subkey2: np.array([4,5])\"]\n", + " \n", + " style root fill:#e1f5ff\n", + " style a fill:#e8f5e1\n", + " style b fill:#ffcccc\n", + " style c fill:#fff4e1\n", + " style c1 fill:#ffcccc\n", + " style c2 fill:#e8f5e1\n", + "```" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "# Example: Data collected from four parallel environments\n", - "step_outputs = [\n", - " {\n", - " \"act\": np.random.randint(10),\n", - " \"rew\": 0.0,\n", - " \"obs\": np.ones((3, 3)),\n", - " \"info\": {\"done\": np.random.choice(2), \"failed\": False},\n", - " \"terminated\": False,\n", - " \"truncated\": False,\n", - " }\n", - " for _ in range(4)\n", - "]\n", - "batch = Batch(step_outputs)\n", + "# Reserving keys\n", + "batch = Batch(\n", + " known_field=np.array([1, 2]),\n", + " future_field=Batch(), # Reserved for later\n", + ")\n", + "print(\"Batch with reserved key:\")\n", "print(batch)\n", - "print(batch.shape)" - ], - "outputs": [], - "execution_count": null + "\n", + "# Later, assign actual data\n", + "batch.future_field = np.array([3, 4])\n", + "print(\"\\nAfter assignment:\")\n", + "print(batch)\n", + "\n", + "# Nested reservation\n", + "batch2 = Batch(\n", + " obs=Batch(\n", + " camera=Batch(), # Reserved\n", + " lidar=np.zeros(10),\n", + " )\n", + ")\n", + "print(\"\\nNested reservation:\")\n", + "print(batch2)" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Advanced indexing for selecting data from specific environments\n", - "print(batch[0])\n", - "print(batch[[0, 3]])" - ], - "outputs": [], - "execution_count": null + "### 6.2 Length and Shape Semantics\n", + "\n", + "Understanding when `len()` works and what `shape` means:" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Slicing operations are supported\n", - "print(batch[-2:])" - ], "outputs": [], - "execution_count": null + "source": [ + "# Normal case: all tensors same length\n", + "batch1 = Batch(a=[1, 2, 3], b=np.array([4, 5, 6]))\n", + "print(\"Normal batch:\")\n", + "print(f\"len(batch1) = {len(batch1)}\")\n", + "print(f\"batch1.shape = {batch1.shape}\")\n", + "\n", + "# Scalars have no length\n", + "batch2 = Batch(a=5, b=10)\n", + "print(\"\\nScalar batch:\")\n", + "print(f\"batch2.shape = {batch2.shape}\")\n", + "try:\n", + " print(f\"len(batch2) = {len(batch2)}\")\n", + "except TypeError as e:\n", + " print(f\"len(batch2) raises TypeError: {e}\")\n", + "\n", + "# Mixed lengths: returns minimum\n", + "batch3 = Batch(a=[1, 2], b=[3, 4, 5])\n", + "print(\"\\nMixed length batch:\")\n", + "print(f\"len(batch3) = {len(batch3)} (minimum of 2 and 3)\")\n", + "\n", + "# Reserved keys are ignored\n", + "batch4 = Batch(a=[1, 2, 3], reserved=Batch())\n", + "print(\"\\nBatch with reserved key:\")\n", + "print(f\"len(batch4) = {len(batch4)} (reserved key ignored)\")" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Stack, Concatenate, and Split Operations\n", + "### 6.3 Empty Batches\n", "\n", - "Tianshou provides intuitive methods for stacking and concatenating multiple `Batch` instances, as well as splitting instances into multiple batches. Currently, we focus on aggregation (stack/concatenate) of homogeneous (structurally identical) batches." + "Understanding different meanings of \"empty\":" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "data_1 = Batch(a=np.array([0.0, 2.0]), b=5)\n", - "data_2 = Batch(a=np.array([1.0, 3.0]), b=-5)\n", - "data = Batch.stack((data_1, data_2))\n", - "print(data)" - ], "outputs": [], - "execution_count": null + "source": [ + "# 1. No keys at all\n", + "empty1 = Batch()\n", + "print(\"No keys:\")\n", + "print(f\"len(empty1.get_keys()) = {len(list(empty1.get_keys()))}\")\n", + "print(f\"len(empty1) = {len(empty1)}\")\n", + "\n", + "# 2. Has keys but they're all reserved\n", + "empty2 = Batch(a=Batch(), b=Batch())\n", + "print(\"\\nReserved keys only:\")\n", + "print(f\"len(empty2.get_keys()) = {len(list(empty2.get_keys()))}\")\n", + "print(f\"len(empty2) = {len(empty2)}\")\n", + "\n", + "# 3. Has data but length is 0\n", + "empty3 = Batch(a=np.array([]), b=np.array([]))\n", + "print(\"\\nZero-length arrays:\")\n", + "print(f\"len(empty3.get_keys()) = {len(list(empty3.get_keys()))}\")\n", + "print(f\"len(empty3) = {len(empty3)}\")" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Split operation with optional shuffling\n", - "data_split = list(data.split(1, shuffle=False))\n", - "print(data_split)" - ], - "outputs": [], - "execution_count": null + "**Checking emptiness:**\n", + "- `len(batch.get_keys()) == 0`: No keys (completely empty)\n", + "- `len(batch) == 0`: No data elements (may have reserved keys)\n", + "\n", + "**The `.empty()` and `.empty_()` methods:**\n", + "These reset values to zeros/None, different from checking emptiness:" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "data_cat = Batch.cat(data_split)\n", - "print(data_cat)" - ], "outputs": [], - "execution_count": null + "source": [ + "batch = Batch(a=[1, 2, 3], b=[\"x\", \"y\", \"z\"])\n", + "print(\"Original:\", batch)\n", + "\n", + "# Empty specific index\n", + "batch[0] = Batch.empty(batch[0])\n", + "print(\"\\nAfter emptying index 0:\")\n", + "print(batch)" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Additional Concatenation and Stacking Examples" + "### 6.4 Heterogeneous Aggregation\n", + "\n", + "Stacking/concatenating batches with different keys:\n", + "\n", + "```mermaid\n", + "graph LR\n", + " A[\"Batch(a=[1,2], c=5)\"] --> C[\"Batch.stack\"]\n", + " B[\"Batch(b=[3,4], c=6)\"] --> C\n", + " C --> D[\"Batch(a=[[1,2],[0,0]],
b=[[0,0],[3,4]],
c=[5,6])\"]\n", + " \n", + " style A fill:#e1f5ff\n", + " style B fill:#fff4e1\n", + " style C fill:#ffe1f5\n", + " style D fill:#e8f5e1\n", + "```" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Concatenating batches with compatible keys\n", - "b1 = Batch(a=[{\"b\": np.float64(1.0), \"d\": Batch(e=np.array(3.0))}])\n", - "b2 = Batch(a=[{\"b\": np.float64(4.0), \"d\": {\"e\": np.array(6.0)}}])\n", - "b12_cat_out = Batch.cat([b1, b2])\n", - "print(b1)\n", - "print(b2)\n", - "print(b12_cat_out)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Stack with different keys (missing keys padded with zeros)\n", + "batch_a = Batch(a=np.ones((2, 3)), shared=np.array([1, 2]))\n", + "batch_b = Batch(b=np.zeros((2, 4)), shared=np.array([3, 4]))\n", + "\n", + "stacked = Batch.stack([batch_a, batch_b])\n", + "print(\"Stacked batch:\")\n", + "print(f\"a.shape = {stacked.a.shape} (padded with zeros for batch_b)\")\n", + "print(f\"b.shape = {stacked.b.shape} (padded with zeros for batch_a)\")\n", + "print(f\"shared.shape = {stacked.shared.shape} (in both batches)\")\n", + "print(stacked)" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Stacking batches with compatible keys along specified axis\n", - "b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))\n", - "b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))\n", - "b34_stack = Batch.stack((b3, b4), axis=1)\n", - "print(b3)\n", - "print(b4)\n", - "print(b34_stack)" - ], - "outputs": [], - "execution_count": null + "### 6.5 Missing Values\n", + "\n", + "Handling `None` and `NaN` values:" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Splitting batch into unit-sized batches with optional shuffling\n", - "print(type(b34_stack.split(1)))\n", - "print(list(b34_stack.split(1, shuffle=True)))" - ], "outputs": [], - "execution_count": null + "source": [ + "# Batch with missing values\n", + "batch = Batch(a=[1, 2, None, 4], b=[5.0, np.nan, 7.0, 8.0], c=[[1, 2], [3, 4], [5, 6], [7, 8]])\n", + "\n", + "# Check for nulls\n", + "print(\"Has null?\", batch.hasnull())\n", + "\n", + "# Get null mask\n", + "null_mask = batch.isnull()\n", + "print(\"\\nNull mask:\")\n", + "print(f\"a: {null_mask.a}\")\n", + "print(f\"b: {null_mask.b}\")\n", + "\n", + "# Drop rows with any null\n", + "clean_batch = batch.dropnull()\n", + "print(\"\\nAfter dropnull() (keeps rows 0 and 3):\")\n", + "print(f\"Length: {len(clean_batch)}\")\n", + "print(f\"a: {clean_batch.a}\")\n", + "print(f\"b: {clean_batch.b}\")" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Data Type Conversion\n", + "### 6.6 Value Transformations\n", "\n", - "While `Batch` supports both NumPy arrays and PyTorch Tensors with identical usage patterns, seamless conversion between these types is provided." + "Applying functions to all values recursively:" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))\n", - "batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))\n", - "batch_cat = Batch.cat([batch1, batch2, batch1])\n", - "print(batch_cat)" - ], "outputs": [], - "execution_count": null + "source": [ + "batch = Batch(a=np.array([1, 2, 3]), nested=Batch(b=np.array([4.0, 5.0]), c=np.array([6, 7, 8])))\n", + "\n", + "# Apply transformation (returns new batch)\n", + "doubled = batch.apply_values_transform(lambda x: x * 2)\n", + "print(\"Original batch a:\", batch.a)\n", + "print(\"Doubled batch a:\", doubled.a)\n", + "print(\"Doubled nested.b:\", doubled.nested.b)\n", + "\n", + "# In-place transformation\n", + "batch.apply_values_transform(lambda x: x + 10, inplace=True)\n", + "print(\"\\nAfter in-place +10:\")\n", + "print(\"a:\", batch.a)\n", + "print(\"nested.b:\", batch.nested.b)" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Data type conversion is straightforward when uniform data types are desired." + "## 7. Surprising Behaviors & Gotchas\n", + "\n", + "### Iteration Does NOT Iterate Over Keys!\n", + "\n", + "**This is the most common source of confusion:**" ] }, { "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch = Batch(a=[1, 2, 3], b=[4, 5, 6])\n", + "\n", + "print(\"WRONG: This doesn't iterate over keys!\")\n", + "for item in batch:\n", + " print(f\"item = {item}\") # Prints batch[0], batch[1], batch[2]\n", + "\n", + "print(\"\\nCORRECT: To iterate over keys:\")\n", + "for key in batch.keys():\n", + " print(f\"key = {key}\")\n", + "\n", + "print(\"\\nCORRECT: To iterate over key-value pairs:\")\n", + "for key, value in batch.items():\n", + " print(f\"{key} = {value}\")" + ] + }, + { + "cell_type": "markdown", "metadata": {}, "source": [ - "data = Batch(a=np.zeros((3, 4)))\n", - "data.to_torch_(dtype=torch.float32, device=\"cpu\")\n", - "print(data.a)\n", - "# Conversion to NumPy is also supported via to_numpy_()\n", - "data.to_numpy_()\n", - "print(data.a)" - ], + "### Automatic Type Conversions\n", + "\n", + "Be aware of these automatic conversions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Lists become arrays\n", + "batch = Batch(a=[1, 2, 3])\n", + "print(\"List → array:\", type(batch.a), batch.a.dtype)\n", + "\n", + "# Dicts become Batch\n", + "batch = Batch(a={\"x\": 1, \"y\": 2})\n", + "print(\"Dict → Batch:\", type(batch.a))\n", + "\n", + "# Scalars become numpy scalars\n", + "batch = Batch(a=5)\n", + "print(\"Scalar → np.ndarray:\", type(batch.a), batch.a)\n", + "\n", + "# Mixed types → object dtype\n", + "batch = Batch(a=[1, \"hello\", None])\n", + "print(\"Mixed → object:\", batch.a.dtype, batch.a)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Length Edge Cases" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 1. Scalars have no length\n", + "batch_scalar = Batch(a=5, b=10)\n", + "try:\n", + " len(batch_scalar)\n", + "except TypeError as e:\n", + " print(f\"Scalar batch: {e}\")\n", + "\n", + "# 2. Empty nested batches ignored in len()\n", + "batch_empty_nested = Batch(a=[1, 2, 3], b=Batch())\n", + "print(f\"\\nWith empty nested: len = {len(batch_empty_nested)} (ignores b)\")\n", + "\n", + "# 3. Different lengths: returns minimum\n", + "batch_different = Batch(a=[1, 2], b=[1, 2, 3, 4])\n", + "print(f\"Different lengths: len = {len(batch_different)} (minimum)\")\n", + "\n", + "# 4. None values don't affect length\n", + "batch_none = Batch(a=[1, 2, 3], b=None)\n", + "print(f\"With None: len = {len(batch_none)} (None ignored)\")" + ] + }, + { + "cell_type": "markdown", "metadata": {}, "source": [ - "batch_cat.to_numpy_()\n", - "print(batch_cat)\n", - "batch_cat.to_torch_()\n", - "print(batch_cat)" - ], + "### String Keys Only" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Integer keys not allowed\n", + "try:\n", + " batch = Batch({1: \"value\", 2: \"other\"})\n", + "except AssertionError as e:\n", + " print(\"Integer keys not allowed:\", e)\n", + "\n", + "# String keys work\n", + "batch = Batch({\"key1\": \"value\", \"key2\": \"other\"})\n", + "print(\"\\nString keys work:\", list(batch.keys()))" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Serialization\n", + "### Cat vs Stack Behavior\n", "\n", - "`Batch` objects are serializable and compatible with Python's `pickle` module, enabling persistent storage and restoration. This capability is particularly important for distributed environment sampling." + "Recent changes have made concatenation stricter about structure:" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n", - "batch_pk = pickle.loads(pickle.dumps(batch))\n", - "print(batch_pk)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Stack pads missing keys with zeros\n", + "b1 = Batch(a=[1, 2])\n", + "b2 = Batch(b=[3, 4])\n", + "stacked = Batch.stack([b1, b2])\n", + "print(\"Stack (different keys):\")\n", + "print(f\" a: {stacked.a} (b2.a padded with 0)\")\n", + "print(f\" b: {stacked.b} (b1.b padded with 0)\")\n", + "\n", + "# Cat requires same structure now\n", + "b3 = Batch(a=[1, 2], b=[3, 4])\n", + "b4 = Batch(a=[5, 6], b=[7, 8])\n", + "concatenated = Batch.cat([b3, b4])\n", + "print(\"\\nCat (same keys):\")\n", + "print(f\" a: {concatenated.a}\")\n", + "print(f\" b: {concatenated.b}\")\n", + "\n", + "# Cat with different structures raises error\n", + "try:\n", + " Batch.cat([b1, b2]) # Different keys!\n", + "except ValueError:\n", + " print(\"\\nCat with different keys: ValueError raised\")" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Advanced Topics\n", + "## 8. Best Practices\n", "\n", - "This section addresses advanced `Batch` concepts, including key reservation mechanisms, detailed length and shape semantics, and aggregation of heterogeneous batches.\n", + "### When to Use Batch\n", "\n", - "### Key Reservation\n", + "**Good use cases:**\n", + "- Collecting environment data (transitions, episodes)\n", + "- Storing replay buffer data\n", + "- Passing data between components (collector → buffer → policy)\n", + "- Handling heterogeneous observations (dict spaces)\n", "\n", - "In many scenarios, the key structure is known in advance while value shapes remain undetermined until runtime (e.g., after environment execution). Tianshou supports key reservation through placeholder values.\n", + "**Consider alternatives:**\n", + "- Simple scalar tracking (use regular variables)\n", + "- Pure tensor operations (use PyTorch tensors directly)\n", + "- Deeply nested arbitrary structures (use dataclasses)\n", "\n", - "
\n", - "
\n", - "Structure of a batch with reserved keys\n", - "
\n", + "### Structuring Your Batches\n", "\n", - "Key reservation is implemented using empty `Batch()` objects as placeholder values." + "**Use protocols for type safety:**" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "a = Batch(b=Batch()) # 'b' is a reserved key\n", - "print(a)\n", - "\n", - "# Hierarchical key reservation\n", - "a = Batch(b=Batch(c=Batch()), d=Batch()) # 'c' and 'd' are reserved keys\n", - "print(a)\n", - "\n", - "a = Batch(key1=np.array([1, 2]), key2=np.array([3, 4]), key3=Batch(key4=Batch(), key5=Batch()))\n", - "print(a)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Good: Use protocols for clear interfaces\n", + "def train_step(batch: RolloutBatchProtocol) -> float:\n", + " \"\"\"IDE knows what fields exist.\"\"\"\n", + " loss = ((batch.rew - 0.5) ** 2).mean() # Type-safe\n", + " return float(loss)\n", + "\n", + "\n", + "# Create properly typed batch\n", + "train_batch = cast(\n", + " RolloutBatchProtocol,\n", + " Batch(\n", + " obs=np.random.randn(10, 4),\n", + " obs_next=np.random.randn(10, 4),\n", + " act=np.random.randint(0, 2, 10),\n", + " rew=np.random.randn(10),\n", + " terminated=np.zeros(10, dtype=bool),\n", + " truncated=np.zeros(10, dtype=bool),\n", + " info=np.array([{}] * 10, dtype=object),\n", + " ),\n", + ")\n", + "\n", + "loss = train_step(train_batch)\n", + "print(f\"Loss: {loss:.4f}\")" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The structure of `Batch` objects with reserved keys can be visualized using tree notation, where reserved keys represent internal nodes lacking attached leaf nodes.\n", - "\n", - "**Important:** Reserved keys indicate that values will eventually be assigned. These values may be scalars, tensors, or `Batch` objects. Understanding this concept is essential for working with heterogeneous batches.\n", + "**Consistent key naming:**\n", + "- Follow Tianshou conventions: `obs`, `act`, `rew`, `terminated`, `truncated`\n", + "- Use descriptive names: `camera_obs` not `co`\n", + "- Avoid name collisions with Batch methods: don't use `keys`, `items`, `get`, etc.\n", "\n", - "The introduction of reserved keys necessitates verification methods." + "**When to nest vs flatten:**" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Examples of checking whether a Batch is empty\n", - "print(len(Batch().get_keys()) == 0)\n", - "print(len(Batch(a=Batch(), b=Batch(c=Batch())).get_keys()) == 0)\n", - "print(len(Batch(a=Batch(), b=Batch(c=Batch()))) == 0)\n", - "print(len(Batch(d=1).get_keys()) == 0)\n", - "print(len(Batch(a=np.float64(1.0)).get_keys()) == 0)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Good: Nest related data\n", + "batch_nested = Batch(\n", + " obs=Batch(\n", + " camera=np.zeros((32, 64, 64, 3)), lidar=np.zeros((32, 100)), position=np.zeros((32, 3))\n", + " ),\n", + " act=np.zeros(32),\n", + ")\n", + "print(\"Nested structure for related obs:\")\n", + "print(f\" Access: batch.obs.camera.shape = {batch_nested.obs.camera.shape}\")\n", + "\n", + "# Less good: Flat structure loses semantic grouping\n", + "batch_flat = Batch(\n", + " camera=np.zeros((32, 64, 64, 3)),\n", + " lidar=np.zeros((32, 100)),\n", + " position=np.zeros((32, 3)),\n", + " act=np.zeros(32),\n", + ")\n", + "print(\"\\nFlat structure (works but less clear):\")\n", + "print(f\" Access: batch.camera.shape = {batch_flat.camera.shape}\")" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To verify emptiness, use `len(Batch.get_keys()) == 0` for direct emptiness (a simple `Batch()`) or `len(Batch) == 0` for recursive emptiness (a `Batch` without scalar or tensor leaf nodes).\n", + "### Performance Tips\n", "\n", - "**Note:** The `Batch.empty` attribute differs from emptiness checking. `Batch.empty` and its in-place variant `Batch.empty_` are used to reset values to zeros or `None`. Consult the API documentation for additional details." + "**Use in-place operations:**" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "### Length and Shape Semantics\n", + "import time\n", + "\n", + "batch = Batch(a=np.random.randn(1000, 100))\n", "\n", - "The primary use case for `Batch` is storing batched data collections. The term \"Batch\" originates from deep learning terminology, denoting mini-batches sampled from datasets. Typically, a \"Batch\" represents a collection of tensors sharing a common first dimension, with batch size corresponding to the `Batch` object's length.\n", + "# Creates copy\n", + "start = time.time()\n", + "for _ in range(100):\n", + " _ = batch.to_torch()\n", + "time_copy = time.time() - start\n", "\n", - "When all leaf nodes in a `Batch` object are tensors but possess different lengths, storage within `Batch` remains possible. However, the semantics of `len(obj)` become ambiguous. Currently, Tianshou returns the minimum tensor length, though we strongly recommend avoiding `len(obj)` operations on `Batch` objects containing tensors of varying lengths." + "# In-place (faster)\n", + "start = time.time()\n", + "for _ in range(100):\n", + " batch.to_torch_()\n", + " batch.to_numpy_()\n", + "time_inplace = time.time() - start\n", + "\n", + "print(f\"Copy: {time_copy:.4f}s\")\n", + "print(f\"In-place: {time_inplace:.4f}s\")\n", + "print(f\"Speedup: {time_copy / time_inplace:.1f}x\")" ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Length and shape examples for Batch objects\n", - "data = Batch(a=[5.0, 4.0], b=np.zeros((2, 3, 4)))\n", - "print(data.shape)\n", - "print(len(data))\n", - "print(data[0].shape)\n", - "try:\n", - " len(data[0])\n", - "except TypeError as e:\n", - " print(f\"TypeError: {e}\")" - ], - "outputs": [], - "execution_count": null + "**Be mindful of copies:**" + ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "**Important:** Following scientific computing conventions, scalars possess no length. If any scalar leaf node exists in a `Batch` object, invoking `len(obj)` raises an exception.\n", - "\n", - "Similarly, reserved keys have undetermined values and therefore no defined length (or equivalently, **arbitrary** length). When tensors and reserved keys coexist, the latter are ignored in `len(obj)` calculations, returning the minimum tensor length. When no tensors exist in the `Batch` object, Tianshou raises an exception.\n", - "\n", - "The `obj.shape` attribute exhibits similar behavior to `len(obj)`:\n", - "\n", - "1. When all leaf nodes are tensors with identical shapes, that shape is returned.\n", + "arr = np.array([1, 2, 3])\n", "\n", - "2. When all leaf nodes are tensors with differing shapes, the minimum length per dimension is returned.\n", + "# Default: creates reference (be careful!)\n", + "batch1 = Batch(a=arr)\n", + "batch1.a[0] = 999\n", + "print(f\"Original array modified: {arr}\") # Changed!\n", "\n", - "3. When any scalar value exists, `obj.shape` returns `[]`.\n", - "\n", - "4. Reserved keys have undetermined shape, treated as `[]`." + "# Explicit copy when needed\n", + "arr = np.array([1, 2, 3])\n", + "batch2 = Batch(a=arr, copy=True)\n", + "batch2.a[0] = 999\n", + "print(f\"Original array preserved: {arr}\") # Unchanged" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Aggregation of Heterogeneous Batches\n", - "\n", - "This section examines aggregation operations (stack/concatenate) on heterogeneous `Batch` objects, focusing on structural heterogeneity. Aggregation operations ultimately invoke NumPy/PyTorch operators (`np.stack`, `np.concatenate`, `torch.stack`, `torch.cat`). Value heterogeneity that violates these operators' requirements (e.g., stacking `np.ndarray` with `torch.Tensor`, or stacking tensors with incompatible shapes) results in exceptions.\n", + "**Avoid unnecessary conversions:**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Inefficient: multiple conversions\n", + "batch = Batch(a=np.random.randn(100, 10))\n", + "batch.to_torch_()\n", + "batch.to_numpy_() # Unnecessary if we just need NumPy\n", "\n", - "
\n", - "
\n", - "
\n", + "# Efficient: convert once, use many times\n", + "batch = Batch(a=np.random.randn(100, 10))\n", + "batch.to_torch_() # Convert once\n", + "# ... do torch operations ...\n", + "# Keep as torch if that's what you need!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Common Patterns\n", "\n", - "The behavior is intuitive: keys not shared across all batches are padded with zeros (or `None` for `np.object` data type) in batches lacking these keys." + "**Pattern 1: Building batches incrementally**" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Stack example: batch a lacks key 'b', batch b lacks key 'a'\n", - "a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5])))\n", - "b = Batch(b=np.zeros([4, 6]), common=Batch(c=np.zeros([4, 5])))\n", - "c = Batch.stack([a, b])\n", - "print(c.a.shape)\n", - "print(c.b.shape)\n", - "print(c.common.c.shape)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Collect data from multiple steps\n", + "step_data = []\n", + "for i in range(5):\n", + " step_data.append({\"obs\": np.random.randn(4), \"act\": i, \"rew\": np.random.randn()})\n", + "\n", + "# Convert to batch (automatically stacks)\n", + "episode_batch = Batch(step_data)\n", + "print(\"Episode batch shape:\", episode_batch.shape)\n", + "print(\"obs shape:\", episode_batch.obs.shape)" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Automatic padding with None or 0 using appropriate shapes\n", - "data_1 = Batch(a=np.array([0.0, 2.0]))\n", - "data_2 = Batch(a=np.array([1.0, 3.0]), b=\"done\")\n", - "data = Batch.stack((data_1, data_2))\n", - "print(data)" - ], - "outputs": [], - "execution_count": null + "**Pattern 2: Slicing for mini-batches**" + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Concatenation example: batch a lacks key 'b', batch b lacks key 'a'\n", - "a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5])))\n", - "b = Batch(b=np.zeros([4, 3]), common=Batch(c=np.zeros([4, 5])))\n", - "# Note: Recent changes have modified concatenation behavior for heterogeneous batches\n", - "# The following operation is no longer supported:\n", - "# c = Batch.cat([a, b])\n", - "# print(c.a.shape)\n", - "# print(c.b.shape)\n", - "# print(c.common.c.shape)" - ], "outputs": [], - "execution_count": null + "source": [ + "# Large batch\n", + "large_batch = Batch(obs=np.random.randn(100, 4), act=np.random.randint(0, 2, 100))\n", + "\n", + "# Split into mini-batches\n", + "batch_size = 32\n", + "for mini_batch in large_batch.split(batch_size, shuffle=True):\n", + " print(f\"Mini-batch size: {len(mini_batch)}\")\n", + " # Train on mini_batch...\n", + " break # Just show one iteration" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "However, certain cases of extreme heterogeneity prevent aggregation:" + "**Pattern 3: Extending batches**" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Example of incompatible batches that cannot be aggregated\n", - "try:\n", - " a = Batch(a=np.zeros([4, 4]))\n", - " b = Batch(a=Batch(b=Batch()))\n", - " c = Batch.stack([a, b])\n", - "except Exception as e:\n", - " print(f\"Exception: {e}\")" - ], "outputs": [], - "execution_count": null + "source": [ + "# Start with some data\n", + "batch = Batch(obs=np.array([[1, 2], [3, 4]]), act=np.array([0, 1]))\n", + "print(\"Initial:\", len(batch))\n", + "\n", + "# Add more data\n", + "new_data = Batch(obs=np.array([[5, 6]]), act=np.array([1]))\n", + "batch.cat_(new_data)\n", + "print(\"After cat_:\", len(batch))\n", + "print(\"obs:\", batch.obs)" + ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "How can we determine if batches can be aggregated? Reconsider the purpose of reserved keys. The distinction between `a1=Batch(b=Batch())` and `a2=Batch()` is that `a1.b` returns `Batch()` while `a2.b` raises an exception. **Reserved keys enable attribute reference for future value assignment.**\n", + "## 9. Summary\n", + "\n", + "### Key Takeaways\n", + "\n", + "1. **Batch = Dict + Array**: Combines key-value storage with array operations\n", + "2. **Hierarchical Structure**: Perfect for complex RL data (nested observations, etc.)\n", + "3. **Type Safety via Protocols**: Use `BatchProtocol` subclasses for IDE support and type checking\n", + "4. **Special RL Features**: Distribution slicing, heterogeneous aggregation, missing value handling\n", + "5. **Remember**: Iteration is over indices, NOT keys!\n", "\n", - "A key chain `k=[key1, key2, ..., keyn]` applies to `b` if the expression `b.key1.key2.{...}.keyn` is valid, with the result being `b[k]`.\n", + "### Quick Reference\n", "\n", - "For a set of `Batch` objects S, aggregation is possible if there exists a `Batch` object `b` satisfying:\n", + "| Operation | Code | Notes |\n", + "|-----------|------|-------|\n", + "| Create | `Batch(a=1, b=[2, 3])` | Auto-converts types |\n", + "| Access | `batch.a` or `batch[\"a\"]` | Equivalent |\n", + "| Index | `batch[0]`, `batch[:10]` | Returns sliced Batch |\n", + "| Iterate indices | `for item in batch:` | Yields batch[0], batch[1], ... |\n", + "| Iterate keys | `for k in batch.keys():` | Like dict |\n", + "| Stack | `Batch.stack([b1, b2])` | Adds dimension |\n", + "| Concatenate | `Batch.cat([b1, b2])` | Extends dimension |\n", + "| Split | `batch.split(size=10)` | Returns iterator |\n", + "| To PyTorch | `batch.to_torch_()` | In-place |\n", + "| To NumPy | `batch.to_numpy_()` | In-place |\n", + "| Transform | `batch.apply_values_transform(fn)` | Recursive |\n", "\n", - "1. **Key chain applicability:** For any object `bi` in S and any key chain `k`, if `bi[k]` is valid, then `b[k]` must be valid.\n", + "### Next Steps\n", "\n", - "2. **Type consistency:** If `bi[k]` is not `Batch()` (the final key in the chain is not reserved), then the type of `b[k]` must match `bi[k]` (both must be scalar/tensor/non-empty Batch values).\n", + "- **Collector Deep Dive**: See how Batch flows through data collection\n", + "- **Buffer Deep Dive**: Understand how Batch is stored and sampled\n", + "- **Policy Guide**: Learn how policies work with BatchProtocol\n", + "- **API Reference**: Full details at [Batch API documentation](https://tianshou.org/en/stable/api/tianshou.data.html#tianshou.data.Batch)\n", "\n", - "The `Batch` object `b` satisfying these rules with minimal keys determines the aggregation structure. Values are defined as follows: for any applicable key chain `k`, `b[k]` represents the stack/concatenation of `[bi[k] for bi in S]` (with appropriate zero or `None` padding when `k` does not apply to `bi`). When all `bi[k]` are `Batch()`, the aggregation result is also an empty `Batch()`." + "### Questions?\n", + "\n", + "- Check the [Tianshou GitHub discussions](https://github.com/thu-ml/tianshou/discussions)\n", + "- Review [issue tracker](https://github.com/thu-ml/tianshou/issues) for known gotchas\n", + "- Read the [source code](https://github.com/thu-ml/tianshou/blob/master/tianshou/data/batch.py) - it's well-documented!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Additional Considerations\n", + "## Appendix: Serialization & Advanced Topics\n", "\n", - "1. Environment observations typically utilize NumPy ndarrays, while policies require `torch.Tensor` for prediction and learning. Tianshou provides helper functions for in-place conversion between NumPy arrays and Torch tensors.\n", + "### Pickle Support" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Batch objects are picklable\n", + "original = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))\n", "\n", - "2. `obj.stack_([a, b])` is equivalent to `Batch.stack([obj, a, b])`, and `obj.cat_([a, b])` is equivalent to `Batch.cat([obj, a, b])`. For frequently required two-batch concatenation, `obj.cat_(a)` serves as an alias for `obj.cat_([a])`.\n", + "# Serialize and deserialize\n", + "serialized = pickle.dumps(original)\n", + "restored = pickle.loads(serialized)\n", "\n", - "3. `Batch.cat` and `Batch.cat_` currently do not support the `axis` argument available in `np.concatenate` and `torch.cat`.\n", + "print(\"Original obs.a:\", original.obs.a)\n", + "print(\"Restored obs.a:\", restored.obs.a)\n", + "print(\"Equal:\", original == restored)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Advanced Indexing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Multi-dimensional data\n", + "batch = Batch(a=np.random.randn(5, 3, 2))\n", + "print(\"Original shape:\", batch.a.shape)\n", "\n", - "4. `Batch.stack` and `Batch.stack_` support the `axis` argument, enabling stacking along dimensions beyond the first. However, when keys are not shared across all batches, `stack` with `axis != 0` is undefined and currently raises an exception." + "# Various indexing operations\n", + "print(\"batch[0].a.shape:\", batch[0].a.shape)\n", + "print(\"batch[:, 0].a.shape:\", batch[:, 0].a.shape)\n", + "print(\"batch[[0, 2, 4]].a.shape:\", batch[[0, 2, 4]].a.shape)" ] } ], "metadata": { - "colab": { - "provenance": [] - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -868,7 +1463,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.0" } }, "nbformat": 4, From 8734a25f17735f79360e2297e9fbfbf4b26ad2c6 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 17 Nov 2025 22:28:33 +0700 Subject: [PATCH 3/3] Rewrote and extended buffer notebook --- docs/02_deep_dives/L2_Buffer.ipynb | 1927 ++++++++++++++++++++++++---- pyproject.toml | 1 + 2 files changed, 1644 insertions(+), 284 deletions(-) diff --git a/docs/02_deep_dives/L2_Buffer.ipynb b/docs/02_deep_dives/L2_Buffer.ipynb index c7da8e09b..a604dd5bf 100644 --- a/docs/02_deep_dives/L2_Buffer.ipynb +++ b/docs/02_deep_dives/L2_Buffer.ipynb @@ -2,450 +2,1809 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "xoPiGVD8LNma" - }, + "metadata": {}, "source": [ - "# Buffer\n", + "# Buffer: Experience Replay in Tianshou\n", "\n", - "The replay buffer is a fundamental component in deep reinforcement learning (DRL) implementations. In Tianshou, the Buffer module extends the functionality of the Batch class by providing trajectory tracking capabilities and sampling utilities that go beyond basic data storage.\n", + "The replay buffer is a fundamental component in reinforcement learning, particularly for off-policy algorithms. Tianshou's buffer implementation extends beyond simple data storage to provide sophisticated trajectory tracking, efficient sampling, and seamless integration with the RL training pipeline.\n", "\n", - "Tianshou provides several buffer implementations, with `ReplayBuffer` and `VectorReplayBuffer` being the most fundamental. The latter is specifically designed for parallelized environments, which will be covered in the [Vectorized Environment](https://tianshou.readthedocs.io/en/master/02_notebooks/L3_Vectorized__Environment.html) tutorial. This tutorial focuses exclusively on the `ReplayBuffer` implementation." + "This tutorial provides comprehensive coverage of Tianshou's buffer system, from basic concepts to advanced features and integration patterns." ] }, { - "metadata": {}, "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "import pickle\n", + "import tempfile\n", "\n", "import numpy as np\n", "\n", - "from tianshou.data import Batch, ReplayBuffer" - ], - "outputs": [], - "execution_count": null + "from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer, VectorReplayBuffer" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "OdesCAxANehZ" - }, + "metadata": {}, "source": [ - "## Core Functionality" + "## 1. Introduction: Why Buffers in Reinforcement Learning?\n", + "\n", + "### The Role of Experience Replay\n", + "\n", + "Experience replay is a critical technique in modern reinforcement learning that addresses three fundamental challenges:\n", + "\n", + "1. **Breaking Temporal Correlation**: Sequential experiences from an agent are highly correlated. Training directly on these sequences can lead to unstable learning. By storing experiences and sampling randomly, we break these correlations.\n", + "\n", + "2. **Sample Efficiency**: In RL, collecting data through environment interaction is often expensive. Experience replay allows us to reuse each experience multiple times for training, dramatically improving sample efficiency.\n", + "\n", + "3. **Mini-batch Training**: Modern deep learning requires mini-batch gradient descent. Buffers enable efficient batching of experiences for neural network training.\n", + "\n", + "### Why Not Alternatives?\n", + "\n", + "**Plain Python Lists**\n", + "- No efficient random sampling\n", + "- No automatic circular queue behavior\n", + "- No trajectory boundary tracking\n", + "- Poor memory management for large datasets\n", + "\n", + "**Simple Batch Storage**\n", + "- No automatic overwriting when full\n", + "- No episode metadata (returns, lengths)\n", + "- No methods for boundary navigation (prev/next)\n", + "- No specialized sampling strategies\n", + "\n", + "### Buffer = Batch + Trajectory Management + Sampling\n", + "\n", + "Tianshou's buffers build on the `Batch` class to provide:\n", + "- **Circular queue storage**: Automatic overwriting of oldest data\n", + "- **Trajectory tracking**: Episode boundaries, returns, and lengths\n", + "- **Efficient sampling**: Random access with various strategies\n", + "- **Integration utilities**: Seamless connection to Collector and Policy\n", + "\n", + "### Use Cases\n", + "\n", + "- **Off-policy algorithms**: DQN, SAC, TD3, DDPG require experience replay\n", + "- **On-policy with replay**: Some PPO implementations reuse buffer data\n", + "- **Offline RL**: Loading and using pre-collected datasets\n", + "- **Multi-environment training**: VectorReplayBuffer for parallel collection" ] }, { "cell_type": "markdown", - "metadata": { - "id": "fUbLl9T_SrTR" - }, + "metadata": {}, "source": [ - "### Circular Queue Storage Mechanism\n", + "## 2. Buffer Types and Hierarchy\n", + "\n", + "Tianshou provides several buffer implementations, each designed for specific use cases. Understanding this hierarchy is crucial for choosing the right buffer.\n", "\n", - "The buffer stores data in batches using a circular queue mechanism. When the buffer reaches its maximum capacity, newly added data automatically overwrites the oldest entries, ensuring efficient memory utilization while maintaining the most recent experiences." + "### Buffer Hierarchy\n", + "\n", + "```mermaid\n", + "graph TD\n", + " RB[ReplayBuffer
Single environment
Circular queue] --> RBM[ReplayBufferManager
Manages multiple buffers
Contiguous memory]\n", + " RBM --> VRB[VectorReplayBuffer
Parallel environments
Maintains temporal order]\n", + " \n", + " RB --> PRB[PrioritizedReplayBuffer
TD-error based sampling
Importance weights]\n", + " PRB --> PVRB[PrioritizedVectorReplayBuffer
Prioritized + Parallel]\n", + " \n", + " RB --> CRB[CachedReplayBuffer
Primary + auxiliary caches
Imitation learning]\n", + " \n", + " RB --> HERB[HERReplayBuffer
Hindsight Experience Replay
Goal-conditioned RL]\n", + " HERB --> HVRB[HERVectorReplayBuffer
HER + Parallel]\n", + " \n", + " style RB fill:#e1f5ff\n", + " style RBM fill:#fff4e1\n", + " style VRB fill:#ffe1f5\n", + " style PRB fill:#e8f5e1\n", + " style CRB fill:#f5e1e1\n", + " style HERB fill:#e1e1f5\n", + "```\n", + "\n", + "### When to Use Which Buffer\n", + "\n", + "**ReplayBuffer**: Single environment scenarios\n", + "- Simple setup and testing\n", + "- Debugging algorithms\n", + "- Low-parallelism training\n", + "\n", + "**VectorReplayBuffer**: Multiple parallel environments (most common)\n", + "- Standard production use case\n", + "- Efficient parallel data collection\n", + "- Maintains per-environment episode boundaries\n", + "\n", + "**PrioritizedReplayBuffer**: DQN variants with prioritization\n", + "- Rainbow DQN\n", + "- Algorithms requiring importance sampling\n", + "- When some transitions are more valuable than others\n", + "\n", + "**CachedReplayBuffer**: Separate primary and auxiliary caches\n", + "- Imitation learning (expert + agent data)\n", + "- GAIL and similar algorithms\n", + "- When you need different sampling strategies for different data sources\n", + "\n", + "**HERReplayBuffer**: Goal-conditioned reinforcement learning\n", + "- Sparse reward environments\n", + "- Robotics tasks with explicit goals\n", + "- Relabeling failed experiences with achieved goals" ] }, { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mocZ6IqZTH62", - "outputId": "66cc4181-c51b-4a47-aacf-666b92b7fc52" - }, + "cell_type": "markdown", + "metadata": {}, "source": [ - "# Initialize buffer with maximum capacity of 10 transitions\n", - "print(\"========================================\")\n", - "dummy_buf = ReplayBuffer(size=10)\n", - "print(dummy_buf)\n", - "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")\n", + "## 3. Basic Operations\n", "\n", - "# Add 3 transition steps sequentially\n", - "print(\"========================================\")\n", - "for i in range(3):\n", - " dummy_buf.add(\n", - " Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),\n", - " )\n", - "print(dummy_buf)\n", - "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")\n", - "\n", - "# Add 10 additional transitions to demonstrate circular queue behavior\n", - "# Note: First 3 transitions will be overwritten as capacity is exceeded\n", - "print(\"========================================\")\n", - "for i in range(3, 13):\n", - " dummy_buf.add(\n", - " Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),\n", - " )\n", - "print(dummy_buf)\n", - "print(f\"maxsize: {dummy_buf.maxsize}, data length: {len(dummy_buf)}\")" - ], + "### 3.1 Construction and Configuration\n", + "\n", + "The ReplayBuffer constructor accepts several important parameters that control its behavior:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Create a buffer with all configuration options\n", + "buf = ReplayBuffer(\n", + " size=20, # Maximum capacity (transitions)\n", + " stack_num=1, # Frame stacking for RNNs (default: 1, no stacking)\n", + " ignore_obs_next=False, # Save memory by not storing obs_next\n", + " save_only_last_obs=False, # For temporal stacking (Atari-style)\n", + " sample_avail=False, # Sample only valid indices for frame stacking\n", + " random_seed=42, # Reproducible sampling\n", + ")\n", + "\n", + "print(f\"Buffer created: {buf}\")\n", + "print(f\"Max size: {buf.maxsize}\")\n", + "print(f\"Current length: {len(buf)}\")" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "H8B85Y5yUfTy" - }, + "metadata": {}, "source": [ - "### Batch-Compatible Operations\n", + "**Parameter Explanations**:\n", "\n", - "Consistent with the `Batch` interface, `ReplayBuffer` supports standard operations including concatenation, splitting, advanced slicing, and indexing." + "- `size`: Maximum number of transitions the buffer can hold. When full, oldest data is overwritten.\n", + "- `stack_num`: Number of consecutive frames to stack. Used for RNN inputs or frame-based policies (Atari).\n", + "- `ignore_obs_next`: If True, obs_next is not stored, saving memory. The buffer reconstructs it from the next obs when needed.\n", + "- `save_only_last_obs`: For temporal stacking. Only saves the last observation in a stack.\n", + "- `sample_avail`: When True with stack_num > 1, only samples indices where a complete stack is available.\n", + "- `random_seed`: Seeds the random number generator for reproducible sampling." ] }, { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cOX-ADOPNeEK", - "outputId": "f1a8ec01-b878-419b-f180-bdce3dee73e6" - }, + "cell_type": "markdown", + "metadata": {}, "source": [ - "print(dummy_buf[-1])\n", - "print(dummy_buf[-3:])\n", - "# Additional Batch methods can be explored as needed" - ], + "### 3.2 Reserved Keys and the Done Flag System\n", + "\n", + "ReplayBuffer uses nine reserved keys that integrate with Gymnasium conventions. Understanding the done flag system is critical." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# The nine reserved keys\n", + "print(\"Reserved keys:\")\n", + "print(ReplayBuffer._reserved_keys)\n", + "print(\"\\nKeys required for add():\")\n", + "print(ReplayBuffer._required_keys_for_add)" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "vqldap-2WQBh" - }, + "metadata": {}, "source": [ - "### Persistence and Serialization\n", + "**Important: Understanding done, terminated, and truncated**\n", + "\n", + "Gymnasium (the successor to OpenAI Gym) introduced a crucial distinction:\n", "\n", - "The buffer can be serialized to disk while preserving trajectory information. This capability is particularly valuable for offline reinforcement learning applications, where pre-collected experience datasets are utilized for training." + "- `terminated`: Episode ended naturally (agent reached goal or failed)\n", + " - Examples: CartPole fell over, agent reached goal state\n", + " - Should be used for bootstrapping calculations\n", + "\n", + "- `truncated`: Episode was cut off artificially (time limit, external interruption)\n", + " - Examples: Maximum episode length reached, environment reset externally \n", + " - Should NOT be used for bootstrapping (the episode could have continued)\n", + "\n", + "- `done`: Computed automatically as `terminated OR truncated`\n", + " - Used internally for episode boundary tracking\n", + " - You should NEVER manually set this field\n", + "\n", + "**Best Practice**: Always use the `info` dictionary for custom metadata rather than adding top-level keys:" ] }, { "cell_type": "code", - "metadata": { - "id": "Ppx0L3niNT5K" - }, - "source": [ - "_dummy_buf = pickle.loads(pickle.dumps(dummy_buf))" - ], + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# GOOD: Custom metadata in info dictionary\n", + "good_batch = Batch(\n", + " obs=np.array([1.0, 2.0]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=np.array([1.5, 2.5]),\n", + " info={\"custom_metric\": 0.95, \"step_count\": 10}, # Custom data here\n", + ")\n", + "\n", + "# BAD: Don't add custom top-level keys (may conflict with future buffer features)\n", + "# bad_batch = Batch(..., custom_metric=0.95) # Don't do this!\n", + "\n", + "print(\"Good batch structure:\")\n", + "print(good_batch)" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "Eqezp0OyXn6J" - }, + "metadata": {}, + "source": [ + "### 3.3 Circular Queue Storage\n", + "\n", + "The buffer implements a circular queue: when it reaches maximum capacity, new data overwrites the oldest entries." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "### Reserved Keys for DRL Integration\n", + "# Create a small buffer to demonstrate circular behavior\n", + "demo_buf = ReplayBuffer(size=5)\n", "\n", - "To facilitate seamless integration with DRL algorithms, `ReplayBuffer` utilizes nine reserved keys within the `Batch` structure. These keys follow the [Gymnasium](https://gymnasium.farama.org/index.html#) conventions:\n", + "print(\"Adding 3 transitions:\")\n", + "for i in range(3):\n", + " demo_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i),\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + "print(f\"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}\")\n", + "print(f\"Observations: {demo_buf.obs[: len(demo_buf)]}\")\n", "\n", - "* `obs` - Current observation\n", - "* `act` - Action taken\n", - "* `rew` - Reward received\n", - "* `terminated` - Episode termination flag (goal reached or failure)\n", - "* `truncated` - Episode truncation flag (time limit or external interruption)\n", - "* `done` - Combined termination/truncation indicator\n", - "* `obs_next` - Subsequent observation\n", - "* `info` - Auxiliary information dictionary\n", - "* `policy` - Policy-specific data\n", + "print(\"\\nAdding 5 more transitions (total 8, exceeds capacity 5):\")\n", + "for i in range(3, 8):\n", + " demo_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i),\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + "print(f\"Length: {len(demo_buf)}, Max: {demo_buf.maxsize}\")\n", + "print(f\"Observations: {demo_buf.obs[: len(demo_buf)]}\")\n", + "print(\"\\nNotice: First 3 transitions (0,1,2) were overwritten by (3,4,5)\")\n", + "print(\"Buffer now contains: [3, 4, 5, 6, 7]\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3.4 Batch-Compatible Operations\n", "\n", - "**Best Practice**: Use the `info` dictionary for custom metadata rather than adding additional top-level keys. The `done` flag is internally tracked to determine trajectory boundaries, episode lengths, and cumulative rewards.\n", + "Since ReplayBuffer extends Batch functionality, it supports standard indexing and slicing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Indexing and slicing\n", + "print(\"Last transition:\")\n", + "print(demo_buf[-1])\n", "\n", - "```python\n", - "# Not recommended: Custom top-level keys\n", - "buf.add(Batch(......, extra_info=0))\n", + "print(\"\\nLast 3 transitions:\")\n", + "print(demo_buf[-3:])\n", "\n", - "# Recommended: Use info dictionary\n", - "buf.add(Batch(......, info={\"extra_info\": 0}))\n", - "```" + "print(\"\\nSpecific indices [0, 2, 4]:\")\n", + "print(demo_buf[np.array([0, 2, 4])])" ] }, { "cell_type": "markdown", - "metadata": { - "id": "ueAbTspsc6jo" - }, + "metadata": {}, "source": [ - "### Experience Sampling\n", + "## 4. Trajectory Management\n", + "\n", + "A key distinguishing feature of ReplayBuffer is its automatic tracking of episode boundaries and metadata.\n", "\n", - "The primary function of a replay buffer in DRL is to enable experience sampling for training. The buffer provides two methods for this purpose:\n", + "### 4.1 Episode Tracking and Metadata\n", "\n", - "1. `ReplayBuffer.sample()` - Direct batch sampling with specified size\n", - "2. `ReplayBuffer.split(..., shuffle=True)` - Split buffer into multiple batches with optional shuffling" + "The `add()` method returns four values that provide episode information:" ] }, { "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "P5xnYOhrchDl", - "outputId": "bcd2c970-efa6-43bb-8709-720d38f77bbd" - }, - "source": [ - "dummy_buf.sample(batch_size=5)" - ], + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Create a fresh buffer for trajectory demonstration\n", + "traj_buf = ReplayBuffer(size=20)\n", + "\n", + "print(\"Episode 1: 4 steps, terminates naturally\")\n", + "for i in range(4):\n", + " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i,\n", + " rew=float(i + 1), # Rewards: 1, 2, 3, 4\n", + " terminated=i == 3, # Last step terminates\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + " print(f\" Step {i}: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len}, ep_start={ep_start}\")\n", + "\n", + "print(\"\\nNotice: Episode return (10.0) and length (4) only appear at the end!\")" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "IWyaOSKOcgK4" - }, + "metadata": {}, "source": [ - "## Trajectory Management\n", + "**Return Values Explained**:\n", "\n", - "A distinguishing feature of `ReplayBuffer` compared to `Batch` is its trajectory tracking capability, which maintains episode boundaries and associated metadata.\n", + "1. `idx`: Index where the transition was inserted (np.ndarray of shape (1,))\n", + "2. `ep_rew`: Episode return, only non-zero when `done=True` (np.ndarray of shape (1,))\n", + "3. `ep_len`: Episode length, only non-zero when `done=True` (np.ndarray of shape (1,))\n", + "4. `ep_start`: Index where the episode started (np.ndarray of shape (1,))\n", "\n", - "The following example demonstrates trajectory tracking by simulating three episodes:\n", - "1. First episode: 3 steps (completed)\n", - "2. Second episode: 5 steps (completed)\n", - "3. Third episode: 5 steps (ongoing)" + "This automatic computation eliminates manual episode tracking during data collection." ] }, { "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "editable": true, - "id": "H0qRb6HLfhLB", - "outputId": "9bdb7d4e-b6ec-489f-a221-0bddf706d85b", - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "trajectory_buffer = ReplayBuffer(size=10)\n", - "\n", - "# Episode 1: 3 steps, terminates at step 2\n", - "print(\"========================================\")\n", - "for i in range(3):\n", - " result = trajectory_buffer.add(\n", + "# Continue with Episode 2: 5 steps\n", + "print(\"Episode 2: 5 steps, truncated (time limit)\")\n", + "for i in range(4, 9):\n", + " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i,\n", - " rew=i,\n", - " terminated=1 if i == 2 else 0,\n", - " truncated=0,\n", - " done=i == 2,\n", + " rew=float(i + 1),\n", + " terminated=False,\n", + " truncated=i == 8, # Last step truncated\n", " obs_next=i + 1,\n", " info={},\n", - " ),\n", + " )\n", " )\n", - " print(result)\n", - "print(trajectory_buffer)\n", - "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")\n", + " if i == 8:\n", + " print(\n", + " f\" Final step: idx={idx}, ep_rew={ep_rew[0]:.1f}, ep_len={ep_len[0]}, ep_start={ep_start}\"\n", + " )\n", "\n", - "# Episode 2: 5 steps, terminates at step 7\n", - "print(\"========================================\")\n", - "for i in range(3, 8):\n", - " result = trajectory_buffer.add(\n", + "# Episode 3: Ongoing (not finished)\n", + "print(\"\\nEpisode 3: 3 steps, ongoing (not done)\")\n", + "for i in range(9, 12):\n", + " idx, ep_rew, ep_len, ep_start = traj_buf.add(\n", " Batch(\n", " obs=i,\n", " act=i,\n", - " rew=i,\n", - " terminated=1 if i == 7 else 0,\n", - " truncated=0,\n", - " done=i == 7,\n", + " rew=float(i + 1),\n", + " terminated=False,\n", + " truncated=False, # Episode continues\n", " obs_next=i + 1,\n", " info={},\n", - " ),\n", - " )\n", - " print(result)\n", - "print(trajectory_buffer)\n", - "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")\n", - "\n", - "# Episode 3: 5 steps added, episode still ongoing\n", - "print(\"========================================\")\n", - "for i in range(8, 13):\n", - " result = trajectory_buffer.add(\n", - " Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=False, obs_next=i + 1, info={}),\n", + " )\n", " )\n", - " print(result)\n", - "print(trajectory_buffer)\n", - "print(f\"maxsize: {trajectory_buffer.maxsize}, data length: {len(trajectory_buffer)}\")" - ], - "outputs": [], - "execution_count": null + " if i == 11:\n", + " print(\n", + " f\" Latest step: idx={idx}, ep_rew={ep_rew}, ep_len={ep_len} (zeros because not done)\"\n", + " )\n", + "\n", + "print(f\"\\nBuffer state: {len(traj_buf)} transitions across 2 complete + 1 ongoing episode\")" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "dO7PWdb_hkXA" - }, + "metadata": {}, "source": [ - "### Episode Metrics Tracking\n", - "\n", - "The `ReplayBuffer.add()` method returns a tuple containing four values: `(current_index, episode_reward, episode_length, episode_start_index)`. \n", + "### 4.2 Boundary Navigation: prev() and next()\n", "\n", - "**Important**: The `episode_reward` and `episode_length` fields are only populated when an episode completes (i.e., when `done=True`). This automatic computation eliminates the need for manual episode metric tracking during data collection." + "The buffer provides methods to navigate within episodes while respecting episode boundaries:" ] }, { - "cell_type": "markdown", - "metadata": { - "id": "xbVc90z8itH0" - }, + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "### Episode Boundary Navigation\n", - "\n", - "The buffer provides mechanisms to navigate episode boundaries efficiently. Consider the following scenario where we query a mid-episode step:" + "# Examine the buffer structure\n", + "print(\"Buffer contents:\")\n", + "print(f\"Indices: {np.arange(len(traj_buf))}\")\n", + "print(f\"Obs: {traj_buf.obs[: len(traj_buf)]}\")\n", + "print(f\"Terminated: {traj_buf.terminated[: len(traj_buf)]}\")\n", + "print(f\"Truncated: {traj_buf.truncated[: len(traj_buf)]}\")\n", + "print(f\"Done: {traj_buf.done[: len(traj_buf)]}\")\n", + "print(\"\\nEpisode boundaries: indices 3 (terminated) and 8 (truncated)\")" ] }, { "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4mKwo54MjupY", - "outputId": "9ae14a7e-908b-44eb-afec-89b45bac5961" - }, + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "print(trajectory_buffer)\n", - "print(\"========================================\")\n", + "# prev() returns the previous index within the same episode\n", + "# It STOPS at episode boundaries\n", + "test_indices = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])\n", + "prev_indices = traj_buf.prev(test_indices)\n", "\n", - "data = trajectory_buffer[6]\n", - "print(data)" - ], + "print(\"prev() behavior:\")\n", + "print(f\"Index: {test_indices}\")\n", + "print(f\"Prev: {prev_indices}\")\n", + "print(\"\\nObservations:\")\n", + "print(\"- Index 0 stays at 0 (start of episode 1)\")\n", + "print(\"- Index 4 stays at 4 (start of episode 2, can't go back to episode 1)\")\n", + "print(\"- Index 9 stays at 9 (start of episode 3, can't go back to episode 2)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# next() returns the next index within the same episode\n", + "# It STOPS at episode boundaries\n", + "next_indices = traj_buf.next(test_indices)\n", + "\n", + "print(\"next() behavior:\")\n", + "print(f\"Index: {test_indices}\")\n", + "print(f\"Next: {next_indices}\")\n", + "print(\"\\nObservations:\")\n", + "print(\"- Index 3 stays at 3 (end of episode 1, terminated)\")\n", + "print(\"- Index 8 stays at 8 (end of episode 2, truncated)\")\n", + "print(\"- Indices 9-11 advance normally (episode 3 ongoing)\")" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "p5Co_Fmzj8Sw" - }, + "metadata": {}, "source": [ - "### Determining Episode Start Indices\n", - "\n", - "Step 6 belongs to the second episode (steps 3-7). While this may appear straightforward, determining the episode start index programmatically is non-trivial due to:\n", + "**Use Cases for prev() and next()**:\n", "\n", - "1. **Ambiguous done flags**: The preceding `done` flag approach fails when the buffer contains incomplete episodes, as step 3 is surrounded by `done=False` values\n", - "2. **Complex buffer structures**: Advanced buffers like `VectorReplayBuffer` do not store data sequentially, making boundary detection more challenging\n", + "These methods are essential for computing algorithmic quantities:\n", + "- **N-step returns**: Use prev() to look back N steps within an episode\n", + "- **GAE (Generalized Advantage Estimation)**: Navigate backwards through episodes\n", + "- **Episode extraction**: Find episode start/end indices\n", + "- **Temporal difference targets**: Ensure you don't bootstrap across episode boundaries" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4.3 Identifying Unfinished Episodes\n", "\n", - "The buffer provides a unified API to handle these complexities through the `prev()` method, which identifies the previous step within an episode:" + "The `unfinished_index()` method returns indices of ongoing episodes:" ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Query previous steps for indices [0, 1, 2, 3, 4, 5, 6]\n", - "# Episode boundaries prevent backward traversal past episode starts\n", - "print(trajectory_buffer.prev(np.array([0, 1, 2, 3, 4, 5, 6])))" - ], "outputs": [], - "execution_count": null + "source": [ + "unfinished = traj_buf.unfinished_index()\n", + "print(f\"Unfinished episode indices: {unfinished}\")\n", + "print(f\"Latest step of ongoing episode: obs={traj_buf.obs[unfinished[0]]}\")\n", + "\n", + "# After finishing episode 3\n", + "traj_buf.add(\n", + " Batch(\n", + " obs=12,\n", + " act=12,\n", + " rew=13.0,\n", + " terminated=True,\n", + " truncated=False,\n", + " obs_next=13,\n", + " info={},\n", + " )\n", + ")\n", + "\n", + "unfinished_after = traj_buf.unfinished_index()\n", + "print(\"\\nAfter finishing episode 3:\")\n", + "print(f\"Unfinished episodes: {unfinished_after} (empty array)\")" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "4Wlb57V4lQyQ" - }, + "metadata": {}, "source": [ - "The output confirms that step 3 marks the episode start. The complementary `ReplayBuffer.next()` method enables forward traversal to identify episode terminations, providing a consistent interface across all buffer implementations." + "## 5. Sampling Strategies\n", + "\n", + "Efficient sampling is critical for RL training. The buffer provides several sampling methods and strategies.\n", + "\n", + "### 5.1 Basic Sampling" ] }, { "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zl5TRMo7oOy5", - "outputId": "4a11612c-3ee0-4e74-b028-c8759e71fbdb" - }, - "source": [ - "# Query next steps for indices [4, 5, 6, 7, 8, 9]\n", - "# Episode boundaries prevent forward traversal past episode ends\n", - "print(trajectory_buffer.next(np.array([4, 5, 6, 7, 8, 9])))" - ], + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Create a buffer with some data\n", + "sample_buf = ReplayBuffer(size=100)\n", + "for i in range(50):\n", + " sample_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=i % 4,\n", + " rew=np.random.random(),\n", + " terminated=(i + 1) % 10 == 0,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Sample with batch_size\n", + "batch, indices = sample_buf.sample(batch_size=8)\n", + "print(f\"Sampled batch size: {len(batch)}\")\n", + "print(f\"Sampled indices: {indices}\")\n", + "print(f\"Sampled observations: {batch.obs}\")\n", + "\n", + "# batch_size=None: return all data in random order\n", + "all_data, all_indices = sample_buf.sample(batch_size=None)\n", + "print(f\"\\nSample all (batch_size=None): {len(all_data)} transitions\")\n", + "\n", + "# batch_size=0: return all data in buffer order\n", + "ordered_data, ordered_indices = sample_buf.sample(batch_size=0)\n", + "print(f\"Get all in order (batch_size=0): {len(ordered_data)} transitions\")\n", + "print(f\"Indices in order: {ordered_indices[:10]}...\") # Show first 10" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "YJ9CcWZXoOXw" - }, + "metadata": {}, + "source": [ + "**Sampling Behavior Summary**:\n", + "\n", + "- `batch_size > 0`: Random sample of specified size\n", + "- `batch_size = None`: All data in random order \n", + "- `batch_size = 0`: All data in insertion order\n", + "- `batch_size < 0`: Empty array (edge case handling)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, "source": [ - "### Identifying Incomplete Episodes\n", + "### 5.2 Frame Stacking\n", "\n", - "The buffer maintains tracking of incomplete episodes through the `unfinished_index()` method, which identifies the most recent step of ongoing episodes (marked with `done=False`):" + "The `stack_num` parameter enables automatic frame stacking, useful for RNN inputs or Atari-style environments where temporal context matters:" ] }, { "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Xkawk97NpItg", - "outputId": "df10b359-c2c7-42ca-e50d-9caee6bccadd" - }, + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "print(trajectory_buffer.unfinished_index())" - ], + "# Create buffer with frame stacking\n", + "stack_buf = ReplayBuffer(size=20, stack_num=4)\n", + "\n", + "# Add observations: 0, 1, 2, ..., 9\n", + "for i in range(10):\n", + " stack_buf.add(\n", + " Batch(\n", + " obs=np.array([i]), # Single frame\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 9,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Get stacked frames for index 6\n", + "# Should return [3, 4, 5, 6] (4 consecutive frames ending at 6)\n", + "stacked = stack_buf.get(index=6, key=\"obs\")\n", + "print(\"Frame stacking demo:\")\n", + "print(\"Requested index: 6\")\n", + "print(f\"Stacked frames shape: {stacked.shape}\")\n", + "print(f\"Stacked frames: {stacked.flatten()}\")\n", + "print(\"\\nExplanation: stack_num=4, so index 6 returns [obs[3], obs[4], obs[5], obs[6]]\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], - "execution_count": null + "source": [ + "# Demonstrate episode boundary handling with frame stacking\n", + "boundary_buf = ReplayBuffer(size=20, stack_num=4)\n", + "\n", + "# Episode 1: indices 0-4\n", + "for i in range(5):\n", + " boundary_buf.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 4,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Episode 2: indices 5-9\n", + "for i in range(5, 10):\n", + " boundary_buf.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 9,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# Try to get stacked frames at episode boundary\n", + "boundary_stack = boundary_buf.get(index=6, key=\"obs\") # Early in episode 2\n", + "print(\"\\nFrame stacking at episode boundary:\")\n", + "print(f\"Index 6 stacked frames: {boundary_stack.flatten()}\")\n", + "print(\"Notice: Frames don't cross episode boundary (5,5,5,6 not 3,4,5,6)\")\n", + "print(\"The buffer uses prev() internally, which respects episode boundaries\")" + ] }, { "cell_type": "markdown", - "metadata": { - "id": "8_lMr0j3pOmn" - }, + "metadata": {}, "source": [ - "### Applications in DRL Algorithms\n", + "**Frame Stacking Use Cases**:\n", "\n", - "These trajectory navigation APIs are essential for computing algorithmic quantities such as:\n", - "- Generalized Advantage Estimation (GAE)\n", - "- N-step returns\n", - "- Temporal difference targets\n", + "- **RNN/LSTM inputs**: Provide temporal context to recurrent networks\n", + "- **Atari games**: Stack 4 frames to capture motion (as in DQN paper)\n", + "- **Velocity estimation**: Multiple frames allow computing derivatives\n", + "- **Partially observable environments**: Build up state estimates\n", "\n", - "The unified interface ensures modular design and enables algorithm implementations that generalize across different buffer types. For reference implementations, see the [Tianshou policy base class](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L384)." + "**Important Notes**:\n", + "- Frame stacking respects episode boundaries (won't stack across episodes)\n", + "- Set `sample_avail=True` to only sample indices where full stacks are available\n", + "- `save_only_last_obs=True` saves memory in Atari-style setups" ] }, { "cell_type": "markdown", - "metadata": { - "id": "FEyE0c7tNfwa" - }, + "metadata": {}, "source": [ - "## Advanced Topics\n", + "## 6. VectorReplayBuffer: Parallel Environment Support\n", "\n", - "### Specialized Buffer Implementations\n", + "VectorReplayBuffer is essential for modern RL training with parallel environments. It maintains separate subbuffers for each environment while providing a unified interface.\n", "\n", - "Tianshou provides several specialized buffer variants for advanced use cases:\n", + "### 6.1 Motivation and Architecture\n", "\n", - "* **PrioritizedReplayBuffer**: Implements [prioritized experience replay](https://arxiv.org/abs/1511.05952) for importance-weighted sampling\n", - "* **CachedReplayBuffer**: Maintains a primary buffer with auxiliary cached buffers for improved sample efficiency in specific scenarios\n", - "* **ReplayBufferManager**: Base class for custom buffer implementations requiring management of multiple buffer instances\n", + "When training with multiple parallel environments (e.g., 8 environments running simultaneously), we need:\n", + "- **Per-environment episode tracking**: Each environment has its own episode boundaries\n", + "- **Temporal ordering**: Preserve the sequence of events within each environment\n", + "- **Unified sampling**: Sample uniformly across all environments for training\n", "\n", - "Consult the API documentation and source code for detailed implementation specifications.\n", + "```mermaid\n", + "graph LR\n", + " E1[Env 1] --> B1[Subbuffer 1
2500 capacity]\n", + " E2[Env 2] --> B2[Subbuffer 2
2500 capacity]\n", + " E3[Env 3] --> B3[Subbuffer 3
2500 capacity]\n", + " E4[Env 4] --> B4[Subbuffer 4
2500 capacity]\n", + " \n", + " B1 --> VRB[VectorReplayBuffer
Total: 10000
Unified Sampling]\n", + " B2 --> VRB\n", + " B3 --> VRB\n", + " B4 --> VRB\n", + " \n", + " VRB --> Policy[Policy Training]\n", + " \n", + " style E1 fill:#e1f5ff\n", + " style E2 fill:#e1f5ff\n", + " style E3 fill:#e1f5ff\n", + " style E4 fill:#e1f5ff\n", + " style B1 fill:#fff4e1\n", + " style B2 fill:#fff4e1\n", + " style B3 fill:#fff4e1\n", + " style B4 fill:#fff4e1\n", + " style VRB fill:#ffe1f5\n", + " style Policy fill:#e8f5e1\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create VectorReplayBuffer for 4 parallel environments\n", + "vec_buf = VectorReplayBuffer(\n", + " total_size=100, # Total capacity across all subbuffers\n", + " buffer_num=4, # Number of parallel environments\n", + ")\n", "\n", - "### Recurrent Neural Network Support\n", + "print(\"VectorReplayBuffer created:\")\n", + "print(f\"Total size: {vec_buf.maxsize}\")\n", + "print(f\"Number of subbuffers: {vec_buf.buffer_num}\")\n", + "print(f\"Size per subbuffer: {vec_buf.maxsize // vec_buf.buffer_num}\")\n", + "print(f\"Subbuffer edges: {vec_buf.subbuffer_edges}\")\n", + "print(\"\\nSubbuffer edges define the boundary indices: [0, 25, 50, 75, 100]\")\n", + "print(\"Subbuffer 0: indices 0-24, Subbuffer 1: indices 25-49, etc.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.2 The buffer_ids Parameter\n", "\n", - "The buffer initialization accepts a `stack_num` parameter (default: 1) to enable frame stacking for recurrent neural network (RNN) integration in DRL algorithms. This feature facilitates temporal sequence processing by automatically stacking consecutive observations. Refer to the API documentation for configuration details and usage examples." + "This is one of the most confusing aspects for new users. The `buffer_ids` parameter specifies which subbuffer each transition belongs to." ] - } - ], - "metadata": { - "colab": { - "provenance": [] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulate data from 4 parallel environments\n", + "# Each environment produces one transition\n", + "parallel_batch = Batch(\n", + " obs=np.array([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2], [3.1, 3.2]]), # 4 observations\n", + " act=np.array([0, 1, 0, 1]), # 4 actions\n", + " rew=np.array([1.0, 2.0, 3.0, 4.0]), # 4 rewards\n", + " terminated=np.array([False, False, False, False]),\n", + " truncated=np.array([False, False, False, False]),\n", + " obs_next=np.array([[0.2, 0.3], [1.2, 1.3], [2.2, 2.3], [3.2, 3.3]]),\n", + " info=np.array([{}, {}, {}, {}], dtype=object),\n", + ")\n", + "\n", + "print(\"Parallel batch shape:\", parallel_batch.obs.shape)\n", + "print(\"This represents 4 transitions, one from each environment\")\n", + "\n", + "# Add with buffer_ids specifying which subbuffer each transition goes to\n", + "indices, ep_rews, ep_lens, ep_starts = vec_buf.add(\n", + " parallel_batch,\n", + " buffer_ids=[0, 1, 2, 3], # Transition 0→Subbuf 0, 1→Subbuf 1, etc.\n", + ")\n", + "\n", + "print(f\"\\nAdded to indices: {indices}\")\n", + "print(\"Notice: Indices are in different subbuffers:\")\n", + "print(f\" Index {indices[0]} in subbuffer 0 (range 0-24)\")\n", + "print(f\" Index {indices[1]} in subbuffer 1 (range 25-49)\")\n", + "print(f\" Index {indices[2]} in subbuffer 2 (range 50-74)\")\n", + "print(f\" Index {indices[3]} in subbuffer 3 (range 75-99)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add more data to demonstrate buffer_ids\n", + "# Environments don't always produce data in order 0,1,2,3\n", + "# For example, if only environments 1 and 3 are ready:\n", + "partial_batch = Batch(\n", + " obs=np.array([[1.2, 1.3], [3.2, 3.3]]), # Only 2 observations\n", + " act=np.array([0, 1]),\n", + " rew=np.array([2.5, 4.5]),\n", + " terminated=np.array([False, False]),\n", + " truncated=np.array([False, False]),\n", + " obs_next=np.array([[1.3, 1.4], [3.3, 3.4]]),\n", + " info=np.array([{}, {}], dtype=object),\n", + ")\n", + "\n", + "# Only environments 1 and 3 produced data\n", + "indices2, _, _, _ = vec_buf.add(\n", + " partial_batch,\n", + " buffer_ids=[1, 3], # Only these two subbuffers receive data\n", + ")\n", + "\n", + "print(\"Added partial batch (only envs 1 and 3):\")\n", + "print(f\"Indices: {indices2}\")\n", + "print(f\"Subbuffer 1 received data at index {indices2[0]}\")\n", + "print(f\"Subbuffer 3 received data at index {indices2[1]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Important: buffer_ids Requirements**:\n", + "\n", + "For `VectorReplayBuffer`:\n", + "- `buffer_ids` length must match batch size\n", + "- Values must be in range [0, buffer_num)\n", + "- Can be partial (not all environments at once)\n", + "\n", + "For regular `ReplayBuffer`:\n", + "- If `buffer_ids` is not None, it must be [0]\n", + "- Batch must have shape (1, data_length)\n", + "- This is for API compatibility with VectorReplayBuffer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.3 Subbuffer Edges and Episode Handling\n", + "\n", + "Subbuffer edges prevent episodes from spanning across subbuffers, ensuring data from different environments doesn't get mixed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The subbuffer_edges property defines boundaries\n", + "print(f\"Subbuffer edges: {vec_buf.subbuffer_edges}\")\n", + "print(\"\\nThis creates 4 subbuffers:\")\n", + "for i in range(vec_buf.buffer_num):\n", + " start = vec_buf.subbuffer_edges[i]\n", + " end = vec_buf.subbuffer_edges[i + 1]\n", + " print(f\"Subbuffer {i}: indices [{start}, {end})\")\n", + "\n", + "# Episodes cannot cross these boundaries\n", + "# prev() and next() respect subbuffer edges just like episode boundaries\n", + "test_idx = np.array([24, 25, 49, 50]) # At subbuffer edges\n", + "prev_result = vec_buf.prev(test_idx)\n", + "next_result = vec_buf.next(test_idx)\n", + "\n", + "print(\"\\nBoundary navigation test:\")\n", + "print(f\"Indices: {test_idx}\")\n", + "print(f\"prev(): {prev_result}\")\n", + "print(f\"next(): {next_result}\")\n", + "print(\"\\nNotice: prev/next don't cross subbuffer boundaries\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6.4 Sampling from VectorReplayBuffer\n", + "\n", + "Sampling is uniform across all subbuffers (proportional to their current fill level):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add more data to have enough for sampling\n", + "for _step in range(10):\n", + " batch = Batch(\n", + " obs=np.random.randn(4, 2),\n", + " act=np.random.randint(0, 2, size=4),\n", + " rew=np.random.random(4),\n", + " terminated=np.zeros(4, dtype=bool),\n", + " truncated=np.zeros(4, dtype=bool),\n", + " obs_next=np.random.randn(4, 2),\n", + " info=np.array([{}] * 4, dtype=object),\n", + " )\n", + " vec_buf.add(batch, buffer_ids=[0, 1, 2, 3])\n", + "\n", + "# Sample batch\n", + "sampled, indices = vec_buf.sample(batch_size=16)\n", + "print(f\"Sampled {len(sampled)} transitions\")\n", + "print(f\"Sample indices (from different subbuffers): {indices}\")\n", + "print(\"\\nNotice indices span across all subbuffer ranges\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Specialized Buffer Variants\n", + "\n", + "### 7.1 PrioritizedReplayBuffer\n", + "\n", + "Implements prioritized experience replay where transitions are sampled based on their TD-error magnitudes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": "# Create prioritized buffer\nprio_buf = PrioritizedReplayBuffer(\n size=100,\n alpha=0.6, # Prioritization exponent (0=uniform, 1=fully prioritized)\n beta=0.4, # Importance sampling correction (annealed to 1)\n)\n\n# Add some transitions\nfor i in range(20):\n prio_buf.add(\n Batch(\n obs=np.array([i]),\n act=i % 4,\n rew=np.random.random(),\n terminated=False,\n truncated=False,\n obs_next=np.array([i + 1]),\n info={},\n )\n )\n\n# Sample returns batch and indices\n# Importance weights are INSIDE the batch as batch.weight\nbatch, indices = prio_buf.sample(batch_size=8)\nprint(f\"Sampled batch size: {len(batch)}\")\nprint(f\"Indices: {indices}\")\nprint(f\"Importance weights (batch.weight): {batch.weight}\")\nprint(\"\\nWeights are stored in batch.weight and compensate for biased sampling\")" + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# After computing TD-errors from the sampled batch, update priorities\n", + "# In practice, these would be actual TD-errors: |Q(s,a) - (r + γ*max Q(s',a'))|\n", + "fake_td_errors = np.random.random(len(indices)) * 10 # Simulated TD-errors\n", + "\n", + "# Update priorities (higher TD-error = higher priority)\n", + "prio_buf.update_weight(indices, fake_td_errors)\n", + "\n", + "print(\"Updated priorities based on TD-errors\")\n", + "print(\"Transitions with higher TD-errors will be sampled more frequently\")\n", + "\n", + "# Demonstrate beta annealing\n", + "prio_buf.set_beta(0.6) # Increase beta over training\n", + "print(f\"\\nAnnealed beta to: {prio_buf.options['beta']}\")\n", + "print(\"Beta typically starts at 0.4 and anneals to 1.0 over training\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**PrioritizedReplayBuffer Use Cases**:\n", + "- Rainbow DQN and variants\n", + "- Any algorithm where some transitions are more \"surprising\" and valuable\n", + "- Environments with rare but important events\n", + "\n", + "**Key Parameters**:\n", + "- `alpha`: Controls how much prioritization affects sampling (0=uniform, 1=fully proportional to priority)\n", + "- `beta`: Importance sampling correction to remain unbiased (anneal from ~0.4 to 1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 7.2 Other Specialized Buffers\n", + "\n", + "**CachedReplayBuffer**: Maintains a primary buffer plus auxiliary caches\n", + "- Use case: Imitation learning where you want separate expert and agent buffers\n", + "- Example: GAIL (Generative Adversarial Imitation Learning)\n", + "- Allows different sampling ratios from different sources\n", + "\n", + "**HERReplayBuffer**: Hindsight Experience Replay for goal-conditioned tasks\n", + "- Use case: Sparse reward robotics tasks\n", + "- Relabels failed episodes with achieved goals as if they were intended\n", + "- Dramatically improves learning in goal-reaching tasks\n", + "- See the HER documentation for detailed examples\n", + "\n", + "For detailed usage of these specialized buffers, refer to the Tianshou API documentation and algorithm-specific tutorials." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Serialization and Persistence\n", + "\n", + "Buffers support multiple serialization formats for saving and loading data.\n", + "\n", + "### 8.1 Pickle Serialization\n", + "\n", + "The simplest method, preserving all buffer state including trajectory metadata:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create and populate a buffer\n", + "save_buf = ReplayBuffer(size=50)\n", + "for i in range(30):\n", + " save_buf.add(\n", + " Batch(\n", + " obs=np.array([i, i + 1]),\n", + " act=i % 4,\n", + " rew=float(i),\n", + " terminated=(i + 1) % 10 == 0,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1, i + 2]),\n", + " info={\"step\": i},\n", + " )\n", + " )\n", + "\n", + "print(f\"Original buffer: {len(save_buf)} transitions\")\n", + "\n", + "# Serialize with pickle\n", + "pickled_data = pickle.dumps(save_buf)\n", + "print(f\"Serialized size: {len(pickled_data)} bytes\")\n", + "\n", + "# Deserialize\n", + "loaded_buf = pickle.loads(pickled_data)\n", + "print(f\"Loaded buffer: {len(loaded_buf)} transitions\")\n", + "print(f\"Data preserved: obs[0] = {loaded_buf.obs[0]}\")\n", + "print(f\"Metadata preserved: info[0] = {loaded_buf.info[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 8.2 HDF5 Serialization\n", + "\n", + "HDF5 is recommended for large datasets and cross-platform compatibility:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Save to HDF5\n", + "with tempfile.NamedTemporaryFile(suffix=\".hdf5\", delete=False) as tmp:\n", + " hdf5_path = tmp.name\n", + "\n", + "save_buf.save_hdf5(hdf5_path, compression=\"gzip\")\n", + "print(f\"Saved to HDF5: {hdf5_path}\")\n", + "\n", + "# Load from HDF5\n", + "loaded_hdf5_buf = ReplayBuffer.load_hdf5(hdf5_path)\n", + "print(f\"Loaded from HDF5: {len(loaded_hdf5_buf)} transitions\")\n", + "print(f\"Data matches: {np.array_equal(save_buf.obs, loaded_hdf5_buf.obs)}\")\n", + "\n", + "# Clean up\n", + "import os\n", + "\n", + "os.unlink(hdf5_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**When to Use HDF5**:\n", + "- Large datasets (> 1GB)\n", + "- Offline RL with pre-collected data\n", + "- Sharing data across platforms\n", + "- Need for compression\n", + "- Integration with external tools (many scientific tools read HDF5)\n", + "\n", + "**When to Use Pickle**:\n", + "- Quick saves during development\n", + "- Small buffers\n", + "- Python-only workflow\n", + "- Simpler serialization needs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 8.3 Loading from Raw Data with from_data()\n", + "\n", + "For offline RL, you can create a buffer from raw arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulate pre-collected offline dataset\n", + "import h5py\n", + "\n", + "# Create temporary HDF5 file with raw data\n", + "with tempfile.NamedTemporaryFile(suffix=\".hdf5\", delete=False) as tmp:\n", + " offline_path = tmp.name\n", + "\n", + "with h5py.File(offline_path, \"w\") as f:\n", + " # Create datasets\n", + " n = 100\n", + " f.create_dataset(\"obs\", data=np.random.randn(n, 4))\n", + " f.create_dataset(\"act\", data=np.random.randint(0, 2, n))\n", + " f.create_dataset(\"rew\", data=np.random.randn(n))\n", + " f.create_dataset(\"terminated\", data=np.random.random(n) < 0.1)\n", + " f.create_dataset(\"truncated\", data=np.zeros(n, dtype=bool))\n", + " f.create_dataset(\"done\", data=np.random.random(n) < 0.1)\n", + " f.create_dataset(\"obs_next\", data=np.random.randn(n, 4))\n", + "\n", + "# Load into buffer\n", + "with h5py.File(offline_path, \"r\") as f:\n", + " offline_buf = ReplayBuffer.from_data(\n", + " obs=f[\"obs\"],\n", + " act=f[\"act\"],\n", + " rew=f[\"rew\"],\n", + " terminated=f[\"terminated\"],\n", + " truncated=f[\"truncated\"],\n", + " done=f[\"done\"],\n", + " obs_next=f[\"obs_next\"],\n", + " )\n", + "\n", + "print(f\"Loaded offline dataset: {len(offline_buf)} transitions\")\n", + "print(f\"Observation shape: {offline_buf.obs.shape}\")\n", + "\n", + "# Clean up\n", + "os.unlink(offline_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is the standard approach for offline RL where you have pre-collected datasets from other sources." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 9. Integration with the RL Pipeline\n", + "\n", + "Understanding how buffers integrate with other Tianshou components is essential for effective usage.\n", + "\n", + "### 9.1 Data Flow in RL Training\n", + "\n", + "```mermaid\n", + "graph LR\n", + " ENV[Vectorized
Environments] -->|observations| COL[Collector]\n", + " POL[Policy] -->|actions| COL\n", + " COL -->|transitions| BUF[Buffer]\n", + " BUF -->|sampled batches| POL\n", + " POL -->|forward pass| ALG[Algorithm]\n", + " ALG -->|loss & gradients| POL\n", + " \n", + " style ENV fill:#e1f5ff\n", + " style COL fill:#fff4e1\n", + " style BUF fill:#ffe1f5\n", + " style POL fill:#e8f5e1\n", + " style ALG fill:#f5e1e1\n", + "```\n", + "\n", + "### 9.2 Typical Training Loop Pattern\n", + "\n", + "Here's how buffers are typically used in a training loop:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pseudocode for typical RL training loop\n", + "# (This is illustrative; actual implementation would use Trainer)\n", + "\n", + "\n", + "def training_loop_pseudocode():\n", + " \"\"\"\n", + " Illustrative training loop showing buffer integration.\n", + "\n", + " In practice, use Tianshou's Trainer class which handles this.\n", + " \"\"\"\n", + " # Setup (illustration only)\n", + " # env = make_vectorized_env(num_envs=8)\n", + " # policy = make_policy()\n", + " # buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\n", + " # collector = Collector(policy, env, buffer)\n", + "\n", + " # Training loop\n", + " # for epoch in range(num_epochs):\n", + " # # 1. Collect data from environments\n", + " # collect_result = collector.collect(n_step=1000)\n", + " # # Collector automatically adds transitions to buffer with correct buffer_ids\n", + " #\n", + " # # 2. Train on multiple batches\n", + " # for _ in range(update_per_collect):\n", + " # # Sample batch from buffer\n", + " # batch, indices = buffer.sample(batch_size=256)\n", + " #\n", + " # # Compute loss and update policy\n", + " # loss = policy.learn(batch)\n", + " #\n", + " # # For prioritized buffers, update priorities\n", + " # # if isinstance(buffer, PrioritizedReplayBuffer):\n", + " # # buffer.update_weight(indices, td_errors)\n", + "\n", + " print(\"This pseudocode illustrates the buffer's role:\")\n", + " print(\"1. Collector fills buffer from environment interaction\")\n", + " print(\"2. Buffer provides random samples for training\")\n", + " print(\"3. Policy learns from sampled batches\")\n", + " print(\"\\nIn practice, use Tianshou's Trainer for this workflow\")\n", + "\n", + "\n", + "training_loop_pseudocode()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 9.3 Collector Integration\n", + "\n", + "The Collector class handles the complexity of:\n", + "- Calling policy to get actions\n", + "- Stepping environments\n", + "- Adding transitions to buffer with correct buffer_ids\n", + "- Tracking episode statistics\n", + "\n", + "When you create a Collector, you pass it a buffer, and it automatically:\n", + "- Uses VectorReplayBuffer for vectorized environments\n", + "- Sets buffer_ids based on which environments are ready\n", + "- Handles episode resets and boundary tracking\n", + "\n", + "See the Collector tutorial for detailed examples of this integration." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 10. Advanced Topics and Edge Cases\n", + "\n", + "### 10.1 Buffer Overflow and Episode Boundaries\n", + "\n", + "What happens when the buffer fills up mid-episode?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Small buffer to demonstrate overflow\n", + "overflow_buf = ReplayBuffer(size=8)\n", + "\n", + "# Add a long episode (12 steps, buffer size is only 8)\n", + "print(\"Adding 12-step episode to buffer with size 8:\")\n", + "for i in range(12):\n", + " idx, ep_rew, ep_len, ep_start = overflow_buf.add(\n", + " Batch(\n", + " obs=i,\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 11,\n", + " truncated=False,\n", + " obs_next=i + 1,\n", + " info={},\n", + " )\n", + " )\n", + " if i in [7, 11]:\n", + " print(f\" Step {i}: idx={idx}, buffer_len={len(overflow_buf)}\")\n", + "\n", + "print(\"\\nFinal buffer contents (most recent 8 steps):\")\n", + "print(f\"Observations: {overflow_buf.obs[: len(overflow_buf)]}\")\n", + "print(f\"Episode return: {ep_rew[0]} (sum of all 12 steps, tracked correctly!)\")\n", + "print(\"\\nNote: Buffer overwrote old data but episode statistics are still correct\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Important**: Episode returns and lengths are tracked internally and remain correct even when the episode spans buffer overflows. The buffer maintains `_ep_return`, `_ep_len`, and `_ep_start_idx` to track ongoing episodes." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.2 Episode Spanning Subbuffer Edges\n", + "\n", + "In VectorReplayBuffer, episodes can wrap around within their subbuffer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create small VectorReplayBuffer to demonstrate edge crossing\n", + "edge_buf = VectorReplayBuffer(total_size=20, buffer_num=2) # 10 per subbuffer\n", + "\n", + "print(f\"Subbuffer edges: {edge_buf.subbuffer_edges}\")\n", + "print(\"Subbuffer 0: indices 0-9, Subbuffer 1: indices 10-19\\n\")\n", + "\n", + "# Fill subbuffer 0 with 12 steps (wraps around since capacity is 10)\n", + "for i in range(12):\n", + " batch = Batch(\n", + " obs=np.array([[i]]),\n", + " act=np.array([0]),\n", + " rew=np.array([1.0]),\n", + " terminated=np.array([i == 11]),\n", + " truncated=np.array([False]),\n", + " obs_next=np.array([[i + 1]]),\n", + " info=np.array([{}], dtype=object),\n", + " )\n", + " idx, _, _, _ = edge_buf.add(batch, buffer_ids=[0])\n", + " if i >= 10:\n", + " print(f\"Step {i} added at index {idx[0]} (wrapped around in subbuffer 0)\")\n", + "\n", + "# get_buffer_indices handles this correctly\n", + "episode_indices = edge_buf.get_buffer_indices(start=8, stop=2) # Crosses edge\n", + "print(f\"\\nEpisode spanning edge (from 8 to 1): {episode_indices}\")\n", + "print(\"Correctly retrieves [8, 9, 0, 1] within subbuffer 0\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 10.3 ignore_obs_next Memory Optimization\n", + "\n", + "For memory-constrained scenarios, you can avoid storing obs_next:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Buffer that doesn't store obs_next\n", + "memory_buf = ReplayBuffer(size=10, ignore_obs_next=True)\n", + "\n", + "# Add transitions (obs_next is ignored)\n", + "for i in range(5):\n", + " memory_buf.add(\n", + " Batch(\n", + " obs=np.array([i, i + 1]),\n", + " act=i,\n", + " rew=1.0,\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1, i + 2]), # Provided but not stored\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# When sampling, obs_next is reconstructed from next obs\n", + "sample, _ = memory_buf.sample(batch_size=1)\n", + "print(f\"Sampled obs: {sample.obs}\")\n", + "print(f\"Sampled obs_next: {sample.obs_next}\")\n", + "print(\"\\nobs_next was reconstructed, not stored directly\")\n", + "print(\"This saves memory at the cost of slightly more complex retrieval\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is particularly useful for Atari environments with large observation spaces (84x84x4 frames)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 11. Surprising Behaviors and Gotchas\n", + "\n", + "### 11.1 Most Common Mistake: buffer_ids Confusion\n", + "\n", + "The buffer_ids parameter is the most common source of errors:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# COMMON ERROR 1: Forgetting buffer_ids with VectorReplayBuffer\n", + "vec_demo = VectorReplayBuffer(total_size=100, buffer_num=4)\n", + "\n", + "parallel_data = Batch(\n", + " obs=np.random.randn(4, 2),\n", + " act=np.array([0, 1, 0, 1]),\n", + " rew=np.array([1.0, 2.0, 3.0, 4.0]),\n", + " terminated=np.array([False, False, False, False]),\n", + " truncated=np.array([False, False, False, False]),\n", + " obs_next=np.random.randn(4, 2),\n", + " info=np.array([{}, {}, {}, {}], dtype=object),\n", + ")\n", + "\n", + "# WRONG: Omitting buffer_ids (defaults to [0,1,2,3] which is OK here)\n", + "# But if you have partial data, this will fail\n", + "vec_demo.add(parallel_data) # Works by default\n", + "\n", + "# CORRECT: Always explicit\n", + "vec_demo.add(parallel_data, buffer_ids=[0, 1, 2, 3])\n", + "print(\"Always specify buffer_ids explicitly for clarity\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# COMMON ERROR 2: Shape mismatch with buffer_ids\n", + "try:\n", + " # Trying to add 2 transitions but specifying 4 buffer_ids\n", + " wrong_batch = Batch(\n", + " obs=np.random.randn(2, 2), # Only 2 transitions!\n", + " act=np.array([0, 1]),\n", + " rew=np.array([1.0, 2.0]),\n", + " terminated=np.array([False, False]),\n", + " truncated=np.array([False, False]),\n", + " obs_next=np.random.randn(2, 2),\n", + " info=np.array([{}, {}], dtype=object),\n", + " )\n", + " vec_demo.add(wrong_batch, buffer_ids=[0, 1, 2, 3]) # MISMATCH!\n", + "except (IndexError, ValueError) as e:\n", + " print(f\"Error caught: {type(e).__name__}\")\n", + " print(\"Lesson: buffer_ids length must match batch size\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11.2 Done Flag Confusion\n", + "\n", + "Never manually set the `done` flag:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# WRONG: Manually setting done\n", + "wrong_batch = Batch(\n", + " obs=1,\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=True,\n", + " truncated=False,\n", + " # done=True, # DON'T DO THIS! It will be overwritten anyway\n", + " obs_next=2,\n", + " info={},\n", + ")\n", + "\n", + "# CORRECT: Only set terminated and truncated\n", + "# done is automatically computed as (terminated OR truncated)\n", + "correct_batch = Batch(\n", + " obs=1,\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=True, # Episode ended naturally\n", + " truncated=False, # Not cut off\n", + " obs_next=2,\n", + " info={},\n", + ")\n", + "\n", + "demo = ReplayBuffer(size=10)\n", + "demo.add(correct_batch)\n", + "print(f\"Terminated: {demo.terminated[0]}\")\n", + "print(f\"Truncated: {demo.truncated[0]}\")\n", + "print(f\"Done (auto-computed): {demo.done[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11.3 Sampling from Empty or Near-Empty Buffers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Edge case: Sampling more than available\n", + "small_buf = ReplayBuffer(size=100)\n", + "for i in range(5): # Only 5 transitions\n", + " small_buf.add(\n", + " Batch(obs=i, act=0, rew=1.0, terminated=False, truncated=False, obs_next=i + 1, info={})\n", + " )\n", + "\n", + "# Request 20 but only 5 available - samples with replacement\n", + "batch, indices = small_buf.sample(batch_size=20)\n", + "print(f\"Requested 20, buffer has {len(small_buf)}, got {len(batch)}\")\n", + "print(f\"Indices: {indices}\")\n", + "print(\"Notice: Some indices repeat (sampling with replacement)\")\n", + "\n", + "# Defensive pattern: Check buffer size\n", + "if len(small_buf) >= 128:\n", + " batch, _ = small_buf.sample(128)\n", + "else:\n", + " print(f\"Buffer has {len(small_buf)} < 128, waiting for more data\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 11.4 Frame Stacking Valid Indices\n", + "\n", + "With stack_num > 1, not all indices are valid for sampling:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# With frame stacking, early indices can't form complete stacks\n", + "stack_demo = ReplayBuffer(size=20, stack_num=4, sample_avail=True)\n", + "\n", + "for i in range(10):\n", + " stack_demo.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=i == 9,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "# With sample_avail=True, only valid indices are sampled\n", + "sampled, indices = stack_demo.sample(batch_size=5)\n", + "print(f\"Sampled indices with stack_num=4, sample_avail=True: {indices}\")\n", + "print(\"All indices >= 3 (can form complete 4-frame stacks)\")\n", + "\n", + "# Without sample_avail, any index can be sampled (may have incomplete stacks)\n", + "stack_demo2 = ReplayBuffer(size=20, stack_num=4, sample_avail=False)\n", + "for i in range(10):\n", + " stack_demo2.add(\n", + " Batch(\n", + " obs=np.array([i]),\n", + " act=0,\n", + " rew=1.0,\n", + " terminated=False,\n", + " truncated=False,\n", + " obs_next=np.array([i + 1]),\n", + " info={},\n", + " )\n", + " )\n", + "\n", + "sampled2, indices2 = stack_demo2.sample(batch_size=5)\n", + "print(f\"\\nSampled indices with sample_avail=False: {indices2}\")\n", + "print(\"May include indices < 3 (incomplete stacks repeated from boundary)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 12. Best Practices\n", + "\n", + "### 12.1 Choosing the Right Buffer\n", + "\n", + "**Decision Tree**:\n", + "\n", + "1. Are you using parallel environments?\n", + " - Yes → Use `VectorReplayBuffer`\n", + " - No → Continue to 2\n", + "\n", + "2. Do you need prioritized experience replay?\n", + " - Yes → Use `PrioritizedReplayBuffer` or `PrioritizedVectorReplayBuffer`\n", + " - No → Continue to 3\n", + "\n", + "3. Is it goal-conditioned RL with sparse rewards?\n", + " - Yes → Use `HERReplayBuffer` or `HERVectorReplayBuffer`\n", + " - No → Continue to 4\n", + "\n", + "4. Do you need separate expert and agent buffers?\n", + " - Yes → Use `CachedReplayBuffer`\n", + " - No → Use `ReplayBuffer` (single env) or `VectorReplayBuffer` (standard choice)\n", + "\n", + "**Most Common Setup**: `VectorReplayBuffer` for production training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.2 Buffer Sizing Guidelines\n", + "\n", + "**Rule of Thumb by Domain**:\n", + "\n", + "- **Atari games**: 1,000,000 transitions (1e6)\n", + "- **Continuous control (MuJoCo)**: 100,000-1,000,000 (1e5-1e6)\n", + "- **Robotics**: 100,000-500,000 (1e5-5e5)\n", + "- **Simple environments (CartPole)**: 10,000-50,000 (1e4-5e4)\n", + "\n", + "**Factors to Consider**:\n", + "- Available RAM (each transition ~observation_size * 2 + metadata)\n", + "- Training time vs sample efficiency tradeoff\n", + "- Algorithm requirements (some need larger buffers)\n", + "\n", + "**Memory Estimation**:\n", + "```python\n", + "# For environments with observation shape (84, 84, 4) (Atari):\n", + "# Each transition: 2 * 84 * 84 * 4 bytes (obs + obs_next) + ~100 bytes overhead\n", + "# = ~56KB per transition\n", + "# 1M transitions = ~56GB (use ignore_obs_next to halve this!)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.3 Configuration Best Practices\n", + "\n", + "**When to use stack_num > 1**:\n", + "- RNN/LSTM policies need temporal context\n", + "- Frame-based policies (Atari with 4-frame stacking)\n", + "- Velocity estimation from positions\n", + "\n", + "**When to use ignore_obs_next=True**:\n", + "- Memory-constrained environments\n", + "- Atari (large observation spaces)\n", + "- When obs_next can be reconstructed from next obs\n", + "\n", + "**When to use save_only_last_obs=True**:\n", + "- Atari with temporal stacking in environment wrapper\n", + "- When observations already contain frame history\n", + "\n", + "**When to use sample_avail=True**:\n", + "- Always use with stack_num > 1 for correctness\n", + "- Ensures samples have complete frame stacks\n", + "- Small performance cost but worth it for data quality" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.4 Integration Patterns\n", + "\n", + "**Pattern 1: Standard Off-Policy Setup**\n", + "```python\n", + "# env = make_vectorized_env(num_envs=8)\n", + "# buffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\n", + "# policy = SACPolicy(...)\n", + "# collector = Collector(policy, env, buffer)\n", + "# \n", + "# # Collect and train\n", + "# collector.collect(n_step=1000)\n", + "# for _ in range(10):\n", + "# batch, indices = buffer.sample(256)\n", + "# policy.learn(batch)\n", + "```\n", + "\n", + "**Pattern 2: Pre-fill Buffer Before Training**\n", + "```python\n", + "# # Collect random exploration data\n", + "# collector.collect(n_step=10000) # Fill buffer\n", + "# \n", + "# # Then start training\n", + "# while not converged:\n", + "# collector.collect(n_step=100)\n", + "# for _ in range(10):\n", + "# batch = buffer.sample(256)\n", + "# policy.learn(batch)\n", + "```\n", + "\n", + "**Pattern 3: Offline RL**\n", + "```python\n", + "# # Load pre-collected dataset\n", + "# buffer = ReplayBuffer.load_hdf5(\"expert_data.hdf5\")\n", + "# \n", + "# # Train without further collection\n", + "# for epoch in range(num_epochs):\n", + "# for _ in range(updates_per_epoch):\n", + "# batch = buffer.sample(256)\n", + "# policy.learn(batch)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 12.5 Performance Tips\n", + "\n", + "**Tip 1: Pre-allocate buffer size appropriately**\n", + "- Don't make buffer too large (wastes memory)\n", + "- Don't make it too small (loses important old experiences)\n", + "- Start with domain defaults and adjust based on performance\n", + "\n", + "**Tip 2: Use HDF5 for large offline datasets**\n", + "- Compression saves disk space\n", + "- Faster loading than pickle for large files\n", + "- Better for sharing across systems\n", + "\n", + "**Tip 3: Batch sampling efficiently**\n", + "- Sample once and use multiple times if possible\n", + "- Don't sample more than you need\n", + "- For multi-GPU training, sample once and split\n", + "\n", + "**Tip 4: Monitor buffer usage**\n", + "```python\n", + "# print(f\"Buffer usage: {len(buffer)}/{buffer.maxsize}\")\n", + "# if len(buffer) < batch_size:\n", + "# print(\"Warning: Sampling with replacement!\")\n", + "```\n", + "\n", + "**Tip 5: Consider ignore_obs_next for large observation spaces**\n", + "- Can halve memory usage\n", + "- Small computational overhead on sampling\n", + "- Especially valuable for image-based RL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "## 13. Quick Reference\n\n### Method Summary\n\n| Method | Purpose | Returns | Notes |\n|--------|---------|---------|-------|\n| `add(batch, buffer_ids)` | Add transition(s) | `(idx, ep_rew, ep_len, ep_start)` | ep_rew/ep_len only non-zero when done=True |\n| `sample(size)` | Random sample | `(batch, indices)` | size=None for all (random), 0 for all (ordered) |\n| `prev(idx)` | Previous in episode | `indices` | Stops at episode boundaries |\n| `next(idx)` | Next in episode | `indices` | Stops at episode boundaries |\n| `get(idx, key, stack_num)` | Get with stacking | `data` | Returns stacked frames if stack_num > 1 |\n| `get_buffer_indices(start, stop)` | Episode range | `indices` | Handles edge-crossing episodes |\n| `unfinished_index()` | Ongoing episodes | `indices` | Returns last step of unfinished episodes |\n| `save_hdf5(path)` | Save to HDF5 | - | Recommended for large datasets |\n| `load_hdf5(path)` | Load from HDF5 | `buffer` | Class method |\n| `from_data(...)` | Create from arrays | `buffer` | For offline RL datasets |\n| `reset()` | Clear buffer | - | Optionally keep episode statistics |\n| `sample_indices(size)` | Get indices only | `indices` | For custom sampling logic |\n\n### Common Patterns Cheatsheet\n\n**Single Environment**:\n```python\nbuffer = ReplayBuffer(size=10000)\nbuffer.add(Batch(obs=..., act=..., rew=..., terminated=..., truncated=..., obs_next=..., info={}))\nbatch, indices = buffer.sample(batch_size=256)\n```\n\n**Parallel Environments**:\n```python\nbuffer = VectorReplayBuffer(total_size=100000, buffer_num=8)\nbuffer.add(parallel_batch, buffer_ids=[0,1,2,3,4,5,6,7])\nbatch, indices = buffer.sample(batch_size=256)\n```\n\n**Frame Stacking**:\n```python\nbuffer = ReplayBuffer(size=100000, stack_num=4, sample_avail=True)\nstacked_obs = buffer.get(index=50, key=\"obs\") # Returns 4 stacked frames\n```\n\n**Prioritized Replay**:\n```python\nbuffer = PrioritizedReplayBuffer(size=100000, alpha=0.6, beta=0.4)\nbatch, indices = buffer.sample(batch_size=256)\nweights = batch.weight # Importance weights are inside the batch\n# ... compute TD errors ...\nbuffer.update_weight(indices, td_errors)\n```\n\n**Offline RL**:\n```python\nbuffer = ReplayBuffer.load_hdf5(\"dataset.hdf5\")\n# Or:\nwith h5py.File(\"dataset.hdf5\", \"r\") as f:\n buffer = ReplayBuffer.from_data(obs=f[\"obs\"], act=f[\"act\"], ...)\n```\n\n**Episode Retrieval**:\n```python\n# Find episode boundaries, then:\nepisode_indices = buffer.get_buffer_indices(start=ep_start_idx, stop=ep_end_idx+1)\nepisode = buffer[episode_indices]\n```" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary and Next Steps\n", + "\n", + "This tutorial covered Tianshou's buffer system comprehensively:\n", + "\n", + "1. **Buffer fundamentals**: Why buffers are essential for RL\n", + "2. **Buffer hierarchy**: Understanding different buffer types\n", + "3. **Basic operations**: Construction, configuration, and data management\n", + "4. **Trajectory management**: Episode tracking and boundary navigation\n", + "5. **Sampling strategies**: Basic sampling and frame stacking\n", + "6. **VectorReplayBuffer**: Critical for parallel environments\n", + "7. **Specialized buffers**: Prioritized, cached, and HER variants\n", + "8. **Serialization**: Pickle and HDF5 persistence\n", + "9. **Integration**: How buffers fit in the RL pipeline\n", + "10. **Advanced topics**: Edge cases and overflow handling\n", + "11. **Gotchas**: Common mistakes and how to avoid them\n", + "12. **Best practices**: Configuration, sizing, and performance\n", + "13. **Quick reference**: Method summary and common patterns\n", + "\n", + "### Next Steps\n", + "\n", + "- **Collector Deep Dive**: Learn how Collector fills buffers from environments\n", + "- **Policy Tutorial**: Understand how policies sample from buffers for training\n", + "- **Algorithm Examples**: See buffer usage in specific algorithms (DQN, SAC, PPO)\n", + "- **API Reference**: Full details at [Buffer API documentation](https://tianshou.org/en/stable/api/tianshou.data.html)\n", + "\n", + "### Further Resources\n", + "\n", + "- [Tianshou GitHub](https://github.com/thu-ml/tianshou) for source code and examples\n", + "- [Gymnasium Documentation](https://gymnasium.farama.org/) for environment conventions\n", + "- Research papers on experience replay and prioritized sampling" + ] + } + ], + "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -459,7 +1818,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.0" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 78a22687d..c84b87142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,6 +162,7 @@ select = [ "ASYNC", "B", "C4", "C90", "COM", "D", "DTZ", "E", "F", "FLY", "G", "I", "ISC", "PIE", "PLC", "PLE", "PLW", "RET", "RUF", "RSE", "SIM", "TID", "UP", "W", "YTT", ] ignore = [ + "RUF003", # custom (greek) letters "SIM118", # Needed b/c iter(batch) != iter(batch.keys()). See https://github.com/thu-ml/tianshou/issues/922 "E501", # line too long. ruff does a good enough job "E741", # variable names like "l". this isn't a huge problem