diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d72600676..280583538 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,14 +3,7 @@ + [ ] algorithm implementation fix + [ ] documentation modification + [ ] new feature +- [ ] I have reformatted the code using `make format` (**required**) +- [ ] I have checked the code using `make commit-checks` (**required**) - [ ] If applicable, I have mentioned the relevant/related issue(s) - -Less important but also useful: - -- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou) -- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates -- [ ] I have mentioned version numbers, operating system and environment, where applicable: - ```python - import tianshou, torch, numpy, sys - print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) - ``` +- [ ] If applicable, I have listed every items in this Pull Request below diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index 3a69bfa71..681654689 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -20,13 +20,14 @@ jobs: - name: Lint with flake8 run: | flake8 . --count --show-source --statistics + - name: Code formatter + run: | + yapf -r -d . + isort --check . - name: Type check run: | mypy - name: Documentation test run: | - pydocstyle tianshou - doc8 docs --max-line-length 1000 - cd docs - make html SPHINXOPTS="-W" - cd .. + make check-docstyle + make spelling diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml index 3d62da417..82c793e99 100644 --- a/.github/workflows/profile.yml +++ b/.github/workflows/profile.yml @@ -5,6 +5,7 @@ on: [push, pull_request] jobs: profile: runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" steps: - uses: actions/checkout@v2 - name: Set up Python 3.8 diff --git a/.gitignore b/.gitignore index be8453abd..e9510a1df 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,4 @@ MUJOCO_LOG.TXT *.swp *.pkl *.hdf5 +wandb/ diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..da5030ccd --- /dev/null +++ b/Makefile @@ -0,0 +1,60 @@ +SHELL=/bin/bash +PROJECT_NAME=tianshou +PROJECT_PATH=${PROJECT_NAME}/ +LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py + +check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade +check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade + +pytest: + $(call check_install, pytest) + $(call check_install, pytest_cov) + $(call check_install, pytest_xdist) + pytest test --cov ${PROJECT_PATH} --durations 0 -v --cov-report term-missing + +mypy: + $(call check_install, mypy) + mypy ${PROJECT_NAME} + +lint: + $(call check_install, flake8) + $(call check_install_extra, bugbear, flake8_bugbear) + flake8 ${LINT_PATHS} --count --show-source --statistics + +format: + # sort imports + $(call check_install, isort) + isort ${LINT_PATHS} + # reformat using yapf + $(call check_install, yapf) + yapf -ir ${LINT_PATHS} + +check-codestyle: + $(call check_install, isort) + $(call check_install, yapf) + isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS} + +check-docstyle: + $(call check_install, pydocstyle) + $(call check_install, doc8) + $(call check_install, sphinx) + $(call check_install, sphinx_rtd_theme) + pydocstyle ${PROJECT_PATH} && doc8 docs && cd docs && make html SPHINXOPTS="-W" + +doc: + $(call check_install, sphinx) + $(call check_install, sphinx_rtd_theme) + cd docs && make html && cd _build/html && python3 -m http.server + +spelling: + $(call check_install, sphinx) + $(call check_install, sphinx_rtd_theme) + $(call check_install_extra, sphinxcontrib.spelling, sphinxcontrib.spelling pyenchant) + cd docs && make spelling SPHINXOPTS="-W" + +clean: + cd docs && make clean + +commit-checks: format lint mypy check-docstyle spelling + +.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks diff --git a/docs/conf.py b/docs/conf.py index c258cf0a7..57d8b48df 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,9 +14,9 @@ # import sys # sys.path.insert(0, os.path.abspath('.')) +import sphinx_rtd_theme import tianshou -import sphinx_rtd_theme # Get the version string version = tianshou.__version__ @@ -30,7 +30,6 @@ # The full version, including alpha/beta/rc tags release = version - # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be @@ -51,7 +50,7 @@ # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] -source_suffix = [".rst", ".md"] +source_suffix = [".rst"] master_doc = "index" # List of patterns, relative to source directory, that match files and @@ -59,7 +58,8 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] autodoc_default_options = { - "special-members": ", ".join( + "special-members": + ", ".join( [ "__len__", "__call__", diff --git a/docs/contributing.rst b/docs/contributing.rst index cf015a95e..d1de0b65b 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -18,14 +18,26 @@ in the main directory. This installation is removable by $ python setup.py develop --uninstall -PEP8 Code Style Check ---------------------- +PEP8 Code Style Check and Code Formatter +---------------------------------------- -We follow PEP8 python code style. To check, in the main directory, run: +We follow PEP8 python code style with flake8. To check, in the main directory, run: .. code-block:: bash - $ flake8 . --count --show-source --statistics + $ make lint + +We use isort and yapf to format all codes. To format, in the main directory, run: + +.. code-block:: bash + + $ make format + +To check if formatted correctly, in the main directory, run: + +.. code-block:: bash + + $ make check-codestyle Type Check @@ -35,7 +47,7 @@ We use `mypy `_ to check the type annotations. .. code-block:: bash - $ mypy + $ make mypy Test Locally @@ -45,7 +57,7 @@ This command will run automatic tests in the main directory .. code-block:: bash - $ pytest test --cov tianshou -s --durations 0 -v + $ make pytest Test by GitHub Actions @@ -76,13 +88,13 @@ Documentations are written under the ``docs/`` directory as ReStructuredText (`` API References are automatically generated by `Sphinx `_ according to the outlines under ``docs/api/`` and should be modified when any code changes. -To compile documentation into webpages, run +To compile documentation into webpage, run .. code-block:: bash - $ make html + $ make doc -under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers. +The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/). Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/. @@ -92,21 +104,14 @@ Documentation Generation Test We have the following three documentation tests: -1. pydocstyle: test docstrings under ``tianshou/``. To check, in the main directory, run: - -.. code-block:: bash - - $ pydocstyle tianshou - -2. doc8: test ReStructuredText formats. To check, in the main directory, run: +1. pydocstyle: test all docstring under ``tianshou/``; -.. code-block:: bash +2. doc8: test ReStructuredText format; - $ doc8 docs +3. sphinx test: test if there is any error/warning when generating front-end html documentation. -3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. To check, in the main directory, run: +To check, in the main directory, run: .. code-block:: bash - $ cd docs - $ make html SPHINXOPTS="-W" + $ make check-docstyle diff --git a/docs/contributor.rst b/docs/contributor.rst index c594b2c0d..d71dc385b 100644 --- a/docs/contributor.rst +++ b/docs/contributor.rst @@ -4,7 +4,6 @@ Contributor We always welcome contributions to help make Tianshou better. Below are an incomplete list of our contributors (find more on `this page `_). * Jiayi Weng (`Trinkle23897 `_) -* Minghao Zhang (`Mehooz `_) * Alexis Duburcq (`duburcqa `_) * Kaichao You (`youkaichao `_) * Huayu Chen (`ChenDRAG `_) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt new file mode 100644 index 000000000..3649df71d --- /dev/null +++ b/docs/spelling_wordlist.txt @@ -0,0 +1,135 @@ +tianshou +arXiv +tanh +lr +logits +env +envs +optim +eps +timelimit +TimeLimit +maxsize +timestep +numpy +ndarray +stackoverflow +len +tac +fqf +iqn +qrdqn +rl +quantile +quantiles +dqn +param +async +subprocess +nn +equ +cql +fn +boolean +pre +np +rnn +rew +pre +perceptron +bsz +dataset +mujoco +jit +nstep +preprocess +repo +ReLU +namespace +th +utils +NaN +linesearch +hyperparameters +pseudocode +entropies +nn +config +cpu +rms +debias +indice +regularizer +miniblock +modularize +serializable +softmax +vectorized +optimizers +undiscounted +submodule +subclasses +submodules +tfevent +dirichlet +docstring +webpage +formatter +num +py +pythonic +中文文档位于 +conda +miniconda +Amir +Andreas +Antonoglou +Beattie +Bellemare +Charles +Daan +Demis +Dharshan +Fidjeland +Georg +Hassabis +Helen +Ioannis +Kavukcuoglu +King +Koray +Kumaran +Legg +Mnih +Ostrovski +Petersen +Riedmiller +Rusu +Sadik +Shane +Stig +Veness +Volodymyr +Wierstra +Lillicrap +Pritzel +Heess +Erez +Yuval +Tassa +Schulman +Filip +Wolski +Prafulla +Dhariwal +Radford +Oleg +Klimov +Kaichao +Jiayi +Weng +Duburcq +Huayu +Strens +Ornstein +Uhlenbeck diff --git a/docs/tutorials/batch.rst b/docs/tutorials/batch.rst index 49d913d9f..71f82f84e 100644 --- a/docs/tutorials/batch.rst +++ b/docs/tutorials/batch.rst @@ -60,7 +60,7 @@ The content of ``Batch`` objects can be defined by the following rules. 2. The keys are always strings (they are names of corresponding values). -3. The values can be scalars, tensors, or Batch objects. The recurse definition makes it possible to form a hierarchy of batches. +3. The values can be scalars, tensors, or Batch objects. The recursive definition makes it possible to form a hierarchy of batches. 4. Tensors are the most important values. In short, tensors are n-dimensional arrays of the same data type. We support two types of tensors: `PyTorch `_ tensor type ``torch.Tensor`` and `NumPy `_ tensor type ``np.ndarray``. @@ -348,7 +348,7 @@ The introduction of reserved keys gives rise to the need to check if a key is re
-The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recurse emptiness (a ``Batch`` object without any scalar/tensor leaf nodes). +The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes). .. note:: @@ -492,7 +492,7 @@ Miscellaneous Notes
-2. It is often the case that the observations returned from the environment are NumPy ndarrays but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors. +2. It is often the case that the observations returned from the environment are all NumPy ndarray but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors. 3. ``obj.stack_([a, b])`` is the same as ``Batch.stack([obj, a, b])``, and ``obj.cat_([a, b])`` is the same as ``Batch.cat([obj, a, b])``. Considering the frequent requirement of concatenating two ``Batch`` objects, Tianshou also supports ``obj.cat_(a)`` to be an alias of ``obj.cat_([a])``. diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 38a0291f4..c224b193b 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -341,7 +341,7 @@ With the flexible core APIs, Tianshou can support multi-agent reinforcement lear Currently, we support three types of multi-agent reinforcement learning paradigms: -1. Simultaneous move: at each timestep, all the agents take their actions (example: moba games) +1. Simultaneous move: at each timestep, all the agents take their actions (example: MOBA games) 2. Cyclic move: players take action in turn (example: Go game) @@ -371,4 +371,4 @@ By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we action = policy(state_) next_state_, reward = env.step(action) -Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-lerning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`. +Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 7222d9116..b1f76deb7 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -219,7 +219,7 @@ Tianshou provides other type of data buffer such as :class:`~tianshou.data.Prior Policy ------ -Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. +Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. A policy class typically has the following parts: diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index b4cd5b62d..1be441013 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -1,21 +1,21 @@ +import argparse +import datetime import os -import torch import pickle import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import Actor -from tianshou.policy import DiscreteBCQPolicy -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import DQN -from atari_wrapper import wrap_deepmind def get_args(): @@ -38,15 +38,19 @@ def get_args(): parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--watch", default=False, action="store_true", - help="watch the play of pre-trained policy only") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only" + ) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", type=str, - default="./expert_DQN_PongNoFrameskip-v4.hdf5") + "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" + ) parser.add_argument( - "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) args = parser.parse_known_args()[0] return args @@ -56,8 +60,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_discrete_bcq(args=get_args()): @@ -69,32 +77,43 @@ def test_discrete_bcq(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - feature_net = DQN(*args.state_shape, args.action_shape, - device=args.device, features_only=True).to(args.device) + feature_net = DQN( + *args.state_shape, args.action_shape, device=args.device, features_only=True + ).to(args.device) policy_net = Actor( - feature_net, args.action_shape, device=args.device, - hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + feature_net, + args.action_shape, + device=args.device, + hidden_sizes=args.hidden_sizes, + softmax_output=False + ).to(args.device) imitation_net = Actor( - feature_net, args.action_shape, device=args.device, - hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + feature_net, + args.action_shape, + device=args.device, + hidden_sizes=args.hidden_sizes, + softmax_output=False + ).to(args.device) optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr) + list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr + ) # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, - args.target_update_freq, args.eps_test, - args.unlikely_action_threshold, args.imitation_logits_penalty) + args.target_update_freq, args.eps_test, args.unlikely_action_threshold, + args.imitation_logits_penalty + ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device)) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -113,7 +132,8 @@ def test_discrete_bcq(args=get_args()): # log log_path = os.path.join( args.logdir, args.task, 'bcq', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=args.log_interval) @@ -132,8 +152,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -143,9 +162,17 @@ def watch(): exit(0) result = offline_trainer( - policy, buffer, test_collector, args.epoch, - args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index b70cc6cc0..291fb7007 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import C51 +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import C51Policy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import C51 -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -40,12 +40,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -55,8 +59,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_c51(args=get_args()): @@ -67,23 +75,30 @@ def test_c51(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = C51(*args.state_shape, args.action_shape, - args.num_atoms, args.device) + net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = C51Policy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -92,8 +107,12 @@ def test_c51(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -136,11 +155,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -148,8 +169,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -161,11 +183,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_cql.py b/examples/atari/atari_cql.py index cbab82029..db4e33a9a 100644 --- a/examples/atari/atari_cql.py +++ b/examples/atari/atari_cql.py @@ -1,20 +1,20 @@ +import argparse +import datetime import os -import torch import pickle import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import QRDQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCQLPolicy -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import QRDQN -from atari_wrapper import wrap_deepmind +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger def get_args(): @@ -37,15 +37,19 @@ def get_args(): parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--watch", default=False, action="store_true", - help="watch the play of pre-trained policy only") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only" + ) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", type=str, - default="./expert_DQN_PongNoFrameskip-v4.hdf5") + "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" + ) parser.add_argument( - "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) args = parser.parse_known_args()[0] return args @@ -55,8 +59,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_discrete_cql(args=get_args()): @@ -68,25 +76,29 @@ def test_discrete_cql(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = QRDQN(*args.state_shape, args.action_shape, - args.num_quantiles, args.device) + net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DiscreteCQLPolicy( - net, optim, args.gamma, args.num_quantiles, args.n_step, - args.target_update_freq, min_q_weight=args.min_q_weight + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + args.target_update_freq, + min_q_weight=args.min_q_weight ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device)) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -105,7 +117,8 @@ def test_discrete_cql(args=get_args()): # log log_path = os.path.join( args.logdir, args.task, 'cql', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=args.log_interval) @@ -124,8 +137,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -135,9 +147,17 @@ def watch(): exit(0) result = offline_trainer( - policy, buffer, test_collector, args.epoch, - args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py index e8e1ba54e..06cde415b 100644 --- a/examples/atari/atari_crr.py +++ b/examples/atari/atari_crr.py @@ -1,21 +1,21 @@ +import argparse +import datetime import os -import torch import pickle import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import Actor -from tianshou.policy import DiscreteCRRPolicy -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import DQN -from atari_wrapper import wrap_deepmind def get_args(): @@ -38,15 +38,19 @@ def get_args(): parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--watch", default=False, action="store_true", - help="watch the play of pre-trained policy only") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only" + ) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", type=str, - default="./expert_DQN_PongNoFrameskip-v4.hdf5") + "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" + ) parser.add_argument( - "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) args = parser.parse_known_args()[0] return args @@ -56,8 +60,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_discrete_crr(args=get_args()): @@ -69,33 +77,44 @@ def test_discrete_crr(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - feature_net = DQN(*args.state_shape, args.action_shape, - device=args.device, features_only=True).to(args.device) - actor = Actor(feature_net, args.action_shape, device=args.device, - hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + feature_net = DQN( + *args.state_shape, args.action_shape, device=args.device, features_only=True + ).to(args.device) + actor = Actor( + feature_net, + args.action_shape, + device=args.device, + hidden_sizes=args.hidden_sizes, + softmax_output=False + ).to(args.device) critic = DQN(*args.state_shape, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()), - lr=args.lr) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr + ) # define policy policy = DiscreteCRRPolicy( - actor, critic, optim, args.gamma, + actor, + critic, + optim, + args.gamma, policy_improvement_mode=args.policy_improvement_mode, - ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, + ratio_upper_bound=args.ratio_upper_bound, + beta=args.beta, min_q_weight=args.min_q_weight, target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device)) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -114,7 +133,8 @@ def test_discrete_crr(args=get_args()): # log log_path = os.path.join( args.logdir, args.task, 'crr', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=args.log_interval) @@ -132,8 +152,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -143,9 +162,17 @@ def watch(): exit(0) result = offline_trainer( - policy, buffer, test_collector, args.epoch, - args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 69ec08349..c9f74af8c 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import DQNPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import DQN -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -37,12 +37,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -52,8 +56,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_dqn(args=get_args()): @@ -64,22 +72,28 @@ def test_dqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = DQN(*args.state_shape, - args.action_shape, args.device).to(args.device) + net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = DQNPolicy(net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + policy = DQNPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) @@ -87,8 +101,12 @@ def test_dqn(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -131,11 +149,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -143,8 +163,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -156,11 +177,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 4a6e97c06..4629bede2 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -1,20 +1,20 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import FQFPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import FQFPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction -from atari_network import DQN -from atari_wrapper import wrap_deepmind - def get_args(): parser = argparse.ArgumentParser() @@ -43,12 +43,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -58,8 +62,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_fqf(args=get_args()): @@ -70,30 +78,43 @@ def test_fqf(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, - features_only=True) + feature_net = DQN( + *args.state_shape, args.action_shape, args.device, features_only=True + ) net = FullQuantileFunction( - feature_net, args.action_shape, args.hidden_sizes, - args.num_cosines, device=args.device + feature_net, + args.action_shape, + args.hidden_sizes, + args.num_cosines, + device=args.device ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) - fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), - lr=args.fraction_lr) + fraction_optim = torch.optim.RMSprop( + fraction_net.parameters(), lr=args.fraction_lr + ) # define policy policy = FQFPolicy( - net, optim, fraction_net, fraction_optim, - args.gamma, args.num_fractions, args.ent_coef, args.n_step, + net, + optim, + fraction_net, + fraction_optim, + args.gamma, + args.num_fractions, + args.ent_coef, + args.n_step, target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy @@ -103,8 +124,12 @@ def test_fqf(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -147,11 +172,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -159,8 +186,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -172,11 +200,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index e5966a318..d0e7773d0 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -1,20 +1,20 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import IQNPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import IQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import ImplicitQuantileNetwork -from atari_network import DQN -from atari_wrapper import wrap_deepmind - def get_args(): parser = argparse.ArgumentParser() @@ -43,12 +43,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -58,8 +62,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_iqn(args=get_args()): @@ -70,27 +78,38 @@ def test_iqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, - features_only=True) + feature_net = DQN( + *args.state_shape, args.action_shape, args.device, features_only=True + ) net = ImplicitQuantileNetwork( - feature_net, args.action_shape, args.hidden_sizes, - num_cosines=args.num_cosines, device=args.device + feature_net, + args.action_shape, + args.hidden_sizes, + num_cosines=args.num_cosines, + device=args.device ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = IQNPolicy( - net, optim, args.gamma, args.sample_size, args.online_sample_size, - args.target_sample_size, args.n_step, + net, + optim, + args.gamma, + args.sample_size, + args.online_sample_size, + args.target_sample_size, + args.n_step, target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy @@ -100,8 +119,12 @@ def test_iqn(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -144,11 +167,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -156,8 +181,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -169,11 +195,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 2eccf11af..4598fce11 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -1,7 +1,9 @@ -import torch +from typing import Any, Dict, Optional, Sequence, Tuple, Union + import numpy as np +import torch from torch import nn -from typing import Any, Dict, Tuple, Union, Optional, Sequence + from tianshou.utils.net.discrete import NoisyLinear @@ -27,15 +29,15 @@ def __init__( nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True), - nn.Flatten()) + nn.Flatten() + ) with torch.no_grad(): - self.output_dim = np.prod( - self.net(torch.zeros(1, c, h, w)).shape[1:]) + self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) if not features_only: self.net = nn.Sequential( - self.net, - nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True), - nn.Linear(512, np.prod(action_shape))) + self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True), + nn.Linear(512, np.prod(action_shape)) + ) self.output_dim = np.prod(action_shape) def forward( @@ -113,12 +115,14 @@ def linear(x, y): self.Q = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), - linear(512, self.action_num * self.num_atoms)) + linear(512, self.action_num * self.num_atoms) + ) self._is_dueling = is_dueling if self._is_dueling: self.V = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), - linear(512, self.num_atoms)) + linear(512, self.num_atoms) + ) self.output_dim = self.action_num * self.num_atoms def forward( diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 781f81d5d..23a7966eb 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import QRDQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger -from tianshou.policy import QRDQNPolicy +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import QRDQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import QRDQN -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -38,12 +38,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -53,8 +57,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_qrdqn(args=get_args()): @@ -65,23 +73,28 @@ def test_qrdqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = QRDQN(*args.state_shape, args.action_shape, - args.num_quantiles, args.device) + net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = QRDQNPolicy( - net, optim, args.gamma, args.num_quantiles, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -90,8 +103,12 @@ def test_qrdqn(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -134,11 +151,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -146,8 +165,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -159,11 +179,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index f2f44f0cd..b131cce5f 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -1,19 +1,19 @@ +import argparse +import datetime import os -import torch import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import Rainbow +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import RainbowPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import RainbowPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer - -from atari_network import Rainbow -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -50,12 +50,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -65,8 +69,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_rainbow(args=get_args()): @@ -77,25 +85,38 @@ def test_rainbow(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = Rainbow(*args.state_shape, args.action_shape, - args.num_atoms, args.noisy_std, args.device, - is_dueling=not args.no_dueling, - is_noisy=not args.no_noisy) + net = Rainbow( + *args.state_shape, + args.action_shape, + args.num_atoms, + args.noisy_std, + args.device, + is_dueling=not args.no_dueling, + is_noisy=not args.no_noisy + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = RainbowPolicy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -105,20 +126,31 @@ def test_rainbow(args=get_args()): # when you have enough RAM if args.no_priority: buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) else: buffer = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha, - beta=args.beta, weight_norm=not args.no_weight_norm) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta, + weight_norm=not args.no_weight_norm + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join( args.logdir, args.task, 'rainbow', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) @@ -164,12 +196,15 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack, alpha=args.alpha, - beta=args.beta) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -177,8 +212,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -190,11 +226,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 53a662613..333f9787a 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -1,10 +1,11 @@ # Borrow a lot from openai baselines: # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +from collections import deque + import cv2 import gym import numpy as np -from collections import deque class NoopResetEnv(gym.Wrapper): @@ -48,7 +49,7 @@ def step(self, action): reward, and max over last observations. """ obs_list, total_reward, done = [], 0., False - for i in range(self._skip): + for _ in range(self._skip): obs, reward, done, info = self.env.step(action) obs_list.append(obs) total_reward += reward @@ -127,13 +128,14 @@ def __init__(self, env): self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), - shape=(self.size, self.size), dtype=env.observation_space.dtype) + shape=(self.size, self.size), + dtype=env.observation_space.dtype + ) def observation(self, frame): """returns the current observation from a frame""" frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - return cv2.resize(frame, (self.size, self.size), - interpolation=cv2.INTER_AREA) + return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) class ScaledFloatFrame(gym.ObservationWrapper): @@ -149,8 +151,8 @@ def __init__(self, env): self.bias = low self.scale = high - low self.observation_space = gym.spaces.Box( - low=0., high=1., shape=env.observation_space.shape, - dtype=np.float32) + low=0., high=1., shape=env.observation_space.shape, dtype=np.float32 + ) def observation(self, observation): return (observation - self.bias) / self.scale @@ -182,11 +184,13 @@ def __init__(self, env, n_frames): super().__init__(env) self.n_frames = n_frames self.frames = deque([], maxlen=n_frames) - shape = (n_frames,) + env.observation_space.shape + shape = (n_frames, ) + env.observation_space.shape self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), - shape=shape, dtype=env.observation_space.dtype) + shape=shape, + dtype=env.observation_space.dtype + ) def reset(self): obs = self.env.reset() @@ -205,8 +209,14 @@ def _get_ob(self): return np.stack(self.frames, axis=0) -def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, - frame_stack=4, scale=False, warp_frame=True): +def wrap_deepmind( + env_id, + episode_life=True, + clip_rewards=True, + frame_stack=4, + scale=False, + warp_frame=True +): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 4889f7eb2..76246fd3e 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -31,17 +32,19 @@ def get_args(): parser.add_argument('--update-per-step', type=float, default=0.01) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) - parser.add_argument('--dueling-q-hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--dueling-v-hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument( + '--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) + parser.add_argument( + '--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser.parse_args() @@ -52,10 +55,12 @@ def test_dqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -64,18 +69,28 @@ def test_dqn(args=get_args()): # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - dueling_param=(Q_param, V_param)).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + dueling_param=(Q_param, V_param) + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -105,10 +120,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 4caa50b94..598622d01 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -32,16 +33,15 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument('--n-step', type=int, default=4) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) return parser.parse_args() @@ -75,13 +75,16 @@ def test_sac_bipedal(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] - train_envs = SubprocVectorEnv([ - lambda: Wrapper(gym.make(args.task)) - for _ in range(args.training_num)]) + train_envs = SubprocVectorEnv( + [lambda: Wrapper(gym.make(args.task)) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([ - lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [ + lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) + for _ in range(args.test_num) + ] + ) # seed np.random.seed(args.seed) @@ -90,22 +93,33 @@ def test_sac_bipedal(args=get_args()): test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( - net_a, args.action_shape, max_action=args.max_action, - device=args.device, unbounded=True).to(args.device) + net_a, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -116,9 +130,18 @@ def test_sac_bipedal(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, action_space=env.action_space) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path)) @@ -126,9 +149,11 @@ def test_sac_bipedal(args=get_args()): # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -144,10 +169,20 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, test_in_train=False, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + test_in_train=False, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) if __name__ == '__main__': pprint.pprint(result) @@ -155,8 +190,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index bb73ac615..88f4c397b 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv def get_args(): @@ -31,19 +32,20 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=16) parser.add_argument('--update-per-step', type=float, default=0.0625) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--dueling-q-hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--dueling-v-hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument( + '--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) + parser.add_argument( + '--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser.parse_args() @@ -54,10 +56,12 @@ def test_dqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -66,18 +70,28 @@ def test_dqn(args=get_args()): # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - dueling_param=(Q_param, V_param)).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + dueling_param=(Q_param, V_param) + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -93,7 +107,7 @@ def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold def train_fn(epoch, env_step): # exp decay - eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) + eps = max(args.eps_train * (1 - 5e-6)**env_step, args.eps_test) policy.set_eps(eps) def test_fn(epoch, env_step): @@ -101,10 +115,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, - test_fn=test_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + train_fn=train_fn, + test_fn=test_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index a43728be5..0638e8f61 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import SACPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise -from tianshou.utils.net.common import Net +from tianshou.policy import SACPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -34,16 +35,15 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=5) parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--training-num', type=int, default=5) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument('--rew-norm', type=bool, default=False) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser.parse_args() @@ -54,31 +54,43 @@ def test_sac(args=get_args()): args.max_action = env.action_space.high[0] # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( - net, args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, - device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, - device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -89,16 +101,26 @@ def test_sac(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, reward_normalization=args.rew_norm, exploration_noise=OUNoise(0.0, args.noise_std), - action_space=env.action_space) + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -114,10 +136,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/analysis.py b/examples/mujoco/analysis.py index 01a2cf678..ed0bb6872 100755 --- a/examples/mujoco/analysis.py +++ b/examples/mujoco/analysis.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 -import re import argparse +import re +from collections import defaultdict + import numpy as np from tabulate import tabulate -from collections import defaultdict -from tools import find_all_files, group_files, csv2numpy +from tools import csv2numpy, find_all_files, group_files def numerical_anysis(root_dir, xlim, norm=False): @@ -20,13 +21,16 @@ def numerical_anysis(root_dir, xlim, norm=False): for f in csv_files: result = csv2numpy(f) if norm: - result = np.stack([ - result['env_step'], - result['rew'] - result['rew'][0], - result['rew:shaded']]) + result = np.stack( + [ + result['env_step'], result['rew'] - result['rew'][0], + result['rew:shaded'] + ] + ) else: - result = np.stack([ - result['env_step'], result['rew'], result['rew:shaded']]) + result = np.stack( + [result['env_step'], result['rew'], result['rew:shaded']] + ) if result[0, -1] < xlim: continue @@ -79,11 +83,17 @@ def numerical_anysis(root_dir, xlim, norm=False): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--xlim', type=int, default=1000000, - help='x-axis limitation (default: 1000000)') + parser.add_argument( + '--xlim', + type=int, + default=1000000, + help='x-axis limitation (default: 1000000)' + ) parser.add_argument('--root-dir', type=str) parser.add_argument( - '--norm', action="store_true", - help="Normalize all results according to environment.") + '--norm', + action="store_true", + help="Normalize all results according to environment." + ) args = parser.parse_args() numerical_anysis(args.root_dir, args.xlim, norm=args.norm) diff --git a/examples/mujoco/gen_json.py b/examples/mujoco/gen_json.py index 0c0b113e9..99cad74a3 100755 --- a/examples/mujoco/gen_json.py +++ b/examples/mujoco/gen_json.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 -import os import csv -import sys import json +import os +import sys def merge(rootdir): """format: $rootdir/$algo/*.csv""" result = [] - for path, dirnames, filenames in os.walk(rootdir): + for path, _, filenames in os.walk(rootdir): filenames = [f for f in filenames if f.endswith('.csv')] if len(filenames) == 0: continue @@ -19,12 +19,14 @@ def merge(rootdir): algo = os.path.relpath(path, rootdir).upper() reader = csv.DictReader(open(os.path.join(path, filenames[0]))) for row in reader: - result.append({ - 'env_step': int(row['env_step']), - 'rew': float(row['rew']), - 'rew_std': float(row['rew:shaded']), - 'Agent': algo, - }) + result.append( + { + 'env_step': int(row['env_step']), + 'rew': float(row['rew']), + 'rew_std': float(row['rew:shaded']), + 'Agent': algo, + } + ) open(os.path.join(rootdir, 'result.json'), 'w').write(json.dumps(result)) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index e9debc906..02978697b 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import A2CPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -48,11 +49,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -63,16 +68,18 @@ def test_a2c(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -80,12 +87,25 @@ def test_a2c(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -101,27 +121,43 @@ def test_a2c(args=get_args()): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()), - lr=args.lr, eps=1e-5, alpha=0.99) + optim = torch.optim.RMSprop( + list(actor.parameters()) + list(critic.parameters()), + lr=args.lr, + eps=1e-5, + alpha=0.99 + ) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = A2CPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, - vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space) + policy = A2CPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -149,10 +185,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 28bac056a..8d436b573 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -1,22 +1,23 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import DDPGPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise +from tianshou.policy import DDPGPolicy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -42,11 +43,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -58,17 +63,18 @@ def test_ddpg(args=get_args()): args.exploration_noise = args.exploration_noise * args.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) if args.training_num > 1: train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) else: train_envs = gym.make(args.task) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -77,19 +83,29 @@ def test_ddpg(args=get_args()): # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor( - net_a, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( - actor, actor_optim, critic, critic_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic, + critic_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), - estimation_step=args.n_step, action_space=env.action_space) + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -118,10 +134,19 @@ def save_fn(policy): if not args.watch: # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 00b2a1a2c..23883a119 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import NPGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -26,8 +27,9 @@ def get_args(): parser.add_argument('--task', type=str, default='HalfCheetah-v3') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=4096) - parser.add_argument('--hidden-sizes', type=int, nargs='*', - default=[64, 64]) # baselines [32, 32] + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[64, 64] + ) # baselines [32, 32] parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) @@ -49,11 +51,15 @@ def get_args(): parser.add_argument('--optim-critic-iters', type=int, default=20) parser.add_argument('--actor-step-size', type=float, default=0.1) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -64,16 +70,18 @@ def test_npg(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -81,12 +89,25 @@ def test_npg(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -107,22 +128,32 @@ def test_npg(args=get_args()): if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = NPGPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - advantage_normalization=args.norm_adv, - optim_critic_iters=args.optim_critic_iters, - actor_step_size=args.actor_step_size) + policy = NPGPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + advantage_normalization=args.norm_adv, + optim_critic_iters=args.optim_critic_iters, + actor_step_size=args.actor_step_size + ) # load a previous policy if args.resume_path: @@ -150,10 +181,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index fb3e7a0a2..01dc5aa3f 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -53,11 +54,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -68,16 +73,18 @@ def test_ppo(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -85,12 +92,25 @@ def test_ppo(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -107,29 +127,44 @@ def test_ppo(args=get_args()): m.weight.data.copy_(0.01 * m.weight.data) optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr) + list(actor.parameters()) + list(critic.parameters()), lr=args.lr + ) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = PPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, - vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - eps_clip=args.eps_clip, value_clip=args.value_clip, - dual_clip=args.dual_clip, advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv) + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + eps_clip=args.eps_clip, + value_clip=args.value_clip, + dual_clip=args.dual_clip, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv + ) # load a previous policy if args.resume_path: @@ -157,10 +192,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index b7698562a..914b46251 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import PGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -45,11 +46,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -60,16 +65,18 @@ def test_reinforce(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -77,10 +84,19 @@ def test_reinforce(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in actor.modules(): if isinstance(m, torch.nn.Linear): @@ -100,18 +116,27 @@ def test_reinforce(args=get_args()): if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = PGPolicy(actor, optim, dist, discount_factor=args.gamma, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.action_bound_method, - lr_scheduler=lr_scheduler, action_space=env.action_space) + policy = PGPolicy( + actor, + optim, + dist, + discount_factor=args.gamma, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.action_bound_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -139,10 +164,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index c2dfd3618..cb764f473 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -1,21 +1,22 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -43,11 +44,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -58,17 +63,18 @@ def test_sac(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) if args.training_num > 1: train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) else: train_envs = gym.make(args.task) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -77,16 +83,28 @@ def test_sac(args=get_args()): # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( - net_a, args.action_shape, max_action=args.max_action, - device=args.device, unbounded=True, conditioned_sigma=True + net_a, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) @@ -99,9 +117,18 @@ def test_sac(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, action_space=env.action_space) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -130,10 +157,19 @@ def save_fn(policy): if not args.watch: # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 9a0179899..9e0ca0d82 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -1,22 +1,23 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import TD3Policy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise +from tianshou.policy import TD3Policy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -45,11 +46,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -63,17 +68,18 @@ def test_td3(args=get_args()): args.noise_clip = args.noise_clip * args.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) if args.training_num > 1: train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) else: train_envs = gym.make(args.task) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -82,27 +88,44 @@ def test_td3(args=get_args()): # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor( - net_a, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), - policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, - noise_clip=args.noise_clip, estimation_step=args.n_step, - action_space=env.action_space) + policy_noise=args.policy_noise, + update_actor_freq=args.update_actor_freq, + noise_clip=args.noise_clip, + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -131,10 +154,19 @@ def save_fn(policy): if not args.watch: # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b00f2e3d6..aef324fd5 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import TRPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -26,8 +27,9 @@ def get_args(): parser.add_argument('--task', type=str, default='HalfCheetah-v3') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=4096) - parser.add_argument('--hidden-sizes', type=int, nargs='*', - default=[64, 64]) # baselines [32, 32] + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[64, 64] + ) # baselines [32, 32] parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) @@ -52,11 +54,15 @@ def get_args(): parser.add_argument('--backtrack-coeff', type=float, default=0.8) parser.add_argument('--max-backtracks', type=int, default=10) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -67,16 +73,18 @@ def test_trpo(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -84,12 +92,25 @@ def test_trpo(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -110,24 +131,34 @@ def test_trpo(args=get_args()): if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = TRPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - advantage_normalization=args.norm_adv, - optim_critic_iters=args.optim_critic_iters, - max_kl=args.max_kl, - backtrack_coeff=args.backtrack_coeff, - max_backtracks=args.max_backtracks) + policy = TRPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + advantage_normalization=args.norm_adv, + optim_critic_iters=args.optim_critic_iters, + max_kl=args.max_kl, + backtrack_coeff=args.backtrack_coeff, + max_backtracks=args.max_backtracks + ) # load a previous policy if args.resume_path: @@ -155,10 +186,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index 4ecd530c7..e3e7057e4 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 -import re -import os import argparse -import numpy as np +import os +import re + import matplotlib.pyplot as plt import matplotlib.ticker as mticker - -from tools import find_all_files, group_files, csv2numpy +import numpy as np +from tools import csv2numpy, find_all_files, group_files def smooth(y, radius, mode='two_sided', valid_only=False): @@ -38,28 +38,49 @@ def smooth(y, radius, mode='two_sided', valid_only=False): return out -COLORS = ([ - # deepmind style - '#0072B2', - '#009E73', - '#D55E00', - '#CC79A7', - # '#F0E442', - '#d73027', # RED - # built-in color - 'blue', 'red', 'pink', 'cyan', 'magenta', 'yellow', 'black', 'purple', - 'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise', - 'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue', 'green', - # personal color - '#313695', # DARK BLUE - '#74add1', # LIGHT BLUE - '#f46d43', # ORANGE - '#4daf4a', # GREEN - '#984ea3', # PURPLE - '#f781bf', # PINK - '#ffc832', # YELLOW - '#000000', # BLACK -]) +COLORS = ( + [ + # deepmind style + '#0072B2', + '#009E73', + '#D55E00', + '#CC79A7', + # '#F0E442', + '#d73027', # RED + # built-in color + 'blue', + 'red', + 'pink', + 'cyan', + 'magenta', + 'yellow', + 'black', + 'purple', + 'brown', + 'orange', + 'teal', + 'lightblue', + 'lime', + 'lavender', + 'turquoise', + 'darkgreen', + 'tan', + 'salmon', + 'gold', + 'darkred', + 'darkblue', + 'green', + # personal color + '#313695', # DARK BLUE + '#74add1', # LIGHT BLUE + '#f46d43', # ORANGE + '#4daf4a', # GREEN + '#984ea3', # PURPLE + '#f781bf', # PINK + '#ffc832', # YELLOW + '#000000', # BLACK + ] +) def plot_ax( @@ -76,6 +97,7 @@ def plot_ax( shaded_std=True, legend_outside=False, ): + def legend_fn(x): # return os.path.split(os.path.join( # args.root_dir, x))[0].replace('/', '_') + " (10)" @@ -96,8 +118,11 @@ def legend_fn(x): y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius) ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2) - ax.legend(legneds, loc=2 if legend_outside else None, - bbox_to_anchor=(1, 1) if legend_outside else None) + ax.legend( + legneds, + loc=2 if legend_outside else None, + bbox_to_anchor=(1, 1) if legend_outside else None + ) ax.xaxis.set_major_formatter(mticker.EngFormatter()) if xlim is not None: ax.set_xlim(xmin=0, xmax=xlim) @@ -127,8 +152,14 @@ def plot_figure( res = group_files(file_lists, group_pattern) row_n = int(np.ceil(len(res) / 3)) col_n = min(len(res), 3) - fig, axes = plt.subplots(row_n, col_n, sharex=sharex, sharey=sharey, figsize=( - fig_length * col_n, fig_width * row_n), squeeze=False) + fig, axes = plt.subplots( + row_n, + col_n, + sharex=sharex, + sharey=sharey, + figsize=(fig_length * col_n, fig_width * row_n), + squeeze=False + ) axes = axes.flatten() for i, (k, v) in enumerate(res.items()): plot_ax(axes[i], v, title=k, **kwargs) @@ -138,53 +169,95 @@ def plot_figure( if __name__ == "__main__": parser = argparse.ArgumentParser(description='plotter') - parser.add_argument('--fig-length', type=int, default=6, - help='matplotlib figure length (default: 6)') - parser.add_argument('--fig-width', type=int, default=6, - help='matplotlib figure width (default: 6)') - parser.add_argument('--style', default='seaborn', - help='matplotlib figure style (default: seaborn)') - parser.add_argument('--title', default=None, - help='matplotlib figure title (default: None)') - parser.add_argument('--xkey', default='env_step', - help='x-axis key in csv file (default: env_step)') - parser.add_argument('--ykey', default='rew', - help='y-axis key in csv file (default: rew)') - parser.add_argument('--smooth', type=int, default=0, - help='smooth radius of y axis (default: 0)') - parser.add_argument('--xlabel', default='Timesteps', - help='matplotlib figure xlabel') - parser.add_argument('--ylabel', default='Episode Reward', - help='matplotlib figure ylabel') - parser.add_argument( - '--shaded-std', action='store_true', - help='shaded region corresponding to standard deviation of the group') - parser.add_argument('--sharex', action='store_true', - help='whether to share x axis within multiple sub-figures') - parser.add_argument('--sharey', action='store_true', - help='whether to share y axis within multiple sub-figures') - parser.add_argument('--legend-outside', action='store_true', - help='place the legend outside of the figure') - parser.add_argument('--xlim', type=int, default=None, - help='x-axis limitation (default: None)') + parser.add_argument( + '--fig-length', + type=int, + default=6, + help='matplotlib figure length (default: 6)' + ) + parser.add_argument( + '--fig-width', + type=int, + default=6, + help='matplotlib figure width (default: 6)' + ) + parser.add_argument( + '--style', + default='seaborn', + help='matplotlib figure style (default: seaborn)' + ) + parser.add_argument( + '--title', default=None, help='matplotlib figure title (default: None)' + ) + parser.add_argument( + '--xkey', + default='env_step', + help='x-axis key in csv file (default: env_step)' + ) + parser.add_argument( + '--ykey', default='rew', help='y-axis key in csv file (default: rew)' + ) + parser.add_argument( + '--smooth', type=int, default=0, help='smooth radius of y axis (default: 0)' + ) + parser.add_argument( + '--xlabel', default='Timesteps', help='matplotlib figure xlabel' + ) + parser.add_argument( + '--ylabel', default='Episode Reward', help='matplotlib figure ylabel' + ) + parser.add_argument( + '--shaded-std', + action='store_true', + help='shaded region corresponding to standard deviation of the group' + ) + parser.add_argument( + '--sharex', + action='store_true', + help='whether to share x axis within multiple sub-figures' + ) + parser.add_argument( + '--sharey', + action='store_true', + help='whether to share y axis within multiple sub-figures' + ) + parser.add_argument( + '--legend-outside', + action='store_true', + help='place the legend outside of the figure' + ) + parser.add_argument( + '--xlim', type=int, default=None, help='x-axis limitation (default: None)' + ) parser.add_argument('--root-dir', default='./', help='root dir (default: ./)') parser.add_argument( - '--file-pattern', type=str, default=r".*/test_rew_\d+seeds.csv$", + '--file-pattern', + type=str, + default=r".*/test_rew_\d+seeds.csv$", help='regular expression to determine whether or not to include target csv ' - 'file, default to including all test_rew_{num}seeds.csv file under rootdir') + 'file, default to including all test_rew_{num}seeds.csv file under rootdir' + ) parser.add_argument( - '--group-pattern', type=str, default=r"(/|^)\w*?\-v(\d|$)", + '--group-pattern', + type=str, + default=r"(/|^)\w*?\-v(\d|$)", help='regular expression to group files in sub-figure, default to grouping ' - 'according to env_name dir, "" means no grouping') + 'according to env_name dir, "" means no grouping' + ) parser.add_argument( - '--legend-pattern', type=str, default=r".*", + '--legend-pattern', + type=str, + default=r".*", help='regular expression to extract legend from csv file path, default to ' - 'using file path as legend name.') + 'using file path as legend name.' + ) parser.add_argument('--show', action='store_true', help='show figure') - parser.add_argument('--output-path', type=str, - help='figure save path', default="./figure.png") - parser.add_argument('--dpi', type=int, default=200, - help='figure dpi (default: 200)') + parser.add_argument( + '--output-path', type=str, help='figure save path', default="./figure.png" + ) + parser.add_argument( + '--dpi', type=int, default=200, help='figure dpi (default: 200)' + ) args = parser.parse_args() file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern)) file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists] @@ -207,9 +280,9 @@ def plot_figure( sharey=args.sharey, smooth_radius=args.smooth, shaded_std=args.shaded_std, - legend_outside=args.legend_outside) + legend_outside=args.legend_outside + ) if args.output_path: - plt.savefig(args.output_path, - dpi=args.dpi, bbox_inches='tight') + plt.savefig(args.output_path, dpi=args.dpi, bbox_inches='tight') if args.show: plt.show() diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index 9e49206e0..3ed4791fd 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 +import argparse +import csv import os import re -import csv -import tqdm -import argparse -import numpy as np from collections import defaultdict + +import numpy as np +import tqdm from tensorboard.backend.event_processing import event_accumulator @@ -66,11 +67,13 @@ def convert_tfevents_to_csv(root_dir, refresh=False): initial_time = ea._first_event_timestamp content = [["env_step", "rew", "time"]] for test_rew in ea.scalars.Items("test/rew"): - content.append([ - round(test_rew.step, 4), - round(test_rew.value, 4), - round(test_rew.wall_time - initial_time, 4), - ]) + content.append( + [ + round(test_rew.step, 4), + round(test_rew.value, 4), + round(test_rew.wall_time - initial_time, 4), + ] + ) csv.writer(open(output_file, 'w')).writerows(content) result[output_file] = content return result @@ -80,13 +83,15 @@ def merge_csv(csv_files, root_dir, remove_zero=False): """Merge result in csv_files into a single csv file.""" assert len(csv_files) > 0 if remove_zero: - for k, v in csv_files.items(): + for v in csv_files.values(): if v[1][0] == 0: v.pop(1) sorted_keys = sorted(csv_files.keys()) sorted_values = [csv_files[k][1:] for k in sorted_keys] - content = [["env_step", "rew", "rew:shaded"] + list(map( - lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys))] + content = [ + ["env_step", "rew", "rew:shaded"] + + list(map(lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys)) + ] for rows in zip(*sorted_values): array = np.array(rows) assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0]) @@ -101,11 +106,15 @@ def merge_csv(csv_files, root_dir, remove_zero=False): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - '--refresh', action="store_true", - help="Re-generate all csv files instead of using existing one.") + '--refresh', + action="store_true", + help="Re-generate all csv files instead of using existing one." + ) parser.add_argument( - '--remove-zero', action="store_true", - help="Remove the data point of env_step == 0.") + '--remove-zero', + action="store_true", + help="Remove the data point of env_step == 0." + ) parser.add_argument('--root-dir', type=str) args = parser.parse_args() diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 017ab7750..290cb92e5 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -1,4 +1,5 @@ import os + import cv2 import gym import numpy as np @@ -33,9 +34,8 @@ def battle_button_comb(): class Env(gym.Env): - def __init__( - self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False - ): + + def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False): super().__init__() self.save_lmp = save_lmp self.health_setting = "battle" in cfg_path @@ -75,8 +75,7 @@ def reset(self): self.obs_buffer = np.zeros(self.res, dtype=np.uint8) self.get_obs() self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) - self.killcount = self.game.get_game_variable( - vzd.GameVariable.KILLCOUNT) + self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) return self.obs_buffer @@ -121,7 +120,7 @@ def close(self): obs = env.reset() print(env.spec.reward_threshold) print(obs.shape, action_num) - for i in range(4000): + for _ in range(4000): obs, rew, done, info = env.step(0) if done: env.reset() diff --git a/examples/vizdoom/maps/spectator.py b/examples/vizdoom/maps/spectator.py index d4d7e8c7c..2180ed7c7 100644 --- a/examples/vizdoom/maps/spectator.py +++ b/examples/vizdoom/maps/spectator.py @@ -10,11 +10,12 @@ from __future__ import print_function +from argparse import ArgumentParser from time import sleep + import vizdoom as vzd -from argparse import ArgumentParser -# import cv2 +# import cv2 if __name__ == "__main__": parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.") diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py index a1e556fce..30cdb31bd 100755 --- a/examples/vizdoom/replay.py +++ b/examples/vizdoom/replay.py @@ -1,6 +1,7 @@ # import cv2 import sys import time + import tqdm import vizdoom as vzd diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 1123151c8..bb3a1f207 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from env import Env +from network import C51 from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import C51Policy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from env import Env -from network import C51 +from tianshou.utils import TensorboardLogger def get_args(): @@ -40,15 +40,23 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--skip-num', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') - parser.add_argument('--save-lmp', default=False, action='store_true', - help='save lmp file for replay whole episode') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + parser.add_argument( + '--save-lmp', + default=False, + action='store_true', + help='save lmp file for replay whole episode' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -64,26 +72,36 @@ def test_c51(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([ - lambda: Env(args.cfg_path, args.frames_stack, args.res) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([ - lambda: Env(args.cfg_path, args.frames_stack, - args.res, args.save_lmp) - for _ in range(min(os.cpu_count() - 1, args.test_num))]) + train_envs = SubprocVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res) + for _ in range(args.training_num) + ] + ) + test_envs = SubprocVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) + for _ in range(min(os.cpu_count() - 1, args.test_num)) + ] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = C51(*args.state_shape, args.action_shape, - args.num_atoms, args.device) + net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = C51Policy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -92,8 +110,12 @@ def test_c51(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -136,11 +158,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -148,8 +172,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() lens = result["lens"].mean() * args.skip_num print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -163,11 +188,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/setup.cfg b/setup.cfg index d485e6d06..7bb1f1e7c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,18 @@ exclude = dist *.egg-info max-line-length = 87 +ignore = B305,W504,B006,B008 + +[yapf] +based_on_style = pep8 +dedent_closing_brackets = true +column_limit = 87 +blank_line_before_nested_class_or_def = true + +[isort] +profile = black +multi_line_output = 3 +line_length = 87 [mypy] files = tianshou/**/*.py diff --git a/setup.py b/setup.py index 208af8e14..bf48020e5 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- import os -from setuptools import setup, find_packages + +from setuptools import find_packages, setup def get_version() -> str: @@ -60,6 +61,9 @@ def get_version() -> str: "sphinx_rtd_theme", "sphinxcontrib-bibtex", "flake8", + "flake8-bugbear", + "yapf", + "isort", "pytest", "pytest-cov", "ray>=1.0.0", diff --git a/test/base/env.py b/test/base/env.py index 1151f5b76..cdcb51efa 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -1,19 +1,28 @@ -import gym -import time import random -import numpy as np -import networkx as nx +import time from copy import deepcopy -from gym.spaces import Discrete, MultiDiscrete, Box, Dict, Tuple + +import gym +import networkx as nx +import numpy as np +from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple class MyTestEnv(gym.Env): """This is a "going right" task. The task is to go right ``size`` steps. """ - def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, - ma_rew=0, multidiscrete_action=False, random_sleep=False, - array_state=False): + def __init__( + self, + size, + sleep=0, + dict_state=False, + recurse_state=False, + ma_rew=0, + multidiscrete_action=False, + random_sleep=False, + array_state=False + ): assert dict_state + recurse_state + array_state <= 1, \ "dict_state / recurse_state / array_state can be only one true" self.size = size @@ -28,17 +37,32 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, self.steps = 0 if dict_state: self.observation_space = Dict( - {"index": Box(shape=(1, ), low=0, high=size - 1), - "rand": Box(shape=(1,), low=0, high=1, dtype=np.float64)}) + { + "index": Box(shape=(1, ), low=0, high=size - 1), + "rand": Box(shape=(1, ), low=0, high=1, dtype=np.float64) + } + ) elif recurse_state: self.observation_space = Dict( - {"index": Box(shape=(1, ), low=0, high=size - 1), - "dict": Dict({ - "tuple": Tuple((Discrete(2), Box(shape=(2,), - low=0, high=1, dtype=np.float64))), - "rand": Box(shape=(1, 2), low=0, high=1, - dtype=np.float64)}) - }) + { + "index": + Box(shape=(1, ), low=0, high=size - 1), + "dict": + Dict( + { + "tuple": + Tuple( + ( + Discrete(2), + Box(shape=(2, ), low=0, high=1, dtype=np.float64) + ) + ), + "rand": + Box(shape=(1, 2), low=0, high=1, dtype=np.float64) + } + ) + } + ) elif array_state: self.observation_space = Box(shape=(4, 84, 84), low=0, high=255) else: @@ -70,13 +94,18 @@ def _get_reward(self): def _get_state(self): """Generate state(observation) of MyTestEnv""" if self.dict_state: - return {'index': np.array([self.index], dtype=np.float32), - 'rand': self.rng.rand(1)} + return { + 'index': np.array([self.index], dtype=np.float32), + 'rand': self.rng.rand(1) + } elif self.recurse_state: - return {'index': np.array([self.index], dtype=np.float32), - 'dict': {"tuple": (np.array([1], - dtype=int), self.rng.rand(2)), - "rand": self.rng.rand(1, 2)}} + return { + 'index': np.array([self.index], dtype=np.float32), + 'dict': { + "tuple": (np.array([1], dtype=int), self.rng.rand(2)), + "rand": self.rng.rand(1, 2) + } + } elif self.array_state: img = np.zeros([4, 84, 84], int) img[3, np.arange(84), np.arange(84)] = self.index @@ -112,6 +141,7 @@ def step(self, action): class NXEnv(gym.Env): + def __init__(self, size, obs_type, feat_dim=32): self.size = size self.feat_dim = feat_dim diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 15357f16c..53ee8ffa3 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,13 +1,14 @@ -import sys import copy -import torch import pickle -import pytest -import numpy as np -import networkx as nx +import sys from itertools import starmap -from tianshou.data import Batch, to_torch, to_numpy +import networkx as nx +import numpy as np +import pytest +import torch + +from tianshou.data import Batch, to_numpy, to_torch def test_batch(): @@ -99,10 +100,13 @@ def test_batch(): assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] - batch2 = Batch(a=[{ - 'b': np.float64(1.0), - 'c': np.zeros(1), - 'd': Batch(e=np.array(3.0))}]) + batch2 = Batch( + a=[{ + 'b': np.float64(1.0), + 'c': np.zeros(1), + 'd': Batch(e=np.array(3.0)) + }] + ) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] @@ -141,9 +145,12 @@ def test_batch(): assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): batch2 += [1] - batch3 = Batch(a={ - 'c': np.zeros(1), - 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) + batch3 = Batch( + a={ + 'c': np.zeros(1), + 'd': Batch(e=np.array([0.0]), f=np.array([3.0])) + } + ) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) @@ -202,7 +209,7 @@ def test_batch_over_batch(): assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) - batch4 = Batch(({'a': {'b': np.array([1.0])}},)) + batch4 = Batch(({'a': {'b': np.array([1.0])}}, )) assert batch4.a.b.ndim == 2 assert batch4.a.b[0, 0] == 1.0 # advanced slicing @@ -239,14 +246,14 @@ def test_batch_cat_and_stack(): a = Batch(a=Batch(a=np.random.randn(3, 4))) assert np.allclose( np.concatenate([a.a.a, a.a.a]), - Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a) + Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a + ) # test cat with lens infer a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) ans = Batch.cat([a, b, a]) - assert np.allclose(ans.a.a, - np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) + assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() @@ -258,51 +265,61 @@ def test_batch_cat_and_stack(): b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) - ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + ans = Batch( + a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) + ) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with reserved keys (values are Batch()) b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(a=Batch(), - b=torch.rand(4, 3), - common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) - ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + ans = Batch( + a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) + ) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with all reserved keys (values are Batch()) b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(a=Batch(), - b=torch.rand(4, 3), - common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) - ans = Batch(a=Batch(), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + ans = Batch( + a=Batch(), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) + ) assert ans.a.is_empty() assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test stack with compatible keys - b3 = Batch(a=np.zeros((3, 4)), - b=torch.ones((2, 5)), - c=Batch(d=[[1], [2]])) - b4 = Batch(a=np.ones((3, 4)), - b=torch.ones((2, 5)), - c=Batch(d=[[0], [3]])) + b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) + b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) - b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, - {'a': True, 'b': {'c': 3.0}}]) + b5_dict = np.array( + [{ + 'a': False, + 'b': { + 'c': 2.0, + 'd': 1.0 + } + }, { + 'a': True, + 'b': { + 'c': 3.0 + } + }] + ) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) @@ -335,15 +352,16 @@ def test_batch_cat_and_stack(): test = Batch.stack([b1, b2], axis=-1) assert test.a.is_empty() assert test.b.is_empty() - assert np.allclose(test.common.c, - np.stack([b1.common.c, b2.common.c], axis=-1)) + assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1)) b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2]) - ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), - b=torch.stack([torch.zeros(4, 6), b2.b]), - common=Batch(c=np.stack([b1.common.c, b2.common.c]))) + ans = Batch( + a=np.stack([b1.a, np.zeros((4, 4))]), + b=torch.stack([torch.zeros(4, 6), b2.b]), + common=Batch(c=np.stack([b1.common.c, b2.common.c])) + ) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) @@ -369,8 +387,8 @@ def test_batch_over_batch_to_torch(): batch = Batch( a=np.float64(1.0), b=Batch( - c=np.ones((1,), dtype=np.float32), - d=torch.ones((1,), dtype=torch.float64) + c=np.ones((1, ), dtype=np.float32), + d=torch.ones((1, ), dtype=torch.float64) ) ) batch.b.__dict__['e'] = 1 # bypass the check @@ -397,8 +415,8 @@ def test_utils_to_torch_numpy(): batch = Batch( a=np.float64(1.0), b=Batch( - c=np.ones((1,), dtype=np.float32), - d=torch.ones((1,), dtype=torch.float64) + c=np.ones((1, ), dtype=np.float32), + d=torch.ones((1, ), dtype=torch.float64) ) ) a_torch_float = to_torch(batch.a, dtype=torch.float32) @@ -464,8 +482,7 @@ def test_utils_to_torch_numpy(): def test_batch_pickle(): - batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), - np=np.zeros([3, 4])) + batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4])) batch_pk = pickle.loads(pickle.dumps(batch)) assert batch.obs.a == batch_pk.obs.a assert torch.all(batch.obs.c == batch_pk.obs.c) @@ -473,7 +490,7 @@ def test_batch_pickle(): def test_batch_from_to_numpy_without_copy(): - batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) + batch = Batch(a=np.ones((1, )), b=Batch(c=np.ones((1, )))) a_mem_addr_orig = batch.a.__array_interface__['data'][0] c_mem_addr_orig = batch.b.c.__array_interface__['data'][0] batch.to_torch() @@ -517,19 +534,35 @@ def test_batch_copy(): def test_batch_empty(): - b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, - {'a': True, 'b': {'c': 3.0}}]) + b5_dict = np.array( + [{ + 'a': False, + 'b': { + 'c': 2.0, + 'd': 1.0 + } + }, { + 'a': True, + 'b': { + 'c': 3.0 + } + }] + ) b5 = Batch(b5_dict) b5[1] = Batch.empty(b5[0]) assert np.allclose(b5.a, [False, False]) assert np.allclose(b5.b.c, [2, 0]) assert np.allclose(b5.b.d, [1, 0]) - data = Batch(a=[False, True], - b={'c': np.array([2., 'st'], dtype=object), - 'd': [1, None], - 'e': [2., float('nan')]}, - c=np.array([1, 3, 4], dtype=int), - t=torch.tensor([4, 5, 6, 7.])) + data = Batch( + a=[False, True], + b={ + 'c': np.array([2., 'st'], dtype=object), + 'd': [1, None], + 'e': [2., float('nan')] + }, + c=np.array([1, 3, 4], dtype=int), + t=torch.tensor([4, 5, 6, 7.]) + ) data[-1] = Batch.empty(data[1]) assert np.allclose(data.c, [1, 3, 0]) assert np.allclose(data.a, [False, False]) @@ -550,9 +583,9 @@ def test_batch_empty(): def test_batch_standard_compatibility(): - batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), - b=Batch(), - c=np.array([5.0, 6.0])) + batch = Batch( + a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]) + ) batch_mean = np.mean(batch) assert isinstance(batch_mean, Batch) assert sorted(batch_mean.keys()) == ['a', 'b', 'c'] diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 39e5badd5..c1568c75d 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,18 +1,23 @@ import os -import h5py -import torch import pickle -import pytest import tempfile -import numpy as np from timeit import timeit -from tianshou.data.utils.converter import to_hdf5 -from tianshou.data import Batch, SegmentTree, ReplayBuffer -from tianshou.data import PrioritizedReplayBuffer -from tianshou.data import VectorReplayBuffer, CachedReplayBuffer -from tianshou.data import PrioritizedVectorReplayBuffer +import h5py +import numpy as np +import pytest +import torch +from tianshou.data import ( + Batch, + CachedReplayBuffer, + PrioritizedReplayBuffer, + PrioritizedVectorReplayBuffer, + ReplayBuffer, + SegmentTree, + VectorReplayBuffer, +) +from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': from env import MyTestEnv @@ -29,8 +34,9 @@ def test_replaybuffer(size=10, bufsize=20): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(Batch(obs=obs, act=[a], rew=rew, - done=done, obs_next=obs_next, info=info)) + buf.add( + Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info) + ) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert buf.act.dtype == int @@ -43,8 +49,20 @@ def test_replaybuffer(size=10, bufsize=20): # neg bsz should return empty index assert b.sample_indices(-1).tolist() == [] ptr, ep_rew, ep_len, ep_idx = b.add( - Batch(obs=1, act=1, rew=1, done=1, obs_next='str', - info={'a': 3, 'b': {'c': 5.0}})) + Batch( + obs=1, + act=1, + rew=1, + done=1, + obs_next='str', + info={ + 'a': 3, + 'b': { + 'c': 5.0 + } + } + ) + ) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' @@ -54,13 +72,24 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float assert np.all(b.info.b.c[1:] == 0.0) - assert ptr.shape == (1,) and ptr[0] == 0 - assert ep_rew.shape == (1,) and ep_rew[0] == 1 - assert ep_len.shape == (1,) and ep_len[0] == 1 - assert ep_idx.shape == (1,) and ep_idx[0] == 0 + assert ptr.shape == (1, ) and ptr[0] == 0 + assert ep_rew.shape == (1, ) and ep_rew[0] == 1 + assert ep_len.shape == (1, ) and ep_len[0] == 1 + assert ep_idx.shape == (1, ) and ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically - batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", - info={"a": 4, "d": {"e": -np.inf}}) + batch = Batch( + obs=2, + act=2, + rew=2, + done=0, + obs_next="str2", + info={ + "a": 4, + "d": { + "e": -np.inf + } + } + ) b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) @@ -71,10 +100,10 @@ def test_replaybuffer(size=10, bufsize=20): batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) - assert ptr.shape == (1,) and ptr[0] == 2 - assert ep_rew.shape == (1,) and ep_rew[0] == 4 - assert ep_len.shape == (1,) and ep_len[0] == 2 - assert ep_idx.shape == (1,) and ep_idx[0] == 1 + assert ptr.shape == (1, ) and ptr[0] == 2 + assert ep_rew.shape == (1, ) and ep_rew[0] == 4 + assert ep_len.shape == (1, ) and ep_len[0] == 2 + assert ep_idx.shape == (1, ) and ep_idx[0] == 1 assert set(b.info.keys()) == set(info_keys + ["e"]) assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): @@ -92,14 +121,22 @@ def test_ignore_obs_next(size=10): # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): - buf.add(Batch(obs={'mask1': np.array([i, 1, 1, 0, 0]), - 'mask2': np.array([i + 4, 0, 1, 0, 0]), - 'mask': i}, - act={'act_id': i, - 'position_id': i + 3}, - rew=i, - done=i % 3 == 0, - info={'if': i})) + buf.add( + Batch( + obs={ + 'mask1': np.array([i, 1, 1, 0, 0]), + 'mask2': np.array([i + 4, 0, 1, 0, 0]), + 'mask': i + }, + act={ + 'act_id': i, + 'position_id': i + 3 + }, + rew=i, + done=i % 3 == 0, + info={'if': i} + ) + ) indices = np.arange(len(buf)) orig = np.arange(len(buf)) data = buf[indices] @@ -113,15 +150,25 @@ def test_ignore_obs_next(size=10): data = buf[indices] data2 = buf[indices] assert np.allclose(data.obs_next.mask, data2.obs_next.mask) - assert np.allclose(data.obs_next.mask, np.array([ - [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], - [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6], - [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]])) + assert np.allclose( + data.obs_next.mask, + np.array( + [ + [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], [4, 4, 4, 5], + [4, 4, 5, 6], [4, 4, 5, 6], [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9] + ] + ) + ) assert np.allclose(data.info['if'], data2.info['if']) - assert np.allclose(data.info['if'], np.array([ - [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6], - [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]])) + assert np.allclose( + data.info['if'], + np.array( + [ + [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [4, 4, 4, 4], + [4, 4, 4, 5], [4, 4, 5, 6], [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9] + ] + ) + ) assert data.obs_next @@ -131,20 +178,30 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) - for i in range(16): + for _ in range(16): obs_next, rew, done, info = env.step(1) buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) - buf3.add(Batch(obs=[obs, obs, obs], act=1, rew=rew, - done=done, obs_next=[obs, obs], info=info)) + buf3.add( + Batch( + obs=[obs, obs, obs], + act=1, + rew=rew, + done=done, + obs_next=[obs, obs], + info=info + ) + ) obs = obs_next if done: obs = env.reset(1) indices = np.arange(len(buf)) - assert np.allclose(buf.get(indices, 'obs')[..., 0], [ - [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], - [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) + assert np.allclose( + buf.get(indices, 'obs')[..., 0], [ + [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], + [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1] + ] + ) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs')) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs_next')) _, indices = buf2.sample(0) @@ -165,8 +222,15 @@ def test_priortized_replaybuffer(size=32, bufsize=15): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - batch = Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, - info=info, policy=np.random.randn() - 0.5) + batch = Batch( + obs=obs, + act=a, + rew=rew, + done=done, + obs_next=obs_next, + info=info, + policy=np.random.randn() - 0.5 + ) batch_stack = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) buf2.add(batch_stack, buffer_ids=[0, 1, 2]) @@ -179,12 +243,12 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert len(buf) == min(bufsize, i + 1) assert len(buf2) == min(bufsize, 3 * (i + 1)) # check single buffer's data - assert buf.info.key.shape == (buf.maxsize,) + assert buf.info.key.shape == (buf.maxsize, ) assert buf.rew.dtype == float assert buf.done.dtype == bool data, indices = buf.sample(len(buf) // 2) buf.update_weight(indices, -data.weight / 2) - assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2) ** buf._alpha) + assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2)**buf._alpha) # check multi buffer's data assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) batch, indices = buf2.sample(10) @@ -200,8 +264,15 @@ def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): - buf1.add(Batch(obs=np.array([i]), act=float(i), rew=i * i, - done=i % 2 == 0, info={'incident': 'found'})) + buf1.add( + Batch( + obs=np.array([i]), + act=float(i), + rew=i * i, + done=i % 2 == 0, + info={'incident': 'found'} + ) + ) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) @@ -242,11 +313,10 @@ def test_segtree(): naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) - for i in range(10): + for _ in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) - assert np.allclose(realop(naive[left:right]), - tree.reduce(left, right)) + assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # large test actual_len = 16384 tree = SegmentTree(actual_len) @@ -257,11 +327,10 @@ def test_segtree(): naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) - for i in range(10): + for _ in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) - assert np.allclose(realop(naive[left:right]), - tree.reduce(left, right)) + assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # test prefix-sum-idx actual_len = 8 @@ -280,8 +349,9 @@ def test_segtree(): assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() tree = SegmentTree(10) tree[np.arange(3)] = np.array([0.1, 0, 0.1]) - assert np.allclose(tree.get_prefix_sum_idx( - np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]) + assert np.allclose( + tree.get_prefix_sum_idx(np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2] + ) with pytest.raises(AssertionError): tree.get_prefix_sum_idx(.2) # test large prefix-sum-idx @@ -321,8 +391,15 @@ def test_pickle(): for i in range(4): vbuf.add(Batch(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)) for i in range(5): - pbuf.add(Batch(obs=Batch(index=np.array([i])), - act=2, rew=rew, done=0, info=np.random.rand())) + pbuf.add( + Batch( + obs=Batch(index=np.array([i])), + act=2, + rew=rew, + done=0, + info=np.random.rand() + ) + ) # save & load _vbuf = pickle.loads(pickle.dumps(vbuf)) _pbuf = pickle.loads(pickle.dumps(pbuf)) @@ -330,8 +407,9 @@ def test_pickle(): assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) # make sure the meta var is identical assert _vbuf.stack_num == vbuf.stack_num - assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], - pbuf.weight[np.arange(len(pbuf))]) + assert np.allclose( + _pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))] + ) def test_hdf5(): @@ -349,7 +427,13 @@ def test_hdf5(): 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, - 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, + 'info': { + "number": { + "n": i, + "t": info_t + }, + 'extra': None + }, } buffers["array"].add(Batch(kwargs)) buffers["prioritized"].add(Batch(kwargs)) @@ -377,10 +461,8 @@ def test_hdf5(): assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: - assert np.all( - buffers[k][:].info.number.n == _buffers[k][:].info.number.n) - assert np.all( - buffers[k][:].info.extra == _buffers[k][:].info.extra) + assert np.all(buffers[k][:].info.number.n == _buffers[k][:].info.number.n) + assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} @@ -423,15 +505,16 @@ def test_replaybuffermanager(): indices_next = buf.next(indices) assert np.allclose(indices_next, indices), indices_next data = np.array([0, 0, 0, 0]) - buf.add(Batch(obs=data, act=data, rew=data, done=data), - buffer_ids=[0, 1, 2, 3]) - buf.add(Batch(obs=data, act=data, rew=data, done=1 - data), - buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=data), buffer_ids=[0, 1, 2, 3]) + buf.add( + Batch(obs=data, act=data, rew=data, done=1 - data), buffer_ids=[0, 1, 2, 3] + ) assert len(buf) == 12 - buf.add(Batch(obs=data, act=data, rew=data, done=data), - buffer_ids=[0, 1, 2, 3]) - buf.add(Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]), - buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=data), buffer_ids=[0, 1, 2, 3]) + buf.add( + Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]), + buffer_ids=[0, 1, 2, 3] + ) assert len(buf) == 20 indices = buf.sample_indices(120000) assert np.bincount(indices).min() >= 5000 @@ -439,44 +522,135 @@ def test_replaybuffermanager(): indices = buf.sample_indices(0) assert np.allclose(indices, np.arange(len(buf))) # check the actual data stored in buf._meta - assert np.allclose(buf.done, [ - 0, 0, 1, 0, 0, - 0, 0, 1, 0, 1, - 1, 0, 1, 0, 0, - 1, 0, 1, 0, 1, - ]) - assert np.allclose(buf.prev(indices), [ - 0, 0, 1, 3, 3, - 5, 5, 6, 8, 8, - 10, 11, 11, 13, 13, - 15, 16, 16, 18, 18, - ]) - assert np.allclose(buf.next(indices), [ - 1, 2, 2, 4, 4, - 6, 7, 7, 9, 9, - 10, 12, 12, 14, 14, - 15, 17, 17, 19, 19, - ]) + assert np.allclose( + buf.done, [ + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + ] + ) + assert np.allclose( + buf.prev(indices), [ + 0, + 0, + 1, + 3, + 3, + 5, + 5, + 6, + 8, + 8, + 10, + 11, + 11, + 13, + 13, + 15, + 16, + 16, + 18, + 18, + ] + ) + assert np.allclose( + buf.next(indices), [ + 1, + 2, + 2, + 4, + 4, + 6, + 7, + 7, + 9, + 9, + 10, + 12, + 12, + 14, + 14, + 15, + 17, + 17, + 19, + 19, + ] + ) assert np.allclose(buf.unfinished_index(), [4, 14]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2]) + Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2] + ) assert np.all(ep_len == [3]) and np.all(ep_rew == [1]) assert np.all(ptr == [10]) and np.all(ep_idx == [13]) assert np.allclose(buf.unfinished_index(), [4]) indices = list(sorted(buf.sample_indices(0))) assert np.allclose(indices, np.arange(len(buf))) - assert np.allclose(buf.prev(indices), [ - 0, 0, 1, 3, 3, - 5, 5, 6, 8, 8, - 14, 11, 11, 13, 13, - 15, 16, 16, 18, 18, - ]) - assert np.allclose(buf.next(indices), [ - 1, 2, 2, 4, 4, - 6, 7, 7, 9, 9, - 10, 12, 12, 14, 10, - 15, 17, 17, 19, 19, - ]) + assert np.allclose( + buf.prev(indices), [ + 0, + 0, + 1, + 3, + 3, + 5, + 5, + 6, + 8, + 8, + 14, + 11, + 11, + 13, + 13, + 15, + 16, + 16, + 18, + 18, + ] + ) + assert np.allclose( + buf.next(indices), [ + 1, + 2, + 2, + 4, + 4, + 6, + 7, + 7, + 9, + 9, + 10, + 12, + 12, + 14, + 10, + 15, + 17, + 17, + 19, + 19, + ] + ) # corner case: list, int and -1 assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] @@ -493,7 +667,8 @@ def test_cachedbuffer(): assert buf.sample_indices(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1]) + Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1] + ) obs = np.zeros(buf.maxsize) obs[15] = 1 indices = buf.sample_indices(0) @@ -504,7 +679,8 @@ def test_cachedbuffer(): assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0]) assert np.all(ptr == [15]) and np.all(ep_idx == [15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3]) + Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3] + ) obs[[0, 25]] = 2 indices = buf.sample_indices(0) assert np.allclose(indices, [0, 15]) @@ -516,8 +692,8 @@ def test_cachedbuffer(): assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_indices(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), - buffer_ids=[3, 1]) + Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), buffer_ids=[3, 1] + ) assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0]) assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] @@ -540,16 +716,35 @@ def test_cachedbuffer(): buf.add(Batch(obs=data, act=data, rew=rew, done=[1, 1, 1, 1])) buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1])) + Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1]) + ) assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1]) assert np.all(ep_len == [0, 2, 0, 2]) assert np.all(ep_rew == [data, data + 2, data, data + 2]) - assert np.allclose(buf.done, [ - 0, 0, 1, 0, 0, - 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, - 0, 1, 0, 0, 0, - ]) + assert np.allclose( + buf.done, [ + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + ] + ) indices = buf.sample_indices(0) assert np.allclose(indices, [0, 1, 10, 11]) assert np.allclose(buf.prev(indices), [0, 0, 10, 10]) @@ -564,14 +759,16 @@ def test_multibuf_stack(): env = MyTestEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), - cached_num, size) + ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, + size + ) # test if CachedReplayBuffer can handle corner case: # buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, - ignore_obs_next=True, sample_avail=True), - cached_num, size) + ReplayBuffer( + bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True + ), cached_num, size + ) obs = env.reset(1) for i in range(18): obs_next, rew, done, info = env.step(1) @@ -581,8 +778,14 @@ def test_multibuf_stack(): done_list = [done] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num - batch = Batch(obs=obs_list, act=act_list, rew=rew_list, - done=done_list, obs_next=obs_next_list, info=info_list) + batch = Batch( + obs=obs_list, + act=act_list, + rew=rew_list, + done=done_list, + obs_next=obs_next_list, + info=info_list + ) buf5.add(batch) buf4.add(batch) assert np.all(buf4.obs == buf5.obs) @@ -591,35 +794,105 @@ def test_multibuf_stack(): if done: obs = env.reset(1) # check the `add` order is correct - assert np.allclose(buf4.obs.reshape(-1), [ - 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer - 1, 2, 3, 4, 0, # cached_buffer[0] - 6, 7, 8, 9, 0, # cached_buffer[1] - 11, 12, 13, 14, 0, # cached_buffer[2] - ]), buf4.obs - assert np.allclose(buf4.done, [ - 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer - 0, 0, 0, 1, 0, # cached_buffer[0] - 0, 0, 0, 1, 0, # cached_buffer[1] - 0, 0, 0, 1, 0, # cached_buffer[2] - ]), buf4.done + assert np.allclose( + buf4.obs.reshape(-1), + [ + 12, + 13, + 14, + 4, + 6, + 7, + 8, + 9, + 11, # main_buffer + 1, + 2, + 3, + 4, + 0, # cached_buffer[0] + 6, + 7, + 8, + 9, + 0, # cached_buffer[1] + 11, + 12, + 13, + 14, + 0, # cached_buffer[2] + ] + ), buf4.obs + assert np.allclose( + buf4.done, + [ + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 1, + 0, # main_buffer + 0, + 0, + 0, + 1, + 0, # cached_buffer[0] + 0, + 0, + 0, + 1, + 0, # cached_buffer[1] + 0, + 0, + 0, + 1, + 0, # cached_buffer[2] + ] + ), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indices = sorted(buf4.sample_indices(0)) assert np.allclose(indices, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) - assert np.allclose(buf4[indices].obs[..., 0], [ - [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], - [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], - [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], - [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], - [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], - ]) - assert np.allclose(buf4[indices].obs_next[..., 0], [ - [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], - [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], - [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], - [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], - [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], - ]) + assert np.allclose( + buf4[indices].obs[..., 0], [ + [11, 11, 11, 12], + [11, 11, 12, 13], + [11, 12, 13, 14], + [4, 4, 4, 4], + [6, 6, 6, 6], + [6, 6, 6, 7], + [6, 6, 7, 8], + [6, 7, 8, 9], + [11, 11, 11, 11], + [1, 1, 1, 1], + [1, 1, 1, 2], + [6, 6, 6, 6], + [6, 6, 6, 7], + [11, 11, 11, 11], + [11, 11, 11, 12], + ] + ) + assert np.allclose( + buf4[indices].obs_next[..., 0], [ + [11, 11, 12, 13], + [11, 12, 13, 14], + [11, 12, 13, 14], + [4, 4, 4, 4], + [6, 6, 6, 7], + [6, 6, 7, 8], + [6, 7, 8, 9], + [6, 7, 8, 9], + [11, 11, 11, 12], + [1, 1, 1, 2], + [1, 1, 1, 2], + [6, 6, 6, 7], + [6, 6, 6, 7], + [11, 11, 11, 12], + [11, 11, 11, 12], + ] + ) indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [2, 7]) assert np.all(np.isin(buf5.sample_indices(100), indices)) @@ -632,12 +905,24 @@ def test_multibuf_stack(): batch, _ = buf5.sample(0) # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, - save_only_last_obs=True, ignore_obs_next=True), - cached_num, size) + ReplayBuffer( + bufsize, + stack_num=stack_num, + save_only_last_obs=True, + ignore_obs_next=True + ), cached_num, size + ) obs = np.random.rand(size, 4, 84, 84) - buf6.add(Batch(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], - obs_next=[obs[3], obs[1]]), buffer_ids=[1, 2]) + buf6.add( + Batch( + obs=[obs[2], obs[0]], + act=[1, 1], + rew=[0, 0], + done=[0, 1], + obs_next=[obs[3], obs[1]] + ), + buffer_ids=[1, 2] + ) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) @@ -660,12 +945,20 @@ def test_multibuf_hdf5(): 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, - 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, + 'info': { + "number": { + "n": i, + "t": info_t + }, + 'extra': None + }, } - buffers["vector"].add(Batch.stack([kwargs, kwargs, kwargs]), - buffer_ids=[0, 1, 2]) - buffers["cached"].add(Batch.stack([kwargs, kwargs, kwargs]), - buffer_ids=[0, 1, 2]) + buffers["vector"].add( + Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2] + ) + buffers["cached"].add( + Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2] + ) # save paths = {} @@ -696,7 +989,12 @@ def test_multibuf_hdf5(): 'act': 5, 'rew': np.array([2, 1]), 'done': False, - 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, + 'info': { + "number": { + "n": i + }, + 'Timelimit.truncate': True + }, } buffers[k].add(Batch.stack([kwargs, kwargs, kwargs, kwargs])) act = np.zeros(buffers[k].maxsize) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 79d1430b8..61bd5a6fc 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,17 +1,19 @@ -import tqdm -import pytest import numpy as np +import pytest +import tqdm from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import BasePolicy -from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Batch, Collector, AsyncCollector from tianshou.data import ( - ReplayBuffer, + AsyncCollector, + Batch, + CachedReplayBuffer, + Collector, PrioritizedReplayBuffer, + ReplayBuffer, VectorReplayBuffer, - CachedReplayBuffer, ) +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import BasePolicy if __name__ == '__main__': from env import MyTestEnv, NXEnv @@ -20,6 +22,7 @@ class MyPolicy(BasePolicy): + def __init__(self, dict_state=False, need_state=True): """ :param bool dict_state: if the observation of the environment is a dict @@ -44,6 +47,7 @@ def learn(self): class Logger: + def __init__(self, writer): self.cnt = 0 self.writer = writer @@ -56,8 +60,7 @@ def preprocess_fn(self, **kwargs): info = kwargs['info'] info.rew = kwargs['rew'] if 'key' in info.keys(): - self.writer.add_scalar( - 'key', np.mean(info.key), global_step=self.cnt) + self.writer.add_scalar('key', np.mean(info.key), global_step=self.cnt) self.cnt += 1 return Batch(info=info) else: @@ -91,13 +94,12 @@ def test_collector(): c0.collect(n_episode=3) assert len(c0.buffer) == 8 assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], - [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) c1 = Collector( - policy, venv, - VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn) + policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn + ) c1.collect(n_step=8) obs = np.zeros(100) obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] @@ -108,13 +110,15 @@ def test_collector(): assert len(c1.buffer) == 16 obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c1.buffer.obs[:, 0], obs) - assert np.allclose(c1.buffer[:].obs_next[..., 0], - [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) + assert np.allclose( + c1.buffer[:].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5] + ) c1.collect(n_episode=4, random=True) c2 = Collector( - policy, dum, - VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn) + policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn + ) c2.collect(n_episode=7) obs1 = obs.copy() obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] @@ -139,10 +143,10 @@ def test_collector(): # test NXEnv for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([ - lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]) - c3 = Collector(policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4)) + envs = SubprocVectorEnv( + [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]] + ) + c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=6) assert c3.buffer.obs.dtype == object @@ -151,23 +155,23 @@ def test_collector_with_async(): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) - for i in env_lens] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens + ] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() bufsize = 60 c1 = AsyncCollector( - policy, venv, - VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), - logger.preprocess_fn) + policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), + logger.preprocess_fn + ) ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): result = c1.collect(n_episode=n_episode) assert result["n/ep"] >= n_episode # check buffer data, obs and obs_next, env_id - for i, count in enumerate( - np.bincount(result["lens"], minlength=6)[2:]): + for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize @@ -176,8 +180,7 @@ def test_collector_with_async(): buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.obs[indices].reshape(count, env_len) == seq) - assert np.all(buf.obs_next[indices].reshape( - count, env_len) == seq + 1) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 15, desc="test async n_step"): result = c1.collect(n_step=n_step) @@ -196,21 +199,21 @@ def test_collector_with_async(): def test_collector_with_dict_state(): env = MyTestEnv(size=5, sleep=0, dict_state=True) policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100), - Logger.single_preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=2) assert len(c0.buffer) == 10 - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) - for i in [2, 3, 4, 5]] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5] + ] envs = DummyVectorEnv(env_fns) envs.seed(666) obs = envs.reset() assert not np.isclose(obs[0]['rand'], obs[1]['rand']) c1 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn + ) c1.collect(n_step=12) result = c1.collect(n_episode=8) assert result['n/ep'] == 8 @@ -221,25 +224,104 @@ def test_collector_with_dict_state(): c0.buffer.update(c1.buffer) assert len(c0.buffer) in [42, 43] if len(c0.buffer) == 42: - assert np.all(c0.buffer[:].obs.index[..., 0] == [ - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - 0, 1, 0, 1, 0, 1, 0, 1, - 0, 1, 2, 0, 1, 2, - 0, 1, 2, 3, 0, 1, 2, 3, - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - ]), c0.buffer[:].obs.index[..., 0] + assert np.all( + c0.buffer[:].obs.index[..., 0] == [ + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 2, + 0, + 1, + 2, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + ] + ), c0.buffer[:].obs.index[..., 0] else: - assert np.all(c0.buffer[:].obs.index[..., 0] == [ - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - 0, 1, 0, 1, 0, 1, - 0, 1, 2, 0, 1, 2, 0, 1, 2, - 0, 1, 2, 3, 0, 1, 2, 3, - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - ]), c0.buffer[:].obs.index[..., 0] + assert np.all( + c0.buffer[:].obs.index[..., 0] == [ + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 2, + 0, + 1, + 2, + 0, + 1, + 2, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + ] + ), c0.buffer[:].obs.index[..., 0] c2 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn + ) c2.collect(n_episode=10) batch, _ = c2.buffer.sample(10) @@ -247,20 +329,18 @@ def test_collector_with_dict_state(): def test_collector_with_ma(): env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() - c0 = Collector(policy, env, ReplayBuffer(size=100), - Logger.single_preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) # n_step=3 will collect a full episode r = c0.collect(n_step=3)['rews'] assert len(r) == 0 r = c0.collect(n_episode=2)['rews'] assert r.shape == (2, 4) and np.all(r == 1) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) - for i in [2, 3, 4, 5]] + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c1 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn + ) r = c1.collect(n_step=12)['rews'] assert r.shape == (2, 4) and np.all(r == 1), r r = c1.collect(n_episode=8)['rews'] @@ -271,26 +351,101 @@ def test_collector_with_ma(): assert len(c0.buffer) in [42, 43] if len(c0.buffer) == 42: rew = [ - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, - 0, 0, 1, 0, 0, 1, - 0, 0, 0, 1, 0, 0, 0, 1, - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, ] else: rew = [ - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, - 0, 1, 0, 1, 0, 1, - 0, 0, 1, 0, 0, 1, 0, 0, 1, - 0, 0, 0, 1, 0, 0, 0, 1, - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, ] assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) assert np.all(c0.buffer[:].done == rew) c2 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn + ) r = c2.collect(n_episode=10)['rews'] assert r.shape == (10, 4) and np.all(r == 1) batch, _ = c2.buffer.sample(10) @@ -326,22 +481,23 @@ def test_collector_with_atari_setting(): c2 = Collector( policy, env, - ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True)) + ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True) + ) c2.collect(n_step=8) assert c2.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c2.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] assert np.all(c2.buffer.obs == obs) - assert np.allclose(c2.buffer[:].obs_next, - reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) + assert np.allclose( + c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1] + ) # atari multi buffer - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) - for i in [2, 3, 4, 5]] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5] + ] envs = DummyVectorEnv(env_fns) - c3 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4)) + c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=12) result = c3.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 @@ -360,8 +516,14 @@ def test_collector_with_atari_setting(): assert np.all(obs_next == c3.buffer.obs_next) c4 = Collector( policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4, - ignore_obs_next=True, save_only_last_obs=True)) + VectorReplayBuffer( + total_size=100, + buffer_num=4, + stack_num=4, + ignore_obs_next=True, + save_only_last_obs=True + ) + ) c4.collect(n_step=12) result = c4.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 @@ -374,12 +536,45 @@ def test_collector_with_atari_setting(): obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] assert np.all(c4.buffer.obs == obs) obs_next = np.zeros([len(c4.buffer), 4, 84, 84]) - ref_index = np.array([ - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 2, 2, 1, 2, 2, 1, 2, 2, - 1, 2, 3, 3, 1, 2, 3, 3, - 1, 2, 3, 4, 4, 1, 2, 3, 4, 4, - ]) + ref_index = np.array( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 1, + 2, + 2, + 1, + 2, + 2, + 1, + 2, + 3, + 3, + 1, + 2, + 3, + 3, + 1, + 2, + 3, + 4, + 4, + 1, + 2, + 3, + 4, + 4, + ] + ) obs_next[:, -1] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 @@ -392,20 +587,25 @@ def test_collector_with_atari_setting(): obs_next[:, -4] = slice_obs[ref_index] assert np.all(obs_next == c4.buffer[:].obs_next) - buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, - save_only_last_obs=True) + buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) result_ = c5.collect(n_step=12) assert len(buf) == 5 and len(c5.buffer) == 12 result = c5.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 assert len(buf) == 35 - assert np.all(buf.obs[:len(buf)] == slice_obs[[ - 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, - 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4]]) - assert np.all(buf[:].obs_next[:, -1] == slice_obs[[ - 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, - 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]]) + assert np.all( + buf.obs[:len(buf)] == slice_obs[[ + 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 0, 1, 0, 1, + 2, 3, 0, 1, 2, 0, 1, 2, 3, 4 + ]] + ) + assert np.all( + buf[:].obs_next[:, -1] == slice_obs[[ + 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 2, + 3, 3, 1, 2, 2, 1, 2, 3, 4, 4 + ]] + ) assert len(buf) == len(c5.buffer) # test buffer=None diff --git a/test/base/test_env.py b/test/base/test_env.py index cc1dc84c7..b9d6489b6 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,10 +1,11 @@ import sys import time + import numpy as np from gym.spaces.discrete import Discrete + from tianshou.data import Batch -from tianshou.env import DummyVectorEnv, SubprocVectorEnv, \ - ShmemVectorEnv, RayVectorEnv +from tianshou.env import DummyVectorEnv, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv if __name__ == '__main__': from env import MyTestEnv, NXEnv @@ -24,17 +25,14 @@ def recurse_comp(a, b): try: if isinstance(a, np.ndarray): if a.dtype == object: - return np.array( - [recurse_comp(m, n) for m, n in zip(a, b)]).all() + return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all() else: return np.allclose(a, b) elif isinstance(a, (list, tuple)): - return np.array( - [recurse_comp(m, n) for m, n in zip(a, b)]).all() + return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all() elif isinstance(a, dict): - return np.array( - [recurse_comp(a[k], b[k]) for k in a.keys()]).all() - except(Exception): + return np.array([recurse_comp(a[k], b[k]) for k in a.keys()]).all() + except (Exception): return False @@ -75,7 +73,7 @@ def test_async_env(size=10000, num=8, sleep=0.1): # truncate env_ids with the first terms # typically len(env_ids) == len(A) == len(action), except for the # last batch when actions are not enough - env_ids = env_ids[: len(action)] + env_ids = env_ids[:len(action)] spent_time = time.time() - spent_time Batch.cat(o) v.close() @@ -85,10 +83,12 @@ def test_async_env(size=10000, num=8, sleep=0.1): def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): - env_fns = [lambda: MyTestEnv(size=size, sleep=sleep * 2), - lambda: MyTestEnv(size=size, sleep=sleep * 3), - lambda: MyTestEnv(size=size, sleep=sleep * 5), - lambda: MyTestEnv(size=size, sleep=sleep * 7)] + env_fns = [ + lambda: MyTestEnv(size=size, sleep=sleep * 2), + lambda: MyTestEnv(size=size, sleep=sleep * 3), + lambda: MyTestEnv(size=size, sleep=sleep * 5), + lambda: MyTestEnv(size=size, sleep=sleep * 7) + ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] @@ -113,8 +113,10 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): t = time.time() - t ids = Batch(info).env_id print(ids, t) - if not (len(ids) == len(res) and np.allclose(sorted(ids), res) - and (t < timeout) == (len(res) == num - 1)): + if not ( + len(ids) == len(res) and np.allclose(sorted(ids), res) and + (t < timeout) == (len(res) == num - 1) + ): pass_check = 0 break total_pass += pass_check @@ -138,7 +140,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): v.seed(0) action_list = [1] * 5 + [0] * 10 + [1] * 20 o = [v.reset() for v in venv] - for i, a in enumerate(action_list): + for a in action_list: o = [] for v in venv: A, B, C, D = v.step([a] * num) @@ -150,6 +152,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): continue for info in infos: assert recurse_comp(infos[0], info) + if __name__ == '__main__': t = [0] * len(venv) for i, e in enumerate(venv): @@ -162,6 +165,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): t[i] = time.time() - t[i] for i, v in enumerate(venv): print(f'{type(v)}: {t[i]:.6f}s') + for v in venv: assert v.size == list(range(size, size + num)) assert v.env_num == num @@ -172,8 +176,9 @@ def test_vecenv(size=10, num=8, sleep=0.001): def test_env_obs(): for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([ - lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]) + envs = SubprocVectorEnv( + [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]] + ) obs = envs.reset() assert obs.dtype == object obs = envs.step([1, 1, 1, 1])[0] diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index b670d65e9..54b438507 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -1,17 +1,19 @@ # see issue #322 for detail -import gym import copy -import numpy as np from collections import Counter -from torch.utils.data import Dataset, DataLoader, DistributedSampler -from tianshou.policy import BasePolicy -from tianshou.data import Collector, Batch +import gym +import numpy as np +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from tianshou.data import Batch, Collector from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import BasePolicy class DummyDataset(Dataset): + def __init__(self, length): self.length = length self.episodes = [3 * i % 5 + 1 for i in range(self.length)] @@ -25,6 +27,7 @@ def __len__(self): class FiniteEnv(gym.Env): + def __init__(self, dataset, num_replicas, rank): self.dataset = dataset self.num_replicas = num_replicas @@ -32,7 +35,8 @@ def __init__(self, dataset, num_replicas, rank): self.loader = DataLoader( dataset, sampler=DistributedSampler(dataset, num_replicas, rank), - batch_size=None) + batch_size=None + ) self.iterator = None def reset(self): @@ -54,6 +58,7 @@ def step(self, action): class FiniteVectorEnv(BaseVectorEnv): + def __init__(self, env_fns, **kwargs): super().__init__(env_fns, **kwargs) self._alive_env_ids = set() @@ -79,6 +84,7 @@ def _get_default_obs(self): def _get_default_info(self): return copy.deepcopy(self._default_info) + # END def reset(self, id=None): @@ -147,6 +153,7 @@ class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): class AnyPolicy(BasePolicy): + def forward(self, batch, state=None): return Batch(act=np.stack([1] * len(batch))) @@ -159,6 +166,7 @@ def _finite_env_factory(dataset, num_replicas, rank): class MetricTracker: + def __init__(self): self.counter = Counter() self.finished = set() @@ -179,30 +187,32 @@ def validate(self): def test_finite_dummy_vector_env(): dataset = DummyDataset(100) - envs = FiniteSubprocVectorEnv([ - _finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteSubprocVectorEnv( + [_finite_env_factory(dataset, 5, i) for i in range(5)] + ) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() def test_finite_subproc_vector_env(): dataset = DummyDataset(100) - envs = FiniteSubprocVectorEnv([ - _finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteSubprocVectorEnv( + [_finite_env_factory(dataset, 5, i) for i in range(5)] + ) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() diff --git a/test/base/test_returns.py b/test/base/test_returns.py index a104eb674..3adcdaf5c 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -1,9 +1,10 @@ -import torch -import numpy as np from timeit import timeit -from tianshou.policy import BasePolicy +import numpy as np +import torch + from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.policy import BasePolicy def compute_episodic_return_base(batch, gamma): @@ -24,8 +25,12 @@ def test_episodic_returns(size=2560): batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), - info=Batch({'TimeLimit.truncated': - np.array([False, False, False, False, False, True, False, False])}) + info=Batch( + { + 'TimeLimit.truncated': + np.array([False, False, False, False, False, True, False, False]) + } + ) ) for b in batch: b.obs = b.act = 1 @@ -65,28 +70,40 @@ def test_episodic_returns(size=2560): buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) - ground_truth = np.array([ - 454.8344, 376.1143, 291.298, 200., - 464.5610, 383.1085, 295.387, 201., - 474.2876, 390.1027, 299.476, 202.]) + ground_truth = np.array( + [ + 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., + 474.2876, 390.1027, 299.476, 202. + ] + ) assert np.allclose(returns, ground_truth) buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), - info=Batch({'TimeLimit.truncated': - np.array([False, False, False, True, False, False, - False, True, False, False, False, False])}) + info=Batch( + { + 'TimeLimit.truncated': + np.array( + [ + False, False, False, True, False, False, False, True, False, + False, False, False + ] + ) + } + ) ) for b in batch: b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) - ground_truth = np.array([ - 454.0109, 375.2386, 290.3669, 199.01, - 462.9138, 381.3571, 293.5248, 199.02, - 474.2876, 390.1027, 299.476, 202.]) + ground_truth = np.array( + [ + 454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248, 199.02, + 474.2876, 390.1027, 299.476, 202. + ] + ) assert np.allclose(returns, ground_truth) if __name__ == '__main__': @@ -129,16 +146,17 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices): real_step_n = nstep for n in range(nstep): idx = (indices[i] + n) % buf_len - r += buffer.rew[idx] * gamma ** n + r += buffer.rew[idx] * gamma**n if buffer.done[idx]: - if not (hasattr(buffer, 'info') and - buffer.info['TimeLimit.truncated'][idx]): + if not ( + hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx] + ): flag = True real_step_n = n + 1 break if not flag: idx = (indices[i] + real_step_n - 1) % buf_len - r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n + r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n returns[i] = r return returns @@ -152,89 +170,128 @@ def test_nstep_returns(size=10000): # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=1 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indices) assert np.allclose(returns, r_), (r_, returns) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=2 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=10 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) def test_nstep_returns_with_timelimit(size=10000): buf = ReplayBuffer(10) for i in range(12): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, - info={"TimeLimit.truncated": i == 3})) + buf.add( + Batch( + obs=0, + act=0, + rew=i + 1, + done=i % 4 == 3, + info={"TimeLimit.truncated": i == 3} + ) + ) batch, indices = buf.sample(0) assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=1 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indices) assert np.allclose(returns, r_), (r_, returns) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=2 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=10 - ).pop('returns').reshape(-1)) - assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, - 7.8, 8, 10.122, 11.22, 12.2, 12]) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1) + ) + assert np.allclose( + returns, [3.36, 3.6, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12] + ) r_ = compute_nstep_return_base(10, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) if __name__ == '__main__': buf = ReplayBuffer(size) for i in range(int(size * 1.5)): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0, - info={"TimeLimit.truncated": i % 33 == 0})) + buf.add( + Batch( + obs=0, + act=0, + rew=i + 1, + done=np.random.randint(3) == 0, + info={"TimeLimit.truncated": i % 33 == 0} + ) + ) batch, indices = buf.sample(256) def vanilla(): @@ -242,7 +299,8 @@ def vanilla(): def optimized(): return BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=3) + batch, buf, indices, target_q_fn, gamma=.1, n_step=3 + ) cnt = 3000 print('nstep vanilla', timeit(vanilla, setup=vanilla, number=cnt)) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index d0e1cefb3..38bf5d40e 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,9 +1,9 @@ -import torch import numpy as np +import torch -from tianshou.utils.net.common import MLP, Net -from tianshou.utils import MovAvg, RunningMeanStd from tianshou.exploration import GaussianNoise, OUNoise +from tianshou.utils import MovAvg, RunningMeanStd +from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -20,14 +20,14 @@ def test_moving_average(): stat = MovAvg(10) assert np.allclose(stat.get(), 0) assert np.allclose(stat.mean(), 0) - assert np.allclose(stat.std() ** 2, 0) + assert np.allclose(stat.std()**2, 0) stat.add(torch.tensor([1])) stat.add(np.array([2])) stat.add([3, 4]) stat.add(5.) assert np.allclose(stat.get(), 3) assert np.allclose(stat.mean(), 3) - assert np.allclose(stat.std() ** 2, 2) + assert np.allclose(stat.std()**2, 2) def test_rms(): @@ -55,23 +55,36 @@ def test_net(): action_shape = (5, ) data = torch.rand([bsz, *state_shape]) expect_output_shape = [bsz, *action_shape] - net = Net(state_shape, action_shape, hidden_sizes=[128, 128], - norm_layer=torch.nn.LayerNorm, activation=None) + net = Net( + state_shape, + action_shape, + hidden_sizes=[128, 128], + norm_layer=torch.nn.LayerNorm, + activation=None + ) assert list(net(data)[0].shape) == expect_output_shape assert str(net).count("LayerNorm") == 2 assert str(net).count("ReLU") == 0 Q_param = V_param = {"hidden_sizes": [128, 128]} - net = Net(state_shape, action_shape, hidden_sizes=[128, 128], - dueling_param=(Q_param, V_param)) + net = Net( + state_shape, + action_shape, + hidden_sizes=[128, 128], + dueling_param=(Q_param, V_param) + ) assert list(net(data)[0].shape) == expect_output_shape # concat - net = Net(state_shape, action_shape, hidden_sizes=[128], - concat=True) + net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True) data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape - net = Net(state_shape, action_shape, hidden_sizes=[128], - concat=True, dueling_param=(Q_param, V_param)) + net = Net( + state_shape, + action_shape, + hidden_sizes=[128], + concat=True, + dueling_param=(Q_param, V_param) + ) assert list(net(data)[0].shape) == expect_output_shape # recurrent actor/critic data = torch.rand([bsz, *state_shape]).flatten(1) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 3dbeb9091..e88c0869c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -39,8 +40,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -56,36 +57,51 @@ def test_ddpg(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic = Critic(net, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( - actor, actor_optim, critic, critic_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic, + critic_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=args.rew_norm, - estimation_step=args.n_step, action_space=env.action_space) + estimation_step=args.n_step, + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ddpg') @@ -100,10 +116,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index ad758897f..1a9e82623 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -1,19 +1,20 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch import nn -from torch.utils.tensorboard import SummaryWriter from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import NPGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -27,8 +28,9 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-collect', type=int, default=2048) - parser.add_argument('--repeat-per-collect', type=int, - default=2) # theoretically it should be 1 + parser.add_argument( + '--repeat-per-collect', type=int, default=2 + ) # theoretically it should be 1 parser.add_argument('--batch-size', type=int, default=99999) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=16) @@ -36,8 +38,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # npg special parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument('--rew-norm', type=int, default=1) @@ -58,23 +60,40 @@ def test_npg(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - critic = Critic(Net( - args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device, - activation=nn.Tanh), device=args.device).to(args.device) + net = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + critic = Critic( + Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + activation=nn.Tanh + ), + device=args.device + ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -88,7 +107,10 @@ def dist(*logits): return Independent(Normal(*logits), 1) policy = NPGPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, @@ -96,11 +118,12 @@ def dist(*logits): action_space=env.action_space, optim_critic_iters=args.optim_critic_iters, actor_step_size=args.actor_step_size, - deterministic_eval=True) + deterministic_eval=True + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'npg') @@ -115,10 +138,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 09c73fdd5..473222816 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np -from torch.utils.tensorboard import SummaryWriter +import torch from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -34,8 +35,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # ppo special parser.add_argument('--vf-coef', type=float, default=0.25) parser.add_argument('--ent-coef', type=float, default=0.0) @@ -63,30 +64,34 @@ def test_ppo(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) - critic = Critic(Net( - args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device - ), device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) + critic = Critic( + Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + device=args.device + ).to(args.device) # orthogonal initialization for m in set(actor.modules()).union(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam( - set(actor.parameters()).union(critic.parameters()), lr=args.lr) + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -94,7 +99,10 @@ def dist(*logits): return Independent(Normal(*logits), 1) policy = PPOPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -107,11 +115,12 @@ def dist(*logits): # dual clip cause monotonically increasing log_std :) value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space) + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') @@ -126,10 +135,12 @@ def stop_fn(mean_rewards): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) if args.resume: # load from existing checkpoint @@ -145,11 +156,21 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger, resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 5b6a79492..da20290ec 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net +from tianshou.policy import ImitationPolicy, SACPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.policy import SACPolicy, ImitationPolicy +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic @@ -34,10 +35,10 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--imitation-hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument( + '--imitation-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -45,8 +46,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -62,29 +63,43 @@ def test_sac_with_il(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - device=args.device, unbounded=True).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -95,15 +110,26 @@ def test_sac_with_il(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, reward_normalization=args.rew_norm, - estimation_step=args.n_step, action_space=env.action_space) + estimation_step=args.n_step, + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -119,10 +145,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -140,23 +175,41 @@ def stop_fn(mean_rewards): if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal net = Actor( - Net(args.state_shape, hidden_sizes=args.imitation_hidden_sizes, - device=args.device), - args.action_shape, max_action=args.max_action, device=args.device + Net( + args.state_shape, + hidden_sizes=args.imitation_hidden_sizes, + device=args.device + ), + args.action_shape, + max_action=args.max_action, + device=args.device ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy( - net, optim, action_space=env.action_space, - action_scaling=True, action_bound_method="clip") + net, + optim, + action_space=env.action_space, + action_scaling=True, + action_bound_method="clip" + ) il_test_collector = Collector( il_policy, DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) ) train_collector.reset() result = offpolicy_trainer( - il_policy, train_collector, il_test_collector, args.epoch, - args.il_step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) + il_policy, + train_collector, + il_test_collector, + args.epoch, + args.il_step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 8bae1edfa..2e3ef7ba7 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -42,8 +43,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -59,46 +60,65 @@ def test_td3(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, reward_normalization=args.rew_norm, estimation_step=args.n_step, - action_space=env.action_space) + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -114,10 +134,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 65535fd50..4a4206f5f 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -1,19 +1,20 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch import nn -from torch.utils.tensorboard import SummaryWriter from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import TRPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -27,8 +28,9 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-collect', type=int, default=2048) - parser.add_argument('--repeat-per-collect', type=int, - default=2) # theoretically it should be 1 + parser.add_argument( + '--repeat-per-collect', type=int, default=2 + ) # theoretically it should be 1 parser.add_argument('--batch-size', type=int, default=99999) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=16) @@ -36,8 +38,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # trpo special parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument('--rew-norm', type=int, default=1) @@ -61,23 +63,40 @@ def test_trpo(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - critic = Critic(Net( - args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device, - activation=nn.Tanh), device=args.device).to(args.device) + net = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + critic = Critic( + Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + activation=nn.Tanh + ), + device=args.device + ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -91,7 +110,10 @@ def dist(*logits): return Independent(Normal(*logits), 1) policy = TRPOPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, @@ -100,11 +122,12 @@ def dist(*logits): optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, backtrack_coeff=args.backtrack_coeff, - max_backtracks=args.max_backtracks) + max_backtracks=args.max_backtracks + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'trpo') @@ -119,10 +142,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index d11ce360c..745295826 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv +from tianshou.policy import A2CPolicy, ImitationPolicy +from tianshou.trainer import offpolicy_trainer, onpolicy_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic -from tianshou.policy import A2CPolicy, ImitationPolicy -from tianshou.trainer import onpolicy_trainer, offpolicy_trainer def get_args(): @@ -31,17 +32,15 @@ def get_args(): parser.add_argument('--update-per-step', type=float, default=1 / 16) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) - parser.add_argument('--imitation-hidden-sizes', type=int, - nargs='*', default=[128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--imitation-hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # a2c special parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--ent-coef', type=float, default=0.0) @@ -60,33 +59,42 @@ def test_a2c_with_il(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam( - set(actor.parameters()).union(critic.parameters()), lr=args.lr) + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) dist = torch.distributions.Categorical policy = A2CPolicy( - actor, critic, optim, dist, - discount_factor=args.gamma, gae_lambda=args.gae_lambda, - vf_coef=args.vf_coef, ent_coef=args.ent_coef, - max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm, - action_space=env.action_space) + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + max_grad_norm=args.max_grad_norm, + reward_normalization=args.rew_norm, + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'a2c') @@ -101,10 +109,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -121,8 +138,7 @@ def stop_fn(mean_rewards): # here we define an imitation collector with a trivial policy if args.task == 'CartPole-v0': env.spec.reward_threshold = 190 # lower the goal - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, action_space=env.action_space) @@ -132,9 +148,18 @@ def stop_fn(mean_rewards): ) train_collector.reset() result = offpolicy_trainer( - il_policy, train_collector, il_test_collector, args.epoch, - args.il_step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) + il_policy, + train_collector, + il_test_collector, + args.epoch, + args.il_step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 3208e83c8..3c74723eb 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import C51Policy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -34,20 +35,20 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--resume', action="store_true") parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument("--save-interval", type=int, default=4) args = parser.parse_known_args()[0] return args @@ -60,29 +61,45 @@ def test_c51(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=True, num_atoms=args.num_atoms) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=True, + num_atoms=args.num_atoms + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = C51Policy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -117,12 +134,16 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) - pickle.dump(train_collector.buffer, - open(os.path.join(log_path, 'train_buffer.pkl'), "wb")) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) + pickle.dump( + train_collector.buffer, + open(os.path.join(log_path, 'train_buffer.pkl'), "wb") + ) if args.resume: # load from existing checkpoint @@ -144,11 +165,23 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, - resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index aae609ec6..6912a1933 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -31,22 +32,22 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--save-buffer-name', type=str, - default="./expert_DQN_CartPole-v0.pkl") + '--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl" + ) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -58,10 +59,12 @@ def test_dqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -69,19 +72,29 @@ def test_dqn(args=get_args()): test_envs.seed(args.seed) # Q_param = V_param = {"hidden_sizes": [128]} # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - # dueling=(Q_param, V_param), - ).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + # dueling=(Q_param, V_param), + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -116,10 +129,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index aa4fbbe0f..064dbba24 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import DQNPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv +from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -37,8 +38,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -50,26 +51,35 @@ def test_drqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Recurrent(args.layer_num, args.state_shape, - args.action_shape, args.device).to(args.device) + net = Recurrent(args.layer_num, args.state_shape, args.action_shape, + args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # collector buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - stack_num=args.stack_num, ignore_obs_next=True) + args.buffer_size, + buffer_num=len(train_envs), + stack_num=args.stack_num, + ignore_obs_next=True + ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -94,11 +104,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, - train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 0df2efb74..1763380f1 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import FQFPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -35,19 +36,17 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64]) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -59,22 +58,31 @@ def test_fqf(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - feature_net = Net(args.state_shape, args.hidden_sizes[-1], - hidden_sizes=args.hidden_sizes[:-1], device=args.device, - softmax=False) + feature_net = Net( + args.state_shape, + args.hidden_sizes[-1], + hidden_sizes=args.hidden_sizes[:-1], + device=args.device, + softmax=False + ) net = FullQuantileFunction( - feature_net, args.action_shape, args.hidden_sizes, - num_cosines=args.num_cosines, device=args.device + feature_net, + args.action_shape, + args.hidden_sizes, + num_cosines=args.num_cosines, + device=args.device ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) @@ -82,14 +90,24 @@ def test_fqf(args=get_args()): fraction_net.parameters(), lr=args.fraction_lr ) policy = FQFPolicy( - net, optim, fraction_net, fraction_optim, args.gamma, args.num_fractions, - args.ent_coef, args.n_step, target_update_freq=args.target_update_freq + net, + optim, + fraction_net, + fraction_optim, + args.gamma, + args.num_fractions, + args.ent_coef, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -124,11 +142,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index f404ce7dc..47540dadd 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteBCQPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -29,17 +30,18 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument( - "--load-buffer-name", type=str, + "--load-buffer-name", + type=str, default="./expert_DQN_CartPole-v0.pkl", ) parser.add_argument( - "--device", type=str, + "--device", + type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume", action="store_true") @@ -56,26 +58,39 @@ def test_discrete_bcq(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model policy_net = Net( - args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) imitation_net = Net( - args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), - lr=args.lr) + list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr + ) policy = DiscreteBCQPolicy( - policy_net, imitation_net, optim, args.gamma, args.n_step, - args.target_update_freq, args.eps_test, - args.unlikely_action_threshold, args.imitation_logits_penalty, + policy_net, + imitation_net, + optim, + args.gamma, + args.n_step, + args.target_update_freq, + args.eps_test, + args.unlikely_action_threshold, + args.imitation_logits_penalty, ) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -97,10 +112,12 @@ def stop_fn(mean_rewards): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) if args.resume: # load from existing checkpoint @@ -115,10 +132,19 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): print("Fail to restore policy and optim.") result = offline_trainer( - policy, buffer, test_collector, - args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py index 858d2b6f7..929469e8b 100644 --- a/test/discrete/test_il_crr.py +++ b/test/discrete/test_il_crr.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCRRPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -26,17 +27,18 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument( - "--load-buffer-name", type=str, + "--load-buffer-name", + type=str, default="./expert_DQN_CartPole-v0.pkl", ) parser.add_argument( - "--device", type=str, + "--device", + type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) args = parser.parse_known_args()[0] @@ -51,23 +53,36 @@ def test_discrete_crr(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - actor = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False) - critic = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False) - optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()), - lr=args.lr) + actor = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False + ) + critic = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False + ) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr + ) policy = DiscreteCRRPolicy( - actor, critic, optim, args.gamma, + actor, + critic, + optim, + args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) # buffer @@ -89,9 +104,17 @@ def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold result = offline_trainer( - policy, buffer, test_collector, - args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 0234c36f0..c93ddfc0d 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import IQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.discrete import ImplicitQuantileNetwork -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -35,19 +36,17 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64]) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -59,33 +58,50 @@ def test_iqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - feature_net = Net(args.state_shape, args.hidden_sizes[-1], - hidden_sizes=args.hidden_sizes[:-1], device=args.device, - softmax=False) + feature_net = Net( + args.state_shape, + args.hidden_sizes[-1], + hidden_sizes=args.hidden_sizes[:-1], + device=args.device, + softmax=False + ) net = ImplicitQuantileNetwork( - feature_net, args.action_shape, - num_cosines=args.num_cosines, device=args.device) + feature_net, + args.action_shape, + num_cosines=args.num_cosines, + device=args.device + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = IQNPolicy( - net, optim, args.gamma, args.sample_size, args.online_sample_size, - args.target_sample_size, args.n_step, + net, + optim, + args.gamma, + args.sample_size, + args.online_sample_size, + args.target_sample_size, + args.n_step, target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -120,11 +136,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 3c36bb265..fafd7cc49 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import PGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -26,16 +27,15 @@ def get_args(): parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument('--rew-norm', type=int, default=1) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -47,24 +47,35 @@ def test_pg(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=True + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy = PGPolicy(net, optim, dist, args.gamma, - reward_normalization=args.rew_norm, - action_space=env.action_space) + policy = PGPolicy( + net, + optim, + dist, + args.gamma, + reward_normalization=args.rew_norm, + action_space=env.action_space + ) for m in net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization @@ -72,8 +83,8 @@ def test_pg(args=get_args()): torch.nn.init.zeros_(m.bias) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'pg') @@ -88,10 +99,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 3658f364b..96650b14b 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic @@ -33,8 +34,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # ppo special parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--ent-coef', type=float, default=0.0) @@ -57,18 +58,19 @@ def test_ppo(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) # orthogonal initialization @@ -77,10 +79,14 @@ def test_ppo(args=get_args()): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam( - set(actor.parameters()).union(critic.parameters()), lr=args.lr) + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) dist = torch.distributions.Categorical policy = PPOPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -93,11 +99,12 @@ def test_ppo(args=get_args()): action_space=env.action_space, deterministic_eval=True, advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv) + recompute_advantage=args.recompute_adv + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') @@ -112,10 +119,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 2385db0ee..cf8d22212 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger -from tianshou.policy import QRDQNPolicy +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net +from tianshou.policy import QRDQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -32,22 +33,22 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--save-buffer-name', type=str, - default="./expert_QRDQN_CartPole-v0.pkl") + '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl" + ) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -61,29 +62,43 @@ def test_qrdqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False, num_atoms=args.num_quantiles) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False, + num_atoms=args.num_quantiles + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = QRDQNPolicy( - net, optim, args.gamma, args.num_quantiles, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -118,11 +133,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py index dbfd42aad..01b868f13 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/discrete/test_qrdqn_il_cql.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCQLPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -29,17 +30,18 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument( - "--load-buffer-name", type=str, + "--load-buffer-name", + type=str, default="./expert_QRDQN_CartPole-v0.pkl", ) parser.add_argument( - "--device", type=str, + "--device", + type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) args = parser.parse_known_args()[0] @@ -54,20 +56,31 @@ def test_discrete_cql(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False, num_atoms=args.num_quantiles) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False, + num_atoms=args.num_quantiles + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DiscreteCQLPolicy( - net, optim, args.gamma, args.num_quantiles, args.n_step, - args.target_update_freq, min_q_weight=args.min_q_weight + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + args.target_update_freq, + min_q_weight=args.min_q_weight ).to(args.device) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -88,9 +101,17 @@ def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold result = offline_trainer( - policy, buffer, test_collector, - args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 4fdcfd352..b226a025c 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -1,19 +1,20 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import RainbowPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -36,21 +37,21 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--beta-final', type=float, default=1.) parser.add_argument('--resume', action="store_true") parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument("--save-interval", type=int, default=4) args = parser.parse_known_args()[0] return args @@ -63,35 +64,56 @@ def test_rainbow(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model def noisy_linear(x, y): return NoisyLinear(x, y, args.noisy_std) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=True, num_atoms=args.num_atoms, - dueling_param=({"linear_layer": noisy_linear}, - {"linear_layer": noisy_linear})) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=True, + num_atoms=args.num_atoms, + dueling_param=({ + "linear_layer": noisy_linear + }, { + "linear_layer": noisy_linear + }) + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = RainbowPolicy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta, weight_norm=True) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta, + weight_norm=True + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -136,12 +158,16 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) - pickle.dump(train_collector.buffer, - open(os.path.join(log_path, 'train_buffer.pkl'), "wb")) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) + pickle.dump( + train_collector.buffer, + open(os.path.join(log_path, 'train_buffer.pkl'), "wb") + ) if args.resume: # load from existing checkpoint @@ -163,11 +189,23 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, - resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index d8abf48e4..41be36838 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net from tianshou.policy import DiscreteSACPolicy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -40,8 +41,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -52,27 +53,26 @@ def test_discrete_sac(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, - softmax_output=False, device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor(net, args.action_shape, softmax_output=False, + device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic1 = Critic(net_c1, last_size=args.action_shape, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic2 = Critic(net_c2, last_size=args.action_shape, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -85,13 +85,22 @@ def test_discrete_sac(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = DiscreteSACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.alpha, estimation_step=args.n_step, - reward_normalization=args.rew_norm) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + args.tau, + args.gamma, + args.alpha, + estimation_step=args.n_step, + reward_normalization=args.rew_norm + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -107,10 +116,20 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 01710d827..3a50f36e9 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -1,15 +1,16 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import PSRLPolicy -from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import PSRLPolicy +from tianshou.trainer import onpolicy_trainer def get_args(): @@ -42,10 +43,12 @@ def test_psrl(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -59,12 +62,15 @@ def test_psrl(args=get_args()): rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy = PSRLPolicy( trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps, - args.add_done_loop) + args.add_done_loop + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'psrl') @@ -80,11 +86,19 @@ def stop_fn(mean_rewards): train_collector.collect(n_step=args.buffer_size, random=True) # trainer, test it without logger result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, 1, args.test_num, 0, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + 1, + args.test_num, + 0, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, # logger=logger, - test_in_train=False) + test_in_train=False + ) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 6418ee8ec..7d1754245 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -1,17 +1,17 @@ import os import pprint -import numpy as np from copy import deepcopy + +import numpy as np +from tic_tac_toe import get_agents, get_parser, train_agent, watch +from tic_tac_toe_env import TicTacToeEnv from torch.utils.tensorboard import SummaryWriter -from tianshou.env import DummyVectorEnv from tianshou.data import Collector +from tianshou.env import DummyVectorEnv from tianshou.policy import RandomPolicy from tianshou.utils import TensorboardLogger -from tic_tac_toe_env import TicTacToeEnv -from tic_tac_toe import get_parser, get_agents, train_agent, watch - def get_args(): parser = get_parser() @@ -39,6 +39,7 @@ def gomoku(args=get_args()): def env_func(): return TicTacToeEnv(args.board_size, args.win_size) + test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) for r in range(args.self_play_round): rews = [] @@ -65,11 +66,11 @@ def env_func(): # previous learner can only be used for forward agent.forward = opponent.forward args.model_save_path = os.path.join( - args.logdir, 'Gomoku', 'dqn', - f'policy_round_{r}_epoch_{epoch}.pth') + args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth' + ) result, agent_learn = train_agent( - args, agent_learn=agent_learn, - agent_opponent=agent, optim=optim) + args, agent_learn=agent_learn, agent_opponent=agent, optim=optim + ) print(f'round_{r}_epoch_{epoch}') pprint.pprint(result) learnt_agent = deepcopy(agent_learn) diff --git a/test/multiagent/test_tic_tac_toe.py b/test/multiagent/test_tic_tac_toe.py index 1cc06d374..aeb4644e1 100644 --- a/test/multiagent/test_tic_tac_toe.py +++ b/test/multiagent/test_tic_tac_toe.py @@ -1,4 +1,5 @@ import pprint + from tic_tac_toe import get_args, train_agent, watch diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index dcd293a06..02fd47cd7 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -1,20 +1,24 @@ -import os -import torch import argparse -import numpy as np +import os from copy import deepcopy from typing import Optional, Tuple + +import numpy as np +import torch +from tic_tac_toe_env import TicTacToeEnv from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net +from tianshou.policy import ( + BasePolicy, + DQNPolicy, + MultiAgentPolicyManager, + RandomPolicy, +) from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \ - MultiAgentPolicyManager - -from tic_tac_toe_env import TicTacToeEnv +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_parser() -> argparse.ArgumentParser: @@ -24,8 +28,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9, - help='a smaller gamma favors earlier win') + parser.add_argument( + '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' + ) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) @@ -33,31 +38,49 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) parser.add_argument('--board-size', type=int, default=6) parser.add_argument('--win-size', type=int, default=4) - parser.add_argument('--win-rate', type=float, default=0.9, - help='the expected winning rate') - parser.add_argument('--watch', default=False, action='store_true', - help='no training, ' - 'watch the play of pre-trained models') - parser.add_argument('--agent-id', type=int, default=2, - help='the learned agent plays as the' - ' agent_id-th player. Choices are 1 and 2.') - parser.add_argument('--resume-path', type=str, default='', - help='the path of agent pth file ' - 'for resuming from a pre-trained agent') - parser.add_argument('--opponent-path', type=str, default='', - help='the path of opponent agent pth file ' - 'for resuming from a pre-trained agent') parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--win-rate', type=float, default=0.9, help='the expected winning rate' + ) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='no training, ' + 'watch the play of pre-trained models' + ) + parser.add_argument( + '--agent-id', + type=int, + default=2, + help='the learned agent plays as the' + ' agent_id-th player. Choices are 1 and 2.' + ) + parser.add_argument( + '--resume-path', + type=str, + default='', + help='the path of agent pth file ' + 'for resuming from a pre-trained agent' + ) + parser.add_argument( + '--opponent-path', + type=str, + default='', + help='the path of opponent agent pth file ' + 'for resuming from a pre-trained agent' + ) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser @@ -77,14 +100,21 @@ def get_agents( args.action_shape = env.action_space.shape or env.action_space.n if agent_learn is None: # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device - ).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) if optim is None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) agent_learn = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) if args.resume_path: agent_learn.load_state_dict(torch.load(args.resume_path)) @@ -109,8 +139,10 @@ def train_agent( agent_opponent: Optional[BasePolicy] = None, optim: Optional[torch.optim.Optimizer] = None, ) -> Tuple[dict, BasePolicy]: + def env_func(): return TicTacToeEnv(args.board_size, args.win_size) + train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)]) test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) # seed @@ -120,14 +152,16 @@ def env_func(): test_envs.seed(args.seed) policy, optim = get_agents( - args, agent_learn=agent_learn, - agent_opponent=agent_opponent, optim=optim) + args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -142,10 +176,9 @@ def save_fn(policy): model_save_path = args.model_save_path else: model_save_path = os.path.join( - args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth') - torch.save( - policy.policies[args.agent_id - 1].state_dict(), - model_save_path) + args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth' + ) + torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path) def stop_fn(mean_rewards): return mean_rewards >= args.win_rate @@ -161,11 +194,23 @@ def reward_metric(rews): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, - logger=logger, test_in_train=False, reward_metric=reward_metric) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric + ) return result, policy.policies[args.agent_id - 1] @@ -177,7 +222,8 @@ def watch( ) -> None: env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( - args, agent_learn=agent_learn, agent_opponent=agent_opponent) + args, agent_learn=agent_learn, agent_opponent=agent_opponent + ) policy.eval() policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) diff --git a/test/multiagent/tic_tac_toe_env.py b/test/multiagent/tic_tac_toe_env.py index 2fc045afa..2c79d303d 100644 --- a/test/multiagent/tic_tac_toe_env.py +++ b/test/multiagent/tic_tac_toe_env.py @@ -1,7 +1,8 @@ +from functools import partial +from typing import Optional, Tuple + import gym import numpy as np -from functools import partial -from typing import Tuple, Optional from tianshou.env import MultiAgentEnv @@ -27,7 +28,8 @@ def __init__(self, size: int = 3, win_size: int = 3): f'be larger than board size {size}' self.convolve_kernel = np.ones(win_size) self.observation_space = gym.spaces.Box( - low=-1.0, high=1.0, shape=(size, size), dtype=np.float32) + low=-1.0, high=1.0, shape=(size, size), dtype=np.float32 + ) self.action_space = gym.spaces.Discrete(size * size) self.current_board = None self.current_agent = None @@ -45,11 +47,10 @@ def reset(self) -> dict: 'mask': self.current_board.flatten() == 0 } - def step(self, action: [int, np.ndarray] - ) -> Tuple[dict, np.ndarray, np.ndarray, dict]: + def step(self, action: [int, + np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]: if self.current_agent is None: - raise ValueError( - "calling step() of unreset environment is prohibited!") + raise ValueError("calling step() of unreset environment is prohibited!") assert 0 <= action < self.size * self.size assert self.current_board.item(action) == 0 _current_agent = self.current_agent @@ -97,18 +98,28 @@ def _test_win(self): rboard = self.current_board[row, :] cboard = self.current_board[:, col] current = self.current_board[row, col] - rightup = [self.current_board[row - i, col + i] - for i in range(1, self.size - col) if row - i >= 0] - leftdown = [self.current_board[row + i, col - i] - for i in range(1, col + 1) if row + i < self.size] + rightup = [ + self.current_board[row - i, col + i] for i in range(1, self.size - col) + if row - i >= 0 + ] + leftdown = [ + self.current_board[row + i, col - i] for i in range(1, col + 1) + if row + i < self.size + ] rdiag = np.array(leftdown[::-1] + [current] + rightup) - rightdown = [self.current_board[row + i, col + i] - for i in range(1, self.size - col) if row + i < self.size] - leftup = [self.current_board[row - i, col - i] - for i in range(1, col + 1) if row - i >= 0] + rightdown = [ + self.current_board[row + i, col + i] for i in range(1, self.size - col) + if row + i < self.size + ] + leftup = [ + self.current_board[row - i, col - i] for i in range(1, col + 1) + if row - i >= 0 + ] diag = np.array(leftup[::-1] + [current] + rightdown) - results = [np.convolve(k, self.convolve_kernel, mode='valid') - for k in (rboard, cboard, rdiag, diag)] + results = [ + np.convolve(k, self.convolve_kernel, mode='valid') + for k in (rboard, cboard, rdiag, diag) + ] return any([(np.abs(x) == self.win_size).any() for x in results]) def seed(self, seed: Optional[int] = None) -> int: @@ -128,6 +139,7 @@ def f(i, data): if number == -1: return 'O' if last_move else 'o' return '_' + for i, row in enumerate(self.current_board): print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad) print(top) diff --git a/test/throughput/test_batch_profile.py b/test/throughput/test_batch_profile.py index 9654f5838..fbd6fb89c 100644 --- a/test/throughput/test_batch_profile.py +++ b/test/throughput/test_batch_profile.py @@ -12,13 +12,20 @@ def data(): print("Initialising data...") np.random.seed(0) - batch_set = [Batch(a=[j for j in np.arange(1e3)], - b={'b1': (3.14, 3.14), 'b2': np.arange(1e3)}, - c=i) for i in np.arange(int(1e4))] + batch_set = [ + Batch( + a=[j for j in np.arange(1e3)], + b={ + 'b1': (3.14, 3.14), + 'b2': np.arange(1e3) + }, + c=i + ) for i in np.arange(int(1e4)) + ] batch0 = Batch( a=np.ones((3, 4), dtype=np.float64), b=Batch( - c=np.ones((1,), dtype=np.float64), + c=np.ones((1, ), dtype=np.float64), d=torch.ones((3, 3, 3), dtype=torch.float32), e=list(range(3)) ) @@ -26,19 +33,25 @@ def data(): batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batch_len = int(1e4) - batch3 = Batch(obs=[np.arange(20) for _ in np.arange(batch_len)], - reward=np.arange(batch_len)) - indexs = np.random.choice(batch_len, - size=batch_len // 10, replace=False) - slice_dict = {'obs': [np.arange(20) - for _ in np.arange(batch_len // 10)], - 'reward': np.arange(batch_len // 10)} - dict_set = [{'obs': np.arange(20), 'info': "this is info", 'reward': 0} - for _ in np.arange(1e2)] + batch3 = Batch( + obs=[np.arange(20) for _ in np.arange(batch_len)], reward=np.arange(batch_len) + ) + indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False) + slice_dict = { + 'obs': [np.arange(20) for _ in np.arange(batch_len // 10)], + 'reward': np.arange(batch_len // 10) + } + dict_set = [ + { + 'obs': np.arange(20), + 'info': "this is info", + 'reward': 0 + } for _ in np.arange(1e2) + ] batch4 = Batch( a=np.ones((10000, 4), dtype=np.float64), b=Batch( - c=np.ones((1,), dtype=np.float64), + c=np.ones((1, ), dtype=np.float64), d=torch.ones((1000, 1000), dtype=torch.float32), e=np.arange(1000) ) diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index 40ce68889..2bb00c143 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,8 +1,10 @@ import sys -import gym import time -import tqdm + +import gym import numpy as np +import tqdm + from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer @@ -12,7 +14,7 @@ def test_replaybuffer(task="Pendulum-v0"): env = gym.make(task) buf = ReplayBuffer(10000) obs = env.reset() - for i in range(100000): + for _ in range(100000): act = env.action_space.sample() obs_next, rew, done, info = env.step(act) batch = Batch( @@ -35,7 +37,7 @@ def test_vectorbuffer(task="Pendulum-v0"): env = gym.make(task) buf = VectorReplayBuffer(total_size=10000, buffer_num=1) obs = env.reset() - for i in range(100000): + for _ in range(100000): act = env.action_space.sample() obs_next, rew, done, info = env.step(act) batch = Batch( diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index 6242e694b..bf9c4dc05 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -1,9 +1,9 @@ -import tqdm import numpy as np +import tqdm -from tianshou.policy import BasePolicy +from tianshou.data import AsyncCollector, Batch, Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Batch, Collector, AsyncCollector, VectorReplayBuffer +from tianshou.policy import BasePolicy if __name__ == '__main__': from env import MyTestEnv @@ -12,6 +12,7 @@ class MyPolicy(BasePolicy): + def __init__(self, dict_state=False, need_state=True): """ :param bool dict_state: if the observation of the environment is a dict @@ -40,8 +41,7 @@ def test_collector_nstep(): env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] dum = DummyVectorEnv(env_fns) num = len(env_fns) - c3 = Collector(policy, dum, - VectorReplayBuffer(total_size=40000, buffer_num=num)) + c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num)) for i in tqdm.trange(1, 400, desc="test step collector n_step"): c3.reset() result = c3.collect(n_step=i * len(env_fns)) @@ -53,8 +53,7 @@ def test_collector_nepisode(): env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] dum = DummyVectorEnv(env_fns) num = len(env_fns) - c3 = Collector(policy, dum, - VectorReplayBuffer(total_size=40000, buffer_num=num)) + c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num)) for i in tqdm.trange(1, 400, desc="test step collector n_episode"): c3.reset() result = c3.collect(n_episode=i) @@ -64,22 +63,22 @@ def test_collector_nepisode(): def test_asynccollector(): env_lens = [2, 3, 4, 5] - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) - for i in env_lens] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens + ] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() bufsize = 300 c1 = AsyncCollector( - policy, venv, - VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)) + policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4) + ) ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 100, desc="test async n_episode"): result = c1.collect(n_episode=n_episode) assert result["n/ep"] >= n_episode # check buffer data, obs and obs_next, env_id - for i, count in enumerate( - np.bincount(result["lens"], minlength=6)[2:]): + for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize @@ -88,8 +87,7 @@ def test_asynccollector(): buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.obs[indices].reshape(count, env_len) == seq) - assert np.all(buf.obs_next[indices].reshape( - count, env_len) == seq + 1) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 150, desc="test async n_step"): result = c1.collect(n_step=n_step) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index fb362d054..3430fb09b 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,6 @@ -from tianshou import data, env, utils, policy, trainer, exploration +from tianshou import data, env, exploration, policy, trainer, utils - -__version__ = "0.4.2" +__version__ = "0.4.3" __all__ = [ "env", diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 75e02a940..89250d009 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,12 +1,19 @@ +"""Data package.""" +# isort:skip_file + from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer.base import ReplayBuffer from tianshou.data.buffer.prio import PrioritizedReplayBuffer -from tianshou.data.buffer.manager import ReplayBufferManager -from tianshou.data.buffer.manager import PrioritizedReplayBufferManager -from tianshou.data.buffer.vecbuf import VectorReplayBuffer -from tianshou.data.buffer.vecbuf import PrioritizedVectorReplayBuffer +from tianshou.data.buffer.manager import ( + ReplayBufferManager, + PrioritizedReplayBufferManager, +) +from tianshou.data.buffer.vecbuf import ( + VectorReplayBuffer, + PrioritizedVectorReplayBuffer, +) from tianshou.data.buffer.cached import CachedReplayBuffer from tianshou.data.collector import Collector, AsyncCollector diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f0ce76c2b..98adf680b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,11 +1,12 @@ -import torch import pprint import warnings -import numpy as np +from collections.abc import Collection from copy import deepcopy from numbers import Number -from collections.abc import Collection -from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Union + +import numpy as np +import torch IndexType = Union[slice, int, np.ndarray, List[int]] @@ -18,8 +19,7 @@ def _is_batch_set(data: Any) -> bool: # "for e in data" will just unpack the first dimension, # but data.tolist() will flatten ndarray of objects # so do not use data.tolist() - return data.dtype == object and all( - isinstance(e, (dict, Batch)) for e in data) + return data.dtype == object and all(isinstance(e, (dict, Batch)) for e in data) elif isinstance(data, (list, tuple)): if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): return True @@ -53,7 +53,7 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: return v # most often case # convert the value to np.ndarray # convert to object data type if neither bool nor number - # raises an exception if array's elements are tensors themself + # raises an exception if array's elements are tensors themselves v = np.asanyarray(v) if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(object) @@ -73,7 +73,9 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: def _create_value( - inst: Any, size: int, stack: bool = True + inst: Any, + size: int, + stack: bool = True, ) -> Union["Batch", np.ndarray, torch.Tensor]: """Create empty place-holders accroding to inst's shape. @@ -92,11 +94,10 @@ def _create_value( shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) if isinstance(inst, np.ndarray): target_type = inst.dtype.type if issubclass( - inst.dtype.type, (np.bool_, np.number)) else object + inst.dtype.type, (np.bool_, np.number) + ) else object return np.full( - shape, - fill_value=None if target_type == object else 0, - dtype=target_type + shape, fill_value=None if target_type == object else 0, dtype=target_type ) elif isinstance(inst, torch.Tensor): return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) @@ -133,8 +134,10 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: try: return torch.stack(v) # type: ignore except RuntimeError as e: - raise TypeError("Batch does not support non-stackable iterable" - " of torch.Tensor as unique value yet.") from e + raise TypeError( + "Batch does not support non-stackable iterable" + " of torch.Tensor as unique value yet." + ) from e if _is_batch_set(v): v = Batch(v) # list of dict / Batch else: @@ -143,8 +146,10 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: try: v = _to_array_with_correct_type(v) except ValueError as e: - raise TypeError("Batch does not support heterogeneous list/" - "tuple of tensors as unique value yet.") from e + raise TypeError( + "Batch does not support heterogeneous list/" + "tuple of tensors as unique value yet." + ) from e return v @@ -164,20 +169,18 @@ def _alloc_by_keys_diff( class Batch: """The internal data structure in Tianshou. - Batch is a kind of supercharged array (of temporal data) stored - individually in a (recursive) dictionary of object that can be either numpy - array, torch tensor, or batch themself. It is designed to make it extremely - easily to access, manipulate and set partial view of the heterogeneous data - conveniently. + Batch is a kind of supercharged array (of temporal data) stored individually in a + (recursive) dictionary of object that can be either numpy array, torch tensor, or + batch themselves. It is designed to make it extremely easily to access, manipulate + and set partial view of the heterogeneous data conveniently. For a detailed description, please refer to :ref:`batch_concept`. """ def __init__( self, - batch_dict: Optional[ - Union[dict, "Batch", Sequence[Union[dict, "Batch"]], np.ndarray] - ] = None, + batch_dict: Optional[Union[dict, "Batch", Sequence[Union[dict, "Batch"]], + np.ndarray]] = None, copy: bool = False, **kwargs: Any, ) -> None: @@ -248,11 +251,12 @@ def __setitem__(self, index: Union[str, IndexType], value: Any) -> None: self.__dict__[index] = value return if not isinstance(value, Batch): - raise ValueError("Batch does not supported tensor assignment. " - "Use a compatible Batch or dict instead.") - if not set(value.keys()).issubset(self.__dict__.keys()): raise ValueError( - "Creating keys is not supported by item assignment.") + "Batch does not supported tensor assignment. " + "Use a compatible Batch or dict instead." + ) + if not set(value.keys()).issubset(self.__dict__.keys()): + raise ValueError("Creating keys is not supported by item assignment.") for key, val in self.items(): try: self.__dict__[key][index] = value[key] @@ -368,9 +372,7 @@ def to_torch( v = v.type(dtype) self.__dict__[k] = v - def __cat( - self, batches: Sequence[Union[dict, "Batch"]], lens: List[int] - ) -> None: + def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None: """Private method for Batch.cat_. :: @@ -397,9 +399,11 @@ def __cat( sum_lens.append(sum_lens[-1] + x) # collect non-empty keys keys_map = [ - set(k for k, v in batch.items() - if not (isinstance(v, Batch) and v.is_empty())) - for batch in batches] + set( + k for k, v in batch.items() + if not (isinstance(v, Batch) and v.is_empty()) + ) for batch in batches + ] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] for k, v in zip(keys_shared, values_shared): @@ -433,8 +437,7 @@ def __cat( try: self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val except KeyError: - self.__dict__[k] = _create_value( - val, sum_lens[-1], stack=False) + self.__dict__[k] = _create_value(val, sum_lens[-1], stack=False) self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: @@ -465,7 +468,8 @@ def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: raise ValueError( "Batch.cat_ meets an exception. Maybe because there is any " f"scalar in {batches} but Batch.cat_ does not support the " - "concatenation of scalar.") from e + "concatenation of scalar." + ) from e if not self.is_empty(): batches = [self] + list(batches) lens = [0 if self.is_empty(recurse=True) else len(self)] + lens @@ -506,8 +510,7 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None if not b.is_empty(): batch_list.append(b) else: - raise ValueError( - f"Cannot concatenate {type(b)} in Batch.stack_") + raise ValueError(f"Cannot concatenate {type(b)} in Batch.stack_") if len(batch_list) == 0: return batches = batch_list @@ -515,9 +518,11 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None batches = [self] + batches # collect non-empty keys keys_map = [ - set(k for k, v in batch.items() - if not (isinstance(v, Batch) and v.is_empty())) - for batch in batches] + set( + k for k, v in batch.items() + if not (isinstance(v, Batch) and v.is_empty()) + ) for batch in batches + ] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] for k, v in zip(keys_shared, values_shared): @@ -529,8 +534,10 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None try: self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis)) except ValueError: - warnings.warn("You are using tensors with different shape," - " fallback to dtype=object by default.") + warnings.warn( + "You are using tensors with different shape," + " fallback to dtype=object by default." + ) self.__dict__[k] = np.array(v, dtype=object) # all the keys keys_total = set.union(*[set(b.keys()) for b in batches]) @@ -543,7 +550,8 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None if keys_partial and axis != 0: raise ValueError( f"Stack of Batch with non-shared keys {keys_partial} is only " - f"supported with axis=0, but got axis={axis}!") + f"supported with axis=0, but got axis={axis}!" + ) for k in keys_reserve: # reserved keys self.__dict__[k] = Batch() @@ -625,8 +633,10 @@ def empty_(self, index: Optional[Union[slice, IndexType]] = None) -> "Batch": elif isinstance(v, Batch): self.__dict__[k].empty_(index=index) else: # scalar value - warnings.warn("You are calling Batch.empty on a NumPy scalar, " - "which may cause undefined behaviors.") + warnings.warn( + "You are calling Batch.empty on a NumPy scalar, " + "which may cause undefined behaviors." + ) if _is_number(v): self.__dict__[k] = v.__class__(0) else: @@ -701,7 +711,8 @@ def is_empty(self, recurse: bool = False) -> bool: return False return all( False if not isinstance(x, Batch) else x.is_empty(recurse=True) - for x in self.values()) + for x in self.values() + ) @property def shape(self) -> List[int]: @@ -718,9 +729,10 @@ def shape(self) -> List[int]: return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ else data_shape[0] - def split( - self, size: int, shuffle: bool = True, merge_last: bool = False - ) -> Iterator["Batch"]: + def split(self, + size: int, + shuffle: bool = True, + merge_last: bool = False) -> Iterator["Batch"]: """Split whole data into multiple small batches. :param int size: divide the data batch with the given size, but one diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index d7fde0e07..381581ce9 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -1,10 +1,11 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + import h5py import numpy as np -from typing import Any, Dict, List, Tuple, Union, Optional from tianshou.data import Batch -from tianshou.data.utils.converter import to_hdf5, from_hdf5 -from tianshou.data.batch import _create_value, _alloc_by_keys_diff +from tianshou.data.batch import _alloc_by_keys_diff, _create_value +from tianshou.data.utils.converter import from_hdf5, to_hdf5 class ReplayBuffer: @@ -81,9 +82,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" - assert ( - key not in self._reserved_keys - ), "key '{}' is reserved and cannot be assigned".format(key) + assert (key not in self._reserved_keys + ), "key '{}' is reserved and cannot be assigned".format(key) super().__setattr__(key, value) def save_hdf5(self, path: str) -> None: @@ -160,9 +160,8 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _add_index( - self, rew: Union[float, np.ndarray], done: bool - ) -> Tuple[int, Union[float, np.ndarray], int, int]: + def _add_index(self, rew: Union[float, np.ndarray], + done: bool) -> Tuple[int, Union[float, np.ndarray], int, int]: """Maintain the buffer's state after adding one data batch. Return (index_to_be_modified, episode_reward, episode_length, @@ -183,7 +182,9 @@ def _add_index( return ptr, self._ep_rew * 0.0, 0, self._ep_idx def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. @@ -246,7 +247,8 @@ def sample_indices(self, batch_size: int) -> np.ndarray: return np.random.choice(self._size, batch_size) elif batch_size == 0: # construct current available indices return np.concatenate( - [np.arange(self._index, self._size), np.arange(self._index)] + [np.arange(self._index, self._size), + np.arange(self._index)] ) else: return np.array([], int) @@ -254,7 +256,8 @@ def sample_indices(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], int) all_indices = prev_indices = np.concatenate( - [np.arange(self._index, self._size), np.arange(self._index)] + [np.arange(self._index, self._size), + np.arange(self._index)] ) for _ in range(self.stack_num - 2): prev_indices = self.prev(prev_indices) diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py index 49bb33bcf..19310dbd0 100644 --- a/tianshou/data/buffer/cached.py +++ b/tianshou/data/buffer/cached.py @@ -1,5 +1,6 @@ +from typing import List, Optional, Tuple, Union + import numpy as np -from typing import List, Tuple, Union, Optional from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager @@ -45,7 +46,9 @@ def __init__( self.cached_buffer_num = cached_buffer_num def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index dc93b6867..70ebcab03 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -1,9 +1,10 @@ +from typing import List, Optional, Sequence, Tuple, Union + import numpy as np from numba import njit -from typing import List, Tuple, Union, Sequence, Optional -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer -from tianshou.data.batch import _create_value, _alloc_by_keys_diff +from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer +from tianshou.data.batch import _alloc_by_keys_diff, _create_value class ReplayBufferManager(ReplayBuffer): @@ -63,33 +64,45 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def unfinished_index(self) -> np.ndarray: - return np.concatenate([ - buf.unfinished_index() + offset - for offset, buf in zip(self._offset, self.buffers) - ]) + return np.concatenate( + [ + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers) + ] + ) def prev(self, index: Union[int, np.ndarray]) -> np.ndarray: if isinstance(index, (list, np.ndarray)): - return _prev_index(np.asarray(index), self._extend_offset, - self.done, self.last_index, self._lengths) + return _prev_index( + np.asarray(index), self._extend_offset, self.done, self.last_index, + self._lengths + ) else: - return _prev_index(np.array([index]), self._extend_offset, - self.done, self.last_index, self._lengths)[0] + return _prev_index( + np.array([index]), self._extend_offset, self.done, self.last_index, + self._lengths + )[0] def next(self, index: Union[int, np.ndarray]) -> np.ndarray: if isinstance(index, (list, np.ndarray)): - return _next_index(np.asarray(index), self._extend_offset, - self.done, self.last_index, self._lengths) + return _next_index( + np.asarray(index), self._extend_offset, self.done, self.last_index, + self._lengths + ) else: - return _next_index(np.array([index]), self._extend_offset, - self.done, self.last_index, self._lengths)[0] + return _next_index( + np.array([index]), self._extend_offset, self.done, self.last_index, + self._lengths + )[0] def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" raise NotImplementedError def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. @@ -145,10 +158,12 @@ def sample_indices(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], int) if self._sample_avail and self.stack_num > 1: - all_indices = np.concatenate([ - buf.sample_indices(0) + offset - for offset, buf in zip(self._offset, self.buffers) - ]) + all_indices = np.concatenate( + [ + buf.sample_indices(0) + offset + for offset, buf in zip(self._offset, self.buffers) + ] + ) if batch_size == 0: return all_indices else: @@ -163,10 +178,12 @@ def sample_indices(self, batch_size: int) -> np.ndarray: # avoid batch_size > 0 and sample_num == 0 -> get child's all data sample_num[sample_num == 0] = -1 - return np.concatenate([ - buf.sample_indices(bsz) + offset - for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) - ]) + return np.concatenate( + [ + buf.sample_indices(bsz) + offset + for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) + ] + ) class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index a4357d5ed..fa3c49be8 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -1,8 +1,9 @@ -import torch +from typing import Any, List, Optional, Tuple, Union + import numpy as np -from typing import Any, List, Tuple, Union, Optional +import torch -from tianshou.data import Batch, SegmentTree, to_numpy, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, SegmentTree, to_numpy class PrioritizedReplayBuffer(ReplayBuffer): @@ -39,7 +40,7 @@ def __init__( self._weight_norm = weight_norm def init_weight(self, index: Union[int, np.ndarray]) -> None: - self.weight[index] = self._max_prio ** self._alpha + self.weight[index] = self._max_prio**self._alpha def update(self, buffer: ReplayBuffer) -> np.ndarray: indices = super().update(buffer) @@ -47,7 +48,9 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray: return indices def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) self.init_weight(ptr) @@ -63,14 +66,14 @@ def sample_indices(self, batch_size: int) -> np.ndarray: def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: """Get the importance sampling weight. - The "weight" in the returned Batch is the weight on loss function to de-bias + The "weight" in the returned Batch is the weight on loss function to debias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) - return (self.weight[index] / self._min_prio) ** (-self._beta) + return (self.weight[index] / self._min_prio)**(-self._beta) def update_weight( self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor] @@ -81,7 +84,7 @@ def update_weight( :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps - self.weight[index] = weight ** self._alpha + self.weight[index] = weight**self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py index 374765bdc..2d4831c06 100644 --- a/tianshou/data/buffer/vecbuf.py +++ b/tianshou/data/buffer/vecbuf.py @@ -1,8 +1,13 @@ -import numpy as np from typing import Any -from tianshou.data import ReplayBuffer, ReplayBufferManager -from tianshou.data import PrioritizedReplayBuffer, PrioritizedReplayBufferManager +import numpy as np + +from tianshou.data import ( + PrioritizedReplayBuffer, + PrioritizedReplayBufferManager, + ReplayBuffer, + ReplayBufferManager, +) class VectorReplayBuffer(ReplayBufferManager): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 192213aa3..d52cc511d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,21 +1,22 @@ -import gym import time -import torch import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import gym import numpy as np -from typing import Any, Dict, List, Union, Optional, Callable +import torch -from tianshou.policy import BasePolicy -from tianshou.data.batch import _alloc_by_keys_diff -from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.data import ( Batch, + CachedReplayBuffer, ReplayBuffer, ReplayBufferManager, VectorReplayBuffer, - CachedReplayBuffer, to_numpy, ) +from tianshou.data.batch import _alloc_by_keys_diff +from tianshou.env import BaseVectorEnv, DummyVectorEnv +from tianshou.policy import BasePolicy class Collector(object): @@ -97,8 +98,9 @@ def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch(obs={}, act={}, rew={}, done={}, - obs_next={}, info={}, policy={}) + self.data = Batch( + obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} + ) self.reset_env() self.reset_buffer() self.reset_stat() @@ -115,8 +117,8 @@ def reset_env(self) -> None: """Reset all of the environments.""" obs = self.env.reset() if self.preprocess_fn: - obs = self.preprocess_fn( - obs=obs, env_id=np.arange(self.env_num)).get("obs", obs) + obs = self.preprocess_fn(obs=obs, + env_id=np.arange(self.env_num)).get("obs", obs) self.data.obs = obs def _reset_state(self, id: Union[int, List[int]]) -> None: @@ -184,8 +186,10 @@ def collect( ready_env_ids = np.arange(min(self.env_num, n_episode)) self.data = self.data[:min(self.env_num, n_episode)] else: - raise TypeError("Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().") + raise TypeError( + "Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect()." + ) start_time = time.time() @@ -203,7 +207,8 @@ def collect( # get the next action if random: self.data.update( - act=[self._action_space[i].sample() for i in ready_env_ids]) + act=[self._action_space[i].sample() for i in ready_env_ids] + ) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -225,19 +230,21 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids) # type: ignore + result = self.env.step(action_remap, ready_env_ids) # type: ignore + obs_next, rew, done, info = result self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update(self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - policy=self.data.policy, - env_id=ready_env_ids, - )) + self.data.update( + self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + policy=self.data.policy, + env_id=ready_env_ids, + ) + ) if render: self.env.render() @@ -246,7 +253,8 @@ def collect( # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add( - self.data, buffer_ids=ready_env_ids) + self.data, buffer_ids=ready_env_ids + ) # collect statistics step_count += len(ready_env_ids) @@ -263,7 +271,8 @@ def collect( obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset) + obs=obs_reset, env_id=env_ind_global + ).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) @@ -290,13 +299,18 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if n_episode: - self.data = Batch(obs={}, act={}, rew={}, done={}, - obs_next={}, info={}, policy={}) + self.data = Batch( + obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} + ) self.reset_env() if episode_count > 0: - rews, lens, idxs = list(map( - np.concatenate, [episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list( + map( + np.concatenate, + [episode_rews, episode_lens, episode_start_indices] + ) + ) else: rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) @@ -377,8 +391,10 @@ def collect( elif n_episode is not None: assert n_episode > 0 else: - raise TypeError("Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().") + raise TypeError( + "Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect()." + ) warnings.warn("Using async setting may collect extra transitions into buffer.") ready_env_ids = self._ready_env_ids @@ -401,7 +417,8 @@ def collect( # get the next action if random: self.data.update( - act=[self._action_space[i].sample() for i in ready_env_ids]) + act=[self._action_space[i].sample() for i in ready_env_ids] + ) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -431,8 +448,8 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids) # type: ignore + result = self.env.step(action_remap, ready_env_ids) # type: ignore + obs_next, rew, done, info = result # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) @@ -440,13 +457,15 @@ def collect( self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update(self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - env_id=ready_env_ids, - )) + self.data.update( + self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + env_id=ready_env_ids, + ) + ) if render: self.env.render() @@ -455,7 +474,8 @@ def collect( # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add( - self.data, buffer_ids=ready_env_ids) + self.data, buffer_ids=ready_env_ids + ) # collect statistics step_count += len(ready_env_ids) @@ -472,7 +492,8 @@ def collect( obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset) + obs=obs_reset, env_id=env_ind_global + ).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) @@ -500,8 +521,12 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if episode_count > 0: - rews, lens, idxs = list(map( - np.concatenate, [episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list( + map( + np.concatenate, + [episode_rews, episode_lens, episode_start_indices] + ) + ) else: rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 9f7d88a82..bd1dd5358 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -1,12 +1,13 @@ -import h5py -import torch import pickle -import numpy as np from copy import deepcopy from numbers import Number -from typing import Any, Dict, Union, Optional +from typing import Any, Dict, Optional, Union + +import h5py +import numpy as np +import torch -from tianshou.data.batch import _parse_value, Batch +from tianshou.data.batch import Batch, _parse_value def to_numpy(x: Any) -> Union[Batch, np.ndarray]: diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py index 5bb6fcc06..063675c53 100644 --- a/tianshou/data/utils/segtree.py +++ b/tianshou/data/utils/segtree.py @@ -1,6 +1,7 @@ +from typing import Optional, Union + import numpy as np from numba import njit -from typing import Union, Optional class SegmentTree: @@ -29,9 +30,7 @@ def __init__(self, size: int) -> None: def __len__(self) -> int: return self._size - def __getitem__( - self, index: Union[int, np.ndarray] - ) -> Union[float, np.ndarray]: + def __getitem__(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: """Return self[index].""" return self._value[index + self._bound] @@ -64,9 +63,8 @@ def reduce(self, start: int = 0, end: Optional[int] = None) -> float: end += self._size return _reduce(self._value, start + self._bound - 1, end + self._bound) - def get_prefix_sum_idx( - self, value: Union[float, np.ndarray] - ) -> Union[int, np.ndarray]: + def get_prefix_sum_idx(self, value: Union[float, + np.ndarray]) -> Union[int, np.ndarray]: r"""Find the index with given value. Return the minimum index for each ``v`` in ``value`` so that @@ -122,9 +120,7 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float: @njit -def _get_prefix_sum_idx( - value: np.ndarray, bound: int, sums: np.ndarray -) -> np.ndarray: +def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray: """Numba version (v0.51), 5x speed up with size=100000 and bsz=64. vectorized np: 0.0923 (numpy best) -> 0.024 (now) diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index a25e06f86..c77c30c3f 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,6 +1,13 @@ -from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \ - SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv +"""Env package.""" + from tianshou.env.maenv import MultiAgentEnv +from tianshou.env.venvs import ( + BaseVectorEnv, + DummyVectorEnv, + RayVectorEnv, + ShmemVectorEnv, + SubprocVectorEnv, +) __all__ = [ "BaseVectorEnv", diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py index f6a454c3c..456bbca13 100644 --- a/tianshou/env/maenv.py +++ b/tianshou/env/maenv.py @@ -1,7 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + import gym import numpy as np -from typing import Any, Dict, Tuple -from abc import ABC, abstractmethod class MultiAgentEnv(ABC, gym.Env): diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index 5c873ce36..cec159012 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -1,6 +1,7 @@ -import cloudpickle from typing import Any +import cloudpickle + class CloudpickleWrapper(object): """A cloudpickle wrapper used in SubprocVectorEnv.""" diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index f9349ff24..654f55b69 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,10 +1,15 @@ +from typing import Any, Callable, List, Optional, Tuple, Union + import gym import numpy as np -from typing import Any, List, Tuple, Union, Optional, Callable +from tianshou.env.worker import ( + DummyEnvWorker, + EnvWorker, + RayEnvWorker, + SubprocEnvWorker, +) from tianshou.utils import RunningMeanStd -from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \ - RayEnvWorker class BaseVectorEnv(gym.Env): @@ -44,7 +49,7 @@ def seed(self, seed): Otherwise, the outputs of these envs may be the same with each other. - :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env. + :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the i-th env. :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a worker which contains the i-th env. :param int wait_num: use in asynchronous simulation if the time cost of @@ -56,7 +61,7 @@ def seed(self, seed): :param float timeout: use in asynchronous simulation same as above, in each vectorized step it only deal with those environments spending time within ``timeout`` seconds. - :param bool norm_obs: Whether to track mean/std of data and normalise observation + :param bool norm_obs: Whether to track mean/std of data and normalize observation on return. For now, observation normalization only support observation of type np.ndarray. :param obs_rms: class to track mean&std of observation. If not given, it will @@ -122,8 +127,9 @@ def __getattribute__(self, key: str) -> Any: ``action_space``. However, we would like the attribute lookup to go straight into the worker (in fact, this vector env's action_space is always None). """ - if key in ['metadata', 'reward_range', 'spec', 'action_space', - 'observation_space']: # reserved keys in gym.Env + if key in [ + 'metadata', 'reward_range', 'spec', 'action_space', 'observation_space' + ]: # reserved keys in gym.Env return self.__getattr__(key) else: return super().__getattribute__(key) @@ -137,7 +143,8 @@ def __getattr__(self, key: str) -> List[Any]: return [getattr(worker, key) for worker in self.workers] def _wrap_id( - self, id: Optional[Union[int, List[int], np.ndarray]] = None + self, + id: Optional[Union[int, List[int], np.ndarray]] = None ) -> Union[List[int], np.ndarray]: if id is None: return list(range(self.env_num)) @@ -222,7 +229,7 @@ def step( if action is not None: self._assert_id(id) assert len(action) == len(id) - for i, (act, env_id) in enumerate(zip(action, id)): + for act, env_id in zip(action, id): self.workers[env_id].send_action(act) self.waiting_conn.append(self.workers[env_id]) self.waiting_id.append(env_id) @@ -230,7 +237,8 @@ def step( ready_conns: List[EnvWorker] = [] while not ready_conns: ready_conns = self.worker_class.wait( - self.waiting_conn, self.wait_num, self.timeout) + self.waiting_conn, self.wait_num, self.timeout + ) result = [] for conn in ready_conns: waiting_index = self.waiting_conn.index(conn) @@ -246,13 +254,15 @@ def step( except ValueError: # different len(obs) obs_stack = np.array(obs_list, dtype=object) rew_stack, done_stack, info_stack = map( - np.stack, [rew_list, done_list, info_list]) + np.stack, [rew_list, done_list, info_list] + ) if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs_stack) return self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack def seed( - self, seed: Optional[Union[int, List[int]]] = None + self, + seed: Optional[Union[int, List[int]]] = None ) -> List[Optional[List[int]]]: """Set the seed for all environments. @@ -279,7 +289,8 @@ def render(self, **kwargs: Any) -> List[Any]: if self.is_async and len(self.waiting_id) > 0: raise RuntimeError( f"Environments {self.waiting_id} are still stepping, cannot " - "render them now.") + "render them now." + ) return [w.render(**kwargs) for w in self.workers] def close(self) -> None: @@ -324,6 +335,7 @@ class SubprocVectorEnv(BaseVectorEnv): """ def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) @@ -341,6 +353,7 @@ class ShmemVectorEnv(BaseVectorEnv): """ def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py index b9a20cf1e..1b1f37510 100644 --- a/tianshou/env/worker/__init__.py +++ b/tianshou/env/worker/__init__.py @@ -1,7 +1,7 @@ from tianshou.env.worker.base import EnvWorker from tianshou.env.worker.dummy import DummyEnvWorker -from tianshou.env.worker.subproc import SubprocEnvWorker from tianshou.env.worker.ray import RayEnvWorker +from tianshou.env.worker.subproc import SubprocEnvWorker __all__ = [ "EnvWorker", diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index dbf350a33..6fef9f68d 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -1,7 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, List, Optional, Tuple + import gym import numpy as np -from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Optional, Callable class EnvWorker(ABC): @@ -11,7 +12,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] - self.action_space = getattr(self, "action_space") + self.action_space = getattr(self, "action_space") # noqa: B009 @abstractmethod def __getattr__(self, key: str) -> Any: @@ -43,7 +44,9 @@ def step( @staticmethod def wait( - workers: List["EnvWorker"], wait_num: int, timeout: Optional[float] = None + workers: List["EnvWorker"], + wait_num: int, + timeout: Optional[float] = None ) -> List["EnvWorker"]: """Given a list of workers, return those ready ones.""" raise NotImplementedError diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index d0579d162..9e68e9f04 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -1,6 +1,7 @@ +from typing import Any, Callable, List, Optional + import gym import numpy as np -from typing import Any, List, Callable, Optional from tianshou.env.worker import EnvWorker diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index af7285b22..5d73763f2 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -1,6 +1,7 @@ +from typing import Any, Callable, List, Optional, Tuple + import gym import numpy as np -from typing import Any, List, Callable, Tuple, Optional from tianshou.env.worker import EnvWorker diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 8b89b6c34..8ef264360 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -1,15 +1,15 @@ -import gym -import time import ctypes -import numpy as np +import time from collections import OrderedDict -from multiprocessing.context import Process from multiprocessing import Array, Pipe, connection -from typing import Any, List, Tuple, Union, Callable, Optional +from multiprocessing.context import Process +from typing import Any, Callable, List, Optional, Tuple, Union -from tianshou.env.worker import EnvWorker -from tianshou.env.utils import CloudpickleWrapper +import gym +import numpy as np +from tianshou.env.utils import CloudpickleWrapper +from tianshou.env.worker import EnvWorker _NP_TO_CT = { np.bool_: ctypes.c_bool, @@ -62,6 +62,7 @@ def _worker( env_fn_wrapper: CloudpickleWrapper, obs_bufs: Optional[Union[dict, tuple, ShArray]] = None, ) -> None: + def _encode_obs( obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray] ) -> None: @@ -144,6 +145,7 @@ def __getattr__(self, key: str) -> Any: return self.parent_remote.recv() def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: + def decode_obs( buffer: Optional[Union[dict, tuple, ShArray]] ) -> Union[dict, tuple, np.ndarray]: diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index a59085809..25316e98d 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -1,6 +1,7 @@ -import numpy as np from abc import ABC, abstractmethod -from typing import Union, Optional, Sequence +from typing import Optional, Sequence, Union + +import numpy as np class BaseNoise(ABC, object): @@ -20,7 +21,7 @@ def __call__(self, size: Sequence[int]) -> np.ndarray: class GaussianNoise(BaseNoise): - """The vanilla gaussian process, for exploration in DDPG by default.""" + """The vanilla Gaussian process, for exploration in DDPG by default.""" def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: super().__init__() @@ -45,7 +46,7 @@ class OUNoise(BaseNoise): For required parameters, you can refer to the stackoverflow page. However, our experiment result shows that (similar to OpenAI SpinningUp) using - vanilla gaussian process has little difference from using the + vanilla Gaussian process has little difference from using the Ornstein-Uhlenbeck process. """ @@ -74,7 +75,8 @@ def __call__(self, size: Sequence[int], mu: Optional[float] = None) -> np.ndarra Return an numpy array which size is equal to ``size``. """ if self._x is None or isinstance( - self._x, np.ndarray) and self._x.shape != size: + self._x, np.ndarray + ) and self._x.shape != size: self._x = 0.0 if mu is None: mu = self._mu diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 8a9c6478e..6a842356f 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,3 +1,6 @@ +"""Policy package.""" +# isort:skip_file + from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy @@ -22,7 +25,6 @@ from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager - __all__ = [ "BasePolicy", "RandomPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 6dc3bdda8..feb6479ce 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,19 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional, Tuple, Union + import gym -import torch import numpy as np -from torch import nn +import torch +from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit -from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple, Union, Optional, Callable -from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary +from torch import nn -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as class BasePolicy(ABC, nn.Module): """The base class for any RL policy. - Tianshou aims to modularizing RL algorithms. It comes into several classes of + Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. @@ -84,9 +85,8 @@ def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: """Modify the action from policy.forward with exploration noise. :param act: a data batch or numpy.ndarray which is the action taken by @@ -216,9 +216,8 @@ def post_process_fn( if hasattr(buffer, "update_weight") and hasattr(batch, "weight"): buffer.update_weight(indices, batch.weight) - def update( - self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any - ) -> Dict[str, Any]: + def update(self, sample_size: int, buffer: Optional[ReplayBuffer], + **kwargs: Any) -> Dict[str, Any]: """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. In @@ -286,7 +285,7 @@ def compute_episodic_return( :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be - recongized by buffer.unfinished_index(). + recognized by buffer.unfinished_index(). :param numpy.ndarray indices: tell batch's location in buffer, batch is equal to buffer[indices]. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index f94aa1d39..a5321acdf 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -1,7 +1,8 @@ -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Union, Optional from tianshou.data import Batch, to_torch from tianshou.policy import BasePolicy diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 671b8b080..d9cac65df 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,11 +1,12 @@ import math -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Union, Optional -from tianshou.policy import DQNPolicy from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import DQNPolicy class DiscreteBCQPolicy(DQNPolicy): @@ -14,7 +15,7 @@ class DiscreteBCQPolicy(DQNPolicy): :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> q_value) :param torch.nn.Module imitator: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits) + :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float discount_factor: in [0, 1]. :param int estimation_step: the number of steps to look ahead. Default to 1. @@ -22,7 +23,7 @@ class DiscreteBCQPolicy(DQNPolicy): :param float eval_eps: the epsilon-greedy noise added in evaluation. :param float unlikely_action_threshold: the threshold (tau) for unlikely actions, as shown in Equ. (17) in the paper. Default to 0.3. - :param float imitation_logits_penalty: reguralization weight for imitation + :param float imitation_logits_penalty: regularization weight for imitation logits. Default to 1e-2. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. @@ -47,8 +48,10 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, estimation_step, target_update_freq, + reward_normalization, **kwargs + ) assert target_update_freq > 0, "BCQ needs target network setting." self.imitator = imitator assert 0.0 <= unlikely_action_threshold < 1.0, \ @@ -93,8 +96,12 @@ def forward( # type: ignore mask = (ratio < self._log_tau).float() action = (q_value - np.inf * mask).argmax(dim=-1) - return Batch(act=action, state=state, q_value=q_value, - imitation_logits=imitation_logits) + return Batch( + act=action, + state=state, + q_value=q_value, + imitation_logits=imitation_logits + ) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: @@ -108,7 +115,9 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss( - F.log_softmax(imitation_logits, dim=-1), act) # type: ignore + F.log_softmax(imitation_logits, dim=-1), + act # type: ignore + ) reg_loss = imitation_logits.pow(2).mean() loss = q_loss + i_loss + self._weight_reg * reg_loss diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index c6e1b50d0..ad4ed19a3 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict -from tianshou.policy import QRDQNPolicy from tianshou.data import Batch, to_torch +from tianshou.policy import QRDQNPolicy class DiscreteCQLPolicy(QRDQNPolicy): @@ -40,8 +41,10 @@ def __init__( min_q_weight: float = 10.0, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, num_quantiles, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, num_quantiles, estimation_step, + target_update_freq, reward_normalization, **kwargs + ) self._min_q_weight = min_q_weight def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -55,9 +58,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - self.tau_hat - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * (self.tau_hat - + (target_dist - curr_dist).detach().le(0.).float()).abs() + ).sum(-1).mean(1) qr_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 05e4b2655..6a149509e 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -1,11 +1,12 @@ -import torch from copy import deepcopy from typing import Any, Dict + +import torch import torch.nn.functional as F from torch.distributions import Categorical -from tianshou.policy.modelfree.pg import PGPolicy from tianshou.data import Batch, to_torch, to_torch_as +from tianshou.policy.modelfree.pg import PGPolicy class DiscreteCRRPolicy(PGPolicy): diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index b438dbcbc..e5ea5c9a5 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -1,6 +1,7 @@ -import torch +from typing import Any, Dict, Optional, Tuple, Union + import numpy as np -from typing import Any, Dict, Tuple, Union, Optional +import torch from tianshou.data import Batch from tianshou.policy import BasePolicy @@ -70,14 +71,16 @@ def observe( sum_count = self.rew_count + rew_count self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_square_sum += rew_square_sum - raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2 - self.rew_std = np.sqrt(1 / ( - sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2)) + raw_std2 = self.rew_square_sum / sum_count - self.rew_mean**2 + self.rew_std = np.sqrt( + 1 / (sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior**2) + ) self.rew_count = sum_count def sample_trans_prob(self) -> np.ndarray: sample_prob = torch.distributions.Dirichlet( - torch.from_numpy(self.trans_count)).sample().numpy() + torch.from_numpy(self.trans_count) + ).sample().numpy() return sample_prob def sample_reward(self) -> np.ndarray: @@ -168,12 +171,10 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - assert ( - 0.0 <= discount_factor <= 1.0 - ), "discount factor should be in [0, 1]" + assert (0.0 <= discount_factor <= 1.0), "discount factor should be in [0, 1]" self.model = PSRLModel( - trans_count_prior, rew_mean_prior, rew_std_prior, - discount_factor, epsilon) + trans_count_prior, rew_mean_prior, rew_std_prior, discount_factor, epsilon + ) self._add_done_loop = add_done_loop def forward( @@ -195,9 +196,7 @@ def forward( act = self.model(batch.obs, state=state, info=batch.info) return Batch(act=act) - def learn( - self, batch: Batch, *args: Any, **kwargs: Any - ) -> Dict[str, float]: + def learn(self, batch: Batch, *args: Any, **kwargs: Any) -> Dict[str, float]: n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) @@ -207,7 +206,7 @@ def learn( obs, act, obs_next = b.obs, b.act, b.obs_next trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += b.rew - rew_square_sum[obs, act] += b.rew ** 2 + rew_square_sum[obs, act] += b.rew**2 rew_count[obs, act] += 1 if self._add_done_loop and b.done: # special operation for terminal states: add a self-loop diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 3e05ce0b6..e44b58b58 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,11 +1,12 @@ -import torch +from typing import Any, Dict, List, Optional, Type + import numpy as np -from torch import nn +import torch import torch.nn.functional as F -from typing import Any, Dict, List, Type, Optional +from torch import nn -from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.policy import PGPolicy class A2CPolicy(PGPolicy): @@ -96,8 +97,14 @@ def _compute_returns( v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( - batch, buffer, indices, v_s_, v_s, - gamma=self._gamma, gae_lambda=self._lambda) + batch, + buffer, + indices, + v_s_, + v_s, + gamma=self._gamma, + gae_lambda=self._lambda + ) if self._rew_norm: batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) @@ -130,7 +137,8 @@ def learn( # type: ignore if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_( set(self.actor.parameters()).union(self.critic.parameters()), - max_norm=self._grad_norm) + max_norm=self._grad_norm + ) self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index a664096f5..4e79eb356 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,9 +1,10 @@ -import torch -import numpy as np from typing import Any, Dict, Optional -from tianshou.policy import DQNPolicy +import numpy as np +import torch + from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import DQNPolicy class C51Policy(DQNPolicy): @@ -44,8 +45,10 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, estimation_step, target_update_freq, + reward_normalization, **kwargs + ) assert num_atoms > 1, "num_atoms should be greater than 1" assert v_min < v_max, "v_max should be larger than v_min" self._num_atoms = num_atoms @@ -77,9 +80,10 @@ def _target_dist(self, batch: Batch) -> torch.Tensor: target_support = batch.returns.clamp(self._v_min, self._v_max) # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL - target_dist = (1 - (target_support.unsqueeze(1) - - self.support.view(1, -1, 1)).abs() / self.delta_z - ).clamp(0, 1) * next_dist.unsqueeze(1) + target_dist = ( + 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() / + self.delta_z + ).clamp(0, 1) * next_dist.unsqueeze(1) return target_dist.sum(-1) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -92,7 +96,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: curr_dist = self(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :] - cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1) + cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) loss = (cross_entropy * weight).mean() # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 batch.weight = cross_entropy.detach() # prio-buffer diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index fc4a622cc..18bb81b6b 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,12 +1,13 @@ -import torch import warnings -import numpy as np from copy import deepcopy -from typing import Any, Dict, Tuple, Union, Optional +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch -from tianshou.policy import BasePolicy -from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.data import Batch, ReplayBuffer +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import BasePolicy class DDPGPolicy(BasePolicy): @@ -53,8 +54,11 @@ def __init__( action_bound_method: str = "clip", **kwargs: Any, ) -> None: - super().__init__(action_scaling=action_scaling, - action_bound_method=action_bound_method, **kwargs) + super().__init__( + action_scaling=action_scaling, + action_bound_method=action_bound_method, + **kwargs + ) assert action_bound_method != "tanh", "tanh mapping is not supported" \ "in policies where action is used as input of critic , because" \ "raw action in range (-inf, inf) will cause instability in training" @@ -96,21 +100,21 @@ def sync_weight(self) -> None: for o, n in zip(self.critic_old.parameters(), self.critic.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - def _target_q( - self, buffer: ReplayBuffer, indices: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs_next: s_{t+n} target_q = self.critic_old( batch.obs_next, - self(batch, model='actor_old', input='obs_next').act) + self(batch, model='actor_old', input='obs_next').act + ) return target_q def process_fn( self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray ) -> Batch: batch = self.compute_nstep_return( - batch, buffer, indices, self._target_q, - self._gamma, self._n_step, self._rew_norm) + batch, buffer, indices, self._target_q, self._gamma, self._n_step, + self._rew_norm + ) return batch def forward( @@ -156,8 +160,7 @@ def _mse_optimizer( def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic - td, critic_loss = self._mse_optimizer( - batch, self.critic, self.critic_optim) + td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer # actor action = self(batch).act @@ -171,9 +174,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: "loss/critic": critic_loss.item(), } - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: if self._noise is None: return act if isinstance(act, np.ndarray): diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 18ac9fa12..7c580f37a 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Tuple, Union + import numpy as np +import torch from torch.distributions import Categorical -from typing import Any, Dict, Tuple, Union, Optional -from tianshou.policy import SACPolicy from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import SACPolicy class DiscreteSACPolicy(SACPolicy): @@ -24,7 +25,7 @@ class DiscreteSACPolicy(SACPolicy): :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy regularization coefficient. Default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, the - alpha is automatatically tuned. + alpha is automatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. @@ -50,9 +51,21 @@ def __init__( **kwargs: Any, ) -> None: super().__init__( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau, gamma, alpha, reward_normalization, estimation_step, - action_scaling=False, action_bound_method="", **kwargs) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau, + gamma, + alpha, + reward_normalization, + estimation_step, + action_scaling=False, + action_bound_method="", + **kwargs + ) self._alpha: Union[float, torch.Tensor] def forward( # type: ignore @@ -68,9 +81,7 @@ def forward( # type: ignore act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def _target_q( - self, buffer: ReplayBuffer, indices: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} obs_next_result = self(batch, input="obs_next") dist = obs_next_result.dist @@ -85,7 +96,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch( - batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) + batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long + ) # critic 1 current_q1 = self.critic1(batch.obs).gather(1, act).flatten() @@ -139,7 +151,6 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: return result - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 03c39d171..850490998 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,10 +1,11 @@ -import torch -import numpy as np from copy import deepcopy -from typing import Any, Dict, Union, Optional +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy class DQNPolicy(BasePolicy): @@ -96,8 +97,9 @@ def process_fn( :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. """ batch = self.compute_nstep_return( - batch, buffer, indices, self._target_q, - self._gamma, self._n_step, self._rew_norm) + batch, buffer, indices, self._target_q, self._gamma, self._n_step, + self._rew_norm + ) return batch def compute_q_value( @@ -173,9 +175,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self._iter += 1 return {"loss": loss.item()} - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index dc03365e9..3c015b3d0 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Optional, Union +from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.policy import DQNPolicy, QRDQNPolicy -from tianshou.data import Batch, to_numpy, ReplayBuffer from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -88,12 +89,14 @@ def forward( ) else: (logits, _, quantiles_tau), h = model( - obs_, propose_model=self.propose_model, fractions=fractions, - state=state, info=batch.info + obs_, + propose_model=self.propose_model, + fractions=fractions, + state=state, + info=batch.info ) - weighted_logits = ( - fractions.taus[:, 1:] - fractions.taus[:, :-1] - ).unsqueeze(1) * logits + weighted_logits = (fractions.taus[:, 1:] - + fractions.taus[:, :-1]).unsqueeze(1) * logits q = DQNPolicy.compute_q_value( self, weighted_logits.sum(2), getattr(obs, "mask", None) ) @@ -101,7 +104,10 @@ def forward( self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) return Batch( - logits=logits, act=act, state=h, fractions=fractions, + logits=logits, + act=act, + state=h, + fractions=fractions, quantiles_tau=quantiles_tau ) @@ -117,9 +123,12 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * ( + tau_hats.unsqueeze(2) - + (target_dist - curr_dist).detach().le(0.).float() + ).abs() + ).sum(-1).mean(1) quantile_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 @@ -131,16 +140,18 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 values_1 = sa_quantiles - sa_quantile_hats[:, :-1] - signs_1 = sa_quantiles > torch.cat([ - sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1) + signs_1 = sa_quantiles > torch.cat( + [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1 + ) values_2 = sa_quantiles - sa_quantile_hats[:, 1:] - signs_2 = sa_quantiles < torch.cat([ - sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1) + signs_2 = sa_quantiles < torch.cat( + [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1 + ) gradient_of_taus = ( - torch.where(signs_1, values_1, -values_1) - + torch.where(signs_2, values_2, -values_2) + torch.where(signs_1, values_1, -values_1) + + torch.where(signs_2, values_2, -values_2) ) fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() # calculate entropy loss diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 4c54d3563..9d9777b98 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Optional, Union -from tianshou.policy import QRDQNPolicy from tianshou.data import Batch, to_numpy +from tianshou.policy import QRDQNPolicy class IQNPolicy(QRDQNPolicy): @@ -45,8 +46,10 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, sample_size, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, sample_size, estimation_step, + target_update_freq, reward_normalization, **kwargs + ) assert sample_size > 1, "sample_size should be greater than 1" assert online_sample_size > 1, "online_sample_size should be greater than 1" assert target_sample_size > 1, "target_sample_size should be greater than 1" @@ -71,9 +74,8 @@ def forward( model = getattr(self, model) obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs - (logits, taus), h = model( - obs_, sample_size=sample_size, state=state, info=batch.info - ) + (logits, + taus), h = model(obs_, sample_size=sample_size, state=state, info=batch.info) q = self.compute_q_value(logits, getattr(obs, "mask", None)) if not hasattr(self, "max_action_num"): self.max_action_num = q.shape[1] @@ -92,9 +94,11 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * + (taus.unsqueeze(2) - + (target_dist - curr_dist).detach().le(0.).float()).abs() + ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 76da4bf0b..758093d1a 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -1,13 +1,13 @@ -import torch +from typing import Any, Dict, List, Type + import numpy as np -from torch import nn +import torch import torch.nn.functional as F -from typing import Any, Dict, List, Type +from torch import nn from torch.distributions import kl_divergence - -from tianshou.policy import A2CPolicy from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import A2CPolicy class NPGPolicy(A2CPolicy): @@ -82,7 +82,7 @@ def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: actor_losses, vf_losses, kls = [], [], [] - for step in range(repeat): + for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient @@ -91,7 +91,8 @@ def learn( # type: ignore log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) actor_loss = -(log_prob * b.adv).mean() flat_grads = self._get_flat_grad( - actor_loss, self.actor, retain_graph=True).detach() + actor_loss, self.actor, retain_graph=True + ).detach() # direction: calculate natural gradient with torch.no_grad(): @@ -101,12 +102,14 @@ def learn( # type: ignore # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( - flat_grads, flat_kl_grad, nsteps=10) + flat_grads, flat_kl_grad, nsteps=10 + ) # step with torch.no_grad(): - flat_params = torch.cat([param.data.view(-1) - for param in self.actor.parameters()]) + flat_params = torch.cat( + [param.data.view(-1) for param in self.actor.parameters()] + ) new_flat_params = flat_params + self._step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) new_dist = self(b).dist @@ -138,8 +141,8 @@ def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: """Matrix vector product.""" # caculate second order gradient of kl with respect to theta kl_v = (flat_kl_grad * v).sum() - flat_kl_grad_grad = self._get_flat_grad( - kl_v, self.actor, retain_graph=True).detach() + flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor, + retain_graph=True).detach() return flat_kl_grad_grad + v * self._damping def _conjugate_gradients( @@ -154,7 +157,7 @@ def _conjugate_gradients( # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0. # Change if doing warm start. rdotr = r.dot(r) - for i in range(nsteps): + for _ in range(nsteps): z = self._MVP(p, flat_kl_grad) alpha = rdotr / p.dot(z) x += alpha * p @@ -179,6 +182,7 @@ def _set_from_flat_params( for param in model.parameters(): flat_size = int(np.prod(list(param.size()))) param.data.copy_( - flat_params[prev_ind:prev_ind + flat_size].view(param.size())) + flat_params[prev_ind:prev_ind + flat_size].view(param.size()) + ) prev_ind += flat_size return model diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 1eb4dde4a..a64828874 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, List, Optional, Type, Union + import numpy as np -from typing import Any, Dict, List, Type, Union, Optional +import torch +from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.policy import BasePolicy from tianshou.utils import RunningMeanStd -from tianshou.data import Batch, ReplayBuffer, to_torch_as class PGPolicy(BasePolicy): @@ -47,8 +48,11 @@ def __init__( deterministic_eval: bool = False, **kwargs: Any, ) -> None: - super().__init__(action_scaling=action_scaling, - action_bound_method=action_bound_method, **kwargs) + super().__init__( + action_scaling=action_scaling, + action_bound_method=action_bound_method, + **kwargs + ) self.actor = model self.optim = optim self.lr_scheduler = lr_scheduler @@ -73,7 +77,8 @@ def process_fn( """ v_s_ = np.full(indices.shape, self.ret_rms.mean) unnormalized_returns, _ = self.compute_episodic_return( - batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0) + batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0 + ) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 824e19ad5..e1e17aa2f 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, List, Optional, Type + import numpy as np +import torch from torch import nn -from typing import Any, Dict, List, Type, Optional -from tianshou.policy import A2CPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.policy import A2CPolicy class PPOPolicy(A2CPolicy): @@ -124,8 +125,8 @@ def learn( # type: ignore # calculate loss for critic value = self.critic(b.obs).flatten() if self._value_clip: - v_clip = b.v_s + (value - b.v_s).clamp( - -self._eps_clip, self._eps_clip) + v_clip = b.v_s + (value - + b.v_s).clamp(-self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = torch.max(vf1, vf2).mean() @@ -140,7 +141,8 @@ def learn( # type: ignore if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_( set(self.actor.parameters()).union(self.critic.parameters()), - max_norm=self._grad_norm) + max_norm=self._grad_norm + ) self.optim.step() clip_losses.append(clip_loss.item()) vf_losses.append(vf_loss.item()) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 91d0f4bf0..fe3e101f7 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -1,11 +1,12 @@ -import torch import warnings +from typing import Any, Dict, Optional + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Optional -from tianshou.policy import DQNPolicy from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import DQNPolicy class QRDQNPolicy(DQNPolicy): @@ -40,13 +41,16 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, estimation_step, target_update_freq, + reward_normalization, **kwargs + ) assert num_quantiles > 1, "num_quantiles should be greater than 1" self._num_quantiles = num_quantiles tau = torch.linspace(0, 1, self._num_quantiles + 1) self.tau_hat = torch.nn.Parameter( - ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False) + ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False + ) warnings.filterwarnings("ignore", message="Using a target size") def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: @@ -77,9 +81,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - self.tau_hat - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * (self.tau_hat - + (target_dist - curr_dist).detach().le(0.).float()).abs() + ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 7aa4c682d..9028258d7 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -1,7 +1,7 @@ from typing import Any, Dict -from tianshou.policy import C51Policy from tianshou.data import Batch +from tianshou.policy import C51Policy from tianshou.utils.net.discrete import sample_noise diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 1858f27ed..2657a1eee 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,12 +1,13 @@ -import torch -import numpy as np from copy import deepcopy +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch from torch.distributions import Independent, Normal -from typing import Any, Dict, Tuple, Union, Optional -from tianshou.policy import DDPGPolicy -from tianshou.exploration import BaseNoise from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.exploration import BaseNoise +from tianshou.policy import DDPGPolicy class SACPolicy(DDPGPolicy): @@ -26,7 +27,7 @@ class SACPolicy(DDPGPolicy): :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy regularization coefficient. Default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then - alpha is automatatically tuned. + alpha is automatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. :param BaseNoise exploration_noise: add a noise to action for exploration. @@ -67,7 +68,8 @@ def __init__( ) -> None: super().__init__( None, None, None, None, tau, gamma, exploration_noise, - reward_normalization, estimation_step, **kwargs) + reward_normalization, estimation_step, **kwargs + ) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() @@ -123,15 +125,17 @@ def forward( # type: ignore # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: action_scale = to_torch_as( - (self.action_space.high - self.action_space.low) / 2.0, act) + (self.action_space.high - self.action_space.low) / 2.0, act + ) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log( action_scale * (1 - squashed_action.pow(2)) + self.__eps ).sum(-1, keepdim=True) - return Batch(logits=logits, act=squashed_action, - state=h, dist=dist, log_prob=log_prob) + return Batch( + logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob + ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} @@ -146,9 +150,11 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer( - batch, self.critic1, self.critic1_optim) + batch, self.critic1, self.critic1_optim + ) td2, critic2_loss = self._mse_optimizer( - batch, self.critic2, self.critic2_optim) + batch, self.critic2, self.critic2_optim + ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor @@ -156,8 +162,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() - actor_loss = (self._alpha * obs_result.log_prob.flatten() - - torch.min(current_q1a, current_q2a)).mean() + actor_loss = ( + self._alpha * obs_result.log_prob.flatten() - + torch.min(current_q1a, current_q2a) + ).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 3ca785bcc..a033237ea 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,11 +1,12 @@ -import torch -import numpy as np from copy import deepcopy from typing import Any, Dict, Optional -from tianshou.policy import DDPGPolicy +import numpy as np +import torch + from tianshou.data import Batch, ReplayBuffer from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import DDPGPolicy class TD3Policy(DDPGPolicy): @@ -64,9 +65,10 @@ def __init__( estimation_step: int = 1, **kwargs: Any, ) -> None: - super().__init__(actor, actor_optim, None, None, tau, gamma, - exploration_noise, reward_normalization, - estimation_step, **kwargs) + super().__init__( + actor, actor_optim, None, None, tau, gamma, exploration_noise, + reward_normalization, estimation_step, **kwargs + ) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim @@ -103,16 +105,18 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: noise = noise.clamp(-self._noise_clip, self._noise_clip) a_ += noise target_q = torch.min( - self.critic1_old(batch.obs_next, a_), - self.critic2_old(batch.obs_next, a_)) + self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_) + ) return target_q def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer( - batch, self.critic1, self.critic1_optim) + batch, self.critic1, self.critic1_optim + ) td2, critic2_loss = self._mse_optimizer( - batch, self.critic2, self.critic2_optim) + batch, self.critic2, self.critic2_optim + ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index b0ba63f11..75956d987 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,9 +1,9 @@ -import torch import warnings -import torch.nn.functional as F from typing import Any, Dict, List, Type -from torch.distributions import kl_divergence +import torch +import torch.nn.functional as F +from torch.distributions import kl_divergence from tianshou.data import Batch from tianshou.policy import NPGPolicy @@ -70,7 +70,7 @@ def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] - for step in range(repeat): + for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient @@ -79,7 +79,8 @@ def learn( # type: ignore ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * b.adv).mean() flat_grads = self._get_flat_grad( - actor_loss, self.actor, retain_graph=True).detach() + actor_loss, self.actor, retain_graph=True + ).detach() # direction: calculate natural gradient with torch.no_grad(): @@ -89,26 +90,30 @@ def learn( # type: ignore # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( - flat_grads, flat_kl_grad, nsteps=10) + flat_grads, flat_kl_grad, nsteps=10 + ) # stepsize: calculate max stepsize constrained by kl bound - step_size = torch.sqrt(2 * self._delta / ( - search_direction * self._MVP(search_direction, flat_kl_grad) - ).sum(0, keepdim=True)) + step_size = torch.sqrt( + 2 * self._delta / + (search_direction * + self._MVP(search_direction, flat_kl_grad)).sum(0, keepdim=True) + ) # stepsize: linesearch stepsize with torch.no_grad(): - flat_params = torch.cat([param.data.view(-1) - for param in self.actor.parameters()]) + flat_params = torch.cat( + [param.data.view(-1) for param in self.actor.parameters()] + ) for i in range(self._max_backtracks): new_flat_params = flat_params + step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) # calculate kl and if in bound, loss actually down new_dist = self(b).dist - new_dratio = ( - new_dist.log_prob(b.act) - b.logp_old).exp().float() - new_dratio = new_dratio.reshape( - new_dratio.size(0), -1).transpose(0, 1) + new_dratio = (new_dist.log_prob(b.act) - + b.logp_old).exp().float() + new_dratio = new_dratio.reshape(new_dratio.size(0), + -1).transpose(0, 1) new_actor_loss = -(new_dratio * b.adv).mean() kl = kl_divergence(old_dist, new_dist).mean() @@ -121,8 +126,10 @@ def learn( # type: ignore else: self._set_from_flat_params(self.actor, new_flat_params) step_size = torch.tensor([0.0]) - warnings.warn("Line search failed! It seems hyperparamters" - " are poor and need to be changed.") + warnings.warn( + "Line search failed! It seems hyperparamters" + " are poor and need to be changed." + ) # optimize citirc for _ in range(self._optim_critic_iters): diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index e7b50f07f..75705f4a3 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + import numpy as np -from typing import Any, Dict, List, Tuple, Union, Optional -from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import BasePolicy class MultiAgentPolicyManager(BasePolicy): @@ -54,21 +55,22 @@ def process_fn( tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] buffer._meta.rew = save_rew[:, policy.agent_id - 1] results[f"agent_{policy.agent_id}"] = policy.process_fn( - tmp_batch, buffer, tmp_indices) + tmp_batch, buffer, tmp_indices + ) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return Batch(results) - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: """Add exploration noise from sub-policy onto act.""" for policy in self.policies: agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] if len(agent_index) == 0: continue act[agent_index] = policy.exploration_noise( - act[agent_index], batch[agent_index]) + act[agent_index], batch[agent_index] + ) return act def forward( # type: ignore @@ -100,8 +102,8 @@ def forward( # type: ignore "agent_n": xxx} } """ - results: List[Tuple[bool, np.ndarray, Batch, - Union[np.ndarray, Batch], Batch]] = [] + results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch], + Batch]] = [] for policy in self.policies: # This part of code is difficult to understand. # Let's follow an example with two agents @@ -119,20 +121,28 @@ def forward( # type: ignore if isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] - out = policy(batch=tmp_batch, state=None if state is None - else state["agent_" + str(policy.agent_id)], - **kwargs) + out = policy( + batch=tmp_batch, + state=None if state is None else state["agent_" + + str(policy.agent_id)], + **kwargs + ) act = out.act each_state = out.state \ if (hasattr(out, "state") and out.state is not None) \ else Batch() results.append((True, agent_index, out, act, each_state)) - holder = Batch.cat([{"act": act} for - (has_data, agent_index, out, act, each_state) - in results if has_data]) + holder = Batch.cat( + [ + { + "act": act + } for (has_data, agent_index, out, act, each_state) in results + if has_data + ] + ) state_dict, out_dict = {}, {} - for policy, (has_data, agent_index, out, act, state) in zip( - self.policies, results): + for policy, (has_data, agent_index, out, act, + state) in zip(self.policies, results): if has_data: holder.act[agent_index] = act state_dict["agent_" + str(policy.agent_id)] = state @@ -141,9 +151,8 @@ def forward( # type: ignore holder["state"] = state_dict return holder - def learn( - self, batch: Batch, **kwargs: Any - ) -> Dict[str, Union[float, List[float]]]: + def learn(self, batch: Batch, + **kwargs: Any) -> Dict[str, Union[float, List[float]]]: """Dispatch the data to all policies for learning. :return: a dict with the following contents: diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 9c7f132af..dfb79564f 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -1,5 +1,6 @@ +from typing import Any, Dict, Optional, Union + import numpy as np -from typing import Any, Dict, Union, Optional from tianshou.data import Batch from tianshou.policy import BasePolicy diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 9fa88fbc3..11b3a95ef 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,3 +1,7 @@ +"""Trainer package.""" + +# isort:skip_file + from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer from tianshou.trainer.offpolicy import offpolicy_trainer diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index a2bcf051a..72cd00d06 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,13 +1,14 @@ import time -import tqdm -import numpy as np from collections import defaultdict -from typing import Dict, Union, Callable, Optional +from typing import Callable, Dict, Optional, Union + +import numpy as np +import tqdm -from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.data import Collector, ReplayBuffer -from tianshou.trainer import test_episode, gather_info +from tianshou.policy import BasePolicy +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config def offline_trainer( @@ -74,17 +75,17 @@ def offline_trainer( start_time = time.time() test_collector.reset_stat() - test_result = test_episode(policy, test_collector, test_fn, start_epoch, - episode_per_test, logger, gradient_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + gradient_step, reward_metric + ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1 + start_epoch, 1 + max_epoch): policy.train() - with tqdm.trange( - update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ) as t: - for i in t: + with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: + for _ in t: gradient_step += 1 losses = policy.update(batch_size, buffer) data = {"gradient_step": str(gradient_step)} @@ -96,8 +97,9 @@ def offline_trainer( t.set_postfix(**data) # test test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, - logger, gradient_step, reward_metric) + policy, test_collector, test_fn, epoch, episode_per_test, logger, + gradient_step, reward_metric + ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: best_epoch, best_reward, best_reward_std = epoch, rew, rew_std @@ -105,8 +107,10 @@ def offline_trainer( save_fn(policy) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) if verbose: - print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) if stop_fn and stop_fn(best_reward): break return gather_info(start_time, None, test_collector, best_reward, best_reward_std) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 2a576ccea..922646197 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,13 +1,14 @@ import time -import tqdm -import numpy as np from collections import defaultdict -from typing import Dict, Union, Callable, Optional +from typing import Callable, Dict, Optional, Union + +import numpy as np +import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import test_episode, gather_info -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config def offpolicy_trainer( @@ -43,7 +44,7 @@ def offpolicy_trainer( :param int step_per_epoch: the number of transitions collected per epoch. :param int step_per_collect: the number of transitions the collector would collect before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatly in each epoch. + transitions and do some policy network update repeatedly in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. @@ -91,8 +92,10 @@ def offpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode(policy, test_collector, test_fn, start_epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + env_step, reward_metric + ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] @@ -123,20 +126,23 @@ def offpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, - epoch, episode_per_test, logger, env_step) + policy, test_collector, test_fn, epoch, episode_per_test, + logger, env_step + ) if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn) + epoch, env_step, gradient_step, save_checkpoint_fn + ) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) + test_result["rew"], test_result["rew_std"] + ) else: policy.train() - for i in range(round(update_per_step * result["n/st"])): + for _ in range(round(update_per_step * result["n/st"])): gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): @@ -148,8 +154,10 @@ def offpolicy_trainer( if t.n <= t.total: t.update() # test - test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric + ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: best_epoch, best_reward, best_reward_std = epoch, rew, rew_std @@ -157,9 +165,12 @@ def offpolicy_trainer( save_fn(policy) logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if verbose: - print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) if stop_fn and stop_fn(best_reward): break - return gather_info(start_time, train_collector, test_collector, - best_reward, best_reward_std) + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 1696421e8..6788e8a65 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,13 +1,14 @@ import time -import tqdm -import numpy as np from collections import defaultdict -from typing import Dict, Union, Callable, Optional +from typing import Callable, Dict, Optional, Union + +import numpy as np +import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import test_episode, gather_info -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config def onpolicy_trainer( @@ -50,10 +51,10 @@ def onpolicy_trainer( policy network. :param int step_per_collect: the number of transitions the collector would collect before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatly in each epoch. + transitions and do some policy network update repeatedly in each epoch. :param int episode_per_collect: the number of episodes the collector would collect before the network update, i.e., trainer will collect "episode_per_collect" - episodes and do some policy network update repeatly in each epoch. + episodes and do some policy network update repeatedly in each epoch. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature ``f( num_epoch: int, step_idx: int) -> None``. @@ -97,8 +98,10 @@ def onpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode(policy, test_collector, test_fn, start_epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + env_step, reward_metric + ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] @@ -111,8 +114,9 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_step=step_per_collect, - n_episode=episode_per_collect) + result = train_collector.collect( + n_step=step_per_collect, n_episode=episode_per_collect + ) if result["n/ep"] > 0 and reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) @@ -130,25 +134,32 @@ def onpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, - epoch, episode_per_test, logger, env_step) + policy, test_collector, test_fn, epoch, episode_per_test, + logger, env_step + ) if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn) + epoch, env_step, gradient_step, save_checkpoint_fn + ) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) + test_result["rew"], test_result["rew_std"] + ) else: policy.train() losses = policy.update( - 0, train_collector.buffer, - batch_size=batch_size, repeat=repeat_per_collect) + 0, + train_collector.buffer, + batch_size=batch_size, + repeat=repeat_per_collect + ) train_collector.reset_buffer(keep_statistics=True) - step = max([1] + [ - len(v) for v in losses.values() if isinstance(v, list)]) + step = max( + [1] + [len(v) for v in losses.values() if isinstance(v, list)] + ) gradient_step += step for k in losses.keys(): stat[k].add(losses[k]) @@ -159,8 +170,10 @@ def onpolicy_trainer( if t.n <= t.total: t.update() # test - test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric + ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: best_epoch, best_reward, best_reward_std = epoch, rew, rew_std @@ -168,9 +181,12 @@ def onpolicy_trainer( save_fn(policy) logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if verbose: - print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) if stop_fn and stop_fn(best_reward): break - return gather_info(start_time, train_collector, test_collector, - best_reward, best_reward_std) + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 2e729feeb..a39a12fff 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,6 +1,7 @@ import time +from typing import Any, Callable, Dict, Optional, Union + import numpy as np -from typing import Any, Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -71,11 +72,13 @@ def gather_info( if train_c is not None: model_time -= train_c.collect_time train_speed = train_c.collect_step / (duration - test_c.collect_time) - result.update({ - "train_step": train_c.collect_step, - "train_episode": train_c.collect_episode, - "train_time/collector": f"{train_c.collect_time:.2f}s", - "train_time/model": f"{model_time:.2f}s", - "train_speed": f"{train_speed:.2f} step/s", - }) + result.update( + { + "train_step": train_c.collect_step, + "train_episode": train_c.collect_episode, + "train_time/collector": f"{train_c.collect_time:.2f}s", + "train_time/model": f"{model_time:.2f}s", + "train_speed": f"{train_speed:.2f} step/s", + } + ) return result diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 4ad73481c..5af038ab3 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,17 +1,12 @@ +"""Utils package.""" + from tianshou.utils.config import tqdm_config -from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.logger.base import BaseLogger, LazyLogger -from tianshou.utils.logger.tensorboard import TensorboardLogger, BasicLogger +from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.wandb import WandBLogger - +from tianshou.utils.statistics import MovAvg, RunningMeanStd __all__ = [ - "MovAvg", - "RunningMeanStd", - "tqdm_config", - "BaseLogger", - "TensorboardLogger", - "BasicLogger", - "LazyLogger", - "WandBLogger" + "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger", + "BasicLogger", "LazyLogger", "WandBLogger" ] diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index c1ffe760d..9b89d5e88 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -1,7 +1,8 @@ -import numpy as np -from numbers import Number from abc import ABC, abstractmethod -from typing import Dict, Tuple, Union, Callable, Optional +from numbers import Number +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]] diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index c65576b41..86e873cda 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -1,10 +1,10 @@ import warnings -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple -from torch.utils.tensorboard import SummaryWriter from tensorboard.backend.event_processing import event_accumulator +from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import BaseLogger, LOG_DATA_TYPE +from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger class TensorboardLogger(BaseLogger): @@ -48,8 +48,10 @@ def save_data( save_checkpoint_fn(epoch, env_step, gradient_step) self.write("save/epoch", epoch, {"save/epoch": epoch}) self.write("save/env_step", env_step, {"save/env_step": env_step}) - self.write("save/gradient_step", gradient_step, - {"save/gradient_step": gradient_step}) + self.write( + "save/gradient_step", gradient_step, + {"save/gradient_step": gradient_step} + ) def restore_data(self) -> Tuple[int, int, int]: ea = event_accumulator.EventAccumulator(self.writer.log_dir) @@ -79,5 +81,6 @@ class BasicLogger(TensorboardLogger): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( - "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427.") + "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427." + ) super().__init__(*args, **kwargs) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index cb11abc2e..b518a54a7 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,7 +1,8 @@ -import torch +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union + import numpy as np +import torch from torch import nn -from typing import Any, Dict, List, Type, Tuple, Union, Optional, Sequence ModuleType = Type[nn.Module] @@ -32,7 +33,7 @@ class MLP(nn.Module): :param int input_dim: dimension of the input vector. :param int output_dim: dimension of the output vector. If set to 0, there is no final linear layer. - :param hidden_sizes: shape of MLP passed in as a list, not incluing + :param hidden_sizes: shape of MLP passed in as a list, not including input_dim and output_dim. :param norm_layer: use which normalization before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. @@ -40,7 +41,7 @@ class MLP(nn.Module): of hidden_sizes, to use different normalization module in different layers. Default to no normalization. :param activation: which activation to use after each layer, can be both - the same actvition for all layers if passed in nn.Module, or different + the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. :param device: which device to create this model on. Default to None. @@ -64,8 +65,7 @@ def __init__( assert len(norm_layer) == len(hidden_sizes) norm_layer_list = norm_layer else: - norm_layer_list = [ - norm_layer for _ in range(len(hidden_sizes))] + norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))] else: norm_layer_list = [None] * len(hidden_sizes) if activation: @@ -73,26 +73,22 @@ def __init__( assert len(activation) == len(hidden_sizes) activation_list = activation else: - activation_list = [ - activation for _ in range(len(hidden_sizes))] + activation_list = [activation for _ in range(len(hidden_sizes))] else: activation_list = [None] * len(hidden_sizes) hidden_sizes = [input_dim] + list(hidden_sizes) model = [] for in_dim, out_dim, norm, activ in zip( - hidden_sizes[:-1], hidden_sizes[1:], - norm_layer_list, activation_list): + hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list + ): model += miniblock(in_dim, out_dim, norm, activ, linear_layer) if output_dim > 0: model += [linear_layer(hidden_sizes[-1], output_dim)] self.output_dim = output_dim or hidden_sizes[-1] self.model = nn.Sequential(*model) - def forward( - self, x: Union[np.ndarray, torch.Tensor] - ) -> torch.Tensor: - x = torch.as_tensor( - x, device=self.device, dtype=torch.float32) # type: ignore + def forward(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + x = torch.as_tensor(x, device=self.device, dtype=torch.float32) # type: ignore return self.model(x.flatten(1)) @@ -111,7 +107,7 @@ class Net(nn.Module): of hidden_sizes, to use different normalization module in different layers. Default to no normalization. :param activation: which activation to use after each layer, can be both - the same actvition for all layers if passed in nn.Module, or different + the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. :param device: specify the device when the network actually runs. Default @@ -162,8 +158,9 @@ def __init__( input_dim += action_dim self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 - self.model = MLP(input_dim, output_dim, hidden_sizes, - norm_layer, activation, device) + self.model = MLP( + input_dim, output_dim, hidden_sizes, norm_layer, activation, device + ) self.output_dim = self.model.output_dim if self.use_dueling: # dueling DQN q_kwargs, v_kwargs = dueling_param # type: ignore @@ -172,10 +169,14 @@ def __init__( q_output_dim, v_output_dim = action_dim, num_atoms q_kwargs: Dict[str, Any] = { **q_kwargs, "input_dim": self.output_dim, - "output_dim": q_output_dim, "device": self.device} + "output_dim": q_output_dim, + "device": self.device + } v_kwargs: Dict[str, Any] = { **v_kwargs, "input_dim": self.output_dim, - "output_dim": v_output_dim, "device": self.device} + "output_dim": v_output_dim, + "device": self.device + } self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) self.output_dim = self.Q.output_dim @@ -239,8 +240,7 @@ def forward( training mode, s should be with shape ``[bsz, len, dim]``. See the code and comment for more detail. """ - s = torch.as_tensor( - s, device=self.device, dtype=torch.float32) # type: ignore + s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -253,9 +253,12 @@ def forward( else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(), - state["c"].transpose(0, 1).contiguous())) + s, (h, c) = self.nn( + s, ( + state["h"].transpose(0, 1).contiguous(), + state["c"].transpose(0, 1).contiguous() + ) + ) s = self.fc2(s[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] - return s, {"h": h.transpose(0, 1).detach(), - "c": c.transpose(0, 1).detach()} + return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()} diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 36c178612..1bb090cdf 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,11 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Sequence, Tuple, Union + import numpy as np +import torch from torch import nn -from typing import Any, Dict, Tuple, Union, Optional, Sequence from tianshou.utils.net.common import MLP - SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -47,10 +47,8 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self._max = max_action def forward( @@ -97,8 +95,7 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = 1 - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) self.last = MLP(input_dim, 1, hidden_sizes, device=self.device) def forward( @@ -109,11 +106,15 @@ def forward( ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" s = torch.as_tensor( - s, device=self.device, dtype=torch.float32 # type: ignore + s, + device=self.device, # type: ignore + dtype=torch.float32, ).flatten(1) if a is not None: a = torch.as_tensor( - a, device=self.device, dtype=torch.float32 # type: ignore + a, + device=self.device, # type: ignore + dtype=torch.float32, ).flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) @@ -163,14 +164,13 @@ def __init__( self.preprocess = preprocess_net self.device = device self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: - self.sigma = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + self.sigma = MLP( + input_dim, self.output_dim, hidden_sizes, device=self.device + ) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self._max = max_action @@ -188,9 +188,7 @@ def forward( if not self._unbounded: mu = self._max * torch.tanh(mu) if self._c_sigma: - sigma = torch.clamp( - self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX - ).exp() + sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 @@ -241,8 +239,7 @@ def forward( info: Dict[str, Any] = {}, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" - s = torch.as_tensor( - s, device=self.device, dtype=torch.float32) # type: ignore + s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -254,23 +251,27 @@ def forward( else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(), - state["c"].transpose(0, 1).contiguous())) + s, (h, c) = self.nn( + s, ( + state["h"].transpose(0, 1).contiguous(), + state["c"].transpose(0, 1).contiguous() + ) + ) logits = s[:, -1] mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) if self._c_sigma: - sigma = torch.clamp( - self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX - ).exp() + sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] - return (mu, sigma), {"h": h.transpose(0, 1).detach(), - "c": c.transpose(0, 1).detach()} + return (mu, sigma), { + "h": h.transpose(0, 1).detach(), + "c": c.transpose(0, 1).detach() + } class RecurrentCritic(nn.Module): @@ -307,8 +308,7 @@ def forward( info: Dict[str, Any] = {}, ) -> torch.Tensor: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" - s = torch.as_tensor( - s, device=self.device, dtype=torch.float32) # type: ignore + s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -318,7 +318,10 @@ def forward( s = s[:, -1] if a is not None: a = torch.as_tensor( - a, device=self.device, dtype=torch.float32) # type: ignore + a, + device=self.device, # type: ignore + dtype=torch.float32, + ) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 200ae9d3a..bcc6531e3 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -1,8 +1,9 @@ -import torch +from typing import Any, Dict, Optional, Sequence, Tuple, Union + import numpy as np -from torch import nn +import torch import torch.nn.functional as F -from typing import Any, Dict, Tuple, Union, Optional, Sequence +from torch import nn from tianshou.data import Batch from tianshou.utils.net.common import MLP @@ -47,10 +48,8 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self.softmax_output = softmax_output def forward( @@ -101,10 +100,8 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = last_size - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) def forward( self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any @@ -141,9 +138,8 @@ def forward(self, taus: torch.Tensor) -> torch.Tensor: start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device ).view(1, 1, self.num_cosines) # Calculate cos(i * \pi * \tau). - cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view( - batch_size * N, self.num_cosines - ) + cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi + ).view(batch_size * N, self.num_cosines) # Calculate embeddings of taus. tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim) return tau_embeddings @@ -181,10 +177,12 @@ def __init__( device: Union[str, int, torch.device] = "cpu" ) -> None: last_size = np.prod(action_shape) - super().__init__(preprocess_net, hidden_sizes, last_size, - preprocess_net_output_dim, device) - self.input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) + super().__init__( + preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device + ) + self.input_dim = getattr( + preprocess_net, "output_dim", preprocess_net_output_dim + ) self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(device) @@ -195,13 +193,12 @@ def forward( # type: ignore logits, h = self.preprocess(s, state=kwargs.get("state", None)) # Sample fractions. batch_size = logits.size(0) - taus = torch.rand(batch_size, sample_size, - dtype=logits.dtype, device=logits.device) - embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view( - batch_size * sample_size, -1 + taus = torch.rand( + batch_size, sample_size, dtype=logits.dtype, device=logits.device ) - out = self.last(embedding).view( - batch_size, sample_size, -1).transpose(1, 2) + embedding = (logits.unsqueeze(1) * + self.embed_model(taus)).view(batch_size * sample_size, -1) + out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) return (out, taus), h @@ -270,20 +267,18 @@ def __init__( device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__( - preprocess_net, action_shape, hidden_sizes, - num_cosines, preprocess_net_output_dim, device + preprocess_net, action_shape, hidden_sizes, num_cosines, + preprocess_net_output_dim, device ) def _compute_quantiles( self, obs: torch.Tensor, taus: torch.Tensor ) -> torch.Tensor: batch_size, sample_size = taus.shape - embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view( - batch_size * sample_size, -1 - ) - quantiles = self.last(embedding).view( - batch_size, sample_size, -1 - ).transpose(1, 2) + embedding = (obs.unsqueeze(1) * + self.embed_model(taus)).view(batch_size * sample_size, -1) + quantiles = self.last(embedding).view(batch_size, sample_size, + -1).transpose(1, 2) return quantiles def forward( # type: ignore @@ -328,10 +323,8 @@ def __init__( super().__init__() # Learnable parameters. - self.mu_W = nn.Parameter( - torch.FloatTensor(out_features, in_features)) - self.sigma_W = nn.Parameter( - torch.FloatTensor(out_features, in_features)) + self.mu_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) + self.sigma_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) self.mu_bias = nn.Parameter(torch.FloatTensor(out_features)) self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features)) diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index e0d06763a..a81af601e 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -1,8 +1,9 @@ -import torch -import numpy as np from numbers import Number from typing import List, Union +import numpy as np +import torch + class MovAvg(object): """Class for moving average. @@ -66,13 +67,15 @@ def std(self) -> float: class RunningMeanStd(object): - """Calulates the running mean and std of a data stream. + """Calculates the running mean and std of a data stream. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm """ def __init__( - self, mean: Union[float, np.ndarray] = 0.0, std: Union[float, np.ndarray] = 1.0 + self, + mean: Union[float, np.ndarray] = 0.0, + std: Union[float, np.ndarray] = 1.0 ) -> None: self.mean, self.var = mean, std self.count = 0 @@ -88,7 +91,7 @@ def update(self, x: np.ndarray) -> None: new_mean = self.mean + delta * batch_count / total_count m_a = self.var * self.count m_b = batch_var * batch_count - m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count + m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count new_var = m_2 / total_count self.mean, self.var = new_mean, new_var