From 971b4b739b3eb9a92f73d7951b80185b0d0a3247 Mon Sep 17 00:00:00 2001 From: Andriy Drozdyuk Date: Thu, 26 Aug 2021 19:36:22 -0400 Subject: [PATCH 1/5] Flake fixed --- README.md | 2 +- docs/tutorials/cheatsheet.rst | 4 +- docs/tutorials/dqn.rst | 4 +- docs/tutorials/tictactoe.rst | 4 +- examples/atari/atari_bcq.py | 4 +- examples/atari/atari_c51.py | 4 +- examples/atari/atari_cql.py | 4 +- examples/atari/atari_crr.py | 4 +- examples/atari/atari_dqn.py | 4 +- examples/atari/atari_fqf.py | 4 +- examples/atari/atari_iqn.py | 4 +- examples/atari/atari_qrdqn.py | 4 +- examples/box2d/acrobot_dualdqn.py | 4 +- examples/box2d/bipedal_hardcore_sac.py | 4 +- examples/box2d/lunarlander_dqn.py | 4 +- examples/box2d/mcc_sac.py | 4 +- examples/mujoco/mujoco_a2c.py | 4 +- examples/mujoco/mujoco_ddpg.py | 4 +- examples/mujoco/mujoco_npg.py | 4 +- examples/mujoco/mujoco_ppo.py | 4 +- examples/mujoco/mujoco_reinforce.py | 4 +- examples/mujoco/mujoco_sac.py | 4 +- examples/mujoco/mujoco_td3.py | 4 +- examples/mujoco/mujoco_trpo.py | 4 +- examples/vizdoom/vizdoom_c51.py | 4 +- test/continuous/test_ddpg.py | 4 +- test/continuous/test_npg.py | 4 +- test/continuous/test_ppo.py | 4 +- test/continuous/test_sac_with_il.py | 4 +- test/continuous/test_td3.py | 4 +- test/continuous/test_trpo.py | 4 +- test/discrete/test_a2c_with_il.py | 4 +- test/discrete/test_c51.py | 4 +- test/discrete/test_dqn.py | 4 +- test/discrete/test_drqn.py | 4 +- test/discrete/test_fqf.py | 4 +- test/discrete/test_il_bcq.py | 4 +- test/discrete/test_il_crr.py | 4 +- test/discrete/test_iqn.py | 4 +- test/discrete/test_pg.py | 4 +- test/discrete/test_ppo.py | 4 +- test/discrete/test_qrdqn.py | 4 +- test/discrete/test_qrdqn_il_cql.py | 4 +- test/discrete/test_sac.py | 4 +- test/modelbased/test_psrl.py | 2 - test/multiagent/Gomoku.py | 4 +- test/multiagent/tic_tac_toe.py | 4 +- tianshou/utils/__init__.py | 6 +- tianshou/utils/logger/__init__.py | 10 ++ tianshou/utils/logger/base.py | 90 +++++++++++++++++ .../{log_tools.py => logger/tensorboard.py} | 96 +------------------ tianshou/utils/logger/wandb.py | 86 +++++++++++++++++ 52 files changed, 285 insertions(+), 187 deletions(-) create mode 100644 tianshou/utils/logger/__init__.py create mode 100644 tianshou/utils/logger/base.py rename tianshou/utils/{log_tools.py => logger/tensorboard.py} (60%) create mode 100644 tianshou/utils/logger/wandb.py diff --git a/README.md b/README.md index ab9e43fd0..4c680232c 100644 --- a/README.md +++ b/README.md @@ -191,7 +191,7 @@ buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 step_per_epoch, step_per_collect = 10000, 10 writer = SummaryWriter('log/dqn') # tensorboard is also supported! -logger = ts.utils.BasicLogger(writer) +logger = ts.utils.TensorboardLogger(writer) ``` Make environments: diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 1fd59ab2a..38a0291f4 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -40,8 +40,8 @@ This is related to `Issue 349 `_. To resume training process from an existing checkpoint, you need to do the following things in the training process: 1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer; -2. Use ``BasicLogger`` which contains a tensorboard; -3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger. +2. Use ``TensorboardLogger``; +3. To adjust the save frequency, specify ``save_interval`` when initializing TensorboardLogger. And to successfully resume from a checkpoint: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 412051e49..0126c7fca 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -148,9 +148,9 @@ The trainer supports `TensorBoard `_ for :: from torch.utils.tensorboard import SummaryWriter - from tianshou.utils import BasicLogger + from tianshou.utils import TensorboardLogger writer = SummaryWriter('log/dqn') - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) Pass the logger into the trainer, and the training result will be recorded into the TensorBoard. diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 64b0dfd70..58e69906a 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -176,7 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul import numpy as np from copy import deepcopy from torch.utils.tensorboard import SummaryWriter - from tianshou.utils import BasicLogger + from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net @@ -323,7 +323,7 @@ With the above preparation, we are close to the first learned agent. The followi log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) # ======== callback functions used during training ========= diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 142edd816..b4cd5b62d 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.utils.net.discrete import Actor @@ -116,7 +116,7 @@ def test_discrete_bcq(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=args.log_interval) + logger = TensorboardLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 558284f47..b70cc6cc0 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -101,7 +101,7 @@ def test_c51(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_cql.py b/examples/atari/atari_cql.py index ff8650608..cbab82029 100644 --- a/examples/atari/atari_cql.py +++ b/examples/atari/atari_cql.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCQLPolicy @@ -108,7 +108,7 @@ def test_discrete_cql(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=args.log_interval) + logger = TensorboardLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py index 6bd91678e..e8e1ba54e 100644 --- a/examples/atari/atari_crr.py +++ b/examples/atari/atari_crr.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.utils.net.discrete import Actor @@ -117,7 +117,7 @@ def test_discrete_crr(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=args.log_interval) + logger = TensorboardLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index a7785c8c6..69ec08349 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -96,7 +96,7 @@ def test_dqn(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 110a57167..4a6e97c06 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import FQFPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -112,7 +112,7 @@ def test_fqf(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'fqf') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index ad34c8f96..e5966a318 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import IQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -109,7 +109,7 @@ def test_iqn(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'iqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 6677f6837..781f81d5d 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -5,7 +5,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.policy import QRDQNPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer @@ -99,7 +99,7 @@ def test_qrdqn(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 58f1a3783..4889f7eb2 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -82,7 +82,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 5678277db..4caa50b94 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer @@ -134,7 +134,7 @@ def test_sac_bipedal(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index f5a9d3bdf..bb73ac615 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -84,7 +84,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 8fbc4257d..a43728be5 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.utils.net.common import Net @@ -104,7 +104,7 @@ def test_sac(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index bf039a187..e9debc906 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import A2CPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -141,7 +141,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'a2c', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index ebd590fd7..28bac056a 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise @@ -110,7 +110,7 @@ def test_ddpg(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 03941c51b..00b2a1a2c 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import NPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -142,7 +142,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'npg', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 681f626a1..fb3e7a0a2 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import PPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -149,7 +149,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'ppo', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index ac9682918..b7698562a 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import PGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -131,7 +131,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=10, train_interval=100) + logger = TensorboardLogger(writer, update_interval=10, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 46d7ac56e..c2dfd3618 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -122,7 +122,7 @@ def test_sac(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'sac', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 97b4e0a0c..9a0179899 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise @@ -123,7 +123,7 @@ def test_td3(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'td3', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index ad99069fa..b00f2e3d6 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import TRPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -147,7 +147,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'trpo', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 7391de087..1123151c8 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -101,7 +101,7 @@ def test_c51(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 021cc0419..3dbeb9091 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -90,7 +90,7 @@ def test_ddpg(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'ddpg') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index cfba19f11..ad758897f 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -9,7 +9,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import NPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -105,7 +105,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'npg') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 6ab4717d6..1077a332f 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -8,7 +8,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import PPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -116,7 +116,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index ad0b9af66..5b6a79492 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -109,7 +109,7 @@ def test_sac_with_il(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index ee6fa11de..8bae1edfa 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -104,7 +104,7 @@ def test_td3(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'td3') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 8c8387773..65535fd50 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -9,7 +9,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import TRPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -109,7 +109,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'trpo') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 9714219e9..098231da8 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.data import Collector, VectorReplayBuffer @@ -91,7 +91,7 @@ def test_a2c_with_il(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'a2c') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index a7fdd922a..3208e83c8 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -93,7 +93,7 @@ def test_c51(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index cb7fac403..aae609ec6 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -92,7 +92,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index c04cbc396..aa4fbbe0f 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.common import Recurrent @@ -78,7 +78,7 @@ def test_drqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'drqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 534927f12..0df2efb74 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import FQFPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -100,7 +100,7 @@ def test_fqf(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'fqf') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 1ea2db433..f404ce7dc 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -87,7 +87,7 @@ def test_discrete_bcq(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py index 736edef5f..858d2b6f7 100644 --- a/test/discrete/test_il_crr.py +++ b/test/discrete/test_il_crr.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -80,7 +80,7 @@ def test_discrete_crr(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_cql') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index f40407fb3..0234c36f0 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import IQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -96,7 +96,7 @@ def test_iqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'iqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 193117150..3c36bb265 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -78,7 +78,7 @@ def test_pg(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'pg') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index ee63b9b2a..de2418dce 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -102,7 +102,7 @@ def test_ppo(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 27c6d658d..36efe2fe3 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.policy import QRDQNPolicy from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net @@ -92,7 +92,7 @@ def test_qrdqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py index 7a782dd95..b9779d022 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/discrete/test_qrdqn_il_cql.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -79,7 +79,7 @@ def test_discrete_cql(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_cql') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 0cb2ae018..d8abf48e4 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.policy import DiscreteSACPolicy @@ -97,7 +97,7 @@ def test_discrete_sac(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'discrete_sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index ffeb6b911..01710d827 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -7,7 +7,6 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PSRLPolicy -# from tianshou.utils import BasicLogger from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv @@ -71,7 +70,6 @@ def test_psrl(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'psrl') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - # logger = BasicLogger(writer) def stop_fn(mean_rewards): if env.spec.reward_threshold: diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 4c88656cb..6418ee8ec 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -7,7 +7,7 @@ from tianshou.env import DummyVectorEnv from tianshou.data import Collector from tianshou.policy import RandomPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tic_tac_toe_env import TicTacToeEnv from tic_tac_toe import get_parser, get_agents, train_agent, watch @@ -33,7 +33,7 @@ def gomoku(args=get_args()): # log log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') writer = SummaryWriter(log_path) - args.logger = BasicLogger(writer) + args.logger = TensorboardLogger(writer) opponent_pool = [agent_opponent] diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index dc4a443a3..dcd293a06 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -135,7 +135,7 @@ def env_func(): log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): if hasattr(args, 'model_save_path'): diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index ccd873233..b16e37383 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,12 +1,14 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd -from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger +from tianshou.utils.logger import TensorboardLogger, LazyLogger, BaseLogger, \ + WandBLogger __all__ = [ "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", - "BasicLogger", + "TensorboardLogger", "LazyLogger", + "WandBLogger" ] diff --git a/tianshou/utils/logger/__init__.py b/tianshou/utils/logger/__init__.py new file mode 100644 index 000000000..e97d47ef9 --- /dev/null +++ b/tianshou/utils/logger/__init__.py @@ -0,0 +1,10 @@ +from tianshou.utils.logger.base import BaseLogger, LazyLogger +from tianshou.utils.logger.tensorboard import TensorboardLogger +from tianshou.utils.logger.wandb import WandBLogger + +__all__ = [ + "BaseLogger", + "TensorboardLogger", + "WandBLogger", + "LazyLogger" +] diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py new file mode 100644 index 000000000..e11980443 --- /dev/null +++ b/tianshou/utils/logger/base.py @@ -0,0 +1,90 @@ +import numpy as np +from numbers import Number +from abc import ABC, abstractmethod +from typing import Any, Tuple, Union, Callable, Optional + + +WRITE_TYPE = Union[int, Number, np.number, np.ndarray] + + +class BaseLogger(ABC): + """The base class for any logger which is compatible with trainer.""" + + def __init__(self, writer: Any) -> None: + super().__init__() + self.writer = writer + + @abstractmethod + def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: + """Specify how the writer is used to log data. + + :param str key: namespace which the input data tuple belongs to. + :param int x: stands for the ordinate of the input data tuple. + :param y: stands for the abscissa of the input data tuple. + """ + pass + + def log_train_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during training. + + :param collect_result: a dict containing information of data collected in + training stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + def log_update_data(self, update_result: dict, step: int) -> None: + """Use writer to log statistics generated during updating. + + :param update_result: a dict containing information of data collected in + updating stage, i.e., returns of policy.update(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + def log_test_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during evaluating. + + :param collect_result: a dict containing information of data collected in + evaluating stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + ) -> None: + """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. + + :param int epoch: the epoch in trainer. + :param int env_step: the env_step in trainer. + :param int gradient_step: the gradient_step in trainer. + :param function save_checkpoint_fn: a hook defined by user, see trainer + documentation for detail. + """ + pass + + def restore_data(self) -> Tuple[int, int, int]: + """Return the metadata from existing log. + + If it finds nothing or an error occurs during the recover process, it will + return the default parameters. + + :return: epoch, env_step, gradient_step. + """ + pass + + +class LazyLogger(BaseLogger): + """A logger that does nothing. Used as the placeholder in trainer.""" + + def __init__(self) -> None: + super().__init__(None) # type: ignore + + def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: + """The LazyLogger writes nothing.""" + pass diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/logger/tensorboard.py similarity index 60% rename from tianshou/utils/log_tools.py rename to tianshou/utils/logger/tensorboard.py index 5a25f1394..4279573ec 100644 --- a/tianshou/utils/log_tools.py +++ b/tianshou/utils/logger/tensorboard.py @@ -1,88 +1,11 @@ -import numpy as np -from numbers import Number -from abc import ABC, abstractmethod from torch.utils.tensorboard import SummaryWriter -from typing import Any, Tuple, Union, Callable, Optional +from typing import Any, Tuple, Callable, Optional from tensorboard.backend.event_processing import event_accumulator +from tianshou.utils.logger.base import BaseLogger, WRITE_TYPE -WRITE_TYPE = Union[int, Number, np.number, np.ndarray] - - -class BaseLogger(ABC): - """The base class for any logger which is compatible with trainer.""" - - def __init__(self, writer: Any) -> None: - super().__init__() - self.writer = writer - - @abstractmethod - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: - """Specify how the writer is used to log data. - - :param str key: namespace which the input data tuple belongs to. - :param int x: stands for the ordinate of the input data tuple. - :param y: stands for the abscissa of the input data tuple. - """ - pass - - def log_train_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during training. - - :param collect_result: a dict containing information of data collected in - training stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - """ - pass - - def log_update_data(self, update_result: dict, step: int) -> None: - """Use writer to log statistics generated during updating. - - :param update_result: a dict containing information of data collected in - updating stage, i.e., returns of policy.update(). - :param int step: stands for the timestep the collect_result being logged. - """ - pass - - def log_test_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during evaluating. - - :param collect_result: a dict containing information of data collected in - evaluating stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - """ - pass - - def save_data( - self, - epoch: int, - env_step: int, - gradient_step: int, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - ) -> None: - """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. - - :param int epoch: the epoch in trainer. - :param int env_step: the env_step in trainer. - :param int gradient_step: the gradient_step in trainer. - :param function save_checkpoint_fn: a hook defined by user, see trainer - documentation for detail. - """ - pass - - def restore_data(self) -> Tuple[int, int, int]: - """Return the metadata from existing log. - - If it finds nothing or an error occurs during the recover process, it will - return the default parameters. - - :return: epoch, env_step, gradient_step. - """ - pass - - -class BasicLogger(BaseLogger): - """A loggger that relies on tensorboard SummaryWriter by default to visualize \ +class TensorboardLogger(BaseLogger): + """A logger that relies on tensorboard SummaryWriter by default to visualize \ and log statistics. You can also rewrite write() func to use your own writer. @@ -197,14 +120,3 @@ def restore_data(self) -> Tuple[int, int, int]: env_step = 0 return epoch, env_step, gradient_step - - -class LazyLogger(BasicLogger): - """A loggger that does nothing. Used as the placeholder in trainer.""" - - def __init__(self) -> None: - super().__init__(None) # type: ignore - - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: - """The LazyLogger writes nothing.""" - pass diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py new file mode 100644 index 000000000..3f6ec1d98 --- /dev/null +++ b/tianshou/utils/logger/wandb.py @@ -0,0 +1,86 @@ +from tianshou.utils import BaseLogger +import wandb + + +class WandBLogger(BaseLogger): + """Weights and Biases logger that sends data to https://www.wandb.com/ + Creates three panels with plots: train, test and update. + Make sure to select the correct access for each panel in weights and biases: + + - `train/env_step` for train plots + - `test/env_ste` for test plots + - `update/gradient_step` for update plots + + Example of usage: + + with wandb.init(project="My Project"): + logger = WandBLogger() + + result = onpolicy_trainer(policy, train_collector, test_collector, + logger=logger) + + :param int train_interval: the log interval in log_train_data(). Default to 1000. + :param int test_interval: the log interval in log_test_data(). Default to 1. + :param int update_interval: the log interval in log_update_data(). + Default to 1000.""" + def __init__( + self, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000 + ) -> None: + super().__init__(writer=None) + + self.train_interval = train_interval + self.test_interval = test_interval + self.update_interval = update_interval + self.last_log_train_step = -1 + self.last_log_test_step = -1 + self.last_log_update_step = -1 + + def write(self, key, x, y, **kwargs): + pass + + def log_train_data(self, collect_result: dict, step: int) -> None: + if collect_result["n/ep"] > 0: + collect_result["rew"] = collect_result["rews"].mean() + collect_result["len"] = collect_result["lens"].mean() + if step - self.last_log_train_step >= self.train_interval: + + log_data = { + "train/env_step": step, + "train/episode": collect_result["n/ep"], + "train/reward": collect_result["rew"], + "train/length": collect_result["len"]} + wandb.log(log_data) + + self.last_log_train_step = step + + def log_test_data(self, collect_result: dict, step: int) -> None: + assert collect_result["n/ep"] > 0 + rews, lens = collect_result["rews"], collect_result["lens"] + rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() + collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) + if step - self.last_log_test_step >= self.test_interval: + + log_data = { + "test/env_step": step, + "test/reward": rew, + "test/length": len_, + "test/reward_std": rew_std, + "test/length_std": len_std} + + wandb.log(log_data) + self.last_log_test_step = step + + def log_update_data(self, update_result: dict, step: int) -> None: + if step - self.last_log_update_step >= self.update_interval: + log_data = {} + + for k, v in update_result.items(): + log_data[f'update/{k}'] = v + + log_data['update/gradient_step'] = step + wandb.log(log_data) + + self.last_log_update_step = step From d8a4d16e5b6de8a48d2d92a93fd9950dd3879216 Mon Sep 17 00:00:00 2001 From: Andriy Drozdyuk Date: Thu, 26 Aug 2021 19:43:10 -0400 Subject: [PATCH 2/5] Fixing all the tests and install deps --- setup.py | 1 + tianshou/utils/__init__.py | 6 ++++-- tianshou/utils/logger/__init__.py | 10 ---------- tianshou/utils/logger/base.py | 2 +- tianshou/utils/logger/wandb.py | 4 +++- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index 97ff849da..c45551afd 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ def get_version() -> str: "tensorboard>=2.5.0", "torch>=1.4.0", "numba>=0.51.0", + "wandb>=0.12.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements ], extras_require={ diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index b16e37383..b239b5ce2 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,7 +1,9 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd -from tianshou.utils.logger import TensorboardLogger, LazyLogger, BaseLogger, \ - WandBLogger +from tianshou.utils.logger.base import BaseLogger, LazyLogger +from tianshou.utils.logger.tensorboard import TensorboardLogger +from tianshou.utils.logger.wandb import WandBLogger + __all__ = [ "MovAvg", diff --git a/tianshou/utils/logger/__init__.py b/tianshou/utils/logger/__init__.py index e97d47ef9..e69de29bb 100644 --- a/tianshou/utils/logger/__init__.py +++ b/tianshou/utils/logger/__init__.py @@ -1,10 +0,0 @@ -from tianshou.utils.logger.base import BaseLogger, LazyLogger -from tianshou.utils.logger.tensorboard import TensorboardLogger -from tianshou.utils.logger.wandb import WandBLogger - -__all__ = [ - "BaseLogger", - "TensorboardLogger", - "WandBLogger", - "LazyLogger" -] diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index e11980443..24ffdc8ed 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -83,7 +83,7 @@ class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" def __init__(self) -> None: - super().__init__(None) # type: ignore + super().__init__(None) def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: """The LazyLogger writes nothing.""" diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 3f6ec1d98..6de681204 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,4 +1,6 @@ from tianshou.utils import BaseLogger +from tianshou.utils.logger.base import WRITE_TYPE +from typing import Any import wandb @@ -38,7 +40,7 @@ def __init__( self.last_log_test_step = -1 self.last_log_update_step = -1 - def write(self, key, x, y, **kwargs): + def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: pass def log_train_data(self, collect_result: dict, step: int) -> None: From 72e8c4073de0754325e940ba86b643973384c808 Mon Sep 17 00:00:00 2001 From: Andriy Drozdyuk Date: Thu, 26 Aug 2021 20:02:11 -0400 Subject: [PATCH 3/5] fix docstring issues --- tianshou/utils/logger/wandb.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 6de681204..7395e9715 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -5,7 +5,8 @@ class WandBLogger(BaseLogger): - """Weights and Biases logger that sends data to https://www.wandb.com/ + """Weights and Biases logger that sends data to Weights and Biases. + Creates three panels with plots: train, test and update. Make sure to select the correct access for each panel in weights and biases: @@ -25,6 +26,7 @@ class WandBLogger(BaseLogger): :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000.""" + def __init__( self, train_interval: int = 1000, From 1493cf7e2da02bf6277494cd270fb127390084ce Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sun, 29 Aug 2021 14:34:51 -0400 Subject: [PATCH 4/5] fix ci --- examples/atari/atari_rainbow.py | 4 +-- test/discrete/test_rainbow.py | 4 +-- tianshou/utils/__init__.py | 3 +- tianshou/utils/logger/tensorboard.py | 18 +++++++++++- tianshou/utils/logger/wandb.py | 41 +++++++++++++--------------- 5 files changed, 42 insertions(+), 28 deletions(-) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 941df5c80..f2f44f0cd 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import RainbowPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer @@ -121,7 +121,7 @@ def test_rainbow(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index cb48fe89c..4fdcfd352 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import RainbowPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear @@ -102,7 +102,7 @@ def noisy_linear(x, y): # log log_path = os.path.join(args.logdir, args.task, 'rainbow') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index b239b5ce2..4ad73481c 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,7 +1,7 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.logger.base import BaseLogger, LazyLogger -from tianshou.utils.logger.tensorboard import TensorboardLogger +from tianshou.utils.logger.tensorboard import TensorboardLogger, BasicLogger from tianshou.utils.logger.wandb import WandBLogger @@ -11,6 +11,7 @@ "tqdm_config", "BaseLogger", "TensorboardLogger", + "BasicLogger", "LazyLogger", "WandBLogger" ] diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 4279573ec..c0cbd5fa8 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -1,6 +1,9 @@ -from torch.utils.tensorboard import SummaryWriter +import warnings from typing import Any, Tuple, Callable, Optional + +from torch.utils.tensorboard import SummaryWriter from tensorboard.backend.event_processing import event_accumulator + from tianshou.utils.logger.base import BaseLogger, WRITE_TYPE @@ -120,3 +123,16 @@ def restore_data(self) -> Tuple[int, int, int]: env_step = 0 return epoch, env_step, gradient_step + + +class BasicLogger(TensorboardLogger): + """BasicLogger has changed its name to TensorboardLogger in #427. + + This class is for compatibility. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "BasicLogger has renamed to TensorboardLogger in #427.", + DeprecationWarning) + super().__init__(*args, **kwargs) diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 7395e9715..a5573ab54 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,40 +1,44 @@ from tianshou.utils import BaseLogger from tianshou.utils.logger.base import WRITE_TYPE from typing import Any -import wandb + +try: + import wandb +except ImportError: + pass class WandBLogger(BaseLogger): """Weights and Biases logger that sends data to Weights and Biases. - Creates three panels with plots: train, test and update. + Creates three panels with plots: train, test, and update. Make sure to select the correct access for each panel in weights and biases: - - `train/env_step` for train plots - - `test/env_ste` for test plots - - `update/gradient_step` for update plots + - ``train/env_step`` for train plots + - ``test/env_step`` for test plots + - ``update/gradient_step`` for update plots Example of usage: + :: with wandb.init(project="My Project"): logger = WandBLogger() - result = onpolicy_trainer(policy, train_collector, test_collector, logger=logger) :param int train_interval: the log interval in log_train_data(). Default to 1000. :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). - Default to 1000.""" + Default to 1000. + """ def __init__( self, train_interval: int = 1000, test_interval: int = 1, - update_interval: int = 1000 + update_interval: int = 1000, ) -> None: super().__init__(writer=None) - self.train_interval = train_interval self.test_interval = test_interval self.update_interval = update_interval @@ -50,14 +54,13 @@ def log_train_data(self, collect_result: dict, step: int) -> None: collect_result["rew"] = collect_result["rews"].mean() collect_result["len"] = collect_result["lens"].mean() if step - self.last_log_train_step >= self.train_interval: - log_data = { "train/env_step": step, "train/episode": collect_result["n/ep"], "train/reward": collect_result["rew"], - "train/length": collect_result["len"]} + "train/length": collect_result["len"], + } wandb.log(log_data) - self.last_log_train_step = step def log_test_data(self, collect_result: dict, step: int) -> None: @@ -66,25 +69,19 @@ def log_test_data(self, collect_result: dict, step: int) -> None: rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) if step - self.last_log_test_step >= self.test_interval: - log_data = { "test/env_step": step, "test/reward": rew, "test/length": len_, "test/reward_std": rew_std, - "test/length_std": len_std} - + "test/length_std": len_std, + } wandb.log(log_data) self.last_log_test_step = step def log_update_data(self, update_result: dict, step: int) -> None: if step - self.last_log_update_step >= self.update_interval: - log_data = {} - - for k, v in update_result.items(): - log_data[f'update/{k}'] = v - - log_data['update/gradient_step'] = step + log_data = {f"update/{k}": v for k, v in update_result.items()} + log_data["update/gradient_step"] = step wandb.log(log_data) - self.last_log_update_step = step From 268ce9d62de87f557005686ef760fcfec5b5865c Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sun, 29 Aug 2021 15:32:30 -0400 Subject: [PATCH 5/5] refactor logger to be more concise --- setup.py | 2 +- tianshou/policy/modelfree/ppo.py | 2 +- tianshou/utils/logger/base.py | 93 +++++++++++++++++++++------- tianshou/utils/logger/tensorboard.py | 77 ++++------------------- tianshou/utils/logger/wandb.py | 53 ++-------------- 5 files changed, 90 insertions(+), 137 deletions(-) diff --git a/setup.py b/setup.py index c45551afd..208af8e14 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,6 @@ def get_version() -> str: "tensorboard>=2.5.0", "torch>=1.4.0", "numba>=0.51.0", - "wandb>=0.12.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements ], extras_require={ @@ -64,6 +63,7 @@ def get_version() -> str: "pytest", "pytest-cov", "ray>=1.0.0", + "wandb>=0.12.0", "networkx", "mypy", "pydocstyle", diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index aaf34bd3e..e66a5f882 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -22,7 +22,7 @@ class PPOPolicy(A2CPolicy): :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, where c > 1 is a constant indicating the lower bound. Default to 5.0 (set None if you do not want to use it). - :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. + :param bool value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. Default to True. :param bool advantage_normalization: whether to do per mini-batch advantage normalization. Default to True. diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 24ffdc8ed..c1ffe760d 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -1,26 +1,42 @@ import numpy as np from numbers import Number from abc import ABC, abstractmethod -from typing import Any, Tuple, Union, Callable, Optional +from typing import Dict, Tuple, Union, Callable, Optional - -WRITE_TYPE = Union[int, Number, np.number, np.ndarray] +LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]] class BaseLogger(ABC): - """The base class for any logger which is compatible with trainer.""" + """The base class for any logger which is compatible with trainer. + + Try to overwrite write() method to use your own writer. - def __init__(self, writer: Any) -> None: + :param int train_interval: the log interval in log_train_data(). Default to 1000. + :param int test_interval: the log interval in log_test_data(). Default to 1. + :param int update_interval: the log interval in log_update_data(). Default to 1000. + """ + + def __init__( + self, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + ) -> None: super().__init__() - self.writer = writer + self.train_interval = train_interval + self.test_interval = test_interval + self.update_interval = update_interval + self.last_log_train_step = -1 + self.last_log_test_step = -1 + self.last_log_update_step = -1 @abstractmethod - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: """Specify how the writer is used to log data. - :param str key: namespace which the input data tuple belongs to. - :param int x: stands for the ordinate of the input data tuple. - :param y: stands for the abscissa of the input data tuple. + :param str step_type: namespace which the data dict belongs to. + :param int step: stands for the ordinate of the data dict. + :param dict data: the data to write with format ``{key: value}``. """ pass @@ -30,17 +46,22 @@ def log_train_data(self, collect_result: dict, step: int) -> None: :param collect_result: a dict containing information of data collected in training stage, i.e., returns of collector.collect(). :param int step: stands for the timestep the collect_result being logged. - """ - pass - def log_update_data(self, update_result: dict, step: int) -> None: - """Use writer to log statistics generated during updating. + .. note:: - :param update_result: a dict containing information of data collected in - updating stage, i.e., returns of policy.update(). - :param int step: stands for the timestep the collect_result being logged. + ``collect_result`` will be modified in-place with "rew" and "len" keys. """ - pass + if collect_result["n/ep"] > 0: + collect_result["rew"] = collect_result["rews"].mean() + collect_result["len"] = collect_result["lens"].mean() + if step - self.last_log_train_step >= self.train_interval: + log_data = { + "train/episode": collect_result["n/ep"], + "train/reward": collect_result["rew"], + "train/length": collect_result["len"], + } + self.write("train/env_step", step, log_data) + self.last_log_train_step = step def log_test_data(self, collect_result: dict, step: int) -> None: """Use writer to log statistics generated during evaluating. @@ -48,8 +69,38 @@ def log_test_data(self, collect_result: dict, step: int) -> None: :param collect_result: a dict containing information of data collected in evaluating stage, i.e., returns of collector.collect(). :param int step: stands for the timestep the collect_result being logged. + + .. note:: + + ``collect_result`` will be modified in-place with "rew", "rew_std", "len", + and "len_std" keys. """ - pass + assert collect_result["n/ep"] > 0 + rews, lens = collect_result["rews"], collect_result["lens"] + rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() + collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) + if step - self.last_log_test_step >= self.test_interval: + log_data = { + "test/env_step": step, + "test/reward": rew, + "test/length": len_, + "test/reward_std": rew_std, + "test/length_std": len_std, + } + self.write("test/env_step", step, log_data) + self.last_log_test_step = step + + def log_update_data(self, update_result: dict, step: int) -> None: + """Use writer to log statistics generated during updating. + + :param update_result: a dict containing information of data collected in + updating stage, i.e., returns of policy.update(). + :param int step: stands for the timestep the collect_result being logged. + """ + if step - self.last_log_update_step >= self.update_interval: + log_data = {f"update/{k}": v for k, v in update_result.items()} + self.write("update/gradient_step", step, log_data) + self.last_log_update_step = step def save_data( self, @@ -83,8 +134,8 @@ class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" def __init__(self) -> None: - super().__init__(None) + super().__init__() - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: """The LazyLogger writes nothing.""" pass diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index c0cbd5fa8..c65576b41 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -4,15 +4,13 @@ from torch.utils.tensorboard import SummaryWriter from tensorboard.backend.event_processing import event_accumulator -from tianshou.utils.logger.base import BaseLogger, WRITE_TYPE +from tianshou.utils.logger.base import BaseLogger, LOG_DATA_TYPE class TensorboardLogger(BaseLogger): """A logger that relies on tensorboard SummaryWriter by default to visualize \ and log statistics. - You can also rewrite write() func to use your own writer. - :param SummaryWriter writer: the writer to log data. :param int train_interval: the log interval in log_train_data(). Default to 1000. :param int test_interval: the log interval in log_test_data(). Default to 1. @@ -29,67 +27,14 @@ def __init__( update_interval: int = 1000, save_interval: int = 1, ) -> None: - super().__init__(writer) - self.train_interval = train_interval - self.test_interval = test_interval - self.update_interval = update_interval + super().__init__(train_interval, test_interval, update_interval) self.save_interval = save_interval - self.last_log_train_step = -1 - self.last_log_test_step = -1 - self.last_log_update_step = -1 self.last_save_step = -1 + self.writer = writer - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: - self.writer.add_scalar(key, y, global_step=x) - - def log_train_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during training. - - :param collect_result: a dict containing information of data collected in - training stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - - .. note:: - - ``collect_result`` will be modified in-place with "rew" and "len" keys. - """ - if collect_result["n/ep"] > 0: - collect_result["rew"] = collect_result["rews"].mean() - collect_result["len"] = collect_result["lens"].mean() - if step - self.last_log_train_step >= self.train_interval: - self.write("train/n/ep", step, collect_result["n/ep"]) - self.write("train/rew", step, collect_result["rew"]) - self.write("train/len", step, collect_result["len"]) - self.last_log_train_step = step - - def log_test_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during evaluating. - - :param collect_result: a dict containing information of data collected in - evaluating stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - - .. note:: - - ``collect_result`` will be modified in-place with "rew", "rew_std", "len", - and "len_std" keys. - """ - assert collect_result["n/ep"] > 0 - rews, lens = collect_result["rews"], collect_result["lens"] - rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() - collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) - if step - self.last_log_test_step >= self.test_interval: - self.write("test/rew", step, rew) - self.write("test/len", step, len_) - self.write("test/rew_std", step, rew_std) - self.write("test/len_std", step, len_std) - self.last_log_test_step = step - - def log_update_data(self, update_result: dict, step: int) -> None: - if step - self.last_log_update_step >= self.update_interval: - for k, v in update_result.items(): - self.write(k, step, v) - self.last_log_update_step = step + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + for k, v in data.items(): + self.writer.add_scalar(k, v, global_step=step) def save_data( self, @@ -101,9 +46,10 @@ def save_data( if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: self.last_save_step = epoch save_checkpoint_fn(epoch, env_step, gradient_step) - self.write("save/epoch", epoch, epoch) - self.write("save/env_step", env_step, env_step) - self.write("save/gradient_step", gradient_step, gradient_step) + self.write("save/epoch", epoch, {"save/epoch": epoch}) + self.write("save/env_step", env_step, {"save/env_step": env_step}) + self.write("save/gradient_step", gradient_step, + {"save/gradient_step": gradient_step}) def restore_data(self) -> Tuple[int, int, int]: ea = event_accumulator.EventAccumulator(self.writer.log_dir) @@ -133,6 +79,5 @@ class BasicLogger(TensorboardLogger): def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( - "BasicLogger has renamed to TensorboardLogger in #427.", - DeprecationWarning) + "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427.") super().__init__(*args, **kwargs) diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index a5573ab54..7a837c96c 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,6 +1,5 @@ from tianshou.utils import BaseLogger -from tianshou.utils.logger.base import WRITE_TYPE -from typing import Any +from tianshou.utils.logger.base import LOG_DATA_TYPE try: import wandb @@ -38,50 +37,8 @@ def __init__( test_interval: int = 1, update_interval: int = 1000, ) -> None: - super().__init__(writer=None) - self.train_interval = train_interval - self.test_interval = test_interval - self.update_interval = update_interval - self.last_log_train_step = -1 - self.last_log_test_step = -1 - self.last_log_update_step = -1 + super().__init__(train_interval, test_interval, update_interval) - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: - pass - - def log_train_data(self, collect_result: dict, step: int) -> None: - if collect_result["n/ep"] > 0: - collect_result["rew"] = collect_result["rews"].mean() - collect_result["len"] = collect_result["lens"].mean() - if step - self.last_log_train_step >= self.train_interval: - log_data = { - "train/env_step": step, - "train/episode": collect_result["n/ep"], - "train/reward": collect_result["rew"], - "train/length": collect_result["len"], - } - wandb.log(log_data) - self.last_log_train_step = step - - def log_test_data(self, collect_result: dict, step: int) -> None: - assert collect_result["n/ep"] > 0 - rews, lens = collect_result["rews"], collect_result["lens"] - rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() - collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) - if step - self.last_log_test_step >= self.test_interval: - log_data = { - "test/env_step": step, - "test/reward": rew, - "test/length": len_, - "test/reward_std": rew_std, - "test/length_std": len_std, - } - wandb.log(log_data) - self.last_log_test_step = step - - def log_update_data(self, update_result: dict, step: int) -> None: - if step - self.last_log_update_step >= self.update_interval: - log_data = {f"update/{k}": v for k, v in update_result.items()} - log_data["update/gradient_step"] = step - wandb.log(log_data) - self.last_log_update_step = step + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + data[step_type] = step + wandb.log(data)