From ed61f92dfb624d3c1789fbb105fca9bc2ab3e045 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 10 Sep 2020 16:43:27 +0800 Subject: [PATCH 1/3] update docs --- .github/workflows/lint_and_docs.yml | 7 +- README.md | 1 + docs/conf.py | 54 +++++++++------ docs/contributing.rst | 30 +++++++++ docs/index.rst | 13 ++-- docs/tutorials/batch.rst | 14 +++- docs/tutorials/cheatsheet.rst | 19 ++++-- docs/tutorials/concepts.rst | 44 ++++++++++-- docs/tutorials/dqn.rst | 32 ++++++--- docs/tutorials/tictactoe.rst | 15 +++-- docs/tutorials/trick.rst | 6 +- setup.cfg | 5 ++ setup.py | 86 ++++++++++++------------ test/base/test_collector.py | 1 - tianshou/data/batch.py | 92 ++++++++++++-------------- tianshou/data/buffer.py | 60 +++++++++-------- tianshou/data/collector.py | 26 +++----- tianshou/data/utils/converter.py | 5 +- tianshou/data/utils/segtree.py | 31 +++++---- tianshou/env/maenv.py | 20 +++--- tianshou/env/utils.py | 2 +- tianshou/env/venvs.py | 61 ++++++++++------- tianshou/env/worker/base.py | 6 +- tianshou/env/worker/dummy.py | 2 +- tianshou/env/worker/subproc.py | 2 +- tianshou/exploration/random.py | 10 +-- tianshou/policy/base.py | 49 ++++++++------ tianshou/policy/modelfree/a2c.py | 5 +- tianshou/policy/modelfree/ddpg.py | 2 +- tianshou/policy/modelfree/dqn.py | 19 +++--- tianshou/policy/modelfree/pg.py | 6 +- tianshou/policy/modelfree/ppo.py | 4 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/policy/modelfree/td3.py | 3 +- tianshou/policy/multiagent/mapolicy.py | 16 +++-- tianshou/policy/random.py | 20 +++--- tianshou/trainer/offpolicy.py | 5 +- tianshou/trainer/onpolicy.py | 5 +- tianshou/utils/compile.py | 5 +- tianshou/utils/moving_average.py | 11 +-- tianshou/utils/net/common.py | 25 ++++--- tianshou/utils/net/continuous.py | 26 +++++--- tianshou/utils/net/discrete.py | 20 +++--- 43 files changed, 516 insertions(+), 351 deletions(-) create mode 100644 setup.cfg diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index 90c6caf67..50e307a68 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -16,15 +16,14 @@ jobs: python -m pip install --upgrade pip setuptools wheel - name: Install dependencies run: | - python -m pip install flake8 + pip install ".[dev]" --upgrade - name: Lint with flake8 run: | flake8 . --count --show-source --statistics - - name: Install dependencies - run: | - pip install ".[dev]" --upgrade - name: Documentation test run: | + pydocstyle tianshou + doc8 docs --max-line-length 1000 cd docs make html SPHINXOPTS="-W" cd .. diff --git a/README.md b/README.md index eebbfde64..f65837e1c 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ Here is Tianshou's other features: - Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) - Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation - Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) +- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions) In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. diff --git a/docs/conf.py b/docs/conf.py index eb6f65f4e..2169ab8ec 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,9 +23,9 @@ # -- Project information ----------------------------------------------------- -project = 'Tianshou' -copyright = '2020, Tianshou contributors.' -author = 'Tianshou contributors' +project = "Tianshou" +copyright = "2020, Tianshou contributors." +author = "Tianshou contributors" # The full version, including alpha/beta/rc tags release = version @@ -37,51 +37,61 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.coverage', + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.coverage", # 'sphinx.ext.imgmath', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinxcontrib.bibtex', + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinxcontrib.bibtex", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] -source_suffix = ['.rst', '.md'] -master_doc = 'index' +templates_path = ["_templates"] +source_suffix = [".rst", ".md"] +master_doc = "index" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] -autodoc_default_options = {'special-members': ', '.join([ - '__len__', '__call__', '__getitem__', '__setitem__', - '__getattr__', '__setattr__'])} +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] +autodoc_default_options = { + "special-members": ", ".join( + [ + "__len__", + "__call__", + "__getitem__", + "__setitem__", + # "__getattr__", + # "__setattr__", + ] + ) +} # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] -html_logo = '_static/images/tianshou-logo.png' +html_logo = "_static/images/tianshou-logo.png" def setup(app): app.add_js_file("js/copybutton.js") app.add_css_file("css/style.css") + # -- Extension configuration ------------------------------------------------- # -- Options for intersphinx extension --------------------------------------- diff --git a/docs/contributing.rst b/docs/contributing.rst index 063db7804..dc7be4f07 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -1,6 +1,7 @@ Contributing to Tianshou ======================== + Install Develop Version ----------------------- @@ -16,6 +17,7 @@ in the main directory. This installation is removable by $ python setup.py develop --uninstall + PEP8 Code Style Check --------------------- @@ -25,6 +27,7 @@ We follow PEP8 python code style. To check, in the main directory, run: $ flake8 . --count --show-source --statistics + Test Locally ------------ @@ -34,6 +37,7 @@ This command will run automatic tests in the main directory $ pytest test --cov tianshou -s --durations 0 -v + Test by GitHub Actions ---------------------- @@ -54,6 +58,7 @@ Test by GitHub Actions .. image:: _static/images/action3.png :align: center + Documentation ------------- @@ -70,3 +75,28 @@ To compile documentation into webpages, run under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers. Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/. + + +Documentation Test +------------------ + +We have the following three documentation tests: + +1. pydocstyle: test docstrings under ``tianshou/``. To check, in the main directory, run: + +.. code-block:: bash + + $ pydocstyle tianshou + +2. doc8: test ReStructuredText formats. To check, in the main directory, run: + +.. code-block:: bash + + $ doc8 docs + +3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. To check, in the main directory, run: + +.. code-block:: bash + + $ cd docs + $ make html SPHINXOPTS="-W" diff --git a/docs/index.rst b/docs/index.rst index bfb2ddfdf..5d962af71 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,6 +3,7 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. + Welcome to Tianshou! ==================== @@ -25,14 +26,15 @@ Here is Tianshou's other features: * Elegant framework, using only ~2000 lines of code * Support parallel environment simulation (synchronous or asynchronous) for all algorithms: :ref:`parallel_sampling` -* Support recurrent state/action representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` -* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` -* Support customized training process: :ref:`customize_training` +* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training` +* Support any type of environment state/action (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` +* Support :ref:`customize_training` * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation -* Support multi-agent RL: :doc:`/tutorials/tictactoe` +* Support :doc:`/tutorials/tictactoe` 中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_ + Installation ------------ @@ -70,6 +72,7 @@ If no error occurs, you have successfully installed Tianshou. Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_. + .. toctree:: :maxdepth: 1 :caption: Tutorials @@ -81,6 +84,7 @@ Tianshou is still under development, you can also check out the documents in sta tutorials/trick tutorials/cheatsheet + .. toctree:: :maxdepth: 1 :caption: API Docs @@ -92,6 +96,7 @@ Tianshou is still under development, you can also check out the documents in sta api/tianshou.exploration api/tianshou.utils + .. toctree:: :maxdepth: 1 :caption: Community diff --git a/docs/tutorials/batch.rst b/docs/tutorials/batch.rst index 390fc41ac..49d913d9f 100644 --- a/docs/tutorials/batch.rst +++ b/docs/tutorials/batch.rst @@ -3,9 +3,10 @@ Understand Batch ================ -:class:`~tianshou.data.Batch` is the internal data structure extensively used in Tianshou. It is designed to store and manipulate hierarchical named tensors. This tutorial aims to help users correctly understand the concept and the behavior of ``Batch`` so that users can make the best of Tianshou. +:class:`~tianshou.data.Batch` is the internal data structure extensively used in Tianshou. It is designed to store and manipulate hierarchical named tensors. This tutorial aims to help users correctly understand the concept and the behavior of :class:`~tianshou.data.Batch` so that users can make the best of Tianshou. + +The tutorial has three parts. We first explain the concept of hierarchical named tensors, and introduce basic usage of :class:`~tianshou.data.Batch`, followed by advanced topics of :class:`~tianshou.data.Batch`. -The tutorial has three parts. We first explain the concept of hierarchical named tensors, and introduce basic usage of ``Batch``, followed by advanced topics of ``Batch``. Hierarchical Named Tensors --------------------------- @@ -43,11 +44,13 @@ Note that, storing hierarchical named tensors is as easy as creating nested dict The real problem is how to **manipulate them**, such as adding new transition tuples into replay buffer and dealing with their heterogeneity. ``Batch`` is designed to easily create, store, and manipulate these hierarchical named tensors. + Basic Usages ------------ Here we cover some basic usages of ``Batch``, describing what ``Batch`` contains, how to construct ``Batch`` objects and how to manipulate them. + What Does Batch Contain ^^^^^^^^^^^^^^^^^^^^^^^ @@ -69,6 +72,7 @@ The content of ``Batch`` objects can be defined by the following rules. The data types of tensors are bool and numbers (any size of int and float as long as they are supported by NumPy or PyTorch). Besides, NumPy supports ndarray of objects and we take advantage of this feature to store non-number objects in ``Batch``. If one wants to store data that are neither boolean nor numbers (such as strings and sets), they can store the data in ``np.ndarray`` with the ``np.object`` data type. This way, ``Batch`` can store any type of python objects. + Construction of Batch ^^^^^^^^^^^^^^^^^^^^^ @@ -136,6 +140,7 @@ There are two ways to construct a ``Batch`` object: from a ``dict``, or using ``
+ Data Manipulation With Batch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -285,11 +290,13 @@ Stacking and concatenating multiple ``Batch`` instances, or split an instance in
+ Advanced Topics --------------- From here on, this tutorial focuses on advanced topics of ``Batch``, including key reservation, length/shape, and aggregation of heterogeneous batches. + .. _key_reservations: Key Reservations @@ -347,6 +354,7 @@ The ``Batch.is_empty`` function has an option to decide whether to identify dire Do not get confused with ``Batch.is_empty`` and ``Batch.empty``. ``Batch.empty`` and its in-place variant ``Batch.empty_`` are used to set some values to zeros or None. Check the API documentation for further details. + Length and Shape ^^^^^^^^^^^^^^^^ @@ -391,6 +399,7 @@ The ``obj.shape`` attribute of ``Batch`` behaves somewhat similar to ``len(obj)` 4. The shape of reserved keys is undetermined, too. We treat their shape as ``[]``. + .. _aggregation: Aggregation of Heterogeneous Batches @@ -457,6 +466,7 @@ For a set of ``Batch`` objects denoted as :math:`S`, they can be aggregated if t The ``Batch`` object ``b`` satisfying these rules with the minimum number of keys determines the structure of aggregating :math:`S`. The values are relatively easy to define: for any key chain ``k`` that applies to ``b``, ``b[k]`` is the stack/concatenation of ``[bi[k] for bi in S]`` (if ``k`` does not apply to ``bi``, the appropriate size of zeros or ``None`` are filled automatically). If ``bi[k]`` are all ``Batch()``, then the aggregation result is also an empty ``Batch()``. + Miscellaneous Notes ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 32358439b..2784eac3f 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -5,6 +5,7 @@ This page shows some code snippets of how to use Tianshou to develop new algorit By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`. + .. _network_api: Build Policy Network @@ -12,6 +13,7 @@ Build Policy Network See :ref:`build_the_network`. + .. _new_policy: Build New Policy @@ -19,6 +21,7 @@ Build New Policy See :class:`~tianshou.policy.BasePolicy`. + .. _customize_training: Customize Training Process @@ -26,6 +29,7 @@ Customize Training Process See :ref:`customized_trainer`. + .. _parallel_sampling: Parallel Sampling @@ -87,6 +91,7 @@ The figure in the right gives an intuitive comparison among synchronous/asynchro Otherwise, the outputs of these envs may be the same with each other. + .. _preprocess_fn: Handle Batched Data Stream in Collector @@ -96,17 +101,20 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env reset, while every key is specified for normal steps. For example, you can write your hook as: +This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps. For example, you can write your hook as: :: import numpy as np from collections import deque + + class MyProcessor: def __init__(self, size=100): self.episode_log = None self.main_log = deque(maxlen=size) self.main_log.append(0) self.baseline = 0 + def preprocess_fn(**kwargs): """change reward to zero mean""" # if only obs exist -> reset @@ -136,6 +144,7 @@ And finally, Some examples are in `test/base/test_collector.py `_. + .. _rnn_training: RNN-style Training @@ -148,13 +157,14 @@ First, add an argument ``stack_num`` to :class:`~tianshou.data.ReplayBuffer`: buf = ReplayBuffer(size=size, stack_num=stack_num) -Then, change the network to recurrent-style, for example, class ``Recurrent`` in `code snippet 1 `_, or ``RecurrentActor`` and ``RecurrentCritic`` in `code snippet 2 `_. +Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`. The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state: - Before: (s, a, s', r, d) stored in replay buffer, and get stacked s; - After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a. + .. _self_defined_env: User-defined Environment and Different State Representation @@ -174,9 +184,9 @@ First of all, your self-defined environment must follow the Gym's API, some of t - close() -> None -- observation_space +- observation_space: gym.Space -- action_space +- action_space: gym.Space The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example: :: @@ -285,6 +295,7 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y ... return copy.deepcopy(self.graph), reward, done, {} + .. _marl_example: Multi-Agent Reinforcement Learning diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index ba771ad55..c09db285d 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -14,6 +14,7 @@ Here is a more detailed description, where ``Env`` is the environment and ``Mode :align: center :height: 300 + Batch ----- @@ -48,6 +49,7 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair :ref:`batch_concept` is a dedicated tutorial for :class:`~tianshou.data.Batch`. We strongly recommend every user to read it so as to correctly understand and use :class:`~tianshou.data.Batch`. + Buffer ------ @@ -57,7 +59,6 @@ Buffer Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. -.. _policy_concept: Policy ------ @@ -73,6 +74,37 @@ A policy class typically has the following parts: * :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. * :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. + +policy.forward +^^^^^^^^^^^^^^ + +The ``forward`` function computes the action over given observations. The input and output is algorithm-specific but generally, the function is a mapping of ``(batch, state, ...) -> batch``. + +The input batch is the environment data. The first dimension of all variables in the input ``batch`` should be equal to the batch-size. + +The output is also a Batch which may contain "act", "state", "policy", and some other algorithm-specific keys. +The keyword "policy" is reserved and the corresponding data will be stored into the replay buffer. Checkout :meth:`~tianshou.policy.BasePolicy.forward` for more explanation. + +For example, if you try to use your policy to evaluate one episode (and don't want to use :meth:`~tianshou.data.Collector.collect`), use the following code-snippet: +:: + + # assume env is a gym.Env + obs, done = env.reset(), False + while not done: + batch = Batch(obs=[obs]) # the first dimension is batch-size + act = policy(batch).act[0] # policy.forward return a batch, use ".act" to extract the action + obs, rew, done, info = env.step(act) + +Here, ``Batch(obs=[obs])`` will automatically create the 0-dimension to be the batch-size. Otherwise, the network cannot determine the batch-size. + + +.. _process_fn: + +policy.process_fn +^^^^^^^^^^^^^^^^^ + +The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns. + Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: .. math:: @@ -128,11 +160,11 @@ Collector The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. -:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer. +:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer. Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. -The solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. +Our solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. @@ -150,7 +182,7 @@ Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` an A High-level Explanation ------------------------ -We give a high-level explanation through the pseudocode used in section :ref:`policy_concept`: +We give a high-level explanation through the pseudocode used in section :ref:`process_fn`: :: # pseudocode, cannot work # methods in tianshou @@ -158,13 +190,13 @@ We give a high-level explanation through the pseudocode used in section :ref:`po buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000) agent = DQN() # policy.__init__(...) for i in range(int(1e6)): # done in trainer - a = agent.compute_action(s) # policy(batch, ...) + a = agent.compute_action(s) # act = policy(batch, ...).act s_, r, d, _ = env.step(a) # collector.collect(...) buffer.store(s, a, s_, r, d) # collector.collect(...) s = s_ # collector.collect(...) if i % 1000 == 0: # done in trainer # the following is done in policy.update(batch_size, buffer) - b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size) + b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indice = buffer.sample(batch_size) # compute 2-step returns. How? b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice) # update DQN policy diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 9655ee8e1..a391069fd 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -8,6 +8,7 @@ The full script is at `test/discrete/test_dqn.py `_, which could only accept a config specification of hyperparameters, network, and others, Tianshou provides an easy way of construction through the code-level. + Make an Environment ------------------- @@ -21,10 +22,11 @@ First of all, you have to make an environment for your agent to interact with. F CartPole-v0 is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both, depending on the probability distribution on the action. + Setup Multi-environment Wrapper ------------------------------- -It is available if you want the original ``gym.Env``: +It is okay if you want the original ``gym.Env``: :: train_envs = gym.make('CartPole-v0') @@ -38,7 +40,7 @@ Tianshou supports parallel sampling for all algorithms. It provides four types o Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``. -For the demonstration, here we use the second block of codes. +For the demonstration, here we use the second code-block. .. warning:: @@ -51,12 +53,13 @@ For the demonstration, here we use the second block of codes. Otherwise, the outputs of these envs may be the same with each other. + .. _build_the_network: Build the Network ----------------- -Tianshou supports any user-defined PyTorch networks and optimizers but with the limitation of input and output API. Here is an example code: +Tianshou supports any user-defined PyTorch networks and optimizers only with the limitation of input and output API. Here is an example: :: import torch, numpy as np @@ -65,12 +68,13 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the class Net(nn.Module): def __init__(self, state_shape, action_shape): super().__init__() - self.model = nn.Sequential(*[ + self.model = nn.Sequential( nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, np.prod(action_shape)) - ]) + nn.Linear(128, np.prod(action_shape)), + ) + def forward(self, obs, state=None, info={}): if not isinstance(obs, torch.Tensor): obs = torch.tensor(obs, dtype=torch.float) @@ -83,29 +87,32 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: +You can have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. -2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need). +2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or containing intermediate result during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. + Setup Policy ------------ -We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with using a target network: +We use the defined ``net`` and ``optim`` above, with extra policy hyper-parameters, to define a policy. Here we define a DQN policy with a target network: :: policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320) + Setup Collector --------------- -The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently. +The collector is a key concept in Tianshou. It allows the policy to interact with different types of environments conveniently. In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. :: train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000)) test_collector = ts.data.Collector(policy, test_envs) + Train Policy with a Trainer --------------------------- @@ -161,15 +168,17 @@ The returned result is a dictionary as follows: It shows that within approximately 4 seconds, we finished training a DQN agent on CartPole. The mean returns over 100 consecutive episodes is 199.03. + Save/Load Policy ---------------- -Since the policy inherits the ``torch.nn.Module`` class, saving and loading the policy are exactly the same as a torch module: +Since the policy inherits the class ``torch.nn.Module``, saving and loading the policy are exactly the same as a torch module: :: torch.save(policy.state_dict(), 'dqn.pth') policy.load_state_dict(torch.load('dqn.pth')) + Watch the Agent's Performance ----------------------------- @@ -181,6 +190,7 @@ Watch the Agent's Performance collector = ts.data.Collector(policy, env) collector.collect(n_episode=1, render=1 / 35) + .. _customized_trainer: Train a Policy with Customized Codes diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 6911177d2..cc4116deb 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -6,6 +6,7 @@ In this section, we describe how to use Tianshou to implement multi-agent reinfo .. image:: ../_static/images/tic-tac-toe.png :align: center + Tic-Tac-Toe Environment ----------------------- @@ -15,11 +16,11 @@ The scripts are located at ``test/multiagent/``. We have implemented a Tic-Tac-T >>> from tic_tac_toe_env import TicTacToeEnv # the module tic_tac_toe_env is in test/multiagent/ >>> board_size = 6 # the size of board size >>> win_size = 4 # how many signs in a row are considered to win - >>> + >>> >>> # This board has 6 rows and 6 cols (36 places in total) >>> # Players place 'x' and 'o' in turn on the board >>> # The player who first gets 4 consecutive 'x's or 'o's wins - >>> + >>> >>> env = TicTacToeEnv(size=board_size, win_size=win_size) >>> obs = env.reset() >>> env.render() # render the empty board @@ -105,6 +106,7 @@ One worth-noting case is that the game is over when there is only one empty posi After being familiar with the environment, let's try to play with random agents first! + Two Random Agent ---------------- @@ -119,7 +121,7 @@ Tianshou already provides some builtin classes for multi-agent learning. You can >>> from tianshou.data import Collector >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager >>> - >>> # agents should be wrapped into one policy, + >>> # agents should be wrapped into one policy, >>> # which is responsible for calling the acting agent correctly >>> # here we use two random agents >>> policy = MultiAgentPolicyManager([RandomPolicy(), RandomPolicy()]) @@ -159,7 +161,8 @@ Tianshou already provides some builtin classes for multi-agent learning. You can ===x _ _ _ x x=== ================= -Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly. +Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly. + Train an MARL Agent ------------------- @@ -212,7 +215,7 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--watch', default=False, action='store_true', help='no training, watch the play of pre-trained models') parser.add_argument('--agent_id', type=int, default=2, - help='the learned agent plays as the agent_id-th player. choices are 1 and 2.') + help='the learned agent plays as the agent_id-th player. Choices are 1 and 2.') parser.add_argument('--resume_path', type=str, default='', help='the path of agent pth file for resuming from a pre-trained agent') parser.add_argument('--opponent_path', type=str, default='', @@ -229,7 +232,7 @@ The following ``get_agents`` function returns agents and their optimizers from e - The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function; - The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; -- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. +- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. diff --git a/docs/tutorials/trick.rst b/docs/tutorials/trick.rst index 5a73ff9e6..0de8b411e 100644 --- a/docs/tutorials/trick.rst +++ b/docs/tutorials/trick.rst @@ -56,13 +56,13 @@ Algorithm specific tricks Here is about the experience of hyper-parameter tuning on CartPole and Pendulum: -* :class:`~tianshou.policy.DQNPolicy`: use estimation_step greater than 1 and target network, also with a suitable size of replay buffer; +* :class:`~tianshou.policy.DQNPolicy`: use estimation_step = 3 or 4 and target network, also with a suitable size of replay buffer; * :class:`~tianshou.policy.PGPolicy`: TBD * :class:`~tianshou.policy.A2CPolicy`: TBD * :class:`~tianshou.policy.PPOPolicy`: TBD * :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K). -* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable. +* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable. Code-level optimization @@ -70,8 +70,6 @@ Code-level optimization Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward. -.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference. - Atari/Mujoco Task Specific -------------------------- diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..188700c10 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[pydocstyle] +ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402 + +[doc8] +max-line-length = 1000 diff --git a/setup.py b/setup.py index 789ea2d57..93acc4a7e 100644 --- a/setup.py +++ b/setup.py @@ -12,65 +12,61 @@ def get_version() -> str: setup( - name='tianshou', + name="tianshou", version=get_version(), - description='A Library for Deep Reinforcement Learning', - long_description=open('README.md', encoding='utf8').read(), - long_description_content_type='text/markdown', - url='https://github.com/thu-ml/tianshou', - author='TSAIL', - author_email='trinkle23897@gmail.com', - license='MIT', - python_requires='>=3.6', + description="A Library for Deep Reinforcement Learning", + long_description=open("README.md", encoding="utf8").read(), + long_description_content_type="text/markdown", + url="https://github.com/thu-ml/tianshou", + author="TSAIL", + author_email="trinkle23897@gmail.com", + license="MIT", + python_requires=">=3.6", classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 3 - Alpha', + "Development Status :: 4 - Beta", # Indicate who your project is intended for - 'Intended Audience :: Science/Research', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", # Pick your license as you wish (should match "license" above) - 'License :: OSI Approved :: MIT License', + "License :: OSI Approved :: MIT License", # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], - keywords='reinforcement learning platform pytorch', - packages=find_packages(exclude=['test', 'test.*', - 'examples', 'examples.*', - 'docs', 'docs.*']), + keywords="reinforcement learning platform pytorch", + packages=find_packages( + exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"] + ), install_requires=[ - 'gym>=0.15.4', - 'tqdm', - 'numpy', - 'tensorboard', - 'torch>=1.4.0', - 'numba>=0.51.0', + "gym>=0.15.4", + "tqdm", + "numpy", + "tensorboard", + "torch>=1.4.0", + "numba>=0.51.0", ], extras_require={ - 'dev': [ - 'Sphinx', - 'sphinx_rtd_theme', - 'sphinxcontrib-bibtex', - 'flake8', - 'pytest', - 'pytest-cov', - 'ray>=0.8.0', - ], - 'atari': [ - 'atari_py', - 'cv2', - ], - 'mujoco': [ - 'mujoco_py', - ], - 'pybullet': [ - 'pybullet', + "dev": [ + "Sphinx", + "sphinx_rtd_theme", + "sphinxcontrib-bibtex", + "flake8", + "pytest", + "pytest-cov", + "ray>=0.8.0", + "mypy", + "pydocstyle", + "doc8", ], + "atari": ["atari_py", "cv2"], + "mujoco": ["mujoco_py"], + "pybullet": ["pybullet"], }, ) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 217531611..e7e11759f 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -200,7 +200,6 @@ def test_collector_with_dict_state(): assert not np.isclose(obs[0]['rand'], obs[1]['rand']) c1 = Collector(policy, envs, ReplayBuffer(size=100), Logger.single_preprocess_fn) - c1.seed(0) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) batch, _ = c1.buffer.sample(10) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index fe70d1f7a..3726126e5 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -79,7 +79,8 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: def _create_value(inst: Any, size: int, stack=True) -> Union[ 'Batch', np.ndarray, torch.Tensor]: - """ + """Create empty place-holders accroding to inst's shape. + :param bool stack: whether to stack or to concatenate. E.g. if inst has shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape of (10, 3, 5), otherwise (10, 5) @@ -154,9 +155,8 @@ def _parse_value(v: Any): class Batch: - """Tianshou provides :class:`~tianshou.data.Batch` as the internal data - structure to pass any kind of data to other methods, for example, a - collector gives a :class:`~tianshou.data.Batch` to policy for learning. + """Batch the internal data structure to pass any kind of data to other \ + methods, for example, a collector gives a batch data to policy.forward(). For a detailed description, please refer to :ref:`batch_concept`. """ @@ -180,12 +180,13 @@ def __init__(self, self.__init__(kwargs, copy=copy) def __setattr__(self, key: str, value: Any) -> None: - """self.key = value""" + """Set self.key = value.""" self.__dict__[key] = _parse_value(value) def __getstate__(self) -> dict: - """Pickling interface. Only the actual data are serialized for both - efficiency and simplicity. + """Pickling interface. + + Only the actual data are serialized for both efficiency and simplicity. """ state = {} for k, v in self.items(): @@ -195,9 +196,10 @@ def __getstate__(self) -> dict: return state def __setstate__(self, state) -> None: - """Unpickling interface. At this point, self is an empty Batch instance - that has not been initialized, so it can safely be initialized by the - pickle state. + """Unpickling interface. + + At this point, self is an empty Batch instance that has not been + initialized, so it can safely be initialized by the pickle state. """ self.__init__(**state) @@ -246,8 +248,7 @@ def __setitem__(self, index: Union[ self.__dict__[key][index] = None def __iadd__(self, other: Union['Batch', Number, np.number]): - """Algebraic addition with another :class:`~tianshou.data.Batch` - instance in-place.""" + """Algebraic addition with another Batch instance in-place.""" if isinstance(other, Batch): for (k, r), v in zip(self.__dict__.items(), other.__dict__.values()): @@ -268,14 +269,12 @@ def __iadd__(self, other: Union['Batch', Number, np.number]): raise TypeError("Only addition of Batch or number is supported.") def __add__(self, other: Union['Batch', Number, np.number]): - """Algebraic addition with another :class:`~tianshou.data.Batch` - instance out-of-place.""" + """Algebraic addition with another Batch instance out-of-place.""" return deepcopy(self).__iadd__(other) def __imul__(self, val: Union[Number, np.number]): """Algebraic multiplication with a scalar value in-place.""" - assert _is_number(val), \ - "Only multiplication by a number is supported." + assert _is_number(val), "Only multiplication by a number is supported." for k, r in self.__dict__.items(): if isinstance(r, Batch) and r.is_empty(): continue @@ -288,8 +287,7 @@ def __mul__(self, val: Union[Number, np.number]): def __itruediv__(self, val: Union[Number, np.number]): """Algebraic division with a scalar value in-place.""" - assert _is_number(val), \ - "Only division by a number is supported." + assert _is_number(val), "Only division by a number is supported." for k, r in self.__dict__.items(): if isinstance(r, Batch) and r.is_empty(): continue @@ -336,15 +334,11 @@ def get(self, k: str, d: Optional[Any] = None) -> Any: return self.__dict__.get(k, d) def pop(self, k: str, d: Optional[Any] = None) -> Any: - """Return and remove self[k] if k in self else d. d defaults to - None. - """ + """Return & remove self[k] if k in self else d. d defaults to None.""" return self.__dict__.pop(k, d) def to_numpy(self) -> None: - """Change all torch.Tensor to numpy.ndarray. This is an in-place - operation. - """ + """Change all torch.Tensor to numpy.ndarray in-place.""" for k, v in self.items(): if isinstance(v, torch.Tensor): self.__dict__[k] = v.detach().cpu().numpy() @@ -353,9 +347,7 @@ def to_numpy(self) -> None: def to_torch(self, dtype: Optional[torch.dtype] = None, device: Union[str, int, torch.device] = 'cpu') -> None: - """Change all numpy.ndarray to torch.Tensor. This is an in-place - operation. - """ + """Change all numpy.ndarray to torch.Tensor in-place.""" if not isinstance(device, torch.device): device = torch.device(device) @@ -382,7 +374,9 @@ def to_torch(self, dtype: Optional[torch.dtype] = None, def __cat(self, batches: List[Union[dict, 'Batch']], lens: List[int]) -> None: - """:: + """Private method for Batch.cat_. + + :: >>> a = Batch(a=np.random.randn(3, 4)) >>> x = Batch(a=a, b=np.random.randn(4, 4)) @@ -448,9 +442,7 @@ def __cat(self, def cat_(self, batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: - """Concatenate a list of (or one) :class:`~tianshou.data.Batch` objects - into current batch. - """ + """Concatenate a list of (or one) Batch objects into current batch.""" if isinstance(batches, Batch): batches = [batches] if len(batches) == 0: @@ -477,10 +469,10 @@ def cat_(self, @staticmethod def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': - """Concatenate a list of :class:`~tianshou.data.Batch` object into a - single new batch. For keys that are not shared across all batches, - batches that do not have these keys will be padded by zeros with - appropriate shapes. E.g. + """Concatenate a list of Batch object into a single new batch. + + For keys that are not shared across all batches, batches that do not + have these keys will be padded by zeros with appropriate shapes. E.g. :: >>> a = Batch(a=np.zeros([3, 4]), common=Batch(c=np.zeros([3, 5]))) @@ -500,9 +492,7 @@ def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': def stack_(self, batches: List[Union[dict, 'Batch']], axis: int = 0) -> None: - """Stack a list of :class:`~tianshou.data.Batch` object into current - batch. - """ + """Stack a list of Batch object into current batch.""" if len(batches) == 0: return batches = [x if isinstance(x, Batch) else Batch(x) for x in batches] @@ -553,9 +543,10 @@ def stack_(self, @staticmethod def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': - """Stack a list of :class:`~tianshou.data.Batch` object into a single - new batch. For keys that are not shared across all batches, - batches that do not have these keys will be padded by zeros. E.g. + """Stack a list of Batch object into a single new batch. + + For keys that are not shared across all batches, batches that do not + have these keys will be padded by zeros. E.g. :: >>> a = Batch(a=np.zeros([4, 4]), common=Batch(c=np.zeros([4, 5]))) @@ -580,9 +571,9 @@ def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': def empty_(self, index: Union[ str, slice, int, np.integer, np.ndarray, List[int]] = None ) -> 'Batch': - """Return an empty a :class:`~tianshou.data.Batch` object with 0 or - ``None`` filled. If ``index`` is specified, it will only reset the - specific indexed-data. + """Return an empty Batch object with 0 or "None" filled. + + If "index" is specified, it will only reset the specific indexed-data. :: >>> data.empty_() @@ -629,9 +620,9 @@ def empty_(self, index: Union[ def empty(batch: 'Batch', index: Union[ str, slice, int, np.integer, np.ndarray, List[int]] = None ) -> 'Batch': - """Return an empty :class:`~tianshou.data.Batch` object with 0 or - ``None`` filled, the shape is the same as the given - :class:`~tianshou.data.Batch`. + """Return an empty Batch object with 0 or "None" filled. + + The shape is the same as the given Batch. """ return deepcopy(batch).empty_(index) @@ -664,9 +655,10 @@ def __len__(self) -> int: return min(r) def is_empty(self, recurse: bool = False): - """ - Test if a Batch is empty. If ``recurse=True``, it further tests the - values of the object; else it only tests the existence of any key. + """Test if a Batch is empty. + + If ``recurse=True``, it further tests the values of the object; else + it only tests the existence of any key. ``b.is_empty(recurse=True)`` is mainly used to distinguish ``Batch(a=Batch(a=Batch()))`` and ``Batch(a=1)``. They both raise diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 3c5658f8e..c636baf19 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -7,9 +7,11 @@ class ReplayBuffer: - """:class:`~tianshou.data.ReplayBuffer` stores data generated from - interaction between the policy and environment. The current implementation - of Tianshou typically use 7 reserved keys in :class:`~tianshou.data.Batch`: + """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ + interaction between the policy and environment. + + The current implementation of Tianshou typically use 7 reserved keys in + :class:`~tianshou.data.Batch` * ``obs`` the observation of step :math:`t` ; * ``act`` the action of step :math:`t` ; @@ -150,15 +152,17 @@ def __repr__(self) -> str: return self.__class__.__name__ + self._meta.__repr__()[5:] def __getattr__(self, key: str) -> Any: - """Return self.key""" + """Return self.key.""" try: return self._meta[key] except KeyError as e: raise AttributeError from e def __setstate__(self, state): - """Unpickling interface. We need it because pickling buffer does not - work out-of-the-box (``buffer.__getattr__`` is customized). + """Unpickling interface. + + We need it because pickling buffer does not work out-of-the-box + (``buffer.__getattr__`` is customized). """ self.__dict__.update(state) @@ -280,9 +284,11 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: - """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], - where s is self.key, t is indice. The stack_num (here equals to 4) is - given from buffer initialization procedure. + """Return the stacked result. + + E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the + indice. The stack_num (here equals to 4) is given from buffer + initialization procedure. """ if stack_num is None: stack_num = self.stack_num @@ -325,8 +331,10 @@ def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, def __getitem__(self, index: Union[ slice, int, np.integer, np.ndarray]) -> Batch: - """Return a data batch: self[index]. If stack_num is larger than 1, - return the stacked obs and obs_next with shape [batch, len, ...]. + """Return a data batch: self[index]. + + If stack_num is larger than 1, return the stacked obs and obs_next + with shape (batch, len, ...). """ return Batch( obs=self.get(index, 'obs'), @@ -340,8 +348,10 @@ def __getitem__(self, index: Union[ class ListReplayBuffer(ReplayBuffer): - """The function of :class:`~tianshou.data.ListReplayBuffer` is almost the - same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that + """List-based replay buffer. + + The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same + as :class:`~tianshou.data.ReplayBuffer`. The only difference is that :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore, it does not support advanced indexing, which means you cannot sample a batch of data out of it. It is typically used for storing data. @@ -373,7 +383,7 @@ def reset(self) -> None: class PrioritizedReplayBuffer(ReplayBuffer): - """Implementation of Prioritized Experience Replay. arXiv:1511.05952 + """Implementation of Prioritized Experience Replay. arXiv:1511.05952. :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. @@ -388,18 +398,11 @@ def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None: super().__init__(size, **kwargs) assert alpha > 0. and beta >= 0. self._alpha, self._beta = alpha, beta - self._max_prio = 1. - self._min_prio = 1. - # bypass the check - self._weight = SegmentTree(size) + self._max_prio = self._min_prio = 1.0 + # save weight directly in this class instead of self._meta + self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() - def __getattr__(self, key: str) -> Union['Batch', Any]: - """Return self.key""" - if key == 'weight': - return self._weight - return super().__getattr__(key) - def add(self, obs: Union[dict, Batch, np.ndarray, float], act: Union[dict, Batch, np.ndarray, float], @@ -418,11 +421,12 @@ def add(self, self._max_prio = max(self._max_prio, weight) self._min_prio = min(self._min_prio, weight) self.weight[self._index] = weight ** self._alpha - super().add(obs, act, rew, done, obs_next, info, policy) + super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with priority probability. Return - all the data in the buffer if batch_size is ``0``. + """Get a random sample from buffer with priority probability. + + Return all the data in the buffer if batch_size is ``0``. :return: Sample data and its corresponding index inside the buffer. @@ -440,7 +444,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: scalar = np.random.rand(batch_size) * self.weight.reduce() indice = self.weight.get_prefix_sum_idx(scalar) batch = self[indice] - # impt_weight + # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) batch.weight = (batch.weight / self._min_prio) ** (-self._beta) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 4ffacac86..5c2f28068 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -14,8 +14,7 @@ class Collector(object): - """The :class:`~tianshou.data.Collector` enables the policy to interact - with different types of environments conveniently. + """Collector enables the policy to interact with different types of envs. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. @@ -42,7 +41,7 @@ class Collector(object): :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py". - Example: + Here is the example: :: policy = PGPolicy(...) # or other policies if you wish @@ -139,9 +138,7 @@ def get_env_num(self) -> int: return self.env_num def reset_env(self) -> None: - """Reset all of the environment(s)' states and reset all of the cache - buffers (if need). - """ + """Reset all of the environment(s)' states and the cache buffers.""" self._ready_env_ids = np.arange(self.env_num) obs = self.env.reset() if self.preprocess_fn: @@ -150,14 +147,6 @@ def reset_env(self) -> None: for b in self._cached_buf: b.reset() - def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: - """Reset all the seed(s) of the given environment(s).""" - return self.env.seed(seed) - - def render(self, **kwargs) -> None: - """Render all the environment(s).""" - return self.env.render(**kwargs) - def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" state = self.data.state # it is a reference @@ -291,7 +280,7 @@ def collect(self, self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if render: - self.render() + self.env.render() time.sleep(render) # add data into the buffer @@ -378,9 +367,10 @@ def collect(self, } def sample(self, batch_size: int) -> Batch: - """Sample a data batch from the internal replay buffer. It will call - :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the - final batch data. + """Sample a data batch from the internal replay buffer. + + It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before + returning the final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index e97b05416..dd36c73d1 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -67,8 +67,9 @@ def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], def to_torch_as(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], y: torch.Tensor ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: - """Return an object without np.ndarray. Same as - ``to_torch(x, dtype=y.dtype, device=y.device)``. + """Return an object without np.ndarray. + + Same as ``to_torch(x, dtype=y.dtype, device=y.device)``. """ assert isinstance(y, torch.Tensor) return to_torch(x, dtype=y.dtype, device=y.device) diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py index e049f3a14..c4b4871d9 100644 --- a/tianshou/data/utils/segtree.py +++ b/tianshou/data/utils/segtree.py @@ -4,13 +4,13 @@ class SegmentTree: - """Implementation of Segment Tree: store an array ``arr`` with size ``n`` - in a segment tree, support value update and fast query of the sum for the - interval ``[left, right)`` in O(log n) time. + """Implementation of Segment Tree. - The detailed procedure is as follows: + The segment tree stores an array ``arr`` with size ``n``. It supports value + update and fast query of the sum for the interval ``[left, right)`` in + O(log n) time. The detailed procedure is as follows: - 1. Pad the array to have length of power of 2, so that leaf nodes in the\ + 1. Pad the array to have length of power of 2, so that leaf nodes in the \ segment tree have the same depth. 2. Store the segment tree in a binary heap. @@ -30,12 +30,14 @@ def __len__(self): def __getitem__(self, index: Union[int, np.ndarray] ) -> Union[float, np.ndarray]: - """Return self[index]""" + """Return self[index].""" return self._value[index + self._bound] def __setitem__(self, index: Union[int, np.ndarray], value: Union[float, np.ndarray]) -> None: - """Duplicate values in ``index`` are handled by numpy: later index + """Update values in segment tree. + + Duplicate values in ``index`` are handled by numpy: later index overwrites previous ones. :: @@ -61,9 +63,11 @@ def reduce(self, start: int = 0, end: Optional[int] = None) -> float: def get_prefix_sum_idx( self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]: - """Return the minimum index for each ``v`` in ``value`` so that - :math:`v \\le \\mathrm{sums}_i`, where :math:`\\mathrm{sums}_i = - \\sum_{j=0}^{i} \\mathrm{arr}_j`. + r"""Find the index with given value. + + Return the minimum index for each ``v`` in ``value`` so that + :math:`v \le \mathrm{sums}_i`, where + :math:`\mathrm{sums}_i = \sum_{j = 0}^{i} \mathrm{arr}_j`. .. warning:: @@ -81,7 +85,7 @@ def get_prefix_sum_idx( @njit def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: - """4x faster: 0.1 -> 0.024""" + """Numba version, 4x faster: 0.1 -> 0.024.""" tree[index] = value while index[0] > 1: index //= 2 @@ -90,7 +94,7 @@ def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: @njit def _reduce(tree: np.ndarray, start: int, end: int) -> float: - """2x faster: 0.009 -> 0.005""" + """Numba version, 2x faster: 0.009 -> 0.005.""" # nodes in (start, end) should be aggregated result = 0. while end - start > 1: # (start, end) interval is not empty @@ -106,7 +110,8 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float: @njit def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray: - """numba version (v0.51), 5x speed up with size=100000 and bsz=64 + """Numba version (v0.51), 5x speed up with size=100000 and bsz=64. + vectorized np: 0.0923 (numpy best) -> 0.024 (now) for-loop: 0.2914 -> 0.019 (but not so stable) """ diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py index ea9284ea7..9153cf563 100644 --- a/tianshou/env/maenv.py +++ b/tianshou/env/maenv.py @@ -5,8 +5,10 @@ class MultiAgentEnv(ABC, gym.Env): - """The interface for multi-agent environments. Multi-agent environments - must be wrapped as :class:`~tianshou.env.MultiAgentEnv`. Here is the usage: + """The interface for multi-agent environments. + + Multi-agent environments must be wrapped as + :class:`~tianshou.env.MultiAgentEnv`. Here is the usage: :: env = MultiAgentEnv(...) @@ -25,18 +27,20 @@ def __init__(self, **kwargs) -> None: @abstractmethod def reset(self) -> dict: - """Reset the state. Return the initial state, first agent_id, and the - initial action set, for example, - ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}`` + """Reset the state. + + Return the initial state, first agent_id, and the initial action set, + for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}`` """ pass @abstractmethod def step(self, action: np.ndarray ) -> Tuple[dict, np.ndarray, np.ndarray, np.ndarray]: - """Run one timestep of the environment’s dynamics. When the end of - episode is reached, you are responsible for calling reset() to reset - the environment’s state. + """Run one timestep of the environment’s dynamics. + + When the end of episode is reached, you are responsible for calling + reset() to reset the environment’s state. Accept action and return a tuple (obs, rew, done, info). diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index 41b9edea6..f7d8c5801 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -2,7 +2,7 @@ class CloudpickleWrapper(object): - """A cloudpickle wrapper used in :class:`~tianshou.env.SubprocVectorEnv`""" + """A cloudpickle wrapper used in SubprocVectorEnv.""" def __init__(self, data): self.data = data diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 04323498d..fa6cf812d 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -8,7 +8,9 @@ class BaseVectorEnv(gym.Env): - """Base class for vectorized environments wrapper. Usage: + """Base class for vectorized environments wrapper. + + Usage: :: env_num = 8 @@ -45,7 +47,7 @@ def seed(self, seed): :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env. :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a - worker which contains this env. + worker which contains the i-th env. :param int wait_num: use in asynchronous simulation if the time cost of ``env.step`` varies with time and synchronously waiting for all environments to finish a step is time-wasting. In that case, we can @@ -98,10 +100,12 @@ def __len__(self) -> int: return self.env_num def __getattribute__(self, key: str) -> Any: - """Any class who inherits ``gym.Env`` will inherit some attributes, - like ``action_space``. However, we would like the attribute lookup to - go straight into the worker (in fact, this vector env's action_space - is always ``None``). + """Switch the attribute getter depending on the key. + + Any class who inherits ``gym.Env`` will inherit some attributes, like + ``action_space``. However, we would like the attribute lookup to go + straight into the worker (in fact, this vector env's action_space is + always ``None``). """ if key in ['metadata', 'reward_range', 'spec', 'action_space', 'observation_space']: # reserved keys in gym.Env @@ -110,9 +114,11 @@ def __getattribute__(self, key: str) -> Any: return super().__getattribute__(key) def __getattr__(self, key: str) -> Any: - """Try to retrieve an attribute from each individual wrapped - environment, if it does not belong to the wrapping vector environment - class. + """Fetch a list of env attributes. + + This function tries to retrieve an attribute from each individual + wrapped environment, if it does not belong to the wrapping vector + environment class. """ return [getattr(worker, key) for worker in self.workers] @@ -133,9 +139,11 @@ def _assert_id(self, id: List[int]) -> None: def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None ) -> np.ndarray: - """Reset the state of all the environments and return initial - observations if id is ``None``, otherwise reset the specific - environments with the given id, either an int or a list. + """Reset the state of some envs and return initial observations. + + If id is "None", reset the state of all the environments and return + initial observations, otherwise reset the specific environments with + the given id, either an int or a list. """ self._assert_is_not_closed() id = self._wrap_id(id) @@ -148,7 +156,9 @@ def step(self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None ) -> List[np.ndarray]: - """Run one timestep of all the environments’ dynamics if id is "None", + """Run one timestep of some environments' dynamics. + + If id is "None", run one timestep of all the environments’ dynamics; otherwise run one timestep for some environments with given id, either an int or a list. When the end of episode is reached, you are responsible for calling reset(id) to reset this environment’s state. @@ -175,7 +185,7 @@ def step(self, should correspond to the ``id`` argument, and the ``id`` argument should be a subset of the ``env_id`` in the last returned ``info`` (initially they are env_ids of all the environments). If action is - ``None``, fetch unfinished step() calls instead. + "None", fetch unfinished step() calls instead. """ self._assert_is_not_closed() id = self._wrap_id(id) @@ -239,9 +249,11 @@ def render(self, **kwargs) -> List[Any]: return [w.render(**kwargs) for w in self.workers] def close(self) -> None: - """Close all of the environments. This function will be called only - once (if not, it will be called during garbage collected). This way, - ``close`` of all workers can be assured. + """Close all of the environments. + + This function will be called only once (if not, it will be called + during garbage collected). This way, ``close`` of all workers can be + assured. """ self._assert_is_not_closed() for w in self.workers: @@ -249,6 +261,7 @@ def close(self) -> None: self.is_closed = True def __del__(self) -> None: + """Redirect to self.close().""" if not self.is_closed: self.close() @@ -270,6 +283,8 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], class VectorEnv(DummyVectorEnv): + """VectorEnv is renamed to DummyVectorEnv.""" + def __init__(self, *args, **kwargs) -> None: warnings.warn( 'VectorEnv is renamed to DummyVectorEnv, and will be removed in ' @@ -296,9 +311,9 @@ def worker_fn(fn): class ShmemVectorEnv(BaseVectorEnv): - """Optimized version of SubprocVectorEnv which uses shared variables to - communicate observations. ShmemVectorEnv has exactly the same API as - SubprocVectorEnv. + """Optimized SubprocVectorEnv with shared buffers to exchange observations. + + ShmemVectorEnv has exactly the same API as SubprocVectorEnv. .. seealso:: @@ -316,9 +331,9 @@ def worker_fn(fn): class RayVectorEnv(BaseVectorEnv): - """Vectorized environment wrapper based on - `ray `_. This is a choice to run - distributed environments in a cluster. + """Vectorized environment wrapper based on ray. + + This is a choice to run distributed environments in a cluster. .. seealso:: diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 2b56dab9b..e1a0708e7 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -30,8 +30,10 @@ def get_result(self) -> Tuple[ def step(self, action: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """``send_action`` and ``get_result`` are coupled in sync simulation, - so typically users only call ``step`` function. But they can be called + """Perform one timestep of the environment's dynamic. + + ``send_action`` and ``get_result`` are coupled in sync simulation, so + typically users only call ``step`` function. But they can be called separately in async simulation, i.e. someone calls ``send_action`` first, and calls ``get_result`` later. """ diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 893500b28..b705913e6 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -22,7 +22,7 @@ def reset(self) -> Any: def wait(workers: List['DummyEnvWorker'], wait_num: int, timeout: Optional[float] = None) -> List['DummyEnvWorker']: - # SequentialEnvWorker objects are always ready + # Sequential EnvWorker objects are always ready return workers def send_action(self, action: np.ndarray) -> None: diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 3186b01db..857d1483e 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -76,7 +76,7 @@ def _encode_obs(obs, buffer): class ShArray: - """Wrapper of multiprocessing Array""" + """Wrapper of multiprocessing Array.""" def __init__(self, dtype, shape): self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 19f4424cc..d9ef0069f 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -20,9 +20,7 @@ def reset(self) -> None: class GaussianNoise(BaseNoise): - """Class for vanilla gaussian process, - used for exploration in DDPG by default. - """ + """The vanilla gaussian process, for exploration in DDPG by default.""" def __init__(self, mu: float = 0.0, @@ -38,6 +36,7 @@ def __call__(self, size: tuple) -> np.ndarray: class OUNoise(BaseNoise): """Class for Ornstein-Uhlenbeck process, as used for exploration in DDPG. + Usage: :: @@ -67,8 +66,9 @@ def __init__(self, self.reset() def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray: - """Generate new noise. Return a ``numpy.ndarray`` which size is equal - to ``size``. + """Generate new noise. + + Return a ``numpy.ndarray`` which size is equal to ``size``. """ if self._x is None or self._x.shape != size: self._x = 0 diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 7b943ae9a..0a079ca9f 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -11,8 +11,10 @@ class BasePolicy(ABC, nn.Module): - """Tianshou aims to modularizing RL algorithms. It comes into several - classes of policies in Tianshou. All of the policy classes must inherit + """The base class for any RL policy. + + Tianshou aims to modularizing RL algorithms. It comes into several classes + of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. A policy class typically has four parts: @@ -62,7 +64,7 @@ def __init__(self, self.agent_id = 0 def set_agent_id(self, agent_id: int) -> None: - """set self.agent_id = agent_id, for MARL.""" + """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id @abstractmethod @@ -86,7 +88,7 @@ def forward(self, batch: Batch, return Batch(logits=..., act=..., state=None, dist=...) The keyword ``policy`` is reserved and the corresponding data will be - stored into the replay buffer in numpy. For instance, + stored into the replay buffer. For instance, :: # some code @@ -98,8 +100,10 @@ def forward(self, batch: Batch, def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: - """Pre-process the data from the provided replay buffer. Check out - :ref:`policy_concept` for more information. + """Pre-process the data from the provided replay buffer. + + Used in :meth:`update`. Check out :ref:`process_fn` for more + information. """ return batch @@ -123,26 +127,28 @@ def learn(self, batch: Batch, **kwargs def post_process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> None: - """Post-process the data from the provided replay buffer. Typical - usage is to update the sampling weight in prioritized experience - replay. Check out :ref:`policy_concept` for more information. + """Post-process the data from the provided replay buffer. + + Typical usage is to update the sampling weight in prioritized + experience replay. Used in :meth:`update`. """ if isinstance(buffer, PrioritizedReplayBuffer) \ and hasattr(batch, 'weight'): buffer.update_weight(indice, batch.weight) - def update(self, batch_size: int, buffer: Optional[ReplayBuffer], + def update(self, sample_size: int, buffer: Optional[ReplayBuffer], *args, **kwargs) -> Dict[str, Union[float, List[float]]]: - """Update the policy network and replay buffer (if needed). It includes - three function steps: process_fn, learn, and post_process_fn. + """Update the policy network and replay buffer. - :param int batch_size: 0 means it will extract all the data from the - buffer, otherwise it will sample a batch with the given batch_size. + It includes 3 function steps: process_fn, learn, and post_process_fn. + + :param int sample_size: 0 means it will extract all the data from the + buffer, otherwise it will sample a batch with given sample_size. :param ReplayBuffer buffer: the corresponding replay buffer. """ if buffer is None: return {} - batch, indice = buffer.sample(batch_size) + batch, indice = buffer.sample(sample_size) batch = self.process_fn(batch, buffer, indice) result = self.learn(batch, *args, **kwargs) self.post_process_fn(batch, buffer, indice) @@ -156,8 +162,9 @@ def compute_episodic_return( gae_lambda: float = 0.95, rew_norm: bool = False, ) -> Batch: - """Compute returns over given full-length episodes, including the - implementation of Generalized Advantage Estimator (arXiv:1506.02438). + """Compute returns over given full-length episodes. + + Implementation of Generalized Advantage Estimator (arXiv:1506.02438). :param batch: a data batch which contains several full-episode data chronologically. @@ -192,13 +199,13 @@ def compute_nstep_return( n_step: int = 1, rew_norm: bool = False, ) -> Batch: - r"""Compute n-step return for Q-learning targets: + r"""Compute n-step return for Q-learning targets. .. math:: G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) - , where :math:`\gamma` is the discount factor, + where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. @@ -249,7 +256,7 @@ def _episodic_return( v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray, gamma: float, gae_lambda: float, ) -> np.ndarray: - """Numba speedup: 4.1s -> 0.057s""" + """Numba speedup: 4.1s -> 0.057s.""" returns = np.roll(v_s_, 1) m = (1. - done) * gamma delta = rew + v_s_ * m - returns @@ -267,7 +274,7 @@ def _nstep_return( indice: np.ndarray, gamma: float, n_step: int, buf_len: int, mean: float, std: float ) -> np.ndarray: - """Numba speedup: 0.3s -> 0.15s""" + """Numba speedup: 0.3s -> 0.15s.""" returns = np.zeros(indice.shape) gammas = np.full(indice.shape, n_step) for n in range(n_step - 1, -1, -1): diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 0f7cffd58..ec872c002 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -9,15 +9,14 @@ class A2CPolicy(PGPolicy): - """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783 + """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.nn.Module critic: the critic network. (s -> V(s)) :param torch.optim.Optimizer optim: the optimizer for actor and critic network. - :param torch.distributions.Distribution dist_fn: for computing the action, - defaults to ``torch.distributions.Categorical``. + :param dist_fn: distribution class for computing the action. :param float discount_factor: in [0, 1], defaults to 0.99. :param float vf_coef: weight for value loss, defaults to 0.5. :param float ent_coef: weight for entropy loss, defaults to 0.01. diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 6c34e34ac..5b11ae886 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -9,7 +9,7 @@ class DDPGPolicy(BasePolicy): - """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971 + """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 5c5e45d5a..95e7be6f3 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -8,12 +8,12 @@ class DQNPolicy(BasePolicy): - """Implementation of Deep Q Network. arXiv:1312.5602 + """Implementation of Deep Q Network. arXiv:1312.5602. - Implementation of Double Q-Learning. arXiv:1509.06461 + Implementation of Double Q-Learning. arXiv:1509.06461. Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is - implemented in the network side, not here) + implemented in the network side, not here). :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) @@ -87,8 +87,10 @@ def _target_q(self, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: - """Compute the n-step return for Q-learning targets. More details can - be found at :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. + """Compute the n-step return for Q-learning targets. + + More details can be found at + :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. """ batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, @@ -101,9 +103,10 @@ def forward(self, batch: Batch, input: str = 'obs', eps: Optional[float] = None, **kwargs) -> Batch: - """Compute action over the given batch data. If you need to mask the - action, please add a "mask" into batch.obs, for example, if we have an - environment that has "0/1/2" three actions: + """Compute action over the given batch data. + + If you need to mask the action, please add a "mask" into batch.obs, for + example, if we have an environment that has "0/1/2" three actions: :: batch == Batch( diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 3eaae641e..e0a52126e 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -12,7 +12,7 @@ class PGPolicy(BasePolicy): :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. - :param torch.distributions.Distribution dist_fn: for computing the action. + :param dist_fn: distribution class for computing the action. :param float discount_factor: in [0, 1]. .. seealso:: @@ -38,12 +38,12 @@ def __init__(self, def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: - r"""Compute the discounted returns for each frame: + r"""Compute the discounted returns for each frame. .. math:: G_t = \sum_{i=t}^T \gamma^{i-t}r_i - , where :math:`T` is the terminal time step, :math:`\gamma` is the + where :math:`T` is the terminal time step, :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. """ # batch.returns = self._vanilla_returns(batch) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 2db5baf3b..9a4d1a26f 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -8,14 +8,14 @@ class PPOPolicy(PGPolicy): - r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347 + r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.nn.Module critic: the critic network. (s -> V(s)) :param torch.optim.Optimizer optim: the optimizer for actor and critic network. - :param torch.distributions.Distribution dist_fn: for computing the action. + :param dist_fn: distribution class for computing the action. :param float discount_factor: in [0, 1], defaults to 0.99. :param float max_grad_norm: clipping gradients in back propagation, defaults to ``None``. diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index dfbc60e05..efb7f9549 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -10,7 +10,7 @@ class SACPolicy(DDPGPolicy): - """Implementation of Soft Actor-Critic. arXiv:1812.05905 + """Implementation of Soft Actor-Critic. arXiv:1812.05905. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 9150f3770..8ed093563 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -9,8 +9,7 @@ class TD3Policy(DDPGPolicy): - """Implementation of Twin Delayed Deep Deterministic Policy Gradient, - arXiv:1802.09477 + """Implementation of TD3, arXiv:1802.09477. :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 74ab0ecac..bf7b18229 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -6,7 +6,9 @@ class MultiAgentPolicyManager(BasePolicy): - """This multi-agent policy manager accepts a list of + """Multi-agent policy manager for MARL. + + This multi-agent policy manager accepts a list of :class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each of these policies when the "forward" is called. The same as "process_fn" and "learn": it splits the data and feeds them to each policy. A figure in @@ -28,7 +30,9 @@ def replace_policy(self, policy, agent_id): def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: - """Save original multi-dimensional rew in "save_rew", set rew to the + """Dispatch batch data from obs.agent_id to every policy's process_fn. + + Save original multi-dimensional rew in "save_rew", set rew to the reward of each agent during their ``process_fn``, and restore the original reward afterwards. """ @@ -57,7 +61,9 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, def forward(self, batch: Batch, state: Optional[Union[dict, Batch]] = None, **kwargs) -> Batch: - """:param state: if None, it means all agents have no state. If not + """Dispatch batch data from obs.agent_id to every policy's forward. + + :param state: if None, it means all agents have no state. If not None, it should contain keys of "agent_1", "agent_2", ... :return: a Batch with the following contents: @@ -120,7 +126,9 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, **kwargs ) -> Dict[str, Union[float, List[float]]]: - """:return: a dict with the following contents: + """Dispatch the data to all policies for learning. + + :return: a dict with the following contents: :: diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index a300e8c92..baac7425d 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -6,19 +6,20 @@ class RandomPolicy(BasePolicy): - """A random agent used in multi-agent learning. It randomly chooses an - action from the legal action. + """A random agent used in multi-agent learning. + + It randomly chooses an action from the legal action. """ def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, **kwargs) -> Batch: - """Compute the random action over the given batch data. The input - should contain a mask in batch.obs, with "True" to be available and - "False" to be unavailable. - For example, ``batch.obs.mask == np.array([[False, True, False]])`` - means with batch size 1, action "1" is available but action "0" and - "2" are unavailable. + """Compute the random action over the given batch data. + + The input should contain a mask in batch.obs, with "True" to be + available and "False" to be unavailable. For example, + ``batch.obs.mask == np.array([[False, True, False]])`` means with batch + size 1, action "1" is available but action "0" and "2" are unavailable. :return: A :class:`~tianshou.data.Batch` with "act" key, containing the random action. @@ -35,6 +36,5 @@ def forward(self, batch: Batch, def learn(self, batch: Batch, **kwargs ) -> Dict[str, Union[float, List[float]]]: - """No need of a learn function for a random agent, so it returns an - empty dict.""" + """Since a random agent learn nothing, it returns an empty dict.""" return {} diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index a7c6ffa36..584a2bf5e 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -28,8 +28,9 @@ def offpolicy_trainer( verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: - """A wrapper for off-policy trainer procedure. The ``step`` in trainer - means a policy network update. + """A wrapper for off-policy trainer procedure. + + The ``step`` in trainer means a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 0a564fc53..383002e22 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -28,8 +28,9 @@ def onpolicy_trainer( verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: - """A wrapper for on-policy trainer procedure. The ``step`` in trainer means - a policy network update. + """A wrapper for on-policy trainer procedure. + + The ``step`` in trainer means a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. diff --git a/tianshou/utils/compile.py b/tianshou/utils/compile.py index b3700b3cb..bf051bd74 100644 --- a/tianshou/utils/compile.py +++ b/tianshou/utils/compile.py @@ -1,12 +1,13 @@ import numpy as np -# functions that need to pre-compile for producing benchmark result from tianshou.policy.base import _episodic_return, _nstep_return from tianshou.data.utils.segtree import _reduce, _setitem, _get_prefix_sum_idx def pre_compile(): - """Since Numba acceleration needs to compile the function in the first run, + """Functions that need to pre-compile for producing benchmark result. + + Since Numba acceleration needs to compile the function in the first run, here we use some fake data for the common-type function-call compilation. Otherwise, the current training speed cannot compare with the previous. """ diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index 7dfc0b1a0..a138b1c1a 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -6,8 +6,9 @@ class MovAvg(object): - """Class for moving average. It will automatically exclude the infinity and - NaN. Usage: + """Class for moving average. + + It will automatically exclude the infinity and NaN. Usage: :: >>> stat = MovAvg(size=66) @@ -30,8 +31,10 @@ def __init__(self, size: int = 100) -> None: self.banned = [np.inf, np.nan, -np.inf] def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float: - """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with - only one element, a python scalar, or a list of python scalar. + """Add a scalar into :class:`MovAvg`. + + You can add ``torch.Tensor`` with only one element, a python scalar, or + a list of python scalar. """ if isinstance(x, torch.Tensor): x = to_numpy(x.flatten()) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index eb68a9710..cc64855ab 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -8,6 +8,7 @@ def miniblock(inp: int, oup: int, norm_layer: nn.modules.Module) -> List[nn.modules.Module]: + """Construct a miniblock with given input/output-size and norm layer.""" ret = [nn.Linear(inp, oup)] if norm_layer is not None: ret += [norm_layer(oup)] @@ -16,8 +17,10 @@ def miniblock(inp: int, oup: int, class Net(nn.Module): - """Simple MLP backbone. For advanced usage (how to customize the network), - please refer to :ref:`build_the_network`. + """Simple MLP backbone. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. :param bool concat: whether the input shape is concatenated by state_shape and action_shape. If it is True, ``action_shape`` is not the output @@ -25,7 +28,7 @@ class Net(nn.Module): :param bool dueling: whether to use dueling network to calculate Q values (for Dueling DQN), defaults to False. :param nn.modules.Module norm_layer: use which normalization before ReLU, - e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None. + e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to "None". """ def __init__(self, layer_num: int, state_shape: tuple, @@ -76,7 +79,7 @@ def __init__(self, layer_num: int, state_shape: tuple, self.model = nn.Sequential(*self.model) def forward(self, s, state=None, info={}): - """s -> flatten -> logits""" + """Mapping: s -> flatten -> logits.""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.reshape(s.size(0), -1) logits = self.model(s) @@ -89,8 +92,10 @@ def forward(self, s, state=None, info={}): class Recurrent(nn.Module): - """Simple Recurrent network based on LSTM. For advanced usage (how to - customize the network), please refer to :ref:`build_the_network`. + """Simple Recurrent network based on LSTM. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. """ def __init__(self, layer_num, state_shape, action_shape, @@ -106,9 +111,11 @@ def __init__(self, layer_num, state_shape, action_shape, self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape)) def forward(self, s, state=None, info={}): - """In the evaluation mode, s should be with shape ``[bsz, dim]``; in - the training mode, s should be with shape ``[bsz, len, dim]``. See the - code and comment for more detail. + """Mapping: s -> flatten -> logits. + + In the evaluation mode, s should be with shape ``[bsz, dim]``; in the + training mode, s should be with shape ``[bsz, len, dim]``. See the code + and comment for more detail. """ s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 03a11f59c..19586bb46 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -6,7 +6,9 @@ class Actor(nn.Module): - """For advanced usage (how to customize the network), please refer to + """Simple actor network with MLP. + + For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ @@ -18,14 +20,16 @@ def __init__(self, preprocess_net, action_shape, max_action=1., self._max = max_action def forward(self, s, state=None, info={}): - """s -> logits -> action""" + """Mapping: s -> logits -> action.""" logits, h = self.preprocess(s, state) logits = self._max * torch.tanh(self.last(logits)) return logits, h class Critic(nn.Module): - """For advanced usage (how to customize the network), please refer to + """Simple critic network with MLP. + + For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ @@ -36,7 +40,7 @@ def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128): self.last = nn.Linear(hidden_layer_size, 1) def forward(self, s, a=None, info={}): - """(s, a) -> logits -> Q(s, a)""" + """Mapping: (s, a) -> logits -> Q(s, a).""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) if a is not None: @@ -49,7 +53,9 @@ def forward(self, s, a=None, info={}): class ActorProb(nn.Module): - """For advanced usage (how to customize the network), please refer to + """Simple actor network (output with a Gauss distribution) with MLP. + + For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ @@ -64,7 +70,7 @@ def __init__(self, preprocess_net, action_shape, max_action=1., self._unbounded = unbounded def forward(self, s, state=None, info={}): - """s -> logits -> (mu, sigma)""" + """Mapping: s -> logits -> (mu, sigma).""" logits, h = self.preprocess(s, state) mu = self.mu(logits) if not self._unbounded: @@ -76,7 +82,9 @@ def forward(self, s, state=None, info={}): class RecurrentActorProb(nn.Module): - """For advanced usage (how to customize the network), please refer to + """Recurrent version of ActorProb. + + For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ @@ -121,7 +129,9 @@ def forward(self, s, state=None, info={}): class RecurrentCritic(nn.Module): - """For advanced usage (how to customize the network), please refer to + """Recurrent version of Critic. + + For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index c7fed2bcb..03f458398 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -5,7 +5,9 @@ class Actor(nn.Module): - """For advanced usage (how to customize the network), please refer to + """Simple actor network with MLP. + + For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ @@ -15,14 +17,16 @@ def __init__(self, preprocess_net, action_shape, hidden_layer_size=128): self.last = nn.Linear(hidden_layer_size, np.prod(action_shape)) def forward(self, s, state=None, info={}): - r"""s -> Q(s, \*)""" + r"""Mapping: s -> Q(s, \*).""" logits, h = self.preprocess(s, state) logits = F.softmax(self.last(logits), dim=-1) return logits, h class Critic(nn.Module): - """For advanced usage (how to customize the network), please refer to + """Simple critic network with MLP. + + For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ @@ -32,17 +36,17 @@ def __init__(self, preprocess_net, hidden_layer_size=128): self.last = nn.Linear(hidden_layer_size, 1) def forward(self, s, **kwargs): - """s -> V(s)""" + """Mapping: s -> V(s).""" logits, h = self.preprocess(s, state=kwargs.get('state', None)) logits = self.last(logits) return logits class DQN(nn.Module): - """For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. + """Reference: Human-level control through deep reinforcement learning. - Reference paper: "Human-level control through deep reinforcement learning". + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. """ def __init__(self, c, h, w, action_shape, device='cpu'): @@ -78,7 +82,7 @@ def conv2d_layers_size_out(size, ) def forward(self, x, state=None, info={}): - r"""x -> Q(x, \*)""" + r"""Mapping: x -> Q(x, \*).""" if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32) return self.net(x), state From 4696eea508e8994d4e1e886d74b92afc841d68e4 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 10 Sep 2020 21:25:46 +0800 Subject: [PATCH 2/3] docs update --- .github/workflows/lint_and_docs.yml | 2 +- .github/workflows/profile.yml | 2 +- .github/workflows/pytest.yml | 2 +- docs/contributing.rst | 4 ++-- docs/tutorials/cheatsheet.rst | 12 ++++++---- docs/tutorials/concepts.rst | 7 +++--- docs/tutorials/dqn.rst | 8 +++---- tianshou/data/batch.py | 21 ++++++++++------- tianshou/data/buffer.py | 17 +++++++------- tianshou/data/collector.py | 8 +++---- tianshou/env/venvs.py | 8 +++---- tianshou/env/worker/base.py | 8 +++---- tianshou/policy/base.py | 32 ++++++++++++-------------- tianshou/policy/modelfree/a2c.py | 4 ++-- tianshou/policy/modelfree/ddpg.py | 4 ++-- tianshou/policy/modelfree/dqn.py | 6 ++--- tianshou/policy/modelfree/ppo.py | 6 ++--- tianshou/policy/modelfree/sac.py | 4 ++-- tianshou/policy/modelfree/td3.py | 4 ++-- tianshou/policy/multiagent/mapolicy.py | 2 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 2 +- tianshou/utils/net/common.py | 2 +- 23 files changed, 86 insertions(+), 81 deletions(-) diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index 50e307a68..07e1defe0 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -16,7 +16,7 @@ jobs: python -m pip install --upgrade pip setuptools wheel - name: Install dependencies run: | - pip install ".[dev]" --upgrade + python -m pip install ".[dev]" --upgrade - name: Lint with flake8 run: | flake8 . --count --show-source --statistics diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml index 3bdd8ea8a..5b0f3581b 100644 --- a/.github/workflows/profile.yml +++ b/.github/workflows/profile.yml @@ -16,7 +16,7 @@ jobs: python -m pip install --upgrade pip setuptools wheel - name: Install dependencies run: | - pip install ".[dev]" --upgrade + python -m pip install ".[dev]" --upgrade - name: Test with pytest run: | pytest test/throughput --durations=0 -v diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 84ab74f65..1b4277010 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -26,7 +26,7 @@ jobs: python -m pip install --upgrade pip setuptools wheel - name: Install dependencies run: | - pip install ".[dev]" --upgrade + python -m pip install ".[dev]" --upgrade - name: Test with pytest # ignore test/throughput which only profiles the code run: | diff --git a/docs/contributing.rst b/docs/contributing.rst index dc7be4f07..b56a7a440 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -77,8 +77,8 @@ under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/. -Documentation Test ------------------- +Documentation Generation Test +----------------------------- We have the following three documentation tests: diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 2784eac3f..fd58834c3 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -70,7 +70,7 @@ Asynchronous simulation is a built-in functionality of :class:`~tianshou.env.Bas # DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like. venv = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.2) venv.reset() # returns the initial observations of each environment - # returns ``wait_num`` steps or finished steps after ``timeout`` seconds, + # returns "wait_num" steps or finished steps after "timeout" seconds, # whichever occurs first. venv.step(actions, ready_id) @@ -101,7 +101,11 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps. For example, you can write your hook as: +This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`. It returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps. + +These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation. + +For example, you can write your hook as: :: import numpy as np @@ -152,7 +156,7 @@ RNN-style Training This is related to `Issue 19 `_. -First, add an argument ``stack_num`` to :class:`~tianshou.data.ReplayBuffer`: +First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`: :: buf = ReplayBuffer(size=size, stack_num=stack_num) @@ -188,7 +192,7 @@ First of all, your self-defined environment must follow the Gym's API, some of t - action_space: gym.Space -The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example: +The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1" as an example: :: >>> e = gym.make('FetchReach-v1') diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index c09db285d..d7f5971d2 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -80,10 +80,9 @@ policy.forward The ``forward`` function computes the action over given observations. The input and output is algorithm-specific but generally, the function is a mapping of ``(batch, state, ...) -> batch``. -The input batch is the environment data. The first dimension of all variables in the input ``batch`` should be equal to the batch-size. +The input batch is the environment data (e.g., observation, reward, done flag and info). It comes from either :meth:`~tianshou.data.Collector.collect` or :meth:`~tianshou.data.ReplayBuffer.sample`. The first dimension of all variables in the input ``batch`` should be equal to the batch-size. -The output is also a Batch which may contain "act", "state", "policy", and some other algorithm-specific keys. -The keyword "policy" is reserved and the corresponding data will be stored into the replay buffer. Checkout :meth:`~tianshou.policy.BasePolicy.forward` for more explanation. +The output is also a Batch which must contain "act" (action) and may contain "state" (hidden state of policy), "policy" (the intermediate result of policy which needs to save into the buffer, see :meth:`~tianshou.policy.BasePolicy.forward`), and some other algorithm-specific keys. For example, if you try to use your policy to evaluate one episode (and don't want to use :meth:`~tianshou.data.Collector.collect`), use the following code-snippet: :: @@ -164,7 +163,7 @@ The :class:`~tianshou.data.Collector` enables the policy to interact with differ Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. -Our solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. +The proposed solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index a391069fd..d923b5669 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -26,7 +26,7 @@ CartPole-v0 is a simple environment with a discrete action space, for which DQN Setup Multi-environment Wrapper ------------------------------- -It is okay if you want the original ``gym.Env``: +If you want to use the original ``gym.Env``: :: train_envs = gym.make('CartPole-v0') @@ -59,7 +59,7 @@ For the demonstration, here we use the second code-block. Build the Network ----------------- -Tianshou supports any user-defined PyTorch networks and optimizers only with the limitation of input and output API. Here is an example: +Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of course, the inputs and outputs must comply with Tianshou's API. Here is an example: :: import torch, numpy as np @@ -87,10 +87,10 @@ Tianshou supports any user-defined PyTorch networks and optimizers only with the net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -You can have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: +It is also possible to use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. -2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or containing intermediate result during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. +2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. Setup Policy diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 3726126e5..7f7f10b3b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -20,7 +20,7 @@ def _is_batch_set(data: Any) -> bool: # or 1-D np.ndarray with np.object type, # where each element is a dict/Batch object if isinstance(data, np.ndarray): # most often case - # ``for e in data`` will just unpack the first dimension, + # "for e in data" will just unpack the first dimension, # but data.tolist() will flatten ndarray of objects # so do not use data.tolist() return data.dtype == np.object and \ @@ -155,8 +155,13 @@ def _parse_value(v: Any): class Batch: - """Batch the internal data structure to pass any kind of data to other \ - methods, for example, a collector gives a batch data to policy.forward(). + """The internal data structure in Tianshou. + + Batch is a kind of supercharged array (of temporal data) stored + individually in a (recursive) dictionary of object that can be either numpy + array, torch tensor, or batch themself. It is designed to make it extremely + easily to access, manipulate and set partial view of the heterogeneous data + conveniently. For a detailed description, please refer to :ref:`batch_concept`. """ @@ -571,7 +576,7 @@ def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': def empty_(self, index: Union[ str, slice, int, np.integer, np.ndarray, List[int]] = None ) -> 'Batch': - """Return an empty Batch object with 0 or "None" filled. + """Return an empty Batch object with 0 or None filled. If "index" is specified, it will only reset the specific indexed-data. :: @@ -620,7 +625,7 @@ def empty_(self, index: Union[ def empty(batch: 'Batch', index: Union[ str, slice, int, np.integer, np.ndarray, List[int]] = None ) -> 'Batch': - """Return an empty Batch object with 0 or "None" filled. + """Return an empty Batch object with 0 or None filled. The shape is the same as the given Batch. """ @@ -707,11 +712,11 @@ def split(self, size: int, shuffle: bool = True, """Split whole data into multiple small batches. :param int size: divide the data batch with the given size, but one - batch if the length of the batch is smaller than ``size``. + batch if the length of the batch is smaller than "size". :param bool shuffle: randomly shuffle the entire data batch if it is - ``True``, otherwise remain in the same. Default to ``True``. + True, otherwise remain in the same. Default to True. :param bool merge_last: merge the last batch into the previous one. - Default to ``False``. + Default to False. """ length = len(self) assert 1 <= size # size can be greater than length, return whole batch diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c636baf19..e9f47b9d8 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -115,13 +115,12 @@ class ReplayBuffer: :param int size: the size of replay buffer. :param int stack_num: the frame-stack sampling argument, should be greater than or equal to 1, defaults to 1 (no stacking). - :param bool ignore_obs_next: whether to store obs_next, defaults to - ``False``. + :param bool ignore_obs_next: whether to store obs_next, defaults to False. :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape of (timestep, ...) because of temporal stacking, defaults to - ``False``. + False. :param bool sample_avail: the parameter indicating sampling only available - index when using frame-stack sampling method, defaults to ``False``. + index when using frame-stack sampling method, defaults to False. This feature is not supported in Prioritized Replay Buffer currently. """ @@ -162,7 +161,7 @@ def __setstate__(self, state): """Unpickling interface. We need it because pickling buffer does not work out-of-the-box - (``buffer.__getattr__`` is customized). + ("buffer.__getattr__" is customized). """ self.__dict__.update(state) @@ -264,7 +263,7 @@ def reset(self) -> None: def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size equal to batch_size. \ - Return all the data in the buffer if batch_size is ``0``. + Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. """ @@ -352,7 +351,7 @@ class ListReplayBuffer(ReplayBuffer): The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same as :class:`~tianshou.data.ReplayBuffer`. The only difference is that - :class:`~tianshou.data.ListReplayBuffer` is based on ``list``. Therefore, + :class:`~tianshou.data.ListReplayBuffer` is based on list. Therefore, it does not support advanced indexing, which means you cannot sample a batch of data out of it. It is typically used for storing data. @@ -426,11 +425,11 @@ def add(self, def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with priority probability. - Return all the data in the buffer if batch_size is ``0``. + Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. - The ``weight`` in the returned Batch is the weight on loss function + The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 5c2f28068..65f869a3e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -24,7 +24,7 @@ class Collector(object): class. If set to ``None`` (testing phase), it will not store the data. :param function preprocess_fn: a function called before the data has been added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults - to ``None``. + to None. :param BaseNoise action_noise: add a noise to continuous action. Normally a policy already has a noise param for exploration in training phase, so this is recommended to use in test collector for some purpose. @@ -172,11 +172,11 @@ def collect(self, a list, it means to collect exactly ``n_episode[i]`` episodes in the i-th environment :param bool random: whether to use random policy for collecting data, - defaults to ``False``. + defaults to False. :param float render: the sleep time between rendering consecutive - frames, defaults to ``None`` (no rendering). + frames, defaults to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward, - defaults to ``True`` (no gradient retaining). + defaults to True (no gradient retaining). .. note:: diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index fa6cf812d..72c5b9e23 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -105,7 +105,7 @@ def __getattribute__(self, key: str) -> Any: Any class who inherits ``gym.Env`` will inherit some attributes, like ``action_space``. However, we would like the attribute lookup to go straight into the worker (in fact, this vector env's action_space is - always ``None``). + always None). """ if key in ['metadata', 'reward_range', 'spec', 'action_space', 'observation_space']: # reserved keys in gym.Env @@ -141,7 +141,7 @@ def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None ) -> np.ndarray: """Reset the state of some envs and return initial observations. - If id is "None", reset the state of all the environments and return + If id is None, reset the state of all the environments and return initial observations, otherwise reset the specific environments with the given id, either an int or a list. """ @@ -158,7 +158,7 @@ def step(self, ) -> List[np.ndarray]: """Run one timestep of some environments' dynamics. - If id is "None", run one timestep of all the environments’ dynamics; + If id is None, run one timestep of all the environments’ dynamics; otherwise run one timestep for some environments with given id, either an int or a list. When the end of episode is reached, you are responsible for calling reset(id) to reset this environment’s state. @@ -185,7 +185,7 @@ def step(self, should correspond to the ``id`` argument, and the ``id`` argument should be a subset of the ``env_id`` in the last returned ``info`` (initially they are env_ids of all the environments). If action is - "None", fetch unfinished step() calls instead. + None, fetch unfinished step() calls instead. """ self._assert_is_not_closed() id = self._wrap_id(id) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index e1a0708e7..c3600fafe 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -32,10 +32,10 @@ def step(self, action: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Perform one timestep of the environment's dynamic. - ``send_action`` and ``get_result`` are coupled in sync simulation, so - typically users only call ``step`` function. But they can be called - separately in async simulation, i.e. someone calls ``send_action`` - first, and calls ``get_result`` later. + "send_action" and "get_result" are coupled in sync simulation, so + typically users only call "step" function. But they can be called + separately in async simulation, i.e. someone calls "send_action" first, + and calls "get_result" later. """ self.send_action(action) return self.get_result() diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 0a079ca9f..a380670c9 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -31,27 +31,25 @@ class BasePolicy(ABC, nn.Module): Most of the policy needs a neural network to predict the action and an optimizer to optimize the policy. The rules of self-defined networks are: - 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, a \ - ``torch.Tensor``, a dict or any others), hidden state ``state`` (for \ - RNN usage), and other information ``info`` provided by the \ - environment. - 2. Output: some ``logits``, the next hidden state ``state``, and the \ - intermediate result during policy forwarding procedure ``policy``. The\ - ``logits`` could be a tuple instead of a ``torch.Tensor``. It depends \ - on how the policy process the network output. For example, in PPO, the\ - return of the network might be ``(mu, sigma), state`` for Gaussian \ - policy. The ``policy`` can be a Batch of torch.Tensor or other things,\ - which will be stored in the replay buffer, and can be accessed in the \ - policy update process (e.g. in ``policy.learn()``, the \ - ``batch.policy`` is what you need). + 1. Input: observation "obs" (may be a ``numpy.ndarray``, a \ + ``torch.Tensor``, a dict or any others), hidden state "state" (for RNN \ + usage), and other information "info" provided by the environment. + 2. Output: some "logits", the next hidden state "state", and the \ + intermediate result during policy forwarding procedure "policy". The \ + "logits" could be a tuple instead of a ``torch.Tensor``. It depends on how\ + the policy process the network output. For example, in PPO, the return of \ + the network might be ``(mu, sigma), state`` for Gaussian policy. The \ + "policy" can be a Batch of torch.Tensor or other things, which will be \ + stored in the replay buffer, and can be accessed in the policy update \ + process (e.g. in "policy.learn()", the "batch.policy" is what you need). Since :class:`~tianshou.policy.BasePolicy` inherits ``torch.nn.Module``, you can use :class:`~tianshou.policy.BasePolicy` almost the same as ``torch.nn.Module``, for instance, loading and saving the model: :: - torch.save(policy.state_dict(), 'policy.pth') - policy.load_state_dict(torch.load('policy.pth')) + torch.save(policy.state_dict(), "policy.pth") + policy.load_state_dict(torch.load("policy.pth")) """ def __init__(self, @@ -176,7 +174,7 @@ def compute_episodic_return( :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to ``False``. + to False. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). @@ -223,7 +221,7 @@ def compute_nstep_return( :param int n_step: the number of estimation step, should be an int greater than 0, defaults to 1. :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to ``False``. + to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with shape (bsz, ). diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index ec872c002..dfdc02db4 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -21,11 +21,11 @@ class A2CPolicy(PGPolicy): :param float vf_coef: weight for value loss, defaults to 0.5. :param float ent_coef: weight for entropy loss, defaults to 0.01. :param float max_grad_norm: clipping gradients in back propagation, - defaults to ``None``. + defaults to None. :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation, defaults to 0.95. :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to ``False``. + defaults to False. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint; diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 5b11ae886..f8cf60ade 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -25,9 +25,9 @@ class DDPGPolicy(BasePolicy): :param action_range: the action range (minimum, maximum). :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to ``False``. + defaults to False. :param bool ignore_done: ignore the done flag while training the policy, - defaults to ``False``. + defaults to False. :param int estimation_step: greater than 1, the number of steps to look ahead. diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 95e7be6f3..401875070 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -21,10 +21,10 @@ class DQNPolicy(BasePolicy): :param float discount_factor: in [0, 1]. :param int estimation_step: greater than 1, the number of steps to look ahead. - :param int target_update_freq: the target network update frequency (``0`` - if you do not use the target network). + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to ``False``. + defaults to False. .. seealso:: diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 9a4d1a26f..df84eb989 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -18,7 +18,7 @@ class PPOPolicy(PGPolicy): :param dist_fn: distribution class for computing the action. :param float discount_factor: in [0, 1], defaults to 0.99. :param float max_grad_norm: clipping gradients in back propagation, - defaults to ``None``. + defaults to None. :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original paper, defaults to 0.2. :param float vf_coef: weight for value loss, defaults to 0.5. @@ -31,9 +31,9 @@ class PPOPolicy(PGPolicy): where c > 1 is a constant indicating the lower bound, defaults to 5.0 (set ``None`` if you do not want to use it). :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1, - defaults to ``True``. + defaults to True. :param bool reward_normalization: normalize the returns to Normal(0, 1), - defaults to ``True``. + defaults to True. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the model; should be as large as possible within the memory constraint; diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index efb7f9549..920bfb167 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -35,9 +35,9 @@ class SACPolicy(DDPGPolicy): :param action_range: the action range (minimum, maximum). :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to ``False``. + defaults to False. :param bool ignore_done: ignore the done flag while training the policy, - defaults to ``False``. + defaults to False. :param BaseNoise exploration_noise: add a noise to action for exploration. This is useful when solving hard-exploration problem. diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 8ed093563..384d4b9bd 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -36,9 +36,9 @@ class TD3Policy(DDPGPolicy): :param action_range: the action range (minimum, maximum). :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to ``False``. + defaults to False. :param bool ignore_done: ignore the done flag while training the policy, - defaults to ``False``. + defaults to False. .. seealso:: diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index bf7b18229..541481e2e 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -33,7 +33,7 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer, """Dispatch batch data from obs.agent_id to every policy's process_fn. Save original multi-dimensional rew in "save_rew", set rew to the - reward of each agent during their ``process_fn``, and restore the + reward of each agent during their "process_fn", and restore the original reward afterwards. """ results = {} diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 584a2bf5e..c04a4b556 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -30,7 +30,7 @@ def offpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. - The ``step`` in trainer means a policy network update. + The "step" in trainer means a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 383002e22..6af43173c 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -30,7 +30,7 @@ def onpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. - The ``step`` in trainer means a policy network update. + The "step" in trainer means a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index cc64855ab..8c4fcc52b 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -28,7 +28,7 @@ class Net(nn.Module): :param bool dueling: whether to use dueling network to calculate Q values (for Dueling DQN), defaults to False. :param nn.modules.Module norm_layer: use which normalization before ReLU, - e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to "None". + e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None. """ def __init__(self, layer_num: int, state_shape: tuple, From 51e167cfea60d2025de1268f795a97919875cfd0 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 10 Sep 2020 21:29:50 +0800 Subject: [PATCH 3/3] fix test --- docs/tutorials/cheatsheet.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index fd58834c3..c088a8db6 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -101,7 +101,7 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`. It returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps. +This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`. It returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps. These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.