这是indexloc提供的服务,不要输入任何密码
Skip to content

Lower level access to hl world #1187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ sphinx:
recursive_update : false # A boolean indicating whether to overwrite the Sphinx config (true) or recursively update (false)
config : # key-value pairs to directly over-ride the Sphinx configuration
autodoc_typehints_format: "short"
autodoc_member_order: "bysource"
autoclass_content: "both"
autodoc_default_options:
show-inheritance: True
html_js_files:
# We have to list them explicitly because they need to be loaded in a specific order
- js/vega@5.js
Expand Down
9 changes: 9 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,12 @@ v_s_
obs
obs_next
dtype
entrypoint
interquantile
init
kwarg
kwargs
autocompletion
codebase
indexable
sliceable
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,10 @@ ignore = [
"RET505",
"D106", # undocumented public nested class
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
"D212", # no blank line after """. This clashes with sphinx for multiline descriptions of :param: that start directly after """
"PLW2901", # overwrite vars in loop
"B027", # empty and non-abstract method in abstract class
"D404", # It's fine to start with "This" in docstrings
]
unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
Expand Down
44 changes: 44 additions & 0 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,47 @@
"""This module implements :class:`Batch`, a flexible data structure for
handling heterogeneous data in reinforcement learning algorithms. Such a data structure
is needed since RL algorithms differ widely in the conceptual fields that they need.
`Batch` is the main data carrier in Tianshou. It bears some similarities to
`TensorDict <https://github.com/pytorch/tensordict>`_
that is used for a similar purpose in `pytorch-rl <https://github.com/pytorch/rl>`_.
The main differences between the two are that `Batch` can hold arbitrary objects (and not just torch tensors),
and that Tianshou implements `BatchProtocol` for enabling type checking and autocompletion (more on that below).

The `Batch` class is designed to store and manipulate collections of data with
varying types and structures. It strikes a balance between flexibility and type safety, the latter mainly
achieved through the use of protocols. One can thing of it as a mixture of a dictionary and an array,
as it has both key-value pairs and nesting, while also having a shape, being indexable and sliceable.

Key features of the `Batch` class include:

1. Flexible data storage: Can hold numpy arrays, torch tensors, scalars, and nested Batch objects.
2. Dynamic attribute access: Allows setting and accessing data using attribute notation (e.g., `batch.observation`).
This allows for type-safe and readable code and enables IDE autocompletion. See comments on `BatchProtocol` below.
3. Indexing and slicing: Supports numpy-like indexing and slicing operations. The slicing is extended to nested
Batch objects and torch Distributions.
4. Batch operations: Provides methods for splitting, shuffling, concatenating and stacking multiple Batch objects.
5. Data type conversion: Offers methods to convert data between numpy arrays and torch tensors.
6. Value transformations: Allows applying functions to all values in the Batch recursively.
7. Analysis utilities: Provides methods for checking for missing values, dropping entries with missing values,
and others.

Since we want to keep `Batch` flexible and not fix a specific set of fields or their types,
we don't have fixed interfaces for actual `Batch` objects that are used throughout
tianshou (such interfaces could be dataclasses, for example). However, we still want to enable
IDE autocompletion and type checking for `Batch` objects. To achieve this, we rely on dynamic duck typing
by using `Protocol`. The :class:`BatchProtocol` defines the interface that all `Batch` objects should adhere to,
and its various implementations (like :class:`~.types.ActBatchProtocol` or :class:`~.types.RolloutBatchProtocol`) define the specific
fields that are expected in the respective `Batch` objects. The protocols are then used as type hints
throughout the codebase. Protocols can't be instantiated, but we can cast to them.
For example, we "instantiate" an `ActBatchProtocol` with something like:

>>> act_batch = cast(ActBatchProtocol, Batch(act=my_action))

The users can decide for themselves how to structure their `Batch` objects, and can opt in to the
`BatchProtocol` style to enable type checking and autocompletion. Opting out will have no effect on
the functionality.
"""

import pprint
import warnings
from collections.abc import Callable, Collection, Iterable, Iterator, KeysView, Sequence
Expand Down
8 changes: 3 additions & 5 deletions tianshou/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
TNestedDictValue = np.ndarray | dict[str, "TNestedDictValue"]


d: dict[str, TNestedDictValue] = {"a": {"b": np.array([1, 2, 3])}}
d["c"] = np.array([1, 2, 3])


class ObsBatchProtocol(BatchProtocol, Protocol):
"""Observations of an environment that a policy can turn into actions.

Expand Down Expand Up @@ -62,6 +58,8 @@ class ActStateBatchProtocol(ActBatchProtocol, Protocol):
"""Contains action and state (which can be None), useful for policies that can support RNNs."""

state: dict | BatchProtocol | np.ndarray | None
"""Hidden state of RNNs, or None if not using RNNs. Used for recurrent policies.
At the moment support for recurrent is experimental!"""


class ModelOutputBatchProtocol(ActStateBatchProtocol, Protocol):
Expand Down Expand Up @@ -121,7 +119,7 @@ class QuantileRegressionBatchProtocol(ModelOutputBatchProtocol, Protocol):


class ImitationBatchProtocol(ActBatchProtocol, Protocol):
"""Similar to other batches, but contains imitation_logits and q_value fields."""
"""Similar to other batches, but contains `imitation_logits` and `q_value` fields."""

state: dict | Batch | np.ndarray | None
q_value: torch.Tensor
Expand Down
15 changes: 10 additions & 5 deletions tianshou/evaluation/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,22 @@ class JoblibConfig:


class ExpLauncher(ABC):
"""Base interface for launching multiple experiments simultaneously."""

def __init__(
self,
experiment_runner: Callable[
[Experiment],
InfoStats | None,
] = lambda exp: exp.run().trainer_result,
):
""":param experiment_runner: can be used to override the default way in which an experiment is executed.
Can be useful e.g., if one wants to use the high-level interfaces to setup an experiment (or an experiment
collection) and tinker with it prior to execution. This need often arises when prototyping with mechanisms
that are not yet supported by the high-level interfaces.
Passing this allows arbitrary things to happen during experiment execution, so use it with caution!
"""
:param experiment_runner: determines how an experiment is to be executed.
Overriding the default can be useful, e.g., for using high-level interfaces
to set up an experiment (or an experiment collection) and tinkering with it prior to execution.
This need often arises when prototyping with mechanisms that are not yet supported by
the high-level interfaces.
Allows arbitrary things to happen during experiment execution, so use it with caution!.
"""
self.experiment_runner = experiment_runner

Expand Down Expand Up @@ -112,6 +116,7 @@ def __init__(
super().__init__(experiment_runner=experiment_runner)
self.joblib_cfg = copy(joblib_cfg) if joblib_cfg is not None else JoblibConfig()
# Joblib's backend is hard-coded to loky since the threading backend produces different results
# TODO: fix this
if self.joblib_cfg.backend != "loky":
log.warning(
f"Ignoring the user provided joblib backend {self.joblib_cfg.backend} and using loky instead. "
Expand Down
10 changes: 1 addition & 9 deletions tianshou/highlevel/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,6 @@ def create_train_test_collector(
if reset_collectors:
train_collector.reset()
test_collector.reset()

if self.sampling_config.start_timesteps > 0:
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
)
train_collector.collect(
n_step=self.sampling_config.start_timesteps,
random=self.sampling_config.start_timesteps_random,
)
return train_collector, test_collector

def set_policy_wrapper_factory(
Expand Down Expand Up @@ -200,6 +191,7 @@ def create_trainer(
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
save_checkpoint_fn=policy_persistence.get_save_checkpoint_fn(world),
logger=world.logger,
test_in_train=False,
train_fn=train_fn,
Expand Down
15 changes: 7 additions & 8 deletions tianshou/highlevel/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ def create_venv(


class EnvFactory(ToStringMixin, ABC):
"""Main interface for the creation of environments (in various forms)."""

def __init__(self, venv_type: VectorEnvType):
""":param venv_type: the type of vectorized environment to use for train and test environments.
watch environments are always created as dummy environments.
"""Main interface for the creation of environments (in various forms).

:param venv_type: the type of vectorized environment to use for train and test environments.
`WATCH` environments are always created as `DUMMY` vector environments.
"""
self.venv_type = venv_type

Expand All @@ -377,7 +377,8 @@ def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
"""Create vectorized environments.

:param num_envs: the number of environments
:param mode: the mode for which to create. In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env.
:param mode: the mode for which to create.
In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env.

:return: the vectorized environments
"""
Expand Down Expand Up @@ -437,9 +438,7 @@ def __init__(
:param render_mode_train: the render mode to use for training environments
:param render_mode_test: the render mode to use for test environments
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`.
If envpool is used, the gymnasium parameters will be appropriately translated for use with
`envpool.make_gymnasium`.
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. If envpool is used, the gymnasium parameters will be appropriately translated for use with `envpool.make_gymnasium`.
"""
super().__init__(venv_type)
self.task = task
Expand Down
Loading
Loading