From 7620cd37e2cd151a41bce24c2eacec942350e56f Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 13:32:31 -0400 Subject: [PATCH 01/11] add isort and yapf --- .github/workflows/lint_and_docs.yml | 6 ++++++ .gitignore | 1 + setup.cfg | 10 ++++++++++ setup.py | 2 ++ tianshou/__init__.py | 2 +- 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index 3a69bfa71..6f6e21a6e 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -20,6 +20,12 @@ jobs: - name: Lint with flake8 run: | flake8 . --count --show-source --statistics + - name: yapf code formatter + run: | + yapf -r -d . + - name: isort code formatter + run: | + isort --check . - name: Type check run: | mypy diff --git a/.gitignore b/.gitignore index be8453abd..e9510a1df 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,4 @@ MUJOCO_LOG.TXT *.swp *.pkl *.hdf5 +wandb/ diff --git a/setup.cfg b/setup.cfg index d485e6d06..e1420b9ad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,16 @@ exclude = *.egg-info max-line-length = 87 +[yapf] +based_on_style = pep8 +dedent_closing_brackets = true +column_limit = 87 + +[isort] +profile = black +multi_line_output = 3 +line_length = 87 + [mypy] files = tianshou/**/*.py allow_redefinition = True diff --git a/setup.py b/setup.py index 208af8e14..9488d27a0 100644 --- a/setup.py +++ b/setup.py @@ -60,6 +60,8 @@ def get_version() -> str: "sphinx_rtd_theme", "sphinxcontrib-bibtex", "flake8", + "yapf", + "isort", "pytest", "pytest-cov", "ray>=1.0.0", diff --git a/tianshou/__init__.py b/tianshou/__init__.py index fb362d054..98972ca9d 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.4.2" +__version__ = "0.4.3" __all__ = [ "env", From a1309ebb107518a78562d3c40645b8092b37176f Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 13:34:30 -0400 Subject: [PATCH 02/11] format --- docs/conf.py | 6 +- examples/atari/atari_bcq.py | 99 ++-- examples/atari/atari_c51.py | 105 ++-- examples/atari/atari_cql.py | 82 +-- examples/atari/atari_crr.py | 97 ++-- examples/atari/atari_dqn.py | 104 ++-- examples/atari/atari_fqf.py | 119 +++-- examples/atari/atari_iqn.py | 113 ++-- examples/atari/atari_network.py | 28 +- examples/atari/atari_qrdqn.py | 103 ++-- examples/atari/atari_rainbow.py | 135 +++-- examples/atari/atari_wrapper.py | 38 +- examples/box2d/acrobot_dualdqn.py | 76 ++- examples/box2d/bipedal_hardcore_sac.py | 109 ++-- examples/box2d/lunarlander_dqn.py | 81 ++- examples/box2d/mcc_sac.py | 95 ++-- examples/mujoco/analysis.py | 36 +- examples/mujoco/gen_json.py | 18 +- examples/mujoco/mujoco_a2c.py | 119 +++-- examples/mujoco/mujoco_ddpg.py | 81 ++- examples/mujoco/mujoco_npg.py | 118 +++-- examples/mujoco/mujoco_ppo.py | 122 +++-- examples/mujoco/mujoco_reinforce.py | 96 ++-- examples/mujoco/mujoco_sac.py | 96 ++-- examples/mujoco/mujoco_td3.py | 98 ++-- examples/mujoco/mujoco_trpo.py | 122 +++-- examples/mujoco/plotter.py | 218 +++++--- examples/mujoco/tools.py | 39 +- examples/vizdoom/env.py | 8 +- examples/vizdoom/maps/spectator.py | 5 +- examples/vizdoom/replay.py | 1 + examples/vizdoom/vizdoom_c51.py | 114 ++-- setup.py | 3 +- test/base/env.py | 76 ++- test/base/test_batch.py | 159 +++--- test/base/test_buffer.py | 618 ++++++++++++++++------ test/base/test_collector.py | 398 ++++++++++---- test/base/test_env.py | 39 +- test/base/test_env_finite.py | 29 +- test/base/test_returns.py | 184 ++++--- test/base/test_utils.py | 39 +- test/continuous/test_ddpg.py | 77 ++- test/continuous/test_npg.py | 88 ++- test/continuous/test_ppo.py | 85 +-- test/continuous/test_sac_with_il.py | 137 +++-- test/continuous/test_td3.py | 89 ++-- test/continuous/test_trpo.py | 88 ++- test/discrete/test_a2c_with_il.py | 93 ++-- test/discrete/test_c51.py | 97 ++-- test/discrete/test_dqn.py | 80 ++- test/discrete/test_drqn.py | 62 ++- test/discrete/test_fqf.py | 84 ++- test/discrete/test_il_bcq.py | 82 ++- test/discrete/test_il_crr.py | 69 ++- test/discrete/test_iqn.py | 82 ++- test/discrete/test_pg.py | 68 ++- test/discrete/test_ppo.py | 58 +- test/discrete/test_qrdqn.py | 83 ++- test/discrete/test_qrdqn_il_cql.py | 59 ++- test/discrete/test_rainbow.py | 106 ++-- test/discrete/test_sac.py | 73 ++- test/modelbased/test_psrl.py | 42 +- test/multiagent/Gomoku.py | 19 +- test/multiagent/test_tic_tac_toe.py | 1 + test/multiagent/tic_tac_toe.py | 143 +++-- test/multiagent/tic_tac_toe_env.py | 47 +- test/throughput/test_batch_profile.py | 41 +- test/throughput/test_buffer_profile.py | 6 +- test/throughput/test_collector_profile.py | 27 +- tianshou/__init__.py | 3 +- tianshou/data/__init__.py | 13 +- tianshou/data/batch.py | 106 ++-- tianshou/data/buffer/base.py | 27 +- tianshou/data/buffer/cached.py | 8 +- tianshou/data/buffer/manager.py | 67 ++- tianshou/data/buffer/prio.py | 18 +- tianshou/data/buffer/vecbuf.py | 13 +- tianshou/data/collector.py | 117 ++-- tianshou/data/utils/converter.py | 12 +- tianshou/data/utils/segtree.py | 17 +- tianshou/env/__init__.py | 9 +- tianshou/env/maenv.py | 6 +- tianshou/env/utils.py | 4 +- tianshou/env/venvs.py | 36 +- tianshou/env/worker/__init__.py | 2 +- tianshou/env/worker/base.py | 10 +- tianshou/env/worker/dummy.py | 4 +- tianshou/env/worker/ray.py | 4 +- tianshou/env/worker/subproc.py | 16 +- tianshou/exploration/random.py | 11 +- tianshou/policy/__init__.py | 2 +- tianshou/policy/base.py | 24 +- tianshou/policy/imitation/base.py | 6 +- tianshou/policy/imitation/discrete_bcq.py | 25 +- tianshou/policy/imitation/discrete_cql.py | 21 +- tianshou/policy/imitation/discrete_crr.py | 6 +- tianshou/policy/modelbased/psrl.py | 31 +- tianshou/policy/modelfree/a2c.py | 23 +- tianshou/policy/modelfree/c51.py | 23 +- tianshou/policy/modelfree/ddpg.py | 39 +- tianshou/policy/modelfree/discrete_sac.py | 38 +- tianshou/policy/modelfree/dqn.py | 20 +- tianshou/policy/modelfree/fqf.py | 48 +- tianshou/policy/modelfree/iqn.py | 27 +- tianshou/policy/modelfree/npg.py | 29 +- tianshou/policy/modelfree/pg.py | 18 +- tianshou/policy/modelfree/ppo.py | 15 +- tianshou/policy/modelfree/qrdqn.py | 24 +- tianshou/policy/modelfree/rainbow.py | 3 +- tianshou/policy/modelfree/sac.py | 35 +- tianshou/policy/modelfree/td3.py | 25 +- tianshou/policy/modelfree/trpo.py | 40 +- tianshou/policy/multiagent/mapolicy.py | 50 +- tianshou/policy/random.py | 4 +- tianshou/trainer/__init__.py | 1 + tianshou/trainer/offline.py | 34 +- tianshou/trainer/offpolicy.py | 45 +- tianshou/trainer/onpolicy.py | 62 ++- tianshou/trainer/utils.py | 19 +- tianshou/utils/__init__.py | 15 +- tianshou/utils/logger/base.py | 9 +- tianshou/utils/logger/tensorboard.py | 17 +- tianshou/utils/logger/wandb.py | 1 - tianshou/utils/net/common.py | 52 +- tianshou/utils/net/continuous.py | 70 ++- tianshou/utils/net/discrete.py | 72 +-- tianshou/utils/statistics.py | 13 +- 127 files changed, 4973 insertions(+), 2728 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index c258cf0a7..bcf2a9b1c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,9 +14,9 @@ # import sys # sys.path.insert(0, os.path.abspath('.')) +import sphinx_rtd_theme import tianshou -import sphinx_rtd_theme # Get the version string version = tianshou.__version__ @@ -30,7 +30,6 @@ # The full version, including alpha/beta/rc tags release = version - # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be @@ -59,7 +58,8 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] autodoc_default_options = { - "special-members": ", ".join( + "special-members": + ", ".join( [ "__len__", "__call__", diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index b4cd5b62d..1be441013 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -1,21 +1,21 @@ +import argparse +import datetime import os -import torch import pickle import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import Actor -from tianshou.policy import DiscreteBCQPolicy -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import DQN -from atari_wrapper import wrap_deepmind def get_args(): @@ -38,15 +38,19 @@ def get_args(): parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--watch", default=False, action="store_true", - help="watch the play of pre-trained policy only") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only" + ) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", type=str, - default="./expert_DQN_PongNoFrameskip-v4.hdf5") + "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" + ) parser.add_argument( - "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) args = parser.parse_known_args()[0] return args @@ -56,8 +60,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_discrete_bcq(args=get_args()): @@ -69,32 +77,43 @@ def test_discrete_bcq(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - feature_net = DQN(*args.state_shape, args.action_shape, - device=args.device, features_only=True).to(args.device) + feature_net = DQN( + *args.state_shape, args.action_shape, device=args.device, features_only=True + ).to(args.device) policy_net = Actor( - feature_net, args.action_shape, device=args.device, - hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + feature_net, + args.action_shape, + device=args.device, + hidden_sizes=args.hidden_sizes, + softmax_output=False + ).to(args.device) imitation_net = Actor( - feature_net, args.action_shape, device=args.device, - hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + feature_net, + args.action_shape, + device=args.device, + hidden_sizes=args.hidden_sizes, + softmax_output=False + ).to(args.device) optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr) + list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr + ) # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, - args.target_update_freq, args.eps_test, - args.unlikely_action_threshold, args.imitation_logits_penalty) + args.target_update_freq, args.eps_test, args.unlikely_action_threshold, + args.imitation_logits_penalty + ) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device)) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -113,7 +132,8 @@ def test_discrete_bcq(args=get_args()): # log log_path = os.path.join( args.logdir, args.task, 'bcq', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=args.log_interval) @@ -132,8 +152,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -143,9 +162,17 @@ def watch(): exit(0) result = offline_trainer( - policy, buffer, test_collector, args.epoch, - args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index b70cc6cc0..291fb7007 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import C51 +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import C51Policy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import C51 -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -40,12 +40,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -55,8 +59,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_c51(args=get_args()): @@ -67,23 +75,30 @@ def test_c51(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = C51(*args.state_shape, args.action_shape, - args.num_atoms, args.device) + net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = C51Policy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -92,8 +107,12 @@ def test_c51(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -136,11 +155,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -148,8 +169,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -161,11 +183,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_cql.py b/examples/atari/atari_cql.py index cbab82029..db4e33a9a 100644 --- a/examples/atari/atari_cql.py +++ b/examples/atari/atari_cql.py @@ -1,20 +1,20 @@ +import argparse +import datetime import os -import torch import pickle import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import QRDQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCQLPolicy -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import QRDQN -from atari_wrapper import wrap_deepmind +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger def get_args(): @@ -37,15 +37,19 @@ def get_args(): parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--watch", default=False, action="store_true", - help="watch the play of pre-trained policy only") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only" + ) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", type=str, - default="./expert_DQN_PongNoFrameskip-v4.hdf5") + "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" + ) parser.add_argument( - "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) args = parser.parse_known_args()[0] return args @@ -55,8 +59,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_discrete_cql(args=get_args()): @@ -68,25 +76,29 @@ def test_discrete_cql(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = QRDQN(*args.state_shape, args.action_shape, - args.num_quantiles, args.device) + net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = DiscreteCQLPolicy( - net, optim, args.gamma, args.num_quantiles, args.n_step, - args.target_update_freq, min_q_weight=args.min_q_weight + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + args.target_update_freq, + min_q_weight=args.min_q_weight ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device)) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -105,7 +117,8 @@ def test_discrete_cql(args=get_args()): # log log_path = os.path.join( args.logdir, args.task, 'cql', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=args.log_interval) @@ -124,8 +137,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -135,9 +147,17 @@ def watch(): exit(0) result = offline_trainer( - policy, buffer, test_collector, args.epoch, - args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py index e8e1ba54e..06cde415b 100644 --- a/examples/atari/atari_crr.py +++ b/examples/atari/atari_crr.py @@ -1,21 +1,21 @@ +import argparse +import datetime import os -import torch import pickle import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import Actor -from tianshou.policy import DiscreteCRRPolicy -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import DQN -from atari_wrapper import wrap_deepmind def get_args(): @@ -38,15 +38,19 @@ def get_args(): parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument("--resume-path", type=str, default=None) - parser.add_argument("--watch", default=False, action="store_true", - help="watch the play of pre-trained policy only") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only" + ) parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( - "--load-buffer-name", type=str, - default="./expert_DQN_PongNoFrameskip-v4.hdf5") + "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5" + ) parser.add_argument( - "--device", type=str, - default="cuda" if torch.cuda.is_available() else "cpu") + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) args = parser.parse_known_args()[0] return args @@ -56,8 +60,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_discrete_crr(args=get_args()): @@ -69,33 +77,44 @@ def test_discrete_crr(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - feature_net = DQN(*args.state_shape, args.action_shape, - device=args.device, features_only=True).to(args.device) - actor = Actor(feature_net, args.action_shape, device=args.device, - hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + feature_net = DQN( + *args.state_shape, args.action_shape, device=args.device, features_only=True + ).to(args.device) + actor = Actor( + feature_net, + args.action_shape, + device=args.device, + hidden_sizes=args.hidden_sizes, + softmax_output=False + ).to(args.device) critic = DQN(*args.state_shape, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()), - lr=args.lr) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr + ) # define policy policy = DiscreteCRRPolicy( - actor, critic, optim, args.gamma, + actor, + critic, + optim, + args.gamma, policy_improvement_mode=args.policy_improvement_mode, - ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, + ratio_upper_bound=args.ratio_upper_bound, + beta=args.beta, min_q_weight=args.min_q_weight, target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device)) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -114,7 +133,8 @@ def test_discrete_crr(args=get_args()): # log log_path = os.path.join( args.logdir, args.task, 'crr', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer, update_interval=args.log_interval) @@ -132,8 +152,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -143,9 +162,17 @@ def watch(): exit(0) result = offline_trainer( - policy, buffer, test_collector, args.epoch, - args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 69ec08349..c9f74af8c 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import DQNPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import DQN -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -37,12 +37,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -52,8 +56,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_dqn(args=get_args()): @@ -64,22 +72,28 @@ def test_dqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = DQN(*args.state_shape, - args.action_shape, args.device).to(args.device) + net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy = DQNPolicy(net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + policy = DQNPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) @@ -87,8 +101,12 @@ def test_dqn(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -131,11 +149,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -143,8 +163,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -156,11 +177,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 4a6e97c06..4629bede2 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -1,20 +1,20 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import FQFPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import FQFPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction -from atari_network import DQN -from atari_wrapper import wrap_deepmind - def get_args(): parser = argparse.ArgumentParser() @@ -43,12 +43,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -58,8 +62,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_fqf(args=get_args()): @@ -70,30 +78,43 @@ def test_fqf(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, - features_only=True) + feature_net = DQN( + *args.state_shape, args.action_shape, args.device, features_only=True + ) net = FullQuantileFunction( - feature_net, args.action_shape, args.hidden_sizes, - args.num_cosines, device=args.device + feature_net, + args.action_shape, + args.hidden_sizes, + args.num_cosines, + device=args.device ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) - fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), - lr=args.fraction_lr) + fraction_optim = torch.optim.RMSprop( + fraction_net.parameters(), lr=args.fraction_lr + ) # define policy policy = FQFPolicy( - net, optim, fraction_net, fraction_optim, - args.gamma, args.num_fractions, args.ent_coef, args.n_step, + net, + optim, + fraction_net, + fraction_optim, + args.gamma, + args.num_fractions, + args.ent_coef, + args.n_step, target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy @@ -103,8 +124,12 @@ def test_fqf(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -147,11 +172,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -159,8 +186,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -172,11 +200,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index e5966a318..d0e7773d0 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -1,20 +1,20 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import DQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import IQNPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import IQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.discrete import ImplicitQuantileNetwork -from atari_network import DQN -from atari_wrapper import wrap_deepmind - def get_args(): parser = argparse.ArgumentParser() @@ -43,12 +43,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -58,8 +62,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_iqn(args=get_args()): @@ -70,27 +78,38 @@ def test_iqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - feature_net = DQN(*args.state_shape, args.action_shape, args.device, - features_only=True) + feature_net = DQN( + *args.state_shape, args.action_shape, args.device, features_only=True + ) net = ImplicitQuantileNetwork( - feature_net, args.action_shape, args.hidden_sizes, - num_cosines=args.num_cosines, device=args.device + feature_net, + args.action_shape, + args.hidden_sizes, + num_cosines=args.num_cosines, + device=args.device ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = IQNPolicy( - net, optim, args.gamma, args.sample_size, args.online_sample_size, - args.target_sample_size, args.n_step, + net, + optim, + args.gamma, + args.sample_size, + args.online_sample_size, + args.target_sample_size, + args.n_step, target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy @@ -100,8 +119,12 @@ def test_iqn(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -144,11 +167,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -156,8 +181,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -169,11 +195,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 2eccf11af..3fb208d44 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -1,7 +1,9 @@ -import torch +from typing import Any, Dict, Optional, Sequence, Tuple, Union + import numpy as np +import torch from torch import nn -from typing import Any, Dict, Tuple, Union, Optional, Sequence + from tianshou.utils.net.discrete import NoisyLinear @@ -11,7 +13,6 @@ class DQN(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - def __init__( self, c: int, @@ -27,15 +28,15 @@ def __init__( nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True), - nn.Flatten()) + nn.Flatten() + ) with torch.no_grad(): - self.output_dim = np.prod( - self.net(torch.zeros(1, c, h, w)).shape[1:]) + self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) if not features_only: self.net = nn.Sequential( - self.net, - nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True), - nn.Linear(512, np.prod(action_shape))) + self.net, nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True), + nn.Linear(512, np.prod(action_shape)) + ) self.output_dim = np.prod(action_shape) def forward( @@ -55,7 +56,6 @@ class C51(DQN): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - def __init__( self, c: int, @@ -88,7 +88,6 @@ class Rainbow(DQN): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - def __init__( self, c: int, @@ -113,12 +112,14 @@ def linear(x, y): self.Q = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), - linear(512, self.action_num * self.num_atoms)) + linear(512, self.action_num * self.num_atoms) + ) self._is_dueling = is_dueling if self._is_dueling: self.V = nn.Sequential( linear(self.output_dim, 512), nn.ReLU(inplace=True), - linear(512, self.num_atoms)) + linear(512, self.num_atoms) + ) self.output_dim = self.action_num * self.num_atoms def forward( @@ -148,7 +149,6 @@ class QRDQN(DQN): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - def __init__( self, c: int, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 781f81d5d..23a7966eb 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from atari_network import QRDQN +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger -from tianshou.policy import QRDQNPolicy +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import QRDQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from atari_network import QRDQN -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -38,12 +38,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -53,8 +57,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_qrdqn(args=get_args()): @@ -65,23 +73,28 @@ def test_qrdqn(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = QRDQN(*args.state_shape, args.action_shape, - args.num_quantiles, args.device) + net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = QRDQNPolicy( - net, optim, args.gamma, args.num_quantiles, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -90,8 +103,12 @@ def test_qrdqn(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -134,11 +151,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -146,8 +165,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -159,11 +179,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index f2f44f0cd..b131cce5f 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -1,19 +1,19 @@ +import argparse +import datetime import os -import torch import pprint -import datetime -import argparse + import numpy as np +import torch +from atari_network import Rainbow +from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import RainbowPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import RainbowPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer - -from atari_network import Rainbow -from atari_wrapper import wrap_deepmind +from tianshou.utils import TensorboardLogger def get_args(): @@ -50,12 +50,16 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -65,8 +69,12 @@ def make_atari_env(args): def make_atari_env_watch(args): - return wrap_deepmind(args.task, frame_stack=args.frames_stack, - episode_life=False, clip_rewards=False) + return wrap_deepmind( + args.task, + frame_stack=args.frames_stack, + episode_life=False, + clip_rewards=False + ) def test_rainbow(args=get_args()): @@ -77,25 +85,38 @@ def test_rainbow(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([lambda: make_atari_env(args) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) - for _ in range(args.test_num)]) + train_envs = SubprocVectorEnv( + [lambda: make_atari_env(args) for _ in range(args.training_num)] + ) + test_envs = SubprocVectorEnv( + [lambda: make_atari_env_watch(args) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = Rainbow(*args.state_shape, args.action_shape, - args.num_atoms, args.noisy_std, args.device, - is_dueling=not args.no_dueling, - is_noisy=not args.no_noisy) + net = Rainbow( + *args.state_shape, + args.action_shape, + args.num_atoms, + args.noisy_std, + args.device, + is_dueling=not args.no_dueling, + is_noisy=not args.no_noisy + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = RainbowPolicy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -105,20 +126,31 @@ def test_rainbow(args=get_args()): # when you have enough RAM if args.no_priority: buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) else: buffer = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha, - beta=args.beta, weight_norm=not args.no_weight_norm) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta, + weight_norm=not args.no_weight_norm + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join( args.logdir, args.task, 'rainbow', - f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}' + ) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) logger = TensorboardLogger(writer) @@ -164,12 +196,15 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack, alpha=args.alpha, - beta=args.beta) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack, + alpha=args.alpha, + beta=args.beta + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -177,8 +212,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -190,11 +226,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 53a662613..2c128c10f 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -1,10 +1,11 @@ # Borrow a lot from openai baselines: # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py +from collections import deque + import cv2 import gym import numpy as np -from collections import deque class NoopResetEnv(gym.Wrapper): @@ -14,7 +15,6 @@ class NoopResetEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. :param int noop_max: the maximum value of no-ops to run. """ - def __init__(self, env, noop_max=30): super().__init__(env) self.noop_max = noop_max @@ -38,7 +38,6 @@ class MaxAndSkipEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. :param int skip: number of `skip`-th frame. """ - def __init__(self, env, skip=4): super().__init__(env) self._skip = skip @@ -64,7 +63,6 @@ class EpisodicLifeEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): super().__init__(env) self.lives = 0 @@ -104,7 +102,6 @@ class FireResetEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): super().__init__(env) assert env.unwrapped.get_action_meanings()[1] == 'FIRE' @@ -120,20 +117,20 @@ class WarpFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): super().__init__(env) self.size = 84 self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), - shape=(self.size, self.size), dtype=env.observation_space.dtype) + shape=(self.size, self.size), + dtype=env.observation_space.dtype + ) def observation(self, frame): """returns the current observation from a frame""" frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) - return cv2.resize(frame, (self.size, self.size), - interpolation=cv2.INTER_AREA) + return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) class ScaledFloatFrame(gym.ObservationWrapper): @@ -141,7 +138,6 @@ class ScaledFloatFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): super().__init__(env) low = np.min(env.observation_space.low) @@ -149,8 +145,8 @@ def __init__(self, env): self.bias = low self.scale = high - low self.observation_space = gym.spaces.Box( - low=0., high=1., shape=env.observation_space.shape, - dtype=np.float32) + low=0., high=1., shape=env.observation_space.shape, dtype=np.float32 + ) def observation(self, observation): return (observation - self.bias) / self.scale @@ -161,7 +157,6 @@ class ClipRewardEnv(gym.RewardWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env): super().__init__(env) self.reward_range = (-1, 1) @@ -177,16 +172,17 @@ class FrameStack(gym.Wrapper): :param gym.Env env: the environment to wrap. :param int n_frames: the number of frames to stack. """ - def __init__(self, env, n_frames): super().__init__(env) self.n_frames = n_frames self.frames = deque([], maxlen=n_frames) - shape = (n_frames,) + env.observation_space.shape + shape = (n_frames, ) + env.observation_space.shape self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), - shape=shape, dtype=env.observation_space.dtype) + shape=shape, + dtype=env.observation_space.dtype + ) def reset(self): obs = self.env.reset() @@ -205,8 +201,14 @@ def _get_ob(self): return np.stack(self.frames, axis=0) -def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, - frame_stack=4, scale=False, warp_frame=True): +def wrap_deepmind( + env_id, + episode_life=True, + clip_rewards=True, + frame_stack=4, + scale=False, + warp_frame=True +): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 4889f7eb2..76246fd3e 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -31,17 +32,19 @@ def get_args(): parser.add_argument('--update-per-step', type=float, default=0.01) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) - parser.add_argument('--dueling-q-hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--dueling-v-hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument( + '--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) + parser.add_argument( + '--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser.parse_args() @@ -52,10 +55,12 @@ def test_dqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -64,18 +69,28 @@ def test_dqn(args=get_args()): # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - dueling_param=(Q_param, V_param)).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + dueling_param=(Q_param, V_param) + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -105,10 +120,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 4caa50b94..bb2449dce 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -32,23 +33,21 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument('--n-step', type=int, default=4) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) return parser.parse_args() class Wrapper(gym.Wrapper): """Env wrapper for reward scale, action repeat and removing done penalty""" - def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True): super().__init__(env) self.action_repeat = action_repeat @@ -75,13 +74,16 @@ def test_sac_bipedal(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] - train_envs = SubprocVectorEnv([ - lambda: Wrapper(gym.make(args.task)) - for _ in range(args.training_num)]) + train_envs = SubprocVectorEnv( + [lambda: Wrapper(gym.make(args.task)) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([ - lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) - for _ in range(args.test_num)]) + test_envs = SubprocVectorEnv( + [ + lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False) + for _ in range(args.test_num) + ] + ) # seed np.random.seed(args.seed) @@ -90,22 +92,33 @@ def test_sac_bipedal(args=get_args()): test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( - net_a, args.action_shape, max_action=args.max_action, - device=args.device, unbounded=True).to(args.device) + net_a, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -116,9 +129,18 @@ def test_sac_bipedal(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, action_space=env.action_space) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: policy.load_state_dict(torch.load(args.resume_path)) @@ -126,9 +148,11 @@ def test_sac_bipedal(args=get_args()): # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -144,10 +168,20 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, test_in_train=False, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + test_in_train=False, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) if __name__ == '__main__': pprint.pprint(result) @@ -155,8 +189,7 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index bb73ac615..88f4c397b 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv def get_args(): @@ -31,19 +32,20 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=16) parser.add_argument('--update-per-step', type=float, default=0.0625) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--dueling-q-hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--dueling-v-hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument( + '--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) + parser.add_argument( + '--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser.parse_args() @@ -54,10 +56,12 @@ def test_dqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -66,18 +70,28 @@ def test_dqn(args=get_args()): # model Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - dueling_param=(Q_param, V_param)).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + dueling_param=(Q_param, V_param) + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -93,7 +107,7 @@ def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold def train_fn(epoch, env_step): # exp decay - eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) + eps = max(args.eps_train * (1 - 5e-6)**env_step, args.eps_test) policy.set_eps(eps) def test_fn(epoch, env_step): @@ -101,10 +115,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, - test_fn=test_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + train_fn=train_fn, + test_fn=test_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index a43728be5..0638e8f61 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import SACPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise -from tianshou.utils.net.common import Net +from tianshou.policy import SACPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -34,16 +35,15 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=5) parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--training-num', type=int, default=5) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument('--rew-norm', type=bool, default=False) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser.parse_args() @@ -54,31 +54,43 @@ def test_sac(args=get_args()): args.max_action = env.action_space.high[0] # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( - net, args.action_shape, - max_action=args.max_action, device=args.device, unbounded=True + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, - device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, - device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -89,16 +101,26 @@ def test_sac(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, reward_normalization=args.rew_norm, exploration_noise=OUNoise(0.0, args.noise_std), - action_space=env.action_space) + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -114,10 +136,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/analysis.py b/examples/mujoco/analysis.py index 01a2cf678..ed0bb6872 100755 --- a/examples/mujoco/analysis.py +++ b/examples/mujoco/analysis.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 -import re import argparse +import re +from collections import defaultdict + import numpy as np from tabulate import tabulate -from collections import defaultdict -from tools import find_all_files, group_files, csv2numpy +from tools import csv2numpy, find_all_files, group_files def numerical_anysis(root_dir, xlim, norm=False): @@ -20,13 +21,16 @@ def numerical_anysis(root_dir, xlim, norm=False): for f in csv_files: result = csv2numpy(f) if norm: - result = np.stack([ - result['env_step'], - result['rew'] - result['rew'][0], - result['rew:shaded']]) + result = np.stack( + [ + result['env_step'], result['rew'] - result['rew'][0], + result['rew:shaded'] + ] + ) else: - result = np.stack([ - result['env_step'], result['rew'], result['rew:shaded']]) + result = np.stack( + [result['env_step'], result['rew'], result['rew:shaded']] + ) if result[0, -1] < xlim: continue @@ -79,11 +83,17 @@ def numerical_anysis(root_dir, xlim, norm=False): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--xlim', type=int, default=1000000, - help='x-axis limitation (default: 1000000)') + parser.add_argument( + '--xlim', + type=int, + default=1000000, + help='x-axis limitation (default: 1000000)' + ) parser.add_argument('--root-dir', type=str) parser.add_argument( - '--norm', action="store_true", - help="Normalize all results according to environment.") + '--norm', + action="store_true", + help="Normalize all results according to environment." + ) args = parser.parse_args() numerical_anysis(args.root_dir, args.xlim, norm=args.norm) diff --git a/examples/mujoco/gen_json.py b/examples/mujoco/gen_json.py index 0c0b113e9..5429cf8d8 100755 --- a/examples/mujoco/gen_json.py +++ b/examples/mujoco/gen_json.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 -import os import csv -import sys import json +import os +import sys def merge(rootdir): @@ -19,12 +19,14 @@ def merge(rootdir): algo = os.path.relpath(path, rootdir).upper() reader = csv.DictReader(open(os.path.join(path, filenames[0]))) for row in reader: - result.append({ - 'env_step': int(row['env_step']), - 'rew': float(row['rew']), - 'rew_std': float(row['rew:shaded']), - 'Agent': algo, - }) + result.append( + { + 'env_step': int(row['env_step']), + 'rew': float(row['rew']), + 'rew_std': float(row['rew:shaded']), + 'Agent': algo, + } + ) open(os.path.join(rootdir, 'result.json'), 'w').write(json.dumps(result)) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index e9debc906..02978697b 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import A2CPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -48,11 +49,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -63,16 +68,18 @@ def test_a2c(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -80,12 +87,25 @@ def test_a2c(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -101,27 +121,43 @@ def test_a2c(args=get_args()): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()), - lr=args.lr, eps=1e-5, alpha=0.99) + optim = torch.optim.RMSprop( + list(actor.parameters()) + list(critic.parameters()), + lr=args.lr, + eps=1e-5, + alpha=0.99 + ) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = A2CPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, - vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space) + policy = A2CPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -149,10 +185,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 28bac056a..8d436b573 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -1,22 +1,23 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import DDPGPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise +from tianshou.policy import DDPGPolicy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -42,11 +43,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -58,17 +63,18 @@ def test_ddpg(args=get_args()): args.exploration_noise = args.exploration_noise * args.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) if args.training_num > 1: train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) else: train_envs = gym.make(args.task) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -77,19 +83,29 @@ def test_ddpg(args=get_args()): # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor( - net_a, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( - actor, actor_optim, critic, critic_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic, + critic_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), - estimation_step=args.n_step, action_space=env.action_space) + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -118,10 +134,19 @@ def save_fn(policy): if not args.watch: # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 00b2a1a2c..23883a119 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import NPGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -26,8 +27,9 @@ def get_args(): parser.add_argument('--task', type=str, default='HalfCheetah-v3') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=4096) - parser.add_argument('--hidden-sizes', type=int, nargs='*', - default=[64, 64]) # baselines [32, 32] + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[64, 64] + ) # baselines [32, 32] parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) @@ -49,11 +51,15 @@ def get_args(): parser.add_argument('--optim-critic-iters', type=int, default=20) parser.add_argument('--actor-step-size', type=float, default=0.1) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -64,16 +70,18 @@ def test_npg(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -81,12 +89,25 @@ def test_npg(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -107,22 +128,32 @@ def test_npg(args=get_args()): if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = NPGPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - advantage_normalization=args.norm_adv, - optim_critic_iters=args.optim_critic_iters, - actor_step_size=args.actor_step_size) + policy = NPGPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + advantage_normalization=args.norm_adv, + optim_critic_iters=args.optim_critic_iters, + actor_step_size=args.actor_step_size + ) # load a previous policy if args.resume_path: @@ -150,10 +181,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index fb3e7a0a2..01dc5aa3f 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -53,11 +54,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -68,16 +73,18 @@ def test_ppo(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -85,12 +92,25 @@ def test_ppo(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -107,29 +127,44 @@ def test_ppo(args=get_args()): m.weight.data.copy_(0.01 * m.weight.data) optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr) + list(actor.parameters()) + list(critic.parameters()), lr=args.lr + ) lr_scheduler = None if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = PPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, - vf_coef=args.vf_coef, ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - eps_clip=args.eps_clip, value_clip=args.value_clip, - dual_clip=args.dual_clip, advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv) + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + eps_clip=args.eps_clip, + value_clip=args.value_clip, + dual_clip=args.dual_clip, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv + ) # load a previous policy if args.resume_path: @@ -157,10 +192,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index b7698562a..914b46251 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import PGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -45,11 +46,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -60,16 +65,18 @@ def test_reinforce(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -77,10 +84,19 @@ def test_reinforce(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in actor.modules(): if isinstance(m, torch.nn.Linear): @@ -100,18 +116,27 @@ def test_reinforce(args=get_args()): if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = PGPolicy(actor, optim, dist, discount_factor=args.gamma, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.action_bound_method, - lr_scheduler=lr_scheduler, action_space=env.action_space) + policy = PGPolicy( + actor, + optim, + dist, + discount_factor=args.gamma, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.action_bound_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -139,10 +164,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index c2dfd3618..cb764f473 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -1,21 +1,22 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -43,11 +44,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -58,17 +63,18 @@ def test_sac(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) if args.training_num > 1: train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) else: train_envs = gym.make(args.task) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -77,16 +83,28 @@ def test_sac(args=get_args()): # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( - net_a, args.action_shape, max_action=args.max_action, - device=args.device, unbounded=True, conditioned_sigma=True + net_a, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) @@ -99,9 +117,18 @@ def test_sac(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, - estimation_step=args.n_step, action_space=env.action_space) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -130,10 +157,19 @@ def save_fn(policy): if not args.watch: # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 9a0179899..9e0ca0d82 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -1,22 +1,23 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import TD3Policy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise +from tianshou.policy import TD3Policy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -45,11 +46,15 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -63,17 +68,18 @@ def test_td3(args=get_args()): args.noise_clip = args.noise_clip * args.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) if args.training_num > 1: train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) else: train_envs = gym.make(args.task) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -82,27 +88,44 @@ def test_td3(args=get_args()): # model net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor( - net_a, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), - policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, - noise_clip=args.noise_clip, estimation_step=args.n_step, - action_space=env.action_space) + policy_noise=args.policy_noise, + update_actor_freq=args.update_actor_freq, + noise_clip=args.noise_clip, + estimation_step=args.n_step, + action_space=env.action_space + ) # load a previous policy if args.resume_path: @@ -131,10 +154,19 @@ def save_fn(policy): if not args.watch: # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b00f2e3d6..aef324fd5 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 +import argparse +import datetime import os -import gym -import torch import pprint -import datetime -import argparse + +import gym import numpy as np +import torch from torch import nn +from torch.distributions import Independent, Normal from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from torch.distributions import Independent, Normal +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv from tianshou.policy import TRPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer from tianshou.utils.net.continuous import ActorProb, Critic -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer def get_args(): @@ -26,8 +27,9 @@ def get_args(): parser.add_argument('--task', type=str, default='HalfCheetah-v3') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=4096) - parser.add_argument('--hidden-sizes', type=int, nargs='*', - default=[64, 64]) # baselines [32, 32] + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[64, 64] + ) # baselines [32, 32] parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) @@ -52,11 +54,15 @@ def get_args(): parser.add_argument('--backtrack-coeff', type=float, default=0.8) parser.add_argument('--max-backtracks', type=int, default=10) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) return parser.parse_args() @@ -67,16 +73,18 @@ def test_trpo(args=get_args()): args.max_action = env.action_space.high[0] print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), - np.max(env.action_space.high)) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) # train_envs = gym.make(args.task) train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)], - norm_obs=True) + [lambda: gym.make(args.task) for _ in range(args.training_num)], norm_obs=True + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( [lambda: gym.make(args.task) for _ in range(args.test_num)], - norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False) + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) # seed np.random.seed(args.seed) @@ -84,12 +92,25 @@ def test_trpo(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net_a, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) critic = Critic(net_c, device=args.device).to(args.device) torch.nn.init.constant_(actor.sigma_param, -0.5) for m in list(actor.modules()) + list(critic.modules()): @@ -110,24 +131,34 @@ def test_trpo(args=get_args()): if args.lr_decay: # decay learning rate to 0 linearly max_update_num = np.ceil( - args.step_per_epoch / args.step_per_collect) * args.epoch + args.step_per_epoch / args.step_per_collect + ) * args.epoch lr_scheduler = LambdaLR( - optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) def dist(*logits): return Independent(Normal(*logits), 1) - policy = TRPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma, - gae_lambda=args.gae_lambda, - reward_normalization=args.rew_norm, action_scaling=True, - action_bound_method=args.bound_action_method, - lr_scheduler=lr_scheduler, action_space=env.action_space, - advantage_normalization=args.norm_adv, - optim_critic_iters=args.optim_critic_iters, - max_kl=args.max_kl, - backtrack_coeff=args.backtrack_coeff, - max_backtracks=args.max_backtracks) + policy = TRPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + advantage_normalization=args.norm_adv, + optim_critic_iters=args.optim_critic_iters, + max_kl=args.max_kl, + backtrack_coeff=args.backtrack_coeff, + max_backtracks=args.max_backtracks + ) # load a previous policy if args.resume_path: @@ -155,10 +186,19 @@ def save_fn(policy): if not args.watch: # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger, - test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) pprint.pprint(result) # Let's watch its performance! diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index 4ecd530c7..04f15aa7b 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 -import re -import os import argparse -import numpy as np +import os +import re + import matplotlib.pyplot as plt import matplotlib.ticker as mticker - -from tools import find_all_files, group_files, csv2numpy +import numpy as np +from tools import csv2numpy, find_all_files, group_files def smooth(y, radius, mode='two_sided', valid_only=False): @@ -38,28 +38,49 @@ def smooth(y, radius, mode='two_sided', valid_only=False): return out -COLORS = ([ - # deepmind style - '#0072B2', - '#009E73', - '#D55E00', - '#CC79A7', - # '#F0E442', - '#d73027', # RED - # built-in color - 'blue', 'red', 'pink', 'cyan', 'magenta', 'yellow', 'black', 'purple', - 'brown', 'orange', 'teal', 'lightblue', 'lime', 'lavender', 'turquoise', - 'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue', 'green', - # personal color - '#313695', # DARK BLUE - '#74add1', # LIGHT BLUE - '#f46d43', # ORANGE - '#4daf4a', # GREEN - '#984ea3', # PURPLE - '#f781bf', # PINK - '#ffc832', # YELLOW - '#000000', # BLACK -]) +COLORS = ( + [ + # deepmind style + '#0072B2', + '#009E73', + '#D55E00', + '#CC79A7', + # '#F0E442', + '#d73027', # RED + # built-in color + 'blue', + 'red', + 'pink', + 'cyan', + 'magenta', + 'yellow', + 'black', + 'purple', + 'brown', + 'orange', + 'teal', + 'lightblue', + 'lime', + 'lavender', + 'turquoise', + 'darkgreen', + 'tan', + 'salmon', + 'gold', + 'darkred', + 'darkblue', + 'green', + # personal color + '#313695', # DARK BLUE + '#74add1', # LIGHT BLUE + '#f46d43', # ORANGE + '#4daf4a', # GREEN + '#984ea3', # PURPLE + '#f781bf', # PINK + '#ffc832', # YELLOW + '#000000', # BLACK + ] +) def plot_ax( @@ -96,8 +117,11 @@ def legend_fn(x): y_shaded = smooth(csv_dict[ykey + ':shaded'], radius=smooth_radius) ax.fill_between(x, y - y_shaded, y + y_shaded, color=color, alpha=.2) - ax.legend(legneds, loc=2 if legend_outside else None, - bbox_to_anchor=(1, 1) if legend_outside else None) + ax.legend( + legneds, + loc=2 if legend_outside else None, + bbox_to_anchor=(1, 1) if legend_outside else None + ) ax.xaxis.set_major_formatter(mticker.EngFormatter()) if xlim is not None: ax.set_xlim(xmin=0, xmax=xlim) @@ -127,8 +151,14 @@ def plot_figure( res = group_files(file_lists, group_pattern) row_n = int(np.ceil(len(res) / 3)) col_n = min(len(res), 3) - fig, axes = plt.subplots(row_n, col_n, sharex=sharex, sharey=sharey, figsize=( - fig_length * col_n, fig_width * row_n), squeeze=False) + fig, axes = plt.subplots( + row_n, + col_n, + sharex=sharex, + sharey=sharey, + figsize=(fig_length * col_n, fig_width * row_n), + squeeze=False + ) axes = axes.flatten() for i, (k, v) in enumerate(res.items()): plot_ax(axes[i], v, title=k, **kwargs) @@ -138,53 +168,95 @@ def plot_figure( if __name__ == "__main__": parser = argparse.ArgumentParser(description='plotter') - parser.add_argument('--fig-length', type=int, default=6, - help='matplotlib figure length (default: 6)') - parser.add_argument('--fig-width', type=int, default=6, - help='matplotlib figure width (default: 6)') - parser.add_argument('--style', default='seaborn', - help='matplotlib figure style (default: seaborn)') - parser.add_argument('--title', default=None, - help='matplotlib figure title (default: None)') - parser.add_argument('--xkey', default='env_step', - help='x-axis key in csv file (default: env_step)') - parser.add_argument('--ykey', default='rew', - help='y-axis key in csv file (default: rew)') - parser.add_argument('--smooth', type=int, default=0, - help='smooth radius of y axis (default: 0)') - parser.add_argument('--xlabel', default='Timesteps', - help='matplotlib figure xlabel') - parser.add_argument('--ylabel', default='Episode Reward', - help='matplotlib figure ylabel') - parser.add_argument( - '--shaded-std', action='store_true', - help='shaded region corresponding to standard deviation of the group') - parser.add_argument('--sharex', action='store_true', - help='whether to share x axis within multiple sub-figures') - parser.add_argument('--sharey', action='store_true', - help='whether to share y axis within multiple sub-figures') - parser.add_argument('--legend-outside', action='store_true', - help='place the legend outside of the figure') - parser.add_argument('--xlim', type=int, default=None, - help='x-axis limitation (default: None)') + parser.add_argument( + '--fig-length', + type=int, + default=6, + help='matplotlib figure length (default: 6)' + ) + parser.add_argument( + '--fig-width', + type=int, + default=6, + help='matplotlib figure width (default: 6)' + ) + parser.add_argument( + '--style', + default='seaborn', + help='matplotlib figure style (default: seaborn)' + ) + parser.add_argument( + '--title', default=None, help='matplotlib figure title (default: None)' + ) + parser.add_argument( + '--xkey', + default='env_step', + help='x-axis key in csv file (default: env_step)' + ) + parser.add_argument( + '--ykey', default='rew', help='y-axis key in csv file (default: rew)' + ) + parser.add_argument( + '--smooth', type=int, default=0, help='smooth radius of y axis (default: 0)' + ) + parser.add_argument( + '--xlabel', default='Timesteps', help='matplotlib figure xlabel' + ) + parser.add_argument( + '--ylabel', default='Episode Reward', help='matplotlib figure ylabel' + ) + parser.add_argument( + '--shaded-std', + action='store_true', + help='shaded region corresponding to standard deviation of the group' + ) + parser.add_argument( + '--sharex', + action='store_true', + help='whether to share x axis within multiple sub-figures' + ) + parser.add_argument( + '--sharey', + action='store_true', + help='whether to share y axis within multiple sub-figures' + ) + parser.add_argument( + '--legend-outside', + action='store_true', + help='place the legend outside of the figure' + ) + parser.add_argument( + '--xlim', type=int, default=None, help='x-axis limitation (default: None)' + ) parser.add_argument('--root-dir', default='./', help='root dir (default: ./)') parser.add_argument( - '--file-pattern', type=str, default=r".*/test_rew_\d+seeds.csv$", + '--file-pattern', + type=str, + default=r".*/test_rew_\d+seeds.csv$", help='regular expression to determine whether or not to include target csv ' - 'file, default to including all test_rew_{num}seeds.csv file under rootdir') + 'file, default to including all test_rew_{num}seeds.csv file under rootdir' + ) parser.add_argument( - '--group-pattern', type=str, default=r"(/|^)\w*?\-v(\d|$)", + '--group-pattern', + type=str, + default=r"(/|^)\w*?\-v(\d|$)", help='regular expression to group files in sub-figure, default to grouping ' - 'according to env_name dir, "" means no grouping') + 'according to env_name dir, "" means no grouping' + ) parser.add_argument( - '--legend-pattern', type=str, default=r".*", + '--legend-pattern', + type=str, + default=r".*", help='regular expression to extract legend from csv file path, default to ' - 'using file path as legend name.') + 'using file path as legend name.' + ) parser.add_argument('--show', action='store_true', help='show figure') - parser.add_argument('--output-path', type=str, - help='figure save path', default="./figure.png") - parser.add_argument('--dpi', type=int, default=200, - help='figure dpi (default: 200)') + parser.add_argument( + '--output-path', type=str, help='figure save path', default="./figure.png" + ) + parser.add_argument( + '--dpi', type=int, default=200, help='figure dpi (default: 200)' + ) args = parser.parse_args() file_lists = find_all_files(args.root_dir, re.compile(args.file_pattern)) file_lists = [os.path.relpath(f, args.root_dir) for f in file_lists] @@ -207,9 +279,9 @@ def plot_figure( sharey=args.sharey, smooth_radius=args.smooth, shaded_std=args.shaded_std, - legend_outside=args.legend_outside) + legend_outside=args.legend_outside + ) if args.output_path: - plt.savefig(args.output_path, - dpi=args.dpi, bbox_inches='tight') + plt.savefig(args.output_path, dpi=args.dpi, bbox_inches='tight') if args.show: plt.show() diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index 9e49206e0..bda7c6db1 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 +import argparse +import csv import os import re -import csv -import tqdm -import argparse -import numpy as np from collections import defaultdict + +import numpy as np +import tqdm from tensorboard.backend.event_processing import event_accumulator @@ -66,11 +67,13 @@ def convert_tfevents_to_csv(root_dir, refresh=False): initial_time = ea._first_event_timestamp content = [["env_step", "rew", "time"]] for test_rew in ea.scalars.Items("test/rew"): - content.append([ - round(test_rew.step, 4), - round(test_rew.value, 4), - round(test_rew.wall_time - initial_time, 4), - ]) + content.append( + [ + round(test_rew.step, 4), + round(test_rew.value, 4), + round(test_rew.wall_time - initial_time, 4), + ] + ) csv.writer(open(output_file, 'w')).writerows(content) result[output_file] = content return result @@ -85,8 +88,10 @@ def merge_csv(csv_files, root_dir, remove_zero=False): v.pop(1) sorted_keys = sorted(csv_files.keys()) sorted_values = [csv_files[k][1:] for k in sorted_keys] - content = [["env_step", "rew", "rew:shaded"] + list(map( - lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys))] + content = [ + ["env_step", "rew", "rew:shaded"] + + list(map(lambda f: "rew:" + os.path.relpath(f, root_dir), sorted_keys)) + ] for rows in zip(*sorted_values): array = np.array(rows) assert len(set(array[:, 0])) == 1, (set(array[:, 0]), array[:, 0]) @@ -101,11 +106,15 @@ def merge_csv(csv_files, root_dir, remove_zero=False): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - '--refresh', action="store_true", - help="Re-generate all csv files instead of using existing one.") + '--refresh', + action="store_true", + help="Re-generate all csv files instead of using existing one." + ) parser.add_argument( - '--remove-zero', action="store_true", - help="Remove the data point of env_step == 0.") + '--remove-zero', + action="store_true", + help="Remove the data point of env_step == 0." + ) parser.add_argument('--root-dir', type=str) args = parser.parse_args() diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 017ab7750..0e8d995d1 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -1,4 +1,5 @@ import os + import cv2 import gym import numpy as np @@ -33,9 +34,7 @@ def battle_button_comb(): class Env(gym.Env): - def __init__( - self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False - ): + def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False): super().__init__() self.save_lmp = save_lmp self.health_setting = "battle" in cfg_path @@ -75,8 +74,7 @@ def reset(self): self.obs_buffer = np.zeros(self.res, dtype=np.uint8) self.get_obs() self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) - self.killcount = self.game.get_game_variable( - vzd.GameVariable.KILLCOUNT) + self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) return self.obs_buffer diff --git a/examples/vizdoom/maps/spectator.py b/examples/vizdoom/maps/spectator.py index d4d7e8c7c..2180ed7c7 100644 --- a/examples/vizdoom/maps/spectator.py +++ b/examples/vizdoom/maps/spectator.py @@ -10,11 +10,12 @@ from __future__ import print_function +from argparse import ArgumentParser from time import sleep + import vizdoom as vzd -from argparse import ArgumentParser -# import cv2 +# import cv2 if __name__ == "__main__": parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.") diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py index a1e556fce..30cdb31bd 100755 --- a/examples/vizdoom/replay.py +++ b/examples/vizdoom/replay.py @@ -1,6 +1,7 @@ # import cv2 import sys import time + import tqdm import vizdoom as vzd diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 1123151c8..bb3a1f207 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -1,18 +1,18 @@ +import argparse import os -import torch import pprint -import argparse + import numpy as np +import torch +from env import Env +from network import C51 from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import C51Policy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv +from tianshou.policy import C51Policy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer - -from env import Env -from network import C51 +from tianshou.utils import TensorboardLogger def get_args(): @@ -40,15 +40,23 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--skip-num', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument('--watch', default=False, action='store_true', - help='watch the play of pre-trained policy only') - parser.add_argument('--save-lmp', default=False, action='store_true', - help='save lmp file for replay whole episode') + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + parser.add_argument( + '--save-lmp', + default=False, + action='store_true', + help='save lmp file for replay whole episode' + ) parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -64,26 +72,36 @@ def test_c51(args=get_args()): print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) # make environments - train_envs = SubprocVectorEnv([ - lambda: Env(args.cfg_path, args.frames_stack, args.res) - for _ in range(args.training_num)]) - test_envs = SubprocVectorEnv([ - lambda: Env(args.cfg_path, args.frames_stack, - args.res, args.save_lmp) - for _ in range(min(os.cpu_count() - 1, args.test_num))]) + train_envs = SubprocVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res) + for _ in range(args.training_num) + ] + ) + test_envs = SubprocVectorEnv( + [ + lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp) + for _ in range(min(os.cpu_count() - 1, args.test_num)) + ] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # define model - net = C51(*args.state_shape, args.action_shape, - args.num_atoms, args.device) + net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy = C51Policy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # load a previous policy if args.resume_path: @@ -92,8 +110,12 @@ def test_c51(args=get_args()): # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + args.buffer_size, + buffer_num=len(train_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) # collector train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -136,11 +158,13 @@ def watch(): if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(test_envs), - ignore_obs_next=True, save_only_last_obs=True, - stack_num=args.frames_stack) - collector = Collector(policy, test_envs, buffer, - exploration_noise=True) + args.buffer_size, + buffer_num=len(test_envs), + ignore_obs_next=True, + save_only_last_obs=True, + stack_num=args.frames_stack + ) + collector = Collector(policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size @@ -148,8 +172,9 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, - render=args.render) + result = test_collector.collect( + n_episode=args.test_num, render=args.render + ) rew = result["rews"].mean() lens = result["lens"].mean() * args.skip_num print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') @@ -163,11 +188,22 @@ def watch(): train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) pprint.pprint(result) watch() diff --git a/setup.py b/setup.py index 9488d27a0..03a2d2186 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- import os -from setuptools import setup, find_packages + +from setuptools import find_packages, setup def get_version() -> str: diff --git a/test/base/env.py b/test/base/env.py index 1151f5b76..e14a7246a 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -1,19 +1,27 @@ -import gym -import time import random -import numpy as np -import networkx as nx +import time from copy import deepcopy -from gym.spaces import Discrete, MultiDiscrete, Box, Dict, Tuple + +import gym +import networkx as nx +import numpy as np +from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple class MyTestEnv(gym.Env): """This is a "going right" task. The task is to go right ``size`` steps. """ - - def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, - ma_rew=0, multidiscrete_action=False, random_sleep=False, - array_state=False): + def __init__( + self, + size, + sleep=0, + dict_state=False, + recurse_state=False, + ma_rew=0, + multidiscrete_action=False, + random_sleep=False, + array_state=False + ): assert dict_state + recurse_state + array_state <= 1, \ "dict_state / recurse_state / array_state can be only one true" self.size = size @@ -28,17 +36,32 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, self.steps = 0 if dict_state: self.observation_space = Dict( - {"index": Box(shape=(1, ), low=0, high=size - 1), - "rand": Box(shape=(1,), low=0, high=1, dtype=np.float64)}) + { + "index": Box(shape=(1, ), low=0, high=size - 1), + "rand": Box(shape=(1, ), low=0, high=1, dtype=np.float64) + } + ) elif recurse_state: self.observation_space = Dict( - {"index": Box(shape=(1, ), low=0, high=size - 1), - "dict": Dict({ - "tuple": Tuple((Discrete(2), Box(shape=(2,), - low=0, high=1, dtype=np.float64))), - "rand": Box(shape=(1, 2), low=0, high=1, - dtype=np.float64)}) - }) + { + "index": + Box(shape=(1, ), low=0, high=size - 1), + "dict": + Dict( + { + "tuple": + Tuple( + ( + Discrete(2), + Box(shape=(2, ), low=0, high=1, dtype=np.float64) + ) + ), + "rand": + Box(shape=(1, 2), low=0, high=1, dtype=np.float64) + } + ) + } + ) elif array_state: self.observation_space = Box(shape=(4, 84, 84), low=0, high=255) else: @@ -70,13 +93,18 @@ def _get_reward(self): def _get_state(self): """Generate state(observation) of MyTestEnv""" if self.dict_state: - return {'index': np.array([self.index], dtype=np.float32), - 'rand': self.rng.rand(1)} + return { + 'index': np.array([self.index], dtype=np.float32), + 'rand': self.rng.rand(1) + } elif self.recurse_state: - return {'index': np.array([self.index], dtype=np.float32), - 'dict': {"tuple": (np.array([1], - dtype=int), self.rng.rand(2)), - "rand": self.rng.rand(1, 2)}} + return { + 'index': np.array([self.index], dtype=np.float32), + 'dict': { + "tuple": (np.array([1], dtype=int), self.rng.rand(2)), + "rand": self.rng.rand(1, 2) + } + } elif self.array_state: img = np.zeros([4, 84, 84], int) img[3, np.arange(84), np.arange(84)] = self.index diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 15357f16c..53ee8ffa3 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,13 +1,14 @@ -import sys import copy -import torch import pickle -import pytest -import numpy as np -import networkx as nx +import sys from itertools import starmap -from tianshou.data import Batch, to_torch, to_numpy +import networkx as nx +import numpy as np +import pytest +import torch + +from tianshou.data import Batch, to_numpy, to_torch def test_batch(): @@ -99,10 +100,13 @@ def test_batch(): assert batch_item.a.c == batch_dict['c'] assert isinstance(batch_item.a.d, torch.Tensor) assert batch_item.a.d == batch_dict['d'] - batch2 = Batch(a=[{ - 'b': np.float64(1.0), - 'c': np.zeros(1), - 'd': Batch(e=np.array(3.0))}]) + batch2 = Batch( + a=[{ + 'b': np.float64(1.0), + 'c': np.zeros(1), + 'd': Batch(e=np.array(3.0)) + }] + ) assert len(batch2) == 1 assert Batch().shape == [] assert Batch(a=1).shape == [] @@ -141,9 +145,12 @@ def test_batch(): assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): batch2 += [1] - batch3 = Batch(a={ - 'c': np.zeros(1), - 'd': Batch(e=np.array([0.0]), f=np.array([3.0]))}) + batch3 = Batch( + a={ + 'c': np.zeros(1), + 'd': Batch(e=np.array([0.0]), f=np.array([3.0])) + } + ) batch3.a.d[0] = {'e': 4.0} assert batch3.a.d.e[0] == 4.0 batch3.a.d[0] = Batch(f=5.0) @@ -202,7 +209,7 @@ def test_batch_over_batch(): assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) - batch4 = Batch(({'a': {'b': np.array([1.0])}},)) + batch4 = Batch(({'a': {'b': np.array([1.0])}}, )) assert batch4.a.b.ndim == 2 assert batch4.a.b[0, 0] == 1.0 # advanced slicing @@ -239,14 +246,14 @@ def test_batch_cat_and_stack(): a = Batch(a=Batch(a=np.random.randn(3, 4))) assert np.allclose( np.concatenate([a.a.a, a.a.a]), - Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a) + Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a + ) # test cat with lens infer a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) ans = Batch.cat([a, b, a]) - assert np.allclose(ans.a.a, - np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) + assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() @@ -258,51 +265,61 @@ def test_batch_cat_and_stack(): b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) - ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + ans = Batch( + a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) + ) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with reserved keys (values are Batch()) b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(a=Batch(), - b=torch.rand(4, 3), - common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) - ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + ans = Batch( + a=np.concatenate([b1.a, np.zeros((4, 4))]), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) + ) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test cat with all reserved keys (values are Batch()) b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) - b2 = Batch(a=Batch(), - b=torch.rand(4, 3), - common=Batch(c=np.random.rand(4, 5))) + b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) test = Batch.cat([b1, b2]) - ans = Batch(a=Batch(), - b=torch.cat([torch.zeros(3, 3), b2.b]), - common=Batch(c=np.concatenate([b1.common.c, b2.common.c]))) + ans = Batch( + a=Batch(), + b=torch.cat([torch.zeros(3, 3), b2.b]), + common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) + ) assert ans.a.is_empty() assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) # test stack with compatible keys - b3 = Batch(a=np.zeros((3, 4)), - b=torch.ones((2, 5)), - c=Batch(d=[[1], [2]])) - b4 = Batch(a=np.ones((3, 4)), - b=torch.ones((2, 5)), - c=Batch(d=[[0], [3]])) + b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) + b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) b34_stack = Batch.stack((b3, b4), axis=1) assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) - b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, - {'a': True, 'b': {'c': 3.0}}]) + b5_dict = np.array( + [{ + 'a': False, + 'b': { + 'c': 2.0, + 'd': 1.0 + } + }, { + 'a': True, + 'b': { + 'c': 3.0 + } + }] + ) b5 = Batch(b5_dict) assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) @@ -335,15 +352,16 @@ def test_batch_cat_and_stack(): test = Batch.stack([b1, b2], axis=-1) assert test.a.is_empty() assert test.b.is_empty() - assert np.allclose(test.common.c, - np.stack([b1.common.c, b2.common.c], axis=-1)) + assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1)) b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) test = Batch.stack([b1, b2]) - ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]), - b=torch.stack([torch.zeros(4, 6), b2.b]), - common=Batch(c=np.stack([b1.common.c, b2.common.c]))) + ans = Batch( + a=np.stack([b1.a, np.zeros((4, 4))]), + b=torch.stack([torch.zeros(4, 6), b2.b]), + common=Batch(c=np.stack([b1.common.c, b2.common.c])) + ) assert np.allclose(test.a, ans.a) assert torch.allclose(test.b, ans.b) assert np.allclose(test.common.c, ans.common.c) @@ -369,8 +387,8 @@ def test_batch_over_batch_to_torch(): batch = Batch( a=np.float64(1.0), b=Batch( - c=np.ones((1,), dtype=np.float32), - d=torch.ones((1,), dtype=torch.float64) + c=np.ones((1, ), dtype=np.float32), + d=torch.ones((1, ), dtype=torch.float64) ) ) batch.b.__dict__['e'] = 1 # bypass the check @@ -397,8 +415,8 @@ def test_utils_to_torch_numpy(): batch = Batch( a=np.float64(1.0), b=Batch( - c=np.ones((1,), dtype=np.float32), - d=torch.ones((1,), dtype=torch.float64) + c=np.ones((1, ), dtype=np.float32), + d=torch.ones((1, ), dtype=torch.float64) ) ) a_torch_float = to_torch(batch.a, dtype=torch.float32) @@ -464,8 +482,7 @@ def test_utils_to_torch_numpy(): def test_batch_pickle(): - batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), - np=np.zeros([3, 4])) + batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4])) batch_pk = pickle.loads(pickle.dumps(batch)) assert batch.obs.a == batch_pk.obs.a assert torch.all(batch.obs.c == batch_pk.obs.c) @@ -473,7 +490,7 @@ def test_batch_pickle(): def test_batch_from_to_numpy_without_copy(): - batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) + batch = Batch(a=np.ones((1, )), b=Batch(c=np.ones((1, )))) a_mem_addr_orig = batch.a.__array_interface__['data'][0] c_mem_addr_orig = batch.b.c.__array_interface__['data'][0] batch.to_torch() @@ -517,19 +534,35 @@ def test_batch_copy(): def test_batch_empty(): - b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}}, - {'a': True, 'b': {'c': 3.0}}]) + b5_dict = np.array( + [{ + 'a': False, + 'b': { + 'c': 2.0, + 'd': 1.0 + } + }, { + 'a': True, + 'b': { + 'c': 3.0 + } + }] + ) b5 = Batch(b5_dict) b5[1] = Batch.empty(b5[0]) assert np.allclose(b5.a, [False, False]) assert np.allclose(b5.b.c, [2, 0]) assert np.allclose(b5.b.d, [1, 0]) - data = Batch(a=[False, True], - b={'c': np.array([2., 'st'], dtype=object), - 'd': [1, None], - 'e': [2., float('nan')]}, - c=np.array([1, 3, 4], dtype=int), - t=torch.tensor([4, 5, 6, 7.])) + data = Batch( + a=[False, True], + b={ + 'c': np.array([2., 'st'], dtype=object), + 'd': [1, None], + 'e': [2., float('nan')] + }, + c=np.array([1, 3, 4], dtype=int), + t=torch.tensor([4, 5, 6, 7.]) + ) data[-1] = Batch.empty(data[1]) assert np.allclose(data.c, [1, 3, 0]) assert np.allclose(data.a, [False, False]) @@ -550,9 +583,9 @@ def test_batch_empty(): def test_batch_standard_compatibility(): - batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), - b=Batch(), - c=np.array([5.0, 6.0])) + batch = Batch( + a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]) + ) batch_mean = np.mean(batch) assert isinstance(batch_mean, Batch) assert sorted(batch_mean.keys()) == ['a', 'b', 'c'] diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 39e5badd5..2c54cbbab 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,18 +1,23 @@ import os -import h5py -import torch import pickle -import pytest import tempfile -import numpy as np from timeit import timeit -from tianshou.data.utils.converter import to_hdf5 -from tianshou.data import Batch, SegmentTree, ReplayBuffer -from tianshou.data import PrioritizedReplayBuffer -from tianshou.data import VectorReplayBuffer, CachedReplayBuffer -from tianshou.data import PrioritizedVectorReplayBuffer +import h5py +import numpy as np +import pytest +import torch +from tianshou.data import ( + Batch, + CachedReplayBuffer, + PrioritizedReplayBuffer, + PrioritizedVectorReplayBuffer, + ReplayBuffer, + SegmentTree, + VectorReplayBuffer, +) +from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': from env import MyTestEnv @@ -29,8 +34,9 @@ def test_replaybuffer(size=10, bufsize=20): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(Batch(obs=obs, act=[a], rew=rew, - done=done, obs_next=obs_next, info=info)) + buf.add( + Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info) + ) obs = obs_next assert len(buf) == min(bufsize, i + 1) assert buf.act.dtype == int @@ -43,8 +49,20 @@ def test_replaybuffer(size=10, bufsize=20): # neg bsz should return empty index assert b.sample_indices(-1).tolist() == [] ptr, ep_rew, ep_len, ep_idx = b.add( - Batch(obs=1, act=1, rew=1, done=1, obs_next='str', - info={'a': 3, 'b': {'c': 5.0}})) + Batch( + obs=1, + act=1, + rew=1, + done=1, + obs_next='str', + info={ + 'a': 3, + 'b': { + 'c': 5.0 + } + } + ) + ) assert b.obs[0] == 1 assert b.done[0] assert b.obs_next[0] == 'str' @@ -54,13 +72,24 @@ def test_replaybuffer(size=10, bufsize=20): assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == float assert np.all(b.info.b.c[1:] == 0.0) - assert ptr.shape == (1,) and ptr[0] == 0 - assert ep_rew.shape == (1,) and ep_rew[0] == 1 - assert ep_len.shape == (1,) and ep_len[0] == 1 - assert ep_idx.shape == (1,) and ep_idx[0] == 0 + assert ptr.shape == (1, ) and ptr[0] == 0 + assert ep_rew.shape == (1, ) and ep_rew[0] == 1 + assert ep_len.shape == (1, ) and ep_len[0] == 1 + assert ep_idx.shape == (1, ) and ep_idx[0] == 0 # test extra keys pop up, the buffer should handle it dynamically - batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", - info={"a": 4, "d": {"e": -np.inf}}) + batch = Batch( + obs=2, + act=2, + rew=2, + done=0, + obs_next="str2", + info={ + "a": 4, + "d": { + "e": -np.inf + } + } + ) b.add(batch) info_keys = ["a", "b", "d"] assert set(b.info.keys()) == set(info_keys) @@ -71,10 +100,10 @@ def test_replaybuffer(size=10, bufsize=20): batch.info.e = np.zeros([1, 4]) batch = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) - assert ptr.shape == (1,) and ptr[0] == 2 - assert ep_rew.shape == (1,) and ep_rew[0] == 4 - assert ep_len.shape == (1,) and ep_len[0] == 2 - assert ep_idx.shape == (1,) and ep_idx[0] == 1 + assert ptr.shape == (1, ) and ptr[0] == 2 + assert ep_rew.shape == (1, ) and ep_rew[0] == 4 + assert ep_len.shape == (1, ) and ep_len[0] == 2 + assert ep_idx.shape == (1, ) and ep_idx[0] == 1 assert set(b.info.keys()) == set(info_keys + ["e"]) assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): @@ -92,14 +121,22 @@ def test_ignore_obs_next(size=10): # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): - buf.add(Batch(obs={'mask1': np.array([i, 1, 1, 0, 0]), - 'mask2': np.array([i + 4, 0, 1, 0, 0]), - 'mask': i}, - act={'act_id': i, - 'position_id': i + 3}, - rew=i, - done=i % 3 == 0, - info={'if': i})) + buf.add( + Batch( + obs={ + 'mask1': np.array([i, 1, 1, 0, 0]), + 'mask2': np.array([i + 4, 0, 1, 0, 0]), + 'mask': i + }, + act={ + 'act_id': i, + 'position_id': i + 3 + }, + rew=i, + done=i % 3 == 0, + info={'if': i} + ) + ) indices = np.arange(len(buf)) orig = np.arange(len(buf)) data = buf[indices] @@ -113,15 +150,25 @@ def test_ignore_obs_next(size=10): data = buf[indices] data2 = buf[indices] assert np.allclose(data.obs_next.mask, data2.obs_next.mask) - assert np.allclose(data.obs_next.mask, np.array([ - [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], - [4, 4, 4, 5], [4, 4, 5, 6], [4, 4, 5, 6], - [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9]])) + assert np.allclose( + data.obs_next.mask, + np.array( + [ + [0, 0, 0, 0], [1, 1, 1, 2], [1, 1, 2, 3], [1, 1, 2, 3], [4, 4, 4, 5], + [4, 4, 5, 6], [4, 4, 5, 6], [7, 7, 7, 8], [7, 7, 8, 9], [7, 7, 8, 9] + ] + ) + ) assert np.allclose(data.info['if'], data2.info['if']) - assert np.allclose(data.info['if'], np.array([ - [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [4, 4, 4, 4], [4, 4, 4, 5], [4, 4, 5, 6], - [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9]])) + assert np.allclose( + data.info['if'], + np.array( + [ + [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], [4, 4, 4, 4], + [4, 4, 4, 5], [4, 4, 5, 6], [7, 7, 7, 7], [7, 7, 7, 8], [7, 7, 8, 9] + ] + ) + ) assert data.obs_next @@ -135,16 +182,26 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): obs_next, rew, done, info = env.step(1) buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) - buf3.add(Batch(obs=[obs, obs, obs], act=1, rew=rew, - done=done, obs_next=[obs, obs], info=info)) + buf3.add( + Batch( + obs=[obs, obs, obs], + act=1, + rew=rew, + done=done, + obs_next=[obs, obs], + info=info + ) + ) obs = obs_next if done: obs = env.reset(1) indices = np.arange(len(buf)) - assert np.allclose(buf.get(indices, 'obs')[..., 0], [ - [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], - [1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3], - [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]]) + assert np.allclose( + buf.get(indices, 'obs')[..., 0], [ + [1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4], [1, 1, 1, 1], [1, 1, 1, 2], + [1, 1, 2, 3], [1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1] + ] + ) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs')) assert np.allclose(buf.get(indices, 'obs'), buf3.get(indices, 'obs_next')) _, indices = buf2.sample(0) @@ -165,8 +222,15 @@ def test_priortized_replaybuffer(size=32, bufsize=15): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - batch = Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, - info=info, policy=np.random.randn() - 0.5) + batch = Batch( + obs=obs, + act=a, + rew=rew, + done=done, + obs_next=obs_next, + info=info, + policy=np.random.randn() - 0.5 + ) batch_stack = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) buf2.add(batch_stack, buffer_ids=[0, 1, 2]) @@ -179,12 +243,12 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert len(buf) == min(bufsize, i + 1) assert len(buf2) == min(bufsize, 3 * (i + 1)) # check single buffer's data - assert buf.info.key.shape == (buf.maxsize,) + assert buf.info.key.shape == (buf.maxsize, ) assert buf.rew.dtype == float assert buf.done.dtype == bool data, indices = buf.sample(len(buf) // 2) buf.update_weight(indices, -data.weight / 2) - assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2) ** buf._alpha) + assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2)**buf._alpha) # check multi buffer's data assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) batch, indices = buf2.sample(10) @@ -200,8 +264,15 @@ def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): - buf1.add(Batch(obs=np.array([i]), act=float(i), rew=i * i, - done=i % 2 == 0, info={'incident': 'found'})) + buf1.add( + Batch( + obs=np.array([i]), + act=float(i), + rew=i * i, + done=i % 2 == 0, + info={'incident': 'found'} + ) + ) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) @@ -245,8 +316,7 @@ def test_segtree(): for i in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) - assert np.allclose(realop(naive[left:right]), - tree.reduce(left, right)) + assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # large test actual_len = 16384 tree = SegmentTree(actual_len) @@ -260,8 +330,7 @@ def test_segtree(): for i in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) - assert np.allclose(realop(naive[left:right]), - tree.reduce(left, right)) + assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # test prefix-sum-idx actual_len = 8 @@ -280,8 +349,9 @@ def test_segtree(): assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() tree = SegmentTree(10) tree[np.arange(3)] = np.array([0.1, 0, 0.1]) - assert np.allclose(tree.get_prefix_sum_idx( - np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]) + assert np.allclose( + tree.get_prefix_sum_idx(np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2] + ) with pytest.raises(AssertionError): tree.get_prefix_sum_idx(.2) # test large prefix-sum-idx @@ -321,8 +391,15 @@ def test_pickle(): for i in range(4): vbuf.add(Batch(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)) for i in range(5): - pbuf.add(Batch(obs=Batch(index=np.array([i])), - act=2, rew=rew, done=0, info=np.random.rand())) + pbuf.add( + Batch( + obs=Batch(index=np.array([i])), + act=2, + rew=rew, + done=0, + info=np.random.rand() + ) + ) # save & load _vbuf = pickle.loads(pickle.dumps(vbuf)) _pbuf = pickle.loads(pickle.dumps(pbuf)) @@ -330,8 +407,9 @@ def test_pickle(): assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) # make sure the meta var is identical assert _vbuf.stack_num == vbuf.stack_num - assert np.allclose(_pbuf.weight[np.arange(len(_pbuf))], - pbuf.weight[np.arange(len(pbuf))]) + assert np.allclose( + _pbuf.weight[np.arange(len(_pbuf))], pbuf.weight[np.arange(len(pbuf))] + ) def test_hdf5(): @@ -349,7 +427,13 @@ def test_hdf5(): 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, - 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, + 'info': { + "number": { + "n": i, + "t": info_t + }, + 'extra': None + }, } buffers["array"].add(Batch(kwargs)) buffers["prioritized"].add(Batch(kwargs)) @@ -377,10 +461,8 @@ def test_hdf5(): assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: - assert np.all( - buffers[k][:].info.number.n == _buffers[k][:].info.number.n) - assert np.all( - buffers[k][:].info.extra == _buffers[k][:].info.extra) + assert np.all(buffers[k][:].info.number.n == _buffers[k][:].info.number.n) + assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} @@ -423,15 +505,16 @@ def test_replaybuffermanager(): indices_next = buf.next(indices) assert np.allclose(indices_next, indices), indices_next data = np.array([0, 0, 0, 0]) - buf.add(Batch(obs=data, act=data, rew=data, done=data), - buffer_ids=[0, 1, 2, 3]) - buf.add(Batch(obs=data, act=data, rew=data, done=1 - data), - buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=data), buffer_ids=[0, 1, 2, 3]) + buf.add( + Batch(obs=data, act=data, rew=data, done=1 - data), buffer_ids=[0, 1, 2, 3] + ) assert len(buf) == 12 - buf.add(Batch(obs=data, act=data, rew=data, done=data), - buffer_ids=[0, 1, 2, 3]) - buf.add(Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]), - buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=data), buffer_ids=[0, 1, 2, 3]) + buf.add( + Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]), + buffer_ids=[0, 1, 2, 3] + ) assert len(buf) == 20 indices = buf.sample_indices(120000) assert np.bincount(indices).min() >= 5000 @@ -439,44 +522,135 @@ def test_replaybuffermanager(): indices = buf.sample_indices(0) assert np.allclose(indices, np.arange(len(buf))) # check the actual data stored in buf._meta - assert np.allclose(buf.done, [ - 0, 0, 1, 0, 0, - 0, 0, 1, 0, 1, - 1, 0, 1, 0, 0, - 1, 0, 1, 0, 1, - ]) - assert np.allclose(buf.prev(indices), [ - 0, 0, 1, 3, 3, - 5, 5, 6, 8, 8, - 10, 11, 11, 13, 13, - 15, 16, 16, 18, 18, - ]) - assert np.allclose(buf.next(indices), [ - 1, 2, 2, 4, 4, - 6, 7, 7, 9, 9, - 10, 12, 12, 14, 14, - 15, 17, 17, 19, 19, - ]) + assert np.allclose( + buf.done, [ + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + ] + ) + assert np.allclose( + buf.prev(indices), [ + 0, + 0, + 1, + 3, + 3, + 5, + 5, + 6, + 8, + 8, + 10, + 11, + 11, + 13, + 13, + 15, + 16, + 16, + 18, + 18, + ] + ) + assert np.allclose( + buf.next(indices), [ + 1, + 2, + 2, + 4, + 4, + 6, + 7, + 7, + 9, + 9, + 10, + 12, + 12, + 14, + 14, + 15, + 17, + 17, + 19, + 19, + ] + ) assert np.allclose(buf.unfinished_index(), [4, 14]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2]) + Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2] + ) assert np.all(ep_len == [3]) and np.all(ep_rew == [1]) assert np.all(ptr == [10]) and np.all(ep_idx == [13]) assert np.allclose(buf.unfinished_index(), [4]) indices = list(sorted(buf.sample_indices(0))) assert np.allclose(indices, np.arange(len(buf))) - assert np.allclose(buf.prev(indices), [ - 0, 0, 1, 3, 3, - 5, 5, 6, 8, 8, - 14, 11, 11, 13, 13, - 15, 16, 16, 18, 18, - ]) - assert np.allclose(buf.next(indices), [ - 1, 2, 2, 4, 4, - 6, 7, 7, 9, 9, - 10, 12, 12, 14, 10, - 15, 17, 17, 19, 19, - ]) + assert np.allclose( + buf.prev(indices), [ + 0, + 0, + 1, + 3, + 3, + 5, + 5, + 6, + 8, + 8, + 14, + 11, + 11, + 13, + 13, + 15, + 16, + 16, + 18, + 18, + ] + ) + assert np.allclose( + buf.next(indices), [ + 1, + 2, + 2, + 4, + 4, + 6, + 7, + 7, + 9, + 9, + 10, + 12, + 12, + 14, + 10, + 15, + 17, + 17, + 19, + 19, + ] + ) # corner case: list, int and -1 assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] @@ -493,7 +667,8 @@ def test_cachedbuffer(): assert buf.sample_indices(0).tolist() == [] # check the normal function/usage/storage in CachedReplayBuffer ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1]) + Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1] + ) obs = np.zeros(buf.maxsize) obs[15] = 1 indices = buf.sample_indices(0) @@ -504,7 +679,8 @@ def test_cachedbuffer(): assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0]) assert np.all(ptr == [15]) and np.all(ep_idx == [15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3]) + Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3] + ) obs[[0, 25]] = 2 indices = buf.sample_indices(0) assert np.allclose(indices, [0, 15]) @@ -516,8 +692,8 @@ def test_cachedbuffer(): assert np.allclose(buf.unfinished_index(), [15]) assert np.allclose(buf.sample_indices(0), [0, 15]) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), - buffer_ids=[3, 1]) + Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), buffer_ids=[3, 1] + ) assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0]) assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] @@ -540,16 +716,35 @@ def test_cachedbuffer(): buf.add(Batch(obs=data, act=data, rew=rew, done=[1, 1, 1, 1])) buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) ptr, ep_rew, ep_len, ep_idx = buf.add( - Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1])) + Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1]) + ) assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1]) assert np.all(ep_len == [0, 2, 0, 2]) assert np.all(ep_rew == [data, data + 2, data, data + 2]) - assert np.allclose(buf.done, [ - 0, 0, 1, 0, 0, - 0, 1, 1, 0, 0, - 0, 0, 0, 0, 0, - 0, 1, 0, 0, 0, - ]) + assert np.allclose( + buf.done, [ + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + ] + ) indices = buf.sample_indices(0) assert np.allclose(indices, [0, 1, 10, 11]) assert np.allclose(buf.prev(indices), [0, 0, 10, 10]) @@ -564,14 +759,16 @@ def test_multibuf_stack(): env = MyTestEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), - cached_num, size) + ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), cached_num, + size + ) # test if CachedReplayBuffer can handle corner case: # buffer + stack_num + ignore_obs_next + sample_avail buf5 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, - ignore_obs_next=True, sample_avail=True), - cached_num, size) + ReplayBuffer( + bufsize, stack_num=stack_num, ignore_obs_next=True, sample_avail=True + ), cached_num, size + ) obs = env.reset(1) for i in range(18): obs_next, rew, done, info = env.step(1) @@ -581,8 +778,14 @@ def test_multibuf_stack(): done_list = [done] * cached_num obs_next_list = -obs_list info_list = [info] * cached_num - batch = Batch(obs=obs_list, act=act_list, rew=rew_list, - done=done_list, obs_next=obs_next_list, info=info_list) + batch = Batch( + obs=obs_list, + act=act_list, + rew=rew_list, + done=done_list, + obs_next=obs_next_list, + info=info_list + ) buf5.add(batch) buf4.add(batch) assert np.all(buf4.obs == buf5.obs) @@ -591,35 +794,105 @@ def test_multibuf_stack(): if done: obs = env.reset(1) # check the `add` order is correct - assert np.allclose(buf4.obs.reshape(-1), [ - 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer - 1, 2, 3, 4, 0, # cached_buffer[0] - 6, 7, 8, 9, 0, # cached_buffer[1] - 11, 12, 13, 14, 0, # cached_buffer[2] - ]), buf4.obs - assert np.allclose(buf4.done, [ - 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer - 0, 0, 0, 1, 0, # cached_buffer[0] - 0, 0, 0, 1, 0, # cached_buffer[1] - 0, 0, 0, 1, 0, # cached_buffer[2] - ]), buf4.done + assert np.allclose( + buf4.obs.reshape(-1), + [ + 12, + 13, + 14, + 4, + 6, + 7, + 8, + 9, + 11, # main_buffer + 1, + 2, + 3, + 4, + 0, # cached_buffer[0] + 6, + 7, + 8, + 9, + 0, # cached_buffer[1] + 11, + 12, + 13, + 14, + 0, # cached_buffer[2] + ] + ), buf4.obs + assert np.allclose( + buf4.done, + [ + 0, + 0, + 1, + 1, + 0, + 0, + 0, + 1, + 0, # main_buffer + 0, + 0, + 0, + 1, + 0, # cached_buffer[0] + 0, + 0, + 0, + 1, + 0, # cached_buffer[1] + 0, + 0, + 0, + 1, + 0, # cached_buffer[2] + ] + ), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) indices = sorted(buf4.sample_indices(0)) assert np.allclose(indices, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) - assert np.allclose(buf4[indices].obs[..., 0], [ - [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], - [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], - [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], - [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], - [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], - ]) - assert np.allclose(buf4[indices].obs_next[..., 0], [ - [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], - [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], - [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], - [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], - [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], - ]) + assert np.allclose( + buf4[indices].obs[..., 0], [ + [11, 11, 11, 12], + [11, 11, 12, 13], + [11, 12, 13, 14], + [4, 4, 4, 4], + [6, 6, 6, 6], + [6, 6, 6, 7], + [6, 6, 7, 8], + [6, 7, 8, 9], + [11, 11, 11, 11], + [1, 1, 1, 1], + [1, 1, 1, 2], + [6, 6, 6, 6], + [6, 6, 6, 7], + [11, 11, 11, 11], + [11, 11, 11, 12], + ] + ) + assert np.allclose( + buf4[indices].obs_next[..., 0], [ + [11, 11, 12, 13], + [11, 12, 13, 14], + [11, 12, 13, 14], + [4, 4, 4, 4], + [6, 6, 6, 7], + [6, 6, 7, 8], + [6, 7, 8, 9], + [6, 7, 8, 9], + [11, 11, 11, 12], + [1, 1, 1, 2], + [1, 1, 1, 2], + [6, 6, 6, 7], + [6, 6, 6, 7], + [11, 11, 11, 12], + [11, 11, 11, 12], + ] + ) indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [2, 7]) assert np.all(np.isin(buf5.sample_indices(100), indices)) @@ -632,12 +905,24 @@ def test_multibuf_stack(): batch, _ = buf5.sample(0) # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( - ReplayBuffer(bufsize, stack_num=stack_num, - save_only_last_obs=True, ignore_obs_next=True), - cached_num, size) + ReplayBuffer( + bufsize, + stack_num=stack_num, + save_only_last_obs=True, + ignore_obs_next=True + ), cached_num, size + ) obs = np.random.rand(size, 4, 84, 84) - buf6.add(Batch(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], - obs_next=[obs[3], obs[1]]), buffer_ids=[1, 2]) + buf6.add( + Batch( + obs=[obs[2], obs[0]], + act=[1, 1], + rew=[0, 0], + done=[0, 1], + obs_next=[obs[3], obs[1]] + ), + buffer_ids=[1, 2] + ) assert buf6.obs.shape == (buf6.maxsize, 84, 84) assert np.allclose(buf6.obs[0], obs[0, -1]) assert np.allclose(buf6.obs[14], obs[2, -1]) @@ -660,12 +945,20 @@ def test_multibuf_hdf5(): 'act': i, 'rew': np.array([1, 2]), 'done': i % 3 == 2, - 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, + 'info': { + "number": { + "n": i, + "t": info_t + }, + 'extra': None + }, } - buffers["vector"].add(Batch.stack([kwargs, kwargs, kwargs]), - buffer_ids=[0, 1, 2]) - buffers["cached"].add(Batch.stack([kwargs, kwargs, kwargs]), - buffer_ids=[0, 1, 2]) + buffers["vector"].add( + Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2] + ) + buffers["cached"].add( + Batch.stack([kwargs, kwargs, kwargs]), buffer_ids=[0, 1, 2] + ) # save paths = {} @@ -696,7 +989,12 @@ def test_multibuf_hdf5(): 'act': 5, 'rew': np.array([2, 1]), 'done': False, - 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, + 'info': { + "number": { + "n": i + }, + 'Timelimit.truncate': True + }, } buffers[k].add(Batch.stack([kwargs, kwargs, kwargs, kwargs])) act = np.zeros(buffers[k].maxsize) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 79d1430b8..da54d4b3a 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,17 +1,19 @@ -import tqdm -import pytest import numpy as np +import pytest +import tqdm from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import BasePolicy -from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Batch, Collector, AsyncCollector from tianshou.data import ( - ReplayBuffer, + AsyncCollector, + Batch, + CachedReplayBuffer, + Collector, PrioritizedReplayBuffer, + ReplayBuffer, VectorReplayBuffer, - CachedReplayBuffer, ) +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import BasePolicy if __name__ == '__main__': from env import MyTestEnv, NXEnv @@ -56,8 +58,7 @@ def preprocess_fn(self, **kwargs): info = kwargs['info'] info.rew = kwargs['rew'] if 'key' in info.keys(): - self.writer.add_scalar( - 'key', np.mean(info.key), global_step=self.cnt) + self.writer.add_scalar('key', np.mean(info.key), global_step=self.cnt) self.cnt += 1 return Batch(info=info) else: @@ -91,13 +92,12 @@ def test_collector(): c0.collect(n_episode=3) assert len(c0.buffer) == 8 assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], - [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) c1 = Collector( - policy, venv, - VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn) + policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn + ) c1.collect(n_step=8) obs = np.zeros(100) obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] @@ -108,13 +108,15 @@ def test_collector(): assert len(c1.buffer) == 16 obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c1.buffer.obs[:, 0], obs) - assert np.allclose(c1.buffer[:].obs_next[..., 0], - [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) + assert np.allclose( + c1.buffer[:].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5] + ) c1.collect(n_episode=4, random=True) c2 = Collector( - policy, dum, - VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn) + policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn + ) c2.collect(n_episode=7) obs1 = obs.copy() obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] @@ -139,10 +141,10 @@ def test_collector(): # test NXEnv for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([ - lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]) - c3 = Collector(policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4)) + envs = SubprocVectorEnv( + [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]] + ) + c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=6) assert c3.buffer.obs.dtype == object @@ -151,23 +153,23 @@ def test_collector_with_async(): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) - for i in env_lens] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens + ] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() bufsize = 60 c1 = AsyncCollector( - policy, venv, - VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), - logger.preprocess_fn) + policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), + logger.preprocess_fn + ) ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): result = c1.collect(n_episode=n_episode) assert result["n/ep"] >= n_episode # check buffer data, obs and obs_next, env_id - for i, count in enumerate( - np.bincount(result["lens"], minlength=6)[2:]): + for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize @@ -176,8 +178,7 @@ def test_collector_with_async(): buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.obs[indices].reshape(count, env_len) == seq) - assert np.all(buf.obs_next[indices].reshape( - count, env_len) == seq + 1) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 15, desc="test async n_step"): result = c1.collect(n_step=n_step) @@ -196,21 +197,21 @@ def test_collector_with_async(): def test_collector_with_dict_state(): env = MyTestEnv(size=5, sleep=0, dict_state=True) policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100), - Logger.single_preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=2) assert len(c0.buffer) == 10 - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) - for i in [2, 3, 4, 5]] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5] + ] envs = DummyVectorEnv(env_fns) envs.seed(666) obs = envs.reset() assert not np.isclose(obs[0]['rand'], obs[1]['rand']) c1 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn + ) c1.collect(n_step=12) result = c1.collect(n_episode=8) assert result['n/ep'] == 8 @@ -221,25 +222,104 @@ def test_collector_with_dict_state(): c0.buffer.update(c1.buffer) assert len(c0.buffer) in [42, 43] if len(c0.buffer) == 42: - assert np.all(c0.buffer[:].obs.index[..., 0] == [ - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - 0, 1, 0, 1, 0, 1, 0, 1, - 0, 1, 2, 0, 1, 2, - 0, 1, 2, 3, 0, 1, 2, 3, - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - ]), c0.buffer[:].obs.index[..., 0] + assert np.all( + c0.buffer[:].obs.index[..., 0] == [ + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 2, + 0, + 1, + 2, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + ] + ), c0.buffer[:].obs.index[..., 0] else: - assert np.all(c0.buffer[:].obs.index[..., 0] == [ - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - 0, 1, 0, 1, 0, 1, - 0, 1, 2, 0, 1, 2, 0, 1, 2, - 0, 1, 2, 3, 0, 1, 2, 3, - 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, - ]), c0.buffer[:].obs.index[..., 0] + assert np.all( + c0.buffer[:].obs.index[..., 0] == [ + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 2, + 0, + 1, + 2, + 0, + 1, + 2, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 0, + 1, + 2, + 3, + 4, + 0, + 1, + 2, + 3, + 4, + ] + ), c0.buffer[:].obs.index[..., 0] c2 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn + ) c2.collect(n_episode=10) batch, _ = c2.buffer.sample(10) @@ -247,20 +327,18 @@ def test_collector_with_dict_state(): def test_collector_with_ma(): env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() - c0 = Collector(policy, env, ReplayBuffer(size=100), - Logger.single_preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) # n_step=3 will collect a full episode r = c0.collect(n_step=3)['rews'] assert len(r) == 0 r = c0.collect(n_episode=2)['rews'] assert r.shape == (2, 4) and np.all(r == 1) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) - for i in [2, 3, 4, 5]] + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c1 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn + ) r = c1.collect(n_step=12)['rews'] assert r.shape == (2, 4) and np.all(r == 1), r r = c1.collect(n_episode=8)['rews'] @@ -271,26 +349,101 @@ def test_collector_with_ma(): assert len(c0.buffer) in [42, 43] if len(c0.buffer) == 42: rew = [ - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, - 0, 0, 1, 0, 0, 1, - 0, 0, 0, 1, 0, 0, 0, 1, - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, ] else: rew = [ - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, - 0, 1, 0, 1, 0, 1, - 0, 0, 1, 0, 0, 1, 0, 0, 1, - 0, 0, 0, 1, 0, 0, 0, 1, - 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, ] assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) assert np.all(c0.buffer[:].done == rew) c2 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn) + policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn + ) r = c2.collect(n_episode=10)['rews'] assert r.shape == (10, 4) and np.all(r == 1) batch, _ = c2.buffer.sample(10) @@ -326,22 +479,23 @@ def test_collector_with_atari_setting(): c2 = Collector( policy, env, - ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True)) + ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True) + ) c2.collect(n_step=8) assert c2.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c2.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] assert np.all(c2.buffer.obs == obs) - assert np.allclose(c2.buffer[:].obs_next, - reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) + assert np.allclose( + c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1] + ) # atari multi buffer - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) - for i in [2, 3, 4, 5]] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5] + ] envs = DummyVectorEnv(env_fns) - c3 = Collector( - policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4)) + c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.collect(n_step=12) result = c3.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 @@ -360,8 +514,14 @@ def test_collector_with_atari_setting(): assert np.all(obs_next == c3.buffer.obs_next) c4 = Collector( policy, envs, - VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4, - ignore_obs_next=True, save_only_last_obs=True)) + VectorReplayBuffer( + total_size=100, + buffer_num=4, + stack_num=4, + ignore_obs_next=True, + save_only_last_obs=True + ) + ) c4.collect(n_step=12) result = c4.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 @@ -374,12 +534,45 @@ def test_collector_with_atari_setting(): obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] assert np.all(c4.buffer.obs == obs) obs_next = np.zeros([len(c4.buffer), 4, 84, 84]) - ref_index = np.array([ - 1, 1, 1, 1, 1, 1, 1, 1, - 1, 2, 2, 1, 2, 2, 1, 2, 2, - 1, 2, 3, 3, 1, 2, 3, 3, - 1, 2, 3, 4, 4, 1, 2, 3, 4, 4, - ]) + ref_index = np.array( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 2, + 2, + 1, + 2, + 2, + 1, + 2, + 2, + 1, + 2, + 3, + 3, + 1, + 2, + 3, + 3, + 1, + 2, + 3, + 4, + 4, + 1, + 2, + 3, + 4, + 4, + ] + ) obs_next[:, -1] = slice_obs[ref_index] ref_index -= 1 ref_index[ref_index < 0] = 0 @@ -392,20 +585,25 @@ def test_collector_with_atari_setting(): obs_next[:, -4] = slice_obs[ref_index] assert np.all(obs_next == c4.buffer[:].obs_next) - buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, - save_only_last_obs=True) + buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) result_ = c5.collect(n_step=12) assert len(buf) == 5 and len(c5.buffer) == 12 result = c5.collect(n_episode=9) assert result["n/ep"] == 9 and result["n/st"] == 23 assert len(buf) == 35 - assert np.all(buf.obs[:len(buf)] == slice_obs[[ - 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, - 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4]]) - assert np.all(buf[:].obs_next[:, -1] == slice_obs[[ - 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, - 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]]) + assert np.all( + buf.obs[:len(buf)] == slice_obs[[ + 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 0, 1, 0, 1, + 2, 3, 0, 1, 2, 0, 1, 2, 3, 4 + ]] + ) + assert np.all( + buf[:].obs_next[:, -1] == slice_obs[[ + 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, 1, 1, 1, 2, 2, 1, 1, 1, 2, + 3, 3, 1, 2, 2, 1, 2, 3, 4, 4 + ]] + ) assert len(buf) == len(c5.buffer) # test buffer=None diff --git a/test/base/test_env.py b/test/base/test_env.py index cc1dc84c7..4d5cc6a00 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,10 +1,11 @@ import sys import time + import numpy as np from gym.spaces.discrete import Discrete + from tianshou.data import Batch -from tianshou.env import DummyVectorEnv, SubprocVectorEnv, \ - ShmemVectorEnv, RayVectorEnv +from tianshou.env import DummyVectorEnv, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv if __name__ == '__main__': from env import MyTestEnv, NXEnv @@ -24,17 +25,14 @@ def recurse_comp(a, b): try: if isinstance(a, np.ndarray): if a.dtype == object: - return np.array( - [recurse_comp(m, n) for m, n in zip(a, b)]).all() + return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all() else: return np.allclose(a, b) elif isinstance(a, (list, tuple)): - return np.array( - [recurse_comp(m, n) for m, n in zip(a, b)]).all() + return np.array([recurse_comp(m, n) for m, n in zip(a, b)]).all() elif isinstance(a, dict): - return np.array( - [recurse_comp(a[k], b[k]) for k in a.keys()]).all() - except(Exception): + return np.array([recurse_comp(a[k], b[k]) for k in a.keys()]).all() + except (Exception): return False @@ -75,7 +73,7 @@ def test_async_env(size=10000, num=8, sleep=0.1): # truncate env_ids with the first terms # typically len(env_ids) == len(A) == len(action), except for the # last batch when actions are not enough - env_ids = env_ids[: len(action)] + env_ids = env_ids[:len(action)] spent_time = time.time() - spent_time Batch.cat(o) v.close() @@ -85,10 +83,12 @@ def test_async_env(size=10000, num=8, sleep=0.1): def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): - env_fns = [lambda: MyTestEnv(size=size, sleep=sleep * 2), - lambda: MyTestEnv(size=size, sleep=sleep * 3), - lambda: MyTestEnv(size=size, sleep=sleep * 5), - lambda: MyTestEnv(size=size, sleep=sleep * 7)] + env_fns = [ + lambda: MyTestEnv(size=size, sleep=sleep * 2), + lambda: MyTestEnv(size=size, sleep=sleep * 3), + lambda: MyTestEnv(size=size, sleep=sleep * 5), + lambda: MyTestEnv(size=size, sleep=sleep * 7) + ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] @@ -113,8 +113,10 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): t = time.time() - t ids = Batch(info).env_id print(ids, t) - if not (len(ids) == len(res) and np.allclose(sorted(ids), res) - and (t < timeout) == (len(res) == num - 1)): + if not ( + len(ids) == len(res) and np.allclose(sorted(ids), res) and + (t < timeout) == (len(res) == num - 1) + ): pass_check = 0 break total_pass += pass_check @@ -172,8 +174,9 @@ def test_vecenv(size=10, num=8, sleep=0.001): def test_env_obs(): for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([ - lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]) + envs = SubprocVectorEnv( + [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]] + ) obs = envs.reset() assert obs.dtype == object obs = envs.step([1, 1, 1, 1])[0] diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index b670d65e9..00b536cd3 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -1,14 +1,15 @@ # see issue #322 for detail -import gym import copy -import numpy as np from collections import Counter -from torch.utils.data import Dataset, DataLoader, DistributedSampler -from tianshou.policy import BasePolicy -from tianshou.data import Collector, Batch +import gym +import numpy as np +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from tianshou.data import Batch, Collector from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import BasePolicy class DummyDataset(Dataset): @@ -32,7 +33,8 @@ def __init__(self, dataset, num_replicas, rank): self.loader = DataLoader( dataset, sampler=DistributedSampler(dataset, num_replicas, rank), - batch_size=None) + batch_size=None + ) self.iterator = None def reset(self): @@ -79,6 +81,7 @@ def _get_default_obs(self): def _get_default_info(self): return copy.deepcopy(self._default_info) + # END def reset(self, id=None): @@ -179,30 +182,32 @@ def validate(self): def test_finite_dummy_vector_env(): dataset = DummyDataset(100) - envs = FiniteSubprocVectorEnv([ - _finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteSubprocVectorEnv( + [_finite_env_factory(dataset, 5, i) for i in range(5)] + ) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() def test_finite_subproc_vector_env(): dataset = DummyDataset(100) - envs = FiniteSubprocVectorEnv([ - _finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteSubprocVectorEnv( + [_finite_env_factory(dataset, 5, i) for i in range(5)] + ) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10 ** 18) + test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() diff --git a/test/base/test_returns.py b/test/base/test_returns.py index a104eb674..3adcdaf5c 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -1,9 +1,10 @@ -import torch -import numpy as np from timeit import timeit -from tianshou.policy import BasePolicy +import numpy as np +import torch + from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.policy import BasePolicy def compute_episodic_return_base(batch, gamma): @@ -24,8 +25,12 @@ def test_episodic_returns(size=2560): batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), - info=Batch({'TimeLimit.truncated': - np.array([False, False, False, False, False, True, False, False])}) + info=Batch( + { + 'TimeLimit.truncated': + np.array([False, False, False, False, False, True, False, False]) + } + ) ) for b in batch: b.obs = b.act = 1 @@ -65,28 +70,40 @@ def test_episodic_returns(size=2560): buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) - ground_truth = np.array([ - 454.8344, 376.1143, 291.298, 200., - 464.5610, 383.1085, 295.387, 201., - 474.2876, 390.1027, 299.476, 202.]) + ground_truth = np.array( + [ + 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., + 474.2876, 390.1027, 299.476, 202. + ] + ) assert np.allclose(returns, ground_truth) buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), - info=Batch({'TimeLimit.truncated': - np.array([False, False, False, True, False, False, - False, True, False, False, False, False])}) + info=Batch( + { + 'TimeLimit.truncated': + np.array( + [ + False, False, False, True, False, False, False, True, False, + False, False, False + ] + ) + } + ) ) for b in batch: b.obs = b.act = 1 buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) returns, _ = fn(batch, buf, buf.sample_indices(0), v, gamma=0.99, gae_lambda=0.95) - ground_truth = np.array([ - 454.0109, 375.2386, 290.3669, 199.01, - 462.9138, 381.3571, 293.5248, 199.02, - 474.2876, 390.1027, 299.476, 202.]) + ground_truth = np.array( + [ + 454.0109, 375.2386, 290.3669, 199.01, 462.9138, 381.3571, 293.5248, 199.02, + 474.2876, 390.1027, 299.476, 202. + ] + ) assert np.allclose(returns, ground_truth) if __name__ == '__main__': @@ -129,16 +146,17 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices): real_step_n = nstep for n in range(nstep): idx = (indices[i] + n) % buf_len - r += buffer.rew[idx] * gamma ** n + r += buffer.rew[idx] * gamma**n if buffer.done[idx]: - if not (hasattr(buffer, 'info') and - buffer.info['TimeLimit.truncated'][idx]): + if not ( + hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx] + ): flag = True real_step_n = n + 1 break if not flag: idx = (indices[i] + real_step_n - 1) % buf_len - r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n + r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n returns[i] = r return returns @@ -152,89 +170,128 @@ def test_nstep_returns(size=10000): # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=1 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indices) assert np.allclose(returns, r_), (r_, returns) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=2 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=10 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) def test_nstep_returns_with_timelimit(size=10000): buf = ReplayBuffer(10) for i in range(12): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, - info={"TimeLimit.truncated": i == 3})) + buf.add( + Batch( + obs=0, + act=0, + rew=i + 1, + done=i % 4 == 3, + info={"TimeLimit.truncated": i == 3} + ) + ) batch, indices = buf.sample(0) assert np.allclose(indices, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=1 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indices) assert np.allclose(returns, r_), (r_, returns) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=2 - ).pop('returns').reshape(-1)) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1) + ) assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 - returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=10 - ).pop('returns').reshape(-1)) - assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, - 7.8, 8, 10.122, 11.22, 12.2, 12]) + returns = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1) + ) + assert np.allclose( + returns, [3.36, 3.6, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12] + ) r_ = compute_nstep_return_base(10, .1, buf, indices) assert np.allclose(returns, r_) - returns_multidim = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 - ).pop('returns')) + returns_multidim = to_numpy( + BasePolicy.compute_nstep_return( + batch, buf, indices, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns') + ) assert np.allclose(returns_multidim, returns[:, np.newaxis]) if __name__ == '__main__': buf = ReplayBuffer(size) for i in range(int(size * 1.5)): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0, - info={"TimeLimit.truncated": i % 33 == 0})) + buf.add( + Batch( + obs=0, + act=0, + rew=i + 1, + done=np.random.randint(3) == 0, + info={"TimeLimit.truncated": i % 33 == 0} + ) + ) batch, indices = buf.sample(256) def vanilla(): @@ -242,7 +299,8 @@ def vanilla(): def optimized(): return BasePolicy.compute_nstep_return( - batch, buf, indices, target_q_fn, gamma=.1, n_step=3) + batch, buf, indices, target_q_fn, gamma=.1, n_step=3 + ) cnt = 3000 print('nstep vanilla', timeit(vanilla, setup=vanilla, number=cnt)) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index d0e1cefb3..38bf5d40e 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,9 +1,9 @@ -import torch import numpy as np +import torch -from tianshou.utils.net.common import MLP, Net -from tianshou.utils import MovAvg, RunningMeanStd from tianshou.exploration import GaussianNoise, OUNoise +from tianshou.utils import MovAvg, RunningMeanStd +from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -20,14 +20,14 @@ def test_moving_average(): stat = MovAvg(10) assert np.allclose(stat.get(), 0) assert np.allclose(stat.mean(), 0) - assert np.allclose(stat.std() ** 2, 0) + assert np.allclose(stat.std()**2, 0) stat.add(torch.tensor([1])) stat.add(np.array([2])) stat.add([3, 4]) stat.add(5.) assert np.allclose(stat.get(), 3) assert np.allclose(stat.mean(), 3) - assert np.allclose(stat.std() ** 2, 2) + assert np.allclose(stat.std()**2, 2) def test_rms(): @@ -55,23 +55,36 @@ def test_net(): action_shape = (5, ) data = torch.rand([bsz, *state_shape]) expect_output_shape = [bsz, *action_shape] - net = Net(state_shape, action_shape, hidden_sizes=[128, 128], - norm_layer=torch.nn.LayerNorm, activation=None) + net = Net( + state_shape, + action_shape, + hidden_sizes=[128, 128], + norm_layer=torch.nn.LayerNorm, + activation=None + ) assert list(net(data)[0].shape) == expect_output_shape assert str(net).count("LayerNorm") == 2 assert str(net).count("ReLU") == 0 Q_param = V_param = {"hidden_sizes": [128, 128]} - net = Net(state_shape, action_shape, hidden_sizes=[128, 128], - dueling_param=(Q_param, V_param)) + net = Net( + state_shape, + action_shape, + hidden_sizes=[128, 128], + dueling_param=(Q_param, V_param) + ) assert list(net(data)[0].shape) == expect_output_shape # concat - net = Net(state_shape, action_shape, hidden_sizes=[128], - concat=True) + net = Net(state_shape, action_shape, hidden_sizes=[128], concat=True) data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)]) expect_output_shape = [bsz, 128] assert list(net(data)[0].shape) == expect_output_shape - net = Net(state_shape, action_shape, hidden_sizes=[128], - concat=True, dueling_param=(Q_param, V_param)) + net = Net( + state_shape, + action_shape, + hidden_sizes=[128], + concat=True, + dueling_param=(Q_param, V_param) + ) assert list(net(data)[0].shape) == expect_output_shape # recurrent actor/critic data = torch.rand([bsz, *state_shape]).flatten(1) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 3dbeb9091..e88c0869c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -39,8 +40,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -56,36 +57,51 @@ def test_ddpg(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic = Critic(net, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( - actor, actor_optim, critic, critic_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic, + critic_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=args.rew_norm, - estimation_step=args.n_step, action_space=env.action_space) + estimation_step=args.n_step, + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ddpg') @@ -100,10 +116,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index ad758897f..1a9e82623 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -1,19 +1,20 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch import nn -from torch.utils.tensorboard import SummaryWriter from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import NPGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -27,8 +28,9 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-collect', type=int, default=2048) - parser.add_argument('--repeat-per-collect', type=int, - default=2) # theoretically it should be 1 + parser.add_argument( + '--repeat-per-collect', type=int, default=2 + ) # theoretically it should be 1 parser.add_argument('--batch-size', type=int, default=99999) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=16) @@ -36,8 +38,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # npg special parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument('--rew-norm', type=int, default=1) @@ -58,23 +60,40 @@ def test_npg(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - critic = Critic(Net( - args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device, - activation=nn.Tanh), device=args.device).to(args.device) + net = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + critic = Critic( + Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + activation=nn.Tanh + ), + device=args.device + ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -88,7 +107,10 @@ def dist(*logits): return Independent(Normal(*logits), 1) policy = NPGPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, @@ -96,11 +118,12 @@ def dist(*logits): action_space=env.action_space, optim_critic_iters=args.optim_critic_iters, actor_step_size=args.actor_step_size, - deterministic_eval=True) + deterministic_eval=True + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'npg') @@ -115,10 +138,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 09c73fdd5..473222816 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np -from torch.utils.tensorboard import SummaryWriter +import torch from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -34,8 +35,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # ppo special parser.add_argument('--vf-coef', type=float, default=0.25) parser.add_argument('--ent-coef', type=float, default=0.0) @@ -63,30 +64,34 @@ def test_ppo(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) - critic = Critic(Net( - args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device - ), device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) + critic = Critic( + Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + device=args.device + ).to(args.device) # orthogonal initialization for m in set(actor.modules()).union(critic.modules()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam( - set(actor.parameters()).union(critic.parameters()), lr=args.lr) + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -94,7 +99,10 @@ def dist(*logits): return Independent(Normal(*logits), 1) policy = PPOPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -107,11 +115,12 @@ def dist(*logits): # dual clip cause monotonically increasing log_std :) value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space) + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') @@ -126,10 +135,12 @@ def stop_fn(mean_rewards): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) if args.resume: # load from existing checkpoint @@ -145,11 +156,21 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, args.step_per_epoch, - args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger, resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 5b6a79492..da20290ec 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net +from tianshou.policy import ImitationPolicy, SACPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.policy import SACPolicy, ImitationPolicy +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, ActorProb, Critic @@ -34,10 +35,10 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--imitation-hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument( + '--imitation-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -45,8 +46,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -62,29 +63,43 @@ def test_sac_with_il(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - device=args.device, unbounded=True).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -95,15 +110,26 @@ def test_sac_with_il(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, alpha=args.alpha, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, reward_normalization=args.rew_norm, - estimation_step=args.n_step, action_space=env.action_space) + estimation_step=args.n_step, + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -119,10 +145,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -140,23 +175,41 @@ def stop_fn(mean_rewards): if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal net = Actor( - Net(args.state_shape, hidden_sizes=args.imitation_hidden_sizes, - device=args.device), - args.action_shape, max_action=args.max_action, device=args.device + Net( + args.state_shape, + hidden_sizes=args.imitation_hidden_sizes, + device=args.device + ), + args.action_shape, + max_action=args.max_action, + device=args.device ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy( - net, optim, action_space=env.action_space, - action_scaling=True, action_bound_method="clip") + net, + optim, + action_space=env.action_space, + action_scaling=True, + action_bound_method="clip" + ) il_test_collector = Collector( il_policy, DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) ) train_collector.reset() result = offpolicy_trainer( - il_policy, train_collector, il_test_collector, args.epoch, - args.il_step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) + il_policy, + train_collector, + il_test_collector, + args.epoch, + args.il_step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 8bae1edfa..2e3ef7ba7 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -42,8 +43,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -59,46 +60,65 @@ def test_td3(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, device=args.device) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) critic2 = Critic(net_c2, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau=args.tau, gamma=args.gamma, + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, reward_normalization=args.rew_norm, estimation_step=args.n_step, - action_space=env.action_space) + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -114,10 +134,19 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 65535fd50..4a4206f5f 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -1,19 +1,20 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch import nn -from torch.utils.tensorboard import SummaryWriter from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import TRPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -27,8 +28,9 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-collect', type=int, default=2048) - parser.add_argument('--repeat-per-collect', type=int, - default=2) # theoretically it should be 1 + parser.add_argument( + '--repeat-per-collect', type=int, default=2 + ) # theoretically it should be 1 parser.add_argument('--batch-size', type=int, default=99999) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=16) @@ -36,8 +38,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # trpo special parser.add_argument('--gae-lambda', type=float, default=0.95) parser.add_argument('--rew-norm', type=int, default=1) @@ -61,23 +63,40 @@ def test_trpo(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - activation=nn.Tanh, device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - unbounded=True, device=args.device).to(args.device) - critic = Critic(Net( - args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device, - activation=nn.Tanh), device=args.device).to(args.device) + net = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + critic = Critic( + Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + activation=nn.Tanh + ), + device=args.device + ).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): @@ -91,7 +110,10 @@ def dist(*logits): return Independent(Normal(*logits), 1) policy = TRPOPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, reward_normalization=args.rew_norm, advantage_normalization=args.norm_adv, @@ -100,11 +122,12 @@ def dist(*logits): optim_critic_iters=args.optim_critic_iters, max_kl=args.max_kl, backtrack_coeff=args.backtrack_coeff, - max_backtracks=args.max_backtracks) + max_backtracks=args.max_backtracks + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'trpo') @@ -119,10 +142,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index d11ce360c..745295826 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv +from tianshou.policy import A2CPolicy, ImitationPolicy +from tianshou.trainer import offpolicy_trainer, onpolicy_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic -from tianshou.policy import A2CPolicy, ImitationPolicy -from tianshou.trainer import onpolicy_trainer, offpolicy_trainer def get_args(): @@ -31,17 +32,15 @@ def get_args(): parser.add_argument('--update-per-step', type=float, default=1 / 16) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) - parser.add_argument('--imitation-hidden-sizes', type=int, - nargs='*', default=[128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--imitation-hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # a2c special parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--ent-coef', type=float, default=0.0) @@ -60,33 +59,42 @@ def test_a2c_with_il(args=get_args()): # you can also use tianshou.env.SubprocVectorEnv # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam( - set(actor.parameters()).union(critic.parameters()), lr=args.lr) + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) dist = torch.distributions.Categorical policy = A2CPolicy( - actor, critic, optim, dist, - discount_factor=args.gamma, gae_lambda=args.gae_lambda, - vf_coef=args.vf_coef, ent_coef=args.ent_coef, - max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm, - action_space=env.action_space) + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + max_grad_norm=args.max_grad_norm, + reward_normalization=args.rew_norm, + action_space=env.action_space + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'a2c') @@ -101,10 +109,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -121,8 +138,7 @@ def stop_fn(mean_rewards): # here we define an imitation collector with a trivial policy if args.task == 'CartPole-v0': env.spec.reward_threshold = 190 # lower the goal - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Actor(net, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, action_space=env.action_space) @@ -132,9 +148,18 @@ def stop_fn(mean_rewards): ) train_collector.reset() result = offpolicy_trainer( - il_policy, train_collector, il_test_collector, args.epoch, - args.il_step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) + il_policy, + train_collector, + il_test_collector, + args.epoch, + args.il_step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 3208e83c8..3c74723eb 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import C51Policy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -34,20 +35,20 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--resume', action="store_true") parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument("--save-interval", type=int, default=4) args = parser.parse_known_args()[0] return args @@ -60,29 +61,45 @@ def test_c51(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=True, num_atoms=args.num_atoms) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=True, + num_atoms=args.num_atoms + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = C51Policy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -117,12 +134,16 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) - pickle.dump(train_collector.buffer, - open(os.path.join(log_path, 'train_buffer.pkl'), "wb")) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) + pickle.dump( + train_collector.buffer, + open(os.path.join(log_path, 'train_buffer.pkl'), "wb") + ) if args.resume: # load from existing checkpoint @@ -144,11 +165,23 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, - resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index aae609ec6..6912a1933 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -31,22 +32,22 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--save-buffer-name', type=str, - default="./expert_DQN_CartPole-v0.pkl") + '--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl" + ) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -58,10 +59,12 @@ def test_dqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -69,19 +72,29 @@ def test_dqn(args=get_args()): test_envs.seed(args.seed) # Q_param = V_param = {"hidden_sizes": [128]} # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - # dueling=(Q_param, V_param), - ).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + # dueling=(Q_param, V_param), + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -116,10 +129,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index aa4fbbe0f..064dbba24 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import DQNPolicy -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv +from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Recurrent -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -37,8 +38,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -50,26 +51,35 @@ def test_drqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Recurrent(args.layer_num, args.state_shape, - args.action_shape, args.device).to(args.device) + net = Recurrent(args.layer_num, args.state_shape, args.action_shape, + args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) # collector buffer = VectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - stack_num=args.stack_num, ignore_obs_next=True) + args.buffer_size, + buffer_num=len(train_envs), + stack_num=args.stack_num, + ignore_obs_next=True + ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs, exploration_noise=True) @@ -94,11 +104,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, - train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 0df2efb74..1763380f1 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import FQFPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -35,19 +36,17 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64]) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -59,22 +58,31 @@ def test_fqf(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - feature_net = Net(args.state_shape, args.hidden_sizes[-1], - hidden_sizes=args.hidden_sizes[:-1], device=args.device, - softmax=False) + feature_net = Net( + args.state_shape, + args.hidden_sizes[-1], + hidden_sizes=args.hidden_sizes[:-1], + device=args.device, + softmax=False + ) net = FullQuantileFunction( - feature_net, args.action_shape, args.hidden_sizes, - num_cosines=args.num_cosines, device=args.device + feature_net, + args.action_shape, + args.hidden_sizes, + num_cosines=args.num_cosines, + device=args.device ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) @@ -82,14 +90,24 @@ def test_fqf(args=get_args()): fraction_net.parameters(), lr=args.fraction_lr ) policy = FQFPolicy( - net, optim, fraction_net, fraction_optim, args.gamma, args.num_fractions, - args.ent_coef, args.n_step, target_update_freq=args.target_update_freq + net, + optim, + fraction_net, + fraction_optim, + args.gamma, + args.num_fractions, + args.ent_coef, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -124,11 +142,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index f404ce7dc..47540dadd 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteBCQPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -29,17 +30,18 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument( - "--load-buffer-name", type=str, + "--load-buffer-name", + type=str, default="./expert_DQN_CartPole-v0.pkl", ) parser.add_argument( - "--device", type=str, + "--device", + type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) parser.add_argument("--resume", action="store_true") @@ -56,26 +58,39 @@ def test_discrete_bcq(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model policy_net = Net( - args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) imitation_net = Net( - args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), - lr=args.lr) + list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr + ) policy = DiscreteBCQPolicy( - policy_net, imitation_net, optim, args.gamma, args.n_step, - args.target_update_freq, args.eps_test, - args.unlikely_action_threshold, args.imitation_logits_penalty, + policy_net, + imitation_net, + optim, + args.gamma, + args.n_step, + args.target_update_freq, + args.eps_test, + args.unlikely_action_threshold, + args.imitation_logits_penalty, ) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -97,10 +112,12 @@ def stop_fn(mean_rewards): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) if args.resume: # load from existing checkpoint @@ -115,10 +132,19 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): print("Fail to restore policy and optim.") result = offline_trainer( - policy, buffer, test_collector, - args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py index 858d2b6f7..929469e8b 100644 --- a/test/discrete/test_il_crr.py +++ b/test/discrete/test_il_crr.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCRRPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -26,17 +27,18 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument( - "--load-buffer-name", type=str, + "--load-buffer-name", + type=str, default="./expert_DQN_CartPole-v0.pkl", ) parser.add_argument( - "--device", type=str, + "--device", + type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) args = parser.parse_known_args()[0] @@ -51,23 +53,36 @@ def test_discrete_crr(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - actor = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False) - critic = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False) - optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()), - lr=args.lr) + actor = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False + ) + critic = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False + ) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr + ) policy = DiscreteCRRPolicy( - actor, critic, optim, args.gamma, + actor, + critic, + optim, + args.gamma, target_update_freq=args.target_update_freq, ).to(args.device) # buffer @@ -89,9 +104,17 @@ def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold result = offline_trainer( - policy, buffer, test_collector, - args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 0234c36f0..c93ddfc0d 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import IQNPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.discrete import ImplicitQuantileNetwork -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -35,19 +36,17 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64, 64]) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -59,33 +58,50 @@ def test_iqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - feature_net = Net(args.state_shape, args.hidden_sizes[-1], - hidden_sizes=args.hidden_sizes[:-1], device=args.device, - softmax=False) + feature_net = Net( + args.state_shape, + args.hidden_sizes[-1], + hidden_sizes=args.hidden_sizes[:-1], + device=args.device, + softmax=False + ) net = ImplicitQuantileNetwork( - feature_net, args.action_shape, - num_cosines=args.num_cosines, device=args.device) + feature_net, + args.action_shape, + num_cosines=args.num_cosines, + device=args.device + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = IQNPolicy( - net, optim, args.gamma, args.sample_size, args.online_sample_size, - args.target_sample_size, args.n_step, + net, + optim, + args.gamma, + args.sample_size, + args.online_sample_size, + args.target_sample_size, + args.n_step, target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -120,11 +136,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 3c36bb265..fafd7cc49 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import PGPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -26,16 +27,15 @@ def get_args(): parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument('--rew-norm', type=int, default=1) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -47,24 +47,35 @@ def test_pg(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device, softmax=True).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=True + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy = PGPolicy(net, optim, dist, args.gamma, - reward_normalization=args.rew_norm, - action_space=env.action_space) + policy = PGPolicy( + net, + optim, + dist, + args.gamma, + reward_normalization=args.rew_norm, + action_space=env.action_space + ) for m in net.modules(): if isinstance(m, torch.nn.Linear): # orthogonal initialization @@ -72,8 +83,8 @@ def test_pg(args=get_args()): torch.nn.init.zeros_(m.bias) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'pg') @@ -88,10 +99,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 3658f364b..96650b14b 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -1,17 +1,18 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic @@ -33,8 +34,8 @@ def get_args(): parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) # ppo special parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--ent-coef', type=float, default=0.0) @@ -57,18 +58,19 @@ def test_ppo(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) # orthogonal initialization @@ -77,10 +79,14 @@ def test_ppo(args=get_args()): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam( - set(actor.parameters()).union(critic.parameters()), lr=args.lr) + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) dist = torch.distributions.Categorical policy = PPOPolicy( - actor, critic, optim, dist, + actor, + critic, + optim, + dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -93,11 +99,12 @@ def test_ppo(args=get_args()): action_space=env.action_space, deterministic_eval=True, advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv) + recompute_advantage=args.recompute_adv + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') @@ -112,10 +119,19 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 2385db0ee..cf8d22212 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger -from tianshou.policy import QRDQNPolicy +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net +from tianshou.policy import QRDQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -32,22 +33,22 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument( - '--save-buffer-name', type=str, - default="./expert_QRDQN_CartPole-v0.pkl") + '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl" + ) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -61,29 +62,43 @@ def test_qrdqn(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False, num_atoms=args.num_quantiles) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False, + num_atoms=args.num_quantiles + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = QRDQNPolicy( - net, optim, args.gamma, args.num_quantiles, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -118,11 +133,21 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py index dbfd42aad..01b868f13 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/discrete/test_qrdqn_il_cql.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCQLPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_args(): @@ -29,17 +30,18 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[64, 64]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) parser.add_argument( - "--load-buffer-name", type=str, + "--load-buffer-name", + type=str, default="./expert_QRDQN_CartPole-v0.pkl", ) parser.add_argument( - "--device", type=str, + "--device", + type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) args = parser.parse_known_args()[0] @@ -54,20 +56,31 @@ def test_discrete_cql(args=get_args()): args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=False, num_atoms=args.num_quantiles) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False, + num_atoms=args.num_quantiles + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DiscreteCQLPolicy( - net, optim, args.gamma, args.num_quantiles, args.n_step, - args.target_update_freq, min_q_weight=args.min_q_weight + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + args.target_update_freq, + min_q_weight=args.min_q_weight ).to(args.device) # buffer assert os.path.exists(args.load_buffer_name), \ @@ -88,9 +101,17 @@ def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold result = offline_trainer( - policy, buffer, test_collector, - args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + policy, + buffer, + test_collector, + args.epoch, + args.update_per_epoch, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 4fdcfd352..b226a025c 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -1,19 +1,20 @@ +import argparse import os -import gym -import torch import pickle import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv from tianshou.policy import RainbowPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -36,21 +37,21 @@ def get_args(): parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--prioritized-replay', - action="store_true", default=False) + parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--beta-final', type=float, default=1.) parser.add_argument('--resume', action="store_true") parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) parser.add_argument("--save-interval", type=int, default=4) args = parser.parse_known_args()[0] return args @@ -63,35 +64,56 @@ def test_rainbow(args=get_args()): # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) + # model def noisy_linear(x, y): return NoisyLinear(x, y, args.noisy_std) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device, - softmax=True, num_atoms=args.num_atoms, - dueling_param=({"linear_layer": noisy_linear}, - {"linear_layer": noisy_linear})) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=True, + num_atoms=args.num_atoms, + dueling_param=({ + "linear_layer": noisy_linear + }, { + "linear_layer": noisy_linear + }) + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = RainbowPolicy( - net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, - args.n_step, target_update_freq=args.target_update_freq + net, + optim, + args.gamma, + args.num_atoms, + args.v_min, + args.v_max, + args.n_step, + target_update_freq=args.target_update_freq ).to(args.device) # buffer if args.prioritized_replay: buf = PrioritizedVectorReplayBuffer( - args.buffer_size, buffer_num=len(train_envs), - alpha=args.alpha, beta=args.beta, weight_norm=True) + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta, + weight_norm=True + ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector @@ -136,12 +158,16 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save({ - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth')) - pickle.dump(train_collector.buffer, - open(os.path.join(log_path, 'train_buffer.pkl'), "wb")) + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) + pickle.dump( + train_collector.buffer, + open(os.path.join(log_path, 'train_buffer.pkl'), "wb") + ) if args.resume: # load from existing checkpoint @@ -163,11 +189,23 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, - resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index d8abf48e4..41be36838 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,18 +1,19 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net from tianshou.policy import DiscreteSACPolicy from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import Actor, Critic -from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -40,8 +41,8 @@ def get_args(): parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=3) parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) args = parser.parse_known_args()[0] return args @@ -52,27 +53,26 @@ def test_discrete_sac(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, - softmax_output=False, device=args.device).to(args.device) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor(net, args.action_shape, softmax_output=False, + device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic1 = Critic(net_c1, last_size=args.action_shape, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) + net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic2 = Critic(net_c2, last_size=args.action_shape, device=args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) @@ -85,13 +85,22 @@ def test_discrete_sac(args=get_args()): args.alpha = (target_entropy, log_alpha, alpha_optim) policy = DiscreteSACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.alpha, estimation_step=args.n_step, - reward_normalization=args.rew_norm) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + args.tau, + args.gamma, + args.alpha, + estimation_step=args.n_step, + reward_normalization=args.rew_norm + ) # collector train_collector = Collector( - policy, train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs))) + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log @@ -107,10 +116,20 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step, test_in_train=False) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 01710d827..3a50f36e9 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -1,15 +1,16 @@ +import argparse import os -import gym -import torch import pprint -import argparse + +import gym import numpy as np +import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.policy import PSRLPolicy -from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import PSRLPolicy +from tianshou.trainer import onpolicy_trainer def get_args(): @@ -42,10 +43,12 @@ def test_psrl(args=get_args()): args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) # test_envs = gym.make(args.task) test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -59,12 +62,15 @@ def test_psrl(args=get_args()): rew_std_prior = np.full((n_state, n_action), args.rew_std_prior) policy = PSRLPolicy( trans_count_prior, rew_mean_prior, rew_std_prior, args.gamma, args.eps, - args.add_done_loop) + args.add_done_loop + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'psrl') @@ -80,11 +86,19 @@ def stop_fn(mean_rewards): train_collector.collect(n_step=args.buffer_size, random=True) # trainer, test it without logger result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, 1, args.test_num, 0, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + 1, + args.test_num, + 0, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, # logger=logger, - test_in_train=False) + test_in_train=False + ) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 6418ee8ec..7d1754245 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -1,17 +1,17 @@ import os import pprint -import numpy as np from copy import deepcopy + +import numpy as np +from tic_tac_toe import get_agents, get_parser, train_agent, watch +from tic_tac_toe_env import TicTacToeEnv from torch.utils.tensorboard import SummaryWriter -from tianshou.env import DummyVectorEnv from tianshou.data import Collector +from tianshou.env import DummyVectorEnv from tianshou.policy import RandomPolicy from tianshou.utils import TensorboardLogger -from tic_tac_toe_env import TicTacToeEnv -from tic_tac_toe import get_parser, get_agents, train_agent, watch - def get_args(): parser = get_parser() @@ -39,6 +39,7 @@ def gomoku(args=get_args()): def env_func(): return TicTacToeEnv(args.board_size, args.win_size) + test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) for r in range(args.self_play_round): rews = [] @@ -65,11 +66,11 @@ def env_func(): # previous learner can only be used for forward agent.forward = opponent.forward args.model_save_path = os.path.join( - args.logdir, 'Gomoku', 'dqn', - f'policy_round_{r}_epoch_{epoch}.pth') + args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth' + ) result, agent_learn = train_agent( - args, agent_learn=agent_learn, - agent_opponent=agent, optim=optim) + args, agent_learn=agent_learn, agent_opponent=agent, optim=optim + ) print(f'round_{r}_epoch_{epoch}') pprint.pprint(result) learnt_agent = deepcopy(agent_learn) diff --git a/test/multiagent/test_tic_tac_toe.py b/test/multiagent/test_tic_tac_toe.py index 1cc06d374..aeb4644e1 100644 --- a/test/multiagent/test_tic_tac_toe.py +++ b/test/multiagent/test_tic_tac_toe.py @@ -1,4 +1,5 @@ import pprint + from tic_tac_toe import get_args, train_agent, watch diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index dcd293a06..8ecbd2878 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -1,20 +1,24 @@ -import os -import torch import argparse -import numpy as np +import os from copy import deepcopy from typing import Optional, Tuple + +import numpy as np +import torch +from tic_tac_toe_env import TicTacToeEnv from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import TensorboardLogger +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv -from tianshou.utils.net.common import Net +from tianshou.policy import ( + BasePolicy, + DQNPolicy, + MultiAgentPolicyManager, + RandomPolicy, +) from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \ - MultiAgentPolicyManager - -from tic_tac_toe_env import TicTacToeEnv +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net def get_parser() -> argparse.ArgumentParser: @@ -24,8 +28,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.9, - help='a smaller gamma favors earlier win') + parser.add_argument( + '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' + ) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) @@ -33,31 +38,49 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128, 128]) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) parser.add_argument('--board-size', type=int, default=6) parser.add_argument('--win-size', type=int, default=4) - parser.add_argument('--win-rate', type=float, default=0.9, - help='the expected winning rate') - parser.add_argument('--watch', default=False, action='store_true', - help='no training, ' - 'watch the play of pre-trained models') - parser.add_argument('--agent-id', type=int, default=2, - help='the learned agent plays as the' - ' agent_id-th player. Choices are 1 and 2.') - parser.add_argument('--resume-path', type=str, default='', - help='the path of agent pth file ' - 'for resuming from a pre-trained agent') - parser.add_argument('--opponent-path', type=str, default='', - help='the path of opponent agent pth file ' - 'for resuming from a pre-trained agent') parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') + '--win-rate', type=float, default=0.9, help='the expected winning rate' + ) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='no training, ' + 'watch the play of pre-trained models' + ) + parser.add_argument( + '--agent-id', + type=int, + default=2, + help='the learned agent plays as the' + ' agent_id-th player. Choices are 1 and 2.' + ) + parser.add_argument( + '--resume-path', + type=str, + default='', + help='the path of agent pth file ' + 'for resuming from a pre-trained agent' + ) + parser.add_argument( + '--opponent-path', + type=str, + default='', + help='the path of opponent agent pth file ' + 'for resuming from a pre-trained agent' + ) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) return parser @@ -77,14 +100,21 @@ def get_agents( args.action_shape = env.action_space.shape or env.action_space.n if agent_learn is None: # model - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, device=args.device - ).to(args.device) + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) if optim is None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) agent_learn = DQNPolicy( - net, optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq) + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) if args.resume_path: agent_learn.load_state_dict(torch.load(args.resume_path)) @@ -111,6 +141,7 @@ def train_agent( ) -> Tuple[dict, BasePolicy]: def env_func(): return TicTacToeEnv(args.board_size, args.win_size) + train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)]) test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) # seed @@ -120,14 +151,16 @@ def env_func(): test_envs.seed(args.seed) policy, optim = get_agents( - args, agent_learn=agent_learn, - agent_opponent=agent_opponent, optim=optim) + args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim + ) # collector train_collector = Collector( - policy, train_envs, + policy, + train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True) + exploration_noise=True + ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -142,10 +175,9 @@ def save_fn(policy): model_save_path = args.model_save_path else: model_save_path = os.path.join( - args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth') - torch.save( - policy.policies[args.agent_id - 1].state_dict(), - model_save_path) + args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth' + ) + torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path) def stop_fn(mean_rewards): return mean_rewards >= args.win_rate @@ -161,11 +193,23 @@ def reward_metric(rews): # trainer result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, - logger=logger, test_in_train=False, reward_metric=reward_metric) + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric + ) return result, policy.policies[args.agent_id - 1] @@ -177,7 +221,8 @@ def watch( ) -> None: env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( - args, agent_learn=agent_learn, agent_opponent=agent_opponent) + args, agent_learn=agent_learn, agent_opponent=agent_opponent + ) policy.eval() policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) diff --git a/test/multiagent/tic_tac_toe_env.py b/test/multiagent/tic_tac_toe_env.py index 2fc045afa..e39e1d13d 100644 --- a/test/multiagent/tic_tac_toe_env.py +++ b/test/multiagent/tic_tac_toe_env.py @@ -1,7 +1,8 @@ +from functools import partial +from typing import Optional, Tuple + import gym import numpy as np -from functools import partial -from typing import Tuple, Optional from tianshou.env import MultiAgentEnv @@ -16,7 +17,6 @@ class TicTacToeEnv(MultiAgentEnv): :param size: the size of the board (square board) :param win_size: how many units in a row is considered to win """ - def __init__(self, size: int = 3, win_size: int = 3): super().__init__() assert size > 0, f'board size should be positive, but got {size}' @@ -27,7 +27,8 @@ def __init__(self, size: int = 3, win_size: int = 3): f'be larger than board size {size}' self.convolve_kernel = np.ones(win_size) self.observation_space = gym.spaces.Box( - low=-1.0, high=1.0, shape=(size, size), dtype=np.float32) + low=-1.0, high=1.0, shape=(size, size), dtype=np.float32 + ) self.action_space = gym.spaces.Discrete(size * size) self.current_board = None self.current_agent = None @@ -45,11 +46,10 @@ def reset(self) -> dict: 'mask': self.current_board.flatten() == 0 } - def step(self, action: [int, np.ndarray] - ) -> Tuple[dict, np.ndarray, np.ndarray, dict]: + def step(self, action: [int, + np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]: if self.current_agent is None: - raise ValueError( - "calling step() of unreset environment is prohibited!") + raise ValueError("calling step() of unreset environment is prohibited!") assert 0 <= action < self.size * self.size assert self.current_board.item(action) == 0 _current_agent = self.current_agent @@ -97,18 +97,28 @@ def _test_win(self): rboard = self.current_board[row, :] cboard = self.current_board[:, col] current = self.current_board[row, col] - rightup = [self.current_board[row - i, col + i] - for i in range(1, self.size - col) if row - i >= 0] - leftdown = [self.current_board[row + i, col - i] - for i in range(1, col + 1) if row + i < self.size] + rightup = [ + self.current_board[row - i, col + i] for i in range(1, self.size - col) + if row - i >= 0 + ] + leftdown = [ + self.current_board[row + i, col - i] for i in range(1, col + 1) + if row + i < self.size + ] rdiag = np.array(leftdown[::-1] + [current] + rightup) - rightdown = [self.current_board[row + i, col + i] - for i in range(1, self.size - col) if row + i < self.size] - leftup = [self.current_board[row - i, col - i] - for i in range(1, col + 1) if row - i >= 0] + rightdown = [ + self.current_board[row + i, col + i] for i in range(1, self.size - col) + if row + i < self.size + ] + leftup = [ + self.current_board[row - i, col - i] for i in range(1, col + 1) + if row - i >= 0 + ] diag = np.array(leftup[::-1] + [current] + rightdown) - results = [np.convolve(k, self.convolve_kernel, mode='valid') - for k in (rboard, cboard, rdiag, diag)] + results = [ + np.convolve(k, self.convolve_kernel, mode='valid') + for k in (rboard, cboard, rdiag, diag) + ] return any([(np.abs(x) == self.win_size).any() for x in results]) def seed(self, seed: Optional[int] = None) -> int: @@ -128,6 +138,7 @@ def f(i, data): if number == -1: return 'O' if last_move else 'o' return '_' + for i, row in enumerate(self.current_board): print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad) print(top) diff --git a/test/throughput/test_batch_profile.py b/test/throughput/test_batch_profile.py index 9654f5838..fbd6fb89c 100644 --- a/test/throughput/test_batch_profile.py +++ b/test/throughput/test_batch_profile.py @@ -12,13 +12,20 @@ def data(): print("Initialising data...") np.random.seed(0) - batch_set = [Batch(a=[j for j in np.arange(1e3)], - b={'b1': (3.14, 3.14), 'b2': np.arange(1e3)}, - c=i) for i in np.arange(int(1e4))] + batch_set = [ + Batch( + a=[j for j in np.arange(1e3)], + b={ + 'b1': (3.14, 3.14), + 'b2': np.arange(1e3) + }, + c=i + ) for i in np.arange(int(1e4)) + ] batch0 = Batch( a=np.ones((3, 4), dtype=np.float64), b=Batch( - c=np.ones((1,), dtype=np.float64), + c=np.ones((1, ), dtype=np.float64), d=torch.ones((3, 3, 3), dtype=torch.float32), e=list(range(3)) ) @@ -26,19 +33,25 @@ def data(): batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batch_len = int(1e4) - batch3 = Batch(obs=[np.arange(20) for _ in np.arange(batch_len)], - reward=np.arange(batch_len)) - indexs = np.random.choice(batch_len, - size=batch_len // 10, replace=False) - slice_dict = {'obs': [np.arange(20) - for _ in np.arange(batch_len // 10)], - 'reward': np.arange(batch_len // 10)} - dict_set = [{'obs': np.arange(20), 'info': "this is info", 'reward': 0} - for _ in np.arange(1e2)] + batch3 = Batch( + obs=[np.arange(20) for _ in np.arange(batch_len)], reward=np.arange(batch_len) + ) + indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False) + slice_dict = { + 'obs': [np.arange(20) for _ in np.arange(batch_len // 10)], + 'reward': np.arange(batch_len // 10) + } + dict_set = [ + { + 'obs': np.arange(20), + 'info': "this is info", + 'reward': 0 + } for _ in np.arange(1e2) + ] batch4 = Batch( a=np.ones((10000, 4), dtype=np.float64), b=Batch( - c=np.ones((1,), dtype=np.float64), + c=np.ones((1, ), dtype=np.float64), d=torch.ones((1000, 1000), dtype=torch.float32), e=np.arange(1000) ) diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index 40ce68889..57bd3f5b6 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,8 +1,10 @@ import sys -import gym import time -import tqdm + +import gym import numpy as np +import tqdm + from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index 6242e694b..eced837a5 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -1,9 +1,9 @@ -import tqdm import numpy as np +import tqdm -from tianshou.policy import BasePolicy +from tianshou.data import AsyncCollector, Batch, Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Batch, Collector, AsyncCollector, VectorReplayBuffer +from tianshou.policy import BasePolicy if __name__ == '__main__': from env import MyTestEnv @@ -40,8 +40,7 @@ def test_collector_nstep(): env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] dum = DummyVectorEnv(env_fns) num = len(env_fns) - c3 = Collector(policy, dum, - VectorReplayBuffer(total_size=40000, buffer_num=num)) + c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num)) for i in tqdm.trange(1, 400, desc="test step collector n_step"): c3.reset() result = c3.collect(n_step=i * len(env_fns)) @@ -53,8 +52,7 @@ def test_collector_nepisode(): env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] dum = DummyVectorEnv(env_fns) num = len(env_fns) - c3 = Collector(policy, dum, - VectorReplayBuffer(total_size=40000, buffer_num=num)) + c3 = Collector(policy, dum, VectorReplayBuffer(total_size=40000, buffer_num=num)) for i in tqdm.trange(1, 400, desc="test step collector n_episode"): c3.reset() result = c3.collect(n_episode=i) @@ -64,22 +62,22 @@ def test_collector_nepisode(): def test_asynccollector(): env_lens = [2, 3, 4, 5] - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) - for i in env_lens] + env_fns = [ + lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens + ] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() bufsize = 300 c1 = AsyncCollector( - policy, venv, - VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)) + policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4) + ) ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 100, desc="test async n_episode"): result = c1.collect(n_episode=n_episode) assert result["n/ep"] >= n_episode # check buffer data, obs and obs_next, env_id - for i, count in enumerate( - np.bincount(result["lens"], minlength=6)[2:]): + for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): env_len = i + 2 total = env_len * count indices = np.arange(ptr[i], ptr[i] + total) % bufsize @@ -88,8 +86,7 @@ def test_asynccollector(): buf = c1.buffer.buffers[i] assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.obs[indices].reshape(count, env_len) == seq) - assert np.all(buf.obs_next[indices].reshape( - count, env_len) == seq + 1) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 150, desc="test async n_step"): result = c1.collect(n_step=n_step) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 98972ca9d..3430fb09b 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,5 +1,4 @@ -from tianshou import data, env, utils, policy, trainer, exploration - +from tianshou import data, env, exploration, policy, trainer, utils __version__ = "0.4.3" diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 75e02a940..23a7e62ff 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,12 +1,17 @@ +"""isort:skip_file""" from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer.base import ReplayBuffer from tianshou.data.buffer.prio import PrioritizedReplayBuffer -from tianshou.data.buffer.manager import ReplayBufferManager -from tianshou.data.buffer.manager import PrioritizedReplayBufferManager -from tianshou.data.buffer.vecbuf import VectorReplayBuffer -from tianshou.data.buffer.vecbuf import PrioritizedVectorReplayBuffer +from tianshou.data.buffer.manager import ( + ReplayBufferManager, + PrioritizedReplayBufferManager, +) +from tianshou.data.buffer.vecbuf import ( + VectorReplayBuffer, + PrioritizedVectorReplayBuffer, +) from tianshou.data.buffer.cached import CachedReplayBuffer from tianshou.data.collector import Collector, AsyncCollector diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f0ce76c2b..11acf1ecd 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,11 +1,12 @@ -import torch import pprint import warnings -import numpy as np +from collections.abc import Collection from copy import deepcopy from numbers import Number -from collections.abc import Collection -from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Union + +import numpy as np +import torch IndexType = Union[slice, int, np.ndarray, List[int]] @@ -18,8 +19,7 @@ def _is_batch_set(data: Any) -> bool: # "for e in data" will just unpack the first dimension, # but data.tolist() will flatten ndarray of objects # so do not use data.tolist() - return data.dtype == object and all( - isinstance(e, (dict, Batch)) for e in data) + return data.dtype == object and all(isinstance(e, (dict, Batch)) for e in data) elif isinstance(data, (list, tuple)): if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): return True @@ -72,9 +72,9 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: return v -def _create_value( - inst: Any, size: int, stack: bool = True -) -> Union["Batch", np.ndarray, torch.Tensor]: +def _create_value(inst: Any, + size: int, + stack: bool = True) -> Union["Batch", np.ndarray, torch.Tensor]: """Create empty place-holders accroding to inst's shape. :param bool stack: whether to stack or to concatenate. E.g. if inst has shape of @@ -92,11 +92,10 @@ def _create_value( shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) if isinstance(inst, np.ndarray): target_type = inst.dtype.type if issubclass( - inst.dtype.type, (np.bool_, np.number)) else object + inst.dtype.type, (np.bool_, np.number) + ) else object return np.full( - shape, - fill_value=None if target_type == object else 0, - dtype=target_type + shape, fill_value=None if target_type == object else 0, dtype=target_type ) elif isinstance(inst, torch.Tensor): return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) @@ -133,8 +132,10 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: try: return torch.stack(v) # type: ignore except RuntimeError as e: - raise TypeError("Batch does not support non-stackable iterable" - " of torch.Tensor as unique value yet.") from e + raise TypeError( + "Batch does not support non-stackable iterable" + " of torch.Tensor as unique value yet." + ) from e if _is_batch_set(v): v = Batch(v) # list of dict / Batch else: @@ -143,8 +144,10 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: try: v = _to_array_with_correct_type(v) except ValueError as e: - raise TypeError("Batch does not support heterogeneous list/" - "tuple of tensors as unique value yet.") from e + raise TypeError( + "Batch does not support heterogeneous list/" + "tuple of tensors as unique value yet." + ) from e return v @@ -172,12 +175,10 @@ class Batch: For a detailed description, please refer to :ref:`batch_concept`. """ - def __init__( self, - batch_dict: Optional[ - Union[dict, "Batch", Sequence[Union[dict, "Batch"]], np.ndarray] - ] = None, + batch_dict: Optional[Union[dict, "Batch", Sequence[Union[dict, "Batch"]], + np.ndarray]] = None, copy: bool = False, **kwargs: Any, ) -> None: @@ -248,11 +249,12 @@ def __setitem__(self, index: Union[str, IndexType], value: Any) -> None: self.__dict__[index] = value return if not isinstance(value, Batch): - raise ValueError("Batch does not supported tensor assignment. " - "Use a compatible Batch or dict instead.") - if not set(value.keys()).issubset(self.__dict__.keys()): raise ValueError( - "Creating keys is not supported by item assignment.") + "Batch does not supported tensor assignment. " + "Use a compatible Batch or dict instead." + ) + if not set(value.keys()).issubset(self.__dict__.keys()): + raise ValueError("Creating keys is not supported by item assignment.") for key, val in self.items(): try: self.__dict__[key][index] = value[key] @@ -368,9 +370,7 @@ def to_torch( v = v.type(dtype) self.__dict__[k] = v - def __cat( - self, batches: Sequence[Union[dict, "Batch"]], lens: List[int] - ) -> None: + def __cat(self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]) -> None: """Private method for Batch.cat_. :: @@ -397,9 +397,11 @@ def __cat( sum_lens.append(sum_lens[-1] + x) # collect non-empty keys keys_map = [ - set(k for k, v in batch.items() - if not (isinstance(v, Batch) and v.is_empty())) - for batch in batches] + set( + k for k, v in batch.items() + if not (isinstance(v, Batch) and v.is_empty()) + ) for batch in batches + ] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] for k, v in zip(keys_shared, values_shared): @@ -433,8 +435,7 @@ def __cat( try: self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val except KeyError: - self.__dict__[k] = _create_value( - val, sum_lens[-1], stack=False) + self.__dict__[k] = _create_value(val, sum_lens[-1], stack=False) self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: @@ -465,7 +466,8 @@ def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: raise ValueError( "Batch.cat_ meets an exception. Maybe because there is any " f"scalar in {batches} but Batch.cat_ does not support the " - "concatenation of scalar.") from e + "concatenation of scalar." + ) from e if not self.is_empty(): batches = [self] + list(batches) lens = [0 if self.is_empty(recurse=True) else len(self)] + lens @@ -506,8 +508,7 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None if not b.is_empty(): batch_list.append(b) else: - raise ValueError( - f"Cannot concatenate {type(b)} in Batch.stack_") + raise ValueError(f"Cannot concatenate {type(b)} in Batch.stack_") if len(batch_list) == 0: return batches = batch_list @@ -515,9 +516,11 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None batches = [self] + batches # collect non-empty keys keys_map = [ - set(k for k, v in batch.items() - if not (isinstance(v, Batch) and v.is_empty())) - for batch in batches] + set( + k for k, v in batch.items() + if not (isinstance(v, Batch) and v.is_empty()) + ) for batch in batches + ] keys_shared = set.intersection(*keys_map) values_shared = [[e[k] for e in batches] for k in keys_shared] for k, v in zip(keys_shared, values_shared): @@ -529,8 +532,10 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None try: self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis)) except ValueError: - warnings.warn("You are using tensors with different shape," - " fallback to dtype=object by default.") + warnings.warn( + "You are using tensors with different shape," + " fallback to dtype=object by default." + ) self.__dict__[k] = np.array(v, dtype=object) # all the keys keys_total = set.union(*[set(b.keys()) for b in batches]) @@ -543,7 +548,8 @@ def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None if keys_partial and axis != 0: raise ValueError( f"Stack of Batch with non-shared keys {keys_partial} is only " - f"supported with axis=0, but got axis={axis}!") + f"supported with axis=0, but got axis={axis}!" + ) for k in keys_reserve: # reserved keys self.__dict__[k] = Batch() @@ -625,8 +631,10 @@ def empty_(self, index: Optional[Union[slice, IndexType]] = None) -> "Batch": elif isinstance(v, Batch): self.__dict__[k].empty_(index=index) else: # scalar value - warnings.warn("You are calling Batch.empty on a NumPy scalar, " - "which may cause undefined behaviors.") + warnings.warn( + "You are calling Batch.empty on a NumPy scalar, " + "which may cause undefined behaviors." + ) if _is_number(v): self.__dict__[k] = v.__class__(0) else: @@ -701,7 +709,8 @@ def is_empty(self, recurse: bool = False) -> bool: return False return all( False if not isinstance(x, Batch) else x.is_empty(recurse=True) - for x in self.values()) + for x in self.values() + ) @property def shape(self) -> List[int]: @@ -718,9 +727,10 @@ def shape(self) -> List[int]: return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ else data_shape[0] - def split( - self, size: int, shuffle: bool = True, merge_last: bool = False - ) -> Iterator["Batch"]: + def split(self, + size: int, + shuffle: bool = True, + merge_last: bool = False) -> Iterator["Batch"]: """Split whole data into multiple small batches. :param int size: divide the data batch with the given size, but one diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index d7fde0e07..381581ce9 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -1,10 +1,11 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + import h5py import numpy as np -from typing import Any, Dict, List, Tuple, Union, Optional from tianshou.data import Batch -from tianshou.data.utils.converter import to_hdf5, from_hdf5 -from tianshou.data.batch import _create_value, _alloc_by_keys_diff +from tianshou.data.batch import _alloc_by_keys_diff, _create_value +from tianshou.data.utils.converter import from_hdf5, to_hdf5 class ReplayBuffer: @@ -81,9 +82,8 @@ def __setstate__(self, state: Dict[str, Any]) -> None: def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" - assert ( - key not in self._reserved_keys - ), "key '{}' is reserved and cannot be assigned".format(key) + assert (key not in self._reserved_keys + ), "key '{}' is reserved and cannot be assigned".format(key) super().__setattr__(key, value) def save_hdf5(self, path: str) -> None: @@ -160,9 +160,8 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _add_index( - self, rew: Union[float, np.ndarray], done: bool - ) -> Tuple[int, Union[float, np.ndarray], int, int]: + def _add_index(self, rew: Union[float, np.ndarray], + done: bool) -> Tuple[int, Union[float, np.ndarray], int, int]: """Maintain the buffer's state after adding one data batch. Return (index_to_be_modified, episode_reward, episode_length, @@ -183,7 +182,9 @@ def _add_index( return ptr, self._ep_rew * 0.0, 0, self._ep_idx def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into replay buffer. @@ -246,7 +247,8 @@ def sample_indices(self, batch_size: int) -> np.ndarray: return np.random.choice(self._size, batch_size) elif batch_size == 0: # construct current available indices return np.concatenate( - [np.arange(self._index, self._size), np.arange(self._index)] + [np.arange(self._index, self._size), + np.arange(self._index)] ) else: return np.array([], int) @@ -254,7 +256,8 @@ def sample_indices(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], int) all_indices = prev_indices = np.concatenate( - [np.arange(self._index, self._size), np.arange(self._index)] + [np.arange(self._index, self._size), + np.arange(self._index)] ) for _ in range(self.stack_num - 2): prev_indices = self.prev(prev_indices) diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py index 49bb33bcf..f62c41aba 100644 --- a/tianshou/data/buffer/cached.py +++ b/tianshou/data/buffer/cached.py @@ -1,5 +1,6 @@ +from typing import List, Optional, Tuple, Union + import numpy as np -from typing import List, Tuple, Union, Optional from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager @@ -25,7 +26,6 @@ class CachedReplayBuffer(ReplayBufferManager): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__( self, main_buffer: ReplayBuffer, @@ -45,7 +45,9 @@ def __init__( self.cached_buffer_num = cached_buffer_num def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into CachedReplayBuffer. diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index dc93b6867..0009575c3 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -1,9 +1,10 @@ +from typing import List, Optional, Sequence, Tuple, Union + import numpy as np from numba import njit -from typing import List, Tuple, Union, Sequence, Optional -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer -from tianshou.data.batch import _create_value, _alloc_by_keys_diff +from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer +from tianshou.data.batch import _alloc_by_keys_diff, _create_value class ReplayBufferManager(ReplayBuffer): @@ -19,7 +20,6 @@ class ReplayBufferManager(ReplayBuffer): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__(self, buffer_list: List[ReplayBuffer]) -> None: self.buffer_num = len(buffer_list) self.buffers = np.array(buffer_list, dtype=object) @@ -63,33 +63,45 @@ def set_batch(self, batch: Batch) -> None: self._set_batch_for_children() def unfinished_index(self) -> np.ndarray: - return np.concatenate([ - buf.unfinished_index() + offset - for offset, buf in zip(self._offset, self.buffers) - ]) + return np.concatenate( + [ + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers) + ] + ) def prev(self, index: Union[int, np.ndarray]) -> np.ndarray: if isinstance(index, (list, np.ndarray)): - return _prev_index(np.asarray(index), self._extend_offset, - self.done, self.last_index, self._lengths) + return _prev_index( + np.asarray(index), self._extend_offset, self.done, self.last_index, + self._lengths + ) else: - return _prev_index(np.array([index]), self._extend_offset, - self.done, self.last_index, self._lengths)[0] + return _prev_index( + np.array([index]), self._extend_offset, self.done, self.last_index, + self._lengths + )[0] def next(self, index: Union[int, np.ndarray]) -> np.ndarray: if isinstance(index, (list, np.ndarray)): - return _next_index(np.asarray(index), self._extend_offset, - self.done, self.last_index, self._lengths) + return _next_index( + np.asarray(index), self._extend_offset, self.done, self.last_index, + self._lengths + ) else: - return _next_index(np.array([index]), self._extend_offset, - self.done, self.last_index, self._lengths)[0] + return _next_index( + np.array([index]), self._extend_offset, self.done, self.last_index, + self._lengths + )[0] def update(self, buffer: ReplayBuffer) -> np.ndarray: """The ReplayBufferManager cannot be updated by any buffer.""" raise NotImplementedError def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Add a batch of data into ReplayBufferManager. @@ -145,10 +157,12 @@ def sample_indices(self, batch_size: int) -> np.ndarray: if batch_size < 0: return np.array([], int) if self._sample_avail and self.stack_num > 1: - all_indices = np.concatenate([ - buf.sample_indices(0) + offset - for offset, buf in zip(self._offset, self.buffers) - ]) + all_indices = np.concatenate( + [ + buf.sample_indices(0) + offset + for offset, buf in zip(self._offset, self.buffers) + ] + ) if batch_size == 0: return all_indices else: @@ -163,10 +177,12 @@ def sample_indices(self, batch_size: int) -> np.ndarray: # avoid batch_size > 0 and sample_num == 0 -> get child's all data sample_num[sample_num == 0] = -1 - return np.concatenate([ - buf.sample_indices(bsz) + offset - for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) - ]) + return np.concatenate( + [ + buf.sample_indices(bsz) + offset + for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) + ] + ) class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): @@ -182,7 +198,6 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: ReplayBufferManager.__init__(self, buffer_list) # type: ignore kwargs = buffer_list[0].options diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index a4357d5ed..17353a514 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -1,8 +1,9 @@ -import torch +from typing import Any, List, Optional, Tuple, Union + import numpy as np -from typing import Any, List, Tuple, Union, Optional +import torch -from tianshou.data import Batch, SegmentTree, to_numpy, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, SegmentTree, to_numpy class PrioritizedReplayBuffer(ReplayBuffer): @@ -17,7 +18,6 @@ class PrioritizedReplayBuffer(ReplayBuffer): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__( self, size: int, @@ -39,7 +39,7 @@ def __init__( self._weight_norm = weight_norm def init_weight(self, index: Union[int, np.ndarray]) -> None: - self.weight[index] = self._max_prio ** self._alpha + self.weight[index] = self._max_prio**self._alpha def update(self, buffer: ReplayBuffer) -> np.ndarray: indices = super().update(buffer) @@ -47,7 +47,9 @@ def update(self, buffer: ReplayBuffer) -> np.ndarray: return indices def add( - self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) self.init_weight(ptr) @@ -70,7 +72,7 @@ def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) - return (self.weight[index] / self._min_prio) ** (-self._beta) + return (self.weight[index] / self._min_prio)**(-self._beta) def update_weight( self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor] @@ -81,7 +83,7 @@ def update_weight( :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps - self.weight[index] = weight ** self._alpha + self.weight[index] = weight**self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py index 374765bdc..3d5508fc9 100644 --- a/tianshou/data/buffer/vecbuf.py +++ b/tianshou/data/buffer/vecbuf.py @@ -1,8 +1,13 @@ -import numpy as np from typing import Any -from tianshou.data import ReplayBuffer, ReplayBufferManager -from tianshou.data import PrioritizedReplayBuffer, PrioritizedReplayBufferManager +import numpy as np + +from tianshou.data import ( + PrioritizedReplayBuffer, + PrioritizedReplayBufferManager, + ReplayBuffer, + ReplayBufferManager, +) class VectorReplayBuffer(ReplayBufferManager): @@ -22,7 +27,6 @@ class VectorReplayBuffer(ReplayBufferManager): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) @@ -47,7 +51,6 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 192213aa3..1da6d08b1 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,21 +1,22 @@ -import gym import time -import torch import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import gym import numpy as np -from typing import Any, Dict, List, Union, Optional, Callable +import torch -from tianshou.policy import BasePolicy -from tianshou.data.batch import _alloc_by_keys_diff -from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.data import ( Batch, + CachedReplayBuffer, ReplayBuffer, ReplayBufferManager, VectorReplayBuffer, - CachedReplayBuffer, to_numpy, ) +from tianshou.data.batch import _alloc_by_keys_diff +from tianshou.env import BaseVectorEnv, DummyVectorEnv +from tianshou.policy import BasePolicy class Collector(object): @@ -46,7 +47,6 @@ class Collector(object): Please make sure the given environment has a time limitation if using n_episode collect option. """ - def __init__( self, policy: BasePolicy, @@ -97,8 +97,9 @@ def reset(self) -> None: """Reset all related variables in the collector.""" # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch(obs={}, act={}, rew={}, done={}, - obs_next={}, info={}, policy={}) + self.data = Batch( + obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} + ) self.reset_env() self.reset_buffer() self.reset_stat() @@ -115,8 +116,8 @@ def reset_env(self) -> None: """Reset all of the environments.""" obs = self.env.reset() if self.preprocess_fn: - obs = self.preprocess_fn( - obs=obs, env_id=np.arange(self.env_num)).get("obs", obs) + obs = self.preprocess_fn(obs=obs, + env_id=np.arange(self.env_num)).get("obs", obs) self.data.obs = obs def _reset_state(self, id: Union[int, List[int]]) -> None: @@ -184,8 +185,10 @@ def collect( ready_env_ids = np.arange(min(self.env_num, n_episode)) self.data = self.data[:min(self.env_num, n_episode)] else: - raise TypeError("Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().") + raise TypeError( + "Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect()." + ) start_time = time.time() @@ -203,7 +206,8 @@ def collect( # get the next action if random: self.data.update( - act=[self._action_space[i].sample() for i in ready_env_ids]) + act=[self._action_space[i].sample() for i in ready_env_ids] + ) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -226,18 +230,21 @@ def collect( action_remap = self.policy.map_action(self.data.act) # step in env obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids) # type: ignore + action_remap, ready_env_ids + ) # type: ignore self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update(self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - policy=self.data.policy, - env_id=ready_env_ids, - )) + self.data.update( + self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + policy=self.data.policy, + env_id=ready_env_ids, + ) + ) if render: self.env.render() @@ -246,7 +253,8 @@ def collect( # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add( - self.data, buffer_ids=ready_env_ids) + self.data, buffer_ids=ready_env_ids + ) # collect statistics step_count += len(ready_env_ids) @@ -263,7 +271,8 @@ def collect( obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset) + obs=obs_reset, env_id=env_ind_global + ).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) @@ -290,13 +299,18 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if n_episode: - self.data = Batch(obs={}, act={}, rew={}, done={}, - obs_next={}, info={}, policy={}) + self.data = Batch( + obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} + ) self.reset_env() if episode_count > 0: - rews, lens, idxs = list(map( - np.concatenate, [episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list( + map( + np.concatenate, + [episode_rews, episode_lens, episode_start_indices] + ) + ) else: rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) @@ -315,7 +329,6 @@ class AsyncCollector(Collector): The arguments are exactly the same as :class:`~tianshou.data.Collector`, please refer to :class:`~tianshou.data.Collector` for more detailed explanation. """ - def __init__( self, policy: BasePolicy, @@ -377,8 +390,10 @@ def collect( elif n_episode is not None: assert n_episode > 0 else: - raise TypeError("Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().") + raise TypeError( + "Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect()." + ) warnings.warn("Using async setting may collect extra transitions into buffer.") ready_env_ids = self._ready_env_ids @@ -401,7 +416,8 @@ def collect( # get the next action if random: self.data.update( - act=[self._action_space[i].sample() for i in ready_env_ids]) + act=[self._action_space[i].sample() for i in ready_env_ids] + ) else: if no_grad: with torch.no_grad(): # faster than retain_grad version @@ -432,7 +448,8 @@ def collect( action_remap = self.policy.map_action(self.data.act) # step in env obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids) # type: ignore + action_remap, ready_env_ids + ) # type: ignore # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) @@ -440,13 +457,15 @@ def collect( self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: - self.data.update(self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - env_id=ready_env_ids, - )) + self.data.update( + self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + env_id=ready_env_ids, + ) + ) if render: self.env.render() @@ -455,7 +474,8 @@ def collect( # add data into the buffer ptr, ep_rew, ep_len, ep_idx = self.buffer.add( - self.data, buffer_ids=ready_env_ids) + self.data, buffer_ids=ready_env_ids + ) # collect statistics step_count += len(ready_env_ids) @@ -472,7 +492,8 @@ def collect( obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset) + obs=obs_reset, env_id=env_ind_global + ).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) @@ -500,8 +521,12 @@ def collect( self.collect_time += max(time.time() - start_time, 1e-9) if episode_count > 0: - rews, lens, idxs = list(map( - np.concatenate, [episode_rews, episode_lens, episode_start_indices])) + rews, lens, idxs = list( + map( + np.concatenate, + [episode_rews, episode_lens, episode_start_indices] + ) + ) else: rews, lens, idxs = np.array([]), np.array([], int), np.array([], int) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 9f7d88a82..7a95169d2 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -1,12 +1,13 @@ -import h5py -import torch import pickle -import numpy as np from copy import deepcopy from numbers import Number -from typing import Any, Dict, Union, Optional +from typing import Any, Dict, Optional, Union + +import h5py +import numpy as np +import torch -from tianshou.data.batch import _parse_value, Batch +from tianshou.data.batch import Batch, _parse_value def to_numpy(x: Any) -> Union[Batch, np.ndarray]: @@ -79,7 +80,6 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]: def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None: """Copy object into HDF5 group.""" - def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None: """Pickle, convert to numpy array and write to HDF5 dataset.""" data = np.frombuffer(pickle.dumps(x), dtype=np.byte) diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py index 5bb6fcc06..0f9a5a5b2 100644 --- a/tianshou/data/utils/segtree.py +++ b/tianshou/data/utils/segtree.py @@ -1,6 +1,7 @@ +from typing import Optional, Union + import numpy as np from numba import njit -from typing import Union, Optional class SegmentTree: @@ -16,7 +17,6 @@ class SegmentTree: :param int size: the size of segment tree. """ - def __init__(self, size: int) -> None: bound = 1 while bound < size: @@ -29,9 +29,7 @@ def __init__(self, size: int) -> None: def __len__(self) -> int: return self._size - def __getitem__( - self, index: Union[int, np.ndarray] - ) -> Union[float, np.ndarray]: + def __getitem__(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: """Return self[index].""" return self._value[index + self._bound] @@ -64,9 +62,8 @@ def reduce(self, start: int = 0, end: Optional[int] = None) -> float: end += self._size return _reduce(self._value, start + self._bound - 1, end + self._bound) - def get_prefix_sum_idx( - self, value: Union[float, np.ndarray] - ) -> Union[int, np.ndarray]: + def get_prefix_sum_idx(self, value: Union[float, + np.ndarray]) -> Union[int, np.ndarray]: r"""Find the index with given value. Return the minimum index for each ``v`` in ``value`` so that @@ -122,9 +119,7 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float: @njit -def _get_prefix_sum_idx( - value: np.ndarray, bound: int, sums: np.ndarray -) -> np.ndarray: +def _get_prefix_sum_idx(value: np.ndarray, bound: int, sums: np.ndarray) -> np.ndarray: """Numba version (v0.51), 5x speed up with size=100000 and bsz=64. vectorized np: 0.0923 (numpy best) -> 0.024 (now) diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index a25e06f86..0e145ba29 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,6 +1,11 @@ -from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, \ - SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv from tianshou.env.maenv import MultiAgentEnv +from tianshou.env.venvs import ( + BaseVectorEnv, + DummyVectorEnv, + RayVectorEnv, + ShmemVectorEnv, + SubprocVectorEnv, +) __all__ = [ "BaseVectorEnv", diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py index f6a454c3c..299051d0c 100644 --- a/tianshou/env/maenv.py +++ b/tianshou/env/maenv.py @@ -1,7 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + import gym import numpy as np -from typing import Any, Dict, Tuple -from abc import ABC, abstractmethod class MultiAgentEnv(ABC, gym.Env): @@ -21,7 +22,6 @@ class MultiAgentEnv(ABC, gym.Env): The available action's mask is set to 1, otherwise it is set to 0. Further usage can be found at :ref:`marl_example`. """ - def __init__(self) -> None: pass diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index 5c873ce36..c0031ba02 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -1,10 +1,10 @@ -import cloudpickle from typing import Any +import cloudpickle + class CloudpickleWrapper(object): """A cloudpickle wrapper used in SubprocVectorEnv.""" - def __init__(self, data: Any) -> None: self.data = data diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index f9349ff24..3e37f56f3 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,10 +1,15 @@ +from typing import Any, Callable, List, Optional, Tuple, Union + import gym import numpy as np -from typing import Any, List, Tuple, Union, Optional, Callable +from tianshou.env.worker import ( + DummyEnvWorker, + EnvWorker, + RayEnvWorker, + SubprocEnvWorker, +) from tianshou.utils import RunningMeanStd -from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \ - RayEnvWorker class BaseVectorEnv(gym.Env): @@ -64,7 +69,6 @@ def seed(self, seed): obs_rms should be passed in. Default to None. :param bool update_obs_rms: Whether to update obs_rms. Default to True. """ - def __init__( self, env_fns: List[Callable[[], gym.Env]], @@ -122,8 +126,9 @@ def __getattribute__(self, key: str) -> Any: ``action_space``. However, we would like the attribute lookup to go straight into the worker (in fact, this vector env's action_space is always None). """ - if key in ['metadata', 'reward_range', 'spec', 'action_space', - 'observation_space']: # reserved keys in gym.Env + if key in [ + 'metadata', 'reward_range', 'spec', 'action_space', 'observation_space' + ]: # reserved keys in gym.Env return self.__getattr__(key) else: return super().__getattribute__(key) @@ -137,7 +142,8 @@ def __getattr__(self, key: str) -> List[Any]: return [getattr(worker, key) for worker in self.workers] def _wrap_id( - self, id: Optional[Union[int, List[int], np.ndarray]] = None + self, + id: Optional[Union[int, List[int], np.ndarray]] = None ) -> Union[List[int], np.ndarray]: if id is None: return list(range(self.env_num)) @@ -230,7 +236,8 @@ def step( ready_conns: List[EnvWorker] = [] while not ready_conns: ready_conns = self.worker_class.wait( - self.waiting_conn, self.wait_num, self.timeout) + self.waiting_conn, self.wait_num, self.timeout + ) result = [] for conn in ready_conns: waiting_index = self.waiting_conn.index(conn) @@ -246,13 +253,15 @@ def step( except ValueError: # different len(obs) obs_stack = np.array(obs_list, dtype=object) rew_stack, done_stack, info_stack = map( - np.stack, [rew_list, done_list, info_list]) + np.stack, [rew_list, done_list, info_list] + ) if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs_stack) return self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack def seed( - self, seed: Optional[Union[int, List[int]]] = None + self, + seed: Optional[Union[int, List[int]]] = None ) -> List[Optional[List[int]]]: """Set the seed for all environments. @@ -279,7 +288,8 @@ def render(self, **kwargs: Any) -> List[Any]: if self.is_async and len(self.waiting_id) > 0: raise RuntimeError( f"Environments {self.waiting_id} are still stepping, cannot " - "render them now.") + "render them now." + ) return [w.render(**kwargs) for w in self.workers] def close(self) -> None: @@ -310,7 +320,6 @@ class DummyVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: super().__init__(env_fns, DummyEnvWorker, **kwargs) @@ -322,7 +331,6 @@ class SubprocVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) @@ -339,7 +347,6 @@ class ShmemVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) @@ -356,7 +363,6 @@ class RayVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: try: import ray diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py index b9a20cf1e..1b1f37510 100644 --- a/tianshou/env/worker/__init__.py +++ b/tianshou/env/worker/__init__.py @@ -1,7 +1,7 @@ from tianshou.env.worker.base import EnvWorker from tianshou.env.worker.dummy import DummyEnvWorker -from tianshou.env.worker.subproc import SubprocEnvWorker from tianshou.env.worker.ray import RayEnvWorker +from tianshou.env.worker.subproc import SubprocEnvWorker __all__ = [ "EnvWorker", diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index dbf350a33..2a32ce961 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -1,12 +1,12 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, List, Optional, Tuple + import gym import numpy as np -from abc import ABC, abstractmethod -from typing import Any, List, Tuple, Optional, Callable class EnvWorker(ABC): """An abstract worker for an environment.""" - def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False @@ -43,7 +43,9 @@ def step( @staticmethod def wait( - workers: List["EnvWorker"], wait_num: int, timeout: Optional[float] = None + workers: List["EnvWorker"], + wait_num: int, + timeout: Optional[float] = None ) -> List["EnvWorker"]: """Given a list of workers, return those ready ones.""" raise NotImplementedError diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index d0579d162..d964850f3 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -1,13 +1,13 @@ +from typing import Any, Callable, List, Optional + import gym import numpy as np -from typing import Any, List, Callable, Optional from tianshou.env.worker import EnvWorker class DummyEnvWorker(EnvWorker): """Dummy worker used in sequential vector environments.""" - def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self.env = env_fn() super().__init__(env_fn) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index af7285b22..22f6bb46e 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -1,6 +1,7 @@ +from typing import Any, Callable, List, Optional, Tuple + import gym import numpy as np -from typing import Any, List, Callable, Tuple, Optional from tianshou.env.worker import EnvWorker @@ -12,7 +13,6 @@ class RayEnvWorker(EnvWorker): """Ray worker used in RayVectorEnv.""" - def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn()) super().__init__(env_fn) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 8b89b6c34..5f81ab452 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -1,15 +1,15 @@ -import gym -import time import ctypes -import numpy as np +import time from collections import OrderedDict -from multiprocessing.context import Process from multiprocessing import Array, Pipe, connection -from typing import Any, List, Tuple, Union, Callable, Optional +from multiprocessing.context import Process +from typing import Any, Callable, List, Optional, Tuple, Union -from tianshou.env.worker import EnvWorker -from tianshou.env.utils import CloudpickleWrapper +import gym +import numpy as np +from tianshou.env.utils import CloudpickleWrapper +from tianshou.env.worker import EnvWorker _NP_TO_CT = { np.bool_: ctypes.c_bool, @@ -28,7 +28,6 @@ class ShArray: """Wrapper of multiprocessing Array.""" - def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore self.dtype = dtype @@ -115,7 +114,6 @@ def _encode_obs( class SubprocEnvWorker(EnvWorker): """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" - def __init__( self, env_fn: Callable[[], gym.Env], share_memory: bool = False ) -> None: diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index a59085809..0c5fd3e4c 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -1,11 +1,11 @@ -import numpy as np from abc import ABC, abstractmethod -from typing import Union, Optional, Sequence +from typing import Optional, Sequence, Union + +import numpy as np class BaseNoise(ABC, object): """The action noise base class.""" - def __init__(self) -> None: super().__init__() @@ -21,7 +21,6 @@ def __call__(self, size: Sequence[int]) -> np.ndarray: class GaussianNoise(BaseNoise): """The vanilla gaussian process, for exploration in DDPG by default.""" - def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: super().__init__() self._mu = mu @@ -48,7 +47,6 @@ class OUNoise(BaseNoise): vanilla gaussian process has little difference from using the Ornstein-Uhlenbeck process. """ - def __init__( self, mu: float = 0.0, @@ -74,7 +72,8 @@ def __call__(self, size: Sequence[int], mu: Optional[float] = None) -> np.ndarra Return an numpy array which size is equal to ``size``. """ if self._x is None or isinstance( - self._x, np.ndarray) and self._x.shape != size: + self._x, np.ndarray + ) and self._x.shape != size: self._x = 0.0 if mu is None: mu = self._mu diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 8a9c6478e..421898162 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,3 +1,4 @@ +"""isort:skip_file""" from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy @@ -22,7 +23,6 @@ from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager - __all__ = [ "BasePolicy", "RandomPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 6dc3bdda8..deed1f168 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,13 +1,14 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional, Tuple, Union + import gym -import torch import numpy as np -from torch import nn +import torch +from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete from numba import njit -from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple, Union, Optional, Callable -from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary +from torch import nn -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as class BasePolicy(ABC, nn.Module): @@ -56,7 +57,6 @@ class BasePolicy(ABC, nn.Module): torch.save(policy.state_dict(), "policy.pth") policy.load_state_dict(torch.load("policy.pth")) """ - def __init__( self, observation_space: Optional[gym.Space] = None, @@ -84,9 +84,8 @@ def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: """Modify the action from policy.forward with exploration noise. :param act: a data batch or numpy.ndarray which is the action taken by @@ -216,9 +215,8 @@ def post_process_fn( if hasattr(buffer, "update_weight") and hasattr(batch, "weight"): buffer.update_weight(indices, batch.weight) - def update( - self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any - ) -> Dict[str, Any]: + def update(self, sample_size: int, buffer: Optional[ReplayBuffer], + **kwargs: Any) -> Dict[str, Any]: """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. In diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index f94aa1d39..35e665e14 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -1,7 +1,8 @@ -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Union, Optional from tianshou.data import Batch, to_torch from tianshou.policy import BasePolicy @@ -20,7 +21,6 @@ class ImitationPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 671b8b080..19c7b08da 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -1,11 +1,12 @@ import math -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Union, Optional -from tianshou.policy import DQNPolicy from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import DQNPolicy class DiscreteBCQPolicy(DQNPolicy): @@ -32,7 +33,6 @@ class DiscreteBCQPolicy(DQNPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, @@ -47,8 +47,10 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, estimation_step, target_update_freq, + reward_normalization, **kwargs + ) assert target_update_freq > 0, "BCQ needs target network setting." self.imitator = imitator assert 0.0 <= unlikely_action_threshold < 1.0, \ @@ -93,8 +95,12 @@ def forward( # type: ignore mask = (ratio < self._log_tau).float() action = (q_value - np.inf * mask).argmax(dim=-1) - return Batch(act=action, state=state, q_value=q_value, - imitation_logits=imitation_logits) + return Batch( + act=action, + state=state, + q_value=q_value, + imitation_logits=imitation_logits + ) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: @@ -108,7 +114,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss( - F.log_softmax(imitation_logits, dim=-1), act) # type: ignore + F.log_softmax(imitation_logits, dim=-1), act + ) # type: ignore reg_loss = imitation_logits.pow(2).mean() loss = q_loss + i_loss + self._weight_reg * reg_loss diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index c6e1b50d0..cfd9e4cab 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict -from tianshou.policy import QRDQNPolicy from tianshou.data import Batch, to_torch +from tianshou.policy import QRDQNPolicy class DiscreteCQLPolicy(QRDQNPolicy): @@ -27,7 +28,6 @@ class DiscreteCQLPolicy(QRDQNPolicy): Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, @@ -40,8 +40,10 @@ def __init__( min_q_weight: float = 10.0, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, num_quantiles, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, num_quantiles, estimation_step, + target_update_freq, reward_normalization, **kwargs + ) self._min_q_weight = min_q_weight def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -55,9 +57,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - self.tau_hat - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * (self.tau_hat - + (target_dist - curr_dist).detach().le(0.).float()).abs() + ).sum(-1).mean(1) qr_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 05e4b2655..8506ba550 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -1,11 +1,12 @@ -import torch from copy import deepcopy from typing import Any, Dict + +import torch import torch.nn.functional as F from torch.distributions import Categorical -from tianshou.policy.modelfree.pg import PGPolicy from tianshou.data import Batch, to_torch, to_torch_as +from tianshou.policy.modelfree.pg import PGPolicy class DiscreteCRRPolicy(PGPolicy): @@ -33,7 +34,6 @@ class DiscreteCRRPolicy(PGPolicy): Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed explanation. """ - def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index b438dbcbc..ce385aa28 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -1,6 +1,7 @@ -import torch +from typing import Any, Dict, Optional, Tuple, Union + import numpy as np -from typing import Any, Dict, Tuple, Union, Optional +import torch from tianshou.data import Batch from tianshou.policy import BasePolicy @@ -18,7 +19,6 @@ class PSRLModel(object): :param float discount_factor: in [0, 1]. :param float epsilon: for precision control in value iteration. """ - def __init__( self, trans_count_prior: np.ndarray, @@ -70,14 +70,16 @@ def observe( sum_count = self.rew_count + rew_count self.rew_mean = (self.rew_mean * self.rew_count + rew_sum) / sum_count self.rew_square_sum += rew_square_sum - raw_std2 = self.rew_square_sum / sum_count - self.rew_mean ** 2 - self.rew_std = np.sqrt(1 / ( - sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior ** 2)) + raw_std2 = self.rew_square_sum / sum_count - self.rew_mean**2 + self.rew_std = np.sqrt( + 1 / (sum_count / (raw_std2 + self.__eps) + 1 / self.rew_std_prior**2) + ) self.rew_count = sum_count def sample_trans_prob(self) -> np.ndarray: sample_prob = torch.distributions.Dirichlet( - torch.from_numpy(self.trans_count)).sample().numpy() + torch.from_numpy(self.trans_count) + ).sample().numpy() return sample_prob def sample_reward(self) -> np.ndarray: @@ -156,7 +158,6 @@ class PSRLPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, trans_count_prior: np.ndarray, @@ -168,12 +169,10 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) - assert ( - 0.0 <= discount_factor <= 1.0 - ), "discount factor should be in [0, 1]" + assert (0.0 <= discount_factor <= 1.0), "discount factor should be in [0, 1]" self.model = PSRLModel( - trans_count_prior, rew_mean_prior, rew_std_prior, - discount_factor, epsilon) + trans_count_prior, rew_mean_prior, rew_std_prior, discount_factor, epsilon + ) self._add_done_loop = add_done_loop def forward( @@ -195,9 +194,7 @@ def forward( act = self.model(batch.obs, state=state, info=batch.info) return Batch(act=act) - def learn( - self, batch: Batch, *args: Any, **kwargs: Any - ) -> Dict[str, float]: + def learn(self, batch: Batch, *args: Any, **kwargs: Any) -> Dict[str, float]: n_s, n_a = self.model.n_state, self.model.n_action trans_count = np.zeros((n_s, n_a, n_s)) rew_sum = np.zeros((n_s, n_a)) @@ -207,7 +204,7 @@ def learn( obs, act, obs_next = b.obs, b.act, b.obs_next trans_count[obs, act, obs_next] += 1 rew_sum[obs, act] += b.rew - rew_square_sum[obs, act] += b.rew ** 2 + rew_square_sum[obs, act] += b.rew**2 rew_count[obs, act] += 1 if self._add_done_loop and b.done: # special operation for terminal states: add a self-loop diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 3e05ce0b6..1c59d2648 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -1,11 +1,12 @@ -import torch +from typing import Any, Dict, List, Optional, Type + import numpy as np -from torch import nn +import torch import torch.nn.functional as F -from typing import Any, Dict, List, Type, Optional +from torch import nn -from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.policy import PGPolicy class A2CPolicy(PGPolicy): @@ -47,7 +48,6 @@ class A2CPolicy(PGPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, actor: torch.nn.Module, @@ -96,8 +96,14 @@ def _compute_returns( v_s = v_s * np.sqrt(self.ret_rms.var + self._eps) v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps) unnormalized_returns, advantages = self.compute_episodic_return( - batch, buffer, indices, v_s_, v_s, - gamma=self._gamma, gae_lambda=self._lambda) + batch, + buffer, + indices, + v_s_, + v_s, + gamma=self._gamma, + gae_lambda=self._lambda + ) if self._rew_norm: batch.returns = unnormalized_returns / \ np.sqrt(self.ret_rms.var + self._eps) @@ -130,7 +136,8 @@ def learn( # type: ignore if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_( set(self.actor.parameters()).union(self.critic.parameters()), - max_norm=self._grad_norm) + max_norm=self._grad_norm + ) self.optim.step() actor_losses.append(actor_loss.item()) vf_losses.append(vf_loss.item()) diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index a664096f5..17423031f 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,9 +1,10 @@ -import torch -import numpy as np from typing import Any, Dict, Optional -from tianshou.policy import DQNPolicy +import numpy as np +import torch + from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import DQNPolicy class C51Policy(DQNPolicy): @@ -30,7 +31,6 @@ class C51Policy(DQNPolicy): Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, @@ -44,8 +44,10 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, estimation_step, target_update_freq, + reward_normalization, **kwargs + ) assert num_atoms > 1, "num_atoms should be greater than 1" assert v_min < v_max, "v_max should be larger than v_min" self._num_atoms = num_atoms @@ -77,9 +79,10 @@ def _target_dist(self, batch: Batch) -> torch.Tensor: target_support = batch.returns.clamp(self._v_min, self._v_max) # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL - target_dist = (1 - (target_support.unsqueeze(1) - - self.support.view(1, -1, 1)).abs() / self.delta_z - ).clamp(0, 1) * next_dist.unsqueeze(1) + target_dist = ( + 1 - (target_support.unsqueeze(1) - self.support.view(1, -1, 1)).abs() / + self.delta_z + ).clamp(0, 1) * next_dist.unsqueeze(1) return target_dist.sum(-1) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -92,7 +95,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: curr_dist = self(batch).logits act = batch.act curr_dist = curr_dist[np.arange(len(act)), act, :] - cross_entropy = - (target_dist * torch.log(curr_dist + 1e-8)).sum(1) + cross_entropy = -(target_dist * torch.log(curr_dist + 1e-8)).sum(1) loss = (cross_entropy * weight).mean() # ref: https://github.com/Kaixhin/Rainbow/blob/master/agent.py L94-100 batch.weight = cross_entropy.detach() # prio-buffer diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index fc4a622cc..66ed4a780 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,12 +1,13 @@ -import torch import warnings -import numpy as np from copy import deepcopy -from typing import Any, Dict, Tuple, Union, Optional +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch -from tianshou.policy import BasePolicy -from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.data import Batch, ReplayBuffer +from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import BasePolicy class DDPGPolicy(BasePolicy): @@ -37,7 +38,6 @@ class DDPGPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, actor: Optional[torch.nn.Module], @@ -53,8 +53,11 @@ def __init__( action_bound_method: str = "clip", **kwargs: Any, ) -> None: - super().__init__(action_scaling=action_scaling, - action_bound_method=action_bound_method, **kwargs) + super().__init__( + action_scaling=action_scaling, + action_bound_method=action_bound_method, + **kwargs + ) assert action_bound_method != "tanh", "tanh mapping is not supported" \ "in policies where action is used as input of critic , because" \ "raw action in range (-inf, inf) will cause instability in training" @@ -96,21 +99,21 @@ def sync_weight(self) -> None: for o, n in zip(self.critic_old.parameters(), self.critic.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - def _target_q( - self, buffer: ReplayBuffer, indices: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs_next: s_{t+n} target_q = self.critic_old( batch.obs_next, - self(batch, model='actor_old', input='obs_next').act) + self(batch, model='actor_old', input='obs_next').act + ) return target_q def process_fn( self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray ) -> Batch: batch = self.compute_nstep_return( - batch, buffer, indices, self._target_q, - self._gamma, self._n_step, self._rew_norm) + batch, buffer, indices, self._target_q, self._gamma, self._n_step, + self._rew_norm + ) return batch def forward( @@ -156,8 +159,7 @@ def _mse_optimizer( def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic - td, critic_loss = self._mse_optimizer( - batch, self.critic, self.critic_optim) + td, critic_loss = self._mse_optimizer(batch, self.critic, self.critic_optim) batch.weight = td # prio-buffer # actor action = self(batch).act @@ -171,9 +173,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: "loss/critic": critic_loss.item(), } - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: if self._noise is None: return act if isinstance(act, np.ndarray): diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 18ac9fa12..9f2331711 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Tuple, Union + import numpy as np +import torch from torch.distributions import Categorical -from typing import Any, Dict, Tuple, Union, Optional -from tianshou.policy import SACPolicy from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import SACPolicy class DiscreteSACPolicy(SACPolicy): @@ -33,7 +34,6 @@ class DiscreteSACPolicy(SACPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, actor: torch.nn.Module, @@ -50,9 +50,21 @@ def __init__( **kwargs: Any, ) -> None: super().__init__( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - tau, gamma, alpha, reward_normalization, estimation_step, - action_scaling=False, action_bound_method="", **kwargs) + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau, + gamma, + alpha, + reward_normalization, + estimation_step, + action_scaling=False, + action_bound_method="", + **kwargs + ) self._alpha: Union[float, torch.Tensor] def forward( # type: ignore @@ -68,9 +80,7 @@ def forward( # type: ignore act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def _target_q( - self, buffer: ReplayBuffer, indices: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} obs_next_result = self(batch, input="obs_next") dist = obs_next_result.dist @@ -85,7 +95,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: weight = batch.pop("weight", 1.0) target_q = batch.returns.flatten() act = to_torch( - batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long) + batch.act[:, np.newaxis], device=target_q.device, dtype=torch.long + ) # critic 1 current_q1 = self.critic1(batch.obs).gather(1, act).flatten() @@ -139,7 +150,6 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: return result - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 03c39d171..ae573339d 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,10 +1,11 @@ -import torch -import numpy as np from copy import deepcopy -from typing import Any, Dict, Union, Optional +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy class DQNPolicy(BasePolicy): @@ -31,7 +32,6 @@ class DQNPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, @@ -96,8 +96,9 @@ def process_fn( :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. """ batch = self.compute_nstep_return( - batch, buffer, indices, self._target_q, - self._gamma, self._n_step, self._rew_norm) + batch, buffer, indices, self._target_q, self._gamma, self._n_step, + self._rew_norm + ) return batch def compute_q_value( @@ -173,9 +174,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self._iter += 1 return {"loss": loss.item()} - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index dc03365e9..df22acafc 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Optional, Union +from tianshou.data import Batch, ReplayBuffer, to_numpy from tianshou.policy import DQNPolicy, QRDQNPolicy -from tianshou.data import Batch, to_numpy, ReplayBuffer from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction @@ -32,7 +33,6 @@ class FQFPolicy(QRDQNPolicy): Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed explanation. """ - def __init__( self, model: FullQuantileFunction, @@ -88,12 +88,14 @@ def forward( ) else: (logits, _, quantiles_tau), h = model( - obs_, propose_model=self.propose_model, fractions=fractions, - state=state, info=batch.info + obs_, + propose_model=self.propose_model, + fractions=fractions, + state=state, + info=batch.info ) - weighted_logits = ( - fractions.taus[:, 1:] - fractions.taus[:, :-1] - ).unsqueeze(1) * logits + weighted_logits = (fractions.taus[:, 1:] - + fractions.taus[:, :-1]).unsqueeze(1) * logits q = DQNPolicy.compute_q_value( self, weighted_logits.sum(2), getattr(obs, "mask", None) ) @@ -101,7 +103,10 @@ def forward( self.max_action_num = q.shape[1] act = to_numpy(q.max(dim=1)[1]) return Batch( - logits=logits, act=act, state=h, fractions=fractions, + logits=logits, + act=act, + state=h, + fractions=fractions, quantiles_tau=quantiles_tau ) @@ -117,9 +122,12 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * ( + tau_hats.unsqueeze(2) - + (target_dist - curr_dist).detach().le(0.).float() + ).abs() + ).sum(-1).mean(1) quantile_loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 @@ -131,16 +139,18 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 values_1 = sa_quantiles - sa_quantile_hats[:, :-1] - signs_1 = sa_quantiles > torch.cat([ - sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1) + signs_1 = sa_quantiles > torch.cat( + [sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1 + ) values_2 = sa_quantiles - sa_quantile_hats[:, 1:] - signs_2 = sa_quantiles < torch.cat([ - sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1) + signs_2 = sa_quantiles < torch.cat( + [sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1 + ) gradient_of_taus = ( - torch.where(signs_1, values_1, -values_1) - + torch.where(signs_2, values_2, -values_2) + torch.where(signs_1, values_1, -values_1) + + torch.where(signs_2, values_2, -values_2) ) fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() # calculate entropy loss diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 4c54d3563..0647307cb 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Union + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Optional, Union -from tianshou.policy import QRDQNPolicy from tianshou.data import Batch, to_numpy +from tianshou.policy import QRDQNPolicy class IQNPolicy(QRDQNPolicy): @@ -31,7 +32,6 @@ class IQNPolicy(QRDQNPolicy): Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, @@ -45,8 +45,10 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, sample_size, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, sample_size, estimation_step, + target_update_freq, reward_normalization, **kwargs + ) assert sample_size > 1, "sample_size should be greater than 1" assert online_sample_size > 1, "online_sample_size should be greater than 1" assert target_sample_size > 1, "target_sample_size should be greater than 1" @@ -71,9 +73,8 @@ def forward( model = getattr(self, model) obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs - (logits, taus), h = model( - obs_, sample_size=sample_size, state=state, info=batch.info - ) + (logits, + taus), h = model(obs_, sample_size=sample_size, state=state, info=batch.info) q = self.compute_q_value(logits, getattr(obs, "mask", None)) if not hasattr(self, "max_action_num"): self.max_action_num = q.shape[1] @@ -92,9 +93,11 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * + (taus.unsqueeze(2) - + (target_dist - curr_dist).detach().le(0.).float()).abs() + ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 76da4bf0b..5e94baabc 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -1,13 +1,13 @@ -import torch +from typing import Any, Dict, List, Type + import numpy as np -from torch import nn +import torch import torch.nn.functional as F -from typing import Any, Dict, List, Type +from torch import nn from torch.distributions import kl_divergence - -from tianshou.policy import A2CPolicy from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import A2CPolicy class NPGPolicy(A2CPolicy): @@ -45,7 +45,6 @@ class NPGPolicy(A2CPolicy): :param bool deterministic_eval: whether to use deterministic action instead of stochastic action sampled by the policy. Default to False. """ - def __init__( self, actor: torch.nn.Module, @@ -91,7 +90,8 @@ def learn( # type: ignore log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1) actor_loss = -(log_prob * b.adv).mean() flat_grads = self._get_flat_grad( - actor_loss, self.actor, retain_graph=True).detach() + actor_loss, self.actor, retain_graph=True + ).detach() # direction: calculate natural gradient with torch.no_grad(): @@ -101,12 +101,14 @@ def learn( # type: ignore # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( - flat_grads, flat_kl_grad, nsteps=10) + flat_grads, flat_kl_grad, nsteps=10 + ) # step with torch.no_grad(): - flat_params = torch.cat([param.data.view(-1) - for param in self.actor.parameters()]) + flat_params = torch.cat( + [param.data.view(-1) for param in self.actor.parameters()] + ) new_flat_params = flat_params + self._step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) new_dist = self(b).dist @@ -138,8 +140,8 @@ def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: """Matrix vector product.""" # caculate second order gradient of kl with respect to theta kl_v = (flat_kl_grad * v).sum() - flat_kl_grad_grad = self._get_flat_grad( - kl_v, self.actor, retain_graph=True).detach() + flat_kl_grad_grad = self._get_flat_grad(kl_v, self.actor, + retain_graph=True).detach() return flat_kl_grad_grad + v * self._damping def _conjugate_gradients( @@ -179,6 +181,7 @@ def _set_from_flat_params( for param in model.parameters(): flat_size = int(np.prod(list(param.size()))) param.data.copy_( - flat_params[prev_ind:prev_ind + flat_size].view(param.size())) + flat_params[prev_ind:prev_ind + flat_size].view(param.size()) + ) prev_ind += flat_size return model diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 1eb4dde4a..d84217f6f 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, List, Optional, Type, Union + import numpy as np -from typing import Any, Dict, List, Type, Union, Optional +import torch +from tianshou.data import Batch, ReplayBuffer, to_torch_as from tianshou.policy import BasePolicy from tianshou.utils import RunningMeanStd -from tianshou.data import Batch, ReplayBuffer, to_torch_as class PGPolicy(BasePolicy): @@ -33,7 +34,6 @@ class PGPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, @@ -47,8 +47,11 @@ def __init__( deterministic_eval: bool = False, **kwargs: Any, ) -> None: - super().__init__(action_scaling=action_scaling, - action_bound_method=action_bound_method, **kwargs) + super().__init__( + action_scaling=action_scaling, + action_bound_method=action_bound_method, + **kwargs + ) self.actor = model self.optim = optim self.lr_scheduler = lr_scheduler @@ -73,7 +76,8 @@ def process_fn( """ v_s_ = np.full(indices.shape, self.ret_rms.mean) unnormalized_returns, _ = self.compute_episodic_return( - batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0) + batch, buffer, indices, v_s_=v_s_, gamma=self._gamma, gae_lambda=1.0 + ) if self._rew_norm: batch.returns = (unnormalized_returns - self.ret_rms.mean) / \ np.sqrt(self.ret_rms.var + self._eps) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 824e19ad5..fe1d37b7a 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,10 +1,11 @@ -import torch +from typing import Any, Dict, List, Optional, Type + import numpy as np +import torch from torch import nn -from typing import Any, Dict, List, Type, Optional -from tianshou.policy import A2CPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.policy import A2CPolicy class PPOPolicy(A2CPolicy): @@ -57,7 +58,6 @@ class PPOPolicy(A2CPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, actor: torch.nn.Module, @@ -124,8 +124,8 @@ def learn( # type: ignore # calculate loss for critic value = self.critic(b.obs).flatten() if self._value_clip: - v_clip = b.v_s + (value - b.v_s).clamp( - -self._eps_clip, self._eps_clip) + v_clip = b.v_s + (value - + b.v_s).clamp(-self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = torch.max(vf1, vf2).mean() @@ -140,7 +140,8 @@ def learn( # type: ignore if self._grad_norm: # clip large gradient nn.utils.clip_grad_norm_( set(self.actor.parameters()).union(self.critic.parameters()), - max_norm=self._grad_norm) + max_norm=self._grad_norm + ) self.optim.step() clip_losses.append(clip_loss.item()) vf_losses.append(vf_loss.item()) diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 91d0f4bf0..1b3aaf1ef 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -1,11 +1,12 @@ -import torch import warnings +from typing import Any, Dict, Optional + import numpy as np +import torch import torch.nn.functional as F -from typing import Any, Dict, Optional -from tianshou.policy import DQNPolicy from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import DQNPolicy class QRDQNPolicy(DQNPolicy): @@ -28,7 +29,6 @@ class QRDQNPolicy(DQNPolicy): Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed explanation. """ - def __init__( self, model: torch.nn.Module, @@ -40,13 +40,16 @@ def __init__( reward_normalization: bool = False, **kwargs: Any, ) -> None: - super().__init__(model, optim, discount_factor, estimation_step, - target_update_freq, reward_normalization, **kwargs) + super().__init__( + model, optim, discount_factor, estimation_step, target_update_freq, + reward_normalization, **kwargs + ) assert num_quantiles > 1, "num_quantiles should be greater than 1" self._num_quantiles = num_quantiles tau = torch.linspace(0, 1, self._num_quantiles + 1) self.tau_hat = torch.nn.Parameter( - ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False) + ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False + ) warnings.filterwarnings("ignore", message="Using a target size") def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: @@ -77,9 +80,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: target_dist = batch.returns.unsqueeze(1) # calculate each element's difference between curr_dist and target_dist u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") - huber_loss = (u * ( - self.tau_hat - (target_dist - curr_dist).detach().le(0.).float() - ).abs()).sum(-1).mean(1) + huber_loss = ( + u * (self.tau_hat - + (target_dist - curr_dist).detach().le(0.).float()).abs() + ).sum(-1).mean(1) loss = (huber_loss * weight).mean() # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 7aa4c682d..138c143e1 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -1,7 +1,7 @@ from typing import Any, Dict -from tianshou.policy import C51Policy from tianshou.data import Batch +from tianshou.policy import C51Policy from tianshou.utils.net.discrete import sample_noise @@ -29,7 +29,6 @@ class RainbowPolicy(C51Policy): Please refer to :class:`~tianshou.policy.C51Policy` for more detailed explanation. """ - def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: sample_noise(self.model) if self._target and sample_noise(self.model_old): diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 1858f27ed..729c3bf3c 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,12 +1,13 @@ -import torch -import numpy as np from copy import deepcopy +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch from torch.distributions import Independent, Normal -from typing import Any, Dict, Tuple, Union, Optional -from tianshou.policy import DDPGPolicy -from tianshou.exploration import BaseNoise from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.exploration import BaseNoise +from tianshou.policy import DDPGPolicy class SACPolicy(DDPGPolicy): @@ -47,7 +48,6 @@ class SACPolicy(DDPGPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, actor: torch.nn.Module, @@ -67,7 +67,8 @@ def __init__( ) -> None: super().__init__( None, None, None, None, tau, gamma, exploration_noise, - reward_normalization, estimation_step, **kwargs) + reward_normalization, estimation_step, **kwargs + ) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() @@ -123,15 +124,17 @@ def forward( # type: ignore # in appendix C to get some understanding of this equation. if self.action_scaling and self.action_space is not None: action_scale = to_torch_as( - (self.action_space.high - self.action_space.low) / 2.0, act) + (self.action_space.high - self.action_space.low) / 2.0, act + ) else: action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) log_prob = log_prob - torch.log( action_scale * (1 - squashed_action.pow(2)) + self.__eps ).sum(-1, keepdim=True) - return Batch(logits=logits, act=squashed_action, - state=h, dist=dist, log_prob=log_prob) + return Batch( + logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob + ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} @@ -146,9 +149,11 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer( - batch, self.critic1, self.critic1_optim) + batch, self.critic1, self.critic1_optim + ) td2, critic2_loss = self._mse_optimizer( - batch, self.critic2, self.critic2_optim) + batch, self.critic2, self.critic2_optim + ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor @@ -156,8 +161,10 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: a = obs_result.act current_q1a = self.critic1(batch.obs, a).flatten() current_q2a = self.critic2(batch.obs, a).flatten() - actor_loss = (self._alpha * obs_result.log_prob.flatten() - - torch.min(current_q1a, current_q2a)).mean() + actor_loss = ( + self._alpha * obs_result.log_prob.flatten() - + torch.min(current_q1a, current_q2a) + ).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 3ca785bcc..506c5cd03 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,11 +1,12 @@ -import torch -import numpy as np from copy import deepcopy from typing import Any, Dict, Optional -from tianshou.policy import DDPGPolicy +import numpy as np +import torch + from tianshou.data import Batch, ReplayBuffer from tianshou.exploration import BaseNoise, GaussianNoise +from tianshou.policy import DDPGPolicy class TD3Policy(DDPGPolicy): @@ -45,7 +46,6 @@ class TD3Policy(DDPGPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__( self, actor: torch.nn.Module, @@ -64,9 +64,10 @@ def __init__( estimation_step: int = 1, **kwargs: Any, ) -> None: - super().__init__(actor, actor_optim, None, None, tau, gamma, - exploration_noise, reward_normalization, - estimation_step, **kwargs) + super().__init__( + actor, actor_optim, None, None, tau, gamma, exploration_noise, + reward_normalization, estimation_step, **kwargs + ) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim @@ -103,16 +104,18 @@ def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: noise = noise.clamp(-self._noise_clip, self._noise_clip) a_ += noise target_q = torch.min( - self.critic1_old(batch.obs_next, a_), - self.critic2_old(batch.obs_next, a_)) + self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_) + ) return target_q def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # critic 1&2 td1, critic1_loss = self._mse_optimizer( - batch, self.critic1, self.critic1_optim) + batch, self.critic1, self.critic1_optim + ) td2, critic2_loss = self._mse_optimizer( - batch, self.critic2, self.critic2_optim) + batch, self.critic2, self.critic2_optim + ) batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index b0ba63f11..92ecf2e7d 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,9 +1,9 @@ -import torch import warnings -import torch.nn.functional as F from typing import Any, Dict, List, Type -from torch.distributions import kl_divergence +import torch +import torch.nn.functional as F +from torch.distributions import kl_divergence from tianshou.data import Batch from tianshou.policy import NPGPolicy @@ -48,7 +48,6 @@ class TRPOPolicy(NPGPolicy): :param bool deterministic_eval: whether to use deterministic action instead of stochastic action sampled by the policy. Default to False. """ - def __init__( self, actor: torch.nn.Module, @@ -79,7 +78,8 @@ def learn( # type: ignore ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * b.adv).mean() flat_grads = self._get_flat_grad( - actor_loss, self.actor, retain_graph=True).detach() + actor_loss, self.actor, retain_graph=True + ).detach() # direction: calculate natural gradient with torch.no_grad(): @@ -89,26 +89,30 @@ def learn( # type: ignore # calculate first order gradient of kl with respect to theta flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) search_direction = -self._conjugate_gradients( - flat_grads, flat_kl_grad, nsteps=10) + flat_grads, flat_kl_grad, nsteps=10 + ) # stepsize: calculate max stepsize constrained by kl bound - step_size = torch.sqrt(2 * self._delta / ( - search_direction * self._MVP(search_direction, flat_kl_grad) - ).sum(0, keepdim=True)) + step_size = torch.sqrt( + 2 * self._delta / + (search_direction * + self._MVP(search_direction, flat_kl_grad)).sum(0, keepdim=True) + ) # stepsize: linesearch stepsize with torch.no_grad(): - flat_params = torch.cat([param.data.view(-1) - for param in self.actor.parameters()]) + flat_params = torch.cat( + [param.data.view(-1) for param in self.actor.parameters()] + ) for i in range(self._max_backtracks): new_flat_params = flat_params + step_size * search_direction self._set_from_flat_params(self.actor, new_flat_params) # calculate kl and if in bound, loss actually down new_dist = self(b).dist - new_dratio = ( - new_dist.log_prob(b.act) - b.logp_old).exp().float() - new_dratio = new_dratio.reshape( - new_dratio.size(0), -1).transpose(0, 1) + new_dratio = (new_dist.log_prob(b.act) - + b.logp_old).exp().float() + new_dratio = new_dratio.reshape(new_dratio.size(0), + -1).transpose(0, 1) new_actor_loss = -(new_dratio * b.adv).mean() kl = kl_divergence(old_dist, new_dist).mean() @@ -121,8 +125,10 @@ def learn( # type: ignore else: self._set_from_flat_params(self.actor, new_flat_params) step_size = torch.tensor([0.0]) - warnings.warn("Line search failed! It seems hyperparamters" - " are poor and need to be changed.") + warnings.warn( + "Line search failed! It seems hyperparamters" + " are poor and need to be changed." + ) # optimize citirc for _ in range(self._optim_critic_iters): diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index e7b50f07f..bd6272f80 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,8 +1,9 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + import numpy as np -from typing import Any, Dict, List, Tuple, Union, Optional -from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer +from tianshou.policy import BasePolicy class MultiAgentPolicyManager(BasePolicy): @@ -14,7 +15,6 @@ class MultiAgentPolicyManager(BasePolicy): and "learn": it splits the data and feeds them to each policy. A figure in :ref:`marl_example` can help you better understand this procedure. """ - def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None: super().__init__(**kwargs) self.policies = policies @@ -54,21 +54,22 @@ def process_fn( tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] buffer._meta.rew = save_rew[:, policy.agent_id - 1] results[f"agent_{policy.agent_id}"] = policy.process_fn( - tmp_batch, buffer, tmp_indices) + tmp_batch, buffer, tmp_indices + ) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return Batch(results) - def exploration_noise( - self, act: Union[np.ndarray, Batch], batch: Batch - ) -> Union[np.ndarray, Batch]: + def exploration_noise(self, act: Union[np.ndarray, Batch], + batch: Batch) -> Union[np.ndarray, Batch]: """Add exploration noise from sub-policy onto act.""" for policy in self.policies: agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] if len(agent_index) == 0: continue act[agent_index] = policy.exploration_noise( - act[agent_index], batch[agent_index]) + act[agent_index], batch[agent_index] + ) return act def forward( # type: ignore @@ -100,8 +101,8 @@ def forward( # type: ignore "agent_n": xxx} } """ - results: List[Tuple[bool, np.ndarray, Batch, - Union[np.ndarray, Batch], Batch]] = [] + results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch], + Batch]] = [] for policy in self.policies: # This part of code is difficult to understand. # Let's follow an example with two agents @@ -119,20 +120,28 @@ def forward( # type: ignore if isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] - out = policy(batch=tmp_batch, state=None if state is None - else state["agent_" + str(policy.agent_id)], - **kwargs) + out = policy( + batch=tmp_batch, + state=None if state is None else state["agent_" + + str(policy.agent_id)], + **kwargs + ) act = out.act each_state = out.state \ if (hasattr(out, "state") and out.state is not None) \ else Batch() results.append((True, agent_index, out, act, each_state)) - holder = Batch.cat([{"act": act} for - (has_data, agent_index, out, act, each_state) - in results if has_data]) + holder = Batch.cat( + [ + { + "act": act + } for (has_data, agent_index, out, act, each_state) in results + if has_data + ] + ) state_dict, out_dict = {}, {} - for policy, (has_data, agent_index, out, act, state) in zip( - self.policies, results): + for policy, (has_data, agent_index, out, act, + state) in zip(self.policies, results): if has_data: holder.act[agent_index] = act state_dict["agent_" + str(policy.agent_id)] = state @@ -141,9 +150,8 @@ def forward( # type: ignore holder["state"] = state_dict return holder - def learn( - self, batch: Batch, **kwargs: Any - ) -> Dict[str, Union[float, List[float]]]: + def learn(self, batch: Batch, + **kwargs: Any) -> Dict[str, Union[float, List[float]]]: """Dispatch the data to all policies for learning. :return: a dict with the following contents: diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 9c7f132af..863f93229 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -1,5 +1,6 @@ +from typing import Any, Dict, Optional, Union + import numpy as np -from typing import Any, Dict, Union, Optional from tianshou.data import Batch from tianshou.policy import BasePolicy @@ -10,7 +11,6 @@ class RandomPolicy(BasePolicy): It randomly chooses an action from the legal action. """ - def forward( self, batch: Batch, diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 9fa88fbc3..f3baf8499 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,3 +1,4 @@ +"""isort:skip_file""" from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer from tianshou.trainer.offpolicy import offpolicy_trainer diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index a2bcf051a..eb7f9a371 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,13 +1,14 @@ import time -import tqdm -import numpy as np from collections import defaultdict -from typing import Dict, Union, Callable, Optional +from typing import Callable, Dict, Optional, Union + +import numpy as np +import tqdm -from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.data import Collector, ReplayBuffer -from tianshou.trainer import test_episode, gather_info +from tianshou.policy import BasePolicy +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config def offline_trainer( @@ -74,16 +75,16 @@ def offline_trainer( start_time = time.time() test_collector.reset_stat() - test_result = test_episode(policy, test_collector, test_fn, start_epoch, - episode_per_test, logger, gradient_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + gradient_step, reward_metric + ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1 + start_epoch, 1 + max_epoch): policy.train() - with tqdm.trange( - update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ) as t: + with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: for i in t: gradient_step += 1 losses = policy.update(batch_size, buffer) @@ -96,8 +97,9 @@ def offline_trainer( t.set_postfix(**data) # test test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, - logger, gradient_step, reward_metric) + policy, test_collector, test_fn, epoch, episode_per_test, logger, + gradient_step, reward_metric + ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: best_epoch, best_reward, best_reward_std = epoch, rew, rew_std @@ -105,8 +107,10 @@ def offline_trainer( save_fn(policy) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) if verbose: - print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) if stop_fn and stop_fn(best_reward): break return gather_info(start_time, None, test_collector, best_reward, best_reward_std) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 2a576ccea..9dc6783eb 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,13 +1,14 @@ import time -import tqdm -import numpy as np from collections import defaultdict -from typing import Dict, Union, Callable, Optional +from typing import Callable, Dict, Optional, Union + +import numpy as np +import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import test_episode, gather_info -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config def offpolicy_trainer( @@ -91,8 +92,10 @@ def offpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode(policy, test_collector, test_fn, start_epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + env_step, reward_metric + ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] @@ -123,17 +126,20 @@ def offpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, - epoch, episode_per_test, logger, env_step) + policy, test_collector, test_fn, epoch, episode_per_test, + logger, env_step + ) if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn) + epoch, env_step, gradient_step, save_checkpoint_fn + ) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) + test_result["rew"], test_result["rew_std"] + ) else: policy.train() for i in range(round(update_per_step * result["n/st"])): @@ -148,8 +154,10 @@ def offpolicy_trainer( if t.n <= t.total: t.update() # test - test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric + ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: best_epoch, best_reward, best_reward_std = epoch, rew, rew_std @@ -157,9 +165,12 @@ def offpolicy_trainer( save_fn(policy) logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if verbose: - print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) if stop_fn and stop_fn(best_reward): break - return gather_info(start_time, train_collector, test_collector, - best_reward, best_reward_std) + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 1696421e8..fd845c506 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,13 +1,14 @@ import time -import tqdm -import numpy as np from collections import defaultdict -from typing import Dict, Union, Callable, Optional +from typing import Callable, Dict, Optional, Union + +import numpy as np +import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import test_episode, gather_info -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config def onpolicy_trainer( @@ -97,8 +98,10 @@ def onpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode(policy, test_collector, test_fn, start_epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + env_step, reward_metric + ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] @@ -111,8 +114,9 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_step=step_per_collect, - n_episode=episode_per_collect) + result = train_collector.collect( + n_step=step_per_collect, n_episode=episode_per_collect + ) if result["n/ep"] > 0 and reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) @@ -130,25 +134,32 @@ def onpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, - epoch, episode_per_test, logger, env_step) + policy, test_collector, test_fn, epoch, episode_per_test, + logger, env_step + ) if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn) + epoch, env_step, gradient_step, save_checkpoint_fn + ) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) + test_result["rew"], test_result["rew_std"] + ) else: policy.train() losses = policy.update( - 0, train_collector.buffer, - batch_size=batch_size, repeat=repeat_per_collect) + 0, + train_collector.buffer, + batch_size=batch_size, + repeat=repeat_per_collect + ) train_collector.reset_buffer(keep_statistics=True) - step = max([1] + [ - len(v) for v in losses.values() if isinstance(v, list)]) + step = max( + [1] + [len(v) for v in losses.values() if isinstance(v, list)] + ) gradient_step += step for k in losses.keys(): stat[k].add(losses[k]) @@ -159,8 +170,10 @@ def onpolicy_trainer( if t.n <= t.total: t.update() # test - test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, logger, env_step, reward_metric) + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric + ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: best_epoch, best_reward, best_reward_std = epoch, rew, rew_std @@ -168,9 +181,12 @@ def onpolicy_trainer( save_fn(policy) logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if verbose: - print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) if stop_fn and stop_fn(best_reward): break - return gather_info(start_time, train_collector, test_collector, - best_reward, best_reward_std) + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 2e729feeb..a39a12fff 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,6 +1,7 @@ import time +from typing import Any, Callable, Dict, Optional, Union + import numpy as np -from typing import Any, Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy @@ -71,11 +72,13 @@ def gather_info( if train_c is not None: model_time -= train_c.collect_time train_speed = train_c.collect_step / (duration - test_c.collect_time) - result.update({ - "train_step": train_c.collect_step, - "train_episode": train_c.collect_episode, - "train_time/collector": f"{train_c.collect_time:.2f}s", - "train_time/model": f"{model_time:.2f}s", - "train_speed": f"{train_speed:.2f} step/s", - }) + result.update( + { + "train_step": train_c.collect_step, + "train_episode": train_c.collect_episode, + "train_time/collector": f"{train_c.collect_time:.2f}s", + "train_time/model": f"{model_time:.2f}s", + "train_speed": f"{train_speed:.2f} step/s", + } + ) return result diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 4ad73481c..64ae88328 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,17 +1,10 @@ from tianshou.utils.config import tqdm_config -from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.logger.base import BaseLogger, LazyLogger -from tianshou.utils.logger.tensorboard import TensorboardLogger, BasicLogger +from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.wandb import WandBLogger - +from tianshou.utils.statistics import MovAvg, RunningMeanStd __all__ = [ - "MovAvg", - "RunningMeanStd", - "tqdm_config", - "BaseLogger", - "TensorboardLogger", - "BasicLogger", - "LazyLogger", - "WandBLogger" + "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger", + "BasicLogger", "LazyLogger", "WandBLogger" ] diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index c1ffe760d..fcb3c4b93 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -1,7 +1,8 @@ -import numpy as np -from numbers import Number from abc import ABC, abstractmethod -from typing import Dict, Tuple, Union, Callable, Optional +from numbers import Number +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]] @@ -15,7 +16,6 @@ class BaseLogger(ABC): :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. """ - def __init__( self, train_interval: int = 1000, @@ -132,7 +132,6 @@ def restore_data(self) -> Tuple[int, int, int]: class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" - def __init__(self) -> None: super().__init__() diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index c65576b41..cee450168 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -1,10 +1,10 @@ import warnings -from typing import Any, Tuple, Callable, Optional +from typing import Any, Callable, Optional, Tuple -from torch.utils.tensorboard import SummaryWriter from tensorboard.backend.event_processing import event_accumulator +from torch.utils.tensorboard import SummaryWriter -from tianshou.utils.logger.base import BaseLogger, LOG_DATA_TYPE +from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger class TensorboardLogger(BaseLogger): @@ -18,7 +18,6 @@ class TensorboardLogger(BaseLogger): :param int save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). """ - def __init__( self, writer: SummaryWriter, @@ -48,8 +47,10 @@ def save_data( save_checkpoint_fn(epoch, env_step, gradient_step) self.write("save/epoch", epoch, {"save/epoch": epoch}) self.write("save/env_step", env_step, {"save/env_step": env_step}) - self.write("save/gradient_step", gradient_step, - {"save/gradient_step": gradient_step}) + self.write( + "save/gradient_step", gradient_step, + {"save/gradient_step": gradient_step} + ) def restore_data(self) -> Tuple[int, int, int]: ea = event_accumulator.EventAccumulator(self.writer.log_dir) @@ -76,8 +77,8 @@ class BasicLogger(TensorboardLogger): This class is for compatibility. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( - "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427.") + "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427." + ) super().__init__(*args, **kwargs) diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 7a837c96c..db38cfcc2 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -30,7 +30,6 @@ class WandBLogger(BaseLogger): :param int update_interval: the log interval in log_update_data(). Default to 1000. """ - def __init__( self, train_interval: int = 1000, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index cb11abc2e..d9b883c87 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,7 +1,8 @@ -import torch +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union + import numpy as np +import torch from torch import nn -from typing import Any, Dict, List, Type, Tuple, Union, Optional, Sequence ModuleType = Type[nn.Module] @@ -46,7 +47,6 @@ class MLP(nn.Module): :param device: which device to create this model on. Default to None. :param linear_layer: use this module as linear layer. Default to nn.Linear. """ - def __init__( self, input_dim: int, @@ -64,8 +64,7 @@ def __init__( assert len(norm_layer) == len(hidden_sizes) norm_layer_list = norm_layer else: - norm_layer_list = [ - norm_layer for _ in range(len(hidden_sizes))] + norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))] else: norm_layer_list = [None] * len(hidden_sizes) if activation: @@ -73,26 +72,22 @@ def __init__( assert len(activation) == len(hidden_sizes) activation_list = activation else: - activation_list = [ - activation for _ in range(len(hidden_sizes))] + activation_list = [activation for _ in range(len(hidden_sizes))] else: activation_list = [None] * len(hidden_sizes) hidden_sizes = [input_dim] + list(hidden_sizes) model = [] for in_dim, out_dim, norm, activ in zip( - hidden_sizes[:-1], hidden_sizes[1:], - norm_layer_list, activation_list): + hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list + ): model += miniblock(in_dim, out_dim, norm, activ, linear_layer) if output_dim > 0: model += [linear_layer(hidden_sizes[-1], output_dim)] self.output_dim = output_dim or hidden_sizes[-1] self.model = nn.Sequential(*model) - def forward( - self, x: Union[np.ndarray, torch.Tensor] - ) -> torch.Tensor: - x = torch.as_tensor( - x, device=self.device, dtype=torch.float32) # type: ignore + def forward(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + x = torch.as_tensor(x, device=self.device, dtype=torch.float32) # type: ignore return self.model(x.flatten(1)) @@ -138,7 +133,6 @@ class Net(nn.Module): :class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's suggested be used. """ - def __init__( self, state_shape: Union[int, Sequence[int]], @@ -162,8 +156,9 @@ def __init__( input_dim += action_dim self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 - self.model = MLP(input_dim, output_dim, hidden_sizes, - norm_layer, activation, device) + self.model = MLP( + input_dim, output_dim, hidden_sizes, norm_layer, activation, device + ) self.output_dim = self.model.output_dim if self.use_dueling: # dueling DQN q_kwargs, v_kwargs = dueling_param # type: ignore @@ -172,10 +167,14 @@ def __init__( q_output_dim, v_output_dim = action_dim, num_atoms q_kwargs: Dict[str, Any] = { **q_kwargs, "input_dim": self.output_dim, - "output_dim": q_output_dim, "device": self.device} + "output_dim": q_output_dim, + "device": self.device + } v_kwargs: Dict[str, Any] = { **v_kwargs, "input_dim": self.output_dim, - "output_dim": v_output_dim, "device": self.device} + "output_dim": v_output_dim, + "device": self.device + } self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) self.output_dim = self.Q.output_dim @@ -207,7 +206,6 @@ class Recurrent(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - def __init__( self, layer_num: int, @@ -239,8 +237,7 @@ def forward( training mode, s should be with shape ``[bsz, len, dim]``. See the code and comment for more detail. """ - s = torch.as_tensor( - s, device=self.device, dtype=torch.float32) # type: ignore + s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -253,9 +250,12 @@ def forward( else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(), - state["c"].transpose(0, 1).contiguous())) + s, (h, c) = self.nn( + s, ( + state["h"].transpose(0, 1).contiguous(), + state["c"].transpose(0, 1).contiguous() + ) + ) s = self.fc2(s[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] - return s, {"h": h.transpose(0, 1).detach(), - "c": c.transpose(0, 1).detach()} + return s, {"h": h.transpose(0, 1).detach(), "c": c.transpose(0, 1).detach()} diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 36c178612..f83f91873 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,11 +1,11 @@ -import torch +from typing import Any, Dict, Optional, Sequence, Tuple, Union + import numpy as np +import torch from torch import nn -from typing import Any, Dict, Tuple, Union, Optional, Sequence from tianshou.utils.net.common import MLP - SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -33,7 +33,6 @@ class Actor(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ - def __init__( self, preprocess_net: nn.Module, @@ -47,10 +46,8 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self._max = max_action def forward( @@ -85,7 +82,6 @@ class Critic(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ - def __init__( self, preprocess_net: nn.Module, @@ -97,8 +93,7 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = 1 - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) self.last = MLP(input_dim, 1, hidden_sizes, device=self.device) def forward( @@ -109,11 +104,15 @@ def forward( ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" s = torch.as_tensor( - s, device=self.device, dtype=torch.float32 # type: ignore + s, + device=self.device, + dtype=torch.float32 # type: ignore ).flatten(1) if a is not None: a = torch.as_tensor( - a, device=self.device, dtype=torch.float32 # type: ignore + a, + device=self.device, + dtype=torch.float32 # type: ignore ).flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) @@ -147,7 +146,6 @@ class ActorProb(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ - def __init__( self, preprocess_net: nn.Module, @@ -163,14 +161,13 @@ def __init__( self.preprocess = preprocess_net self.device = device self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.mu = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self._c_sigma = conditioned_sigma if conditioned_sigma: - self.sigma = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + self.sigma = MLP( + input_dim, self.output_dim, hidden_sizes, device=self.device + ) else: self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self._max = max_action @@ -188,9 +185,7 @@ def forward( if not self._unbounded: mu = self._max * torch.tanh(mu) if self._c_sigma: - sigma = torch.clamp( - self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX - ).exp() + sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 @@ -204,7 +199,6 @@ class RecurrentActorProb(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - def __init__( self, layer_num: int, @@ -241,8 +235,7 @@ def forward( info: Dict[str, Any] = {}, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" - s = torch.as_tensor( - s, device=self.device, dtype=torch.float32) # type: ignore + s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -254,23 +247,27 @@ def forward( else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(), - state["c"].transpose(0, 1).contiguous())) + s, (h, c) = self.nn( + s, ( + state["h"].transpose(0, 1).contiguous(), + state["c"].transpose(0, 1).contiguous() + ) + ) logits = s[:, -1] mu = self.mu(logits) if not self._unbounded: mu = self._max * torch.tanh(mu) if self._c_sigma: - sigma = torch.clamp( - self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX - ).exp() + sigma = torch.clamp(self.sigma(logits), min=SIGMA_MIN, max=SIGMA_MAX).exp() else: shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] - return (mu, sigma), {"h": h.transpose(0, 1).detach(), - "c": c.transpose(0, 1).detach()} + return (mu, sigma), { + "h": h.transpose(0, 1).detach(), + "c": c.transpose(0, 1).detach() + } class RecurrentCritic(nn.Module): @@ -279,7 +276,6 @@ class RecurrentCritic(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ - def __init__( self, layer_num: int, @@ -307,8 +303,7 @@ def forward( info: Dict[str, Any] = {}, ) -> torch.Tensor: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" - s = torch.as_tensor( - s, device=self.device, dtype=torch.float32) # type: ignore + s = torch.as_tensor(s, device=self.device, dtype=torch.float32) # type: ignore # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -318,7 +313,8 @@ def forward( s = s[:, -1] if a is not None: a = torch.as_tensor( - a, device=self.device, dtype=torch.float32) # type: ignore + a, device=self.device, dtype=torch.float32 + ) # type: ignore s = torch.cat([s, a], dim=1) s = self.fc2(s) return s diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 200ae9d3a..e642bafc5 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -1,8 +1,9 @@ -import torch +from typing import Any, Dict, Optional, Sequence, Tuple, Union + import numpy as np -from torch import nn +import torch import torch.nn.functional as F -from typing import Any, Dict, Tuple, Union, Optional, Sequence +from torch import nn from tianshou.data import Batch from tianshou.utils.net.common import MLP @@ -33,7 +34,6 @@ class Actor(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ - def __init__( self, preprocess_net: nn.Module, @@ -47,10 +47,8 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = int(np.prod(action_shape)) - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.last = MLP(input_dim, self.output_dim, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self.softmax_output = softmax_output def forward( @@ -88,7 +86,6 @@ class Critic(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ - def __init__( self, preprocess_net: nn.Module, @@ -101,10 +98,8 @@ def __init__( self.device = device self.preprocess = preprocess_net self.output_dim = last_size - input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) - self.last = MLP(input_dim, last_size, - hidden_sizes, device=self.device) + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) def forward( self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any @@ -126,7 +121,6 @@ class CosineEmbeddingNetwork(nn.Module): From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ - def __init__(self, num_cosines: int, embedding_dim: int) -> None: super().__init__() self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU()) @@ -141,9 +135,8 @@ def forward(self, taus: torch.Tensor) -> torch.Tensor: start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device ).view(1, 1, self.num_cosines) # Calculate cos(i * \pi * \tau). - cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view( - batch_size * N, self.num_cosines - ) + cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi + ).view(batch_size * N, self.num_cosines) # Calculate embeddings of taus. tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim) return tau_embeddings @@ -170,7 +163,6 @@ class ImplicitQuantileNetwork(Critic): The second item of the first return value is tau vector. """ - def __init__( self, preprocess_net: nn.Module, @@ -181,10 +173,12 @@ def __init__( device: Union[str, int, torch.device] = "cpu" ) -> None: last_size = np.prod(action_shape) - super().__init__(preprocess_net, hidden_sizes, last_size, - preprocess_net_output_dim, device) - self.input_dim = getattr(preprocess_net, "output_dim", - preprocess_net_output_dim) + super().__init__( + preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device + ) + self.input_dim = getattr( + preprocess_net, "output_dim", preprocess_net_output_dim + ) self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(device) @@ -195,13 +189,12 @@ def forward( # type: ignore logits, h = self.preprocess(s, state=kwargs.get("state", None)) # Sample fractions. batch_size = logits.size(0) - taus = torch.rand(batch_size, sample_size, - dtype=logits.dtype, device=logits.device) - embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view( - batch_size * sample_size, -1 + taus = torch.rand( + batch_size, sample_size, dtype=logits.dtype, device=logits.device ) - out = self.last(embedding).view( - batch_size, sample_size, -1).transpose(1, 2) + embedding = (logits.unsqueeze(1) * + self.embed_model(taus)).view(batch_size * sample_size, -1) + out = self.last(embedding).view(batch_size, sample_size, -1).transpose(1, 2) return (out, taus), h @@ -216,7 +209,6 @@ class FractionProposalNetwork(nn.Module): Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ - def __init__(self, num_fractions: int, embedding_dim: int) -> None: super().__init__() self.net = nn.Linear(embedding_dim, num_fractions) @@ -259,7 +251,6 @@ class FullQuantileFunction(ImplicitQuantileNetwork): The first return value is a tuple of (quantiles, fractions, quantiles_tau), where fractions is a Batch(taus, tau_hats, entropies). """ - def __init__( self, preprocess_net: nn.Module, @@ -270,20 +261,18 @@ def __init__( device: Union[str, int, torch.device] = "cpu", ) -> None: super().__init__( - preprocess_net, action_shape, hidden_sizes, - num_cosines, preprocess_net_output_dim, device + preprocess_net, action_shape, hidden_sizes, num_cosines, + preprocess_net_output_dim, device ) def _compute_quantiles( self, obs: torch.Tensor, taus: torch.Tensor ) -> torch.Tensor: batch_size, sample_size = taus.shape - embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view( - batch_size * sample_size, -1 - ) - quantiles = self.last(embedding).view( - batch_size, sample_size, -1 - ).transpose(1, 2) + embedding = (obs.unsqueeze(1) * + self.embed_model(taus)).view(batch_size * sample_size, -1) + quantiles = self.last(embedding).view(batch_size, sample_size, + -1).transpose(1, 2) return quantiles def forward( # type: ignore @@ -321,17 +310,14 @@ class NoisyLinear(nn.Module): Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ - def __init__( self, in_features: int, out_features: int, noisy_std: float = 0.5 ) -> None: super().__init__() # Learnable parameters. - self.mu_W = nn.Parameter( - torch.FloatTensor(out_features, in_features)) - self.sigma_W = nn.Parameter( - torch.FloatTensor(out_features, in_features)) + self.mu_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) + self.sigma_W = nn.Parameter(torch.FloatTensor(out_features, in_features)) self.mu_bias = nn.Parameter(torch.FloatTensor(out_features)) self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features)) diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index e0d06763a..574747776 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -1,8 +1,9 @@ -import torch -import numpy as np from numbers import Number from typing import List, Union +import numpy as np +import torch + class MovAvg(object): """Class for moving average. @@ -22,7 +23,6 @@ class MovAvg(object): >>> print(f'{stat.mean():.2f}±{stat.std():.2f}') 6.50±1.12 """ - def __init__(self, size: int = 100) -> None: super().__init__() self.size = size @@ -70,9 +70,10 @@ class RunningMeanStd(object): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm """ - def __init__( - self, mean: Union[float, np.ndarray] = 0.0, std: Union[float, np.ndarray] = 1.0 + self, + mean: Union[float, np.ndarray] = 0.0, + std: Union[float, np.ndarray] = 1.0 ) -> None: self.mean, self.var = mean, std self.count = 0 @@ -88,7 +89,7 @@ def update(self, x: np.ndarray) -> None: new_mean = self.mean + delta * batch_count / total_count m_a = self.var * self.count m_b = batch_var * batch_count - m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count + m_2 = m_a + m_b + delta**2 * self.count * batch_count / total_count new_var = m_2 / total_count self.mean, self.var = new_mean, new_var From e15183265969b68b04c203806e36cf8d38a3a0a3 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 14:10:55 -0400 Subject: [PATCH 03/11] add flake8-bugbear --- .github/workflows/lint_and_docs.yml | 4 +--- .github/workflows/profile.yml | 1 + examples/atari/atari_wrapper.py | 2 +- examples/mujoco/gen_json.py | 2 +- examples/mujoco/tools.py | 2 +- examples/vizdoom/env.py | 2 +- setup.cfg | 1 + setup.py | 1 + test/base/test_buffer.py | 6 +++--- test/base/test_env.py | 4 +++- test/throughput/test_buffer_profile.py | 4 ++-- tianshou/data/collector.py | 8 ++------ tianshou/env/venvs.py | 2 +- tianshou/env/worker/base.py | 2 +- tianshou/policy/imitation/discrete_bcq.py | 5 +++-- tianshou/policy/modelfree/npg.py | 4 ++-- tianshou/policy/modelfree/trpo.py | 2 +- tianshou/trainer/offline.py | 2 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/utils/net/continuous.py | 16 +++------------- 20 files changed, 31 insertions(+), 41 deletions(-) diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index 6f6e21a6e..89daaaa6b 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -20,11 +20,9 @@ jobs: - name: Lint with flake8 run: | flake8 . --count --show-source --statistics - - name: yapf code formatter + - name: Code formatter run: | yapf -r -d . - - name: isort code formatter - run: | isort --check . - name: Type check run: | diff --git a/.github/workflows/profile.yml b/.github/workflows/profile.yml index 3d62da417..82c793e99 100644 --- a/.github/workflows/profile.yml +++ b/.github/workflows/profile.yml @@ -5,6 +5,7 @@ on: [push, pull_request] jobs: profile: runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, 'ci skip')" steps: - uses: actions/checkout@v2 - name: Set up Python 3.8 diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 2c128c10f..4fbe5fb4a 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -47,7 +47,7 @@ def step(self, action): reward, and max over last observations. """ obs_list, total_reward, done = [], 0., False - for i in range(self._skip): + for _ in range(self._skip): obs, reward, done, info = self.env.step(action) obs_list.append(obs) total_reward += reward diff --git a/examples/mujoco/gen_json.py b/examples/mujoco/gen_json.py index 5429cf8d8..99cad74a3 100755 --- a/examples/mujoco/gen_json.py +++ b/examples/mujoco/gen_json.py @@ -9,7 +9,7 @@ def merge(rootdir): """format: $rootdir/$algo/*.csv""" result = [] - for path, dirnames, filenames in os.walk(rootdir): + for path, _, filenames in os.walk(rootdir): filenames = [f for f in filenames if f.endswith('.csv')] if len(filenames) == 0: continue diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index bda7c6db1..3ed4791fd 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -83,7 +83,7 @@ def merge_csv(csv_files, root_dir, remove_zero=False): """Merge result in csv_files into a single csv file.""" assert len(csv_files) > 0 if remove_zero: - for k, v in csv_files.items(): + for v in csv_files.values(): if v[1][0] == 0: v.pop(1) sorted_keys = sorted(csv_files.keys()) diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index 0e8d995d1..f7db3078c 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -119,7 +119,7 @@ def close(self): obs = env.reset() print(env.spec.reward_threshold) print(obs.shape, action_num) - for i in range(4000): + for _ in range(4000): obs, rew, done, info = env.step(0) if done: env.reset() diff --git a/setup.cfg b/setup.cfg index e1420b9ad..1370ddcbd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ exclude = dist *.egg-info max-line-length = 87 +ignore = B305,W504,B006,B008 [yapf] based_on_style = pep8 diff --git a/setup.py b/setup.py index 03a2d2186..bf48020e5 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ def get_version() -> str: "sphinx_rtd_theme", "sphinxcontrib-bibtex", "flake8", + "flake8-bugbear", "yapf", "isort", "pytest", diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 2c54cbbab..c1568c75d 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -178,7 +178,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) obs = env.reset(1) - for i in range(16): + for _ in range(16): obs_next, rew, done, info = env.step(1) buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) @@ -313,7 +313,7 @@ def test_segtree(): naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) - for i in range(10): + for _ in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) @@ -327,7 +327,7 @@ def test_segtree(): naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) - for i in range(10): + for _ in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) diff --git a/test/base/test_env.py b/test/base/test_env.py index 4d5cc6a00..b9d6489b6 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -140,7 +140,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): v.seed(0) action_list = [1] * 5 + [0] * 10 + [1] * 20 o = [v.reset() for v in venv] - for i, a in enumerate(action_list): + for a in action_list: o = [] for v in venv: A, B, C, D = v.step([a] * num) @@ -152,6 +152,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): continue for info in infos: assert recurse_comp(infos[0], info) + if __name__ == '__main__': t = [0] * len(venv) for i, e in enumerate(venv): @@ -164,6 +165,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): t[i] = time.time() - t[i] for i, v in enumerate(venv): print(f'{type(v)}: {t[i]:.6f}s') + for v in venv: assert v.size == list(range(size, size + num)) assert v.env_num == num diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index 57bd3f5b6..2bb00c143 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -14,7 +14,7 @@ def test_replaybuffer(task="Pendulum-v0"): env = gym.make(task) buf = ReplayBuffer(10000) obs = env.reset() - for i in range(100000): + for _ in range(100000): act = env.action_space.sample() obs_next, rew, done, info = env.step(act) batch = Batch( @@ -37,7 +37,7 @@ def test_vectorbuffer(task="Pendulum-v0"): env = gym.make(task) buf = VectorReplayBuffer(total_size=10000, buffer_num=1) obs = env.reset() - for i in range(100000): + for _ in range(100000): act = env.action_space.sample() obs_next, rew, done, info = env.step(act) batch = Batch( diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1da6d08b1..f4bb278e4 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -229,9 +229,7 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids - ) # type: ignore + obs_next, rew, done, info = self.env.step(action_remap, ready_env_ids) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -447,9 +445,7 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids - ) # type: ignore + obs_next, rew, done, info = self.env.step(action_remap, ready_env_ids) # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 3e37f56f3..37f96c9df 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -228,7 +228,7 @@ def step( if action is not None: self._assert_id(id) assert len(action) == len(id) - for i, (act, env_id) in enumerate(zip(action, id)): + for act, env_id in zip(action, id): self.workers[env_id].send_action(act) self.waiting_conn.append(self.workers[env_id]) self.waiting_id.append(env_id) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 2a32ce961..162feb6b1 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -11,7 +11,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] - self.action_space = getattr(self, "action_space") + self.action_space = getattr(self, "action_space") # noqa: B009 @abstractmethod def __getattr__(self, key: str) -> Any: diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 19c7b08da..a3fdeda17 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -114,8 +114,9 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: act = to_torch(batch.act, dtype=torch.long, device=target_q.device) q_loss = F.smooth_l1_loss(current_q, target_q) i_loss = F.nll_loss( - F.log_softmax(imitation_logits, dim=-1), act - ) # type: ignore + F.log_softmax(imitation_logits, dim=-1), + act # type: ignore + ) reg_loss = imitation_logits.pow(2).mean() loss = q_loss + i_loss + self._weight_reg * reg_loss diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 5e94baabc..72761bf1e 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -81,7 +81,7 @@ def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: actor_losses, vf_losses, kls = [], [], [] - for step in range(repeat): + for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient @@ -156,7 +156,7 @@ def _conjugate_gradients( # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0. # Change if doing warm start. rdotr = r.dot(r) - for i in range(nsteps): + for _ in range(nsteps): z = self._MVP(p, flat_kl_grad) alpha = rdotr / p.dot(z) x += alpha * p diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 92ecf2e7d..45d820dcb 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -69,7 +69,7 @@ def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any ) -> Dict[str, List[float]]: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] - for step in range(repeat): + for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index eb7f9a371..72cd00d06 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -85,7 +85,7 @@ def offline_trainer( for epoch in range(1 + start_epoch, 1 + max_epoch): policy.train() with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: - for i in t: + for _ in t: gradient_step += 1 losses = policy.update(batch_size, buffer) data = {"gradient_step": str(gradient_step)} diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 9dc6783eb..e9fe91ea4 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -142,7 +142,7 @@ def offpolicy_trainer( ) else: policy.train() - for i in range(round(update_per_step * result["n/st"])): + for _ in range(round(update_per_step * result["n/st"])): gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index f83f91873..af0bb1c3d 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -103,17 +103,9 @@ def forward( info: Dict[str, Any] = {}, ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" - s = torch.as_tensor( - s, - device=self.device, - dtype=torch.float32 # type: ignore - ).flatten(1) + s = torch.as_tensor(s, device=self.device, dtype=torch.float32).flatten(1) if a is not None: - a = torch.as_tensor( - a, - device=self.device, - dtype=torch.float32 # type: ignore - ).flatten(1) + a = torch.as_tensor(a, device=self.device, dtype=torch.float32).flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) logits = self.last(logits) @@ -312,9 +304,7 @@ def forward( s, (h, c) = self.nn(s) s = s[:, -1] if a is not None: - a = torch.as_tensor( - a, device=self.device, dtype=torch.float32 - ) # type: ignore + a = torch.as_tensor(a, device=self.device, dtype=torch.float32) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s From dfb412f32d3a7145cbde4c53c0f6896f66636b4b Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 14:14:04 -0400 Subject: [PATCH 04/11] fix D204 for yapf --- examples/atari/atari_network.py | 4 ++++ examples/atari/atari_wrapper.py | 8 ++++++++ examples/box2d/bipedal_hardcore_sac.py | 1 + examples/mujoco/plotter.py | 1 + examples/vizdoom/env.py | 1 + setup.cfg | 1 + test/base/env.py | 2 ++ test/base/test_collector.py | 2 ++ test/base/test_env_finite.py | 5 +++++ test/multiagent/tic_tac_toe.py | 1 + test/multiagent/tic_tac_toe_env.py | 1 + test/throughput/test_collector_profile.py | 1 + tianshou/data/batch.py | 1 + tianshou/data/buffer/cached.py | 1 + tianshou/data/buffer/manager.py | 2 ++ tianshou/data/buffer/prio.py | 1 + tianshou/data/buffer/vecbuf.py | 2 ++ tianshou/data/collector.py | 2 ++ tianshou/data/utils/converter.py | 1 + tianshou/data/utils/segtree.py | 1 + tianshou/env/maenv.py | 1 + tianshou/env/utils.py | 1 + tianshou/env/venvs.py | 7 +++++++ tianshou/env/worker/base.py | 1 + tianshou/env/worker/dummy.py | 1 + tianshou/env/worker/ray.py | 1 + tianshou/env/worker/subproc.py | 4 ++++ tianshou/exploration/random.py | 3 +++ tianshou/policy/base.py | 1 + tianshou/policy/imitation/base.py | 1 + tianshou/policy/imitation/discrete_bcq.py | 1 + tianshou/policy/imitation/discrete_cql.py | 1 + tianshou/policy/imitation/discrete_crr.py | 1 + tianshou/policy/modelbased/psrl.py | 2 ++ tianshou/policy/modelfree/a2c.py | 1 + tianshou/policy/modelfree/c51.py | 1 + tianshou/policy/modelfree/ddpg.py | 1 + tianshou/policy/modelfree/discrete_sac.py | 1 + tianshou/policy/modelfree/dqn.py | 1 + tianshou/policy/modelfree/fqf.py | 1 + tianshou/policy/modelfree/iqn.py | 1 + tianshou/policy/modelfree/npg.py | 1 + tianshou/policy/modelfree/pg.py | 1 + tianshou/policy/modelfree/ppo.py | 1 + tianshou/policy/modelfree/qrdqn.py | 1 + tianshou/policy/modelfree/rainbow.py | 1 + tianshou/policy/modelfree/sac.py | 1 + tianshou/policy/modelfree/td3.py | 1 + tianshou/policy/modelfree/trpo.py | 1 + tianshou/policy/multiagent/mapolicy.py | 1 + tianshou/policy/random.py | 1 + tianshou/utils/logger/base.py | 2 ++ tianshou/utils/logger/tensorboard.py | 2 ++ tianshou/utils/logger/wandb.py | 1 + tianshou/utils/net/common.py | 3 +++ tianshou/utils/net/continuous.py | 5 +++++ tianshou/utils/net/discrete.py | 7 +++++++ tianshou/utils/statistics.py | 2 ++ 58 files changed, 104 insertions(+) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 3fb208d44..4598fce11 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -13,6 +13,7 @@ class DQN(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ + def __init__( self, c: int, @@ -56,6 +57,7 @@ class C51(DQN): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ + def __init__( self, c: int, @@ -88,6 +90,7 @@ class Rainbow(DQN): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ + def __init__( self, c: int, @@ -149,6 +152,7 @@ class QRDQN(DQN): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ + def __init__( self, c: int, diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 4fbe5fb4a..333f9787a 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -15,6 +15,7 @@ class NoopResetEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. :param int noop_max: the maximum value of no-ops to run. """ + def __init__(self, env, noop_max=30): super().__init__(env) self.noop_max = noop_max @@ -38,6 +39,7 @@ class MaxAndSkipEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. :param int skip: number of `skip`-th frame. """ + def __init__(self, env, skip=4): super().__init__(env) self._skip = skip @@ -63,6 +65,7 @@ class EpisodicLifeEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ + def __init__(self, env): super().__init__(env) self.lives = 0 @@ -102,6 +105,7 @@ class FireResetEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ + def __init__(self, env): super().__init__(env) assert env.unwrapped.get_action_meanings()[1] == 'FIRE' @@ -117,6 +121,7 @@ class WarpFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ + def __init__(self, env): super().__init__(env) self.size = 84 @@ -138,6 +143,7 @@ class ScaledFloatFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ + def __init__(self, env): super().__init__(env) low = np.min(env.observation_space.low) @@ -157,6 +163,7 @@ class ClipRewardEnv(gym.RewardWrapper): :param gym.Env env: the environment to wrap. """ + def __init__(self, env): super().__init__(env) self.reward_range = (-1, 1) @@ -172,6 +179,7 @@ class FrameStack(gym.Wrapper): :param gym.Env env: the environment to wrap. :param int n_frames: the number of frames to stack. """ + def __init__(self, env, n_frames): super().__init__(env) self.n_frames = n_frames diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index bb2449dce..598622d01 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -48,6 +48,7 @@ def get_args(): class Wrapper(gym.Wrapper): """Env wrapper for reward scale, action repeat and removing done penalty""" + def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True): super().__init__(env) self.action_repeat = action_repeat diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index 04f15aa7b..e3e7057e4 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -97,6 +97,7 @@ def plot_ax( shaded_std=True, legend_outside=False, ): + def legend_fn(x): # return os.path.split(os.path.join( # args.root_dir, x))[0].replace('/', '_') + " (10)" diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index f7db3078c..290cb92e5 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -34,6 +34,7 @@ def battle_button_comb(): class Env(gym.Env): + def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False): super().__init__() self.save_lmp = save_lmp diff --git a/setup.cfg b/setup.cfg index 1370ddcbd..7bb1f1e7c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ ignore = B305,W504,B006,B008 based_on_style = pep8 dedent_closing_brackets = true column_limit = 87 +blank_line_before_nested_class_or_def = true [isort] profile = black diff --git a/test/base/env.py b/test/base/env.py index e14a7246a..cdcb51efa 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -11,6 +11,7 @@ class MyTestEnv(gym.Env): """This is a "going right" task. The task is to go right ``size`` steps. """ + def __init__( self, size, @@ -140,6 +141,7 @@ def step(self, action): class NXEnv(gym.Env): + def __init__(self, size, obs_type, feat_dim=32): self.size = size self.feat_dim = feat_dim diff --git a/test/base/test_collector.py b/test/base/test_collector.py index da54d4b3a..61bd5a6fc 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -22,6 +22,7 @@ class MyPolicy(BasePolicy): + def __init__(self, dict_state=False, need_state=True): """ :param bool dict_state: if the observation of the environment is a dict @@ -46,6 +47,7 @@ def learn(self): class Logger: + def __init__(self, writer): self.cnt = 0 self.writer = writer diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 00b536cd3..54b438507 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -13,6 +13,7 @@ class DummyDataset(Dataset): + def __init__(self, length): self.length = length self.episodes = [3 * i % 5 + 1 for i in range(self.length)] @@ -26,6 +27,7 @@ def __len__(self): class FiniteEnv(gym.Env): + def __init__(self, dataset, num_replicas, rank): self.dataset = dataset self.num_replicas = num_replicas @@ -56,6 +58,7 @@ def step(self, action): class FiniteVectorEnv(BaseVectorEnv): + def __init__(self, env_fns, **kwargs): super().__init__(env_fns, **kwargs) self._alive_env_ids = set() @@ -150,6 +153,7 @@ class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): class AnyPolicy(BasePolicy): + def forward(self, batch, state=None): return Batch(act=np.stack([1] * len(batch))) @@ -162,6 +166,7 @@ def _finite_env_factory(dataset, num_replicas, rank): class MetricTracker: + def __init__(self): self.counter = Counter() self.finished = set() diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 8ecbd2878..02fd47cd7 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -139,6 +139,7 @@ def train_agent( agent_opponent: Optional[BasePolicy] = None, optim: Optional[torch.optim.Optimizer] = None, ) -> Tuple[dict, BasePolicy]: + def env_func(): return TicTacToeEnv(args.board_size, args.win_size) diff --git a/test/multiagent/tic_tac_toe_env.py b/test/multiagent/tic_tac_toe_env.py index e39e1d13d..2c79d303d 100644 --- a/test/multiagent/tic_tac_toe_env.py +++ b/test/multiagent/tic_tac_toe_env.py @@ -17,6 +17,7 @@ class TicTacToeEnv(MultiAgentEnv): :param size: the size of the board (square board) :param win_size: how many units in a row is considered to win """ + def __init__(self, size: int = 3, win_size: int = 3): super().__init__() assert size > 0, f'board size should be positive, but got {size}' diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index eced837a5..bf9c4dc05 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -12,6 +12,7 @@ class MyPolicy(BasePolicy): + def __init__(self, dict_state=False, need_state=True): """ :param bool dict_state: if the observation of the environment is a dict diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 11acf1ecd..accc0c8c2 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -175,6 +175,7 @@ class Batch: For a detailed description, please refer to :ref:`batch_concept`. """ + def __init__( self, batch_dict: Optional[Union[dict, "Batch", Sequence[Union[dict, "Batch"]], diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py index f62c41aba..19310dbd0 100644 --- a/tianshou/data/buffer/cached.py +++ b/tianshou/data/buffer/cached.py @@ -26,6 +26,7 @@ class CachedReplayBuffer(ReplayBufferManager): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ + def __init__( self, main_buffer: ReplayBuffer, diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 0009575c3..70ebcab03 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -20,6 +20,7 @@ class ReplayBufferManager(ReplayBuffer): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ + def __init__(self, buffer_list: List[ReplayBuffer]) -> None: self.buffer_num = len(buffer_list) self.buffers = np.array(buffer_list, dtype=object) @@ -198,6 +199,7 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ + def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: ReplayBufferManager.__init__(self, buffer_list) # type: ignore kwargs = buffer_list[0].options diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index 17353a514..c4d48be10 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -18,6 +18,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ + def __init__( self, size: int, diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py index 3d5508fc9..2d4831c06 100644 --- a/tianshou/data/buffer/vecbuf.py +++ b/tianshou/data/buffer/vecbuf.py @@ -27,6 +27,7 @@ class VectorReplayBuffer(ReplayBufferManager): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) @@ -51,6 +52,7 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: assert buffer_num > 0 size = int(np.ceil(total_size / buffer_num)) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index f4bb278e4..d2ca85929 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -47,6 +47,7 @@ class Collector(object): Please make sure the given environment has a time limitation if using n_episode collect option. """ + def __init__( self, policy: BasePolicy, @@ -327,6 +328,7 @@ class AsyncCollector(Collector): The arguments are exactly the same as :class:`~tianshou.data.Collector`, please refer to :class:`~tianshou.data.Collector` for more detailed explanation. """ + def __init__( self, policy: BasePolicy, diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 7a95169d2..bd1dd5358 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -80,6 +80,7 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]: def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None: """Copy object into HDF5 group.""" + def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None: """Pickle, convert to numpy array and write to HDF5 dataset.""" data = np.frombuffer(pickle.dumps(x), dtype=np.byte) diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py index 0f9a5a5b2..063675c53 100644 --- a/tianshou/data/utils/segtree.py +++ b/tianshou/data/utils/segtree.py @@ -17,6 +17,7 @@ class SegmentTree: :param int size: the size of segment tree. """ + def __init__(self, size: int) -> None: bound = 1 while bound < size: diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py index 299051d0c..456bbca13 100644 --- a/tianshou/env/maenv.py +++ b/tianshou/env/maenv.py @@ -22,6 +22,7 @@ class MultiAgentEnv(ABC, gym.Env): The available action's mask is set to 1, otherwise it is set to 0. Further usage can be found at :ref:`marl_example`. """ + def __init__(self) -> None: pass diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index c0031ba02..cec159012 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -5,6 +5,7 @@ class CloudpickleWrapper(object): """A cloudpickle wrapper used in SubprocVectorEnv.""" + def __init__(self, data: Any) -> None: self.data = data diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 37f96c9df..14918759b 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -69,6 +69,7 @@ def seed(self, seed): obs_rms should be passed in. Default to None. :param bool update_obs_rms: Whether to update obs_rms. Default to True. """ + def __init__( self, env_fns: List[Callable[[], gym.Env]], @@ -320,6 +321,7 @@ class DummyVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: super().__init__(env_fns, DummyEnvWorker, **kwargs) @@ -331,7 +333,9 @@ class SubprocVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) @@ -347,7 +351,9 @@ class ShmemVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) @@ -363,6 +369,7 @@ class RayVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: try: import ray diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 162feb6b1..6fef9f68d 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -7,6 +7,7 @@ class EnvWorker(ABC): """An abstract worker for an environment.""" + def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index d964850f3..9e68e9f04 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -8,6 +8,7 @@ class DummyEnvWorker(EnvWorker): """Dummy worker used in sequential vector environments.""" + def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self.env = env_fn() super().__init__(env_fn) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 22f6bb46e..5d73763f2 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -13,6 +13,7 @@ class RayEnvWorker(EnvWorker): """Ray worker used in RayVectorEnv.""" + def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn()) super().__init__(env_fn) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 5f81ab452..8ef264360 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -28,6 +28,7 @@ class ShArray: """Wrapper of multiprocessing Array.""" + def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) # type: ignore self.dtype = dtype @@ -61,6 +62,7 @@ def _worker( env_fn_wrapper: CloudpickleWrapper, obs_bufs: Optional[Union[dict, tuple, ShArray]] = None, ) -> None: + def _encode_obs( obs: Union[dict, tuple, np.ndarray], buffer: Union[dict, tuple, ShArray] ) -> None: @@ -114,6 +116,7 @@ def _encode_obs( class SubprocEnvWorker(EnvWorker): """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" + def __init__( self, env_fn: Callable[[], gym.Env], share_memory: bool = False ) -> None: @@ -142,6 +145,7 @@ def __getattr__(self, key: str) -> Any: return self.parent_remote.recv() def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: + def decode_obs( buffer: Optional[Union[dict, tuple, ShArray]] ) -> Union[dict, tuple, np.ndarray]: diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 0c5fd3e4c..03f863873 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -6,6 +6,7 @@ class BaseNoise(ABC, object): """The action noise base class.""" + def __init__(self) -> None: super().__init__() @@ -21,6 +22,7 @@ def __call__(self, size: Sequence[int]) -> np.ndarray: class GaussianNoise(BaseNoise): """The vanilla gaussian process, for exploration in DDPG by default.""" + def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: super().__init__() self._mu = mu @@ -47,6 +49,7 @@ class OUNoise(BaseNoise): vanilla gaussian process has little difference from using the Ornstein-Uhlenbeck process. """ + def __init__( self, mu: float = 0.0, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index deed1f168..09de4f68d 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -57,6 +57,7 @@ class BasePolicy(ABC, nn.Module): torch.save(policy.state_dict(), "policy.pth") policy.load_state_dict(torch.load("policy.pth")) """ + def __init__( self, observation_space: Optional[gym.Space] = None, diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 35e665e14..a5321acdf 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -21,6 +21,7 @@ class ImitationPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index a3fdeda17..135fe8d84 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -33,6 +33,7 @@ class DiscreteBCQPolicy(DQNPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index cfd9e4cab..ad4ed19a3 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -28,6 +28,7 @@ class DiscreteCQLPolicy(QRDQNPolicy): Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 8506ba550..6a149509e 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -34,6 +34,7 @@ class DiscreteCRRPolicy(PGPolicy): Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed explanation. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index ce385aa28..e5ea5c9a5 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -19,6 +19,7 @@ class PSRLModel(object): :param float discount_factor: in [0, 1]. :param float epsilon: for precision control in value iteration. """ + def __init__( self, trans_count_prior: np.ndarray, @@ -158,6 +159,7 @@ class PSRLPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, trans_count_prior: np.ndarray, diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 1c59d2648..e44b58b58 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -48,6 +48,7 @@ class A2CPolicy(PGPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 17423031f..4e79eb356 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -31,6 +31,7 @@ class C51Policy(DQNPolicy): Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 66ed4a780..18bb81b6b 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -38,6 +38,7 @@ class DDPGPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, actor: Optional[torch.nn.Module], diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 9f2331711..5f087e5b9 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -34,6 +34,7 @@ class DiscreteSACPolicy(SACPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ae573339d..850490998 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -32,6 +32,7 @@ class DQNPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index df22acafc..3c015b3d0 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -33,6 +33,7 @@ class FQFPolicy(QRDQNPolicy): Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed explanation. """ + def __init__( self, model: FullQuantileFunction, diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 0647307cb..9d9777b98 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -32,6 +32,7 @@ class IQNPolicy(QRDQNPolicy): Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 72761bf1e..758093d1a 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -45,6 +45,7 @@ class NPGPolicy(A2CPolicy): :param bool deterministic_eval: whether to use deterministic action instead of stochastic action sampled by the policy. Default to False. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index d84217f6f..a64828874 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -34,6 +34,7 @@ class PGPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index fe1d37b7a..e1e17aa2f 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -58,6 +58,7 @@ class PPOPolicy(A2CPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 1b3aaf1ef..fe3e101f7 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -29,6 +29,7 @@ class QRDQNPolicy(DQNPolicy): Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed explanation. """ + def __init__( self, model: torch.nn.Module, diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 138c143e1..9028258d7 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -29,6 +29,7 @@ class RainbowPolicy(C51Policy): Please refer to :class:`~tianshou.policy.C51Policy` for more detailed explanation. """ + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: sample_noise(self.model) if self._target and sample_noise(self.model_old): diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 729c3bf3c..13702e88f 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -48,6 +48,7 @@ class SACPolicy(DDPGPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 506c5cd03..a033237ea 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -46,6 +46,7 @@ class TD3Policy(DDPGPolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 45d820dcb..75956d987 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -48,6 +48,7 @@ class TRPOPolicy(NPGPolicy): :param bool deterministic_eval: whether to use deterministic action instead of stochastic action sampled by the policy. Default to False. """ + def __init__( self, actor: torch.nn.Module, diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index bd6272f80..75705f4a3 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -15,6 +15,7 @@ class MultiAgentPolicyManager(BasePolicy): and "learn": it splits the data and feeds them to each policy. A figure in :ref:`marl_example` can help you better understand this procedure. """ + def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None: super().__init__(**kwargs) self.policies = policies diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 863f93229..dfb79564f 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -11,6 +11,7 @@ class RandomPolicy(BasePolicy): It randomly chooses an action from the legal action. """ + def forward( self, batch: Batch, diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index fcb3c4b93..9b89d5e88 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -16,6 +16,7 @@ class BaseLogger(ABC): :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. """ + def __init__( self, train_interval: int = 1000, @@ -132,6 +133,7 @@ def restore_data(self) -> Tuple[int, int, int]: class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" + def __init__(self) -> None: super().__init__() diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index cee450168..86e873cda 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -18,6 +18,7 @@ class TensorboardLogger(BaseLogger): :param int save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). """ + def __init__( self, writer: SummaryWriter, @@ -77,6 +78,7 @@ class BasicLogger(TensorboardLogger): This class is for compatibility. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427." diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index db38cfcc2..7a837c96c 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -30,6 +30,7 @@ class WandBLogger(BaseLogger): :param int update_interval: the log interval in log_update_data(). Default to 1000. """ + def __init__( self, train_interval: int = 1000, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index d9b883c87..1c8cf7d0f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -47,6 +47,7 @@ class MLP(nn.Module): :param device: which device to create this model on. Default to None. :param linear_layer: use this module as linear layer. Default to nn.Linear. """ + def __init__( self, input_dim: int, @@ -133,6 +134,7 @@ class Net(nn.Module): :class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's suggested be used. """ + def __init__( self, state_shape: Union[int, Sequence[int]], @@ -206,6 +208,7 @@ class Recurrent(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ + def __init__( self, layer_num: int, diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index af0bb1c3d..ebc8cd773 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -33,6 +33,7 @@ class Actor(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ + def __init__( self, preprocess_net: nn.Module, @@ -82,6 +83,7 @@ class Critic(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ + def __init__( self, preprocess_net: nn.Module, @@ -138,6 +140,7 @@ class ActorProb(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ + def __init__( self, preprocess_net: nn.Module, @@ -191,6 +194,7 @@ class RecurrentActorProb(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ + def __init__( self, layer_num: int, @@ -268,6 +272,7 @@ class RecurrentCritic(nn.Module): For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. """ + def __init__( self, layer_num: int, diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index e642bafc5..bcc6531e3 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -34,6 +34,7 @@ class Actor(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ + def __init__( self, preprocess_net: nn.Module, @@ -86,6 +87,7 @@ class Critic(nn.Module): Please refer to :class:`~tianshou.utils.net.common.Net` as an instance of how preprocess_net is suggested to be defined. """ + def __init__( self, preprocess_net: nn.Module, @@ -121,6 +123,7 @@ class CosineEmbeddingNetwork(nn.Module): From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ + def __init__(self, num_cosines: int, embedding_dim: int) -> None: super().__init__() self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU()) @@ -163,6 +166,7 @@ class ImplicitQuantileNetwork(Critic): The second item of the first return value is tau vector. """ + def __init__( self, preprocess_net: nn.Module, @@ -209,6 +213,7 @@ class FractionProposalNetwork(nn.Module): Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ + def __init__(self, num_fractions: int, embedding_dim: int) -> None: super().__init__() self.net = nn.Linear(embedding_dim, num_fractions) @@ -251,6 +256,7 @@ class FullQuantileFunction(ImplicitQuantileNetwork): The first return value is a tuple of (quantiles, fractions, quantiles_tau), where fractions is a Batch(taus, tau_hats, entropies). """ + def __init__( self, preprocess_net: nn.Module, @@ -310,6 +316,7 @@ class NoisyLinear(nn.Module): Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master /fqf_iqn_qrdqn/network.py . """ + def __init__( self, in_features: int, out_features: int, noisy_std: float = 0.5 ) -> None: diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index 574747776..5c895303b 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -23,6 +23,7 @@ class MovAvg(object): >>> print(f'{stat.mean():.2f}±{stat.std():.2f}') 6.50±1.12 """ + def __init__(self, size: int = 100) -> None: super().__init__() self.size = size @@ -70,6 +71,7 @@ class RunningMeanStd(object): https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm """ + def __init__( self, mean: Union[float, np.ndarray] = 0.0, From 00c0aa5ed73cfc75a5587f96ad8b2445100da90a Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 15:01:47 -0400 Subject: [PATCH 05/11] fix mypy --- tianshou/data/collector.py | 8 ++++++-- tianshou/utils/net/continuous.py | 18 +++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d2ca85929..e7f47e88c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -230,7 +230,9 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step(action_remap, ready_env_ids) + obs_next, rew, done, info = self.env.step( + action_remap, ready_env_ids # type: ignore + ) self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -447,7 +449,9 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step(action_remap, ready_env_ids) + obs_next, rew, done, info = self.env.step( + action_remap, ready_env_ids # type: ignore + ) # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index ebc8cd773..1bb090cdf 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -105,9 +105,17 @@ def forward( info: Dict[str, Any] = {}, ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" - s = torch.as_tensor(s, device=self.device, dtype=torch.float32).flatten(1) + s = torch.as_tensor( + s, + device=self.device, # type: ignore + dtype=torch.float32, + ).flatten(1) if a is not None: - a = torch.as_tensor(a, device=self.device, dtype=torch.float32).flatten(1) + a = torch.as_tensor( + a, + device=self.device, # type: ignore + dtype=torch.float32, + ).flatten(1) s = torch.cat([s, a], dim=1) logits, h = self.preprocess(s) logits = self.last(logits) @@ -309,7 +317,11 @@ def forward( s, (h, c) = self.nn(s) s = s[:, -1] if a is not None: - a = torch.as_tensor(a, device=self.device, dtype=torch.float32) + a = torch.as_tensor( + a, + device=self.device, # type: ignore + dtype=torch.float32, + ) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s From 074bb27793754dd60acbbd6db2d9bcf2cbf06354 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 15:10:20 -0400 Subject: [PATCH 06/11] fix D400 --- tianshou/data/__init__.py | 6 +++++- tianshou/data/collector.py | 10 ++++------ tianshou/env/__init__.py | 2 ++ tianshou/policy/__init__.py | 6 +++++- tianshou/trainer/__init__.py | 6 +++++- tianshou/utils/__init__.py | 2 ++ 6 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 23a7e62ff..1466587f9 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,4 +1,8 @@ -"""isort:skip_file""" +"""Data package. + +isort:skip_file +""" + from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index e7f47e88c..d52cc511d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -230,9 +230,8 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids # type: ignore - ) + result = self.env.step(action_remap, ready_env_ids) # type: ignore + obs_next, rew, done, info = result self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) if self.preprocess_fn: @@ -449,9 +448,8 @@ def collect( # get bounded and remapped actions first (not saved into buffer) action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, done, info = self.env.step( - action_remap, ready_env_ids # type: ignore - ) + result = self.env.step(action_remap, ready_env_ids) # type: ignore + obs_next, rew, done, info = result # change self.data here because ready_env_ids has changed ready_env_ids = np.array([i["env_id"] for i in info]) diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 0e145ba29..c77c30c3f 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,3 +1,5 @@ +"""Env package.""" + from tianshou.env.maenv import MultiAgentEnv from tianshou.env.venvs import ( BaseVectorEnv, diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 421898162..552422cab 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,4 +1,8 @@ -"""isort:skip_file""" +"""Policy package. + +isort:skip_file +""" + from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index f3baf8499..9ba22887e 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,4 +1,8 @@ -"""isort:skip_file""" +"""Trainer package. + +isort:skip_file +""" + from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer from tianshou.trainer.offpolicy import offpolicy_trainer diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 64ae88328..5af038ab3 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,3 +1,5 @@ +"""Utils package.""" + from tianshou.utils.config import tqdm_config from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger From 7d8f783a3292e7a75acffd06ed949786b7ba8f54 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 15:58:11 -0400 Subject: [PATCH 07/11] add makefile --- Makefile | 60 ++++++++++ docs/conf.py | 2 +- docs/spelling_wordlist.txt | 130 ++++++++++++++++++++++ docs/tutorials/batch.rst | 6 +- docs/tutorials/cheatsheet.rst | 4 +- docs/tutorials/concepts.rst | 2 +- tianshou/data/__init__.py | 6 +- tianshou/data/batch.py | 19 ++-- tianshou/data/buffer/prio.py | 2 +- tianshou/env/venvs.py | 4 +- tianshou/exploration/random.py | 4 +- tianshou/policy/__init__.py | 6 +- tianshou/policy/base.py | 4 +- tianshou/policy/imitation/discrete_bcq.py | 4 +- tianshou/policy/modelfree/discrete_sac.py | 2 +- tianshou/policy/modelfree/sac.py | 2 +- tianshou/trainer/__init__.py | 5 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 4 +- tianshou/utils/net/common.py | 6 +- tianshou/utils/statistics.py | 2 +- 21 files changed, 231 insertions(+), 45 deletions(-) create mode 100644 Makefile create mode 100644 docs/spelling_wordlist.txt diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..a5a0f5d29 --- /dev/null +++ b/Makefile @@ -0,0 +1,60 @@ +SHELL=/bin/bash +PROJECT_NAME=tianshou +PROJECT_PATH=${PROJECT_NAME}/ +LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py + +check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade +check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade + +pytest: + $(call check_install, pytest) + $(call check_install, pytest_cov) + $(call check_install, pytest_xdist) + pytest tests --cov ${PROJECT_PATH} --durations 0 -v --cov-report term-missing + +mypy: + $(call check_install, mypy) + mypy ${PROJECT_NAME} + +lint: + $(call check_install, flake8) + $(call check_install_extra, bugbear, flake8_bugbear) + flake8 ${LINT_PATHS} --count --show-source --statistics + +format: + # sort imports + $(call check_install, isort) + isort ${LINT_PATHS} + # reformat using yapf + $(call check_install, yapf) + yapf -ir ${LINT_PATHS} + +check-codestyle: + $(call check_install, isort) + $(call check_install, yapf) + isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS} + +check-docstyle: + $(call check_install, pydocstyle) + $(call check_install, doc8) + $(call check_install, sphinx) + $(call check_install, sphinx_rtd_theme) + pydocstyle ${PROJECT_PATH} --convention=google && doc8 docs && cd docs && make html SPHINXOPTS="-W" + +doc: + $(call check_install, sphinx) + $(call check_install, sphinx_rtd_theme) + cd docs && make html && cd _build/html && python3 -m http.server + +spelling: + $(call check_install, sphinx) + $(call check_install, sphinx_rtd_theme) + $(call check_install_extra, sphinxcontrib.spelling, sphinxcontrib.spelling pyenchant) + cd docs && make spelling SPHINXOPTS="-W" + +clean: + cd docs && make clean + +commit-checks: format lint mypy check-docstyle spelling + +.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks diff --git a/docs/conf.py b/docs/conf.py index bcf2a9b1c..57d8b48df 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,7 +50,7 @@ # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] -source_suffix = [".rst", ".md"] +source_suffix = [".rst"] master_doc = "index" # List of patterns, relative to source directory, that match files and diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt new file mode 100644 index 000000000..2a5c0676e --- /dev/null +++ b/docs/spelling_wordlist.txt @@ -0,0 +1,130 @@ +tianshou +arXiv +tanh +lr +logits +env +envs +optim +eps +timelimit +TimeLimit +maxsize +timestep +numpy +ndarray +stackoverflow +len +tac +fqf +iqn +qrdqn +rl +quantile +quantiles +dqn +param +async +subprocess +nn +equ +cql +fn +boolean +pre +np +rnn +rew +pre +perceptron +bsz +dataset +mujoco +jit +nstep +preprocess +repo +ReLU +namespace +th +utils +NaN +linesearch +hyperparameters +pseudocode +entropies +nn +config +cpu +rms +debias +indice +regularizer +miniblock +modularize +serializable +softmax +vectorized +optimizers +undiscounted +submodule +subclasses +submodules +tfevent +dirichlet +webpages +docstrings +num +py +pythonic +中文文档位于 +conda +miniconda +Amir +Andreas +Antonoglou +Beattie +Bellemare +Charles +Daan +Demis +Dharshan +Fidjeland +Georg +Hassabis +Helen +Ioannis +Kavukcuoglu +King +Koray +Kumaran +Legg +Mnih +Ostrovski +Petersen +Riedmiller +Rusu +Sadik +Shane +Stig +Veness +Volodymyr +Wierstra +Lillicrap +Pritzel +Heess +Erez +Yuval +Tassa +Schulman +Filip +Wolski +Prafulla +Dhariwal +Radford +Oleg +Klimov +Kaichao +Strens +Ornstein +Uhlenbeck diff --git a/docs/tutorials/batch.rst b/docs/tutorials/batch.rst index 49d913d9f..71f82f84e 100644 --- a/docs/tutorials/batch.rst +++ b/docs/tutorials/batch.rst @@ -60,7 +60,7 @@ The content of ``Batch`` objects can be defined by the following rules. 2. The keys are always strings (they are names of corresponding values). -3. The values can be scalars, tensors, or Batch objects. The recurse definition makes it possible to form a hierarchy of batches. +3. The values can be scalars, tensors, or Batch objects. The recursive definition makes it possible to form a hierarchy of batches. 4. Tensors are the most important values. In short, tensors are n-dimensional arrays of the same data type. We support two types of tensors: `PyTorch `_ tensor type ``torch.Tensor`` and `NumPy `_ tensor type ``np.ndarray``. @@ -348,7 +348,7 @@ The introduction of reserved keys gives rise to the need to check if a key is re
-The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recurse emptiness (a ``Batch`` object without any scalar/tensor leaf nodes). +The ``Batch.is_empty`` function has an option to decide whether to identify direct emptiness (just a ``Batch()``) or to identify recursive emptiness (a ``Batch`` object without any scalar/tensor leaf nodes). .. note:: @@ -492,7 +492,7 @@ Miscellaneous Notes
-2. It is often the case that the observations returned from the environment are NumPy ndarrays but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors. +2. It is often the case that the observations returned from the environment are all NumPy ndarray but the policy requires ``torch.Tensor`` for prediction and learning. In this regard, Tianshou provides helper functions to convert the stored data in-place into Numpy arrays or Torch tensors. 3. ``obj.stack_([a, b])`` is the same as ``Batch.stack([obj, a, b])``, and ``obj.cat_([a, b])`` is the same as ``Batch.cat([obj, a, b])``. Considering the frequent requirement of concatenating two ``Batch`` objects, Tianshou also supports ``obj.cat_(a)`` to be an alias of ``obj.cat_([a])``. diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 38a0291f4..c224b193b 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -341,7 +341,7 @@ With the flexible core APIs, Tianshou can support multi-agent reinforcement lear Currently, we support three types of multi-agent reinforcement learning paradigms: -1. Simultaneous move: at each timestep, all the agents take their actions (example: moba games) +1. Simultaneous move: at each timestep, all the agents take their actions (example: MOBA games) 2. Cyclic move: players take action in turn (example: Go game) @@ -371,4 +371,4 @@ By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we action = policy(state_) next_state_, reward = env.step(action) -Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-lerning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`. +Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 7222d9116..b1f76deb7 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -219,7 +219,7 @@ Tianshou provides other type of data buffer such as :class:`~tianshou.data.Prior Policy ------ -Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. +Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. A policy class typically has the following parts: diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 1466587f9..89250d009 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,7 +1,5 @@ -"""Data package. - -isort:skip_file -""" +"""Data package.""" +# isort:skip_file from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index accc0c8c2..98adf680b 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -53,7 +53,7 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: return v # most often case # convert the value to np.ndarray # convert to object data type if neither bool nor number - # raises an exception if array's elements are tensors themself + # raises an exception if array's elements are tensors themselves v = np.asanyarray(v) if not issubclass(v.dtype.type, (np.bool_, np.number)): v = v.astype(object) @@ -72,9 +72,11 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: return v -def _create_value(inst: Any, - size: int, - stack: bool = True) -> Union["Batch", np.ndarray, torch.Tensor]: +def _create_value( + inst: Any, + size: int, + stack: bool = True, +) -> Union["Batch", np.ndarray, torch.Tensor]: """Create empty place-holders accroding to inst's shape. :param bool stack: whether to stack or to concatenate. E.g. if inst has shape of @@ -167,11 +169,10 @@ def _alloc_by_keys_diff( class Batch: """The internal data structure in Tianshou. - Batch is a kind of supercharged array (of temporal data) stored - individually in a (recursive) dictionary of object that can be either numpy - array, torch tensor, or batch themself. It is designed to make it extremely - easily to access, manipulate and set partial view of the heterogeneous data - conveniently. + Batch is a kind of supercharged array (of temporal data) stored individually in a + (recursive) dictionary of object that can be either numpy array, torch tensor, or + batch themselves. It is designed to make it extremely easily to access, manipulate + and set partial view of the heterogeneous data conveniently. For a detailed description, please refer to :ref:`batch_concept`. """ diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index c4d48be10..fa3c49be8 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -66,7 +66,7 @@ def sample_indices(self, batch_size: int) -> np.ndarray: def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: """Get the importance sampling weight. - The "weight" in the returned Batch is the weight on loss function to de-bias + The "weight" in the returned Batch is the weight on loss function to debias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 14918759b..654f55b69 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -49,7 +49,7 @@ def seed(self, seed): Otherwise, the outputs of these envs may be the same with each other. - :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env. + :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the i-th env. :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a worker which contains the i-th env. :param int wait_num: use in asynchronous simulation if the time cost of @@ -61,7 +61,7 @@ def seed(self, seed): :param float timeout: use in asynchronous simulation same as above, in each vectorized step it only deal with those environments spending time within ``timeout`` seconds. - :param bool norm_obs: Whether to track mean/std of data and normalise observation + :param bool norm_obs: Whether to track mean/std of data and normalize observation on return. For now, observation normalization only support observation of type np.ndarray. :param obs_rms: class to track mean&std of observation. If not given, it will diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 03f863873..25316e98d 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -21,7 +21,7 @@ def __call__(self, size: Sequence[int]) -> np.ndarray: class GaussianNoise(BaseNoise): - """The vanilla gaussian process, for exploration in DDPG by default.""" + """The vanilla Gaussian process, for exploration in DDPG by default.""" def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: super().__init__() @@ -46,7 +46,7 @@ class OUNoise(BaseNoise): For required parameters, you can refer to the stackoverflow page. However, our experiment result shows that (similar to OpenAI SpinningUp) using - vanilla gaussian process has little difference from using the + vanilla Gaussian process has little difference from using the Ornstein-Uhlenbeck process. """ diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 552422cab..6a842356f 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,7 +1,5 @@ -"""Policy package. - -isort:skip_file -""" +"""Policy package.""" +# isort:skip_file from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 09de4f68d..feb6479ce 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -14,7 +14,7 @@ class BasePolicy(ABC, nn.Module): """The base class for any RL policy. - Tianshou aims to modularizing RL algorithms. It comes into several classes of + Tianshou aims to modularize RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. @@ -285,7 +285,7 @@ def compute_episodic_return( :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be - recongized by buffer.unfinished_index(). + recognized by buffer.unfinished_index(). :param numpy.ndarray indices: tell batch's location in buffer, batch is equal to buffer[indices]. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 135fe8d84..d9cac65df 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -15,7 +15,7 @@ class DiscreteBCQPolicy(DQNPolicy): :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> q_value) :param torch.nn.Module imitator: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits) + :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float discount_factor: in [0, 1]. :param int estimation_step: the number of steps to look ahead. Default to 1. @@ -23,7 +23,7 @@ class DiscreteBCQPolicy(DQNPolicy): :param float eval_eps: the epsilon-greedy noise added in evaluation. :param float unlikely_action_threshold: the threshold (tau) for unlikely actions, as shown in Equ. (17) in the paper. Default to 0.3. - :param float imitation_logits_penalty: reguralization weight for imitation + :param float imitation_logits_penalty: regularization weight for imitation logits. Default to 1e-2. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 5f087e5b9..7c580f37a 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -25,7 +25,7 @@ class DiscreteSACPolicy(SACPolicy): :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy regularization coefficient. Default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, the - alpha is automatatically tuned. + alpha is automatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 13702e88f..2657a1eee 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -27,7 +27,7 @@ class SACPolicy(DDPGPolicy): :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy regularization coefficient. Default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then - alpha is automatatically tuned. + alpha is automatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. :param BaseNoise exploration_noise: add a noise to action for exploration. diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 9ba22887e..11b3a95ef 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,7 +1,6 @@ -"""Trainer package. +"""Trainer package.""" -isort:skip_file -""" +# isort:skip_file from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index e9fe91ea4..922646197 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -44,7 +44,7 @@ def offpolicy_trainer( :param int step_per_epoch: the number of transitions collected per epoch. :param int step_per_collect: the number of transitions the collector would collect before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatly in each epoch. + transitions and do some policy network update repeatedly in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index fd845c506..6788e8a65 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -51,10 +51,10 @@ def onpolicy_trainer( policy network. :param int step_per_collect: the number of transitions the collector would collect before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatly in each epoch. + transitions and do some policy network update repeatedly in each epoch. :param int episode_per_collect: the number of episodes the collector would collect before the network update, i.e., trainer will collect "episode_per_collect" - episodes and do some policy network update repeatly in each epoch. + episodes and do some policy network update repeatedly in each epoch. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature ``f( num_epoch: int, step_idx: int) -> None``. diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 1c8cf7d0f..b518a54a7 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -33,7 +33,7 @@ class MLP(nn.Module): :param int input_dim: dimension of the input vector. :param int output_dim: dimension of the output vector. If set to 0, there is no final linear layer. - :param hidden_sizes: shape of MLP passed in as a list, not incluing + :param hidden_sizes: shape of MLP passed in as a list, not including input_dim and output_dim. :param norm_layer: use which normalization before activation, e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. @@ -41,7 +41,7 @@ class MLP(nn.Module): of hidden_sizes, to use different normalization module in different layers. Default to no normalization. :param activation: which activation to use after each layer, can be both - the same actvition for all layers if passed in nn.Module, or different + the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. :param device: which device to create this model on. Default to None. @@ -107,7 +107,7 @@ class Net(nn.Module): of hidden_sizes, to use different normalization module in different layers. Default to no normalization. :param activation: which activation to use after each layer, can be both - the same actvition for all layers if passed in nn.Module, or different + the same activation for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. :param device: specify the device when the network actually runs. Default diff --git a/tianshou/utils/statistics.py b/tianshou/utils/statistics.py index 5c895303b..a81af601e 100644 --- a/tianshou/utils/statistics.py +++ b/tianshou/utils/statistics.py @@ -67,7 +67,7 @@ def std(self) -> float: class RunningMeanStd(object): - """Calulates the running mean and std of a data stream. + """Calculates the running mean and std of a data stream. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm """ From 07466c067a525ab6eedfa13a2d3a7f0fb54b2c20 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 16:12:16 -0400 Subject: [PATCH 08/11] finish --- .github/PULL_REQUEST_TEMPLATE.md | 13 +++------- Makefile | 2 +- docs/contributing.rst | 41 +++++++++++++++++++------------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d72600676..280583538 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,14 +3,7 @@ + [ ] algorithm implementation fix + [ ] documentation modification + [ ] new feature +- [ ] I have reformatted the code using `make format` (**required**) +- [ ] I have checked the code using `make commit-checks` (**required**) - [ ] If applicable, I have mentioned the relevant/related issue(s) - -Less important but also useful: - -- [ ] I have visited the [source website](https://github.com/thu-ml/tianshou) -- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates -- [ ] I have mentioned version numbers, operating system and environment, where applicable: - ```python - import tianshou, torch, numpy, sys - print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) - ``` +- [ ] If applicable, I have listed every items in this Pull Request below diff --git a/Makefile b/Makefile index a5a0f5d29..5e623e04f 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ pytest: $(call check_install, pytest) $(call check_install, pytest_cov) $(call check_install, pytest_xdist) - pytest tests --cov ${PROJECT_PATH} --durations 0 -v --cov-report term-missing + pytest test --cov ${PROJECT_PATH} --durations 0 -v --cov-report term-missing mypy: $(call check_install, mypy) diff --git a/docs/contributing.rst b/docs/contributing.rst index cf015a95e..09319a9ba 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -25,7 +25,23 @@ We follow PEP8 python code style. To check, in the main directory, run: .. code-block:: bash - $ flake8 . --count --show-source --statistics + $ make lint + + +Code Formatter +-------------- + +We use isort and yapf to format all codes. To format, in the main directory, run: + +.. code-block:: bash + + $ make format + +To check if formatted correctly, in the main directory, run: + +.. code-block:: bash + + $ make check-codestyle Type Check @@ -35,7 +51,7 @@ We use `mypy `_ to check the type annotations. .. code-block:: bash - $ mypy + $ make mypy Test Locally @@ -45,7 +61,7 @@ This command will run automatic tests in the main directory .. code-block:: bash - $ pytest test --cov tianshou -s --durations 0 -v + $ make pytest Test by GitHub Actions @@ -80,9 +96,9 @@ To compile documentation into webpages, run .. code-block:: bash - $ make html + $ make doc -under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers. +The generated webpages are in ``docs/_build`` and can be viewed with browsers. Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/. @@ -92,21 +108,14 @@ Documentation Generation Test We have the following three documentation tests: -1. pydocstyle: test docstrings under ``tianshou/``. To check, in the main directory, run: - -.. code-block:: bash - - $ pydocstyle tianshou +1. pydocstyle: test docstrings under ``tianshou/``. 2. doc8: test ReStructuredText formats. To check, in the main directory, run: -.. code-block:: bash - - $ doc8 docs - 3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. To check, in the main directory, run: +To check, in the main directory, run: + .. code-block:: bash - $ cd docs - $ make html SPHINXOPTS="-W" + $ make check-docstyle From e4ac4613c4754e1d0b011772e6e60adcb0ea0079 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 16:18:44 -0400 Subject: [PATCH 09/11] fix ci --- docs/contributing.rst | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/docs/contributing.rst b/docs/contributing.rst index 09319a9ba..6062cf025 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -18,19 +18,15 @@ in the main directory. This installation is removable by $ python setup.py develop --uninstall -PEP8 Code Style Check ---------------------- +PEP8 Code Style Check and Code Formatter +---------------------------------------- -We follow PEP8 python code style. To check, in the main directory, run: +We follow PEP8 python code style with flake8. To check, in the main directory, run: .. code-block:: bash $ make lint - -Code Formatter --------------- - We use isort and yapf to format all codes. To format, in the main directory, run: .. code-block:: bash @@ -108,11 +104,11 @@ Documentation Generation Test We have the following three documentation tests: -1. pydocstyle: test docstrings under ``tianshou/``. +1. pydocstyle: test docstrings under ``tianshou/``; -2. doc8: test ReStructuredText formats. To check, in the main directory, run: +2. doc8: test ReStructuredText formats; -3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. To check, in the main directory, run: +3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. To check, in the main directory, run: From b5d967170afd7f3c7e65e71408fd8209dd95d9ab Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 16:30:57 -0400 Subject: [PATCH 10/11] test if make spelling work --- .github/workflows/lint_and_docs.yml | 7 ++----- Makefile | 2 +- docs/contributing.rst | 10 +++++----- docs/spelling_wordlist.txt | 5 +++-- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.github/workflows/lint_and_docs.yml b/.github/workflows/lint_and_docs.yml index 89daaaa6b..681654689 100644 --- a/.github/workflows/lint_and_docs.yml +++ b/.github/workflows/lint_and_docs.yml @@ -29,8 +29,5 @@ jobs: mypy - name: Documentation test run: | - pydocstyle tianshou - doc8 docs --max-line-length 1000 - cd docs - make html SPHINXOPTS="-W" - cd .. + make check-docstyle + make spelling diff --git a/Makefile b/Makefile index 5e623e04f..da5030ccd 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ check-docstyle: $(call check_install, doc8) $(call check_install, sphinx) $(call check_install, sphinx_rtd_theme) - pydocstyle ${PROJECT_PATH} --convention=google && doc8 docs && cd docs && make html SPHINXOPTS="-W" + pydocstyle ${PROJECT_PATH} && doc8 docs && cd docs && make html SPHINXOPTS="-W" doc: $(call check_install, sphinx) diff --git a/docs/contributing.rst b/docs/contributing.rst index 6062cf025..d1de0b65b 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -88,13 +88,13 @@ Documentations are written under the ``docs/`` directory as ReStructuredText (`` API References are automatically generated by `Sphinx `_ according to the outlines under ``docs/api/`` and should be modified when any code changes. -To compile documentation into webpages, run +To compile documentation into webpage, run .. code-block:: bash $ make doc -The generated webpages are in ``docs/_build`` and can be viewed with browsers. +The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/). Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/. @@ -104,11 +104,11 @@ Documentation Generation Test We have the following three documentation tests: -1. pydocstyle: test docstrings under ``tianshou/``; +1. pydocstyle: test all docstring under ``tianshou/``; -2. doc8: test ReStructuredText formats; +2. doc8: test ReStructuredText format; -3. sphinx test: test if there is any errors/warnings when generating front-end html documentations. +3. sphinx test: test if there is any error/warning when generating front-end html documentation. To check, in the main directory, run: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 2a5c0676e..34e84d01d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -72,8 +72,9 @@ subclasses submodules tfevent dirichlet -webpages -docstrings +docstring +webpage +formatter num py pythonic From 83cce25cee02477bacc5c4d07317e5d957f0ec00 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 2 Sep 2021 16:36:14 -0400 Subject: [PATCH 11/11] fix spelling error --- docs/contributor.rst | 1 - docs/spelling_wordlist.txt | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/contributor.rst b/docs/contributor.rst index c594b2c0d..d71dc385b 100644 --- a/docs/contributor.rst +++ b/docs/contributor.rst @@ -4,7 +4,6 @@ Contributor We always welcome contributions to help make Tianshou better. Below are an incomplete list of our contributors (find more on `this page `_). * Jiayi Weng (`Trinkle23897 `_) -* Minghao Zhang (`Mehooz `_) * Alexis Duburcq (`duburcqa `_) * Kaichao You (`youkaichao `_) * Huayu Chen (`ChenDRAG `_) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 34e84d01d..3649df71d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -126,6 +126,10 @@ Radford Oleg Klimov Kaichao +Jiayi +Weng +Duburcq +Huayu Strens Ornstein Uhlenbeck