diff --git a/README.md b/README.md index f189f369e..9e2d7a9da 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 step_per_epoch, step_per_collect = 10000, 10 writer = SummaryWriter('log/dqn') # tensorboard is also supported! -logger = ts.utils.BasicLogger(writer) +logger = ts.utils.TensorboardLogger(writer) ``` Make environments: diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 1fd59ab2a..38a0291f4 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -40,8 +40,8 @@ This is related to `Issue 349 `_. To resume training process from an existing checkpoint, you need to do the following things in the training process: 1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer; -2. Use ``BasicLogger`` which contains a tensorboard; -3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger. +2. Use ``TensorboardLogger``; +3. To adjust the save frequency, specify ``save_interval`` when initializing TensorboardLogger. And to successfully resume from a checkpoint: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 412051e49..0126c7fca 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -148,9 +148,9 @@ The trainer supports `TensorBoard `_ for :: from torch.utils.tensorboard import SummaryWriter - from tianshou.utils import BasicLogger + from tianshou.utils import TensorboardLogger writer = SummaryWriter('log/dqn') - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) Pass the logger into the trainer, and the training result will be recorded into the TensorBoard. diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 64b0dfd70..58e69906a 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -176,7 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul import numpy as np from copy import deepcopy from torch.utils.tensorboard import SummaryWriter - from tianshou.utils import BasicLogger + from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net @@ -323,7 +323,7 @@ With the above preparation, we are close to the first learned agent. The followi log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) # ======== callback functions used during training ========= diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 142edd816..b4cd5b62d 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.utils.net.discrete import Actor @@ -116,7 +116,7 @@ def test_discrete_bcq(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=args.log_interval) + logger = TensorboardLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 558284f47..b70cc6cc0 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -101,7 +101,7 @@ def test_c51(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_cql.py b/examples/atari/atari_cql.py index ff8650608..cbab82029 100644 --- a/examples/atari/atari_cql.py +++ b/examples/atari/atari_cql.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.policy import DiscreteCQLPolicy @@ -108,7 +108,7 @@ def test_discrete_cql(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=args.log_interval) + logger = TensorboardLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py index 6bd91678e..e8e1ba54e 100644 --- a/examples/atari/atari_crr.py +++ b/examples/atari/atari_crr.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.utils.net.discrete import Actor @@ -117,7 +117,7 @@ def test_discrete_crr(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=args.log_interval) + logger = TensorboardLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index a7785c8c6..69ec08349 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -96,7 +96,7 @@ def test_dqn(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 110a57167..4a6e97c06 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import FQFPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -112,7 +112,7 @@ def test_fqf(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'fqf') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index ad34c8f96..e5966a318 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import IQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -109,7 +109,7 @@ def test_iqn(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'iqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 6677f6837..781f81d5d 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -5,7 +5,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.policy import QRDQNPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer @@ -99,7 +99,7 @@ def test_qrdqn(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 941df5c80..f2f44f0cd 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import RainbowPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer @@ -121,7 +121,7 @@ def test_rainbow(args=get_args()): f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 58f1a3783..4889f7eb2 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -82,7 +82,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 5678277db..4caa50b94 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer @@ -134,7 +134,7 @@ def test_sac_bipedal(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index f5a9d3bdf..bb73ac615 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -84,7 +84,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 8fbc4257d..a43728be5 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.utils.net.common import Net @@ -104,7 +104,7 @@ def test_sac(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index bf039a187..e9debc906 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import A2CPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -141,7 +141,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'a2c', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index ebd590fd7..28bac056a 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise @@ -110,7 +110,7 @@ def test_ddpg(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 03941c51b..00b2a1a2c 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import NPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -142,7 +142,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'npg', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 681f626a1..fb3e7a0a2 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import PPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -149,7 +149,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'ppo', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index ac9682918..b7698562a 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import PGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -131,7 +131,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=10, train_interval=100) + logger = TensorboardLogger(writer, update_interval=10, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 46d7ac56e..c2dfd3618 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -122,7 +122,7 @@ def test_sac(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'sac', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 97b4e0a0c..9a0179899 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -10,7 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise @@ -123,7 +123,7 @@ def test_td3(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'td3', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index ad99069fa..b00f2e3d6 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -13,7 +13,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import TRPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -147,7 +147,7 @@ def dist(*logits): log_path = os.path.join(args.logdir, args.task, 'trpo', log_file) writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer, update_interval=100, train_interval=100) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 7391de087..1123151c8 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -6,7 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -101,7 +101,7 @@ def test_c51(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/setup.py b/setup.py index 97ff849da..208af8e14 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,7 @@ def get_version() -> str: "pytest", "pytest-cov", "ray>=1.0.0", + "wandb>=0.12.0", "networkx", "mypy", "pydocstyle", diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 021cc0419..3dbeb9091 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -90,7 +90,7 @@ def test_ddpg(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'ddpg') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index cfba19f11..ad758897f 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -9,7 +9,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import NPGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -105,7 +105,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'npg') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 98d69a269..09c73fdd5 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -8,7 +8,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import PPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -116,7 +116,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index ad0b9af66..5b6a79492 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -109,7 +109,7 @@ def test_sac_with_il(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index ee6fa11de..8bae1edfa 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -104,7 +104,7 @@ def test_td3(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'td3') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 8c8387773..65535fd50 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -9,7 +9,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import TRPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -109,7 +109,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'trpo') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index fde783da5..d11ce360c 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.data import Collector, VectorReplayBuffer @@ -91,7 +91,7 @@ def test_a2c_with_il(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'a2c') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index a7fdd922a..3208e83c8 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -93,7 +93,7 @@ def test_c51(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index cb7fac403..aae609ec6 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -92,7 +92,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index c04cbc396..aa4fbbe0f 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.common import Recurrent @@ -78,7 +78,7 @@ def test_drqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'drqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 534927f12..0df2efb74 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import FQFPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -100,7 +100,7 @@ def test_fqf(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'fqf') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 1ea2db433..f404ce7dc 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -87,7 +87,7 @@ def test_discrete_bcq(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py index 736edef5f..858d2b6f7 100644 --- a/test/discrete/test_il_crr.py +++ b/test/discrete/test_il_crr.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -80,7 +80,7 @@ def test_discrete_crr(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_cql') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index f40407fb3..0234c36f0 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import IQNPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -96,7 +96,7 @@ def test_iqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'iqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 193117150..3c36bb265 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PGPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -78,7 +78,7 @@ def test_pg(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'pg') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 2cf6d49ca..3658f364b 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PPOPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -102,7 +102,7 @@ def test_ppo(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index f6bf5ae6b..2385db0ee 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -7,7 +7,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.policy import QRDQNPolicy from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net @@ -94,7 +94,7 @@ def test_qrdqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py index 423506bd7..dbfd42aad 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/discrete/test_qrdqn_il_cql.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -79,7 +79,7 @@ def test_discrete_cql(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_cql') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index cb48fe89c..4fdcfd352 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import RainbowPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import NoisyLinear @@ -102,7 +102,7 @@ def noisy_linear(x, y): # log log_path = os.path.join(args.logdir, args.task, 'rainbow') writer = SummaryWriter(log_path) - logger = BasicLogger(writer, save_interval=args.save_interval) + logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 0cb2ae018..d8abf48e4 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -6,7 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.policy import DiscreteSACPolicy @@ -97,7 +97,7 @@ def test_discrete_sac(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'discrete_sac') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index ffeb6b911..01710d827 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -7,7 +7,6 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PSRLPolicy -# from tianshou.utils import BasicLogger from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv @@ -71,7 +70,6 @@ def test_psrl(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'psrl') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - # logger = BasicLogger(writer) def stop_fn(mean_rewards): if env.spec.reward_threshold: diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 4c88656cb..6418ee8ec 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -7,7 +7,7 @@ from tianshou.env import DummyVectorEnv from tianshou.data import Collector from tianshou.policy import RandomPolicy -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tic_tac_toe_env import TicTacToeEnv from tic_tac_toe import get_parser, get_agents, train_agent, watch @@ -33,7 +33,7 @@ def gomoku(args=get_args()): # log log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') writer = SummaryWriter(log_path) - args.logger = BasicLogger(writer) + args.logger = TensorboardLogger(writer) opponent_pool = [agent_opponent] diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index dc4a443a3..dcd293a06 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BasicLogger +from tianshou.utils import TensorboardLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -135,7 +135,7 @@ def env_func(): log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') writer = SummaryWriter(log_path) writer.add_text("args", str(args)) - logger = BasicLogger(writer) + logger = TensorboardLogger(writer) def save_fn(policy): if hasattr(args, 'model_save_path'): diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index aaf34bd3e..e66a5f882 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -22,7 +22,7 @@ class PPOPolicy(A2CPolicy): :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, where c > 1 is a constant indicating the lower bound. Default to 5.0 (set None if you do not want to use it). - :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. + :param bool value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1. Default to True. :param bool advantage_normalization: whether to do per mini-batch advantage normalization. Default to True. diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index ccd873233..4ad73481c 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,12 +1,17 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.statistics import MovAvg, RunningMeanStd -from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger +from tianshou.utils.logger.base import BaseLogger, LazyLogger +from tianshou.utils.logger.tensorboard import TensorboardLogger, BasicLogger +from tianshou.utils.logger.wandb import WandBLogger + __all__ = [ "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", + "TensorboardLogger", "BasicLogger", "LazyLogger", + "WandBLogger" ] diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py deleted file mode 100644 index 5a25f1394..000000000 --- a/tianshou/utils/log_tools.py +++ /dev/null @@ -1,210 +0,0 @@ -import numpy as np -from numbers import Number -from abc import ABC, abstractmethod -from torch.utils.tensorboard import SummaryWriter -from typing import Any, Tuple, Union, Callable, Optional -from tensorboard.backend.event_processing import event_accumulator - - -WRITE_TYPE = Union[int, Number, np.number, np.ndarray] - - -class BaseLogger(ABC): - """The base class for any logger which is compatible with trainer.""" - - def __init__(self, writer: Any) -> None: - super().__init__() - self.writer = writer - - @abstractmethod - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: - """Specify how the writer is used to log data. - - :param str key: namespace which the input data tuple belongs to. - :param int x: stands for the ordinate of the input data tuple. - :param y: stands for the abscissa of the input data tuple. - """ - pass - - def log_train_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during training. - - :param collect_result: a dict containing information of data collected in - training stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - """ - pass - - def log_update_data(self, update_result: dict, step: int) -> None: - """Use writer to log statistics generated during updating. - - :param update_result: a dict containing information of data collected in - updating stage, i.e., returns of policy.update(). - :param int step: stands for the timestep the collect_result being logged. - """ - pass - - def log_test_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during evaluating. - - :param collect_result: a dict containing information of data collected in - evaluating stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - """ - pass - - def save_data( - self, - epoch: int, - env_step: int, - gradient_step: int, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - ) -> None: - """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. - - :param int epoch: the epoch in trainer. - :param int env_step: the env_step in trainer. - :param int gradient_step: the gradient_step in trainer. - :param function save_checkpoint_fn: a hook defined by user, see trainer - documentation for detail. - """ - pass - - def restore_data(self) -> Tuple[int, int, int]: - """Return the metadata from existing log. - - If it finds nothing or an error occurs during the recover process, it will - return the default parameters. - - :return: epoch, env_step, gradient_step. - """ - pass - - -class BasicLogger(BaseLogger): - """A loggger that relies on tensorboard SummaryWriter by default to visualize \ - and log statistics. - - You can also rewrite write() func to use your own writer. - - :param SummaryWriter writer: the writer to log data. - :param int train_interval: the log interval in log_train_data(). Default to 1000. - :param int test_interval: the log interval in log_test_data(). Default to 1. - :param int update_interval: the log interval in log_update_data(). Default to 1000. - :param int save_interval: the save interval in save_data(). Default to 1 (save at - the end of each epoch). - """ - - def __init__( - self, - writer: SummaryWriter, - train_interval: int = 1000, - test_interval: int = 1, - update_interval: int = 1000, - save_interval: int = 1, - ) -> None: - super().__init__(writer) - self.train_interval = train_interval - self.test_interval = test_interval - self.update_interval = update_interval - self.save_interval = save_interval - self.last_log_train_step = -1 - self.last_log_test_step = -1 - self.last_log_update_step = -1 - self.last_save_step = -1 - - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: - self.writer.add_scalar(key, y, global_step=x) - - def log_train_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during training. - - :param collect_result: a dict containing information of data collected in - training stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - - .. note:: - - ``collect_result`` will be modified in-place with "rew" and "len" keys. - """ - if collect_result["n/ep"] > 0: - collect_result["rew"] = collect_result["rews"].mean() - collect_result["len"] = collect_result["lens"].mean() - if step - self.last_log_train_step >= self.train_interval: - self.write("train/n/ep", step, collect_result["n/ep"]) - self.write("train/rew", step, collect_result["rew"]) - self.write("train/len", step, collect_result["len"]) - self.last_log_train_step = step - - def log_test_data(self, collect_result: dict, step: int) -> None: - """Use writer to log statistics generated during evaluating. - - :param collect_result: a dict containing information of data collected in - evaluating stage, i.e., returns of collector.collect(). - :param int step: stands for the timestep the collect_result being logged. - - .. note:: - - ``collect_result`` will be modified in-place with "rew", "rew_std", "len", - and "len_std" keys. - """ - assert collect_result["n/ep"] > 0 - rews, lens = collect_result["rews"], collect_result["lens"] - rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() - collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) - if step - self.last_log_test_step >= self.test_interval: - self.write("test/rew", step, rew) - self.write("test/len", step, len_) - self.write("test/rew_std", step, rew_std) - self.write("test/len_std", step, len_std) - self.last_log_test_step = step - - def log_update_data(self, update_result: dict, step: int) -> None: - if step - self.last_log_update_step >= self.update_interval: - for k, v in update_result.items(): - self.write(k, step, v) - self.last_log_update_step = step - - def save_data( - self, - epoch: int, - env_step: int, - gradient_step: int, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - ) -> None: - if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: - self.last_save_step = epoch - save_checkpoint_fn(epoch, env_step, gradient_step) - self.write("save/epoch", epoch, epoch) - self.write("save/env_step", env_step, env_step) - self.write("save/gradient_step", gradient_step, gradient_step) - - def restore_data(self) -> Tuple[int, int, int]: - ea = event_accumulator.EventAccumulator(self.writer.log_dir) - ea.Reload() - - try: # epoch / gradient_step - epoch = ea.scalars.Items("save/epoch")[-1].step - self.last_save_step = self.last_log_test_step = epoch - gradient_step = ea.scalars.Items("save/gradient_step")[-1].step - self.last_log_update_step = gradient_step - except KeyError: - epoch, gradient_step = 0, 0 - try: # offline trainer doesn't have env_step - env_step = ea.scalars.Items("save/env_step")[-1].step - self.last_log_train_step = env_step - except KeyError: - env_step = 0 - - return epoch, env_step, gradient_step - - -class LazyLogger(BasicLogger): - """A loggger that does nothing. Used as the placeholder in trainer.""" - - def __init__(self) -> None: - super().__init__(None) # type: ignore - - def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: - """The LazyLogger writes nothing.""" - pass diff --git a/tianshou/utils/logger/__init__.py b/tianshou/utils/logger/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py new file mode 100644 index 000000000..c1ffe760d --- /dev/null +++ b/tianshou/utils/logger/base.py @@ -0,0 +1,141 @@ +import numpy as np +from numbers import Number +from abc import ABC, abstractmethod +from typing import Dict, Tuple, Union, Callable, Optional + +LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]] + + +class BaseLogger(ABC): + """The base class for any logger which is compatible with trainer. + + Try to overwrite write() method to use your own writer. + + :param int train_interval: the log interval in log_train_data(). Default to 1000. + :param int test_interval: the log interval in log_test_data(). Default to 1. + :param int update_interval: the log interval in log_update_data(). Default to 1000. + """ + + def __init__( + self, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + ) -> None: + super().__init__() + self.train_interval = train_interval + self.test_interval = test_interval + self.update_interval = update_interval + self.last_log_train_step = -1 + self.last_log_test_step = -1 + self.last_log_update_step = -1 + + @abstractmethod + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + """Specify how the writer is used to log data. + + :param str step_type: namespace which the data dict belongs to. + :param int step: stands for the ordinate of the data dict. + :param dict data: the data to write with format ``{key: value}``. + """ + pass + + def log_train_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during training. + + :param collect_result: a dict containing information of data collected in + training stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + + .. note:: + + ``collect_result`` will be modified in-place with "rew" and "len" keys. + """ + if collect_result["n/ep"] > 0: + collect_result["rew"] = collect_result["rews"].mean() + collect_result["len"] = collect_result["lens"].mean() + if step - self.last_log_train_step >= self.train_interval: + log_data = { + "train/episode": collect_result["n/ep"], + "train/reward": collect_result["rew"], + "train/length": collect_result["len"], + } + self.write("train/env_step", step, log_data) + self.last_log_train_step = step + + def log_test_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during evaluating. + + :param collect_result: a dict containing information of data collected in + evaluating stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + + .. note:: + + ``collect_result`` will be modified in-place with "rew", "rew_std", "len", + and "len_std" keys. + """ + assert collect_result["n/ep"] > 0 + rews, lens = collect_result["rews"], collect_result["lens"] + rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() + collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) + if step - self.last_log_test_step >= self.test_interval: + log_data = { + "test/env_step": step, + "test/reward": rew, + "test/length": len_, + "test/reward_std": rew_std, + "test/length_std": len_std, + } + self.write("test/env_step", step, log_data) + self.last_log_test_step = step + + def log_update_data(self, update_result: dict, step: int) -> None: + """Use writer to log statistics generated during updating. + + :param update_result: a dict containing information of data collected in + updating stage, i.e., returns of policy.update(). + :param int step: stands for the timestep the collect_result being logged. + """ + if step - self.last_log_update_step >= self.update_interval: + log_data = {f"update/{k}": v for k, v in update_result.items()} + self.write("update/gradient_step", step, log_data) + self.last_log_update_step = step + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + ) -> None: + """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. + + :param int epoch: the epoch in trainer. + :param int env_step: the env_step in trainer. + :param int gradient_step: the gradient_step in trainer. + :param function save_checkpoint_fn: a hook defined by user, see trainer + documentation for detail. + """ + pass + + def restore_data(self) -> Tuple[int, int, int]: + """Return the metadata from existing log. + + If it finds nothing or an error occurs during the recover process, it will + return the default parameters. + + :return: epoch, env_step, gradient_step. + """ + pass + + +class LazyLogger(BaseLogger): + """A logger that does nothing. Used as the placeholder in trainer.""" + + def __init__(self) -> None: + super().__init__() + + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + """The LazyLogger writes nothing.""" + pass diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py new file mode 100644 index 000000000..c65576b41 --- /dev/null +++ b/tianshou/utils/logger/tensorboard.py @@ -0,0 +1,83 @@ +import warnings +from typing import Any, Tuple, Callable, Optional + +from torch.utils.tensorboard import SummaryWriter +from tensorboard.backend.event_processing import event_accumulator + +from tianshou.utils.logger.base import BaseLogger, LOG_DATA_TYPE + + +class TensorboardLogger(BaseLogger): + """A logger that relies on tensorboard SummaryWriter by default to visualize \ + and log statistics. + + :param SummaryWriter writer: the writer to log data. + :param int train_interval: the log interval in log_train_data(). Default to 1000. + :param int test_interval: the log interval in log_test_data(). Default to 1. + :param int update_interval: the log interval in log_update_data(). Default to 1000. + :param int save_interval: the save interval in save_data(). Default to 1 (save at + the end of each epoch). + """ + + def __init__( + self, + writer: SummaryWriter, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + save_interval: int = 1, + ) -> None: + super().__init__(train_interval, test_interval, update_interval) + self.save_interval = save_interval + self.last_save_step = -1 + self.writer = writer + + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + for k, v in data.items(): + self.writer.add_scalar(k, v, global_step=step) + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + ) -> None: + if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: + self.last_save_step = epoch + save_checkpoint_fn(epoch, env_step, gradient_step) + self.write("save/epoch", epoch, {"save/epoch": epoch}) + self.write("save/env_step", env_step, {"save/env_step": env_step}) + self.write("save/gradient_step", gradient_step, + {"save/gradient_step": gradient_step}) + + def restore_data(self) -> Tuple[int, int, int]: + ea = event_accumulator.EventAccumulator(self.writer.log_dir) + ea.Reload() + + try: # epoch / gradient_step + epoch = ea.scalars.Items("save/epoch")[-1].step + self.last_save_step = self.last_log_test_step = epoch + gradient_step = ea.scalars.Items("save/gradient_step")[-1].step + self.last_log_update_step = gradient_step + except KeyError: + epoch, gradient_step = 0, 0 + try: # offline trainer doesn't have env_step + env_step = ea.scalars.Items("save/env_step")[-1].step + self.last_log_train_step = env_step + except KeyError: + env_step = 0 + + return epoch, env_step, gradient_step + + +class BasicLogger(TensorboardLogger): + """BasicLogger has changed its name to TensorboardLogger in #427. + + This class is for compatibility. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427.") + super().__init__(*args, **kwargs) diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py new file mode 100644 index 000000000..7a837c96c --- /dev/null +++ b/tianshou/utils/logger/wandb.py @@ -0,0 +1,44 @@ +from tianshou.utils import BaseLogger +from tianshou.utils.logger.base import LOG_DATA_TYPE + +try: + import wandb +except ImportError: + pass + + +class WandBLogger(BaseLogger): + """Weights and Biases logger that sends data to Weights and Biases. + + Creates three panels with plots: train, test, and update. + Make sure to select the correct access for each panel in weights and biases: + + - ``train/env_step`` for train plots + - ``test/env_step`` for test plots + - ``update/gradient_step`` for update plots + + Example of usage: + :: + + with wandb.init(project="My Project"): + logger = WandBLogger() + result = onpolicy_trainer(policy, train_collector, test_collector, + logger=logger) + + :param int train_interval: the log interval in log_train_data(). Default to 1000. + :param int test_interval: the log interval in log_test_data(). Default to 1. + :param int update_interval: the log interval in log_update_data(). + Default to 1000. + """ + + def __init__( + self, + train_interval: int = 1000, + test_interval: int = 1, + update_interval: int = 1000, + ) -> None: + super().__init__(train_interval, test_interval, update_interval) + + def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: + data[step_type] = step + wandb.log(data)