diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index d72600676..280583538 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -3,14 +3,7 @@
+ [ ] algorithm implementation fix
+ [ ] documentation modification
+ [ ] new feature
+- [ ] I have reformatted the code using `make format` (**required**)
+- [ ] I have checked the code using `make commit-checks` (**required**)
- [ ] If applicable, I have mentioned the relevant/related issue(s)
-
-Less important but also useful:
-
-- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou)
-- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
-- [ ] I have mentioned version numbers, operating system and environment, where applicable:
- ```python
- import tianshou, torch, numpy, sys
- print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
- ```
+- [ ] If applicable, I have listed every items in this Pull Request below
diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml
index 3a69bfa71..681654689 100644
--- a/.github/workflows/lint_and_docs.yml
+++ b/.github/workflows/lint_and_docs.yml
@@ -20,13 +20,14 @@ jobs:
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
+ - name: Code formatter
+ run: |
+ yapf -r -d .
+ isort --check .
- name: Type check
run: |
mypy
- name: Documentation test
run: |
- pydocstyle tianshou
- doc8 docs --max-line-length 1000
- cd docs
- make html SPHINXOPTS="-W"
- cd ..
+ make check-docstyle
+ make spelling
diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml
index 3d62da417..82c793e99 100644
--- a/.github/workflows/profile.yml
+++ b/.github/workflows/profile.yml
@@ -5,6 +5,7 @@ on: [push, pull_request]
jobs:
profile:
runs-on: ubuntu-latest
+ if: "!contains(github.event.head_commit.message, 'ci skip')"
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
diff --git a/.gitignore b/.gitignore
index be8453abd..e9510a1df 100644
--- a/.gitignore
+++ b/.gitignore
@@ -147,3 +147,4 @@ MUJOCO_LOG.TXT
*.swp
*.pkl
*.hdf5
+wandb/
diff --git a/Makefile b/Makefile
new file mode 100644
index 000000000..da5030ccd
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,60 @@
+SHELL=/bin/bash
+PROJECT_NAME=tianshou
+PROJECT_PATH=${PROJECT_NAME}/
+LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py
+
+check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
+check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
+
+pytest:
+ $(call check_install, pytest)
+ $(call check_install, pytest_cov)
+ $(call check_install, pytest_xdist)
+ pytest test --cov ${PROJECT_PATH} --durations 0 -v --cov-report term-missing
+
+mypy:
+ $(call check_install, mypy)
+ mypy ${PROJECT_NAME}
+
+lint:
+ $(call check_install, flake8)
+ $(call check_install_extra, bugbear, flake8_bugbear)
+ flake8 ${LINT_PATHS} --count --show-source --statistics
+
+format:
+ # sort imports
+ $(call check_install, isort)
+ isort ${LINT_PATHS}
+ # reformat using yapf
+ $(call check_install, yapf)
+ yapf -ir ${LINT_PATHS}
+
+check-codestyle:
+ $(call check_install, isort)
+ $(call check_install, yapf)
+ isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS}
+
+check-docstyle:
+ $(call check_install, pydocstyle)
+ $(call check_install, doc8)
+ $(call check_install, sphinx)
+ $(call check_install, sphinx_rtd_theme)
+ pydocstyle ${PROJECT_PATH} && doc8 docs && cd docs && make html SPHINXOPTS="-W"
+
+doc:
+ $(call check_install, sphinx)
+ $(call check_install, sphinx_rtd_theme)
+ cd docs && make html && cd _build/html && python3 -m http.server
+
+spelling:
+ $(call check_install, sphinx)
+ $(call check_install, sphinx_rtd_theme)
+ $(call check_install_extra, sphinxcontrib.spelling, sphinxcontrib.spelling pyenchant)
+ cd docs && make spelling SPHINXOPTS="-W"
+
+clean:
+ cd docs && make clean
+
+commit-checks: format lint mypy check-docstyle spelling
+
+.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks
diff --git a/docs/conf.py b/docs/conf.py
index c258cf0a7..57d8b48df 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -14,9 +14,9 @@
# import sys
# sys.path.insert(0, os.path.abspath('.'))
+import sphinx_rtd_theme
import tianshou
-import sphinx_rtd_theme
# Get the version string
version = tianshou.__version__
@@ -30,7 +30,6 @@
# The full version, including alpha/beta/rc tags
release = version
-
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
@@ -51,7 +50,7 @@
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
-source_suffix = [".rst", ".md"]
+source_suffix = [".rst"]
master_doc = "index"
# List of patterns, relative to source directory, that match files and
@@ -59,7 +58,8 @@
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
autodoc_default_options = {
- "special-members": ", ".join(
+ "special-members":
+ ", ".join(
[
"__len__",
"__call__",
diff --git a/docs/contributing.rst b/docs/contributing.rst
index cf015a95e..d1de0b65b 100644
--- a/docs/contributing.rst
+++ b/docs/contributing.rst
@@ -18,14 +18,26 @@ in the main directory. This installation is removable by
$ python setup.py develop --uninstall
-PEP8 Code Style Check
----------------------
+PEP8 Code Style Check and Code Formatter
+----------------------------------------
-We follow PEP8 python code style. To check, in the main directory, run:
+We follow PEP8 python code style with flake8. To check, in the main directory, run:
.. code-block:: bash
- $ flake8 . --count --show-source --statistics
+ $ make lint
+
+We use isort and yapf to format all codes. To format, in the main directory, run:
+
+.. code-block:: bash
+
+ $ make format
+
+To check if formatted correctly, in the main directory, run:
+
+.. code-block:: bash
+
+ $ make check-codestyle
Type Check
@@ -35,7 +47,7 @@ We use `mypy `_ to check the type annotations.
.. code-block:: bash
- $ mypy
+ $ make mypy
Test Locally
@@ -45,7 +57,7 @@ This command will run automatic tests in the main directory
.. code-block:: bash
- $ pytest test --cov tianshou -s --durations 0 -v
+ $ make pytest
Test by GitHub Actions
@@ -76,13 +88,13 @@ Documentations are written under the ``docs/`` directory as ReStructuredText (``
API References are automatically generated by `Sphinx `_ according to the outlines under ``docs/api/`` and should be modified when any code changes.
-To compile documentation into webpages, run
+To compile documentation into webpage, run
.. code-block:: bash
- $ make html
+ $ make doc
-under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
+The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/).
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
@@ -92,21 +104,14 @@ Documentation Generation Test
We have the following three documentation tests:
-1. pydocstyle: test docstrings under ``tianshou/``. To check, in the main directory, run:
-
-.. code-block:: bash
-
- $ pydocstyle tianshou
-
-2. doc8: test ReStructuredText formats. To check, in the main directory, run:
+1. pydocstyle: test all docstring under ``tianshou/``;
-.. code-block:: bash
+2. doc8: test ReStructuredText format;
- $ doc8 docs
+3. sphinx test: test if there is any error/warning when generating front-end html documentation.
-3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. To check, in the main directory, run:
+To check, in the main directory, run:
.. code-block:: bash
- $ cd docs
- $ make html SPHINXOPTS="-W"
+ $ make check-docstyle
diff --git a/docs/contributor.rst b/docs/contributor.rst
index c594b2c0d..d71dc385b 100644
--- a/docs/contributor.rst
+++ b/docs/contributor.rst
@@ -4,7 +4,6 @@ Contributor
We always welcome contributions to help make Tianshou better. Below are an incomplete list of our contributors (find more on `this page `_).
* Jiayi Weng (`Trinkle23897 `_)
-* Minghao Zhang (`Mehooz `_)
* Alexis Duburcq (`duburcqa `_)
* Kaichao You (`youkaichao `_)
* Huayu Chen (`ChenDRAG `_)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
new file mode 100644
index 000000000..3649df71d
--- /dev/null
+++ b/docs/spelling_wordlist.txt
@@ -0,0 +1,135 @@
+tianshou
+arXiv
+tanh
+lr
+logits
+env
+envs
+optim
+eps
+timelimit
+TimeLimit
+maxsize
+timestep
+numpy
+ndarray
+stackoverflow
+len
+tac
+fqf
+iqn
+qrdqn
+rl
+quantile
+quantiles
+dqn
+param
+async
+subprocess
+nn
+equ
+cql
+fn
+boolean
+pre
+np
+rnn
+rew
+pre
+perceptron
+bsz
+dataset
+mujoco
+jit
+nstep
+preprocess
+repo
+ReLU
+namespace
+th
+utils
+NaN
+linesearch
+hyperparameters
+pseudocode
+entropies
+nn
+config
+cpu
+rms
+debias
+indice
+regularizer
+miniblock
+modularize
+serializable
+softmax
+vectorized
+optimizers
+undiscounted
+submodule
+subclasses
+submodules
+tfevent
+dirichlet
+docstring
+webpage
+formatter
+num
+py
+pythonic
+中文文档位于
+conda
+miniconda
+Amir
+Andreas
+Antonoglou
+Beattie
+Bellemare
+Charles
+Daan
+Demis
+Dharshan
+Fidjeland
+Georg
+Hassabis
+Helen
+Ioannis
+Kavukcuoglu
+King
+Koray
+Kumaran
+Legg
+Mnih
+Ostrovski
+Petersen
+Riedmiller
+Rusu
+Sadik
+Shane
+Stig
+Veness
+Volodymyr
+Wierstra
+Lillicrap
+Pritzel
+Heess
+Erez
+Yuval
+Tassa
+Schulman
+Filip
+Wolski
+Prafulla
+Dhariwal
+Radford
+Oleg
+Klimov
+Kaichao
+Jiayi
+Weng
+Duburcq
+Huayu
+Strens
+Ornstein
+Uhlenbeck
diff --git a/docs/tutorials/batch.rst b/docs/tutorials/batch.rst
index 49d913d9f..71f82f84e 100644
--- a/docs/tutorials/batch.rst
+++ b/docs/tutorials/batch.rst
@@ -60,7 +60,7 @@ The content of ``Batch`` objects can be defined by the following rules.
2. The keys are always strings (they are names of corresponding values).
-3. The values can be scalars, tensors, or Batch objects. The recurse definition makes it possible to form a hierarchy of batches.
+3. The values can be scalars, tensors, or Batch objects. The recursive definition makes it possible to form a hierarchy of batches.
4. Tensors are the most important values. In short, tensors are n-dimensional arrays of the same data type. We support two types of tensors: `PyTorch `_ tensor type ``torch.Tensor`` and `NumPy `_ tensor type ``np.ndarray``.
@@ -348,7 +348,7 @@ The introduction of reserved keys gives rise to the need to check if a key is re
-The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recurse emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
+The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes).
.. note::
@@ -492,7 +492,7 @@ Miscellaneous Notes
-2. It is often the case that the observations returned from the environment are NumPy ndarrays but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors.
+2. It is often the case that the observations returned from the environment are all NumPy ndarray but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors.
3. ``obj.stack_([a, b])`` is the same as ``Batch.stack([obj, a, b])``, and ``obj.cat_([a, b])`` is the same as ``Batch.cat([obj, a, b])``. Considering the frequent requirement of concatenating two ``Batch`` objects, Tianshou also supports ``obj.cat_(a)`` to be an alias of ``obj.cat_([a])``.
diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst
index 38a0291f4..c224b193b 100644
--- a/docs/tutorials/cheatsheet.rst
+++ b/docs/tutorials/cheatsheet.rst
@@ -341,7 +341,7 @@ With the flexible core APIs, Tianshou can support multi-agent reinforcement lear
Currently, we support three types of multi-agent reinforcement learning paradigms:
-1. Simultaneous move: at each timestep, all the agents take their actions (example: moba games)
+1. Simultaneous move: at each timestep, all the agents take their actions (example: MOBA games)
2. Cyclic move: players take action in turn (example: Go game)
@@ -371,4 +371,4 @@ By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we
action = policy(state_)
next_state_, reward = env.step(action)
-Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-lerning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.
+Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.
diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst
index 7222d9116..b1f76deb7 100644
--- a/docs/tutorials/concepts.rst
+++ b/docs/tutorials/concepts.rst
@@ -219,7 +219,7 @@ Tianshou provides other type of data buffer such as :class:`~tianshou.data.Prior
Policy
------
-Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
+Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`.
A policy class typically has the following parts:
diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py
index b4cd5b62d..1be441013 100644
--- a/examples/atari/atari_bcq.py
+++ b/examples/atari/atari_bcq.py
@@ -1,21 +1,21 @@
+import argparse
+import datetime
import os
-import torch
import pickle
import pprint
-import datetime
-import argparse
+
import numpy as np
+import torch
+from atari_network import DQN
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import DiscreteBCQPolicy
from tianshou.trainer import offline_trainer
+from tianshou.utils import TensorboardLogger
from tianshou.utils.net.discrete import Actor
-from tianshou.policy import DiscreteBCQPolicy
-from tianshou.data import Collector, VectorReplayBuffer
-
-from atari_network import DQN
-from atari_wrapper import wrap_deepmind
def get_args():
@@ -38,15 +38,19 @@ def get_args():
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
- parser.add_argument("--watch", default=False, action="store_true",
- help="watch the play of pre-trained policy only")
+ parser.add_argument(
+ "--watch",
+ default=False,
+ action="store_true",
+ help="watch the play of pre-trained policy only"
+ )
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
- "--load-buffer-name", type=str,
- default="./expert_DQN_PongNoFrameskip-v4.hdf5")
+ "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
+ )
parser.add_argument(
- "--device", type=str,
- default="cuda" if torch.cuda.is_available() else "cpu")
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
args = parser.parse_known_args()[0]
return args
@@ -56,8 +60,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_discrete_bcq(args=get_args()):
@@ -69,32 +77,43 @@ def test_discrete_bcq(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
- feature_net = DQN(*args.state_shape, args.action_shape,
- device=args.device, features_only=True).to(args.device)
+ feature_net = DQN(
+ *args.state_shape, args.action_shape, device=args.device, features_only=True
+ ).to(args.device)
policy_net = Actor(
- feature_net, args.action_shape, device=args.device,
- hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
+ feature_net,
+ args.action_shape,
+ device=args.device,
+ hidden_sizes=args.hidden_sizes,
+ softmax_output=False
+ ).to(args.device)
imitation_net = Actor(
- feature_net, args.action_shape, device=args.device,
- hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
+ feature_net,
+ args.action_shape,
+ device=args.device,
+ hidden_sizes=args.hidden_sizes,
+ softmax_output=False
+ ).to(args.device)
optim = torch.optim.Adam(
- list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr)
+ list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
+ )
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
- args.target_update_freq, args.eps_test,
- args.unlikely_action_threshold, args.imitation_logits_penalty)
+ args.target_update_freq, args.eps_test, args.unlikely_action_threshold,
+ args.imitation_logits_penalty
+ )
# load a previous policy
if args.resume_path:
- policy.load_state_dict(torch.load(
- args.resume_path, map_location=args.device))
+ policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
@@ -113,7 +132,8 @@ def test_discrete_bcq(args=get_args()):
# log
log_path = os.path.join(
args.logdir, args.task, 'bcq',
- f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
+ f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
+ )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=args.log_interval)
@@ -132,8 +152,7 @@ def watch():
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -143,9 +162,17 @@ def watch():
exit(0)
result = offline_trainer(
- policy, buffer, test_collector, args.epoch,
- args.update_per_epoch, args.test_num, args.batch_size,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ buffer,
+ test_collector,
+ args.epoch,
+ args.update_per_epoch,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py
index b70cc6cc0..291fb7007 100644
--- a/examples/atari/atari_c51.py
+++ b/examples/atari/atari_c51.py
@@ -1,18 +1,18 @@
+import argparse
import os
-import torch
import pprint
-import argparse
+
import numpy as np
+import torch
+from atari_network import C51
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import C51Policy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
-
-from atari_network import C51
-from atari_wrapper import wrap_deepmind
+from tianshou.utils import TensorboardLogger
def get_args():
@@ -40,12 +40,16 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@@ -55,8 +59,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_c51(args=get_args()):
@@ -67,23 +75,30 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
- for _ in range(args.training_num)])
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ train_envs = SubprocVectorEnv(
+ [lambda: make_atari_env(args) for _ in range(args.training_num)]
+ )
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
- net = C51(*args.state_shape, args.action_shape,
- args.num_atoms, args.device)
+ net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(
- net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
- args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ args.gamma,
+ args.num_atoms,
+ args.v_min,
+ args.v_max,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
@@ -92,8 +107,12 @@ def test_c51(args=get_args()):
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
- save_only_last_obs=True, stack_num=args.frames_stack)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
@@ -136,11 +155,13 @@ def watch():
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(test_envs),
- ignore_obs_next=True, save_only_last_obs=True,
- stack_num=args.frames_stack)
- collector = Collector(policy, test_envs, buffer,
- exploration_noise=True)
+ args.buffer_size,
+ buffer_num=len(test_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
+ collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
@@ -148,8 +169,9 @@ def watch():
else:
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(
+ n_episode=args.test_num, render=args.render
+ )
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -161,11 +183,22 @@ def watch():
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_cql.py b/examples/atari/atari_cql.py
index cbab82029..db4e33a9a 100644
--- a/examples/atari/atari_cql.py
+++ b/examples/atari/atari_cql.py
@@ -1,20 +1,20 @@
+import argparse
+import datetime
import os
-import torch
import pickle
import pprint
-import datetime
-import argparse
+
import numpy as np
+import torch
+from atari_network import QRDQN
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
-from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
-from tianshou.data import Collector, VectorReplayBuffer
-
-from atari_network import QRDQN
-from atari_wrapper import wrap_deepmind
+from tianshou.trainer import offline_trainer
+from tianshou.utils import TensorboardLogger
def get_args():
@@ -37,15 +37,19 @@ def get_args():
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
- parser.add_argument("--watch", default=False, action="store_true",
- help="watch the play of pre-trained policy only")
+ parser.add_argument(
+ "--watch",
+ default=False,
+ action="store_true",
+ help="watch the play of pre-trained policy only"
+ )
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
- "--load-buffer-name", type=str,
- default="./expert_DQN_PongNoFrameskip-v4.hdf5")
+ "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
+ )
parser.add_argument(
- "--device", type=str,
- default="cuda" if torch.cuda.is_available() else "cpu")
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
args = parser.parse_known_args()[0]
return args
@@ -55,8 +59,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_discrete_cql(args=get_args()):
@@ -68,25 +76,29 @@ def test_discrete_cql(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
- net = QRDQN(*args.state_shape, args.action_shape,
- args.num_quantiles, args.device)
+ net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DiscreteCQLPolicy(
- net, optim, args.gamma, args.num_quantiles, args.n_step,
- args.target_update_freq, min_q_weight=args.min_q_weight
+ net,
+ optim,
+ args.gamma,
+ args.num_quantiles,
+ args.n_step,
+ args.target_update_freq,
+ min_q_weight=args.min_q_weight
).to(args.device)
# load a previous policy
if args.resume_path:
- policy.load_state_dict(torch.load(
- args.resume_path, map_location=args.device))
+ policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
@@ -105,7 +117,8 @@ def test_discrete_cql(args=get_args()):
# log
log_path = os.path.join(
args.logdir, args.task, 'cql',
- f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
+ f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
+ )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=args.log_interval)
@@ -124,8 +137,7 @@ def watch():
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -135,9 +147,17 @@ def watch():
exit(0)
result = offline_trainer(
- policy, buffer, test_collector, args.epoch,
- args.update_per_epoch, args.test_num, args.batch_size,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ buffer,
+ test_collector,
+ args.epoch,
+ args.update_per_epoch,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py
index e8e1ba54e..06cde415b 100644
--- a/examples/atari/atari_crr.py
+++ b/examples/atari/atari_crr.py
@@ -1,21 +1,21 @@
+import argparse
+import datetime
import os
-import torch
import pickle
import pprint
-import datetime
-import argparse
+
import numpy as np
+import torch
+from atari_network import DQN
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import DiscreteCRRPolicy
from tianshou.trainer import offline_trainer
+from tianshou.utils import TensorboardLogger
from tianshou.utils.net.discrete import Actor
-from tianshou.policy import DiscreteCRRPolicy
-from tianshou.data import Collector, VectorReplayBuffer
-
-from atari_network import DQN
-from atari_wrapper import wrap_deepmind
def get_args():
@@ -38,15 +38,19 @@ def get_args():
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
- parser.add_argument("--watch", default=False, action="store_true",
- help="watch the play of pre-trained policy only")
+ parser.add_argument(
+ "--watch",
+ default=False,
+ action="store_true",
+ help="watch the play of pre-trained policy only"
+ )
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
- "--load-buffer-name", type=str,
- default="./expert_DQN_PongNoFrameskip-v4.hdf5")
+ "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
+ )
parser.add_argument(
- "--device", type=str,
- default="cuda" if torch.cuda.is_available() else "cpu")
+ "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
+ )
args = parser.parse_known_args()[0]
return args
@@ -56,8 +60,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_discrete_crr(args=get_args()):
@@ -69,33 +77,44 @@ def test_discrete_crr(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
- feature_net = DQN(*args.state_shape, args.action_shape,
- device=args.device, features_only=True).to(args.device)
- actor = Actor(feature_net, args.action_shape, device=args.device,
- hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
+ feature_net = DQN(
+ *args.state_shape, args.action_shape, device=args.device, features_only=True
+ ).to(args.device)
+ actor = Actor(
+ feature_net,
+ args.action_shape,
+ device=args.device,
+ hidden_sizes=args.hidden_sizes,
+ softmax_output=False
+ ).to(args.device)
critic = DQN(*args.state_shape, args.action_shape,
device=args.device).to(args.device)
- optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()),
- lr=args.lr)
+ optim = torch.optim.Adam(
+ list(actor.parameters()) + list(critic.parameters()), lr=args.lr
+ )
# define policy
policy = DiscreteCRRPolicy(
- actor, critic, optim, args.gamma,
+ actor,
+ critic,
+ optim,
+ args.gamma,
policy_improvement_mode=args.policy_improvement_mode,
- ratio_upper_bound=args.ratio_upper_bound, beta=args.beta,
+ ratio_upper_bound=args.ratio_upper_bound,
+ beta=args.beta,
min_q_weight=args.min_q_weight,
target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
- policy.load_state_dict(torch.load(
- args.resume_path, map_location=args.device))
+ policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
@@ -114,7 +133,8 @@ def test_discrete_crr(args=get_args()):
# log
log_path = os.path.join(
args.logdir, args.task, 'crr',
- f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
+ f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
+ )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer, update_interval=args.log_interval)
@@ -132,8 +152,7 @@ def watch():
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -143,9 +162,17 @@ def watch():
exit(0)
result = offline_trainer(
- policy, buffer, test_collector, args.epoch,
- args.update_per_epoch, args.test_num, args.batch_size,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ buffer,
+ test_collector,
+ args.epoch,
+ args.update_per_epoch,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py
index 69ec08349..c9f74af8c 100644
--- a/examples/atari/atari_dqn.py
+++ b/examples/atari/atari_dqn.py
@@ -1,18 +1,18 @@
+import argparse
import os
-import torch
import pprint
-import argparse
+
import numpy as np
+import torch
+from atari_network import DQN
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import DQNPolicy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
-
-from atari_network import DQN
-from atari_wrapper import wrap_deepmind
+from tianshou.utils import TensorboardLogger
def get_args():
@@ -37,12 +37,16 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@@ -52,8 +56,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_dqn(args=get_args()):
@@ -64,22 +72,28 @@ def test_dqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
- for _ in range(args.training_num)])
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ train_envs = SubprocVectorEnv(
+ [lambda: make_atari_env(args) for _ in range(args.training_num)]
+ )
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
- net = DQN(*args.state_shape,
- args.action_shape, args.device).to(args.device)
+ net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
- policy = DQNPolicy(net, optim, args.gamma, args.n_step,
- target_update_freq=args.target_update_freq)
+ policy = DQNPolicy(
+ net,
+ optim,
+ args.gamma,
+ args.n_step,
+ target_update_freq=args.target_update_freq
+ )
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
@@ -87,8 +101,12 @@ def test_dqn(args=get_args()):
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
- save_only_last_obs=True, stack_num=args.frames_stack)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
@@ -131,11 +149,13 @@ def watch():
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(test_envs),
- ignore_obs_next=True, save_only_last_obs=True,
- stack_num=args.frames_stack)
- collector = Collector(policy, test_envs, buffer,
- exploration_noise=True)
+ args.buffer_size,
+ buffer_num=len(test_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
+ collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
@@ -143,8 +163,9 @@ def watch():
else:
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(
+ n_episode=args.test_num, render=args.render
+ )
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -156,11 +177,22 @@ def watch():
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py
index 4a6e97c06..4629bede2 100644
--- a/examples/atari/atari_fqf.py
+++ b/examples/atari/atari_fqf.py
@@ -1,20 +1,20 @@
+import argparse
import os
-import torch
import pprint
-import argparse
+
import numpy as np
+import torch
+from atari_network import DQN
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import FQFPolicy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import FQFPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.utils import TensorboardLogger
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
-from atari_network import DQN
-from atari_wrapper import wrap_deepmind
-
def get_args():
parser = argparse.ArgumentParser()
@@ -43,12 +43,16 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@@ -58,8 +62,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_fqf(args=get_args()):
@@ -70,30 +78,43 @@ def test_fqf(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
- for _ in range(args.training_num)])
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ train_envs = SubprocVectorEnv(
+ [lambda: make_atari_env(args) for _ in range(args.training_num)]
+ )
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
- feature_net = DQN(*args.state_shape, args.action_shape, args.device,
- features_only=True)
+ feature_net = DQN(
+ *args.state_shape, args.action_shape, args.device, features_only=True
+ )
net = FullQuantileFunction(
- feature_net, args.action_shape, args.hidden_sizes,
- args.num_cosines, device=args.device
+ feature_net,
+ args.action_shape,
+ args.hidden_sizes,
+ args.num_cosines,
+ device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
- fraction_optim = torch.optim.RMSprop(fraction_net.parameters(),
- lr=args.fraction_lr)
+ fraction_optim = torch.optim.RMSprop(
+ fraction_net.parameters(), lr=args.fraction_lr
+ )
# define policy
policy = FQFPolicy(
- net, optim, fraction_net, fraction_optim,
- args.gamma, args.num_fractions, args.ent_coef, args.n_step,
+ net,
+ optim,
+ fraction_net,
+ fraction_optim,
+ args.gamma,
+ args.num_fractions,
+ args.ent_coef,
+ args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
@@ -103,8 +124,12 @@ def test_fqf(args=get_args()):
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
@@ -147,11 +172,13 @@ def watch():
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(test_envs),
- ignore_obs_next=True, save_only_last_obs=True,
- stack_num=args.frames_stack)
- collector = Collector(policy, test_envs, buffer,
- exploration_noise=True)
+ args.buffer_size,
+ buffer_num=len(test_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
+ collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
@@ -159,8 +186,9 @@ def watch():
else:
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(
+ n_episode=args.test_num, render=args.render
+ )
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -172,11 +200,22 @@ def watch():
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py
index e5966a318..d0e7773d0 100644
--- a/examples/atari/atari_iqn.py
+++ b/examples/atari/atari_iqn.py
@@ -1,20 +1,20 @@
+import argparse
import os
-import torch
import pprint
-import argparse
+
import numpy as np
+import torch
+from atari_network import DQN
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import IQNPolicy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import IQNPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.utils import TensorboardLogger
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
-from atari_network import DQN
-from atari_wrapper import wrap_deepmind
-
def get_args():
parser = argparse.ArgumentParser()
@@ -43,12 +43,16 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@@ -58,8 +62,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_iqn(args=get_args()):
@@ -70,27 +78,38 @@ def test_iqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
- for _ in range(args.training_num)])
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ train_envs = SubprocVectorEnv(
+ [lambda: make_atari_env(args) for _ in range(args.training_num)]
+ )
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
- feature_net = DQN(*args.state_shape, args.action_shape, args.device,
- features_only=True)
+ feature_net = DQN(
+ *args.state_shape, args.action_shape, args.device, features_only=True
+ )
net = ImplicitQuantileNetwork(
- feature_net, args.action_shape, args.hidden_sizes,
- num_cosines=args.num_cosines, device=args.device
+ feature_net,
+ args.action_shape,
+ args.hidden_sizes,
+ num_cosines=args.num_cosines,
+ device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = IQNPolicy(
- net, optim, args.gamma, args.sample_size, args.online_sample_size,
- args.target_sample_size, args.n_step,
+ net,
+ optim,
+ args.gamma,
+ args.sample_size,
+ args.online_sample_size,
+ args.target_sample_size,
+ args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
@@ -100,8 +119,12 @@ def test_iqn(args=get_args()):
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
@@ -144,11 +167,13 @@ def watch():
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(test_envs),
- ignore_obs_next=True, save_only_last_obs=True,
- stack_num=args.frames_stack)
- collector = Collector(policy, test_envs, buffer,
- exploration_noise=True)
+ args.buffer_size,
+ buffer_num=len(test_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
+ collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
@@ -156,8 +181,9 @@ def watch():
else:
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(
+ n_episode=args.test_num, render=args.render
+ )
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -169,11 +195,22 @@ def watch():
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py
index 2eccf11af..4598fce11 100644
--- a/examples/atari/atari_network.py
+++ b/examples/atari/atari_network.py
@@ -1,7 +1,9 @@
-import torch
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
import numpy as np
+import torch
from torch import nn
-from typing import Any, Dict, Tuple, Union, Optional, Sequence
+
from tianshou.utils.net.discrete import NoisyLinear
@@ -27,15 +29,15 @@ def __init__(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
- nn.Flatten())
+ nn.Flatten()
+ )
with torch.no_grad():
- self.output_dim = np.prod(
- self.net(torch.zeros(1, c, h, w)).shape[1:])
+ self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
- self.net,
- nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
- nn.Linear(512, np.prod(action_shape)))
+ self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
+ nn.Linear(512, np.prod(action_shape))
+ )
self.output_dim = np.prod(action_shape)
def forward(
@@ -113,12 +115,14 @@ def linear(x, y):
self.Q = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
- linear(512, self.action_num * self.num_atoms))
+ linear(512, self.action_num * self.num_atoms)
+ )
self._is_dueling = is_dueling
if self._is_dueling:
self.V = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
- linear(512, self.num_atoms))
+ linear(512, self.num_atoms)
+ )
self.output_dim = self.action_num * self.num_atoms
def forward(
diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py
index 781f81d5d..23a7966eb 100644
--- a/examples/atari/atari_qrdqn.py
+++ b/examples/atari/atari_qrdqn.py
@@ -1,18 +1,18 @@
+import argparse
import os
-import torch
import pprint
-import argparse
+
import numpy as np
+import torch
+from atari_network import QRDQN
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
-from tianshou.policy import QRDQNPolicy
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import QRDQNPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
-
-from atari_network import QRDQN
-from atari_wrapper import wrap_deepmind
+from tianshou.utils import TensorboardLogger
def get_args():
@@ -38,12 +38,16 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@@ -53,8 +57,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_qrdqn(args=get_args()):
@@ -65,23 +73,28 @@ def test_qrdqn(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
- for _ in range(args.training_num)])
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ train_envs = SubprocVectorEnv(
+ [lambda: make_atari_env(args) for _ in range(args.training_num)]
+ )
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
- net = QRDQN(*args.state_shape, args.action_shape,
- args.num_quantiles, args.device)
+ net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = QRDQNPolicy(
- net, optim, args.gamma, args.num_quantiles,
- args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ args.gamma,
+ args.num_quantiles,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
@@ -90,8 +103,12 @@ def test_qrdqn(args=get_args()):
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
@@ -134,11 +151,13 @@ def watch():
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(test_envs),
- ignore_obs_next=True, save_only_last_obs=True,
- stack_num=args.frames_stack)
- collector = Collector(policy, test_envs, buffer,
- exploration_noise=True)
+ args.buffer_size,
+ buffer_num=len(test_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
+ collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
@@ -146,8 +165,9 @@ def watch():
else:
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(
+ n_episode=args.test_num, render=args.render
+ )
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -159,11 +179,22 @@ def watch():
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py
index f2f44f0cd..b131cce5f 100644
--- a/examples/atari/atari_rainbow.py
+++ b/examples/atari/atari_rainbow.py
@@ -1,19 +1,19 @@
+import argparse
+import datetime
import os
-import torch
import pprint
-import datetime
-import argparse
+
import numpy as np
+import torch
+from atari_network import Rainbow
+from atari_wrapper import wrap_deepmind
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import RainbowPolicy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import RainbowPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
-
-from atari_network import Rainbow
-from atari_wrapper import wrap_deepmind
+from tianshou.utils import TensorboardLogger
def get_args():
@@ -50,12 +50,16 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@@ -65,8 +69,12 @@ def make_atari_env(args):
def make_atari_env_watch(args):
- return wrap_deepmind(args.task, frame_stack=args.frames_stack,
- episode_life=False, clip_rewards=False)
+ return wrap_deepmind(
+ args.task,
+ frame_stack=args.frames_stack,
+ episode_life=False,
+ clip_rewards=False
+ )
def test_rainbow(args=get_args()):
@@ -77,25 +85,38 @@ def test_rainbow(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
- for _ in range(args.training_num)])
- test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
- for _ in range(args.test_num)])
+ train_envs = SubprocVectorEnv(
+ [lambda: make_atari_env(args) for _ in range(args.training_num)]
+ )
+ test_envs = SubprocVectorEnv(
+ [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
- net = Rainbow(*args.state_shape, args.action_shape,
- args.num_atoms, args.noisy_std, args.device,
- is_dueling=not args.no_dueling,
- is_noisy=not args.no_noisy)
+ net = Rainbow(
+ *args.state_shape,
+ args.action_shape,
+ args.num_atoms,
+ args.noisy_std,
+ args.device,
+ is_dueling=not args.no_dueling,
+ is_noisy=not args.no_noisy
+ )
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = RainbowPolicy(
- net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
- args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ args.gamma,
+ args.num_atoms,
+ args.v_min,
+ args.v_max,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
@@ -105,20 +126,31 @@ def test_rainbow(args=get_args()):
# when you have enough RAM
if args.no_priority:
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
- save_only_last_obs=True, stack_num=args.frames_stack)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
else:
buffer = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
- save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha,
- beta=args.beta, weight_norm=not args.no_weight_norm)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack,
+ alpha=args.alpha,
+ beta=args.beta,
+ weight_norm=not args.no_weight_norm
+ )
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(
args.logdir, args.task, 'rainbow',
- f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
+ f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
+ )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
@@ -164,12 +196,15 @@ def watch():
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(test_envs),
- ignore_obs_next=True, save_only_last_obs=True,
- stack_num=args.frames_stack, alpha=args.alpha,
- beta=args.beta)
- collector = Collector(policy, test_envs, buffer,
- exploration_noise=True)
+ args.buffer_size,
+ buffer_num=len(test_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack,
+ alpha=args.alpha,
+ beta=args.beta
+ )
+ collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
@@ -177,8 +212,9 @@ def watch():
else:
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(
+ n_episode=args.test_num, render=args.render
+ )
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -190,11 +226,22 @@ def watch():
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
watch()
diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py
index 53a662613..333f9787a 100644
--- a/examples/atari/atari_wrapper.py
+++ b/examples/atari/atari_wrapper.py
@@ -1,10 +1,11 @@
# Borrow a lot from openai baselines:
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
+from collections import deque
+
import cv2
import gym
import numpy as np
-from collections import deque
class NoopResetEnv(gym.Wrapper):
@@ -48,7 +49,7 @@ def step(self, action):
reward, and max over last observations.
"""
obs_list, total_reward, done = [], 0., False
- for i in range(self._skip):
+ for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
obs_list.append(obs)
total_reward += reward
@@ -127,13 +128,14 @@ def __init__(self, env):
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
- shape=(self.size, self.size), dtype=env.observation_space.dtype)
+ shape=(self.size, self.size),
+ dtype=env.observation_space.dtype
+ )
def observation(self, frame):
"""returns the current observation from a frame"""
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
- return cv2.resize(frame, (self.size, self.size),
- interpolation=cv2.INTER_AREA)
+ return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
class ScaledFloatFrame(gym.ObservationWrapper):
@@ -149,8 +151,8 @@ def __init__(self, env):
self.bias = low
self.scale = high - low
self.observation_space = gym.spaces.Box(
- low=0., high=1., shape=env.observation_space.shape,
- dtype=np.float32)
+ low=0., high=1., shape=env.observation_space.shape, dtype=np.float32
+ )
def observation(self, observation):
return (observation - self.bias) / self.scale
@@ -182,11 +184,13 @@ def __init__(self, env, n_frames):
super().__init__(env)
self.n_frames = n_frames
self.frames = deque([], maxlen=n_frames)
- shape = (n_frames,) + env.observation_space.shape
+ shape = (n_frames, ) + env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=np.min(env.observation_space.low),
high=np.max(env.observation_space.high),
- shape=shape, dtype=env.observation_space.dtype)
+ shape=shape,
+ dtype=env.observation_space.dtype
+ )
def reset(self):
obs = self.env.reset()
@@ -205,8 +209,14 @@ def _get_ob(self):
return np.stack(self.frames, axis=0)
-def wrap_deepmind(env_id, episode_life=True, clip_rewards=True,
- frame_stack=4, scale=False, warp_frame=True):
+def wrap_deepmind(
+ env_id,
+ episode_life=True,
+ clip_rewards=True,
+ frame_stack=4,
+ scale=False,
+ warp_frame=True
+):
"""Configure environment for DeepMind-style Atari. The observation is
channel-first: (c, h, w) instead of (h, w, c).
diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py
index 4889f7eb2..76246fd3e 100644
--- a/examples/box2d/acrobot_dualdqn.py
+++ b/examples/box2d/acrobot_dualdqn.py
@@ -1,17 +1,18 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import DQNPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
def get_args():
@@ -31,17 +32,19 @@ def get_args():
parser.add_argument('--update-per-step', type=float, default=0.01)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128])
- parser.add_argument('--dueling-q-hidden-sizes', type=int,
- nargs='*', default=[128, 128])
- parser.add_argument('--dueling-v-hidden-sizes', type=int,
- nargs='*', default=[128, 128])
+ parser.add_argument(
+ '--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128]
+ )
+ parser.add_argument(
+ '--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128]
+ )
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
return parser.parse_args()
@@ -52,10 +55,12 @@ def test_dqn(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -64,18 +69,28 @@ def test_dqn(args=get_args()):
# model
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- dueling_param=(Q_param, V_param)).to(args.device)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ dueling_param=(Q_param, V_param)
+ ).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
- net, optim, args.gamma, args.n_step,
- target_update_freq=args.target_update_freq)
+ net,
+ optim,
+ args.gamma,
+ args.n_step,
+ target_update_freq=args.target_update_freq
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
@@ -105,10 +120,21 @@ def test_fn(epoch, env_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
- update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py
index 4caa50b94..598622d01 100644
--- a/examples/box2d/bipedal_hardcore_sac.py
+++ b/examples/box2d/bipedal_hardcore_sac.py
@@ -1,17 +1,18 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import SubprocVectorEnv
from tianshou.policy import SACPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
-from tianshou.env import SubprocVectorEnv
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
@@ -32,16 +33,15 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=128)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
return parser.parse_args()
@@ -75,13 +75,16 @@ def test_sac_bipedal(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
- train_envs = SubprocVectorEnv([
- lambda: Wrapper(gym.make(args.task))
- for _ in range(args.training_num)])
+ train_envs = SubprocVectorEnv(
+ [lambda: Wrapper(gym.make(args.task)) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
- test_envs = SubprocVectorEnv([
- lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False)
- for _ in range(args.test_num)])
+ test_envs = SubprocVectorEnv(
+ [
+ lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False)
+ for _ in range(args.test_num)
+ ]
+ )
# seed
np.random.seed(args.seed)
@@ -90,22 +93,33 @@ def test_sac_bipedal(args=get_args()):
test_envs.seed(args.seed)
# model
- net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
- net_a, args.action_shape, max_action=args.max_action,
- device=args.device, unbounded=True).to(args.device)
+ net_a,
+ args.action_shape,
+ max_action=args.max_action,
+ device=args.device,
+ unbounded=True
+ ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c1 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c1 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
- net_c2 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c2 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
@@ -116,9 +130,18 @@ def test_sac_bipedal(args=get_args()):
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- tau=args.tau, gamma=args.gamma, alpha=args.alpha,
- estimation_step=args.n_step, action_space=env.action_space)
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau=args.tau,
+ gamma=args.gamma,
+ alpha=args.alpha,
+ estimation_step=args.n_step,
+ action_space=env.action_space
+ )
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))
@@ -126,9 +149,11 @@ def test_sac_bipedal(args=get_args()):
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
@@ -144,10 +169,20 @@ def stop_fn(mean_rewards):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
- update_per_step=args.update_per_step, test_in_train=False,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ test_in_train=False,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
if __name__ == '__main__':
pprint.pprint(result)
@@ -155,8 +190,7 @@ def stop_fn(mean_rewards):
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py
index bb73ac615..88f4c397b 100644
--- a/examples/box2d/lunarlander_dqn.py
+++ b/examples/box2d/lunarlander_dqn.py
@@ -1,17 +1,18 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import DQNPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
-from tianshou.env import DummyVectorEnv, SubprocVectorEnv
def get_args():
@@ -31,19 +32,20 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=16)
parser.add_argument('--update-per-step', type=float, default=0.0625)
parser.add_argument('--batch-size', type=int, default=128)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128])
- parser.add_argument('--dueling-q-hidden-sizes', type=int,
- nargs='*', default=[128, 128])
- parser.add_argument('--dueling-v-hidden-sizes', type=int,
- nargs='*', default=[128, 128])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
+ parser.add_argument(
+ '--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128]
+ )
+ parser.add_argument(
+ '--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128]
+ )
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
return parser.parse_args()
@@ -54,10 +56,12 @@ def test_dqn(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -66,18 +70,28 @@ def test_dqn(args=get_args()):
# model
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- dueling_param=(Q_param, V_param)).to(args.device)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ dueling_param=(Q_param, V_param)
+ ).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
- net, optim, args.gamma, args.n_step,
- target_update_freq=args.target_update_freq)
+ net,
+ optim,
+ args.gamma,
+ args.n_step,
+ target_update_freq=args.target_update_freq
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
@@ -93,7 +107,7 @@ def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def train_fn(epoch, env_step): # exp decay
- eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test)
+ eps = max(args.eps_train * (1 - 5e-6)**env_step, args.eps_test)
policy.set_eps(eps)
def test_fn(epoch, env_step):
@@ -101,10 +115,21 @@ def test_fn(epoch, env_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
- update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn,
- test_fn=test_fn, save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ stop_fn=stop_fn,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py
index a43728be5..0638e8f61 100644
--- a/examples/box2d/mcc_sac.py
+++ b/examples/box2d/mcc_sac.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import SACPolicy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.exploration import OUNoise
-from tianshou.utils.net.common import Net
+from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
@@ -34,16 +35,15 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=5)
parser.add_argument('--update-per-step', type=float, default=0.2)
parser.add_argument('--batch-size', type=int, default=128)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=5)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=False)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
return parser.parse_args()
@@ -54,31 +54,43 @@ def test_sac(args=get_args()):
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
- net, args.action_shape,
- max_action=args.max_action, device=args.device, unbounded=True
+ net,
+ args.action_shape,
+ max_action=args.max_action,
+ device=args.device,
+ unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c1 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, concat=True,
- device=args.device)
+ net_c1 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
- net_c2 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, concat=True,
- device=args.device)
+ net_c2 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
@@ -89,16 +101,26 @@ def test_sac(args=get_args()):
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- tau=args.tau, gamma=args.gamma, alpha=args.alpha,
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau=args.tau,
+ gamma=args.gamma,
+ alpha=args.alpha,
reward_normalization=args.rew_norm,
exploration_noise=OUNoise(0.0, args.noise_std),
- action_space=env.action_space)
+ action_space=env.action_space
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
@@ -114,10 +136,19 @@ def stop_fn(mean_rewards):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
- update_per_step=args.update_per_step, stop_fn=stop_fn,
- save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/examples/mujoco/analysis.py b/examples/mujoco/analysis.py
index 01a2cf678..ed0bb6872 100755
--- a/examples/mujoco/analysis.py
+++ b/examples/mujoco/analysis.py
@@ -1,11 +1,12 @@
#!/usr/bin/env python3
-import re
import argparse
+import re
+from collections import defaultdict
+
import numpy as np
from tabulate import tabulate
-from collections import defaultdict
-from tools import find_all_files, group_files, csv2numpy
+from tools import csv2numpy, find_all_files, group_files
def numerical_anysis(root_dir, xlim, norm=False):
@@ -20,13 +21,16 @@ def numerical_anysis(root_dir, xlim, norm=False):
for f in csv_files:
result = csv2numpy(f)
if norm:
- result = np.stack([
- result['env_step'],
- result['rew'] - result['rew'][0],
- result['rew:shaded']])
+ result = np.stack(
+ [
+ result['env_step'], result['rew'] - result['rew'][0],
+ result['rew:shaded']
+ ]
+ )
else:
- result = np.stack([
- result['env_step'], result['rew'], result['rew:shaded']])
+ result = np.stack(
+ [result['env_step'], result['rew'], result['rew:shaded']]
+ )
if result[0, -1] < xlim:
continue
@@ -79,11 +83,17 @@ def numerical_anysis(root_dir, xlim, norm=False):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
- parser.add_argument('--xlim', type=int, default=1000000,
- help='x-axis limitation (default: 1000000)')
+ parser.add_argument(
+ '--xlim',
+ type=int,
+ default=1000000,
+ help='x-axis limitation (default: 1000000)'
+ )
parser.add_argument('--root-dir', type=str)
parser.add_argument(
- '--norm', action="store_true",
- help="Normalize all results according to environment.")
+ '--norm',
+ action="store_true",
+ help="Normalize all results according to environment."
+ )
args = parser.parse_args()
numerical_anysis(args.root_dir, args.xlim, norm=args.norm)
diff --git a/examples/mujoco/gen_json.py b/examples/mujoco/gen_json.py
index 0c0b113e9..99cad74a3 100755
--- a/examples/mujoco/gen_json.py
+++ b/examples/mujoco/gen_json.py
@@ -1,15 +1,15 @@
#!/usr/bin/env python3
-import os
import csv
-import sys
import json
+import os
+import sys
def merge(rootdir):
"""format: $rootdir/$algo/*.csv"""
result = []
- for path, dirnames, filenames in os.walk(rootdir):
+ for path, _, filenames in os.walk(rootdir):
filenames = [f for f in filenames if f.endswith('.csv')]
if len(filenames) == 0:
continue
@@ -19,12 +19,14 @@ def merge(rootdir):
algo = os.path.relpath(path, rootdir).upper()
reader = csv.DictReader(open(os.path.join(path, filenames[0])))
for row in reader:
- result.append({
- 'env_step': int(row['env_step']),
- 'rew': float(row['rew']),
- 'rew_std': float(row['rew:shaded']),
- 'Agent': algo,
- })
+ result.append(
+ {
+ 'env_step': int(row['env_step']),
+ 'rew': float(row['rew']),
+ 'rew_std': float(row['rew:shaded']),
+ 'Agent': algo,
+ }
+ )
open(os.path.join(rootdir, 'result.json'), 'w').write(json.dumps(result))
diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py
index e9debc906..02978697b 100755
--- a/examples/mujoco/mujoco_a2c.py
+++ b/examples/mujoco/mujoco_a2c.py
@@ -1,24 +1,25 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch import nn
+from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
-from torch.distributions import Independent, Normal
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+from tianshou.env import SubprocVectorEnv
from tianshou.policy import A2CPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -48,11 +49,15 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -63,16 +68,18 @@ def test_a2c(args=get_args()):
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)],
- norm_obs=True)
+ [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True
+ )
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
- norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
+ norm_obs=True,
+ obs_rms=train_envs.obs_rms,
+ update_obs_rms=False
+ )
# seed
np.random.seed(args.seed)
@@ -80,12 +87,25 @@ def test_a2c(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
- actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
- unbounded=True, device=args.device).to(args.device)
- net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
+ net_a = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
+ actor = ActorProb(
+ net_a,
+ args.action_shape,
+ max_action=args.max_action,
+ unbounded=True,
+ device=args.device
+ ).to(args.device)
+ net_c = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
@@ -101,27 +121,43 @@ def test_a2c(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
- optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()),
- lr=args.lr, eps=1e-5, alpha=0.99)
+ optim = torch.optim.RMSprop(
+ list(actor.parameters()) + list(critic.parameters()),
+ lr=args.lr,
+ eps=1e-5,
+ alpha=0.99
+ )
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
- args.step_per_epoch / args.step_per_collect) * args.epoch
+ args.step_per_epoch / args.step_per_collect
+ ) * args.epoch
lr_scheduler = LambdaLR(
- optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
+ optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
+ )
def dist(*logits):
return Independent(Normal(*logits), 1)
- policy = A2CPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
- gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm,
- vf_coef=args.vf_coef, ent_coef=args.ent_coef,
- reward_normalization=args.rew_norm, action_scaling=True,
- action_bound_method=args.bound_action_method,
- lr_scheduler=lr_scheduler, action_space=env.action_space)
+ policy = A2CPolicy(
+ actor,
+ critic,
+ optim,
+ dist,
+ discount_factor=args.gamma,
+ gae_lambda=args.gae_lambda,
+ max_grad_norm=args.max_grad_norm,
+ vf_coef=args.vf_coef,
+ ent_coef=args.ent_coef,
+ reward_normalization=args.rew_norm,
+ action_scaling=True,
+ action_bound_method=args.bound_action_method,
+ lr_scheduler=lr_scheduler,
+ action_space=env.action_space
+ )
# load a previous policy
if args.resume_path:
@@ -149,10 +185,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
- args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
- test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ save_fn=save_fn,
+ logger=logger,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py
index 28bac056a..8d436b573 100755
--- a/examples/mujoco/mujoco_ddpg.py
+++ b/examples/mujoco/mujoco_ddpg.py
@@ -1,22 +1,23 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import DDPGPolicy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
-from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
+from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -42,11 +43,15 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -58,17 +63,18 @@ def test_ddpg(args=get_args()):
args.exploration_noise = args.exploration_noise * args.max_action
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
if args.training_num > 1:
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
else:
train_envs = gym.make(args.task)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -77,19 +83,29 @@ def test_ddpg(args=get_args()):
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(
- net_a, args.action_shape, max_action=args.max_action,
- device=args.device).to(args.device)
+ net_a, args.action_shape, max_action=args.max_action, device=args.device
+ ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic = Critic(net_c, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
- actor, actor_optim, critic, critic_optim,
- tau=args.tau, gamma=args.gamma,
+ actor,
+ actor_optim,
+ critic,
+ critic_optim,
+ tau=args.tau,
+ gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
- estimation_step=args.n_step, action_space=env.action_space)
+ estimation_step=args.n_step,
+ action_space=env.action_space
+ )
# load a previous policy
if args.resume_path:
@@ -118,10 +134,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py
index 00b2a1a2c..23883a119 100755
--- a/examples/mujoco/mujoco_npg.py
+++ b/examples/mujoco/mujoco_npg.py
@@ -1,24 +1,25 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch import nn
+from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
-from torch.distributions import Independent, Normal
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+from tianshou.env import SubprocVectorEnv
from tianshou.policy import NPGPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -26,8 +27,9 @@ def get_args():
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
- parser.add_argument('--hidden-sizes', type=int, nargs='*',
- default=[64, 64]) # baselines [32, 32]
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[64, 64]
+ ) # baselines [32, 32]
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
@@ -49,11 +51,15 @@ def get_args():
parser.add_argument('--optim-critic-iters', type=int, default=20)
parser.add_argument('--actor-step-size', type=float, default=0.1)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -64,16 +70,18 @@ def test_npg(args=get_args()):
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)],
- norm_obs=True)
+ [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True
+ )
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
- norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
+ norm_obs=True,
+ obs_rms=train_envs.obs_rms,
+ update_obs_rms=False
+ )
# seed
np.random.seed(args.seed)
@@ -81,12 +89,25 @@ def test_npg(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
- actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
- unbounded=True, device=args.device).to(args.device)
- net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
+ net_a = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
+ actor = ActorProb(
+ net_a,
+ args.action_shape,
+ max_action=args.max_action,
+ unbounded=True,
+ device=args.device
+ ).to(args.device)
+ net_c = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
@@ -107,22 +128,32 @@ def test_npg(args=get_args()):
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
- args.step_per_epoch / args.step_per_collect) * args.epoch
+ args.step_per_epoch / args.step_per_collect
+ ) * args.epoch
lr_scheduler = LambdaLR(
- optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
+ optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
+ )
def dist(*logits):
return Independent(Normal(*logits), 1)
- policy = NPGPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
- gae_lambda=args.gae_lambda,
- reward_normalization=args.rew_norm, action_scaling=True,
- action_bound_method=args.bound_action_method,
- lr_scheduler=lr_scheduler, action_space=env.action_space,
- advantage_normalization=args.norm_adv,
- optim_critic_iters=args.optim_critic_iters,
- actor_step_size=args.actor_step_size)
+ policy = NPGPolicy(
+ actor,
+ critic,
+ optim,
+ dist,
+ discount_factor=args.gamma,
+ gae_lambda=args.gae_lambda,
+ reward_normalization=args.rew_norm,
+ action_scaling=True,
+ action_bound_method=args.bound_action_method,
+ lr_scheduler=lr_scheduler,
+ action_space=env.action_space,
+ advantage_normalization=args.norm_adv,
+ optim_critic_iters=args.optim_critic_iters,
+ actor_step_size=args.actor_step_size
+ )
# load a previous policy
if args.resume_path:
@@ -150,10 +181,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
- args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
- test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ save_fn=save_fn,
+ logger=logger,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py
index fb3e7a0a2..01dc5aa3f 100755
--- a/examples/mujoco/mujoco_ppo.py
+++ b/examples/mujoco/mujoco_ppo.py
@@ -1,24 +1,25 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch import nn
+from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
-from torch.distributions import Independent, Normal
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+from tianshou.env import SubprocVectorEnv
from tianshou.policy import PPOPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -53,11 +54,15 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -68,16 +73,18 @@ def test_ppo(args=get_args()):
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)],
- norm_obs=True)
+ [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True
+ )
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
- norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
+ norm_obs=True,
+ obs_rms=train_envs.obs_rms,
+ update_obs_rms=False
+ )
# seed
np.random.seed(args.seed)
@@ -85,12 +92,25 @@ def test_ppo(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
- actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
- unbounded=True, device=args.device).to(args.device)
- net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
+ net_a = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
+ actor = ActorProb(
+ net_a,
+ args.action_shape,
+ max_action=args.max_action,
+ unbounded=True,
+ device=args.device
+ ).to(args.device)
+ net_c = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
@@ -107,29 +127,44 @@ def test_ppo(args=get_args()):
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(
- list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
+ list(actor.parameters()) + list(critic.parameters()), lr=args.lr
+ )
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
- args.step_per_epoch / args.step_per_collect) * args.epoch
+ args.step_per_epoch / args.step_per_collect
+ ) * args.epoch
lr_scheduler = LambdaLR(
- optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
+ optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
+ )
def dist(*logits):
return Independent(Normal(*logits), 1)
- policy = PPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
- gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm,
- vf_coef=args.vf_coef, ent_coef=args.ent_coef,
- reward_normalization=args.rew_norm, action_scaling=True,
- action_bound_method=args.bound_action_method,
- lr_scheduler=lr_scheduler, action_space=env.action_space,
- eps_clip=args.eps_clip, value_clip=args.value_clip,
- dual_clip=args.dual_clip, advantage_normalization=args.norm_adv,
- recompute_advantage=args.recompute_adv)
+ policy = PPOPolicy(
+ actor,
+ critic,
+ optim,
+ dist,
+ discount_factor=args.gamma,
+ gae_lambda=args.gae_lambda,
+ max_grad_norm=args.max_grad_norm,
+ vf_coef=args.vf_coef,
+ ent_coef=args.ent_coef,
+ reward_normalization=args.rew_norm,
+ action_scaling=True,
+ action_bound_method=args.bound_action_method,
+ lr_scheduler=lr_scheduler,
+ action_space=env.action_space,
+ eps_clip=args.eps_clip,
+ value_clip=args.value_clip,
+ dual_clip=args.dual_clip,
+ advantage_normalization=args.norm_adv,
+ recompute_advantage=args.recompute_adv
+ )
# load a previous policy
if args.resume_path:
@@ -157,10 +192,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
- args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
- test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ save_fn=save_fn,
+ logger=logger,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py
index b7698562a..914b46251 100755
--- a/examples/mujoco/mujoco_reinforce.py
+++ b/examples/mujoco/mujoco_reinforce.py
@@ -1,24 +1,25 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch import nn
+from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
-from torch.distributions import Independent, Normal
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+from tianshou.env import SubprocVectorEnv
from tianshou.policy import PGPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -45,11 +46,15 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -60,16 +65,18 @@ def test_reinforce(args=get_args()):
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)],
- norm_obs=True)
+ [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True
+ )
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
- norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
+ norm_obs=True,
+ obs_rms=train_envs.obs_rms,
+ update_obs_rms=False
+ )
# seed
np.random.seed(args.seed)
@@ -77,10 +84,19 @@ def test_reinforce(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
- actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
- unbounded=True, device=args.device).to(args.device)
+ net_a = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
+ actor = ActorProb(
+ net_a,
+ args.action_shape,
+ max_action=args.max_action,
+ unbounded=True,
+ device=args.device
+ ).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in actor.modules():
if isinstance(m, torch.nn.Linear):
@@ -100,18 +116,27 @@ def test_reinforce(args=get_args()):
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
- args.step_per_epoch / args.step_per_collect) * args.epoch
+ args.step_per_epoch / args.step_per_collect
+ ) * args.epoch
lr_scheduler = LambdaLR(
- optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
+ optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
+ )
def dist(*logits):
return Independent(Normal(*logits), 1)
- policy = PGPolicy(actor, optim, dist, discount_factor=args.gamma,
- reward_normalization=args.rew_norm, action_scaling=True,
- action_bound_method=args.action_bound_method,
- lr_scheduler=lr_scheduler, action_space=env.action_space)
+ policy = PGPolicy(
+ actor,
+ optim,
+ dist,
+ discount_factor=args.gamma,
+ reward_normalization=args.rew_norm,
+ action_scaling=True,
+ action_bound_method=args.action_bound_method,
+ lr_scheduler=lr_scheduler,
+ action_space=env.action_space
+ )
# load a previous policy
if args.resume_path:
@@ -139,10 +164,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
- args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
- test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ save_fn=save_fn,
+ logger=logger,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py
index c2dfd3618..cb764f473 100755
--- a/examples/mujoco/mujoco_sac.py
+++ b/examples/mujoco/mujoco_sac.py
@@ -1,21 +1,22 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+from tianshou.env import SubprocVectorEnv
from tianshou.policy import SACPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -43,11 +44,15 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -58,17 +63,18 @@ def test_sac(args=get_args()):
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
if args.training_num > 1:
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
else:
train_envs = gym.make(args.task)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -77,16 +83,28 @@ def test_sac(args=get_args()):
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
- net_a, args.action_shape, max_action=args.max_action,
- device=args.device, unbounded=True, conditioned_sigma=True
+ net_a,
+ args.action_shape,
+ max_action=args.max_action,
+ device=args.device,
+ unbounded=True,
+ conditioned_sigma=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c1 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
- net_c2 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c1 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
+ net_c2 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
@@ -99,9 +117,18 @@ def test_sac(args=get_args()):
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- tau=args.tau, gamma=args.gamma, alpha=args.alpha,
- estimation_step=args.n_step, action_space=env.action_space)
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau=args.tau,
+ gamma=args.gamma,
+ alpha=args.alpha,
+ estimation_step=args.n_step,
+ action_space=env.action_space
+ )
# load a previous policy
if args.resume_path:
@@ -130,10 +157,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py
index 9a0179899..9e0ca0d82 100755
--- a/examples/mujoco/mujoco_td3.py
+++ b/examples/mujoco/mujoco_td3.py
@@ -1,22 +1,23 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import TD3Policy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
-from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
+from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -45,11 +46,15 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -63,17 +68,18 @@ def test_td3(args=get_args()):
args.noise_clip = args.noise_clip * args.max_action
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
if args.training_num > 1:
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
else:
train_envs = gym.make(args.task)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -82,27 +88,44 @@ def test_td3(args=get_args()):
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(
- net_a, args.action_shape, max_action=args.max_action,
- device=args.device).to(args.device)
+ net_a, args.action_shape, max_action=args.max_action, device=args.device
+ ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c1 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
- net_c2 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c1 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
+ net_c2 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- tau=args.tau, gamma=args.gamma,
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau=args.tau,
+ gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
- policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
- noise_clip=args.noise_clip, estimation_step=args.n_step,
- action_space=env.action_space)
+ policy_noise=args.policy_noise,
+ update_actor_freq=args.update_actor_freq,
+ noise_clip=args.noise_clip,
+ estimation_step=args.n_step,
+ action_space=env.action_space
+ )
# load a previous policy
if args.resume_path:
@@ -131,10 +154,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py
index b00f2e3d6..aef324fd5 100755
--- a/examples/mujoco/mujoco_trpo.py
+++ b/examples/mujoco/mujoco_trpo.py
@@ -1,24 +1,25 @@
#!/usr/bin/env python3
+import argparse
+import datetime
import os
-import gym
-import torch
import pprint
-import datetime
-import argparse
+
+import gym
import numpy as np
+import torch
from torch import nn
+from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
-from torch.distributions import Independent, Normal
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+from tianshou.env import SubprocVectorEnv
from tianshou.policy import TRPOPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
-from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
@@ -26,8 +27,9 @@ def get_args():
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
- parser.add_argument('--hidden-sizes', type=int, nargs='*',
- default=[64, 64]) # baselines [32, 32]
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[64, 64]
+ ) # baselines [32, 32]
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
@@ -52,11 +54,15 @@ def get_args():
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
parser.add_argument('--max-backtracks', type=int, default=10)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
return parser.parse_args()
@@ -67,16 +73,18 @@ def test_trpo(args=get_args()):
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
- print("Action range:", np.min(env.action_space.low),
- np.max(env.action_space.high))
+ print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)],
- norm_obs=True)
+ [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True
+ )
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
- norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
+ norm_obs=True,
+ obs_rms=train_envs.obs_rms,
+ update_obs_rms=False
+ )
# seed
np.random.seed(args.seed)
@@ -84,12 +92,25 @@ def test_trpo(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
- actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
- unbounded=True, device=args.device).to(args.device)
- net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
+ net_a = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
+ actor = ActorProb(
+ net_a,
+ args.action_shape,
+ max_action=args.max_action,
+ unbounded=True,
+ device=args.device
+ ).to(args.device)
+ net_c = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
@@ -110,24 +131,34 @@ def test_trpo(args=get_args()):
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
- args.step_per_epoch / args.step_per_collect) * args.epoch
+ args.step_per_epoch / args.step_per_collect
+ ) * args.epoch
lr_scheduler = LambdaLR(
- optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
+ optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
+ )
def dist(*logits):
return Independent(Normal(*logits), 1)
- policy = TRPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
- gae_lambda=args.gae_lambda,
- reward_normalization=args.rew_norm, action_scaling=True,
- action_bound_method=args.bound_action_method,
- lr_scheduler=lr_scheduler, action_space=env.action_space,
- advantage_normalization=args.norm_adv,
- optim_critic_iters=args.optim_critic_iters,
- max_kl=args.max_kl,
- backtrack_coeff=args.backtrack_coeff,
- max_backtracks=args.max_backtracks)
+ policy = TRPOPolicy(
+ actor,
+ critic,
+ optim,
+ dist,
+ discount_factor=args.gamma,
+ gae_lambda=args.gae_lambda,
+ reward_normalization=args.rew_norm,
+ action_scaling=True,
+ action_bound_method=args.bound_action_method,
+ lr_scheduler=lr_scheduler,
+ action_space=env.action_space,
+ advantage_normalization=args.norm_adv,
+ optim_critic_iters=args.optim_critic_iters,
+ max_kl=args.max_kl,
+ backtrack_coeff=args.backtrack_coeff,
+ max_backtracks=args.max_backtracks
+ )
# load a previous policy
if args.resume_path:
@@ -155,10 +186,19 @@ def save_fn(policy):
if not args.watch:
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
- args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
- test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ save_fn=save_fn,
+ logger=logger,
+ test_in_train=False
+ )
pprint.pprint(result)
# Let's watch its performance!
diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py
index 4ecd530c7..e3e7057e4 100755
--- a/examples/mujoco/plotter.py
+++ b/examples/mujoco/plotter.py
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
-import re
-import os
import argparse
-import numpy as np
+import os
+import re
+
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
-
-from tools import find_all_files, group_files, csv2numpy
+import numpy as np
+from tools import csv2numpy, find_all_files, group_files
def smooth(y, radius, mode='two_sided', valid_only=False):
@@ -38,28 +38,49 @@ def smooth(y, radius, mode='two_sided', valid_only=False):
return out
-COLORS = ([
- # deepmind style
- '#0072B2',
- '#009E73',
- '#D55E00',
- '#CC79A7',
- # '#F0E442',
- '#d73027', # RED
- # built-in color
- 'blue', 'red', 'pink', 'cyan', 'magenta', 'yellow', 'black', 'purple',
- 'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise',
- 'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue', 'green',
- # personal color
- '#313695', # DARK BLUE
- '#74add1', # LIGHT BLUE
- '#f46d43', # ORANGE
- '#4daf4a', # GREEN
- '#984ea3', # PURPLE
- '#f781bf', # PINK
- '#ffc832', # YELLOW
- '#000000', # BLACK
-])
+COLORS = (
+ [
+ # deepmind style
+ '#0072B2',
+ '#009E73',
+ '#D55E00',
+ '#CC79A7',
+ # '#F0E442',
+ '#d73027', # RED
+ # built-in color
+ 'blue',
+ 'red',
+ 'pink',
+ 'cyan',
+ 'magenta',
+ 'yellow',
+ 'black',
+ 'purple',
+ 'brown',
+ 'orange',
+ 'teal',
+ 'lightblue',
+ 'lime',
+ 'lavender',
+ 'turquoise',
+ 'darkgreen',
+ 'tan',
+ 'salmon',
+ 'gold',
+ 'darkred',
+ 'darkblue',
+ 'green',
+ # personal color
+ '#313695', # DARK BLUE
+ '#74add1', # LIGHT BLUE
+ '#f46d43', # ORANGE
+ '#4daf4a', # GREEN
+ '#984ea3', # PURPLE
+ '#f781bf', # PINK
+ '#ffc832', # YELLOW
+ '#000000', # BLACK
+ ]
+)
def plot_ax(
@@ -76,6 +97,7 @@ def plot_ax(
shaded_std=True,
legend_outside=False,
):
+
def legend_fn(x):
# return os.path.split(os.path.join(
# args.root_dir, x))[0].replace('/', '_') + " (10)"
@@ -96,8 +118,11 @@ def legend_fn(x):
y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius)
ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2)
- ax.legend(legneds, loc=2 if legend_outside else None,
- bbox_to_anchor=(1, 1) if legend_outside else None)
+ ax.legend(
+ legneds,
+ loc=2 if legend_outside else None,
+ bbox_to_anchor=(1, 1) if legend_outside else None
+ )
ax.xaxis.set_major_formatter(mticker.EngFormatter())
if xlim is not None:
ax.set_xlim(xmin=0, xmax=xlim)
@@ -127,8 +152,14 @@ def plot_figure(
res = group_files(file_lists, group_pattern)
row_n = int(np.ceil(len(res) / 3))
col_n = min(len(res), 3)
- fig, axes = plt.subplots(row_n, col_n, sharex=sharex, sharey=sharey, figsize=(
- fig_length * col_n, fig_width * row_n), squeeze=False)
+ fig, axes = plt.subplots(
+ row_n,
+ col_n,
+ sharex=sharex,
+ sharey=sharey,
+ figsize=(fig_length * col_n, fig_width * row_n),
+ squeeze=False
+ )
axes = axes.flatten()
for i, (k, v) in enumerate(res.items()):
plot_ax(axes[i], v, title=k, **kwargs)
@@ -138,53 +169,95 @@ def plot_figure(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='plotter')
- parser.add_argument('--fig-length', type=int, default=6,
- help='matplotlib figure length (default: 6)')
- parser.add_argument('--fig-width', type=int, default=6,
- help='matplotlib figure width (default: 6)')
- parser.add_argument('--style', default='seaborn',
- help='matplotlib figure style (default: seaborn)')
- parser.add_argument('--title', default=None,
- help='matplotlib figure title (default: None)')
- parser.add_argument('--xkey', default='env_step',
- help='x-axis key in csv file (default: env_step)')
- parser.add_argument('--ykey', default='rew',
- help='y-axis key in csv file (default: rew)')
- parser.add_argument('--smooth', type=int, default=0,
- help='smooth radius of y axis (default: 0)')
- parser.add_argument('--xlabel', default='Timesteps',
- help='matplotlib figure xlabel')
- parser.add_argument('--ylabel', default='Episode Reward',
- help='matplotlib figure ylabel')
- parser.add_argument(
- '--shaded-std', action='store_true',
- help='shaded region corresponding to standard deviation of the group')
- parser.add_argument('--sharex', action='store_true',
- help='whether to share x axis within multiple sub-figures')
- parser.add_argument('--sharey', action='store_true',
- help='whether to share y axis within multiple sub-figures')
- parser.add_argument('--legend-outside', action='store_true',
- help='place the legend outside of the figure')
- parser.add_argument('--xlim', type=int, default=None,
- help='x-axis limitation (default: None)')
+ parser.add_argument(
+ '--fig-length',
+ type=int,
+ default=6,
+ help='matplotlib figure length (default: 6)'
+ )
+ parser.add_argument(
+ '--fig-width',
+ type=int,
+ default=6,
+ help='matplotlib figure width (default: 6)'
+ )
+ parser.add_argument(
+ '--style',
+ default='seaborn',
+ help='matplotlib figure style (default: seaborn)'
+ )
+ parser.add_argument(
+ '--title', default=None, help='matplotlib figure title (default: None)'
+ )
+ parser.add_argument(
+ '--xkey',
+ default='env_step',
+ help='x-axis key in csv file (default: env_step)'
+ )
+ parser.add_argument(
+ '--ykey', default='rew', help='y-axis key in csv file (default: rew)'
+ )
+ parser.add_argument(
+ '--smooth', type=int, default=0, help='smooth radius of y axis (default: 0)'
+ )
+ parser.add_argument(
+ '--xlabel', default='Timesteps', help='matplotlib figure xlabel'
+ )
+ parser.add_argument(
+ '--ylabel', default='Episode Reward', help='matplotlib figure ylabel'
+ )
+ parser.add_argument(
+ '--shaded-std',
+ action='store_true',
+ help='shaded region corresponding to standard deviation of the group'
+ )
+ parser.add_argument(
+ '--sharex',
+ action='store_true',
+ help='whether to share x axis within multiple sub-figures'
+ )
+ parser.add_argument(
+ '--sharey',
+ action='store_true',
+ help='whether to share y axis within multiple sub-figures'
+ )
+ parser.add_argument(
+ '--legend-outside',
+ action='store_true',
+ help='place the legend outside of the figure'
+ )
+ parser.add_argument(
+ '--xlim', type=int, default=None, help='x-axis limitation (default: None)'
+ )
parser.add_argument('--root-dir', default='./', help='root dir (default: ./)')
parser.add_argument(
- '--file-pattern', type=str, default=r".*/test_rew_\d+seeds.csv$",
+ '--file-pattern',
+ type=str,
+ default=r".*/test_rew_\d+seeds.csv$",
help='regular expression to determine whether or not to include target csv '
- 'file, default to including all test_rew_{num}seeds.csv file under rootdir')
+ 'file, default to including all test_rew_{num}seeds.csv file under rootdir'
+ )
parser.add_argument(
- '--group-pattern', type=str, default=r"(/|^)\w*?\-v(\d|$)",
+ '--group-pattern',
+ type=str,
+ default=r"(/|^)\w*?\-v(\d|$)",
help='regular expression to group files in sub-figure, default to grouping '
- 'according to env_name dir, "" means no grouping')
+ 'according to env_name dir, "" means no grouping'
+ )
parser.add_argument(
- '--legend-pattern', type=str, default=r".*",
+ '--legend-pattern',
+ type=str,
+ default=r".*",
help='regular expression to extract legend from csv file path, default to '
- 'using file path as legend name.')
+ 'using file path as legend name.'
+ )
parser.add_argument('--show', action='store_true', help='show figure')
- parser.add_argument('--output-path', type=str,
- help='figure save path', default="./figure.png")
- parser.add_argument('--dpi', type=int, default=200,
- help='figure dpi (default: 200)')
+ parser.add_argument(
+ '--output-path', type=str, help='figure save path', default="./figure.png"
+ )
+ parser.add_argument(
+ '--dpi', type=int, default=200, help='figure dpi (default: 200)'
+ )
args = parser.parse_args()
file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern))
file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists]
@@ -207,9 +280,9 @@ def plot_figure(
sharey=args.sharey,
smooth_radius=args.smooth,
shaded_std=args.shaded_std,
- legend_outside=args.legend_outside)
+ legend_outside=args.legend_outside
+ )
if args.output_path:
- plt.savefig(args.output_path,
- dpi=args.dpi, bbox_inches='tight')
+ plt.savefig(args.output_path, dpi=args.dpi, bbox_inches='tight')
if args.show:
plt.show()
diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py
index 9e49206e0..3ed4791fd 100755
--- a/examples/mujoco/tools.py
+++ b/examples/mujoco/tools.py
@@ -1,12 +1,13 @@
#!/usr/bin/env python3
+import argparse
+import csv
import os
import re
-import csv
-import tqdm
-import argparse
-import numpy as np
from collections import defaultdict
+
+import numpy as np
+import tqdm
from tensorboard.backend.event_processing import event_accumulator
@@ -66,11 +67,13 @@ def convert_tfevents_to_csv(root_dir, refresh=False):
initial_time = ea._first_event_timestamp
content = [["env_step", "rew", "time"]]
for test_rew in ea.scalars.Items("test/rew"):
- content.append([
- round(test_rew.step, 4),
- round(test_rew.value, 4),
- round(test_rew.wall_time - initial_time, 4),
- ])
+ content.append(
+ [
+ round(test_rew.step, 4),
+ round(test_rew.value, 4),
+ round(test_rew.wall_time - initial_time, 4),
+ ]
+ )
csv.writer(open(output_file, 'w')).writerows(content)
result[output_file] = content
return result
@@ -80,13 +83,15 @@ def merge_csv(csv_files, root_dir, remove_zero=False):
"""Merge result in csv_files into a single csv file."""
assert len(csv_files) > 0
if remove_zero:
- for k, v in csv_files.items():
+ for v in csv_files.values():
if v[1][0] == 0:
v.pop(1)
sorted_keys = sorted(csv_files.keys())
sorted_values = [csv_files[k][1:] for k in sorted_keys]
- content = [["env_step", "rew", "rew:shaded"] + list(map(
- lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys))]
+ content = [
+ ["env_step", "rew", "rew:shaded"] +
+ list(map(lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys))
+ ]
for rows in zip(*sorted_values):
array = np.array(rows)
assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0])
@@ -101,11 +106,15 @@ def merge_csv(csv_files, root_dir, remove_zero=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- '--refresh', action="store_true",
- help="Re-generate all csv files instead of using existing one.")
+ '--refresh',
+ action="store_true",
+ help="Re-generate all csv files instead of using existing one."
+ )
parser.add_argument(
- '--remove-zero', action="store_true",
- help="Remove the data point of env_step == 0.")
+ '--remove-zero',
+ action="store_true",
+ help="Remove the data point of env_step == 0."
+ )
parser.add_argument('--root-dir', type=str)
args = parser.parse_args()
diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py
index 017ab7750..290cb92e5 100644
--- a/examples/vizdoom/env.py
+++ b/examples/vizdoom/env.py
@@ -1,4 +1,5 @@
import os
+
import cv2
import gym
import numpy as np
@@ -33,9 +34,8 @@ def battle_button_comb():
class Env(gym.Env):
- def __init__(
- self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False
- ):
+
+ def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False):
super().__init__()
self.save_lmp = save_lmp
self.health_setting = "battle" in cfg_path
@@ -75,8 +75,7 @@ def reset(self):
self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
self.get_obs()
self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
- self.killcount = self.game.get_game_variable(
- vzd.GameVariable.KILLCOUNT)
+ self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
return self.obs_buffer
@@ -121,7 +120,7 @@ def close(self):
obs = env.reset()
print(env.spec.reward_threshold)
print(obs.shape, action_num)
- for i in range(4000):
+ for _ in range(4000):
obs, rew, done, info = env.step(0)
if done:
env.reset()
diff --git a/examples/vizdoom/maps/spectator.py b/examples/vizdoom/maps/spectator.py
index d4d7e8c7c..2180ed7c7 100644
--- a/examples/vizdoom/maps/spectator.py
+++ b/examples/vizdoom/maps/spectator.py
@@ -10,11 +10,12 @@
from __future__ import print_function
+from argparse import ArgumentParser
from time import sleep
+
import vizdoom as vzd
-from argparse import ArgumentParser
-# import cv2
+# import cv2
if __name__ == "__main__":
parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.")
diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py
index a1e556fce..30cdb31bd 100755
--- a/examples/vizdoom/replay.py
+++ b/examples/vizdoom/replay.py
@@ -1,6 +1,7 @@
# import cv2
import sys
import time
+
import tqdm
import vizdoom as vzd
diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py
index 1123151c8..bb3a1f207 100644
--- a/examples/vizdoom/vizdoom_c51.py
+++ b/examples/vizdoom/vizdoom_c51.py
@@ -1,18 +1,18 @@
+import argparse
import os
-import torch
import pprint
-import argparse
+
import numpy as np
+import torch
+from env import Env
+from network import C51
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import C51Policy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
+from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
-
-from env import Env
-from network import C51
+from tianshou.utils import TensorboardLogger
def get_args():
@@ -40,15 +40,23 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--skip-num', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
- parser.add_argument('--save-lmp', default=False, action='store_true',
- help='save lmp file for replay whole episode')
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='watch the play of pre-trained policy only'
+ )
+ parser.add_argument(
+ '--save-lmp',
+ default=False,
+ action='store_true',
+ help='save lmp file for replay whole episode'
+ )
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
@@ -64,26 +72,36 @@ def test_c51(args=get_args()):
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
- train_envs = SubprocVectorEnv([
- lambda: Env(args.cfg_path, args.frames_stack, args.res)
- for _ in range(args.training_num)])
- test_envs = SubprocVectorEnv([
- lambda: Env(args.cfg_path, args.frames_stack,
- args.res, args.save_lmp)
- for _ in range(min(os.cpu_count() - 1, args.test_num))])
+ train_envs = SubprocVectorEnv(
+ [
+ lambda: Env(args.cfg_path, args.frames_stack, args.res)
+ for _ in range(args.training_num)
+ ]
+ )
+ test_envs = SubprocVectorEnv(
+ [
+ lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
+ for _ in range(min(os.cpu_count() - 1, args.test_num))
+ ]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
- net = C51(*args.state_shape, args.action_shape,
- args.num_atoms, args.device)
+ net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(
- net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
- args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ args.gamma,
+ args.num_atoms,
+ args.v_min,
+ args.v_max,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
@@ -92,8 +110,12 @@ def test_c51(args=get_args()):
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
- save_only_last_obs=True, stack_num=args.frames_stack)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
@@ -136,11 +158,13 @@ def watch():
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(test_envs),
- ignore_obs_next=True, save_only_last_obs=True,
- stack_num=args.frames_stack)
- collector = Collector(policy, test_envs, buffer,
- exploration_noise=True)
+ args.buffer_size,
+ buffer_num=len(test_envs),
+ ignore_obs_next=True,
+ save_only_last_obs=True,
+ stack_num=args.frames_stack
+ )
+ collector = Collector(policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
@@ -148,8 +172,9 @@ def watch():
else:
print("Testing agent ...")
test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num,
- render=args.render)
+ result = test_collector.collect(
+ n_episode=args.test_num, render=args.render
+ )
rew = result["rews"].mean()
lens = result["lens"].mean() * args.skip_num
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
@@ -163,11 +188,22 @@ def watch():
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
pprint.pprint(result)
watch()
diff --git a/setup.cfg b/setup.cfg
index d485e6d06..7bb1f1e7c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -8,6 +8,18 @@ exclude =
dist
*.egg-info
max-line-length = 87
+ignore = B305,W504,B006,B008
+
+[yapf]
+based_on_style = pep8
+dedent_closing_brackets = true
+column_limit = 87
+blank_line_before_nested_class_or_def = true
+
+[isort]
+profile = black
+multi_line_output = 3
+line_length = 87
[mypy]
files = tianshou/**/*.py
diff --git a/setup.py b/setup.py
index 208af8e14..bf48020e5 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,8 @@
# -*- coding: utf-8 -*-
import os
-from setuptools import setup, find_packages
+
+from setuptools import find_packages, setup
def get_version() -> str:
@@ -60,6 +61,9 @@ def get_version() -> str:
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"flake8",
+ "flake8-bugbear",
+ "yapf",
+ "isort",
"pytest",
"pytest-cov",
"ray>=1.0.0",
diff --git a/test/base/env.py b/test/base/env.py
index 1151f5b76..cdcb51efa 100644
--- a/test/base/env.py
+++ b/test/base/env.py
@@ -1,19 +1,28 @@
-import gym
-import time
import random
-import numpy as np
-import networkx as nx
+import time
from copy import deepcopy
-from gym.spaces import Discrete, MultiDiscrete, Box, Dict, Tuple
+
+import gym
+import networkx as nx
+import numpy as np
+from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
class MyTestEnv(gym.Env):
"""This is a "going right" task. The task is to go right ``size`` steps.
"""
- def __init__(self, size, sleep=0, dict_state=False, recurse_state=False,
- ma_rew=0, multidiscrete_action=False, random_sleep=False,
- array_state=False):
+ def __init__(
+ self,
+ size,
+ sleep=0,
+ dict_state=False,
+ recurse_state=False,
+ ma_rew=0,
+ multidiscrete_action=False,
+ random_sleep=False,
+ array_state=False
+ ):
assert dict_state + recurse_state + array_state <= 1, \
"dict_state / recurse_state / array_state can be only one true"
self.size = size
@@ -28,17 +37,32 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False,
self.steps = 0
if dict_state:
self.observation_space = Dict(
- {"index": Box(shape=(1, ), low=0, high=size - 1),
- "rand": Box(shape=(1,), low=0, high=1, dtype=np.float64)})
+ {
+ "index": Box(shape=(1, ), low=0, high=size - 1),
+ "rand": Box(shape=(1, ), low=0, high=1, dtype=np.float64)
+ }
+ )
elif recurse_state:
self.observation_space = Dict(
- {"index": Box(shape=(1, ), low=0, high=size - 1),
- "dict": Dict({
- "tuple": Tuple((Discrete(2), Box(shape=(2,),
- low=0, high=1, dtype=np.float64))),
- "rand": Box(shape=(1, 2), low=0, high=1,
- dtype=np.float64)})
- })
+ {
+ "index":
+ Box(shape=(1, ), low=0, high=size - 1),
+ "dict":
+ Dict(
+ {
+ "tuple":
+ Tuple(
+ (
+ Discrete(2),
+ Box(shape=(2, ), low=0, high=1, dtype=np.float64)
+ )
+ ),
+ "rand":
+ Box(shape=(1, 2), low=0, high=1, dtype=np.float64)
+ }
+ )
+ }
+ )
elif array_state:
self.observation_space = Box(shape=(4, 84, 84), low=0, high=255)
else:
@@ -70,13 +94,18 @@ def _get_reward(self):
def _get_state(self):
"""Generate state(observation) of MyTestEnv"""
if self.dict_state:
- return {'index': np.array([self.index], dtype=np.float32),
- 'rand': self.rng.rand(1)}
+ return {
+ 'index': np.array([self.index], dtype=np.float32),
+ 'rand': self.rng.rand(1)
+ }
elif self.recurse_state:
- return {'index': np.array([self.index], dtype=np.float32),
- 'dict': {"tuple": (np.array([1],
- dtype=int), self.rng.rand(2)),
- "rand": self.rng.rand(1, 2)}}
+ return {
+ 'index': np.array([self.index], dtype=np.float32),
+ 'dict': {
+ "tuple": (np.array([1], dtype=int), self.rng.rand(2)),
+ "rand": self.rng.rand(1, 2)
+ }
+ }
elif self.array_state:
img = np.zeros([4, 84, 84], int)
img[3, np.arange(84), np.arange(84)] = self.index
@@ -112,6 +141,7 @@ def step(self, action):
class NXEnv(gym.Env):
+
def __init__(self, size, obs_type, feat_dim=32):
self.size = size
self.feat_dim = feat_dim
diff --git a/test/base/test_batch.py b/test/base/test_batch.py
index 15357f16c..53ee8ffa3 100644
--- a/test/base/test_batch.py
+++ b/test/base/test_batch.py
@@ -1,13 +1,14 @@
-import sys
import copy
-import torch
import pickle
-import pytest
-import numpy as np
-import networkx as nx
+import sys
from itertools import starmap
-from tianshou.data import Batch, to_torch, to_numpy
+import networkx as nx
+import numpy as np
+import pytest
+import torch
+
+from tianshou.data import Batch, to_numpy, to_torch
def test_batch():
@@ -99,10 +100,13 @@ def test_batch():
assert batch_item.a.c == batch_dict['c']
assert isinstance(batch_item.a.d, torch.Tensor)
assert batch_item.a.d == batch_dict['d']
- batch2 = Batch(a=[{
- 'b': np.float64(1.0),
- 'c': np.zeros(1),
- 'd': Batch(e=np.array(3.0))}])
+ batch2 = Batch(
+ a=[{
+ 'b': np.float64(1.0),
+ 'c': np.zeros(1),
+ 'd': Batch(e=np.array(3.0))
+ }]
+ )
assert len(batch2) == 1
assert Batch().shape == []
assert Batch(a=1).shape == []
@@ -141,9 +145,12 @@ def test_batch():
assert batch2_sum.a.d.f.is_empty()
with pytest.raises(TypeError):
batch2 += [1]
- batch3 = Batch(a={
- 'c': np.zeros(1),
- 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
+ batch3 = Batch(
+ a={
+ 'c': np.zeros(1),
+ 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))
+ }
+ )
batch3.a.d[0] = {'e': 4.0}
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
@@ -202,7 +209,7 @@ def test_batch_over_batch():
assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8])
assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5])
assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6])
- batch4 = Batch(({'a': {'b': np.array([1.0])}},))
+ batch4 = Batch(({'a': {'b': np.array([1.0])}}, ))
assert batch4.a.b.ndim == 2
assert batch4.a.b[0, 0] == 1.0
# advanced slicing
@@ -239,14 +246,14 @@ def test_batch_cat_and_stack():
a = Batch(a=Batch(a=np.random.randn(3, 4)))
assert np.allclose(
np.concatenate([a.a.a, a.a.a]),
- Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)
+ Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a
+ )
# test cat with lens infer
a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
ans = Batch.cat([a, b, a])
- assert np.allclose(ans.a.a,
- np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
+ assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
assert ans.a.t.is_empty()
@@ -258,51 +265,61 @@ def test_batch_cat_and_stack():
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
- ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
- b=torch.cat([torch.zeros(3, 3), b2.b]),
- common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
+ ans = Batch(
+ a=np.concatenate([b1.a, np.zeros((4, 4))]),
+ b=torch.cat([torch.zeros(3, 3), b2.b]),
+ common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))
+ )
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
- b2 = Batch(a=Batch(),
- b=torch.rand(4, 3),
- common=Batch(c=np.random.rand(4, 5)))
+ b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
- ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
- b=torch.cat([torch.zeros(3, 3), b2.b]),
- common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
+ ans = Batch(
+ a=np.concatenate([b1.a, np.zeros((4, 4))]),
+ b=torch.cat([torch.zeros(3, 3), b2.b]),
+ common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))
+ )
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
- b2 = Batch(a=Batch(),
- b=torch.rand(4, 3),
- common=Batch(c=np.random.rand(4, 5)))
+ b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
- ans = Batch(a=Batch(),
- b=torch.cat([torch.zeros(3, 3), b2.b]),
- common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
+ ans = Batch(
+ a=Batch(),
+ b=torch.cat([torch.zeros(3, 3), b2.b]),
+ common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))
+ )
assert ans.a.is_empty()
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test stack with compatible keys
- b3 = Batch(a=np.zeros((3, 4)),
- b=torch.ones((2, 5)),
- c=Batch(d=[[1], [2]]))
- b4 = Batch(a=np.ones((3, 4)),
- b=torch.ones((2, 5)),
- c=Batch(d=[[0], [3]]))
+ b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]]))
+ b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))
assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d))))
- b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
- {'a': True, 'b': {'c': 3.0}}])
+ b5_dict = np.array(
+ [{
+ 'a': False,
+ 'b': {
+ 'c': 2.0,
+ 'd': 1.0
+ }
+ }, {
+ 'a': True,
+ 'b': {
+ 'c': 3.0
+ }
+ }]
+ )
b5 = Batch(b5_dict)
assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True)
assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0))
@@ -335,15 +352,16 @@ def test_batch_cat_and_stack():
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
- assert np.allclose(test.common.c,
- np.stack([b1.common.c, b2.common.c], axis=-1))
+ assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
- ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
- b=torch.stack([torch.zeros(4, 6), b2.b]),
- common=Batch(c=np.stack([b1.common.c, b2.common.c])))
+ ans = Batch(
+ a=np.stack([b1.a, np.zeros((4, 4))]),
+ b=torch.stack([torch.zeros(4, 6), b2.b]),
+ common=Batch(c=np.stack([b1.common.c, b2.common.c]))
+ )
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
@@ -369,8 +387,8 @@ def test_batch_over_batch_to_torch():
batch = Batch(
a=np.float64(1.0),
b=Batch(
- c=np.ones((1,), dtype=np.float32),
- d=torch.ones((1,), dtype=torch.float64)
+ c=np.ones((1, ), dtype=np.float32),
+ d=torch.ones((1, ), dtype=torch.float64)
)
)
batch.b.__dict__['e'] = 1 # bypass the check
@@ -397,8 +415,8 @@ def test_utils_to_torch_numpy():
batch = Batch(
a=np.float64(1.0),
b=Batch(
- c=np.ones((1,), dtype=np.float32),
- d=torch.ones((1,), dtype=torch.float64)
+ c=np.ones((1, ), dtype=np.float32),
+ d=torch.ones((1, ), dtype=torch.float64)
)
)
a_torch_float = to_torch(batch.a, dtype=torch.float32)
@@ -464,8 +482,7 @@ def test_utils_to_torch_numpy():
def test_batch_pickle():
- batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
- np=np.zeros([3, 4]))
+ batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
assert batch.obs.a == batch_pk.obs.a
assert torch.all(batch.obs.c == batch_pk.obs.c)
@@ -473,7 +490,7 @@ def test_batch_pickle():
def test_batch_from_to_numpy_without_copy():
- batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
+ batch = Batch(a=np.ones((1, )), b=Batch(c=np.ones((1, ))))
a_mem_addr_orig = batch.a.__array_interface__['data'][0]
c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
batch.to_torch()
@@ -517,19 +534,35 @@ def test_batch_copy():
def test_batch_empty():
- b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
- {'a': True, 'b': {'c': 3.0}}])
+ b5_dict = np.array(
+ [{
+ 'a': False,
+ 'b': {
+ 'c': 2.0,
+ 'd': 1.0
+ }
+ }, {
+ 'a': True,
+ 'b': {
+ 'c': 3.0
+ }
+ }]
+ )
b5 = Batch(b5_dict)
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
- data = Batch(a=[False, True],
- b={'c': np.array([2., 'st'], dtype=object),
- 'd': [1, None],
- 'e': [2., float('nan')]},
- c=np.array([1, 3, 4], dtype=int),
- t=torch.tensor([4, 5, 6, 7.]))
+ data = Batch(
+ a=[False, True],
+ b={
+ 'c': np.array([2., 'st'], dtype=object),
+ 'd': [1, None],
+ 'e': [2., float('nan')]
+ },
+ c=np.array([1, 3, 4], dtype=int),
+ t=torch.tensor([4, 5, 6, 7.])
+ )
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
@@ -550,9 +583,9 @@ def test_batch_empty():
def test_batch_standard_compatibility():
- batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]),
- b=Batch(),
- c=np.array([5.0, 6.0]))
+ batch = Batch(
+ a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])
+ )
batch_mean = np.mean(batch)
assert isinstance(batch_mean, Batch)
assert sorted(batch_mean.keys()) == ['a', 'b', 'c']
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 39e5badd5..c1568c75d 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -1,18 +1,23 @@
import os
-import h5py
-import torch
import pickle
-import pytest
import tempfile
-import numpy as np
from timeit import timeit
-from tianshou.data.utils.converter import to_hdf5
-from tianshou.data import Batch, SegmentTree, ReplayBuffer
-from tianshou.data import PrioritizedReplayBuffer
-from tianshou.data import VectorReplayBuffer, CachedReplayBuffer
-from tianshou.data import PrioritizedVectorReplayBuffer
+import h5py
+import numpy as np
+import pytest
+import torch
+from tianshou.data import (
+ Batch,
+ CachedReplayBuffer,
+ PrioritizedReplayBuffer,
+ PrioritizedVectorReplayBuffer,
+ ReplayBuffer,
+ SegmentTree,
+ VectorReplayBuffer,
+)
+from tianshou.data.utils.converter import to_hdf5
if __name__ == '__main__':
from env import MyTestEnv
@@ -29,8 +34,9 @@ def test_replaybuffer(size=10, bufsize=20):
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
- buf.add(Batch(obs=obs, act=[a], rew=rew,
- done=done, obs_next=obs_next, info=info))
+ buf.add(
+ Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)
+ )
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
assert buf.act.dtype == int
@@ -43,8 +49,20 @@ def test_replaybuffer(size=10, bufsize=20):
# neg bsz should return empty index
assert b.sample_indices(-1).tolist() == []
ptr, ep_rew, ep_len, ep_idx = b.add(
- Batch(obs=1, act=1, rew=1, done=1, obs_next='str',
- info={'a': 3, 'b': {'c': 5.0}}))
+ Batch(
+ obs=1,
+ act=1,
+ rew=1,
+ done=1,
+ obs_next='str',
+ info={
+ 'a': 3,
+ 'b': {
+ 'c': 5.0
+ }
+ }
+ )
+ )
assert b.obs[0] == 1
assert b.done[0]
assert b.obs_next[0] == 'str'
@@ -54,13 +72,24 @@ def test_replaybuffer(size=10, bufsize=20):
assert np.all(b.info.a[1:] == 0)
assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float
assert np.all(b.info.b.c[1:] == 0.0)
- assert ptr.shape == (1,) and ptr[0] == 0
- assert ep_rew.shape == (1,) and ep_rew[0] == 1
- assert ep_len.shape == (1,) and ep_len[0] == 1
- assert ep_idx.shape == (1,) and ep_idx[0] == 0
+ assert ptr.shape == (1, ) and ptr[0] == 0
+ assert ep_rew.shape == (1, ) and ep_rew[0] == 1
+ assert ep_len.shape == (1, ) and ep_len[0] == 1
+ assert ep_idx.shape == (1, ) and ep_idx[0] == 0
# test extra keys pop up, the buffer should handle it dynamically
- batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2",
- info={"a": 4, "d": {"e": -np.inf}})
+ batch = Batch(
+ obs=2,
+ act=2,
+ rew=2,
+ done=0,
+ obs_next="str2",
+ info={
+ "a": 4,
+ "d": {
+ "e": -np.inf
+ }
+ }
+ )
b.add(batch)
info_keys = ["a", "b", "d"]
assert set(b.info.keys()) == set(info_keys)
@@ -71,10 +100,10 @@ def test_replaybuffer(size=10, bufsize=20):
batch.info.e = np.zeros([1, 4])
batch = Batch.stack([batch])
ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0])
- assert ptr.shape == (1,) and ptr[0] == 2
- assert ep_rew.shape == (1,) and ep_rew[0] == 4
- assert ep_len.shape == (1,) and ep_len[0] == 2
- assert ep_idx.shape == (1,) and ep_idx[0] == 1
+ assert ptr.shape == (1, ) and ptr[0] == 2
+ assert ep_rew.shape == (1, ) and ep_rew[0] == 4
+ assert ep_len.shape == (1, ) and ep_len[0] == 2
+ assert ep_idx.shape == (1, ) and ep_idx[0] == 1
assert set(b.info.keys()) == set(info_keys + ["e"])
assert b.info.e.shape == (b.maxsize, 1, 4)
with pytest.raises(IndexError):
@@ -92,14 +121,22 @@ def test_ignore_obs_next(size=10):
# Issue 82
buf = ReplayBuffer(size, ignore_obs_next=True)
for i in range(size):
- buf.add(Batch(obs={'mask1': np.array([i, 1, 1, 0, 0]),
- 'mask2': np.array([i + 4, 0, 1, 0, 0]),
- 'mask': i},
- act={'act_id': i,
- 'position_id': i + 3},
- rew=i,
- done=i % 3 == 0,
- info={'if': i}))
+ buf.add(
+ Batch(
+ obs={
+ 'mask1': np.array([i, 1, 1, 0, 0]),
+ 'mask2': np.array([i + 4, 0, 1, 0, 0]),
+ 'mask': i
+ },
+ act={
+ 'act_id': i,
+ 'position_id': i + 3
+ },
+ rew=i,
+ done=i % 3 == 0,
+ info={'if': i}
+ )
+ )
indices = np.arange(len(buf))
orig = np.arange(len(buf))
data = buf[indices]
@@ -113,15 +150,25 @@ def test_ignore_obs_next(size=10):
data = buf[indices]
data2 = buf[indices]
assert np.allclose(data.obs_next.mask, data2.obs_next.mask)
- assert np.allclose(data.obs_next.mask, np.array([
- [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3],
- [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6],
- [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]]))
+ assert np.allclose(
+ data.obs_next.mask,
+ np.array(
+ [
+ [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], [4, 4, 4, 5],
+ [4, 4, 5, 6], [4, 4, 5, 6], [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]
+ ]
+ )
+ )
assert np.allclose(data.info['if'], data2.info['if'])
- assert np.allclose(data.info['if'], np.array([
- [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
- [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6],
- [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]]))
+ assert np.allclose(
+ data.info['if'],
+ np.array(
+ [
+ [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [4, 4, 4, 4],
+ [4, 4, 4, 5], [4, 4, 5, 6], [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]
+ ]
+ )
+ )
assert data.obs_next
@@ -131,20 +178,30 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3):
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
obs = env.reset(1)
- for i in range(16):
+ for _ in range(16):
obs_next, rew, done, info = env.step(1)
buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info))
buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info))
- buf3.add(Batch(obs=[obs, obs, obs], act=1, rew=rew,
- done=done, obs_next=[obs, obs], info=info))
+ buf3.add(
+ Batch(
+ obs=[obs, obs, obs],
+ act=1,
+ rew=rew,
+ done=done,
+ obs_next=[obs, obs],
+ info=info
+ )
+ )
obs = obs_next
if done:
obs = env.reset(1)
indices = np.arange(len(buf))
- assert np.allclose(buf.get(indices, 'obs')[..., 0], [
- [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
- [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
- [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
+ assert np.allclose(
+ buf.get(indices, 'obs')[..., 0], [
+ [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2],
+ [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]
+ ]
+ )
assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs'))
assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs_next'))
_, indices = buf2.sample(0)
@@ -165,8 +222,15 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
- batch = Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next,
- info=info, policy=np.random.randn() - 0.5)
+ batch = Batch(
+ obs=obs,
+ act=a,
+ rew=rew,
+ done=done,
+ obs_next=obs_next,
+ info=info,
+ policy=np.random.randn() - 0.5
+ )
batch_stack = Batch.stack([batch, batch, batch])
buf.add(Batch.stack([batch]), buffer_ids=[0])
buf2.add(batch_stack, buffer_ids=[0, 1, 2])
@@ -179,12 +243,12 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
assert len(buf) == min(bufsize, i + 1)
assert len(buf2) == min(bufsize, 3 * (i + 1))
# check single buffer's data
- assert buf.info.key.shape == (buf.maxsize,)
+ assert buf.info.key.shape == (buf.maxsize, )
assert buf.rew.dtype == float
assert buf.done.dtype == bool
data, indices = buf.sample(len(buf) // 2)
buf.update_weight(indices, -data.weight / 2)
- assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2) ** buf._alpha)
+ assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2)**buf._alpha)
# check multi buffer's data
assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1)
batch, indices = buf2.sample(10)
@@ -200,8 +264,15 @@ def test_update():
buf1 = ReplayBuffer(4, stack_num=2)
buf2 = ReplayBuffer(4, stack_num=2)
for i in range(5):
- buf1.add(Batch(obs=np.array([i]), act=float(i), rew=i * i,
- done=i % 2 == 0, info={'incident': 'found'}))
+ buf1.add(
+ Batch(
+ obs=np.array([i]),
+ act=float(i),
+ rew=i * i,
+ done=i % 2 == 0,
+ info={'incident': 'found'}
+ )
+ )
assert len(buf1) > len(buf2)
buf2.update(buf1)
assert len(buf1) == len(buf2)
@@ -242,11 +313,10 @@ def test_segtree():
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
- for i in range(10):
+ for _ in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
- assert np.allclose(realop(naive[left:right]),
- tree.reduce(left, right))
+ assert np.allclose(realop(naive[left:right]), tree.reduce(left, right))
# large test
actual_len = 16384
tree = SegmentTree(actual_len)
@@ -257,11 +327,10 @@ def test_segtree():
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
- for i in range(10):
+ for _ in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
- assert np.allclose(realop(naive[left:right]),
- tree.reduce(left, right))
+ assert np.allclose(realop(naive[left:right]), tree.reduce(left, right))
# test prefix-sum-idx
actual_len = 8
@@ -280,8 +349,9 @@ def test_segtree():
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
tree = SegmentTree(10)
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
- assert np.allclose(tree.get_prefix_sum_idx(
- np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
+ assert np.allclose(
+ tree.get_prefix_sum_idx(np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]
+ )
with pytest.raises(AssertionError):
tree.get_prefix_sum_idx(.2)
# test large prefix-sum-idx
@@ -321,8 +391,15 @@ def test_pickle():
for i in range(4):
vbuf.add(Batch(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0))
for i in range(5):
- pbuf.add(Batch(obs=Batch(index=np.array([i])),
- act=2, rew=rew, done=0, info=np.random.rand()))
+ pbuf.add(
+ Batch(
+ obs=Batch(index=np.array([i])),
+ act=2,
+ rew=rew,
+ done=0,
+ info=np.random.rand()
+ )
+ )
# save & load
_vbuf = pickle.loads(pickle.dumps(vbuf))
_pbuf = pickle.loads(pickle.dumps(pbuf))
@@ -330,8 +407,9 @@ def test_pickle():
assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act)
# make sure the meta var is identical
assert _vbuf.stack_num == vbuf.stack_num
- assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))],
- pbuf.weight[np.arange(len(pbuf))])
+ assert np.allclose(
+ _pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))]
+ )
def test_hdf5():
@@ -349,7 +427,13 @@ def test_hdf5():
'act': i,
'rew': np.array([1, 2]),
'done': i % 3 == 2,
- 'info': {"number": {"n": i, "t": info_t}, 'extra': None},
+ 'info': {
+ "number": {
+ "n": i,
+ "t": info_t
+ },
+ 'extra': None
+ },
}
buffers["array"].add(Batch(kwargs))
buffers["prioritized"].add(Batch(kwargs))
@@ -377,10 +461,8 @@ def test_hdf5():
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
- assert np.all(
- buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
- assert np.all(
- buffers[k][:].info.extra == _buffers[k][:].info.extra)
+ assert np.all(buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
+ assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra)
# raise exception when value cannot be pickled
data = {"not_supported": lambda x: x * x}
@@ -423,15 +505,16 @@ def test_replaybuffermanager():
indices_next = buf.next(indices)
assert np.allclose(indices_next, indices), indices_next
data = np.array([0, 0, 0, 0])
- buf.add(Batch(obs=data, act=data, rew=data, done=data),
- buffer_ids=[0, 1, 2, 3])
- buf.add(Batch(obs=data, act=data, rew=data, done=1 - data),
- buffer_ids=[0, 1, 2, 3])
+ buf.add(Batch(obs=data, act=data, rew=data, done=data), buffer_ids=[0, 1, 2, 3])
+ buf.add(
+ Batch(obs=data, act=data, rew=data, done=1 - data), buffer_ids=[0, 1, 2, 3]
+ )
assert len(buf) == 12
- buf.add(Batch(obs=data, act=data, rew=data, done=data),
- buffer_ids=[0, 1, 2, 3])
- buf.add(Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]),
- buffer_ids=[0, 1, 2, 3])
+ buf.add(Batch(obs=data, act=data, rew=data, done=data), buffer_ids=[0, 1, 2, 3])
+ buf.add(
+ Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]),
+ buffer_ids=[0, 1, 2, 3]
+ )
assert len(buf) == 20
indices = buf.sample_indices(120000)
assert np.bincount(indices).min() >= 5000
@@ -439,44 +522,135 @@ def test_replaybuffermanager():
indices = buf.sample_indices(0)
assert np.allclose(indices, np.arange(len(buf)))
# check the actual data stored in buf._meta
- assert np.allclose(buf.done, [
- 0, 0, 1, 0, 0,
- 0, 0, 1, 0, 1,
- 1, 0, 1, 0, 0,
- 1, 0, 1, 0, 1,
- ])
- assert np.allclose(buf.prev(indices), [
- 0, 0, 1, 3, 3,
- 5, 5, 6, 8, 8,
- 10, 11, 11, 13, 13,
- 15, 16, 16, 18, 18,
- ])
- assert np.allclose(buf.next(indices), [
- 1, 2, 2, 4, 4,
- 6, 7, 7, 9, 9,
- 10, 12, 12, 14, 14,
- 15, 17, 17, 19, 19,
- ])
+ assert np.allclose(
+ buf.done, [
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 1,
+ 1,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ ]
+ )
+ assert np.allclose(
+ buf.prev(indices), [
+ 0,
+ 0,
+ 1,
+ 3,
+ 3,
+ 5,
+ 5,
+ 6,
+ 8,
+ 8,
+ 10,
+ 11,
+ 11,
+ 13,
+ 13,
+ 15,
+ 16,
+ 16,
+ 18,
+ 18,
+ ]
+ )
+ assert np.allclose(
+ buf.next(indices), [
+ 1,
+ 2,
+ 2,
+ 4,
+ 4,
+ 6,
+ 7,
+ 7,
+ 9,
+ 9,
+ 10,
+ 12,
+ 12,
+ 14,
+ 14,
+ 15,
+ 17,
+ 17,
+ 19,
+ 19,
+ ]
+ )
assert np.allclose(buf.unfinished_index(), [4, 14])
ptr, ep_rew, ep_len, ep_idx = buf.add(
- Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2])
+ Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2]
+ )
assert np.all(ep_len == [3]) and np.all(ep_rew == [1])
assert np.all(ptr == [10]) and np.all(ep_idx == [13])
assert np.allclose(buf.unfinished_index(), [4])
indices = list(sorted(buf.sample_indices(0)))
assert np.allclose(indices, np.arange(len(buf)))
- assert np.allclose(buf.prev(indices), [
- 0, 0, 1, 3, 3,
- 5, 5, 6, 8, 8,
- 14, 11, 11, 13, 13,
- 15, 16, 16, 18, 18,
- ])
- assert np.allclose(buf.next(indices), [
- 1, 2, 2, 4, 4,
- 6, 7, 7, 9, 9,
- 10, 12, 12, 14, 10,
- 15, 17, 17, 19, 19,
- ])
+ assert np.allclose(
+ buf.prev(indices), [
+ 0,
+ 0,
+ 1,
+ 3,
+ 3,
+ 5,
+ 5,
+ 6,
+ 8,
+ 8,
+ 14,
+ 11,
+ 11,
+ 13,
+ 13,
+ 15,
+ 16,
+ 16,
+ 18,
+ 18,
+ ]
+ )
+ assert np.allclose(
+ buf.next(indices), [
+ 1,
+ 2,
+ 2,
+ 4,
+ 4,
+ 6,
+ 7,
+ 7,
+ 9,
+ 9,
+ 10,
+ 12,
+ 12,
+ 14,
+ 10,
+ 15,
+ 17,
+ 17,
+ 19,
+ 19,
+ ]
+ )
# corner case: list, int and -1
assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0]
assert buf.next(-1) == buf.next([buf.maxsize - 1])[0]
@@ -493,7 +667,8 @@ def test_cachedbuffer():
assert buf.sample_indices(0).tolist() == []
# check the normal function/usage/storage in CachedReplayBuffer
ptr, ep_rew, ep_len, ep_idx = buf.add(
- Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1])
+ Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1]
+ )
obs = np.zeros(buf.maxsize)
obs[15] = 1
indices = buf.sample_indices(0)
@@ -504,7 +679,8 @@ def test_cachedbuffer():
assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0])
assert np.all(ptr == [15]) and np.all(ep_idx == [15])
ptr, ep_rew, ep_len, ep_idx = buf.add(
- Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3])
+ Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3]
+ )
obs[[0, 25]] = 2
indices = buf.sample_indices(0)
assert np.allclose(indices, [0, 15])
@@ -516,8 +692,8 @@ def test_cachedbuffer():
assert np.allclose(buf.unfinished_index(), [15])
assert np.allclose(buf.sample_indices(0), [0, 15])
ptr, ep_rew, ep_len, ep_idx = buf.add(
- Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]),
- buffer_ids=[3, 1])
+ Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), buffer_ids=[3, 1]
+ )
assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0])
assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1])
obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3]
@@ -540,16 +716,35 @@ def test_cachedbuffer():
buf.add(Batch(obs=data, act=data, rew=rew, done=[1, 1, 1, 1]))
buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0]))
ptr, ep_rew, ep_len, ep_idx = buf.add(
- Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1]))
+ Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1])
+ )
assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1])
assert np.all(ep_len == [0, 2, 0, 2])
assert np.all(ep_rew == [data, data + 2, data, data + 2])
- assert np.allclose(buf.done, [
- 0, 0, 1, 0, 0,
- 0, 1, 1, 0, 0,
- 0, 0, 0, 0, 0,
- 0, 1, 0, 0, 0,
- ])
+ assert np.allclose(
+ buf.done, [
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ ]
+ )
indices = buf.sample_indices(0)
assert np.allclose(indices, [0, 1, 10, 11])
assert np.allclose(buf.prev(indices), [0, 0, 10, 10])
@@ -564,14 +759,16 @@ def test_multibuf_stack():
env = MyTestEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer(
- ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
- cached_num, size)
+ ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num,
+ size
+ )
# test if CachedReplayBuffer can handle corner case:
# buffer + stack_num + ignore_obs_next + sample_avail
buf5 = CachedReplayBuffer(
- ReplayBuffer(bufsize, stack_num=stack_num,
- ignore_obs_next=True, sample_avail=True),
- cached_num, size)
+ ReplayBuffer(
+ bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True
+ ), cached_num, size
+ )
obs = env.reset(1)
for i in range(18):
obs_next, rew, done, info = env.step(1)
@@ -581,8 +778,14 @@ def test_multibuf_stack():
done_list = [done] * cached_num
obs_next_list = -obs_list
info_list = [info] * cached_num
- batch = Batch(obs=obs_list, act=act_list, rew=rew_list,
- done=done_list, obs_next=obs_next_list, info=info_list)
+ batch = Batch(
+ obs=obs_list,
+ act=act_list,
+ rew=rew_list,
+ done=done_list,
+ obs_next=obs_next_list,
+ info=info_list
+ )
buf5.add(batch)
buf4.add(batch)
assert np.all(buf4.obs == buf5.obs)
@@ -591,35 +794,105 @@ def test_multibuf_stack():
if done:
obs = env.reset(1)
# check the `add` order is correct
- assert np.allclose(buf4.obs.reshape(-1), [
- 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer
- 1, 2, 3, 4, 0, # cached_buffer[0]
- 6, 7, 8, 9, 0, # cached_buffer[1]
- 11, 12, 13, 14, 0, # cached_buffer[2]
- ]), buf4.obs
- assert np.allclose(buf4.done, [
- 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer
- 0, 0, 0, 1, 0, # cached_buffer[0]
- 0, 0, 0, 1, 0, # cached_buffer[1]
- 0, 0, 0, 1, 0, # cached_buffer[2]
- ]), buf4.done
+ assert np.allclose(
+ buf4.obs.reshape(-1),
+ [
+ 12,
+ 13,
+ 14,
+ 4,
+ 6,
+ 7,
+ 8,
+ 9,
+ 11, # main_buffer
+ 1,
+ 2,
+ 3,
+ 4,
+ 0, # cached_buffer[0]
+ 6,
+ 7,
+ 8,
+ 9,
+ 0, # cached_buffer[1]
+ 11,
+ 12,
+ 13,
+ 14,
+ 0, # cached_buffer[2]
+ ]
+ ), buf4.obs
+ assert np.allclose(
+ buf4.done,
+ [
+ 0,
+ 0,
+ 1,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0, # main_buffer
+ 0,
+ 0,
+ 0,
+ 1,
+ 0, # cached_buffer[0]
+ 0,
+ 0,
+ 0,
+ 1,
+ 0, # cached_buffer[1]
+ 0,
+ 0,
+ 0,
+ 1,
+ 0, # cached_buffer[2]
+ ]
+ ), buf4.done
assert np.allclose(buf4.unfinished_index(), [10, 15, 20])
indices = sorted(buf4.sample_indices(0))
assert np.allclose(indices, list(range(bufsize)) + [9, 10, 14, 15, 19, 20])
- assert np.allclose(buf4[indices].obs[..., 0], [
- [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14],
- [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7],
- [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11],
- [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6],
- [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12],
- ])
- assert np.allclose(buf4[indices].obs_next[..., 0], [
- [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14],
- [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8],
- [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12],
- [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7],
- [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12],
- ])
+ assert np.allclose(
+ buf4[indices].obs[..., 0], [
+ [11, 11, 11, 12],
+ [11, 11, 12, 13],
+ [11, 12, 13, 14],
+ [4, 4, 4, 4],
+ [6, 6, 6, 6],
+ [6, 6, 6, 7],
+ [6, 6, 7, 8],
+ [6, 7, 8, 9],
+ [11, 11, 11, 11],
+ [1, 1, 1, 1],
+ [1, 1, 1, 2],
+ [6, 6, 6, 6],
+ [6, 6, 6, 7],
+ [11, 11, 11, 11],
+ [11, 11, 11, 12],
+ ]
+ )
+ assert np.allclose(
+ buf4[indices].obs_next[..., 0], [
+ [11, 11, 12, 13],
+ [11, 12, 13, 14],
+ [11, 12, 13, 14],
+ [4, 4, 4, 4],
+ [6, 6, 6, 7],
+ [6, 6, 7, 8],
+ [6, 7, 8, 9],
+ [6, 7, 8, 9],
+ [11, 11, 11, 12],
+ [1, 1, 1, 2],
+ [1, 1, 1, 2],
+ [6, 6, 6, 7],
+ [6, 6, 6, 7],
+ [11, 11, 11, 12],
+ [11, 11, 11, 12],
+ ]
+ )
indices = buf5.sample_indices(0)
assert np.allclose(sorted(indices), [2, 7])
assert np.all(np.isin(buf5.sample_indices(100), indices))
@@ -632,12 +905,24 @@ def test_multibuf_stack():
batch, _ = buf5.sample(0)
# test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next
buf6 = CachedReplayBuffer(
- ReplayBuffer(bufsize, stack_num=stack_num,
- save_only_last_obs=True, ignore_obs_next=True),
- cached_num, size)
+ ReplayBuffer(
+ bufsize,
+ stack_num=stack_num,
+ save_only_last_obs=True,
+ ignore_obs_next=True
+ ), cached_num, size
+ )
obs = np.random.rand(size, 4, 84, 84)
- buf6.add(Batch(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1],
- obs_next=[obs[3], obs[1]]), buffer_ids=[1, 2])
+ buf6.add(
+ Batch(
+ obs=[obs[2], obs[0]],
+ act=[1, 1],
+ rew=[0, 0],
+ done=[0, 1],
+ obs_next=[obs[3], obs[1]]
+ ),
+ buffer_ids=[1, 2]
+ )
assert buf6.obs.shape == (buf6.maxsize, 84, 84)
assert np.allclose(buf6.obs[0], obs[0, -1])
assert np.allclose(buf6.obs[14], obs[2, -1])
@@ -660,12 +945,20 @@ def test_multibuf_hdf5():
'act': i,
'rew': np.array([1, 2]),
'done': i % 3 == 2,
- 'info': {"number": {"n": i, "t": info_t}, 'extra': None},
+ 'info': {
+ "number": {
+ "n": i,
+ "t": info_t
+ },
+ 'extra': None
+ },
}
- buffers["vector"].add(Batch.stack([kwargs, kwargs, kwargs]),
- buffer_ids=[0, 1, 2])
- buffers["cached"].add(Batch.stack([kwargs, kwargs, kwargs]),
- buffer_ids=[0, 1, 2])
+ buffers["vector"].add(
+ Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]
+ )
+ buffers["cached"].add(
+ Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2]
+ )
# save
paths = {}
@@ -696,7 +989,12 @@ def test_multibuf_hdf5():
'act': 5,
'rew': np.array([2, 1]),
'done': False,
- 'info': {"number": {"n": i}, 'Timelimit.truncate': True},
+ 'info': {
+ "number": {
+ "n": i
+ },
+ 'Timelimit.truncate': True
+ },
}
buffers[k].add(Batch.stack([kwargs, kwargs, kwargs, kwargs]))
act = np.zeros(buffers[k].maxsize)
diff --git a/test/base/test_collector.py b/test/base/test_collector.py
index 79d1430b8..61bd5a6fc 100644
--- a/test/base/test_collector.py
+++ b/test/base/test_collector.py
@@ -1,17 +1,19 @@
-import tqdm
-import pytest
import numpy as np
+import pytest
+import tqdm
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import BasePolicy
-from tianshou.env import DummyVectorEnv, SubprocVectorEnv
-from tianshou.data import Batch, Collector, AsyncCollector
from tianshou.data import (
- ReplayBuffer,
+ AsyncCollector,
+ Batch,
+ CachedReplayBuffer,
+ Collector,
PrioritizedReplayBuffer,
+ ReplayBuffer,
VectorReplayBuffer,
- CachedReplayBuffer,
)
+from tianshou.env import DummyVectorEnv, SubprocVectorEnv
+from tianshou.policy import BasePolicy
if __name__ == '__main__':
from env import MyTestEnv, NXEnv
@@ -20,6 +22,7 @@
class MyPolicy(BasePolicy):
+
def __init__(self, dict_state=False, need_state=True):
"""
:param bool dict_state: if the observation of the environment is a dict
@@ -44,6 +47,7 @@ def learn(self):
class Logger:
+
def __init__(self, writer):
self.cnt = 0
self.writer = writer
@@ -56,8 +60,7 @@ def preprocess_fn(self, **kwargs):
info = kwargs['info']
info.rew = kwargs['rew']
if 'key' in info.keys():
- self.writer.add_scalar(
- 'key', np.mean(info.key), global_step=self.cnt)
+ self.writer.add_scalar('key', np.mean(info.key), global_step=self.cnt)
self.cnt += 1
return Batch(info=info)
else:
@@ -91,13 +94,12 @@ def test_collector():
c0.collect(n_episode=3)
assert len(c0.buffer) == 8
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
- assert np.allclose(c0.buffer[:].obs_next[..., 0],
- [1, 2, 1, 2, 1, 2, 1, 2])
+ assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True)
c1 = Collector(
- policy, venv,
- VectorReplayBuffer(total_size=100, buffer_num=4),
- logger.preprocess_fn)
+ policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4),
+ logger.preprocess_fn
+ )
c1.collect(n_step=8)
obs = np.zeros(100)
obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]
@@ -108,13 +110,15 @@ def test_collector():
assert len(c1.buffer) == 16
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c1.buffer.obs[:, 0], obs)
- assert np.allclose(c1.buffer[:].obs_next[..., 0],
- [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
+ assert np.allclose(
+ c1.buffer[:].obs_next[..., 0],
+ [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
+ )
c1.collect(n_episode=4, random=True)
c2 = Collector(
- policy, dum,
- VectorReplayBuffer(total_size=100, buffer_num=4),
- logger.preprocess_fn)
+ policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4),
+ logger.preprocess_fn
+ )
c2.collect(n_episode=7)
obs1 = obs.copy()
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
@@ -139,10 +143,10 @@ def test_collector():
# test NXEnv
for obs_type in ["array", "object"]:
- envs = SubprocVectorEnv([
- lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]])
- c3 = Collector(policy, envs,
- VectorReplayBuffer(total_size=100, buffer_num=4))
+ envs = SubprocVectorEnv(
+ [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]
+ )
+ c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=6)
assert c3.buffer.obs.dtype == object
@@ -151,23 +155,23 @@ def test_collector_with_async():
env_lens = [2, 3, 4, 5]
writer = SummaryWriter('log/async_collector')
logger = Logger(writer)
- env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True)
- for i in env_lens]
+ env_fns = [
+ lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens
+ ]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy()
bufsize = 60
c1 = AsyncCollector(
- policy, venv,
- VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
- logger.preprocess_fn)
+ policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
+ logger.preprocess_fn
+ )
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
- for i, count in enumerate(
- np.bincount(result["lens"], minlength=6)[2:]):
+ for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
@@ -176,8 +180,7 @@ def test_collector_with_async():
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
- assert np.all(buf.obs_next[indices].reshape(
- count, env_len) == seq + 1)
+ assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step)
@@ -196,21 +199,21 @@ def test_collector_with_async():
def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
- c0 = Collector(policy, env, ReplayBuffer(size=100),
- Logger.single_preprocess_fn)
+ c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
c0.collect(n_step=3)
c0.collect(n_episode=2)
assert len(c0.buffer) == 10
- env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
- for i in [2, 3, 4, 5]]
+ env_fns = [
+ lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]
+ ]
envs = DummyVectorEnv(env_fns)
envs.seed(666)
obs = envs.reset()
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
c1 = Collector(
- policy, envs,
- VectorReplayBuffer(total_size=100, buffer_num=4),
- Logger.single_preprocess_fn)
+ policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
+ Logger.single_preprocess_fn
+ )
c1.collect(n_step=12)
result = c1.collect(n_episode=8)
assert result['n/ep'] == 8
@@ -221,25 +224,104 @@ def test_collector_with_dict_state():
c0.buffer.update(c1.buffer)
assert len(c0.buffer) in [42, 43]
if len(c0.buffer) == 42:
- assert np.all(c0.buffer[:].obs.index[..., 0] == [
- 0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
- 0, 1, 0, 1, 0, 1, 0, 1,
- 0, 1, 2, 0, 1, 2,
- 0, 1, 2, 3, 0, 1, 2, 3,
- 0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
- ]), c0.buffer[:].obs.index[..., 0]
+ assert np.all(
+ c0.buffer[:].obs.index[..., 0] == [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 2,
+ 0,
+ 1,
+ 2,
+ 0,
+ 1,
+ 2,
+ 3,
+ 0,
+ 1,
+ 2,
+ 3,
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ ]
+ ), c0.buffer[:].obs.index[..., 0]
else:
- assert np.all(c0.buffer[:].obs.index[..., 0] == [
- 0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
- 0, 1, 0, 1, 0, 1,
- 0, 1, 2, 0, 1, 2, 0, 1, 2,
- 0, 1, 2, 3, 0, 1, 2, 3,
- 0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
- ]), c0.buffer[:].obs.index[..., 0]
+ assert np.all(
+ c0.buffer[:].obs.index[..., 0] == [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 2,
+ 0,
+ 1,
+ 2,
+ 0,
+ 1,
+ 2,
+ 0,
+ 1,
+ 2,
+ 3,
+ 0,
+ 1,
+ 2,
+ 3,
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ ]
+ ), c0.buffer[:].obs.index[..., 0]
c2 = Collector(
- policy, envs,
- VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
- Logger.single_preprocess_fn)
+ policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
+ Logger.single_preprocess_fn
+ )
c2.collect(n_episode=10)
batch, _ = c2.buffer.sample(10)
@@ -247,20 +329,18 @@ def test_collector_with_dict_state():
def test_collector_with_ma():
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy()
- c0 = Collector(policy, env, ReplayBuffer(size=100),
- Logger.single_preprocess_fn)
+ c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
# n_step=3 will collect a full episode
r = c0.collect(n_step=3)['rews']
assert len(r) == 0
r = c0.collect(n_episode=2)['rews']
assert r.shape == (2, 4) and np.all(r == 1)
- env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
- for i in [2, 3, 4, 5]]
+ env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
c1 = Collector(
- policy, envs,
- VectorReplayBuffer(total_size=100, buffer_num=4),
- Logger.single_preprocess_fn)
+ policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
+ Logger.single_preprocess_fn
+ )
r = c1.collect(n_step=12)['rews']
assert r.shape == (2, 4) and np.all(r == 1), r
r = c1.collect(n_episode=8)['rews']
@@ -271,26 +351,101 @@ def test_collector_with_ma():
assert len(c0.buffer) in [42, 43]
if len(c0.buffer) == 42:
rew = [
- 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
- 0, 1, 0, 1, 0, 1, 0, 1,
- 0, 0, 1, 0, 0, 1,
- 0, 0, 0, 1, 0, 0, 0, 1,
- 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
]
else:
rew = [
- 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
- 0, 1, 0, 1, 0, 1,
- 0, 0, 1, 0, 0, 1, 0, 0, 1,
- 0, 0, 0, 1, 0, 0, 0, 1,
- 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
]
assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew])
assert np.all(c0.buffer[:].done == rew)
c2 = Collector(
- policy, envs,
- VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
- Logger.single_preprocess_fn)
+ policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
+ Logger.single_preprocess_fn
+ )
r = c2.collect(n_episode=10)['rews']
assert r.shape == (10, 4) and np.all(r == 1)
batch, _ = c2.buffer.sample(10)
@@ -326,22 +481,23 @@ def test_collector_with_atari_setting():
c2 = Collector(
policy, env,
- ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True))
+ ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True)
+ )
c2.collect(n_step=8)
assert c2.buffer.obs.shape == (100, 84, 84)
obs = np.zeros_like(c2.buffer.obs)
obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1]
assert np.all(c2.buffer.obs == obs)
- assert np.allclose(c2.buffer[:].obs_next,
- reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1])
+ assert np.allclose(
+ c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]
+ )
# atari multi buffer
- env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True)
- for i in [2, 3, 4, 5]]
+ env_fns = [
+ lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]
+ ]
envs = DummyVectorEnv(env_fns)
- c3 = Collector(
- policy, envs,
- VectorReplayBuffer(total_size=100, buffer_num=4))
+ c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=12)
result = c3.collect(n_episode=9)
assert result["n/ep"] == 9 and result["n/st"] == 23
@@ -360,8 +516,14 @@ def test_collector_with_atari_setting():
assert np.all(obs_next == c3.buffer.obs_next)
c4 = Collector(
policy, envs,
- VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4,
- ignore_obs_next=True, save_only_last_obs=True))
+ VectorReplayBuffer(
+ total_size=100,
+ buffer_num=4,
+ stack_num=4,
+ ignore_obs_next=True,
+ save_only_last_obs=True
+ )
+ )
c4.collect(n_step=12)
result = c4.collect(n_episode=9)
assert result["n/ep"] == 9 and result["n/st"] == 23
@@ -374,12 +536,45 @@ def test_collector_with_atari_setting():
obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]]
assert np.all(c4.buffer.obs == obs)
obs_next = np.zeros([len(c4.buffer), 4, 84, 84])
- ref_index = np.array([
- 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 2, 2, 1, 2, 2, 1, 2, 2,
- 1, 2, 3, 3, 1, 2, 3, 3,
- 1, 2, 3, 4, 4, 1, 2, 3, 4, 4,
- ])
+ ref_index = np.array(
+ [
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 1,
+ 2,
+ 2,
+ 1,
+ 2,
+ 2,
+ 1,
+ 2,
+ 3,
+ 3,
+ 1,
+ 2,
+ 3,
+ 3,
+ 1,
+ 2,
+ 3,
+ 4,
+ 4,
+ 1,
+ 2,
+ 3,
+ 4,
+ 4,
+ ]
+ )
obs_next[:, -1] = slice_obs[ref_index]
ref_index -= 1
ref_index[ref_index < 0] = 0
@@ -392,20 +587,25 @@ def test_collector_with_atari_setting():
obs_next[:, -4] = slice_obs[ref_index]
assert np.all(obs_next == c4.buffer[:].obs_next)
- buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True,
- save_only_last_obs=True)
+ buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True)
c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
result_ = c5.collect(n_step=12)
assert len(buf) == 5 and len(c5.buffer) == 12
result = c5.collect(n_episode=9)
assert result["n/ep"] == 9 and result["n/st"] == 23
assert len(buf) == 35
- assert np.all(buf.obs[:len(buf)] == slice_obs[[
- 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4,
- 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4]])
- assert np.all(buf[:].obs_next[:, -1] == slice_obs[[
- 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4,
- 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]])
+ assert np.all(
+ buf.obs[:len(buf)] == slice_obs[[
+ 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 0, 1, 0, 1,
+ 2, 3, 0, 1, 2, 0, 1, 2, 3, 4
+ ]]
+ )
+ assert np.all(
+ buf[:].obs_next[:, -1] == slice_obs[[
+ 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 2,
+ 3, 3, 1, 2, 2, 1, 2, 3, 4, 4
+ ]]
+ )
assert len(buf) == len(c5.buffer)
# test buffer=None
diff --git a/test/base/test_env.py b/test/base/test_env.py
index cc1dc84c7..b9d6489b6 100644
--- a/test/base/test_env.py
+++ b/test/base/test_env.py
@@ -1,10 +1,11 @@
import sys
import time
+
import numpy as np
from gym.spaces.discrete import Discrete
+
from tianshou.data import Batch
-from tianshou.env import DummyVectorEnv, SubprocVectorEnv, \
- ShmemVectorEnv, RayVectorEnv
+from tianshou.env import DummyVectorEnv, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from env import MyTestEnv, NXEnv
@@ -24,17 +25,14 @@ def recurse_comp(a, b):
try:
if isinstance(a, np.ndarray):
if a.dtype == object:
- return np.array(
- [recurse_comp(m, n) for m, n in zip(a, b)]).all()
+ return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
else:
return np.allclose(a, b)
elif isinstance(a, (list, tuple)):
- return np.array(
- [recurse_comp(m, n) for m, n in zip(a, b)]).all()
+ return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all()
elif isinstance(a, dict):
- return np.array(
- [recurse_comp(a[k], b[k]) for k in a.keys()]).all()
- except(Exception):
+ return np.array([recurse_comp(a[k], b[k]) for k in a.keys()]).all()
+ except (Exception):
return False
@@ -75,7 +73,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
# truncate env_ids with the first terms
# typically len(env_ids) == len(A) == len(action), except for the
# last batch when actions are not enough
- env_ids = env_ids[: len(action)]
+ env_ids = env_ids[:len(action)]
spent_time = time.time() - spent_time
Batch.cat(o)
v.close()
@@ -85,10 +83,12 @@ def test_async_env(size=10000, num=8, sleep=0.1):
def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
- env_fns = [lambda: MyTestEnv(size=size, sleep=sleep * 2),
- lambda: MyTestEnv(size=size, sleep=sleep * 3),
- lambda: MyTestEnv(size=size, sleep=sleep * 5),
- lambda: MyTestEnv(size=size, sleep=sleep * 7)]
+ env_fns = [
+ lambda: MyTestEnv(size=size, sleep=sleep * 2),
+ lambda: MyTestEnv(size=size, sleep=sleep * 3),
+ lambda: MyTestEnv(size=size, sleep=sleep * 5),
+ lambda: MyTestEnv(size=size, sleep=sleep * 7)
+ ]
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
@@ -113,8 +113,10 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
t = time.time() - t
ids = Batch(info).env_id
print(ids, t)
- if not (len(ids) == len(res) and np.allclose(sorted(ids), res)
- and (t < timeout) == (len(res) == num - 1)):
+ if not (
+ len(ids) == len(res) and np.allclose(sorted(ids), res) and
+ (t < timeout) == (len(res) == num - 1)
+ ):
pass_check = 0
break
total_pass += pass_check
@@ -138,7 +140,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
v.seed(0)
action_list = [1] * 5 + [0] * 10 + [1] * 20
o = [v.reset() for v in venv]
- for i, a in enumerate(action_list):
+ for a in action_list:
o = []
for v in venv:
A, B, C, D = v.step([a] * num)
@@ -150,6 +152,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
continue
for info in infos:
assert recurse_comp(infos[0], info)
+
if __name__ == '__main__':
t = [0] * len(venv)
for i, e in enumerate(venv):
@@ -162,6 +165,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
t[i] = time.time() - t[i]
for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s')
+
for v in venv:
assert v.size == list(range(size, size + num))
assert v.env_num == num
@@ -172,8 +176,9 @@ def test_vecenv(size=10, num=8, sleep=0.001):
def test_env_obs():
for obs_type in ["array", "object"]:
- envs = SubprocVectorEnv([
- lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]])
+ envs = SubprocVectorEnv(
+ [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]
+ )
obs = envs.reset()
assert obs.dtype == object
obs = envs.step([1, 1, 1, 1])[0]
diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py
index b670d65e9..54b438507 100644
--- a/test/base/test_env_finite.py
+++ b/test/base/test_env_finite.py
@@ -1,17 +1,19 @@
# see issue #322 for detail
-import gym
import copy
-import numpy as np
from collections import Counter
-from torch.utils.data import Dataset, DataLoader, DistributedSampler
-from tianshou.policy import BasePolicy
-from tianshou.data import Collector, Batch
+import gym
+import numpy as np
+from torch.utils.data import DataLoader, Dataset, DistributedSampler
+
+from tianshou.data import Batch, Collector
from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv
+from tianshou.policy import BasePolicy
class DummyDataset(Dataset):
+
def __init__(self, length):
self.length = length
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
@@ -25,6 +27,7 @@ def __len__(self):
class FiniteEnv(gym.Env):
+
def __init__(self, dataset, num_replicas, rank):
self.dataset = dataset
self.num_replicas = num_replicas
@@ -32,7 +35,8 @@ def __init__(self, dataset, num_replicas, rank):
self.loader = DataLoader(
dataset,
sampler=DistributedSampler(dataset, num_replicas, rank),
- batch_size=None)
+ batch_size=None
+ )
self.iterator = None
def reset(self):
@@ -54,6 +58,7 @@ def step(self, action):
class FiniteVectorEnv(BaseVectorEnv):
+
def __init__(self, env_fns, **kwargs):
super().__init__(env_fns, **kwargs)
self._alive_env_ids = set()
@@ -79,6 +84,7 @@ def _get_default_obs(self):
def _get_default_info(self):
return copy.deepcopy(self._default_info)
+
# END
def reset(self, id=None):
@@ -147,6 +153,7 @@ class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
class AnyPolicy(BasePolicy):
+
def forward(self, batch, state=None):
return Batch(act=np.stack([1] * len(batch)))
@@ -159,6 +166,7 @@ def _finite_env_factory(dataset, num_replicas, rank):
class MetricTracker:
+
def __init__(self):
self.counter = Counter()
self.finished = set()
@@ -179,30 +187,32 @@ def validate(self):
def test_finite_dummy_vector_env():
dataset = DummyDataset(100)
- envs = FiniteSubprocVectorEnv([
- _finite_env_factory(dataset, 5, i) for i in range(5)])
+ envs = FiniteSubprocVectorEnv(
+ [_finite_env_factory(dataset, 5, i) for i in range(5)]
+ )
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(3):
envs.tracker = MetricTracker()
try:
- test_collector.collect(n_step=10 ** 18)
+ test_collector.collect(n_step=10**18)
except StopIteration:
envs.tracker.validate()
def test_finite_subproc_vector_env():
dataset = DummyDataset(100)
- envs = FiniteSubprocVectorEnv([
- _finite_env_factory(dataset, 5, i) for i in range(5)])
+ envs = FiniteSubprocVectorEnv(
+ [_finite_env_factory(dataset, 5, i) for i in range(5)]
+ )
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
for _ in range(3):
envs.tracker = MetricTracker()
try:
- test_collector.collect(n_step=10 ** 18)
+ test_collector.collect(n_step=10**18)
except StopIteration:
envs.tracker.validate()
diff --git a/test/base/test_returns.py b/test/base/test_returns.py
index a104eb674..3adcdaf5c 100644
--- a/test/base/test_returns.py
+++ b/test/base/test_returns.py
@@ -1,9 +1,10 @@
-import torch
-import numpy as np
from timeit import timeit
-from tianshou.policy import BasePolicy
+import numpy as np
+import torch
+
from tianshou.data import Batch, ReplayBuffer, to_numpy
+from tianshou.policy import BasePolicy
def compute_episodic_return_base(batch, gamma):
@@ -24,8 +25,12 @@ def test_episodic_returns(size=2560):
batch = Batch(
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
- info=Batch({'TimeLimit.truncated':
- np.array([False, False, False, False, False, True, False, False])})
+ info=Batch(
+ {
+ 'TimeLimit.truncated':
+ np.array([False, False, False, False, False, True, False, False])
+ }
+ )
)
for b in batch:
b.obs = b.act = 1
@@ -65,28 +70,40 @@ def test_episodic_returns(size=2560):
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95)
- ground_truth = np.array([
- 454.8344, 376.1143, 291.298, 200.,
- 464.5610, 383.1085, 295.387, 201.,
- 474.2876, 390.1027, 299.476, 202.])
+ ground_truth = np.array(
+ [
+ 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201.,
+ 474.2876, 390.1027, 299.476, 202.
+ ]
+ )
assert np.allclose(returns, ground_truth)
buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]),
- info=Batch({'TimeLimit.truncated':
- np.array([False, False, False, True, False, False,
- False, True, False, False, False, False])})
+ info=Batch(
+ {
+ 'TimeLimit.truncated':
+ np.array(
+ [
+ False, False, False, True, False, False, False, True, False,
+ False, False, False
+ ]
+ )
+ }
+ )
)
for b in batch:
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95)
- ground_truth = np.array([
- 454.0109, 375.2386, 290.3669, 199.01,
- 462.9138, 381.3571, 293.5248, 199.02,
- 474.2876, 390.1027, 299.476, 202.])
+ ground_truth = np.array(
+ [
+ 454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248, 199.02,
+ 474.2876, 390.1027, 299.476, 202.
+ ]
+ )
assert np.allclose(returns, ground_truth)
if __name__ == '__main__':
@@ -129,16 +146,17 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
real_step_n = nstep
for n in range(nstep):
idx = (indices[i] + n) % buf_len
- r += buffer.rew[idx] * gamma ** n
+ r += buffer.rew[idx] * gamma**n
if buffer.done[idx]:
- if not (hasattr(buffer, 'info') and
- buffer.info['TimeLimit.truncated'][idx]):
+ if not (
+ hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx]
+ ):
flag = True
real_step_n = n + 1
break
if not flag:
idx = (indices[i] + real_step_n - 1) % buf_len
- r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n
+ r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
returns[i] = r
return returns
@@ -152,89 +170,128 @@ def test_nstep_returns(size=10000):
# rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
# done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]
# test nstep = 1
- returns = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn, gamma=.1, n_step=1
- ).pop('returns').reshape(-1))
+ returns = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn, gamma=.1, n_step=1
+ ).pop('returns').reshape(-1)
+ )
assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
r_ = compute_nstep_return_base(1, .1, buf, indices)
assert np.allclose(returns, r_), (r_, returns)
- returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1
- ).pop('returns'))
+ returns_multidim = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1
+ ).pop('returns')
+ )
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 2
- returns = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn, gamma=.1, n_step=2
- ).pop('returns').reshape(-1))
+ returns = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn, gamma=.1, n_step=2
+ ).pop('returns').reshape(-1)
+ )
assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
r_ = compute_nstep_return_base(2, .1, buf, indices)
assert np.allclose(returns, r_)
- returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2
- ).pop('returns'))
+ returns_multidim = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2
+ ).pop('returns')
+ )
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 10
- returns = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn, gamma=.1, n_step=10
- ).pop('returns').reshape(-1))
+ returns = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn, gamma=.1, n_step=10
+ ).pop('returns').reshape(-1)
+ )
assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
r_ = compute_nstep_return_base(10, .1, buf, indices)
assert np.allclose(returns, r_)
- returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10
- ).pop('returns'))
+ returns_multidim = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10
+ ).pop('returns')
+ )
assert np.allclose(returns_multidim, returns[:, np.newaxis])
def test_nstep_returns_with_timelimit(size=10000):
buf = ReplayBuffer(10)
for i in range(12):
- buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3,
- info={"TimeLimit.truncated": i == 3}))
+ buf.add(
+ Batch(
+ obs=0,
+ act=0,
+ rew=i + 1,
+ done=i % 4 == 3,
+ info={"TimeLimit.truncated": i == 3}
+ )
+ )
batch, indices = buf.sample(0)
assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
# rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
# done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]
# test nstep = 1
- returns = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn, gamma=.1, n_step=1
- ).pop('returns').reshape(-1))
+ returns = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn, gamma=.1, n_step=1
+ ).pop('returns').reshape(-1)
+ )
assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
r_ = compute_nstep_return_base(1, .1, buf, indices)
assert np.allclose(returns, r_), (r_, returns)
- returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1
- ).pop('returns'))
+ returns_multidim = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1
+ ).pop('returns')
+ )
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 2
- returns = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn, gamma=.1, n_step=2
- ).pop('returns').reshape(-1))
+ returns = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn, gamma=.1, n_step=2
+ ).pop('returns').reshape(-1)
+ )
assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
r_ = compute_nstep_return_base(2, .1, buf, indices)
assert np.allclose(returns, r_)
- returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2
- ).pop('returns'))
+ returns_multidim = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2
+ ).pop('returns')
+ )
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 10
- returns = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn, gamma=.1, n_step=10
- ).pop('returns').reshape(-1))
- assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78,
- 7.8, 8, 10.122, 11.22, 12.2, 12])
+ returns = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn, gamma=.1, n_step=10
+ ).pop('returns').reshape(-1)
+ )
+ assert np.allclose(
+ returns, [3.36, 3.6, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]
+ )
r_ = compute_nstep_return_base(10, .1, buf, indices)
assert np.allclose(returns, r_)
- returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10
- ).pop('returns'))
+ returns_multidim = to_numpy(
+ BasePolicy.compute_nstep_return(
+ batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10
+ ).pop('returns')
+ )
assert np.allclose(returns_multidim, returns[:, np.newaxis])
if __name__ == '__main__':
buf = ReplayBuffer(size)
for i in range(int(size * 1.5)):
- buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0,
- info={"TimeLimit.truncated": i % 33 == 0}))
+ buf.add(
+ Batch(
+ obs=0,
+ act=0,
+ rew=i + 1,
+ done=np.random.randint(3) == 0,
+ info={"TimeLimit.truncated": i % 33 == 0}
+ )
+ )
batch, indices = buf.sample(256)
def vanilla():
@@ -242,7 +299,8 @@ def vanilla():
def optimized():
return BasePolicy.compute_nstep_return(
- batch, buf, indices, target_q_fn, gamma=.1, n_step=3)
+ batch, buf, indices, target_q_fn, gamma=.1, n_step=3
+ )
cnt = 3000
print('nstep vanilla', timeit(vanilla, setup=vanilla, number=cnt))
diff --git a/test/base/test_utils.py b/test/base/test_utils.py
index d0e1cefb3..38bf5d40e 100644
--- a/test/base/test_utils.py
+++ b/test/base/test_utils.py
@@ -1,9 +1,9 @@
-import torch
import numpy as np
+import torch
-from tianshou.utils.net.common import MLP, Net
-from tianshou.utils import MovAvg, RunningMeanStd
from tianshou.exploration import GaussianNoise, OUNoise
+from tianshou.utils import MovAvg, RunningMeanStd
+from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@@ -20,14 +20,14 @@ def test_moving_average():
stat = MovAvg(10)
assert np.allclose(stat.get(), 0)
assert np.allclose(stat.mean(), 0)
- assert np.allclose(stat.std() ** 2, 0)
+ assert np.allclose(stat.std()**2, 0)
stat.add(torch.tensor([1]))
stat.add(np.array([2]))
stat.add([3, 4])
stat.add(5.)
assert np.allclose(stat.get(), 3)
assert np.allclose(stat.mean(), 3)
- assert np.allclose(stat.std() ** 2, 2)
+ assert np.allclose(stat.std()**2, 2)
def test_rms():
@@ -55,23 +55,36 @@ def test_net():
action_shape = (5, )
data = torch.rand([bsz, *state_shape])
expect_output_shape = [bsz, *action_shape]
- net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
- norm_layer=torch.nn.LayerNorm, activation=None)
+ net = Net(
+ state_shape,
+ action_shape,
+ hidden_sizes=[128, 128],
+ norm_layer=torch.nn.LayerNorm,
+ activation=None
+ )
assert list(net(data)[0].shape) == expect_output_shape
assert str(net).count("LayerNorm") == 2
assert str(net).count("ReLU") == 0
Q_param = V_param = {"hidden_sizes": [128, 128]}
- net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
- dueling_param=(Q_param, V_param))
+ net = Net(
+ state_shape,
+ action_shape,
+ hidden_sizes=[128, 128],
+ dueling_param=(Q_param, V_param)
+ )
assert list(net(data)[0].shape) == expect_output_shape
# concat
- net = Net(state_shape, action_shape, hidden_sizes=[128],
- concat=True)
+ net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True)
data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)])
expect_output_shape = [bsz, 128]
assert list(net(data)[0].shape) == expect_output_shape
- net = Net(state_shape, action_shape, hidden_sizes=[128],
- concat=True, dueling_param=(Q_param, V_param))
+ net = Net(
+ state_shape,
+ action_shape,
+ hidden_sizes=[128],
+ concat=True,
+ dueling_param=(Q_param, V_param)
+ )
assert list(net(data)[0].shape) == expect_output_shape
# recurrent actor/critic
data = torch.rand([bsz, *state_shape]).flatten(1)
diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py
index 3dbeb9091..e88c0869c 100644
--- a/test/continuous/test_ddpg.py
+++ b/test/continuous/test_ddpg.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
+from tianshou.exploration import GaussianNoise
from tianshou.policy import DDPGPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
-from tianshou.exploration import GaussianNoise
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import Actor, Critic
@@ -39,8 +40,8 @@ def get_args():
parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -56,36 +57,51 @@ def test_ddpg(args=get_args()):
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
- actor = Actor(net, args.action_shape, max_action=args.max_action,
- device=args.device).to(args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = Actor(
+ net, args.action_shape, max_action=args.max_action, device=args.device
+ ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic = Critic(net, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
- actor, actor_optim, critic, critic_optim,
- tau=args.tau, gamma=args.gamma,
+ actor,
+ actor_optim,
+ critic,
+ critic_optim,
+ tau=args.tau,
+ gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm,
- estimation_step=args.n_step, action_space=env.action_space)
+ estimation_step=args.n_step,
+ action_space=env.action_space
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ddpg')
@@ -100,10 +116,19 @@ def stop_fn(mean_rewards):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
- update_per_step=args.update_per_step, stop_fn=stop_fn,
- save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py
index ad758897f..1a9e82623 100644
--- a/test/continuous/test_npg.py
+++ b/test/continuous/test_npg.py
@@ -1,19 +1,20 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch import nn
-from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
+from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import NPGPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
@@ -27,8 +28,9 @@ def get_args():
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048)
- parser.add_argument('--repeat-per-collect', type=int,
- default=2) # theoretically it should be 1
+ parser.add_argument(
+ '--repeat-per-collect', type=int, default=2
+ ) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
@@ -36,8 +38,8 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
# npg special
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
@@ -58,23 +60,40 @@ def test_npg(args=get_args()):
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
- actor = ActorProb(net, args.action_shape, max_action=args.max_action,
- unbounded=True, device=args.device).to(args.device)
- critic = Critic(Net(
- args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device,
- activation=nn.Tanh), device=args.device).to(args.device)
+ net = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
+ actor = ActorProb(
+ net,
+ args.action_shape,
+ max_action=args.max_action,
+ unbounded=True,
+ device=args.device
+ ).to(args.device)
+ critic = Critic(
+ Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ activation=nn.Tanh
+ ),
+ device=args.device
+ ).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
@@ -88,7 +107,10 @@ def dist(*logits):
return Independent(Normal(*logits), 1)
policy = NPGPolicy(
- actor, critic, optim, dist,
+ actor,
+ critic,
+ optim,
+ dist,
discount_factor=args.gamma,
reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv,
@@ -96,11 +118,12 @@ def dist(*logits):
action_space=env.action_space,
optim_critic_iters=args.optim_critic_iters,
actor_step_size=args.actor_step_size,
- deterministic_eval=True)
+ deterministic_eval=True
+ )
# collector
train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)))
+ policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'npg')
@@ -115,10 +138,19 @@ def stop_fn(mean_rewards):
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
- logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py
index 09c73fdd5..473222816 100644
--- a/test/continuous/test_ppo.py
+++ b/test/continuous/test_ppo.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
-from torch.utils.tensorboard import SummaryWriter
+import torch
from torch.distributions import Independent, Normal
+from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
@@ -34,8 +35,8 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
# ppo special
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.0)
@@ -63,30 +64,34 @@ def test_ppo(args=get_args()):
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
- actor = ActorProb(net, args.action_shape, max_action=args.max_action,
- device=args.device).to(args.device)
- critic = Critic(Net(
- args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device
- ), device=args.device).to(args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = ActorProb(
+ net, args.action_shape, max_action=args.max_action, device=args.device
+ ).to(args.device)
+ critic = Critic(
+ Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device),
+ device=args.device
+ ).to(args.device)
# orthogonal initialization
for m in set(actor.modules()).union(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
- set(actor.parameters()).union(critic.parameters()), lr=args.lr)
+ set(actor.parameters()).union(critic.parameters()), lr=args.lr
+ )
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
@@ -94,7 +99,10 @@ def dist(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
- actor, critic, optim, dist,
+ actor,
+ critic,
+ optim,
+ dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
@@ -107,11 +115,12 @@ def dist(*logits):
# dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip,
gae_lambda=args.gae_lambda,
- action_space=env.action_space)
+ action_space=env.action_space
+ )
# collector
train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)))
+ policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
@@ -126,10 +135,12 @@ def stop_fn(mean_rewards):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
- torch.save({
- 'model': policy.state_dict(),
- 'optim': optim.state_dict(),
- }, os.path.join(log_path, 'checkpoint.pth'))
+ torch.save(
+ {
+ 'model': policy.state_dict(),
+ 'optim': optim.state_dict(),
+ }, os.path.join(log_path, 'checkpoint.pth')
+ )
if args.resume:
# load from existing checkpoint
@@ -145,11 +156,21 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
- args.repeat_per_collect, args.test_num, args.batch_size,
- episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
- logger=logger, resume_from_log=args.resume,
- save_checkpoint_fn=save_checkpoint_fn)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ episode_per_collect=args.episode_per_collect,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ resume_from_log=args.resume,
+ save_checkpoint_fn=save_checkpoint_fn
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py
index 5b6a79492..da20290ec 100644
--- a/test/continuous/test_sac_with_il.py
+++ b/test/continuous/test_sac_with_il.py
@@ -1,17 +1,18 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
-from tianshou.utils.net.common import Net
+from tianshou.policy import ImitationPolicy, SACPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
-from tianshou.policy import SACPolicy, ImitationPolicy
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, ActorProb, Critic
@@ -34,10 +35,10 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=128)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128])
- parser.add_argument('--imitation-hidden-sizes', type=int,
- nargs='*', default=[128, 128])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
+ parser.add_argument(
+ '--imitation-hidden-sizes', type=int, nargs='*', default=[128, 128]
+ )
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@@ -45,8 +46,8 @@ def get_args():
parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -62,29 +63,43 @@ def test_sac_with_il(args=get_args()):
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
- actor = ActorProb(net, args.action_shape, max_action=args.max_action,
- device=args.device, unbounded=True).to(args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = ActorProb(
+ net,
+ args.action_shape,
+ max_action=args.max_action,
+ device=args.device,
+ unbounded=True
+ ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c1 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c1 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
- net_c2 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c2 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
@@ -95,15 +110,26 @@ def test_sac_with_il(args=get_args()):
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- tau=args.tau, gamma=args.gamma, alpha=args.alpha,
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau=args.tau,
+ gamma=args.gamma,
+ alpha=args.alpha,
reward_normalization=args.rew_norm,
- estimation_step=args.n_step, action_space=env.action_space)
+ estimation_step=args.n_step,
+ action_space=env.action_space
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
@@ -119,10 +145,19 @@ def stop_fn(mean_rewards):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
- update_per_step=args.update_per_step, stop_fn=stop_fn,
- save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
@@ -140,23 +175,41 @@ def stop_fn(mean_rewards):
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal
net = Actor(
- Net(args.state_shape, hidden_sizes=args.imitation_hidden_sizes,
- device=args.device),
- args.action_shape, max_action=args.max_action, device=args.device
+ Net(
+ args.state_shape,
+ hidden_sizes=args.imitation_hidden_sizes,
+ device=args.device
+ ),
+ args.action_shape,
+ max_action=args.max_action,
+ device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(
- net, optim, action_space=env.action_space,
- action_scaling=True, action_bound_method="clip")
+ net,
+ optim,
+ action_space=env.action_space,
+ action_scaling=True,
+ action_bound_method="clip"
+ )
il_test_collector = Collector(
il_policy,
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
)
train_collector.reset()
result = offpolicy_trainer(
- il_policy, train_collector, il_test_collector, args.epoch,
- args.il_step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ il_policy,
+ train_collector,
+ il_test_collector,
+ args.epoch,
+ args.il_step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py
index 8bae1edfa..2e3ef7ba7 100644
--- a/test/continuous/test_td3.py
+++ b/test/continuous/test_td3.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
+from tianshou.exploration import GaussianNoise
from tianshou.policy import TD3Policy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
-from tianshou.exploration import GaussianNoise
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import Actor, Critic
@@ -42,8 +43,8 @@ def get_args():
parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -59,46 +60,65 @@ def test_td3(args=get_args()):
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
- actor = Actor(net, args.action_shape, max_action=args.max_action,
- device=args.device).to(args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = Actor(
+ net, args.action_shape, max_action=args.max_action, device=args.device
+ ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c1 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c1 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
- net_c2 = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- concat=True, device=args.device)
+ net_c2 = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True,
+ device=args.device
+ )
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- tau=args.tau, gamma=args.gamma,
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau=args.tau,
+ gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=args.rew_norm,
estimation_step=args.n_step,
- action_space=env.action_space)
+ action_space=env.action_space
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
@@ -114,10 +134,19 @@ def stop_fn(mean_rewards):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
- update_per_step=args.update_per_step, stop_fn=stop_fn,
- save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py
index 65535fd50..4a4206f5f 100644
--- a/test/continuous/test_trpo.py
+++ b/test/continuous/test_trpo.py
@@ -1,19 +1,20 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch import nn
-from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
+from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import TRPOPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
@@ -27,8 +28,9 @@ def get_args():
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048)
- parser.add_argument('--repeat-per-collect', type=int,
- default=2) # theoretically it should be 1
+ parser.add_argument(
+ '--repeat-per-collect', type=int, default=2
+ ) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
@@ -36,8 +38,8 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
# trpo special
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
@@ -61,23 +63,40 @@ def test_trpo(args=get_args()):
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- activation=nn.Tanh, device=args.device)
- actor = ActorProb(net, args.action_shape, max_action=args.max_action,
- unbounded=True, device=args.device).to(args.device)
- critic = Critic(Net(
- args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device,
- activation=nn.Tanh), device=args.device).to(args.device)
+ net = Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ activation=nn.Tanh,
+ device=args.device
+ )
+ actor = ActorProb(
+ net,
+ args.action_shape,
+ max_action=args.max_action,
+ unbounded=True,
+ device=args.device
+ ).to(args.device)
+ critic = Critic(
+ Net(
+ args.state_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ activation=nn.Tanh
+ ),
+ device=args.device
+ ).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
@@ -91,7 +110,10 @@ def dist(*logits):
return Independent(Normal(*logits), 1)
policy = TRPOPolicy(
- actor, critic, optim, dist,
+ actor,
+ critic,
+ optim,
+ dist,
discount_factor=args.gamma,
reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv,
@@ -100,11 +122,12 @@ def dist(*logits):
optim_critic_iters=args.optim_critic_iters,
max_kl=args.max_kl,
backtrack_coeff=args.backtrack_coeff,
- max_backtracks=args.max_backtracks)
+ max_backtracks=args.max_backtracks
+ )
# collector
train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)))
+ policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'trpo')
@@ -119,10 +142,19 @@ def stop_fn(mean_rewards):
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
- logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py
index d11ce360c..745295826 100644
--- a/test/discrete/test_a2c_with_il.py
+++ b/test/discrete/test_a2c_with_il.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
+from tianshou.policy import A2CPolicy, ImitationPolicy
+from tianshou.trainer import offpolicy_trainer, onpolicy_trainer
+from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
-from tianshou.policy import A2CPolicy, ImitationPolicy
-from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
def get_args():
@@ -31,17 +32,15 @@ def get_args():
parser.add_argument('--update-per-step', type=float, default=1 / 16)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[64, 64])
- parser.add_argument('--imitation-hidden-sizes', type=int,
- nargs='*', default=[128])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
+ parser.add_argument('--imitation-hidden-sizes', type=int, nargs='*', default=[128])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
# a2c special
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.0)
@@ -60,33 +59,42 @@ def test_a2c_with_il(args=get_args()):
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(
- set(actor.parameters()).union(critic.parameters()), lr=args.lr)
+ set(actor.parameters()).union(critic.parameters()), lr=args.lr
+ )
dist = torch.distributions.Categorical
policy = A2CPolicy(
- actor, critic, optim, dist,
- discount_factor=args.gamma, gae_lambda=args.gae_lambda,
- vf_coef=args.vf_coef, ent_coef=args.ent_coef,
- max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
- action_space=env.action_space)
+ actor,
+ critic,
+ optim,
+ dist,
+ discount_factor=args.gamma,
+ gae_lambda=args.gae_lambda,
+ vf_coef=args.vf_coef,
+ ent_coef=args.ent_coef,
+ max_grad_norm=args.max_grad_norm,
+ reward_normalization=args.rew_norm,
+ action_space=env.action_space
+ )
# collector
train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)))
+ policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'a2c')
@@ -101,10 +109,19 @@ def stop_fn(mean_rewards):
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
- episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
- logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ episode_per_collect=args.episode_per_collect,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
@@ -121,8 +138,7 @@ def stop_fn(mean_rewards):
# here we define an imitation collector with a trivial policy
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
net = Actor(net, args.action_shape, device=args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
@@ -132,9 +148,18 @@ def stop_fn(mean_rewards):
)
train_collector.reset()
result = offpolicy_trainer(
- il_policy, train_collector, il_test_collector, args.epoch,
- args.il_step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ il_policy,
+ train_collector,
+ il_test_collector,
+ args.epoch,
+ args.il_step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py
index 3208e83c8..3c74723eb 100644
--- a/test/discrete/test_c51.py
+++ b/test/discrete/test_c51.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pickle
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import C51Policy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
@@ -34,20 +35,20 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=8)
parser.add_argument('--update-per-step', type=float, default=0.125)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128, 128, 128])
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
+ )
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--prioritized-replay',
- action="store_true", default=False)
+ parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--resume', action="store_true")
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args
@@ -60,29 +61,45 @@ def test_c51(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- softmax=True, num_atoms=args.num_atoms)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ softmax=True,
+ num_atoms=args.num_atoms
+ )
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(
- net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
- args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ args.gamma,
+ args.num_atoms,
+ args.v_min,
+ args.v_max,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- alpha=args.alpha, beta=args.beta)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ alpha=args.alpha,
+ beta=args.beta
+ )
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
@@ -117,12 +134,16 @@ def test_fn(epoch, env_step):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
- torch.save({
- 'model': policy.state_dict(),
- 'optim': optim.state_dict(),
- }, os.path.join(log_path, 'checkpoint.pth'))
- pickle.dump(train_collector.buffer,
- open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
+ torch.save(
+ {
+ 'model': policy.state_dict(),
+ 'optim': optim.state_dict(),
+ }, os.path.join(log_path, 'checkpoint.pth')
+ )
+ pickle.dump(
+ train_collector.buffer,
+ open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
+ )
if args.resume:
# load from existing checkpoint
@@ -144,11 +165,23 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
- test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ resume_from_log=args.resume,
+ save_checkpoint_fn=save_checkpoint_fn
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py
index aae609ec6..6912a1933 100644
--- a/test/discrete/test_dqn.py
+++ b/test/discrete/test_dqn.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pickle
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import DQNPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
@@ -31,22 +32,22 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128, 128, 128])
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
+ )
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--prioritized-replay',
- action="store_true", default=False)
+ parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
- '--save-buffer-name', type=str,
- default="./expert_DQN_CartPole-v0.pkl")
+ '--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl"
+ )
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -58,10 +59,12 @@ def test_dqn(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -69,19 +72,29 @@ def test_dqn(args=get_args()):
test_envs.seed(args.seed)
# Q_param = V_param = {"hidden_sizes": [128]}
# model
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- # dueling=(Q_param, V_param),
- ).to(args.device)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ # dueling=(Q_param, V_param),
+ ).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
- net, optim, args.gamma, args.n_step,
- target_update_freq=args.target_update_freq)
+ net,
+ optim,
+ args.gamma,
+ args.n_step,
+ target_update_freq=args.target_update_freq
+ )
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- alpha=args.alpha, beta=args.beta)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ alpha=args.alpha,
+ beta=args.beta
+ )
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
@@ -116,10 +129,21 @@ def test_fn(epoch, env_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
- test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py
index aa4fbbe0f..064dbba24 100644
--- a/test/discrete/test_drqn.py
+++ b/test/discrete/test_drqn.py
@@ -1,17 +1,18 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import DQNPolicy
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
+from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
+from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Recurrent
-from tianshou.data import Collector, VectorReplayBuffer
def get_args():
@@ -37,8 +38,8 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -50,26 +51,35 @@ def test_drqn(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Recurrent(args.layer_num, args.state_shape,
- args.action_shape, args.device).to(args.device)
+ net = Recurrent(args.layer_num, args.state_shape, args.action_shape,
+ args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
- net, optim, args.gamma, args.n_step,
- target_update_freq=args.target_update_freq)
+ net,
+ optim,
+ args.gamma,
+ args.n_step,
+ target_update_freq=args.target_update_freq
+ )
# collector
buffer = VectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- stack_num=args.stack_num, ignore_obs_next=True)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ stack_num=args.stack_num,
+ ignore_obs_next=True
+ )
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
# the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs, exploration_noise=True)
@@ -94,11 +104,21 @@ def test_fn(epoch, env_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, update_per_step=args.update_per_step,
- train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
- save_fn=save_fn, logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py
index 0df2efb74..1763380f1 100644
--- a/test/discrete/test_fqf.py
+++ b/test/discrete/test_fqf.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import FQFPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
-from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
@@ -35,19 +36,17 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[64, 64, 64])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--prioritized-replay',
- action="store_true", default=False)
+ parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -59,22 +58,31 @@ def test_fqf(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- feature_net = Net(args.state_shape, args.hidden_sizes[-1],
- hidden_sizes=args.hidden_sizes[:-1], device=args.device,
- softmax=False)
+ feature_net = Net(
+ args.state_shape,
+ args.hidden_sizes[-1],
+ hidden_sizes=args.hidden_sizes[:-1],
+ device=args.device,
+ softmax=False
+ )
net = FullQuantileFunction(
- feature_net, args.action_shape, args.hidden_sizes,
- num_cosines=args.num_cosines, device=args.device
+ feature_net,
+ args.action_shape,
+ args.hidden_sizes,
+ num_cosines=args.num_cosines,
+ device=args.device
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
@@ -82,14 +90,24 @@ def test_fqf(args=get_args()):
fraction_net.parameters(), lr=args.fraction_lr
)
policy = FQFPolicy(
- net, optim, fraction_net, fraction_optim, args.gamma, args.num_fractions,
- args.ent_coef, args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ fraction_net,
+ fraction_optim,
+ args.gamma,
+ args.num_fractions,
+ args.ent_coef,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- alpha=args.alpha, beta=args.beta)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ alpha=args.alpha,
+ beta=args.beta
+ )
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
@@ -124,11 +142,21 @@ def test_fn(epoch, env_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py
index f404ce7dc..47540dadd 100644
--- a/test/discrete/test_il_bcq.py
+++ b/test/discrete/test_il_bcq.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pickle
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
-from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
-from tianshou.utils.net.common import Net
-from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteBCQPolicy
+from tianshou.trainer import offline_trainer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
def get_args():
@@ -29,17 +30,18 @@ def get_args():
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[64, 64])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
- "--load-buffer-name", type=str,
+ "--load-buffer-name",
+ type=str,
default="./expert_DQN_CartPole-v0.pkl",
)
parser.add_argument(
- "--device", type=str,
+ "--device",
+ type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume", action="store_true")
@@ -56,26 +58,39 @@ def test_discrete_bcq(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
policy_net = Net(
- args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device
+ ).to(args.device)
imitation_net = Net(
- args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device
+ ).to(args.device)
optim = torch.optim.Adam(
- list(policy_net.parameters()) + list(imitation_net.parameters()),
- lr=args.lr)
+ list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr
+ )
policy = DiscreteBCQPolicy(
- policy_net, imitation_net, optim, args.gamma, args.n_step,
- args.target_update_freq, args.eps_test,
- args.unlikely_action_threshold, args.imitation_logits_penalty,
+ policy_net,
+ imitation_net,
+ optim,
+ args.gamma,
+ args.n_step,
+ args.target_update_freq,
+ args.eps_test,
+ args.unlikely_action_threshold,
+ args.imitation_logits_penalty,
)
# buffer
assert os.path.exists(args.load_buffer_name), \
@@ -97,10 +112,12 @@ def stop_fn(mean_rewards):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
- torch.save({
- 'model': policy.state_dict(),
- 'optim': optim.state_dict(),
- }, os.path.join(log_path, 'checkpoint.pth'))
+ torch.save(
+ {
+ 'model': policy.state_dict(),
+ 'optim': optim.state_dict(),
+ }, os.path.join(log_path, 'checkpoint.pth')
+ )
if args.resume:
# load from existing checkpoint
@@ -115,10 +132,19 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
print("Fail to restore policy and optim.")
result = offline_trainer(
- policy, buffer, test_collector,
- args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
+ policy,
+ buffer,
+ test_collector,
+ args.epoch,
+ args.update_per_epoch,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ resume_from_log=args.resume,
+ save_checkpoint_fn=save_checkpoint_fn
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py
index 858d2b6f7..929469e8b 100644
--- a/test/discrete/test_il_crr.py
+++ b/test/discrete/test_il_crr.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pickle
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
-from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
-from tianshou.utils.net.common import Net
-from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCRRPolicy
+from tianshou.trainer import offline_trainer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
def get_args():
@@ -26,17 +27,18 @@ def get_args():
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[64, 64])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
- "--load-buffer-name", type=str,
+ "--load-buffer-name",
+ type=str,
default="./expert_DQN_CartPole-v0.pkl",
)
parser.add_argument(
- "--device", type=str,
+ "--device",
+ type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
@@ -51,23 +53,36 @@ def test_discrete_crr(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
- actor = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- softmax=False)
- critic = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- softmax=False)
- optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()),
- lr=args.lr)
+ actor = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ softmax=False
+ )
+ critic = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ softmax=False
+ )
+ optim = torch.optim.Adam(
+ list(actor.parameters()) + list(critic.parameters()), lr=args.lr
+ )
policy = DiscreteCRRPolicy(
- actor, critic, optim, args.gamma,
+ actor,
+ critic,
+ optim,
+ args.gamma,
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
@@ -89,9 +104,17 @@ def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
result = offline_trainer(
- policy, buffer, test_collector,
- args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ buffer,
+ test_collector,
+ args.epoch,
+ args.update_per_epoch,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py
index 0234c36f0..c93ddfc0d 100644
--- a/test/discrete/test_iqn.py
+++ b/test/discrete/test_iqn.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import IQNPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
-from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
@@ -35,19 +36,17 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[64, 64, 64])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--prioritized-replay',
- action="store_true", default=False)
+ parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -59,33 +58,50 @@ def test_iqn(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- feature_net = Net(args.state_shape, args.hidden_sizes[-1],
- hidden_sizes=args.hidden_sizes[:-1], device=args.device,
- softmax=False)
+ feature_net = Net(
+ args.state_shape,
+ args.hidden_sizes[-1],
+ hidden_sizes=args.hidden_sizes[:-1],
+ device=args.device,
+ softmax=False
+ )
net = ImplicitQuantileNetwork(
- feature_net, args.action_shape,
- num_cosines=args.num_cosines, device=args.device)
+ feature_net,
+ args.action_shape,
+ num_cosines=args.num_cosines,
+ device=args.device
+ )
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = IQNPolicy(
- net, optim, args.gamma, args.sample_size, args.online_sample_size,
- args.target_sample_size, args.n_step,
+ net,
+ optim,
+ args.gamma,
+ args.sample_size,
+ args.online_sample_size,
+ args.target_sample_size,
+ args.n_step,
target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- alpha=args.alpha, beta=args.beta)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ alpha=args.alpha,
+ beta=args.beta
+ )
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
@@ -120,11 +136,21 @@ def test_fn(epoch, env_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py
index 3c36bb265..fafd7cc49 100644
--- a/test/discrete/test_pg.py
+++ b/test/discrete/test_pg.py
@@ -1,17 +1,18 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import PGPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
def get_args():
@@ -26,16 +27,15 @@ def get_args():
parser.add_argument('--episode-per-collect', type=int, default=8)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[64, 64])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -47,24 +47,35 @@ def test_pg(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes,
- device=args.device, softmax=True).to(args.device)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ softmax=True
+ ).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
- policy = PGPolicy(net, optim, dist, args.gamma,
- reward_normalization=args.rew_norm,
- action_space=env.action_space)
+ policy = PGPolicy(
+ net,
+ optim,
+ dist,
+ args.gamma,
+ reward_normalization=args.rew_norm,
+ action_space=env.action_space
+ )
for m in net.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
@@ -72,8 +83,8 @@ def test_pg(args=get_args()):
torch.nn.init.zeros_(m.bias)
# collector
train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)))
+ policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'pg')
@@ -88,10 +99,19 @@ def stop_fn(mean_rewards):
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
- episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
- logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ episode_per_collect=args.episode_per_collect,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py
index 3658f364b..96650b14b 100644
--- a/test/discrete/test_ppo.py
+++ b/test/discrete/test_ppo.py
@@ -1,17 +1,18 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
+from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
-from tianshou.trainer import onpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
@@ -33,8 +34,8 @@ def get_args():
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
# ppo special
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.0)
@@ -57,18 +58,19 @@ def test_ppo(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
# orthogonal initialization
@@ -77,10 +79,14 @@ def test_ppo(args=get_args()):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
- set(actor.parameters()).union(critic.parameters()), lr=args.lr)
+ set(actor.parameters()).union(critic.parameters()), lr=args.lr
+ )
dist = torch.distributions.Categorical
policy = PPOPolicy(
- actor, critic, optim, dist,
+ actor,
+ critic,
+ optim,
+ dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
@@ -93,11 +99,12 @@ def test_ppo(args=get_args()):
action_space=env.action_space,
deterministic_eval=True,
advantage_normalization=args.norm_adv,
- recompute_advantage=args.recompute_adv)
+ recompute_advantage=args.recompute_adv
+ )
# collector
train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)))
+ policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
@@ -112,10 +119,19 @@ def stop_fn(mean_rewards):
# trainer
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
- step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
- logger=logger)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.repeat_per_collect,
+ args.test_num,
+ args.batch_size,
+ step_per_collect=args.step_per_collect,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py
index 2385db0ee..cf8d22212 100644
--- a/test/discrete/test_qrdqn.py
+++ b/test/discrete/test_qrdqn.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pickle
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
-from tianshou.policy import QRDQNPolicy
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
-from tianshou.utils.net.common import Net
+from tianshou.policy import QRDQNPolicy
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
def get_args():
@@ -32,22 +33,22 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128, 128, 128])
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
+ )
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--prioritized-replay',
- action="store_true", default=False)
+ parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
- '--save-buffer-name', type=str,
- default="./expert_QRDQN_CartPole-v0.pkl")
+ '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl"
+ )
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -61,29 +62,43 @@ def test_qrdqn(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- softmax=False, num_atoms=args.num_quantiles)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ softmax=False,
+ num_atoms=args.num_quantiles
+ )
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = QRDQNPolicy(
- net, optim, args.gamma, args.num_quantiles,
- args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ args.gamma,
+ args.num_quantiles,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- alpha=args.alpha, beta=args.beta)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ alpha=args.alpha,
+ beta=args.beta
+ )
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
@@ -118,11 +133,21 @@ def test_fn(epoch, env_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py
index dbfd42aad..01b868f13 100644
--- a/test/discrete/test_qrdqn_il_cql.py
+++ b/test/discrete/test_qrdqn_il_cql.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pickle
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
-from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
-from tianshou.utils.net.common import Net
-from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
+from tianshou.trainer import offline_trainer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
def get_args():
@@ -29,17 +30,18 @@ def get_args():
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[64, 64])
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
- "--load-buffer-name", type=str,
+ "--load-buffer-name",
+ type=str,
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument(
- "--device", type=str,
+ "--device",
+ type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
@@ -54,20 +56,31 @@ def test_discrete_cql(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- softmax=False, num_atoms=args.num_quantiles)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ softmax=False,
+ num_atoms=args.num_quantiles
+ )
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DiscreteCQLPolicy(
- net, optim, args.gamma, args.num_quantiles, args.n_step,
- args.target_update_freq, min_q_weight=args.min_q_weight
+ net,
+ optim,
+ args.gamma,
+ args.num_quantiles,
+ args.n_step,
+ args.target_update_freq,
+ min_q_weight=args.min_q_weight
).to(args.device)
# buffer
assert os.path.exists(args.load_buffer_name), \
@@ -88,9 +101,17 @@ def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
result = offline_trainer(
- policy, buffer, test_collector,
- args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
- stop_fn=stop_fn, save_fn=save_fn, logger=logger)
+ policy,
+ buffer,
+ test_collector,
+ args.epoch,
+ args.update_per_epoch,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger
+ )
assert stop_fn(result['best_reward'])
diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py
index 4fdcfd352..b226a025c 100644
--- a/test/discrete/test_rainbow.py
+++ b/test/discrete/test_rainbow.py
@@ -1,19 +1,20 @@
+import argparse
import os
-import gym
-import torch
import pickle
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
+from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
+from tianshou.env import DummyVectorEnv
from tianshou.policy import RainbowPolicy
+from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
-from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import NoisyLinear
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer
def get_args():
@@ -36,21 +37,21 @@ def get_args():
parser.add_argument('--step-per-collect', type=int, default=8)
parser.add_argument('--update-per-step', type=float, default=0.125)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128, 128, 128])
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
+ )
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--prioritized-replay',
- action="store_true", default=False)
+ parser.add_argument('--prioritized-replay', action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--beta-final', type=float, default=1.)
parser.add_argument('--resume', action="store_true")
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args
@@ -63,35 +64,56 @@ def test_rainbow(args=get_args()):
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
+
# model
def noisy_linear(x, y):
return NoisyLinear(x, y, args.noisy_std)
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device,
- softmax=True, num_atoms=args.num_atoms,
- dueling_param=({"linear_layer": noisy_linear},
- {"linear_layer": noisy_linear}))
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device,
+ softmax=True,
+ num_atoms=args.num_atoms,
+ dueling_param=({
+ "linear_layer": noisy_linear
+ }, {
+ "linear_layer": noisy_linear
+ })
+ )
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = RainbowPolicy(
- net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
- args.n_step, target_update_freq=args.target_update_freq
+ net,
+ optim,
+ args.gamma,
+ args.num_atoms,
+ args.v_min,
+ args.v_max,
+ args.n_step,
+ target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedVectorReplayBuffer(
- args.buffer_size, buffer_num=len(train_envs),
- alpha=args.alpha, beta=args.beta, weight_norm=True)
+ args.buffer_size,
+ buffer_num=len(train_envs),
+ alpha=args.alpha,
+ beta=args.beta,
+ weight_norm=True
+ )
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
@@ -136,12 +158,16 @@ def test_fn(epoch, env_step):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
- torch.save({
- 'model': policy.state_dict(),
- 'optim': optim.state_dict(),
- }, os.path.join(log_path, 'checkpoint.pth'))
- pickle.dump(train_collector.buffer,
- open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
+ torch.save(
+ {
+ 'model': policy.state_dict(),
+ 'optim': optim.state_dict(),
+ }, os.path.join(log_path, 'checkpoint.pth')
+ )
+ pickle.dump(
+ train_collector.buffer,
+ open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
+ )
if args.resume:
# load from existing checkpoint
@@ -163,11 +189,23 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
- test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ update_per_step=args.update_per_step,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ resume_from_log=args.resume,
+ save_checkpoint_fn=save_checkpoint_fn
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py
index d8abf48e4..41be36838 100644
--- a/test/discrete/test_sac.py
+++ b/test/discrete/test_sac.py
@@ -1,18 +1,19 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
-from tianshou.utils.net.common import Net
from tianshou.policy import DiscreteSACPolicy
from tianshou.trainer import offpolicy_trainer
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import Actor, Critic
-from tianshou.data import Collector, VectorReplayBuffer
def get_args():
@@ -40,8 +41,8 @@ def get_args():
parser.add_argument('--rew-norm', action="store_true", default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
args = parser.parse_known_args()[0]
return args
@@ -52,27 +53,26 @@ def test_discrete_sac(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
test_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
- actor = Actor(net, args.action_shape,
- softmax_output=False, device=args.device).to(args.device)
+ net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = Actor(net, args.action_shape, softmax_output=False,
+ device=args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
- net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
critic1 = Critic(net_c1, last_size=args.action_shape,
device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
- net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
critic2 = Critic(net_c2, last_size=args.action_shape,
device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
@@ -85,13 +85,22 @@ def test_discrete_sac(args=get_args()):
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = DiscreteSACPolicy(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- args.tau, args.gamma, args.alpha, estimation_step=args.n_step,
- reward_normalization=args.rew_norm)
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ args.tau,
+ args.gamma,
+ args.alpha,
+ estimation_step=args.n_step,
+ reward_normalization=args.rew_norm
+ )
# collector
train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)))
+ policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs))
+ )
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
@@ -107,10 +116,20 @@ def stop_fn(mean_rewards):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step, test_in_train=False)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ logger=logger,
+ update_per_step=args.update_per_step,
+ test_in_train=False
+ )
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py
index 01710d827..3a50f36e9 100644
--- a/test/modelbased/test_psrl.py
+++ b/test/modelbased/test_psrl.py
@@ -1,15 +1,16 @@
+import argparse
import os
-import gym
-import torch
import pprint
-import argparse
+
+import gym
import numpy as np
+import torch
from torch.utils.tensorboard import SummaryWriter
-from tianshou.policy import PSRLPolicy
-from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
+from tianshou.policy import PSRLPolicy
+from tianshou.trainer import onpolicy_trainer
def get_args():
@@ -42,10 +43,12 @@ def test_psrl(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ [lambda: gym.make(args.task) for _ in range(args.training_num)]
+ )
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ [lambda: gym.make(args.task) for _ in range(args.test_num)]
+ )
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@@ -59,12 +62,15 @@ def test_psrl(args=get_args()):
rew_std_prior = np.full((n_state, n_action), args.rew_std_prior)
policy = PSRLPolicy(
trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps,
- args.add_done_loop)
+ args.add_done_loop
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'psrl')
@@ -80,11 +86,19 @@ def stop_fn(mean_rewards):
train_collector.collect(n_step=args.buffer_size, random=True)
# trainer, test it without logger
result = onpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, 1, args.test_num, 0,
- episode_per_collect=args.episode_per_collect, stop_fn=stop_fn,
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ 1,
+ args.test_num,
+ 0,
+ episode_per_collect=args.episode_per_collect,
+ stop_fn=stop_fn,
# logger=logger,
- test_in_train=False)
+ test_in_train=False
+ )
if __name__ == '__main__':
pprint.pprint(result)
diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py
index 6418ee8ec..7d1754245 100644
--- a/test/multiagent/Gomoku.py
+++ b/test/multiagent/Gomoku.py
@@ -1,17 +1,17 @@
import os
import pprint
-import numpy as np
from copy import deepcopy
+
+import numpy as np
+from tic_tac_toe import get_agents, get_parser, train_agent, watch
+from tic_tac_toe_env import TicTacToeEnv
from torch.utils.tensorboard import SummaryWriter
-from tianshou.env import DummyVectorEnv
from tianshou.data import Collector
+from tianshou.env import DummyVectorEnv
from tianshou.policy import RandomPolicy
from tianshou.utils import TensorboardLogger
-from tic_tac_toe_env import TicTacToeEnv
-from tic_tac_toe import get_parser, get_agents, train_agent, watch
-
def get_args():
parser = get_parser()
@@ -39,6 +39,7 @@ def gomoku(args=get_args()):
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
+
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for r in range(args.self_play_round):
rews = []
@@ -65,11 +66,11 @@ def env_func():
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
- args.logdir, 'Gomoku', 'dqn',
- f'policy_round_{r}_epoch_{epoch}.pth')
+ args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth'
+ )
result, agent_learn = train_agent(
- args, agent_learn=agent_learn,
- agent_opponent=agent, optim=optim)
+ args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
+ )
print(f'round_{r}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
diff --git a/test/multiagent/test_tic_tac_toe.py b/test/multiagent/test_tic_tac_toe.py
index 1cc06d374..aeb4644e1 100644
--- a/test/multiagent/test_tic_tac_toe.py
+++ b/test/multiagent/test_tic_tac_toe.py
@@ -1,4 +1,5 @@
import pprint
+
from tic_tac_toe import get_args, train_agent, watch
diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py
index dcd293a06..02fd47cd7 100644
--- a/test/multiagent/tic_tac_toe.py
+++ b/test/multiagent/tic_tac_toe.py
@@ -1,20 +1,24 @@
-import os
-import torch
import argparse
-import numpy as np
+import os
from copy import deepcopy
from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from tic_tac_toe_env import TicTacToeEnv
from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils import TensorboardLogger
+from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
-from tianshou.utils.net.common import Net
+from tianshou.policy import (
+ BasePolicy,
+ DQNPolicy,
+ MultiAgentPolicyManager,
+ RandomPolicy,
+)
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
-from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \
- MultiAgentPolicyManager
-
-from tic_tac_toe_env import TicTacToeEnv
+from tianshou.utils import TensorboardLogger
+from tianshou.utils.net.common import Net
def get_parser() -> argparse.ArgumentParser:
@@ -24,8 +28,9 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
- parser.add_argument('--gamma', type=float, default=0.9,
- help='a smaller gamma favors earlier win')
+ parser.add_argument(
+ '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
+ )
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=20)
@@ -33,31 +38,49 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128, 128, 128])
+ parser.add_argument(
+ '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
+ )
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.1)
parser.add_argument('--board-size', type=int, default=6)
parser.add_argument('--win-size', type=int, default=4)
- parser.add_argument('--win-rate', type=float, default=0.9,
- help='the expected winning rate')
- parser.add_argument('--watch', default=False, action='store_true',
- help='no training, '
- 'watch the play of pre-trained models')
- parser.add_argument('--agent-id', type=int, default=2,
- help='the learned agent plays as the'
- ' agent_id-th player. Choices are 1 and 2.')
- parser.add_argument('--resume-path', type=str, default='',
- help='the path of agent pth file '
- 'for resuming from a pre-trained agent')
- parser.add_argument('--opponent-path', type=str, default='',
- help='the path of opponent agent pth file '
- 'for resuming from a pre-trained agent')
parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
+ '--win-rate', type=float, default=0.9, help='the expected winning rate'
+ )
+ parser.add_argument(
+ '--watch',
+ default=False,
+ action='store_true',
+ help='no training, '
+ 'watch the play of pre-trained models'
+ )
+ parser.add_argument(
+ '--agent-id',
+ type=int,
+ default=2,
+ help='the learned agent plays as the'
+ ' agent_id-th player. Choices are 1 and 2.'
+ )
+ parser.add_argument(
+ '--resume-path',
+ type=str,
+ default='',
+ help='the path of agent pth file '
+ 'for resuming from a pre-trained agent'
+ )
+ parser.add_argument(
+ '--opponent-path',
+ type=str,
+ default='',
+ help='the path of opponent agent pth file '
+ 'for resuming from a pre-trained agent'
+ )
+ parser.add_argument(
+ '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
+ )
return parser
@@ -77,14 +100,21 @@ def get_agents(
args.action_shape = env.action_space.shape or env.action_space.n
if agent_learn is None:
# model
- net = Net(args.state_shape, args.action_shape,
- hidden_sizes=args.hidden_sizes, device=args.device
- ).to(args.device)
+ net = Net(
+ args.state_shape,
+ args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ device=args.device
+ ).to(args.device)
if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent_learn = DQNPolicy(
- net, optim, args.gamma, args.n_step,
- target_update_freq=args.target_update_freq)
+ net,
+ optim,
+ args.gamma,
+ args.n_step,
+ target_update_freq=args.target_update_freq
+ )
if args.resume_path:
agent_learn.load_state_dict(torch.load(args.resume_path))
@@ -109,8 +139,10 @@ def train_agent(
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[dict, BasePolicy]:
+
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
+
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
# seed
@@ -120,14 +152,16 @@ def env_func():
test_envs.seed(args.seed)
policy, optim = get_agents(
- args, agent_learn=agent_learn,
- agent_opponent=agent_opponent, optim=optim)
+ args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
+ )
# collector
train_collector = Collector(
- policy, train_envs,
+ policy,
+ train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ exploration_noise=True
+ )
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
@@ -142,10 +176,9 @@ def save_fn(policy):
model_save_path = args.model_save_path
else:
model_save_path = os.path.join(
- args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth')
- torch.save(
- policy.policies[args.agent_id - 1].state_dict(),
- model_save_path)
+ args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
+ )
+ torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path)
def stop_fn(mean_rewards):
return mean_rewards >= args.win_rate
@@ -161,11 +194,23 @@ def reward_metric(rews):
# trainer
result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
- logger=logger, test_in_train=False, reward_metric=reward_metric)
+ policy,
+ train_collector,
+ test_collector,
+ args.epoch,
+ args.step_per_epoch,
+ args.step_per_collect,
+ args.test_num,
+ args.batch_size,
+ train_fn=train_fn,
+ test_fn=test_fn,
+ stop_fn=stop_fn,
+ save_fn=save_fn,
+ update_per_step=args.update_per_step,
+ logger=logger,
+ test_in_train=False,
+ reward_metric=reward_metric
+ )
return result, policy.policies[args.agent_id - 1]
@@ -177,7 +222,8 @@ def watch(
) -> None:
env = TicTacToeEnv(args.board_size, args.win_size)
policy, optim = get_agents(
- args, agent_learn=agent_learn, agent_opponent=agent_opponent)
+ args, agent_learn=agent_learn, agent_opponent=agent_opponent
+ )
policy.eval()
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
diff --git a/test/multiagent/tic_tac_toe_env.py b/test/multiagent/tic_tac_toe_env.py
index 2fc045afa..2c79d303d 100644
--- a/test/multiagent/tic_tac_toe_env.py
+++ b/test/multiagent/tic_tac_toe_env.py
@@ -1,7 +1,8 @@
+from functools import partial
+from typing import Optional, Tuple
+
import gym
import numpy as np
-from functools import partial
-from typing import Tuple, Optional
from tianshou.env import MultiAgentEnv
@@ -27,7 +28,8 @@ def __init__(self, size: int = 3, win_size: int = 3):
f'be larger than board size {size}'
self.convolve_kernel = np.ones(win_size)
self.observation_space = gym.spaces.Box(
- low=-1.0, high=1.0, shape=(size, size), dtype=np.float32)
+ low=-1.0, high=1.0, shape=(size, size), dtype=np.float32
+ )
self.action_space = gym.spaces.Discrete(size * size)
self.current_board = None
self.current_agent = None
@@ -45,11 +47,10 @@ def reset(self) -> dict:
'mask': self.current_board.flatten() == 0
}
- def step(self, action: [int, np.ndarray]
- ) -> Tuple[dict, np.ndarray, np.ndarray, dict]:
+ def step(self, action: [int,
+ np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]:
if self.current_agent is None:
- raise ValueError(
- "calling step() of unreset environment is prohibited!")
+ raise ValueError("calling step() of unreset environment is prohibited!")
assert 0 <= action < self.size * self.size
assert self.current_board.item(action) == 0
_current_agent = self.current_agent
@@ -97,18 +98,28 @@ def _test_win(self):
rboard = self.current_board[row, :]
cboard = self.current_board[:, col]
current = self.current_board[row, col]
- rightup = [self.current_board[row - i, col + i]
- for i in range(1, self.size - col) if row - i >= 0]
- leftdown = [self.current_board[row + i, col - i]
- for i in range(1, col + 1) if row + i < self.size]
+ rightup = [
+ self.current_board[row - i, col + i] for i in range(1, self.size - col)
+ if row - i >= 0
+ ]
+ leftdown = [
+ self.current_board[row + i, col - i] for i in range(1, col + 1)
+ if row + i < self.size
+ ]
rdiag = np.array(leftdown[::-1] + [current] + rightup)
- rightdown = [self.current_board[row + i, col + i]
- for i in range(1, self.size - col) if row + i < self.size]
- leftup = [self.current_board[row - i, col - i]
- for i in range(1, col + 1) if row - i >= 0]
+ rightdown = [
+ self.current_board[row + i, col + i] for i in range(1, self.size - col)
+ if row + i < self.size
+ ]
+ leftup = [
+ self.current_board[row - i, col - i] for i in range(1, col + 1)
+ if row - i >= 0
+ ]
diag = np.array(leftup[::-1] + [current] + rightdown)
- results = [np.convolve(k, self.convolve_kernel, mode='valid')
- for k in (rboard, cboard, rdiag, diag)]
+ results = [
+ np.convolve(k, self.convolve_kernel, mode='valid')
+ for k in (rboard, cboard, rdiag, diag)
+ ]
return any([(np.abs(x) == self.win_size).any() for x in results])
def seed(self, seed: Optional[int] = None) -> int:
@@ -128,6 +139,7 @@ def f(i, data):
if number == -1:
return 'O' if last_move else 'o'
return '_'
+
for i, row in enumerate(self.current_board):
print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad)
print(top)
diff --git a/test/throughput/test_batch_profile.py b/test/throughput/test_batch_profile.py
index 9654f5838..fbd6fb89c 100644
--- a/test/throughput/test_batch_profile.py
+++ b/test/throughput/test_batch_profile.py
@@ -12,13 +12,20 @@
def data():
print("Initialising data...")
np.random.seed(0)
- batch_set = [Batch(a=[j for j in np.arange(1e3)],
- b={'b1': (3.14, 3.14), 'b2': np.arange(1e3)},
- c=i) for i in np.arange(int(1e4))]
+ batch_set = [
+ Batch(
+ a=[j for j in np.arange(1e3)],
+ b={
+ 'b1': (3.14, 3.14),
+ 'b2': np.arange(1e3)
+ },
+ c=i
+ ) for i in np.arange(int(1e4))
+ ]
batch0 = Batch(
a=np.ones((3, 4), dtype=np.float64),
b=Batch(
- c=np.ones((1,), dtype=np.float64),
+ c=np.ones((1, ), dtype=np.float64),
d=torch.ones((3, 3, 3), dtype=torch.float32),
e=list(range(3))
)
@@ -26,19 +33,25 @@ def data():
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
batch_len = int(1e4)
- batch3 = Batch(obs=[np.arange(20) for _ in np.arange(batch_len)],
- reward=np.arange(batch_len))
- indexs = np.random.choice(batch_len,
- size=batch_len // 10, replace=False)
- slice_dict = {'obs': [np.arange(20)
- for _ in np.arange(batch_len // 10)],
- 'reward': np.arange(batch_len // 10)}
- dict_set = [{'obs': np.arange(20), 'info': "this is info", 'reward': 0}
- for _ in np.arange(1e2)]
+ batch3 = Batch(
+ obs=[np.arange(20) for _ in np.arange(batch_len)], reward=np.arange(batch_len)
+ )
+ indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False)
+ slice_dict = {
+ 'obs': [np.arange(20) for _ in np.arange(batch_len // 10)],
+ 'reward': np.arange(batch_len // 10)
+ }
+ dict_set = [
+ {
+ 'obs': np.arange(20),
+ 'info': "this is info",
+ 'reward': 0
+ } for _ in np.arange(1e2)
+ ]
batch4 = Batch(
a=np.ones((10000, 4), dtype=np.float64),
b=Batch(
- c=np.ones((1,), dtype=np.float64),
+ c=np.ones((1, ), dtype=np.float64),
d=torch.ones((1000, 1000), dtype=torch.float32),
e=np.arange(1000)
)
diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py
index 40ce68889..2bb00c143 100644
--- a/test/throughput/test_buffer_profile.py
+++ b/test/throughput/test_buffer_profile.py
@@ -1,8 +1,10 @@
import sys
-import gym
import time
-import tqdm
+
+import gym
import numpy as np
+import tqdm
+
from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer
@@ -12,7 +14,7 @@ def test_replaybuffer(task="Pendulum-v0"):
env = gym.make(task)
buf = ReplayBuffer(10000)
obs = env.reset()
- for i in range(100000):
+ for _ in range(100000):
act = env.action_space.sample()
obs_next, rew, done, info = env.step(act)
batch = Batch(
@@ -35,7 +37,7 @@ def test_vectorbuffer(task="Pendulum-v0"):
env = gym.make(task)
buf = VectorReplayBuffer(total_size=10000, buffer_num=1)
obs = env.reset()
- for i in range(100000):
+ for _ in range(100000):
act = env.action_space.sample()
obs_next, rew, done, info = env.step(act)
batch = Batch(
diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py
index 6242e694b..bf9c4dc05 100644
--- a/test/throughput/test_collector_profile.py
+++ b/test/throughput/test_collector_profile.py
@@ -1,9 +1,9 @@
-import tqdm
import numpy as np
+import tqdm
-from tianshou.policy import BasePolicy
+from tianshou.data import AsyncCollector, Batch, Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
-from tianshou.data import Batch, Collector, AsyncCollector, VectorReplayBuffer
+from tianshou.policy import BasePolicy
if __name__ == '__main__':
from env import MyTestEnv
@@ -12,6 +12,7 @@
class MyPolicy(BasePolicy):
+
def __init__(self, dict_state=False, need_state=True):
"""
:param bool dict_state: if the observation of the environment is a dict
@@ -40,8 +41,7 @@ def test_collector_nstep():
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
- c3 = Collector(policy, dum,
- VectorReplayBuffer(total_size=40000, buffer_num=num))
+ c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_step"):
c3.reset()
result = c3.collect(n_step=i * len(env_fns))
@@ -53,8 +53,7 @@ def test_collector_nepisode():
env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)]
dum = DummyVectorEnv(env_fns)
num = len(env_fns)
- c3 = Collector(policy, dum,
- VectorReplayBuffer(total_size=40000, buffer_num=num))
+ c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num))
for i in tqdm.trange(1, 400, desc="test step collector n_episode"):
c3.reset()
result = c3.collect(n_episode=i)
@@ -64,22 +63,22 @@ def test_collector_nepisode():
def test_asynccollector():
env_lens = [2, 3, 4, 5]
- env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True)
- for i in env_lens]
+ env_fns = [
+ lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens
+ ]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy()
bufsize = 300
c1 = AsyncCollector(
- policy, venv,
- VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4))
+ policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)
+ )
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 100, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
- for i, count in enumerate(
- np.bincount(result["lens"], minlength=6)[2:]):
+ for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
@@ -88,8 +87,7 @@ def test_asynccollector():
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
- assert np.all(buf.obs_next[indices].reshape(
- count, env_len) == seq + 1)
+ assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 150, desc="test async n_step"):
result = c1.collect(n_step=n_step)
diff --git a/tianshou/__init__.py b/tianshou/__init__.py
index fb362d054..3430fb09b 100644
--- a/tianshou/__init__.py
+++ b/tianshou/__init__.py
@@ -1,7 +1,6 @@
-from tianshou import data, env, utils, policy, trainer, exploration
+from tianshou import data, env, exploration, policy, trainer, utils
-
-__version__ = "0.4.2"
+__version__ = "0.4.3"
__all__ = [
"env",
diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py
index 75e02a940..89250d009 100644
--- a/tianshou/data/__init__.py
+++ b/tianshou/data/__init__.py
@@ -1,12 +1,19 @@
+"""Data package."""
+# isort:skip_file
+
from tianshou.data.batch import Batch
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer.base import ReplayBuffer
from tianshou.data.buffer.prio import PrioritizedReplayBuffer
-from tianshou.data.buffer.manager import ReplayBufferManager
-from tianshou.data.buffer.manager import PrioritizedReplayBufferManager
-from tianshou.data.buffer.vecbuf import VectorReplayBuffer
-from tianshou.data.buffer.vecbuf import PrioritizedVectorReplayBuffer
+from tianshou.data.buffer.manager import (
+ ReplayBufferManager,
+ PrioritizedReplayBufferManager,
+)
+from tianshou.data.buffer.vecbuf import (
+ VectorReplayBuffer,
+ PrioritizedVectorReplayBuffer,
+)
from tianshou.data.buffer.cached import CachedReplayBuffer
from tianshou.data.collector import Collector, AsyncCollector
diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py
index f0ce76c2b..98adf680b 100644
--- a/tianshou/data/batch.py
+++ b/tianshou/data/batch.py
@@ -1,11 +1,12 @@
-import torch
import pprint
import warnings
-import numpy as np
+from collections.abc import Collection
from copy import deepcopy
from numbers import Number
-from collections.abc import Collection
-from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Union
+
+import numpy as np
+import torch
IndexType = Union[slice, int, np.ndarray, List[int]]
@@ -18,8 +19,7 @@ def _is_batch_set(data: Any) -> bool:
# "for e in data" will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
- return data.dtype == object and all(
- isinstance(e, (dict, Batch)) for e in data)
+ return data.dtype == object and all(isinstance(e, (dict, Batch)) for e in data)
elif isinstance(data, (list, tuple)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
return True
@@ -53,7 +53,7 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
return v # most often case
# convert the value to np.ndarray
# convert to object data type if neither bool nor number
- # raises an exception if array's elements are tensors themself
+ # raises an exception if array's elements are tensors themselves
v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(object)
@@ -73,7 +73,9 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
def _create_value(
- inst: Any, size: int, stack: bool = True
+ inst: Any,
+ size: int,
+ stack: bool = True,
) -> Union["Batch", np.ndarray, torch.Tensor]:
"""Create empty place-holders accroding to inst's shape.
@@ -92,11 +94,10 @@ def _create_value(
shape = (size, *inst.shape) if stack else (size, *inst.shape[1:])
if isinstance(inst, np.ndarray):
target_type = inst.dtype.type if issubclass(
- inst.dtype.type, (np.bool_, np.number)) else object
+ inst.dtype.type, (np.bool_, np.number)
+ ) else object
return np.full(
- shape,
- fill_value=None if target_type == object else 0,
- dtype=target_type
+ shape, fill_value=None if target_type == object else 0, dtype=target_type
)
elif isinstance(inst, torch.Tensor):
return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype)
@@ -133,8 +134,10 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
try:
return torch.stack(v) # type: ignore
except RuntimeError as e:
- raise TypeError("Batch does not support non-stackable iterable"
- " of torch.Tensor as unique value yet.") from e
+ raise TypeError(
+ "Batch does not support non-stackable iterable"
+ " of torch.Tensor as unique value yet."
+ ) from e
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
@@ -143,8 +146,10 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
try:
v = _to_array_with_correct_type(v)
except ValueError as e:
- raise TypeError("Batch does not support heterogeneous list/"
- "tuple of tensors as unique value yet.") from e
+ raise TypeError(
+ "Batch does not support heterogeneous list/"
+ "tuple of tensors as unique value yet."
+ ) from e
return v
@@ -164,20 +169,18 @@ def _alloc_by_keys_diff(
class Batch:
"""The internal data structure in Tianshou.
- Batch is a kind of supercharged array (of temporal data) stored
- individually in a (recursive) dictionary of object that can be either numpy
- array, torch tensor, or batch themself. It is designed to make it extremely
- easily to access, manipulate and set partial view of the heterogeneous data
- conveniently.
+ Batch is a kind of supercharged array (of temporal data) stored individually in a
+ (recursive) dictionary of object that can be either numpy array, torch tensor, or
+ batch themselves. It is designed to make it extremely easily to access, manipulate
+ and set partial view of the heterogeneous data conveniently.
For a detailed description, please refer to :ref:`batch_concept`.
"""
def __init__(
self,
- batch_dict: Optional[
- Union[dict, "Batch", Sequence[Union[dict, "Batch"]], np.ndarray]
- ] = None,
+ batch_dict: Optional[Union[dict, "Batch", Sequence[Union[dict, "Batch"]],
+ np.ndarray]] = None,
copy: bool = False,
**kwargs: Any,
) -> None:
@@ -248,11 +251,12 @@ def __setitem__(self, index: Union[str, IndexType], value: Any) -> None:
self.__dict__[index] = value
return
if not isinstance(value, Batch):
- raise ValueError("Batch does not supported tensor assignment. "
- "Use a compatible Batch or dict instead.")
- if not set(value.keys()).issubset(self.__dict__.keys()):
raise ValueError(
- "Creating keys is not supported by item assignment.")
+ "Batch does not supported tensor assignment. "
+ "Use a compatible Batch or dict instead."
+ )
+ if not set(value.keys()).issubset(self.__dict__.keys()):
+ raise ValueError("Creating keys is not supported by item assignment.")
for key, val in self.items():
try:
self.__dict__[key][index] = value[key]
@@ -368,9 +372,7 @@ def to_torch(
v = v.type(dtype)
self.__dict__[k] = v
- def __cat(
- self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]
- ) -> None:
+ def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None:
"""Private method for Batch.cat_.
::
@@ -397,9 +399,11 @@ def __cat(
sum_lens.append(sum_lens[-1] + x)
# collect non-empty keys
keys_map = [
- set(k for k, v in batch.items()
- if not (isinstance(v, Batch) and v.is_empty()))
- for batch in batches]
+ set(
+ k for k, v in batch.items()
+ if not (isinstance(v, Batch) and v.is_empty())
+ ) for batch in batches
+ ]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
@@ -433,8 +437,7 @@ def __cat(
try:
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
except KeyError:
- self.__dict__[k] = _create_value(
- val, sum_lens[-1], stack=False)
+ self.__dict__[k] = _create_value(val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None:
@@ -465,7 +468,8 @@ def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None:
raise ValueError(
"Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the "
- "concatenation of scalar.") from e
+ "concatenation of scalar."
+ ) from e
if not self.is_empty():
batches = [self] + list(batches)
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
@@ -506,8 +510,7 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None
if not b.is_empty():
batch_list.append(b)
else:
- raise ValueError(
- f"Cannot concatenate {type(b)} in Batch.stack_")
+ raise ValueError(f"Cannot concatenate {type(b)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = batch_list
@@ -515,9 +518,11 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None
batches = [self] + batches
# collect non-empty keys
keys_map = [
- set(k for k, v in batch.items()
- if not (isinstance(v, Batch) and v.is_empty()))
- for batch in batches]
+ set(
+ k for k, v in batch.items()
+ if not (isinstance(v, Batch) and v.is_empty())
+ ) for batch in batches
+ ]
keys_shared = set.intersection(*keys_map)
values_shared = [[e[k] for e in batches] for k in keys_shared]
for k, v in zip(keys_shared, values_shared):
@@ -529,8 +534,10 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None
try:
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
except ValueError:
- warnings.warn("You are using tensors with different shape,"
- " fallback to dtype=object by default.")
+ warnings.warn(
+ "You are using tensors with different shape,"
+ " fallback to dtype=object by default."
+ )
self.__dict__[k] = np.array(v, dtype=object)
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
@@ -543,7 +550,8 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} is only "
- f"supported with axis=0, but got axis={axis}!")
+ f"supported with axis=0, but got axis={axis}!"
+ )
for k in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
@@ -625,8 +633,10 @@ def empty_(self, index: Optional[Union[slice, IndexType]] = None) -> "Batch":
elif isinstance(v, Batch):
self.__dict__[k].empty_(index=index)
else: # scalar value
- warnings.warn("You are calling Batch.empty on a NumPy scalar, "
- "which may cause undefined behaviors.")
+ warnings.warn(
+ "You are calling Batch.empty on a NumPy scalar, "
+ "which may cause undefined behaviors."
+ )
if _is_number(v):
self.__dict__[k] = v.__class__(0)
else:
@@ -701,7 +711,8 @@ def is_empty(self, recurse: bool = False) -> bool:
return False
return all(
False if not isinstance(x, Batch) else x.is_empty(recurse=True)
- for x in self.values())
+ for x in self.values()
+ )
@property
def shape(self) -> List[int]:
@@ -718,9 +729,10 @@ def shape(self) -> List[int]:
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
else data_shape[0]
- def split(
- self, size: int, shuffle: bool = True, merge_last: bool = False
- ) -> Iterator["Batch"]:
+ def split(self,
+ size: int,
+ shuffle: bool = True,
+ merge_last: bool = False) -> Iterator["Batch"]:
"""Split whole data into multiple small batches.
:param int size: divide the data batch with the given size, but one
diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py
index d7fde0e07..381581ce9 100644
--- a/tianshou/data/buffer/base.py
+++ b/tianshou/data/buffer/base.py
@@ -1,10 +1,11 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
import h5py
import numpy as np
-from typing import Any, Dict, List, Tuple, Union, Optional
from tianshou.data import Batch
-from tianshou.data.utils.converter import to_hdf5, from_hdf5
-from tianshou.data.batch import _create_value, _alloc_by_keys_diff
+from tianshou.data.batch import _alloc_by_keys_diff, _create_value
+from tianshou.data.utils.converter import from_hdf5, to_hdf5
class ReplayBuffer:
@@ -81,9 +82,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
- assert (
- key not in self._reserved_keys
- ), "key '{}' is reserved and cannot be assigned".format(key)
+ assert (key not in self._reserved_keys
+ ), "key '{}' is reserved and cannot be assigned".format(key)
super().__setattr__(key, value)
def save_hdf5(self, path: str) -> None:
@@ -160,9 +160,8 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray:
self._meta[to_indices] = buffer._meta[from_indices]
return to_indices
- def _add_index(
- self, rew: Union[float, np.ndarray], done: bool
- ) -> Tuple[int, Union[float, np.ndarray], int, int]:
+ def _add_index(self, rew: Union[float, np.ndarray],
+ done: bool) -> Tuple[int, Union[float, np.ndarray], int, int]:
"""Maintain the buffer's state after adding one data batch.
Return (index_to_be_modified, episode_reward, episode_length,
@@ -183,7 +182,9 @@ def _add_index(
return ptr, self._ep_rew * 0.0, 0, self._ep_idx
def add(
- self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into replay buffer.
@@ -246,7 +247,8 @@ def sample_indices(self, batch_size: int) -> np.ndarray:
return np.random.choice(self._size, batch_size)
elif batch_size == 0: # construct current available indices
return np.concatenate(
- [np.arange(self._index, self._size), np.arange(self._index)]
+ [np.arange(self._index, self._size),
+ np.arange(self._index)]
)
else:
return np.array([], int)
@@ -254,7 +256,8 @@ def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size < 0:
return np.array([], int)
all_indices = prev_indices = np.concatenate(
- [np.arange(self._index, self._size), np.arange(self._index)]
+ [np.arange(self._index, self._size),
+ np.arange(self._index)]
)
for _ in range(self.stack_num - 2):
prev_indices = self.prev(prev_indices)
diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py
index 49bb33bcf..19310dbd0 100644
--- a/tianshou/data/buffer/cached.py
+++ b/tianshou/data/buffer/cached.py
@@ -1,5 +1,6 @@
+from typing import List, Optional, Tuple, Union
+
import numpy as np
-from typing import List, Tuple, Union, Optional
from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager
@@ -45,7 +46,9 @@ def __init__(
self.cached_buffer_num = cached_buffer_num
def add(
- self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into CachedReplayBuffer.
diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py
index dc93b6867..70ebcab03 100644
--- a/tianshou/data/buffer/manager.py
+++ b/tianshou/data/buffer/manager.py
@@ -1,9 +1,10 @@
+from typing import List, Optional, Sequence, Tuple, Union
+
import numpy as np
from numba import njit
-from typing import List, Tuple, Union, Sequence, Optional
-from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
-from tianshou.data.batch import _create_value, _alloc_by_keys_diff
+from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
+from tianshou.data.batch import _alloc_by_keys_diff, _create_value
class ReplayBufferManager(ReplayBuffer):
@@ -63,33 +64,45 @@ def set_batch(self, batch: Batch) -> None:
self._set_batch_for_children()
def unfinished_index(self) -> np.ndarray:
- return np.concatenate([
- buf.unfinished_index() + offset
- for offset, buf in zip(self._offset, self.buffers)
- ])
+ return np.concatenate(
+ [
+ buf.unfinished_index() + offset
+ for offset, buf in zip(self._offset, self.buffers)
+ ]
+ )
def prev(self, index: Union[int, np.ndarray]) -> np.ndarray:
if isinstance(index, (list, np.ndarray)):
- return _prev_index(np.asarray(index), self._extend_offset,
- self.done, self.last_index, self._lengths)
+ return _prev_index(
+ np.asarray(index), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )
else:
- return _prev_index(np.array([index]), self._extend_offset,
- self.done, self.last_index, self._lengths)[0]
+ return _prev_index(
+ np.array([index]), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )[0]
def next(self, index: Union[int, np.ndarray]) -> np.ndarray:
if isinstance(index, (list, np.ndarray)):
- return _next_index(np.asarray(index), self._extend_offset,
- self.done, self.last_index, self._lengths)
+ return _next_index(
+ np.asarray(index), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )
else:
- return _next_index(np.array([index]), self._extend_offset,
- self.done, self.last_index, self._lengths)[0]
+ return _next_index(
+ np.array([index]), self._extend_offset, self.done, self.last_index,
+ self._lengths
+ )[0]
def update(self, buffer: ReplayBuffer) -> np.ndarray:
"""The ReplayBufferManager cannot be updated by any buffer."""
raise NotImplementedError
def add(
- self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into ReplayBufferManager.
@@ -145,10 +158,12 @@ def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size < 0:
return np.array([], int)
if self._sample_avail and self.stack_num > 1:
- all_indices = np.concatenate([
- buf.sample_indices(0) + offset
- for offset, buf in zip(self._offset, self.buffers)
- ])
+ all_indices = np.concatenate(
+ [
+ buf.sample_indices(0) + offset
+ for offset, buf in zip(self._offset, self.buffers)
+ ]
+ )
if batch_size == 0:
return all_indices
else:
@@ -163,10 +178,12 @@ def sample_indices(self, batch_size: int) -> np.ndarray:
# avoid batch_size > 0 and sample_num == 0 -> get child's all data
sample_num[sample_num == 0] = -1
- return np.concatenate([
- buf.sample_indices(bsz) + offset
- for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
- ])
+ return np.concatenate(
+ [
+ buf.sample_indices(bsz) + offset
+ for offset, buf, bsz in zip(self._offset, self.buffers, sample_num)
+ ]
+ )
class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager):
diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py
index a4357d5ed..fa3c49be8 100644
--- a/tianshou/data/buffer/prio.py
+++ b/tianshou/data/buffer/prio.py
@@ -1,8 +1,9 @@
-import torch
+from typing import Any, List, Optional, Tuple, Union
+
import numpy as np
-from typing import Any, List, Tuple, Union, Optional
+import torch
-from tianshou.data import Batch, SegmentTree, to_numpy, ReplayBuffer
+from tianshou.data import Batch, ReplayBuffer, SegmentTree, to_numpy
class PrioritizedReplayBuffer(ReplayBuffer):
@@ -39,7 +40,7 @@ def __init__(
self._weight_norm = weight_norm
def init_weight(self, index: Union[int, np.ndarray]) -> None:
- self.weight[index] = self._max_prio ** self._alpha
+ self.weight[index] = self._max_prio**self._alpha
def update(self, buffer: ReplayBuffer) -> np.ndarray:
indices = super().update(buffer)
@@ -47,7 +48,9 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray:
return indices
def add(
- self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
+ self,
+ batch: Batch,
+ buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids)
self.init_weight(ptr)
@@ -63,14 +66,14 @@ def sample_indices(self, batch_size: int) -> np.ndarray:
def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
"""Get the importance sampling weight.
- The "weight" in the returned Batch is the weight on loss function to de-bias
+ The "weight" in the returned Batch is the weight on loss function to debias
the sampling process (some transition tuples are sampled more often so their
losses are weighted less).
"""
# important sampling weight calculation
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
- return (self.weight[index] / self._min_prio) ** (-self._beta)
+ return (self.weight[index] / self._min_prio)**(-self._beta)
def update_weight(
self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]
@@ -81,7 +84,7 @@ def update_weight(
:param np.ndarray new_weight: new priority weight you want to update.
"""
weight = np.abs(to_numpy(new_weight)) + self.__eps
- self.weight[index] = weight ** self._alpha
+ self.weight[index] = weight**self._alpha
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py
index 374765bdc..2d4831c06 100644
--- a/tianshou/data/buffer/vecbuf.py
+++ b/tianshou/data/buffer/vecbuf.py
@@ -1,8 +1,13 @@
-import numpy as np
from typing import Any
-from tianshou.data import ReplayBuffer, ReplayBufferManager
-from tianshou.data import PrioritizedReplayBuffer, PrioritizedReplayBufferManager
+import numpy as np
+
+from tianshou.data import (
+ PrioritizedReplayBuffer,
+ PrioritizedReplayBufferManager,
+ ReplayBuffer,
+ ReplayBufferManager,
+)
class VectorReplayBuffer(ReplayBufferManager):
diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py
index 192213aa3..d52cc511d 100644
--- a/tianshou/data/collector.py
+++ b/tianshou/data/collector.py
@@ -1,21 +1,22 @@
-import gym
import time
-import torch
import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import gym
import numpy as np
-from typing import Any, Dict, List, Union, Optional, Callable
+import torch
-from tianshou.policy import BasePolicy
-from tianshou.data.batch import _alloc_by_keys_diff
-from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.data import (
Batch,
+ CachedReplayBuffer,
ReplayBuffer,
ReplayBufferManager,
VectorReplayBuffer,
- CachedReplayBuffer,
to_numpy,
)
+from tianshou.data.batch import _alloc_by_keys_diff
+from tianshou.env import BaseVectorEnv, DummyVectorEnv
+from tianshou.policy import BasePolicy
class Collector(object):
@@ -97,8 +98,9 @@ def reset(self) -> None:
"""Reset all related variables in the collector."""
# use empty Batch for "state" so that self.data supports slicing
# convert empty Batch to None when passing data to policy
- self.data = Batch(obs={}, act={}, rew={}, done={},
- obs_next={}, info={}, policy={})
+ self.data = Batch(
+ obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
+ )
self.reset_env()
self.reset_buffer()
self.reset_stat()
@@ -115,8 +117,8 @@ def reset_env(self) -> None:
"""Reset all of the environments."""
obs = self.env.reset()
if self.preprocess_fn:
- obs = self.preprocess_fn(
- obs=obs, env_id=np.arange(self.env_num)).get("obs", obs)
+ obs = self.preprocess_fn(obs=obs,
+ env_id=np.arange(self.env_num)).get("obs", obs)
self.data.obs = obs
def _reset_state(self, id: Union[int, List[int]]) -> None:
@@ -184,8 +186,10 @@ def collect(
ready_env_ids = np.arange(min(self.env_num, n_episode))
self.data = self.data[:min(self.env_num, n_episode)]
else:
- raise TypeError("Please specify at least one (either n_step or n_episode) "
- "in AsyncCollector.collect().")
+ raise TypeError(
+ "Please specify at least one (either n_step or n_episode) "
+ "in AsyncCollector.collect()."
+ )
start_time = time.time()
@@ -203,7 +207,8 @@ def collect(
# get the next action
if random:
self.data.update(
- act=[self._action_space[i].sample() for i in ready_env_ids])
+ act=[self._action_space[i].sample() for i in ready_env_ids]
+ )
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
@@ -225,19 +230,21 @@ def collect(
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
- obs_next, rew, done, info = self.env.step(
- action_remap, ready_env_ids) # type: ignore
+ result = self.env.step(action_remap, ready_env_ids) # type: ignore
+ obs_next, rew, done, info = result
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn:
- self.data.update(self.preprocess_fn(
- obs_next=self.data.obs_next,
- rew=self.data.rew,
- done=self.data.done,
- info=self.data.info,
- policy=self.data.policy,
- env_id=ready_env_ids,
- ))
+ self.data.update(
+ self.preprocess_fn(
+ obs_next=self.data.obs_next,
+ rew=self.data.rew,
+ done=self.data.done,
+ info=self.data.info,
+ policy=self.data.policy,
+ env_id=ready_env_ids,
+ )
+ )
if render:
self.env.render()
@@ -246,7 +253,8 @@ def collect(
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
- self.data, buffer_ids=ready_env_ids)
+ self.data, buffer_ids=ready_env_ids
+ )
# collect statistics
step_count += len(ready_env_ids)
@@ -263,7 +271,8 @@ def collect(
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(
- obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
+ obs=obs_reset, env_id=env_ind_global
+ ).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local:
self._reset_state(i)
@@ -290,13 +299,18 @@ def collect(
self.collect_time += max(time.time() - start_time, 1e-9)
if n_episode:
- self.data = Batch(obs={}, act={}, rew={}, done={},
- obs_next={}, info={}, policy={})
+ self.data = Batch(
+ obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
+ )
self.reset_env()
if episode_count > 0:
- rews, lens, idxs = list(map(
- np.concatenate, [episode_rews, episode_lens, episode_start_indices]))
+ rews, lens, idxs = list(
+ map(
+ np.concatenate,
+ [episode_rews, episode_lens, episode_start_indices]
+ )
+ )
else:
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
@@ -377,8 +391,10 @@ def collect(
elif n_episode is not None:
assert n_episode > 0
else:
- raise TypeError("Please specify at least one (either n_step or n_episode) "
- "in AsyncCollector.collect().")
+ raise TypeError(
+ "Please specify at least one (either n_step or n_episode) "
+ "in AsyncCollector.collect()."
+ )
warnings.warn("Using async setting may collect extra transitions into buffer.")
ready_env_ids = self._ready_env_ids
@@ -401,7 +417,8 @@ def collect(
# get the next action
if random:
self.data.update(
- act=[self._action_space[i].sample() for i in ready_env_ids])
+ act=[self._action_space[i].sample() for i in ready_env_ids]
+ )
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
@@ -431,8 +448,8 @@ def collect(
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
- obs_next, rew, done, info = self.env.step(
- action_remap, ready_env_ids) # type: ignore
+ result = self.env.step(action_remap, ready_env_ids) # type: ignore
+ obs_next, rew, done, info = result
# change self.data here because ready_env_ids has changed
ready_env_ids = np.array([i["env_id"] for i in info])
@@ -440,13 +457,15 @@ def collect(
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn:
- self.data.update(self.preprocess_fn(
- obs_next=self.data.obs_next,
- rew=self.data.rew,
- done=self.data.done,
- info=self.data.info,
- env_id=ready_env_ids,
- ))
+ self.data.update(
+ self.preprocess_fn(
+ obs_next=self.data.obs_next,
+ rew=self.data.rew,
+ done=self.data.done,
+ info=self.data.info,
+ env_id=ready_env_ids,
+ )
+ )
if render:
self.env.render()
@@ -455,7 +474,8 @@ def collect(
# add data into the buffer
ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
- self.data, buffer_ids=ready_env_ids)
+ self.data, buffer_ids=ready_env_ids
+ )
# collect statistics
step_count += len(ready_env_ids)
@@ -472,7 +492,8 @@ def collect(
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(
- obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
+ obs=obs_reset, env_id=env_ind_global
+ ).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local:
self._reset_state(i)
@@ -500,8 +521,12 @@ def collect(
self.collect_time += max(time.time() - start_time, 1e-9)
if episode_count > 0:
- rews, lens, idxs = list(map(
- np.concatenate, [episode_rews, episode_lens, episode_start_indices]))
+ rews, lens, idxs = list(
+ map(
+ np.concatenate,
+ [episode_rews, episode_lens, episode_start_indices]
+ )
+ )
else:
rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py
index 9f7d88a82..bd1dd5358 100644
--- a/tianshou/data/utils/converter.py
+++ b/tianshou/data/utils/converter.py
@@ -1,12 +1,13 @@
-import h5py
-import torch
import pickle
-import numpy as np
from copy import deepcopy
from numbers import Number
-from typing import Any, Dict, Union, Optional
+from typing import Any, Dict, Optional, Union
+
+import h5py
+import numpy as np
+import torch
-from tianshou.data.batch import _parse_value, Batch
+from tianshou.data.batch import Batch, _parse_value
def to_numpy(x: Any) -> Union[Batch, np.ndarray]:
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 5bb6fcc06..063675c53 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -1,6 +1,7 @@
+from typing import Optional, Union
+
import numpy as np
from numba import njit
-from typing import Union, Optional
class SegmentTree:
@@ -29,9 +30,7 @@ def __init__(self, size: int) -> None:
def __len__(self) -> int:
return self._size
- def __getitem__(
- self, index: Union[int, np.ndarray]
- ) -> Union[float, np.ndarray]:
+ def __getitem__(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]:
"""Return self[index]."""
return self._value[index + self._bound]
@@ -64,9 +63,8 @@ def reduce(self, start: int = 0, end: Optional[int] = None) -> float:
end += self._size
return _reduce(self._value, start + self._bound - 1, end + self._bound)
- def get_prefix_sum_idx(
- self, value: Union[float, np.ndarray]
- ) -> Union[int, np.ndarray]:
+ def get_prefix_sum_idx(self, value: Union[float,
+ np.ndarray]) -> Union[int, np.ndarray]:
r"""Find the index with given value.
Return the minimum index for each ``v`` in ``value`` so that
@@ -122,9 +120,7 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float:
@njit
-def _get_prefix_sum_idx(
- value: np.ndarray, bound: int, sums: np.ndarray
-) -> np.ndarray:
+def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray:
"""Numba version (v0.51), 5x speed up with size=100000 and bsz=64.
vectorized np: 0.0923 (numpy best) -> 0.024 (now)
diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py
index a25e06f86..c77c30c3f 100644
--- a/tianshou/env/__init__.py
+++ b/tianshou/env/__init__.py
@@ -1,6 +1,13 @@
-from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \
- SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
+"""Env package."""
+
from tianshou.env.maenv import MultiAgentEnv
+from tianshou.env.venvs import (
+ BaseVectorEnv,
+ DummyVectorEnv,
+ RayVectorEnv,
+ ShmemVectorEnv,
+ SubprocVectorEnv,
+)
__all__ = [
"BaseVectorEnv",
diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py
index f6a454c3c..456bbca13 100644
--- a/tianshou/env/maenv.py
+++ b/tianshou/env/maenv.py
@@ -1,7 +1,8 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Tuple
+
import gym
import numpy as np
-from typing import Any, Dict, Tuple
-from abc import ABC, abstractmethod
class MultiAgentEnv(ABC, gym.Env):
diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py
index 5c873ce36..cec159012 100644
--- a/tianshou/env/utils.py
+++ b/tianshou/env/utils.py
@@ -1,6 +1,7 @@
-import cloudpickle
from typing import Any
+import cloudpickle
+
class CloudpickleWrapper(object):
"""A cloudpickle wrapper used in SubprocVectorEnv."""
diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py
index f9349ff24..654f55b69 100644
--- a/tianshou/env/venvs.py
+++ b/tianshou/env/venvs.py
@@ -1,10 +1,15 @@
+from typing import Any, Callable, List, Optional, Tuple, Union
+
import gym
import numpy as np
-from typing import Any, List, Tuple, Union, Optional, Callable
+from tianshou.env.worker import (
+ DummyEnvWorker,
+ EnvWorker,
+ RayEnvWorker,
+ SubprocEnvWorker,
+)
from tianshou.utils import RunningMeanStd
-from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
- RayEnvWorker
class BaseVectorEnv(gym.Env):
@@ -44,7 +49,7 @@ def seed(self, seed):
Otherwise, the outputs of these envs may be the same with each other.
- :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env.
+ :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the i-th env.
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
worker which contains the i-th env.
:param int wait_num: use in asynchronous simulation if the time cost of
@@ -56,7 +61,7 @@ def seed(self, seed):
:param float timeout: use in asynchronous simulation same as above, in each
vectorized step it only deal with those environments spending time
within ``timeout`` seconds.
- :param bool norm_obs: Whether to track mean/std of data and normalise observation
+ :param bool norm_obs: Whether to track mean/std of data and normalize observation
on return. For now, observation normalization only support observation of
type np.ndarray.
:param obs_rms: class to track mean&std of observation. If not given, it will
@@ -122,8 +127,9 @@ def __getattribute__(self, key: str) -> Any:
``action_space``. However, we would like the attribute lookup to go straight
into the worker (in fact, this vector env's action_space is always None).
"""
- if key in ['metadata', 'reward_range', 'spec', 'action_space',
- 'observation_space']: # reserved keys in gym.Env
+ if key in [
+ 'metadata', 'reward_range', 'spec', 'action_space', 'observation_space'
+ ]: # reserved keys in gym.Env
return self.__getattr__(key)
else:
return super().__getattribute__(key)
@@ -137,7 +143,8 @@ def __getattr__(self, key: str) -> List[Any]:
return [getattr(worker, key) for worker in self.workers]
def _wrap_id(
- self, id: Optional[Union[int, List[int], np.ndarray]] = None
+ self,
+ id: Optional[Union[int, List[int], np.ndarray]] = None
) -> Union[List[int], np.ndarray]:
if id is None:
return list(range(self.env_num))
@@ -222,7 +229,7 @@ def step(
if action is not None:
self._assert_id(id)
assert len(action) == len(id)
- for i, (act, env_id) in enumerate(zip(action, id)):
+ for act, env_id in zip(action, id):
self.workers[env_id].send_action(act)
self.waiting_conn.append(self.workers[env_id])
self.waiting_id.append(env_id)
@@ -230,7 +237,8 @@ def step(
ready_conns: List[EnvWorker] = []
while not ready_conns:
ready_conns = self.worker_class.wait(
- self.waiting_conn, self.wait_num, self.timeout)
+ self.waiting_conn, self.wait_num, self.timeout
+ )
result = []
for conn in ready_conns:
waiting_index = self.waiting_conn.index(conn)
@@ -246,13 +254,15 @@ def step(
except ValueError: # different len(obs)
obs_stack = np.array(obs_list, dtype=object)
rew_stack, done_stack, info_stack = map(
- np.stack, [rew_list, done_list, info_list])
+ np.stack, [rew_list, done_list, info_list]
+ )
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs_stack)
return self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack
def seed(
- self, seed: Optional[Union[int, List[int]]] = None
+ self,
+ seed: Optional[Union[int, List[int]]] = None
) -> List[Optional[List[int]]]:
"""Set the seed for all environments.
@@ -279,7 +289,8 @@ def render(self, **kwargs: Any) -> List[Any]:
if self.is_async and len(self.waiting_id) > 0:
raise RuntimeError(
f"Environments {self.waiting_id} are still stepping, cannot "
- "render them now.")
+ "render them now."
+ )
return [w.render(**kwargs) for w in self.workers]
def close(self) -> None:
@@ -324,6 +335,7 @@ class SubprocVectorEnv(BaseVectorEnv):
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
+
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)
@@ -341,6 +353,7 @@ class ShmemVectorEnv(BaseVectorEnv):
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
+
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)
diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py
index b9a20cf1e..1b1f37510 100644
--- a/tianshou/env/worker/__init__.py
+++ b/tianshou/env/worker/__init__.py
@@ -1,7 +1,7 @@
from tianshou.env.worker.base import EnvWorker
from tianshou.env.worker.dummy import DummyEnvWorker
-from tianshou.env.worker.subproc import SubprocEnvWorker
from tianshou.env.worker.ray import RayEnvWorker
+from tianshou.env.worker.subproc import SubprocEnvWorker
__all__ = [
"EnvWorker",
diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py
index dbf350a33..6fef9f68d 100644
--- a/tianshou/env/worker/base.py
+++ b/tianshou/env/worker/base.py
@@ -1,7 +1,8 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, List, Optional, Tuple
+
import gym
import numpy as np
-from abc import ABC, abstractmethod
-from typing import Any, List, Tuple, Optional, Callable
class EnvWorker(ABC):
@@ -11,7 +12,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self._env_fn = env_fn
self.is_closed = False
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
- self.action_space = getattr(self, "action_space")
+ self.action_space = getattr(self, "action_space") # noqa: B009
@abstractmethod
def __getattr__(self, key: str) -> Any:
@@ -43,7 +44,9 @@ def step(
@staticmethod
def wait(
- workers: List["EnvWorker"], wait_num: int, timeout: Optional[float] = None
+ workers: List["EnvWorker"],
+ wait_num: int,
+ timeout: Optional[float] = None
) -> List["EnvWorker"]:
"""Given a list of workers, return those ready ones."""
raise NotImplementedError
diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py
index d0579d162..9e68e9f04 100644
--- a/tianshou/env/worker/dummy.py
+++ b/tianshou/env/worker/dummy.py
@@ -1,6 +1,7 @@
+from typing import Any, Callable, List, Optional
+
import gym
import numpy as np
-from typing import Any, List, Callable, Optional
from tianshou.env.worker import EnvWorker
diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py
index af7285b22..5d73763f2 100644
--- a/tianshou/env/worker/ray.py
+++ b/tianshou/env/worker/ray.py
@@ -1,6 +1,7 @@
+from typing import Any, Callable, List, Optional, Tuple
+
import gym
import numpy as np
-from typing import Any, List, Callable, Tuple, Optional
from tianshou.env.worker import EnvWorker
diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py
index 8b89b6c34..8ef264360 100644
--- a/tianshou/env/worker/subproc.py
+++ b/tianshou/env/worker/subproc.py
@@ -1,15 +1,15 @@
-import gym
-import time
import ctypes
-import numpy as np
+import time
from collections import OrderedDict
-from multiprocessing.context import Process
from multiprocessing import Array, Pipe, connection
-from typing import Any, List, Tuple, Union, Callable, Optional
+from multiprocessing.context import Process
+from typing import Any, Callable, List, Optional, Tuple, Union
-from tianshou.env.worker import EnvWorker
-from tianshou.env.utils import CloudpickleWrapper
+import gym
+import numpy as np
+from tianshou.env.utils import CloudpickleWrapper
+from tianshou.env.worker import EnvWorker
_NP_TO_CT = {
np.bool_: ctypes.c_bool,
@@ -62,6 +62,7 @@ def _worker(
env_fn_wrapper: CloudpickleWrapper,
obs_bufs: Optional[Union[dict, tuple, ShArray]] = None,
) -> None:
+
def _encode_obs(
obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray]
) -> None:
@@ -144,6 +145,7 @@ def __getattr__(self, key: str) -> Any:
return self.parent_remote.recv()
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
+
def decode_obs(
buffer: Optional[Union[dict, tuple, ShArray]]
) -> Union[dict, tuple, np.ndarray]:
diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py
index a59085809..25316e98d 100644
--- a/tianshou/exploration/random.py
+++ b/tianshou/exploration/random.py
@@ -1,6 +1,7 @@
-import numpy as np
from abc import ABC, abstractmethod
-from typing import Union, Optional, Sequence
+from typing import Optional, Sequence, Union
+
+import numpy as np
class BaseNoise(ABC, object):
@@ -20,7 +21,7 @@ def __call__(self, size: Sequence[int]) -> np.ndarray:
class GaussianNoise(BaseNoise):
- """The vanilla gaussian process, for exploration in DDPG by default."""
+ """The vanilla Gaussian process, for exploration in DDPG by default."""
def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None:
super().__init__()
@@ -45,7 +46,7 @@ class OUNoise(BaseNoise):
For required parameters, you can refer to the stackoverflow page. However,
our experiment result shows that (similar to OpenAI SpinningUp) using
- vanilla gaussian process has little difference from using the
+ vanilla Gaussian process has little difference from using the
Ornstein-Uhlenbeck process.
"""
@@ -74,7 +75,8 @@ def __call__(self, size: Sequence[int], mu: Optional[float] = None) -> np.ndarra
Return an numpy array which size is equal to ``size``.
"""
if self._x is None or isinstance(
- self._x, np.ndarray) and self._x.shape != size:
+ self._x, np.ndarray
+ ) and self._x.shape != size:
self._x = 0.0
if mu is None:
mu = self._mu
diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py
index 8a9c6478e..6a842356f 100644
--- a/tianshou/policy/__init__.py
+++ b/tianshou/policy/__init__.py
@@ -1,3 +1,6 @@
+"""Policy package."""
+# isort:skip_file
+
from tianshou.policy.base import BasePolicy
from tianshou.policy.random import RandomPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
@@ -22,7 +25,6 @@
from tianshou.policy.modelbased.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
-
__all__ = [
"BasePolicy",
"RandomPolicy",
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index 6dc3bdda8..feb6479ce 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -1,19 +1,20 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
import gym
-import torch
import numpy as np
-from torch import nn
+import torch
+from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
from numba import njit
-from abc import ABC, abstractmethod
-from typing import Any, Dict, Tuple, Union, Optional, Callable
-from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary
+from torch import nn
-from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
+from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
class BasePolicy(ABC, nn.Module):
"""The base class for any RL policy.
- Tianshou aims to modularizing RL algorithms. It comes into several classes of
+ Tianshou aims to modularize RL algorithms. It comes into several classes of
policies in Tianshou. All of the policy classes must inherit
:class:`~tianshou.policy.BasePolicy`.
@@ -84,9 +85,8 @@ def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id
- def exploration_noise(
- self, act: Union[np.ndarray, Batch], batch: Batch
- ) -> Union[np.ndarray, Batch]:
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
"""Modify the action from policy.forward with exploration noise.
:param act: a data batch or numpy.ndarray which is the action taken by
@@ -216,9 +216,8 @@ def post_process_fn(
if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
buffer.update_weight(indices, batch.weight)
- def update(
- self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
- ) -> Dict[str, Any]:
+ def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
+ **kwargs: Any) -> Dict[str, Any]:
"""Update the policy network and replay buffer.
It includes 3 function steps: process_fn, learn, and post_process_fn. In
@@ -286,7 +285,7 @@ def compute_episodic_return(
:param Batch batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be
- recongized by buffer.unfinished_index().
+ recognized by buffer.unfinished_index().
:param numpy.ndarray indices: tell batch's location in buffer, batch is equal
to buffer[indices].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py
index f94aa1d39..a5321acdf 100644
--- a/tianshou/policy/imitation/base.py
+++ b/tianshou/policy/imitation/base.py
@@ -1,7 +1,8 @@
-import torch
+from typing import Any, Dict, Optional, Union
+
import numpy as np
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, Union, Optional
from tianshou.data import Batch, to_torch
from tianshou.policy import BasePolicy
diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py
index 671b8b080..d9cac65df 100644
--- a/tianshou/policy/imitation/discrete_bcq.py
+++ b/tianshou/policy/imitation/discrete_bcq.py
@@ -1,11 +1,12 @@
import math
-import torch
+from typing import Any, Dict, Optional, Union
+
import numpy as np
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, Union, Optional
-from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch
+from tianshou.policy import DQNPolicy
class DiscreteBCQPolicy(DQNPolicy):
@@ -14,7 +15,7 @@ class DiscreteBCQPolicy(DQNPolicy):
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
:param torch.nn.Module imitator: a model following the rules in
- :class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits)
+ :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int estimation_step: the number of steps to look ahead. Default to 1.
@@ -22,7 +23,7 @@ class DiscreteBCQPolicy(DQNPolicy):
:param float eval_eps: the epsilon-greedy noise added in evaluation.
:param float unlikely_action_threshold: the threshold (tau) for unlikely
actions, as shown in Equ. (17) in the paper. Default to 0.3.
- :param float imitation_logits_penalty: reguralization weight for imitation
+ :param float imitation_logits_penalty: regularization weight for imitation
logits. Default to 1e-2.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
@@ -47,8 +48,10 @@ def __init__(
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
- super().__init__(model, optim, discount_factor, estimation_step,
- target_update_freq, reward_normalization, **kwargs)
+ super().__init__(
+ model, optim, discount_factor, estimation_step, target_update_freq,
+ reward_normalization, **kwargs
+ )
assert target_update_freq > 0, "BCQ needs target network setting."
self.imitator = imitator
assert 0.0 <= unlikely_action_threshold < 1.0, \
@@ -93,8 +96,12 @@ def forward( # type: ignore
mask = (ratio < self._log_tau).float()
action = (q_value - np.inf * mask).argmax(dim=-1)
- return Batch(act=action, state=state, q_value=q_value,
- imitation_logits=imitation_logits)
+ return Batch(
+ act=action,
+ state=state,
+ q_value=q_value,
+ imitation_logits=imitation_logits
+ )
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._iter % self._freq == 0:
@@ -108,7 +115,9 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
act = to_torch(batch.act, dtype=torch.long, device=target_q.device)
q_loss = F.smooth_l1_loss(current_q, target_q)
i_loss = F.nll_loss(
- F.log_softmax(imitation_logits, dim=-1), act) # type: ignore
+ F.log_softmax(imitation_logits, dim=-1),
+ act # type: ignore
+ )
reg_loss = imitation_logits.pow(2).mean()
loss = q_loss + i_loss + self._weight_reg * reg_loss
diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py
index c6e1b50d0..ad4ed19a3 100644
--- a/tianshou/policy/imitation/discrete_cql.py
+++ b/tianshou/policy/imitation/discrete_cql.py
@@ -1,10 +1,11 @@
-import torch
+from typing import Any, Dict
+
import numpy as np
+import torch
import torch.nn.functional as F
-from typing import Any, Dict
-from tianshou.policy import QRDQNPolicy
from tianshou.data import Batch, to_torch
+from tianshou.policy import QRDQNPolicy
class DiscreteCQLPolicy(QRDQNPolicy):
@@ -40,8 +41,10 @@ def __init__(
min_q_weight: float = 10.0,
**kwargs: Any,
) -> None:
- super().__init__(model, optim, discount_factor, num_quantiles, estimation_step,
- target_update_freq, reward_normalization, **kwargs)
+ super().__init__(
+ model, optim, discount_factor, num_quantiles, estimation_step,
+ target_update_freq, reward_normalization, **kwargs
+ )
self._min_q_weight = min_q_weight
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
@@ -55,9 +58,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
- huber_loss = (u * (
- self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()
- ).abs()).sum(-1).mean(1)
+ huber_loss = (
+ u * (self.tau_hat -
+ (target_dist - curr_dist).detach().le(0.).float()).abs()
+ ).sum(-1).mean(1)
qr_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py
index 05e4b2655..6a149509e 100644
--- a/tianshou/policy/imitation/discrete_crr.py
+++ b/tianshou/policy/imitation/discrete_crr.py
@@ -1,11 +1,12 @@
-import torch
from copy import deepcopy
from typing import Any, Dict
+
+import torch
import torch.nn.functional as F
from torch.distributions import Categorical
-from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.data import Batch, to_torch, to_torch_as
+from tianshou.policy.modelfree.pg import PGPolicy
class DiscreteCRRPolicy(PGPolicy):
diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py
index b438dbcbc..e5ea5c9a5 100644
--- a/tianshou/policy/modelbased/psrl.py
+++ b/tianshou/policy/modelbased/psrl.py
@@ -1,6 +1,7 @@
-import torch
+from typing import Any, Dict, Optional, Tuple, Union
+
import numpy as np
-from typing import Any, Dict, Tuple, Union, Optional
+import torch
from tianshou.data import Batch
from tianshou.policy import BasePolicy
@@ -70,14 +71,16 @@ def observe(
sum_count = self.rew_count + rew_count
self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count
self.rew_square_sum += rew_square_sum
- raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2
- self.rew_std = np.sqrt(1 / (
- sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2))
+ raw_std2 = self.rew_square_sum / sum_count - self.rew_mean**2
+ self.rew_std = np.sqrt(
+ 1 / (sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior**2)
+ )
self.rew_count = sum_count
def sample_trans_prob(self) -> np.ndarray:
sample_prob = torch.distributions.Dirichlet(
- torch.from_numpy(self.trans_count)).sample().numpy()
+ torch.from_numpy(self.trans_count)
+ ).sample().numpy()
return sample_prob
def sample_reward(self) -> np.ndarray:
@@ -168,12 +171,10 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
- assert (
- 0.0 <= discount_factor <= 1.0
- ), "discount factor should be in [0, 1]"
+ assert (0.0 <= discount_factor <= 1.0), "discount factor should be in [0, 1]"
self.model = PSRLModel(
- trans_count_prior, rew_mean_prior, rew_std_prior,
- discount_factor, epsilon)
+ trans_count_prior, rew_mean_prior, rew_std_prior, discount_factor, epsilon
+ )
self._add_done_loop = add_done_loop
def forward(
@@ -195,9 +196,7 @@ def forward(
act = self.model(batch.obs, state=state, info=batch.info)
return Batch(act=act)
- def learn(
- self, batch: Batch, *args: Any, **kwargs: Any
- ) -> Dict[str, float]:
+ def learn(self, batch: Batch, *args: Any, **kwargs: Any) -> Dict[str, float]:
n_s, n_a = self.model.n_state, self.model.n_action
trans_count = np.zeros((n_s, n_a, n_s))
rew_sum = np.zeros((n_s, n_a))
@@ -207,7 +206,7 @@ def learn(
obs, act, obs_next = b.obs, b.act, b.obs_next
trans_count[obs, act, obs_next] += 1
rew_sum[obs, act] += b.rew
- rew_square_sum[obs, act] += b.rew ** 2
+ rew_square_sum[obs, act] += b.rew**2
rew_count[obs, act] += 1
if self._add_done_loop and b.done:
# special operation for terminal states: add a self-loop
diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py
index 3e05ce0b6..e44b58b58 100644
--- a/tianshou/policy/modelfree/a2c.py
+++ b/tianshou/policy/modelfree/a2c.py
@@ -1,11 +1,12 @@
-import torch
+from typing import Any, Dict, List, Optional, Type
+
import numpy as np
-from torch import nn
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, List, Type, Optional
+from torch import nn
-from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
+from tianshou.policy import PGPolicy
class A2CPolicy(PGPolicy):
@@ -96,8 +97,14 @@ def _compute_returns(
v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
unnormalized_returns, advantages = self.compute_episodic_return(
- batch, buffer, indices, v_s_, v_s,
- gamma=self._gamma, gae_lambda=self._lambda)
+ batch,
+ buffer,
+ indices,
+ v_s_,
+ v_s,
+ gamma=self._gamma,
+ gae_lambda=self._lambda
+ )
if self._rew_norm:
batch.returns = unnormalized_returns / \
np.sqrt(self.ret_rms.var + self._eps)
@@ -130,7 +137,8 @@ def learn( # type: ignore
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
set(self.actor.parameters()).union(self.critic.parameters()),
- max_norm=self._grad_norm)
+ max_norm=self._grad_norm
+ )
self.optim.step()
actor_losses.append(actor_loss.item())
vf_losses.append(vf_loss.item())
diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py
index a664096f5..4e79eb356 100644
--- a/tianshou/policy/modelfree/c51.py
+++ b/tianshou/policy/modelfree/c51.py
@@ -1,9 +1,10 @@
-import torch
-import numpy as np
from typing import Any, Dict, Optional
-from tianshou.policy import DQNPolicy
+import numpy as np
+import torch
+
from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import DQNPolicy
class C51Policy(DQNPolicy):
@@ -44,8 +45,10 @@ def __init__(
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
- super().__init__(model, optim, discount_factor, estimation_step,
- target_update_freq, reward_normalization, **kwargs)
+ super().__init__(
+ model, optim, discount_factor, estimation_step, target_update_freq,
+ reward_normalization, **kwargs
+ )
assert num_atoms > 1, "num_atoms should be greater than 1"
assert v_min < v_max, "v_max should be larger than v_min"
self._num_atoms = num_atoms
@@ -77,9 +80,10 @@ def _target_dist(self, batch: Batch) -> torch.Tensor:
target_support = batch.returns.clamp(self._v_min, self._v_max)
# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL
- target_dist = (1 - (target_support.unsqueeze(1) -
- self.support.view(1, -1, 1)).abs() / self.delta_z
- ).clamp(0, 1) * next_dist.unsqueeze(1)
+ target_dist = (
+ 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() /
+ self.delta_z
+ ).clamp(0, 1) * next_dist.unsqueeze(1)
return target_dist.sum(-1)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
@@ -92,7 +96,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
curr_dist = self(batch).logits
act = batch.act
curr_dist = curr_dist[np.arange(len(act)), act, :]
- cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1)
+ cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1)
loss = (cross_entropy * weight).mean()
# ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100
batch.weight = cross_entropy.detach() # prio-buffer
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
index fc4a622cc..18bb81b6b 100644
--- a/tianshou/policy/modelfree/ddpg.py
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -1,12 +1,13 @@
-import torch
import warnings
-import numpy as np
from copy import deepcopy
-from typing import Any, Dict, Tuple, Union, Optional
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
-from tianshou.policy import BasePolicy
-from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.data import Batch, ReplayBuffer
+from tianshou.exploration import BaseNoise, GaussianNoise
+from tianshou.policy import BasePolicy
class DDPGPolicy(BasePolicy):
@@ -53,8 +54,11 @@ def __init__(
action_bound_method: str = "clip",
**kwargs: Any,
) -> None:
- super().__init__(action_scaling=action_scaling,
- action_bound_method=action_bound_method, **kwargs)
+ super().__init__(
+ action_scaling=action_scaling,
+ action_bound_method=action_bound_method,
+ **kwargs
+ )
assert action_bound_method != "tanh", "tanh mapping is not supported" \
"in policies where action is used as input of critic , because" \
"raw action in range (-inf, inf) will cause instability in training"
@@ -96,21 +100,21 @@ def sync_weight(self) -> None:
for o, n in zip(self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
- def _target_q(
- self, buffer: ReplayBuffer, indices: np.ndarray
- ) -> torch.Tensor:
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs_next: s_{t+n}
target_q = self.critic_old(
batch.obs_next,
- self(batch, model='actor_old', input='obs_next').act)
+ self(batch, model='actor_old', input='obs_next').act
+ )
return target_q
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
batch = self.compute_nstep_return(
- batch, buffer, indices, self._target_q,
- self._gamma, self._n_step, self._rew_norm)
+ batch, buffer, indices, self._target_q, self._gamma, self._n_step,
+ self._rew_norm
+ )
return batch
def forward(
@@ -156,8 +160,7 @@ def _mse_optimizer(
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# critic
- td, critic_loss = self._mse_optimizer(
- batch, self.critic, self.critic_optim)
+ td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim)
batch.weight = td # prio-buffer
# actor
action = self(batch).act
@@ -171,9 +174,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
"loss/critic": critic_loss.item(),
}
- def exploration_noise(
- self, act: Union[np.ndarray, Batch], batch: Batch
- ) -> Union[np.ndarray, Batch]:
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
if self._noise is None:
return act
if isinstance(act, np.ndarray):
diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py
index 18ac9fa12..7c580f37a 100644
--- a/tianshou/policy/modelfree/discrete_sac.py
+++ b/tianshou/policy/modelfree/discrete_sac.py
@@ -1,10 +1,11 @@
-import torch
+from typing import Any, Dict, Optional, Tuple, Union
+
import numpy as np
+import torch
from torch.distributions import Categorical
-from typing import Any, Dict, Tuple, Union, Optional
-from tianshou.policy import SACPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch
+from tianshou.policy import SACPolicy
class DiscreteSACPolicy(SACPolicy):
@@ -24,7 +25,7 @@ class DiscreteSACPolicy(SACPolicy):
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
regularization coefficient. Default to 0.2.
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, the
- alpha is automatatically tuned.
+ alpha is automatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
@@ -50,9 +51,21 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(
- actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
- tau, gamma, alpha, reward_normalization, estimation_step,
- action_scaling=False, action_bound_method="", **kwargs)
+ actor,
+ actor_optim,
+ critic1,
+ critic1_optim,
+ critic2,
+ critic2_optim,
+ tau,
+ gamma,
+ alpha,
+ reward_normalization,
+ estimation_step,
+ action_scaling=False,
+ action_bound_method="",
+ **kwargs
+ )
self._alpha: Union[float, torch.Tensor]
def forward( # type: ignore
@@ -68,9 +81,7 @@ def forward( # type: ignore
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
- def _target_q(
- self, buffer: ReplayBuffer, indices: np.ndarray
- ) -> torch.Tensor:
+ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
obs_next_result = self(batch, input="obs_next")
dist = obs_next_result.dist
@@ -85,7 +96,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
weight = batch.pop("weight", 1.0)
target_q = batch.returns.flatten()
act = to_torch(
- batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long)
+ batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long
+ )
# critic 1
current_q1 = self.critic1(batch.obs).gather(1, act).flatten()
@@ -139,7 +151,6 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
return result
- def exploration_noise(
- self, act: Union[np.ndarray, Batch], batch: Batch
- ) -> Union[np.ndarray, Batch]:
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
return act
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
index 03c39d171..850490998 100644
--- a/tianshou/policy/modelfree/dqn.py
+++ b/tianshou/policy/modelfree/dqn.py
@@ -1,10 +1,11 @@
-import torch
-import numpy as np
from copy import deepcopy
-from typing import Any, Dict, Union, Optional
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
+import torch
+from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.policy import BasePolicy
-from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class DQNPolicy(BasePolicy):
@@ -96,8 +97,9 @@ def process_fn(
:meth:`~tianshou.policy.BasePolicy.compute_nstep_return`.
"""
batch = self.compute_nstep_return(
- batch, buffer, indices, self._target_q,
- self._gamma, self._n_step, self._rew_norm)
+ batch, buffer, indices, self._target_q, self._gamma, self._n_step,
+ self._rew_norm
+ )
return batch
def compute_q_value(
@@ -173,9 +175,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self._iter += 1
return {"loss": loss.item()}
- def exploration_noise(
- self, act: Union[np.ndarray, Batch], batch: Batch
- ) -> Union[np.ndarray, Batch]:
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps
diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py
index dc03365e9..3c015b3d0 100644
--- a/tianshou/policy/modelfree/fqf.py
+++ b/tianshou/policy/modelfree/fqf.py
@@ -1,10 +1,11 @@
-import torch
+from typing import Any, Dict, Optional, Union
+
import numpy as np
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, Optional, Union
+from tianshou.data import Batch, ReplayBuffer, to_numpy
from tianshou.policy import DQNPolicy, QRDQNPolicy
-from tianshou.data import Batch, to_numpy, ReplayBuffer
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction
@@ -88,12 +89,14 @@ def forward(
)
else:
(logits, _, quantiles_tau), h = model(
- obs_, propose_model=self.propose_model, fractions=fractions,
- state=state, info=batch.info
+ obs_,
+ propose_model=self.propose_model,
+ fractions=fractions,
+ state=state,
+ info=batch.info
)
- weighted_logits = (
- fractions.taus[:, 1:] - fractions.taus[:, :-1]
- ).unsqueeze(1) * logits
+ weighted_logits = (fractions.taus[:, 1:] -
+ fractions.taus[:, :-1]).unsqueeze(1) * logits
q = DQNPolicy.compute_q_value(
self, weighted_logits.sum(2), getattr(obs, "mask", None)
)
@@ -101,7 +104,10 @@ def forward(
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
return Batch(
- logits=logits, act=act, state=h, fractions=fractions,
+ logits=logits,
+ act=act,
+ state=h,
+ fractions=fractions,
quantiles_tau=quantiles_tau
)
@@ -117,9 +123,12 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
- huber_loss = (u * (
- tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()
- ).abs()).sum(-1).mean(1)
+ huber_loss = (
+ u * (
+ tau_hats.unsqueeze(2) -
+ (target_dist - curr_dist).detach().le(0.).float()
+ ).abs()
+ ).sum(-1).mean(1)
quantile_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
@@ -131,16 +140,18 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169
values_1 = sa_quantiles - sa_quantile_hats[:, :-1]
- signs_1 = sa_quantiles > torch.cat([
- sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1)
+ signs_1 = sa_quantiles > torch.cat(
+ [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1
+ )
values_2 = sa_quantiles - sa_quantile_hats[:, 1:]
- signs_2 = sa_quantiles < torch.cat([
- sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1)
+ signs_2 = sa_quantiles < torch.cat(
+ [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1
+ )
gradient_of_taus = (
- torch.where(signs_1, values_1, -values_1)
- + torch.where(signs_2, values_2, -values_2)
+ torch.where(signs_1, values_1, -values_1) +
+ torch.where(signs_2, values_2, -values_2)
)
fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean()
# calculate entropy loss
diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py
index 4c54d3563..9d9777b98 100644
--- a/tianshou/policy/modelfree/iqn.py
+++ b/tianshou/policy/modelfree/iqn.py
@@ -1,10 +1,11 @@
-import torch
+from typing import Any, Dict, Optional, Union
+
import numpy as np
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, Optional, Union
-from tianshou.policy import QRDQNPolicy
from tianshou.data import Batch, to_numpy
+from tianshou.policy import QRDQNPolicy
class IQNPolicy(QRDQNPolicy):
@@ -45,8 +46,10 @@ def __init__(
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
- super().__init__(model, optim, discount_factor, sample_size, estimation_step,
- target_update_freq, reward_normalization, **kwargs)
+ super().__init__(
+ model, optim, discount_factor, sample_size, estimation_step,
+ target_update_freq, reward_normalization, **kwargs
+ )
assert sample_size > 1, "sample_size should be greater than 1"
assert online_sample_size > 1, "online_sample_size should be greater than 1"
assert target_sample_size > 1, "target_sample_size should be greater than 1"
@@ -71,9 +74,8 @@ def forward(
model = getattr(self, model)
obs = batch[input]
obs_ = obs.obs if hasattr(obs, "obs") else obs
- (logits, taus), h = model(
- obs_, sample_size=sample_size, state=state, info=batch.info
- )
+ (logits,
+ taus), h = model(obs_, sample_size=sample_size, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
if not hasattr(self, "max_action_num"):
self.max_action_num = q.shape[1]
@@ -92,9 +94,11 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
- huber_loss = (u * (
- taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float()
- ).abs()).sum(-1).mean(1)
+ huber_loss = (
+ u *
+ (taus.unsqueeze(2) -
+ (target_dist - curr_dist).detach().le(0.).float()).abs()
+ ).sum(-1).mean(1)
loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py
index 76da4bf0b..758093d1a 100644
--- a/tianshou/policy/modelfree/npg.py
+++ b/tianshou/policy/modelfree/npg.py
@@ -1,13 +1,13 @@
-import torch
+from typing import Any, Dict, List, Type
+
import numpy as np
-from torch import nn
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, List, Type
+from torch import nn
from torch.distributions import kl_divergence
-
-from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import A2CPolicy
class NPGPolicy(A2CPolicy):
@@ -82,7 +82,7 @@ def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
actor_losses, vf_losses, kls = [], [], []
- for step in range(repeat):
+ for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
@@ -91,7 +91,8 @@ def learn( # type: ignore
log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1)
actor_loss = -(log_prob * b.adv).mean()
flat_grads = self._get_flat_grad(
- actor_loss, self.actor, retain_graph=True).detach()
+ actor_loss, self.actor, retain_graph=True
+ ).detach()
# direction: calculate natural gradient
with torch.no_grad():
@@ -101,12 +102,14 @@ def learn( # type: ignore
# calculate first order gradient of kl with respect to theta
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
search_direction = -self._conjugate_gradients(
- flat_grads, flat_kl_grad, nsteps=10)
+ flat_grads, flat_kl_grad, nsteps=10
+ )
# step
with torch.no_grad():
- flat_params = torch.cat([param.data.view(-1)
- for param in self.actor.parameters()])
+ flat_params = torch.cat(
+ [param.data.view(-1) for param in self.actor.parameters()]
+ )
new_flat_params = flat_params + self._step_size * search_direction
self._set_from_flat_params(self.actor, new_flat_params)
new_dist = self(b).dist
@@ -138,8 +141,8 @@ def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor:
"""Matrix vector product."""
# caculate second order gradient of kl with respect to theta
kl_v = (flat_kl_grad * v).sum()
- flat_kl_grad_grad = self._get_flat_grad(
- kl_v, self.actor, retain_graph=True).detach()
+ flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor,
+ retain_graph=True).detach()
return flat_kl_grad_grad + v * self._damping
def _conjugate_gradients(
@@ -154,7 +157,7 @@ def _conjugate_gradients(
# Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0.
# Change if doing warm start.
rdotr = r.dot(r)
- for i in range(nsteps):
+ for _ in range(nsteps):
z = self._MVP(p, flat_kl_grad)
alpha = rdotr / p.dot(z)
x += alpha * p
@@ -179,6 +182,7 @@ def _set_from_flat_params(
for param in model.parameters():
flat_size = int(np.prod(list(param.size())))
param.data.copy_(
- flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
+ flat_params[prev_ind:prev_ind + flat_size].view(param.size())
+ )
prev_ind += flat_size
return model
diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py
index 1eb4dde4a..a64828874 100644
--- a/tianshou/policy/modelfree/pg.py
+++ b/tianshou/policy/modelfree/pg.py
@@ -1,10 +1,11 @@
-import torch
+from typing import Any, Dict, List, Optional, Type, Union
+
import numpy as np
-from typing import Any, Dict, List, Type, Union, Optional
+import torch
+from tianshou.data import Batch, ReplayBuffer, to_torch_as
from tianshou.policy import BasePolicy
from tianshou.utils import RunningMeanStd
-from tianshou.data import Batch, ReplayBuffer, to_torch_as
class PGPolicy(BasePolicy):
@@ -47,8 +48,11 @@ def __init__(
deterministic_eval: bool = False,
**kwargs: Any,
) -> None:
- super().__init__(action_scaling=action_scaling,
- action_bound_method=action_bound_method, **kwargs)
+ super().__init__(
+ action_scaling=action_scaling,
+ action_bound_method=action_bound_method,
+ **kwargs
+ )
self.actor = model
self.optim = optim
self.lr_scheduler = lr_scheduler
@@ -73,7 +77,8 @@ def process_fn(
"""
v_s_ = np.full(indices.shape, self.ret_rms.mean)
unnormalized_returns, _ = self.compute_episodic_return(
- batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0)
+ batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0
+ )
if self._rew_norm:
batch.returns = (unnormalized_returns - self.ret_rms.mean) / \
np.sqrt(self.ret_rms.var + self._eps)
diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py
index 824e19ad5..e1e17aa2f 100644
--- a/tianshou/policy/modelfree/ppo.py
+++ b/tianshou/policy/modelfree/ppo.py
@@ -1,10 +1,11 @@
-import torch
+from typing import Any, Dict, List, Optional, Type
+
import numpy as np
+import torch
from torch import nn
-from typing import Any, Dict, List, Type, Optional
-from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
+from tianshou.policy import A2CPolicy
class PPOPolicy(A2CPolicy):
@@ -124,8 +125,8 @@ def learn( # type: ignore
# calculate loss for critic
value = self.critic(b.obs).flatten()
if self._value_clip:
- v_clip = b.v_s + (value - b.v_s).clamp(
- -self._eps_clip, self._eps_clip)
+ v_clip = b.v_s + (value -
+ b.v_s).clamp(-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2)
vf_loss = torch.max(vf1, vf2).mean()
@@ -140,7 +141,8 @@ def learn( # type: ignore
if self._grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
set(self.actor.parameters()).union(self.critic.parameters()),
- max_norm=self._grad_norm)
+ max_norm=self._grad_norm
+ )
self.optim.step()
clip_losses.append(clip_loss.item())
vf_losses.append(vf_loss.item())
diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py
index 91d0f4bf0..fe3e101f7 100644
--- a/tianshou/policy/modelfree/qrdqn.py
+++ b/tianshou/policy/modelfree/qrdqn.py
@@ -1,11 +1,12 @@
-import torch
import warnings
+from typing import Any, Dict, Optional
+
import numpy as np
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, Optional
-from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import DQNPolicy
class QRDQNPolicy(DQNPolicy):
@@ -40,13 +41,16 @@ def __init__(
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
- super().__init__(model, optim, discount_factor, estimation_step,
- target_update_freq, reward_normalization, **kwargs)
+ super().__init__(
+ model, optim, discount_factor, estimation_step, target_update_freq,
+ reward_normalization, **kwargs
+ )
assert num_quantiles > 1, "num_quantiles should be greater than 1"
self._num_quantiles = num_quantiles
tau = torch.linspace(0, 1, self._num_quantiles + 1)
self.tau_hat = torch.nn.Parameter(
- ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False)
+ ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False
+ )
warnings.filterwarnings("ignore", message="Using a target size")
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
@@ -77,9 +81,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
- huber_loss = (u * (
- self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()
- ).abs()).sum(-1).mean(1)
+ huber_loss = (
+ u * (self.tau_hat -
+ (target_dist - curr_dist).detach().le(0.).float()).abs()
+ ).sum(-1).mean(1)
loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py
index 7aa4c682d..9028258d7 100644
--- a/tianshou/policy/modelfree/rainbow.py
+++ b/tianshou/policy/modelfree/rainbow.py
@@ -1,7 +1,7 @@
from typing import Any, Dict
-from tianshou.policy import C51Policy
from tianshou.data import Batch
+from tianshou.policy import C51Policy
from tianshou.utils.net.discrete import sample_noise
diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py
index 1858f27ed..2657a1eee 100644
--- a/tianshou/policy/modelfree/sac.py
+++ b/tianshou/policy/modelfree/sac.py
@@ -1,12 +1,13 @@
-import torch
-import numpy as np
from copy import deepcopy
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
from torch.distributions import Independent, Normal
-from typing import Any, Dict, Tuple, Union, Optional
-from tianshou.policy import DDPGPolicy
-from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, to_torch_as
+from tianshou.exploration import BaseNoise
+from tianshou.policy import DDPGPolicy
class SACPolicy(DDPGPolicy):
@@ -26,7 +27,7 @@ class SACPolicy(DDPGPolicy):
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
regularization coefficient. Default to 0.2.
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
- alpha is automatatically tuned.
+ alpha is automatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param BaseNoise exploration_noise: add a noise to action for exploration.
@@ -67,7 +68,8 @@ def __init__(
) -> None:
super().__init__(
None, None, None, None, tau, gamma, exploration_noise,
- reward_normalization, estimation_step, **kwargs)
+ reward_normalization, estimation_step, **kwargs
+ )
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
@@ -123,15 +125,17 @@ def forward( # type: ignore
# in appendix C to get some understanding of this equation.
if self.action_scaling and self.action_space is not None:
action_scale = to_torch_as(
- (self.action_space.high - self.action_space.low) / 2.0, act)
+ (self.action_space.high - self.action_space.low) / 2.0, act
+ )
else:
action_scale = 1.0 # type: ignore
squashed_action = torch.tanh(act)
log_prob = log_prob - torch.log(
action_scale * (1 - squashed_action.pow(2)) + self.__eps
).sum(-1, keepdim=True)
- return Batch(logits=logits, act=squashed_action,
- state=h, dist=dist, log_prob=log_prob)
+ return Batch(
+ logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob
+ )
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
batch = buffer[indices] # batch.obs: s_{t+n}
@@ -146,9 +150,11 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# critic 1&2
td1, critic1_loss = self._mse_optimizer(
- batch, self.critic1, self.critic1_optim)
+ batch, self.critic1, self.critic1_optim
+ )
td2, critic2_loss = self._mse_optimizer(
- batch, self.critic2, self.critic2_optim)
+ batch, self.critic2, self.critic2_optim
+ )
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
@@ -156,8 +162,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
a = obs_result.act
current_q1a = self.critic1(batch.obs, a).flatten()
current_q2a = self.critic2(batch.obs, a).flatten()
- actor_loss = (self._alpha * obs_result.log_prob.flatten()
- - torch.min(current_q1a, current_q2a)).mean()
+ actor_loss = (
+ self._alpha * obs_result.log_prob.flatten() -
+ torch.min(current_q1a, current_q2a)
+ ).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py
index 3ca785bcc..a033237ea 100644
--- a/tianshou/policy/modelfree/td3.py
+++ b/tianshou/policy/modelfree/td3.py
@@ -1,11 +1,12 @@
-import torch
-import numpy as np
from copy import deepcopy
from typing import Any, Dict, Optional
-from tianshou.policy import DDPGPolicy
+import numpy as np
+import torch
+
from tianshou.data import Batch, ReplayBuffer
from tianshou.exploration import BaseNoise, GaussianNoise
+from tianshou.policy import DDPGPolicy
class TD3Policy(DDPGPolicy):
@@ -64,9 +65,10 @@ def __init__(
estimation_step: int = 1,
**kwargs: Any,
) -> None:
- super().__init__(actor, actor_optim, None, None, tau, gamma,
- exploration_noise, reward_normalization,
- estimation_step, **kwargs)
+ super().__init__(
+ actor, actor_optim, None, None, tau, gamma, exploration_noise,
+ reward_normalization, estimation_step, **kwargs
+ )
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
self.critic1_optim = critic1_optim
@@ -103,16 +105,18 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
target_q = torch.min(
- self.critic1_old(batch.obs_next, a_),
- self.critic2_old(batch.obs_next, a_))
+ self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_)
+ )
return target_q
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
# critic 1&2
td1, critic1_loss = self._mse_optimizer(
- batch, self.critic1, self.critic1_optim)
+ batch, self.critic1, self.critic1_optim
+ )
td2, critic2_loss = self._mse_optimizer(
- batch, self.critic2, self.critic2_optim)
+ batch, self.critic2, self.critic2_optim
+ )
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py
index b0ba63f11..75956d987 100644
--- a/tianshou/policy/modelfree/trpo.py
+++ b/tianshou/policy/modelfree/trpo.py
@@ -1,9 +1,9 @@
-import torch
import warnings
-import torch.nn.functional as F
from typing import Any, Dict, List, Type
-from torch.distributions import kl_divergence
+import torch
+import torch.nn.functional as F
+from torch.distributions import kl_divergence
from tianshou.data import Batch
from tianshou.policy import NPGPolicy
@@ -70,7 +70,7 @@ def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
- for step in range(repeat):
+ for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
@@ -79,7 +79,8 @@ def learn( # type: ignore
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
actor_loss = -(ratio * b.adv).mean()
flat_grads = self._get_flat_grad(
- actor_loss, self.actor, retain_graph=True).detach()
+ actor_loss, self.actor, retain_graph=True
+ ).detach()
# direction: calculate natural gradient
with torch.no_grad():
@@ -89,26 +90,30 @@ def learn( # type: ignore
# calculate first order gradient of kl with respect to theta
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
search_direction = -self._conjugate_gradients(
- flat_grads, flat_kl_grad, nsteps=10)
+ flat_grads, flat_kl_grad, nsteps=10
+ )
# stepsize: calculate max stepsize constrained by kl bound
- step_size = torch.sqrt(2 * self._delta / (
- search_direction * self._MVP(search_direction, flat_kl_grad)
- ).sum(0, keepdim=True))
+ step_size = torch.sqrt(
+ 2 * self._delta /
+ (search_direction *
+ self._MVP(search_direction, flat_kl_grad)).sum(0, keepdim=True)
+ )
# stepsize: linesearch stepsize
with torch.no_grad():
- flat_params = torch.cat([param.data.view(-1)
- for param in self.actor.parameters()])
+ flat_params = torch.cat(
+ [param.data.view(-1) for param in self.actor.parameters()]
+ )
for i in range(self._max_backtracks):
new_flat_params = flat_params + step_size * search_direction
self._set_from_flat_params(self.actor, new_flat_params)
# calculate kl and if in bound, loss actually down
new_dist = self(b).dist
- new_dratio = (
- new_dist.log_prob(b.act) - b.logp_old).exp().float()
- new_dratio = new_dratio.reshape(
- new_dratio.size(0), -1).transpose(0, 1)
+ new_dratio = (new_dist.log_prob(b.act) -
+ b.logp_old).exp().float()
+ new_dratio = new_dratio.reshape(new_dratio.size(0),
+ -1).transpose(0, 1)
new_actor_loss = -(new_dratio * b.adv).mean()
kl = kl_divergence(old_dist, new_dist).mean()
@@ -121,8 +126,10 @@ def learn( # type: ignore
else:
self._set_from_flat_params(self.actor, new_flat_params)
step_size = torch.tensor([0.0])
- warnings.warn("Line search failed! It seems hyperparamters"
- " are poor and need to be changed.")
+ warnings.warn(
+ "Line search failed! It seems hyperparamters"
+ " are poor and need to be changed."
+ )
# optimize citirc
for _ in range(self._optim_critic_iters):
diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py
index e7b50f07f..75705f4a3 100644
--- a/tianshou/policy/multiagent/mapolicy.py
+++ b/tianshou/policy/multiagent/mapolicy.py
@@ -1,8 +1,9 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
import numpy as np
-from typing import Any, Dict, List, Tuple, Union, Optional
-from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer
+from tianshou.policy import BasePolicy
class MultiAgentPolicyManager(BasePolicy):
@@ -54,21 +55,22 @@ def process_fn(
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
results[f"agent_{policy.agent_id}"] = policy.process_fn(
- tmp_batch, buffer, tmp_indices)
+ tmp_batch, buffer, tmp_indices
+ )
if has_rew: # restore from save_rew
buffer._meta.rew = save_rew
return Batch(results)
- def exploration_noise(
- self, act: Union[np.ndarray, Batch], batch: Batch
- ) -> Union[np.ndarray, Batch]:
+ def exploration_noise(self, act: Union[np.ndarray, Batch],
+ batch: Batch) -> Union[np.ndarray, Batch]:
"""Add exploration noise from sub-policy onto act."""
for policy in self.policies:
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
if len(agent_index) == 0:
continue
act[agent_index] = policy.exploration_noise(
- act[agent_index], batch[agent_index])
+ act[agent_index], batch[agent_index]
+ )
return act
def forward( # type: ignore
@@ -100,8 +102,8 @@ def forward( # type: ignore
"agent_n": xxx}
}
"""
- results: List[Tuple[bool, np.ndarray, Batch,
- Union[np.ndarray, Batch], Batch]] = []
+ results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
+ Batch]] = []
for policy in self.policies:
# This part of code is difficult to understand.
# Let's follow an example with two agents
@@ -119,20 +121,28 @@ def forward( # type: ignore
if isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
- out = policy(batch=tmp_batch, state=None if state is None
- else state["agent_" + str(policy.agent_id)],
- **kwargs)
+ out = policy(
+ batch=tmp_batch,
+ state=None if state is None else state["agent_" +
+ str(policy.agent_id)],
+ **kwargs
+ )
act = out.act
each_state = out.state \
if (hasattr(out, "state") and out.state is not None) \
else Batch()
results.append((True, agent_index, out, act, each_state))
- holder = Batch.cat([{"act": act} for
- (has_data, agent_index, out, act, each_state)
- in results if has_data])
+ holder = Batch.cat(
+ [
+ {
+ "act": act
+ } for (has_data, agent_index, out, act, each_state) in results
+ if has_data
+ ]
+ )
state_dict, out_dict = {}, {}
- for policy, (has_data, agent_index, out, act, state) in zip(
- self.policies, results):
+ for policy, (has_data, agent_index, out, act,
+ state) in zip(self.policies, results):
if has_data:
holder.act[agent_index] = act
state_dict["agent_" + str(policy.agent_id)] = state
@@ -141,9 +151,8 @@ def forward( # type: ignore
holder["state"] = state_dict
return holder
- def learn(
- self, batch: Batch, **kwargs: Any
- ) -> Dict[str, Union[float, List[float]]]:
+ def learn(self, batch: Batch,
+ **kwargs: Any) -> Dict[str, Union[float, List[float]]]:
"""Dispatch the data to all policies for learning.
:return: a dict with the following contents:
diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py
index 9c7f132af..dfb79564f 100644
--- a/tianshou/policy/random.py
+++ b/tianshou/policy/random.py
@@ -1,5 +1,6 @@
+from typing import Any, Dict, Optional, Union
+
import numpy as np
-from typing import Any, Dict, Union, Optional
from tianshou.data import Batch
from tianshou.policy import BasePolicy
diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py
index 9fa88fbc3..11b3a95ef 100644
--- a/tianshou/trainer/__init__.py
+++ b/tianshou/trainer/__init__.py
@@ -1,3 +1,7 @@
+"""Trainer package."""
+
+# isort:skip_file
+
from tianshou.trainer.utils import test_episode, gather_info
from tianshou.trainer.onpolicy import onpolicy_trainer
from tianshou.trainer.offpolicy import offpolicy_trainer
diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py
index a2bcf051a..72cd00d06 100644
--- a/tianshou/trainer/offline.py
+++ b/tianshou/trainer/offline.py
@@ -1,13 +1,14 @@
import time
-import tqdm
-import numpy as np
from collections import defaultdict
-from typing import Dict, Union, Callable, Optional
+from typing import Callable, Dict, Optional, Union
+
+import numpy as np
+import tqdm
-from tianshou.policy import BasePolicy
-from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.data import Collector, ReplayBuffer
-from tianshou.trainer import test_episode, gather_info
+from tianshou.policy import BasePolicy
+from tianshou.trainer import gather_info, test_episode
+from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def offline_trainer(
@@ -74,17 +75,17 @@ def offline_trainer(
start_time = time.time()
test_collector.reset_stat()
- test_result = test_episode(policy, test_collector, test_fn, start_epoch,
- episode_per_test, logger, gradient_step, reward_metric)
+ test_result = test_episode(
+ policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
+ gradient_step, reward_metric
+ )
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1 + start_epoch, 1 + max_epoch):
policy.train()
- with tqdm.trange(
- update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
- ) as t:
- for i in t:
+ with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t:
+ for _ in t:
gradient_step += 1
losses = policy.update(batch_size, buffer)
data = {"gradient_step": str(gradient_step)}
@@ -96,8 +97,9 @@ def offline_trainer(
t.set_postfix(**data)
# test
test_result = test_episode(
- policy, test_collector, test_fn, epoch, episode_per_test,
- logger, gradient_step, reward_metric)
+ policy, test_collector, test_fn, epoch, episode_per_test, logger,
+ gradient_step, reward_metric
+ )
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
@@ -105,8 +107,10 @@ def offline_trainer(
save_fn(policy)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
+ f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
+ )
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py
index 2a576ccea..922646197 100644
--- a/tianshou/trainer/offpolicy.py
+++ b/tianshou/trainer/offpolicy.py
@@ -1,13 +1,14 @@
import time
-import tqdm
-import numpy as np
from collections import defaultdict
-from typing import Dict, Union, Callable, Optional
+from typing import Callable, Dict, Optional, Union
+
+import numpy as np
+import tqdm
from tianshou.data import Collector
from tianshou.policy import BasePolicy
-from tianshou.trainer import test_episode, gather_info
-from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
+from tianshou.trainer import gather_info, test_episode
+from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def offpolicy_trainer(
@@ -43,7 +44,7 @@ def offpolicy_trainer(
:param int step_per_epoch: the number of transitions collected per epoch.
:param int step_per_collect: the number of transitions the collector would collect
before the network update, i.e., trainer will collect "step_per_collect"
- transitions and do some policy network update repeatly in each epoch.
+ transitions and do some policy network update repeatedly in each epoch.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to feed in the
policy network.
@@ -91,8 +92,10 @@ def offpolicy_trainer(
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
- test_result = test_episode(policy, test_collector, test_fn, start_epoch,
- episode_per_test, logger, env_step, reward_metric)
+ test_result = test_episode(
+ policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
+ env_step, reward_metric
+ )
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
@@ -123,20 +126,23 @@ def offpolicy_trainer(
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
- policy, test_collector, test_fn,
- epoch, episode_per_test, logger, env_step)
+ policy, test_collector, test_fn, epoch, episode_per_test,
+ logger, env_step
+ )
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
- epoch, env_step, gradient_step, save_checkpoint_fn)
+ epoch, env_step, gradient_step, save_checkpoint_fn
+ )
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
- test_result["rew"], test_result["rew_std"])
+ test_result["rew"], test_result["rew_std"]
+ )
else:
policy.train()
- for i in range(round(update_per_step * result["n/st"])):
+ for _ in range(round(update_per_step * result["n/st"])):
gradient_step += 1
losses = policy.update(batch_size, train_collector.buffer)
for k in losses.keys():
@@ -148,8 +154,10 @@ def offpolicy_trainer(
if t.n <= t.total:
t.update()
# test
- test_result = test_episode(policy, test_collector, test_fn, epoch,
- episode_per_test, logger, env_step, reward_metric)
+ test_result = test_episode(
+ policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
+ reward_metric
+ )
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
@@ -157,9 +165,12 @@ def offpolicy_trainer(
save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
+ f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
+ )
if stop_fn and stop_fn(best_reward):
break
- return gather_info(start_time, train_collector, test_collector,
- best_reward, best_reward_std)
+ return gather_info(
+ start_time, train_collector, test_collector, best_reward, best_reward_std
+ )
diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py
index 1696421e8..6788e8a65 100644
--- a/tianshou/trainer/onpolicy.py
+++ b/tianshou/trainer/onpolicy.py
@@ -1,13 +1,14 @@
import time
-import tqdm
-import numpy as np
from collections import defaultdict
-from typing import Dict, Union, Callable, Optional
+from typing import Callable, Dict, Optional, Union
+
+import numpy as np
+import tqdm
from tianshou.data import Collector
from tianshou.policy import BasePolicy
-from tianshou.trainer import test_episode, gather_info
-from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
+from tianshou.trainer import gather_info, test_episode
+from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config
def onpolicy_trainer(
@@ -50,10 +51,10 @@ def onpolicy_trainer(
policy network.
:param int step_per_collect: the number of transitions the collector would collect
before the network update, i.e., trainer will collect "step_per_collect"
- transitions and do some policy network update repeatly in each epoch.
+ transitions and do some policy network update repeatedly in each epoch.
:param int episode_per_collect: the number of episodes the collector would collect
before the network update, i.e., trainer will collect "episode_per_collect"
- episodes and do some policy network update repeatly in each epoch.
+ episodes and do some policy network update repeatedly in each epoch.
:param function train_fn: a hook called at the beginning of training in each epoch.
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
@@ -97,8 +98,10 @@ def onpolicy_trainer(
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
- test_result = test_episode(policy, test_collector, test_fn, start_epoch,
- episode_per_test, logger, env_step, reward_metric)
+ test_result = test_episode(
+ policy, test_collector, test_fn, start_epoch, episode_per_test, logger,
+ env_step, reward_metric
+ )
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
@@ -111,8 +114,9 @@ def onpolicy_trainer(
while t.n < t.total:
if train_fn:
train_fn(epoch, env_step)
- result = train_collector.collect(n_step=step_per_collect,
- n_episode=episode_per_collect)
+ result = train_collector.collect(
+ n_step=step_per_collect, n_episode=episode_per_collect
+ )
if result["n/ep"] > 0 and reward_metric:
result["rews"] = reward_metric(result["rews"])
env_step += int(result["n/st"])
@@ -130,25 +134,32 @@ def onpolicy_trainer(
if result["n/ep"] > 0:
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
- policy, test_collector, test_fn,
- epoch, episode_per_test, logger, env_step)
+ policy, test_collector, test_fn, epoch, episode_per_test,
+ logger, env_step
+ )
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
- epoch, env_step, gradient_step, save_checkpoint_fn)
+ epoch, env_step, gradient_step, save_checkpoint_fn
+ )
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
- test_result["rew"], test_result["rew_std"])
+ test_result["rew"], test_result["rew_std"]
+ )
else:
policy.train()
losses = policy.update(
- 0, train_collector.buffer,
- batch_size=batch_size, repeat=repeat_per_collect)
+ 0,
+ train_collector.buffer,
+ batch_size=batch_size,
+ repeat=repeat_per_collect
+ )
train_collector.reset_buffer(keep_statistics=True)
- step = max([1] + [
- len(v) for v in losses.values() if isinstance(v, list)])
+ step = max(
+ [1] + [len(v) for v in losses.values() if isinstance(v, list)]
+ )
gradient_step += step
for k in losses.keys():
stat[k].add(losses[k])
@@ -159,8 +170,10 @@ def onpolicy_trainer(
if t.n <= t.total:
t.update()
# test
- test_result = test_episode(policy, test_collector, test_fn, epoch,
- episode_per_test, logger, env_step, reward_metric)
+ test_result = test_episode(
+ policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step,
+ reward_metric
+ )
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
@@ -168,9 +181,12 @@ def onpolicy_trainer(
save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
+ f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
+ )
if stop_fn and stop_fn(best_reward):
break
- return gather_info(start_time, train_collector, test_collector,
- best_reward, best_reward_std)
+ return gather_info(
+ start_time, train_collector, test_collector, best_reward, best_reward_std
+ )
diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py
index 2e729feeb..a39a12fff 100644
--- a/tianshou/trainer/utils.py
+++ b/tianshou/trainer/utils.py
@@ -1,6 +1,7 @@
import time
+from typing import Any, Callable, Dict, Optional, Union
+
import numpy as np
-from typing import Any, Dict, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
@@ -71,11 +72,13 @@ def gather_info(
if train_c is not None:
model_time -= train_c.collect_time
train_speed = train_c.collect_step / (duration - test_c.collect_time)
- result.update({
- "train_step": train_c.collect_step,
- "train_episode": train_c.collect_episode,
- "train_time/collector": f"{train_c.collect_time:.2f}s",
- "train_time/model": f"{model_time:.2f}s",
- "train_speed": f"{train_speed:.2f} step/s",
- })
+ result.update(
+ {
+ "train_step": train_c.collect_step,
+ "train_episode": train_c.collect_episode,
+ "train_time/collector": f"{train_c.collect_time:.2f}s",
+ "train_time/model": f"{model_time:.2f}s",
+ "train_speed": f"{train_speed:.2f} step/s",
+ }
+ )
return result
diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py
index 4ad73481c..5af038ab3 100644
--- a/tianshou/utils/__init__.py
+++ b/tianshou/utils/__init__.py
@@ -1,17 +1,12 @@
+"""Utils package."""
+
from tianshou.utils.config import tqdm_config
-from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.logger.base import BaseLogger, LazyLogger
-from tianshou.utils.logger.tensorboard import TensorboardLogger, BasicLogger
+from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandBLogger
-
+from tianshou.utils.statistics import MovAvg, RunningMeanStd
__all__ = [
- "MovAvg",
- "RunningMeanStd",
- "tqdm_config",
- "BaseLogger",
- "TensorboardLogger",
- "BasicLogger",
- "LazyLogger",
- "WandBLogger"
+ "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger",
+ "BasicLogger", "LazyLogger", "WandBLogger"
]
diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py
index c1ffe760d..9b89d5e88 100644
--- a/tianshou/utils/logger/base.py
+++ b/tianshou/utils/logger/base.py
@@ -1,7 +1,8 @@
-import numpy as np
-from numbers import Number
from abc import ABC, abstractmethod
-from typing import Dict, Tuple, Union, Callable, Optional
+from numbers import Number
+from typing import Callable, Dict, Optional, Tuple, Union
+
+import numpy as np
LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]]
diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py
index c65576b41..86e873cda 100644
--- a/tianshou/utils/logger/tensorboard.py
+++ b/tianshou/utils/logger/tensorboard.py
@@ -1,10 +1,10 @@
import warnings
-from typing import Any, Tuple, Callable, Optional
+from typing import Any, Callable, Optional, Tuple
-from torch.utils.tensorboard import SummaryWriter
from tensorboard.backend.event_processing import event_accumulator
+from torch.utils.tensorboard import SummaryWriter
-from tianshou.utils.logger.base import BaseLogger, LOG_DATA_TYPE
+from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger
class TensorboardLogger(BaseLogger):
@@ -48,8 +48,10 @@ def save_data(
save_checkpoint_fn(epoch, env_step, gradient_step)
self.write("save/epoch", epoch, {"save/epoch": epoch})
self.write("save/env_step", env_step, {"save/env_step": env_step})
- self.write("save/gradient_step", gradient_step,
- {"save/gradient_step": gradient_step})
+ self.write(
+ "save/gradient_step", gradient_step,
+ {"save/gradient_step": gradient_step}
+ )
def restore_data(self) -> Tuple[int, int, int]:
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
@@ -79,5 +81,6 @@ class BasicLogger(TensorboardLogger):
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
- "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427.")
+ "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427."
+ )
super().__init__(*args, **kwargs)
diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py
index cb11abc2e..b518a54a7 100644
--- a/tianshou/utils/net/common.py
+++ b/tianshou/utils/net/common.py
@@ -1,7 +1,8 @@
-import torch
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
+
import numpy as np
+import torch
from torch import nn
-from typing import Any, Dict, List, Type, Tuple, Union, Optional, Sequence
ModuleType = Type[nn.Module]
@@ -32,7 +33,7 @@ class MLP(nn.Module):
:param int input_dim: dimension of the input vector.
:param int output_dim: dimension of the output vector. If set to 0, there
is no final linear layer.
- :param hidden_sizes: shape of MLP passed in as a list, not incluing
+ :param hidden_sizes: shape of MLP passed in as a list, not including
input_dim and output_dim.
:param norm_layer: use which normalization before activation, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization.
@@ -40,7 +41,7 @@ class MLP(nn.Module):
of hidden_sizes, to use different normalization module in different
layers. Default to no normalization.
:param activation: which activation to use after each layer, can be both
- the same actvition for all layers if passed in nn.Module, or different
+ the same activation for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
:param device: which device to create this model on. Default to None.
@@ -64,8 +65,7 @@ def __init__(
assert len(norm_layer) == len(hidden_sizes)
norm_layer_list = norm_layer
else:
- norm_layer_list = [
- norm_layer for _ in range(len(hidden_sizes))]
+ norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
else:
norm_layer_list = [None] * len(hidden_sizes)
if activation:
@@ -73,26 +73,22 @@ def __init__(
assert len(activation) == len(hidden_sizes)
activation_list = activation
else:
- activation_list = [
- activation for _ in range(len(hidden_sizes))]
+ activation_list = [activation for _ in range(len(hidden_sizes))]
else:
activation_list = [None] * len(hidden_sizes)
hidden_sizes = [input_dim] + list(hidden_sizes)
model = []
for in_dim, out_dim, norm, activ in zip(
- hidden_sizes[:-1], hidden_sizes[1:],
- norm_layer_list, activation_list):
+ hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list
+ ):
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
if output_dim > 0:
model += [linear_layer(hidden_sizes[-1], output_dim)]
self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model)
- def forward(
- self, x: Union[np.ndarray, torch.Tensor]
- ) -> torch.Tensor:
- x = torch.as_tensor(
- x, device=self.device, dtype=torch.float32) # type: ignore
+ def forward(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ x = torch.as_tensor(x, device=self.device, dtype=torch.float32) # type: ignore
return self.model(x.flatten(1))
@@ -111,7 +107,7 @@ class Net(nn.Module):
of hidden_sizes, to use different normalization module in different
layers. Default to no normalization.
:param activation: which activation to use after each layer, can be both
- the same actvition for all layers if passed in nn.Module, or different
+ the same activation for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
:param device: specify the device when the network actually runs. Default
@@ -162,8 +158,9 @@ def __init__(
input_dim += action_dim
self.use_dueling = dueling_param is not None
output_dim = action_dim if not self.use_dueling and not concat else 0
- self.model = MLP(input_dim, output_dim, hidden_sizes,
- norm_layer, activation, device)
+ self.model = MLP(
+ input_dim, output_dim, hidden_sizes, norm_layer, activation, device
+ )
self.output_dim = self.model.output_dim
if self.use_dueling: # dueling DQN
q_kwargs, v_kwargs = dueling_param # type: ignore
@@ -172,10 +169,14 @@ def __init__(
q_output_dim, v_output_dim = action_dim, num_atoms
q_kwargs: Dict[str, Any] = {
**q_kwargs, "input_dim": self.output_dim,
- "output_dim": q_output_dim, "device": self.device}
+ "output_dim": q_output_dim,
+ "device": self.device
+ }
v_kwargs: Dict[str, Any] = {
**v_kwargs, "input_dim": self.output_dim,
- "output_dim": v_output_dim, "device": self.device}
+ "output_dim": v_output_dim,
+ "device": self.device
+ }
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
self.output_dim = self.Q.output_dim
@@ -239,8 +240,7 @@ def forward(
training mode, s should be with shape ``[bsz, len, dim]``. See the code
and comment for more detail.
"""
- s = torch.as_tensor(
- s, device=self.device, dtype=torch.float32) # type: ignore
+ s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
@@ -253,9 +253,12 @@ def forward(
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
- s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(),
- state["c"].transpose(0, 1).contiguous()))
+ s, (h, c) = self.nn(
+ s, (
+ state["h"].transpose(0, 1).contiguous(),
+ state["c"].transpose(0, 1).contiguous()
+ )
+ )
s = self.fc2(s[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...]
- return s, {"h": h.transpose(0, 1).detach(),
- "c": c.transpose(0, 1).detach()}
+ return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()}
diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py
index 36c178612..1bb090cdf 100644
--- a/tianshou/utils/net/continuous.py
+++ b/tianshou/utils/net/continuous.py
@@ -1,11 +1,11 @@
-import torch
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
import numpy as np
+import torch
from torch import nn
-from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.utils.net.common import MLP
-
SIGMA_MIN = -20
SIGMA_MAX = 2
@@ -47,10 +47,8 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
- input_dim = getattr(preprocess_net, "output_dim",
- preprocess_net_output_dim)
- self.last = MLP(input_dim, self.output_dim,
- hidden_sizes, device=self.device)
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self._max = max_action
def forward(
@@ -97,8 +95,7 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = 1
- input_dim = getattr(preprocess_net, "output_dim",
- preprocess_net_output_dim)
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
def forward(
@@ -109,11 +106,15 @@ def forward(
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
s = torch.as_tensor(
- s, device=self.device, dtype=torch.float32 # type: ignore
+ s,
+ device=self.device, # type: ignore
+ dtype=torch.float32,
).flatten(1)
if a is not None:
a = torch.as_tensor(
- a, device=self.device, dtype=torch.float32 # type: ignore
+ a,
+ device=self.device, # type: ignore
+ dtype=torch.float32,
).flatten(1)
s = torch.cat([s, a], dim=1)
logits, h = self.preprocess(s)
@@ -163,14 +164,13 @@ def __init__(
self.preprocess = preprocess_net
self.device = device
self.output_dim = int(np.prod(action_shape))
- input_dim = getattr(preprocess_net, "output_dim",
- preprocess_net_output_dim)
- self.mu = MLP(input_dim, self.output_dim,
- hidden_sizes, device=self.device)
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
- self.sigma = MLP(input_dim, self.output_dim,
- hidden_sizes, device=self.device)
+ self.sigma = MLP(
+ input_dim, self.output_dim, hidden_sizes, device=self.device
+ )
else:
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
self._max = max_action
@@ -188,9 +188,7 @@ def forward(
if not self._unbounded:
mu = self._max * torch.tanh(mu)
if self._c_sigma:
- sigma = torch.clamp(
- self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX
- ).exp()
+ sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
else:
shape = [1] * len(mu.shape)
shape[1] = -1
@@ -241,8 +239,7 @@ def forward(
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
- s = torch.as_tensor(
- s, device=self.device, dtype=torch.float32) # type: ignore
+ s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
@@ -254,23 +251,27 @@ def forward(
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
- s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(),
- state["c"].transpose(0, 1).contiguous()))
+ s, (h, c) = self.nn(
+ s, (
+ state["h"].transpose(0, 1).contiguous(),
+ state["c"].transpose(0, 1).contiguous()
+ )
+ )
logits = s[:, -1]
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
if self._c_sigma:
- sigma = torch.clamp(
- self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX
- ).exp()
+ sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp()
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...]
- return (mu, sigma), {"h": h.transpose(0, 1).detach(),
- "c": c.transpose(0, 1).detach()}
+ return (mu, sigma), {
+ "h": h.transpose(0, 1).detach(),
+ "c": c.transpose(0, 1).detach()
+ }
class RecurrentCritic(nn.Module):
@@ -307,8 +308,7 @@ def forward(
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
- s = torch.as_tensor(
- s, device=self.device, dtype=torch.float32) # type: ignore
+ s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
@@ -318,7 +318,10 @@ def forward(
s = s[:, -1]
if a is not None:
a = torch.as_tensor(
- a, device=self.device, dtype=torch.float32) # type: ignore
+ a,
+ device=self.device, # type: ignore
+ dtype=torch.float32,
+ )
s = torch.cat([s, a], dim=1)
s = self.fc2(s)
return s
diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py
index 200ae9d3a..bcc6531e3 100644
--- a/tianshou/utils/net/discrete.py
+++ b/tianshou/utils/net/discrete.py
@@ -1,8 +1,9 @@
-import torch
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
import numpy as np
-from torch import nn
+import torch
import torch.nn.functional as F
-from typing import Any, Dict, Tuple, Union, Optional, Sequence
+from torch import nn
from tianshou.data import Batch
from tianshou.utils.net.common import MLP
@@ -47,10 +48,8 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
- input_dim = getattr(preprocess_net, "output_dim",
- preprocess_net_output_dim)
- self.last = MLP(input_dim, self.output_dim,
- hidden_sizes, device=self.device)
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self.softmax_output = softmax_output
def forward(
@@ -101,10 +100,8 @@ def __init__(
self.device = device
self.preprocess = preprocess_net
self.output_dim = last_size
- input_dim = getattr(preprocess_net, "output_dim",
- preprocess_net_output_dim)
- self.last = MLP(input_dim, last_size,
- hidden_sizes, device=self.device)
+ input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
+ self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
@@ -141,9 +138,8 @@ def forward(self, taus: torch.Tensor) -> torch.Tensor:
start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device
).view(1, 1, self.num_cosines)
# Calculate cos(i * \pi * \tau).
- cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view(
- batch_size * N, self.num_cosines
- )
+ cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi
+ ).view(batch_size * N, self.num_cosines)
# Calculate embeddings of taus.
tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim)
return tau_embeddings
@@ -181,10 +177,12 @@ def __init__(
device: Union[str, int, torch.device] = "cpu"
) -> None:
last_size = np.prod(action_shape)
- super().__init__(preprocess_net, hidden_sizes, last_size,
- preprocess_net_output_dim, device)
- self.input_dim = getattr(preprocess_net, "output_dim",
- preprocess_net_output_dim)
+ super().__init__(
+ preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device
+ )
+ self.input_dim = getattr(
+ preprocess_net, "output_dim", preprocess_net_output_dim
+ )
self.embed_model = CosineEmbeddingNetwork(num_cosines,
self.input_dim).to(device)
@@ -195,13 +193,12 @@ def forward( # type: ignore
logits, h = self.preprocess(s, state=kwargs.get("state", None))
# Sample fractions.
batch_size = logits.size(0)
- taus = torch.rand(batch_size, sample_size,
- dtype=logits.dtype, device=logits.device)
- embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view(
- batch_size * sample_size, -1
+ taus = torch.rand(
+ batch_size, sample_size, dtype=logits.dtype, device=logits.device
)
- out = self.last(embedding).view(
- batch_size, sample_size, -1).transpose(1, 2)
+ embedding = (logits.unsqueeze(1) *
+ self.embed_model(taus)).view(batch_size * sample_size, -1)
+ out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2)
return (out, taus), h
@@ -270,20 +267,18 @@ def __init__(
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__(
- preprocess_net, action_shape, hidden_sizes,
- num_cosines, preprocess_net_output_dim, device
+ preprocess_net, action_shape, hidden_sizes, num_cosines,
+ preprocess_net_output_dim, device
)
def _compute_quantiles(
self, obs: torch.Tensor, taus: torch.Tensor
) -> torch.Tensor:
batch_size, sample_size = taus.shape
- embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view(
- batch_size * sample_size, -1
- )
- quantiles = self.last(embedding).view(
- batch_size, sample_size, -1
- ).transpose(1, 2)
+ embedding = (obs.unsqueeze(1) *
+ self.embed_model(taus)).view(batch_size * sample_size, -1)
+ quantiles = self.last(embedding).view(batch_size, sample_size,
+ -1).transpose(1, 2)
return quantiles
def forward( # type: ignore
@@ -328,10 +323,8 @@ def __init__(
super().__init__()
# Learnable parameters.
- self.mu_W = nn.Parameter(
- torch.FloatTensor(out_features, in_features))
- self.sigma_W = nn.Parameter(
- torch.FloatTensor(out_features, in_features))
+ self.mu_W = nn.Parameter(torch.FloatTensor(out_features, in_features))
+ self.sigma_W = nn.Parameter(torch.FloatTensor(out_features, in_features))
self.mu_bias = nn.Parameter(torch.FloatTensor(out_features))
self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features))
diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py
index e0d06763a..a81af601e 100644
--- a/tianshou/utils/statistics.py
+++ b/tianshou/utils/statistics.py
@@ -1,8 +1,9 @@
-import torch
-import numpy as np
from numbers import Number
from typing import List, Union
+import numpy as np
+import torch
+
class MovAvg(object):
"""Class for moving average.
@@ -66,13 +67,15 @@ def std(self) -> float:
class RunningMeanStd(object):
- """Calulates the running mean and std of a data stream.
+ """Calculates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
"""
def __init__(
- self, mean: Union[float, np.ndarray] = 0.0, std: Union[float, np.ndarray] = 1.0
+ self,
+ mean: Union[float, np.ndarray] = 0.0,
+ std: Union[float, np.ndarray] = 1.0
) -> None:
self.mean, self.var = mean, std
self.count = 0
@@ -88,7 +91,7 @@ def update(self, x: np.ndarray) -> None:
new_mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
- m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
+ m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count
new_var = m_2 / total_count
self.mean, self.var = new_mean, new_var