diff --git a/README.md b/README.md index 80ee2ff12..5037a7868 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,7 @@ buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 step_per_epoch, step_per_collect = 10000, 10 writer = SummaryWriter('log/dqn') # tensorboard is also supported! +logger = ts.utils.BasicLogger(writer) ``` Make environments: @@ -237,7 +238,7 @@ result = ts.trainer.offpolicy_trainer( train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, - writer=writer) + logger=logger) print(f'Finished training! Use {result["duration"]}') ``` diff --git a/docs/contributor.rst b/docs/contributor.rst index b48d7ffb5..c594b2c0d 100644 --- a/docs/contributor.rst +++ b/docs/contributor.rst @@ -7,3 +7,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom * Minghao Zhang (`Mehooz `_) * Alexis Duburcq (`duburcqa `_) * Kaichao You (`youkaichao `_) +* Huayu Chen (`ChenDRAG `_) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 40e4a399f..5c5d547d5 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -130,7 +130,7 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda epoch, env_step: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, - writer=None) + logger=None) print(f'Finished training! Use {result["duration"]}') The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): @@ -143,15 +143,17 @@ The meaning of each parameter is as follows (full description can be found at :f * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". * ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". * ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. -* ``writer``: See below. +* ``logger``: See below. The trainer supports `TensorBoard `_ for logging. It can be used as: :: from torch.utils.tensorboard import SummaryWriter + from tianshou.utils import BasicLogger writer = SummaryWriter('log/dqn') + logger = BasicLogger(writer) -Pass the writer into the trainer, and the training result will be recorded into the TensorBoard. +Pass the logger into the trainer, and the training result will be recorded into the TensorBoard. The returned result is a dictionary as follows: :: diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 3d7f28106..64b0dfd70 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -176,6 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul import numpy as np from copy import deepcopy from torch.utils.tensorboard import SummaryWriter + from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net @@ -319,11 +320,10 @@ With the above preparation, we are close to the first learned agent. The followi train_collector.collect(n_step=args.batch_size * args.training_num) # ======== tensorboard logging setup ========= - if not hasattr(args, 'writer'): - log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') - writer = SummaryWriter(log_path) - else: - writer = args.writer + log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) # ======== callback functions used during training ========= @@ -359,7 +359,7 @@ With the above preparation, we are close to the first learned agent. The followi args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, - writer=writer, test_in_train=False, reward_metric=reward_metric) + logger=logger, test_in_train=False, reward_metric=reward_metric) agent = policy.policies[args.agent_id - 1] # let's watch the match! diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 72b8d1336..05633bc4a 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -2,10 +2,12 @@ import torch import pickle import pprint +import datetime import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.utils.net.discrete import Actor @@ -39,7 +41,7 @@ def get_args(): parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--watch", default=False, action="store_true", help="watch the play of pre-trained policy only") - parser.add_argument("--log-interval", type=int, default=1000) + parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", @@ -113,8 +115,13 @@ def test_discrete_bcq(args=get_args()): # collector test_collector = Collector(policy, test_envs, exploration_noise=True) - log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') + # log + log_path = os.path.join( + args.logdir, args.task, 'bcq', + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -141,7 +148,7 @@ def watch(): result = offline_trainer( policy, buffer, test_collector, args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, log_interval=args.log_interval, ) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index d0a7ab81d..0956d691c 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -6,6 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -98,6 +99,8 @@ def test_c51(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -118,7 +121,7 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=env_step) + logger.write('train/eps', env_step, eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) @@ -144,7 +147,7 @@ def watch(): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 077d24891..b3f36c893 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -6,6 +6,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -94,6 +95,8 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -114,7 +117,7 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=env_step) + logger.write('train/eps', env_step, eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) @@ -154,7 +157,7 @@ def watch(): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index e2eed3cfd..ae2a26f4f 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -5,6 +5,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.policy import QRDQNPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer @@ -96,6 +97,8 @@ def test_qrdqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -116,7 +119,7 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=env_step) + logger.write('train/eps', env_step, eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) @@ -142,7 +145,7 @@ def watch(): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 5b760c14e..023824ce6 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -4,7 +4,9 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter + from tianshou.policy import A2CPolicy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -79,7 +81,9 @@ def test_a2c(args=get_args()): preprocess_fn=preprocess_fn, exploration_noise=True) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c')) + log_path = os.path.join(args.logdir, args.task, 'a2c') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def stop_fn(mean_rewards): if env.env.spec.reward_threshold: @@ -91,7 +95,7 @@ def stop_fn(mean_rewards): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 8a2a6845d..36728de6f 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -6,11 +6,12 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PPOPolicy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic +from tianshou.data import Collector, VectorReplayBuffer from atari import create_atari_environment, preprocess_fn @@ -84,7 +85,9 @@ def test_ppo(args=get_args()): preprocess_fn=preprocess_fn, exploration_noise=True) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo')) + log_path = os.path.join(args.logdir, args.task, 'ppo') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def stop_fn(mean_rewards): if env.env.spec.reward_threshold: @@ -96,7 +99,8 @@ def stop_fn(mean_rewards): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 69d0bfbee..58f1a3783 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -81,6 +82,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -106,7 +108,7 @@ def test_fn(epoch, env_step): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index a903008ab..d5a8f0577 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy +from tianshou.utils import BasicLogger from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer @@ -134,6 +135,7 @@ def test_sac_bipedal(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -146,7 +148,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, test_in_train=False, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, logger=logger) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 3d5d033f2..f5a9d3bdf 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer @@ -83,6 +84,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -102,7 +104,7 @@ def test_fn(epoch, env_step): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, - test_fn=test_fn, save_fn=save_fn, writer=writer) + test_fn=test_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 333dab49c..7fe3daed2 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -7,11 +7,12 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -103,6 +104,7 @@ def test_sac(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -115,7 +117,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, writer=writer) + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 22dda7d9b..d7ead224e 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -2,11 +2,13 @@ import gym import torch import pprint +import datetime import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -114,8 +116,11 @@ def test_sac(args=get_args()): exploration_noise=True) test_collector = Collector(policy, test_envs) # log - log_path = os.path.join(args.logdir, args.task, 'sac') + log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str( + args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S')) writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer, train_interval=args.log_interval) def watch(): # watch agent's performance @@ -141,8 +146,8 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - update_per_step=args.update_per_step, log_interval=args.log_interval) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step) pprint.pprint(result) watch() diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index a83abc37b..bc75e1f51 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -6,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -79,7 +81,9 @@ def test_ddpg(args=get_args()): exploration_noise=True) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + 'ddpg') + log_path = os.path.join(args.logdir, args.task, 'ddpg') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold @@ -88,7 +92,7 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 1bda7aa16..004b604a6 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -6,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.exploration import GaussianNoise @@ -88,7 +90,9 @@ def test_td3(args=get_args()): test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - writer = SummaryWriter(args.logdir + '/' + 'td3') + log_path = os.path.join(args.logdir, args.task, 'td3') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold @@ -97,7 +101,7 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index d64155dbe..2ed046294 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -2,12 +2,14 @@ import gym import torch import pprint +import datetime import argparse import numpy as np import pybullet_envs from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy +from tianshou.utils import BasicLogger from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer @@ -88,8 +90,10 @@ def test_sac(args=get_args()): test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) + log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str( + args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S')) writer = SummaryWriter(log_path) + logger = BasicLogger(writer, train_interval=args.log_interval) def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold @@ -99,7 +103,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, - writer=writer, log_interval=args.log_interval) + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index f23e1b057..8e5f37be7 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -6,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy +from tianshou.utils import BasicLogger from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.exploration import GaussianNoise @@ -93,7 +95,9 @@ def test_td3(args=get_args()): test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - writer = SummaryWriter(args.logdir + '/' + 'td3') + log_path = os.path.join(args.logdir, args.task, 'td3') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def stop_fn(mean_rewards): if env.spec.reward_threshold: @@ -105,7 +109,7 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 93d2a48ca..aa72272a9 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2,7 +2,6 @@ import numpy as np from tianshou.utils import MovAvg -from tianshou.utils import SummaryWriter from tianshou.utils.net.common import MLP, Net from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -77,25 +76,7 @@ def test_net(): assert list(net(data, act).shape) == [bsz, 1] -def test_summary_writer(): - # get first instance by key of `default` or your own key - writer1 = SummaryWriter.get_instance( - key="first", log_dir="log/test_sw/first") - assert writer1.log_dir == "log/test_sw/first" - writer2 = SummaryWriter.get_instance() - assert writer1 is writer2 - # create new instance by specify a new key - writer3 = SummaryWriter.get_instance( - key="second", log_dir="log/test_sw/second") - assert writer3.log_dir == "log/test_sw/second" - writer4 = SummaryWriter.get_instance(key="second") - assert writer3 is writer4 - assert writer1 is not writer3 - assert writer1.log_dir != writer4.log_dir - - if __name__ == '__main__': test_noise() test_moving_average() test_net() - test_summary_writer() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index afe88e16b..311aa65a7 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -93,6 +94,7 @@ def test_ddpg(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'ddpg') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -105,7 +107,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, writer=writer) + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 4f8ede1a0..b4fd383c7 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -8,6 +8,7 @@ from torch.distributions import Independent, Normal from tianshou.policy import PPOPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -111,6 +112,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -123,7 +125,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 0a96dbfa9..1b4a977ef 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -6,6 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -102,6 +103,7 @@ def test_sac_with_il(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -114,7 +116,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, writer=writer) + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -146,7 +148,7 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, args.il_step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index bbc32d912..d67818194 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -106,6 +107,7 @@ def test_td3(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'td3') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -118,7 +120,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, - save_fn=save_fn, writer=writer) + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 1032b3176..196dc28b5 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -6,6 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.data import Collector, VectorReplayBuffer @@ -89,6 +90,7 @@ def test_a2c_with_il(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'a2c') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -101,7 +103,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -130,7 +132,7 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, args.il_step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index fe0573e1d..53768dae6 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -89,6 +90,7 @@ def test_c51(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -115,7 +117,7 @@ def test_fn(epoch, env_step): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index e9104c152..c59910e84 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -91,6 +92,7 @@ def test_dqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -117,7 +119,7 @@ def test_fn(epoch, env_step): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index dc2e06c00..33f0432a3 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.common import Recurrent @@ -77,6 +78,7 @@ def test_drqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'drqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -96,7 +98,7 @@ def test_fn(epoch, env_step): args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, - save_fn=save_fn, writer=writer) + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 09c4c52d2..996f7d599 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -82,6 +83,7 @@ def test_discrete_bcq(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -92,7 +94,7 @@ def stop_fn(mean_rewards): result = offline_trainer( policy, buffer, test_collector, args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 784ae70db..6ebeb2686 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PGPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -72,6 +73,7 @@ def test_pg(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'pg') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -84,7 +86,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 35634c675..5821e7be8 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PPOPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer @@ -98,6 +99,7 @@ def test_ppo(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -110,7 +112,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index e5ce61b98..2268b63de 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -6,6 +6,7 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.policy import QRDQNPolicy from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net @@ -87,6 +88,7 @@ def test_qrdqn(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -113,7 +115,7 @@ def test_fn(epoch, env_step): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index ebcb75157..d7f408ffe 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -6,12 +6,13 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import DiscreteSACPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.discrete import Actor, Critic +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -99,6 +100,7 @@ def test_discrete_sac(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'discrete_sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -110,7 +112,7 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 01ea98a58..c04261868 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -6,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PSRLPolicy +# from tianshou.utils import BasicLogger from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv @@ -66,7 +68,10 @@ def test_psrl(args=get_args()): exploration_noise=True) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + args.task) + log_path = os.path.join(args.logdir, args.task, 'psrl') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + # logger = BasicLogger(writer) def stop_fn(mean_rewards): if env.spec.reward_threshold: @@ -75,11 +80,12 @@ def stop_fn(mean_rewards): return False train_collector.collect(n_step=args.buffer_size, random=True) - # trainer + # trainer, test it without logger result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, 1, args.test_num, 0, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, + # logger=logger, test_in_train=False) if __name__ == '__main__': diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 0c0dcf58c..4c88656cb 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -7,6 +7,7 @@ from tianshou.env import DummyVectorEnv from tianshou.data import Collector from tianshou.policy import RandomPolicy +from tianshou.utils import BasicLogger from tic_tac_toe_env import TicTacToeEnv from tic_tac_toe import get_parser, get_agents, train_agent, watch @@ -31,7 +32,8 @@ def gomoku(args=get_args()): # log log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') - args.writer = SummaryWriter(log_path) + writer = SummaryWriter(log_path) + args.logger = BasicLogger(writer) opponent_pool = [agent_opponent] diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index edf066e09..3e92838ab 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer @@ -131,12 +132,10 @@ def env_func(): # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log - if not hasattr(args, 'writer'): - log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') - writer = SummaryWriter(log_path) - args.writer = writer - else: - writer = args.writer + log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): if hasattr(args, 'model_save_path'): @@ -166,7 +165,7 @@ def reward_metric(rews): args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, - writer=writer, test_in_train=False, reward_metric=reward_metric) + logger=logger, test_in_train=False, reward_metric=reward_metric) return result, policy.policies[args.agent_id - 1] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b048c1ead..2ceb8081f 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -92,8 +92,6 @@ def __setstate__(self, state: Dict[str, Any]) -> None: ("buffer.__getattr__" is customized). """ self.__dict__.update(state) - # compatible with version == 0.3.1's HDF5 data format - self._indices = np.arange(self.maxsize) def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index bb3239e0e..640210942 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -184,9 +184,9 @@ def collect( * ``n/ep`` collected number of episodes. * ``n/st`` collected number of steps. - * ``rews`` list of episode reward over collected episodes. - * ``lens`` list of episode length over collected episodes. - * ``idxs`` list of episode start index in buffer over collected episodes. + * ``rews`` array of episode reward over collected episodes. + * ``lens`` array of episode length over collected episodes. + * ``idxs`` array of episode start index in buffer over collected episodes. """ assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: @@ -379,9 +379,9 @@ def collect( * ``n/ep`` collected number of episodes. * ``n/st`` collected number of steps. - * ``rews`` list of episode reward over collected episodes. - * ``lens`` list of episode length over collected episodes. - * ``idxs`` list of episode start index in buffer over collected episodes. + * ``rews`` array of episode reward over collected episodes. + * ``lens`` array of episode length over collected episodes. + * ``idxs`` array of episode start index in buffer over collected episodes. """ # collect at least n_step or n_episode if n_step is not None: diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 3b663d630..99a2f8a81 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,7 @@ from torch import nn from numba import njit from abc import ABC, abstractmethod -from typing import Any, List, Union, Mapping, Optional, Callable +from typing import Any, Dict, Union, Optional, Callable from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -124,12 +124,10 @@ def process_fn( return batch @abstractmethod - def learn( - self, batch: Batch, **kwargs: Any - ) -> Mapping[str, Union[float, List[float]]]: + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]: """Update policy with a given batch of data. - :return: A dict which includes loss and its corresponding label. + :return: A dict, including the data needed to be logged (e.g., loss). .. note:: @@ -162,18 +160,20 @@ def post_process_fn( def update( self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any - ) -> Mapping[str, Union[float, List[float]]]: + ) -> Dict[str, Any]: """Update the policy network and replay buffer. - It includes 3 function steps: process_fn, learn, and post_process_fn. - In addition, this function will change the value of ``self.updating``: - it will be False before this function and will be True when executing - :meth:`update`. Please refer to :ref:`policy_state` for more detailed - explanation. + It includes 3 function steps: process_fn, learn, and post_process_fn. In + addition, this function will change the value of ``self.updating``: it will be + False before this function and will be True when executing :meth:`update`. + Please refer to :ref:`policy_state` for more detailed explanation. - :param int sample_size: 0 means it will extract all the data from the - buffer, otherwise it will sample a batch with given sample_size. + :param int sample_size: 0 means it will extract all the data from the buffer, + otherwise it will sample a batch with given sample_size. :param ReplayBuffer buffer: the corresponding replay buffer. + + :return: A dict, including the data needed to be logged (e.g., loss) from + ``policy.learn()``. """ if buffer is None: return {} diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 61714f7a0..b3588ae5b 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -2,11 +2,10 @@ import tqdm import numpy as np from collections import defaultdict -from torch.utils.tensorboard import SummaryWriter from typing import Dict, Union, Callable, Optional from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.data import Collector, ReplayBuffer from tianshou.trainer import test_episode, gather_info @@ -23,8 +22,7 @@ def offline_trainer( stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, + logger: BaseLogger = LazyLogger(), verbose: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. @@ -55,9 +53,8 @@ def offline_trainer( result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; - if None is given, it will not write logs to TensorBoard. Default to None. - :param int log_interval: the log interval of the writer. Default to 1. + :param BaseLogger logger: A logger that logs statistics during updating/testing. + Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. @@ -67,10 +64,9 @@ def offline_trainer( start_time = time.time() test_collector.reset_stat() test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, - writer, gradient_step, reward_metric) + logger, gradient_step, reward_metric) best_epoch = 0 - best_reward = test_result["rews"].mean() - best_reward_std = test_result["rews"].std() + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1, 1 + max_epoch): policy.train() with tqdm.trange( @@ -82,27 +78,23 @@ def offline_trainer( data = {"gradient_step": str(gradient_step)} for k in losses.keys(): stat[k].add(losses[k]) - data[k] = f"{stat[k].get():.6f}" - if writer and gradient_step % log_interval == 0: - writer.add_scalar( - "train/" + k, stat[k].get(), - global_step=gradient_step) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.6f}" + logger.log_update_data(losses, gradient_step) t.set_postfix(**data) # test - test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, gradient_step, - reward_metric) - if best_epoch == -1 or best_reward < test_result["rews"].mean(): - best_reward = test_result["rews"].mean() - best_reward_std = test_result['rews'].std() + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, + logger, gradient_step, reward_metric) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch == -1 or best_reward < rew: + best_reward, best_reward_std = rew, rew_std best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " - f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " - f"{best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break - return gather_info(start_time, None, test_collector, - best_reward, best_reward_std) + return gather_info(start_time, None, test_collector, best_reward, best_reward_std) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 54e7cb166..5f233bfef 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -2,12 +2,11 @@ import tqdm import numpy as np from collections import defaultdict -from torch.utils.tensorboard import SummaryWriter from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.trainer import test_episode, gather_info @@ -26,8 +25,7 @@ def offpolicy_trainer( stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, + logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: @@ -70,25 +68,24 @@ def offpolicy_trainer( result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; - if None is given, it will not write logs to TensorBoard. Default to None. - :param int log_interval: the log interval of the writer. Default to 1. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ env_step, gradient_step = 0, 0 + last_rew, last_len = 0.0, 0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, - writer, env_step, reward_metric) + logger, env_step, reward_metric) best_epoch = 0 - best_reward = test_result["rews"].mean() - best_reward_std = test_result["rews"].std() + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -99,34 +96,32 @@ def offpolicy_trainer( if train_fn: train_fn(epoch, env_step) result = train_collector.collect(n_step=step_per_collect) - if len(result["rews"]) > 0 and reward_metric: + if result["n/ep"] > 0 and reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) t.update(result["n/st"]) + logger.log_train_data(result, env_step) + last_rew = result['rew'] if 'rew' in result else last_rew + last_len = result['len'] if 'len' in result else last_len data = { "env_step": str(env_step), - "rew": f"{result['rews'].mean():.2f}", - "len": str(result["lens"].mean()), + "rew": f"{last_rew:.2f}", + "len": str(last_len), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), } if result["n/ep"] > 0: - if writer and env_step % log_interval == 0: - writer.add_scalar( - "train/rew", result['rews'].mean(), global_step=env_step) - writer.add_scalar( - "train/len", result['lens'].mean(), global_step=env_step) - if test_in_train and stop_fn and stop_fn(result["rews"].mean()): + if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test, writer, env_step) - if stop_fn(test_result["rews"].mean()): + epoch, episode_per_test, logger, env_step) + if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rews"].mean(), test_result["rews"].std()) + test_result["rew"], test_result["rew_std"]) else: policy.train() for i in range(round(update_per_step * result["n/st"])): @@ -134,26 +129,24 @@ def offpolicy_trainer( losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): stat[k].add(losses[k]) - data[k] = f"{stat[k].get():.6f}" - if writer and gradient_step % log_interval == 0: - writer.add_scalar( - k, stat[k].get(), global_step=gradient_step) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.6f}" + logger.log_update_data(losses, gradient_step) t.set_postfix(**data) if t.n <= t.total: t.update() # test test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step, reward_metric) - if best_epoch == -1 or best_reward < test_result["rews"].mean(): - best_reward = test_result["rews"].mean() - best_reward_std = test_result['rews'].std() + episode_per_test, logger, env_step, reward_metric) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch == -1 or best_reward < rew: + best_reward, best_reward_std = rew, rew_std best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " - f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " - f"{best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, train_collector, test_collector, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 43fcc8738..5f5254d66 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -2,12 +2,11 @@ import tqdm import numpy as np from collections import defaultdict -from torch.utils.tensorboard import SummaryWriter from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.trainer import test_episode, gather_info @@ -27,8 +26,7 @@ def onpolicy_trainer( stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, + logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: @@ -72,9 +70,8 @@ def onpolicy_trainer( result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; - if None is given, it will not write logs to TensorBoard. Default to None. - :param int log_interval: the log interval of the writer. Default to 1. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. @@ -85,16 +82,16 @@ def onpolicy_trainer( Only either one of step_per_collect and episode_per_collect can be specified. """ env_step, gradient_step = 0, 0 + last_rew, last_len = 0.0, 0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, - writer, env_step, reward_metric) + logger, env_step, reward_metric) best_epoch = 0 - best_reward = test_result["rews"].mean() - best_reward_std = test_result["rews"].std() + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -110,29 +107,27 @@ def onpolicy_trainer( result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) t.update(result["n/st"]) + logger.log_train_data(result, env_step) + last_rew = result['rew'] if 'rew' in result else last_rew + last_len = result['len'] if 'len' in result else last_len data = { "env_step": str(env_step), - "rew": f"{result['rews'].mean():.2f}", - "len": str(int(result["lens"].mean())), + "rew": f"{last_rew:.2f}", + "len": str(last_len), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), } - if writer and env_step % log_interval == 0: - writer.add_scalar( - "train/rew", result['rews'].mean(), global_step=env_step) - writer.add_scalar( - "train/len", result['lens'].mean(), global_step=env_step) - if test_in_train and stop_fn and stop_fn(result["rews"].mean()): + if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test, writer, env_step) - if stop_fn(test_result["rews"].mean()): + epoch, episode_per_test, logger, env_step) + if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result["rews"].mean(), test_result["rews"].std()) + test_result["rew"], test_result["rew_std"]) else: policy.train() losses = policy.update( @@ -144,26 +139,24 @@ def onpolicy_trainer( gradient_step += step for k in losses.keys(): stat[k].add(losses[k]) - data[k] = f"{stat[k].get():.6f}" - if writer and gradient_step % log_interval == 0: - writer.add_scalar( - k, stat[k].get(), global_step=gradient_step) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.6f}" + logger.log_update_data(losses, gradient_step) t.set_postfix(**data) if t.n <= t.total: t.update() # test test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < test_result["rews"].mean(): - best_reward = test_result["rews"].mean() - best_reward_std = test_result['rews'].std() + episode_per_test, logger, env_step) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch == -1 or best_reward < rew: + best_reward, best_reward_std = rew, rew_std best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " - f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " - f"{best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, train_collector, test_collector, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 72803bef0..2e729feeb 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,10 +1,10 @@ import time import numpy as np -from torch.utils.tensorboard import SummaryWriter from typing import Any, Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy +from tianshou.utils import BaseLogger def test_episode( @@ -13,7 +13,7 @@ def test_episode( test_fn: Optional[Callable[[int, Optional[int]], None]], epoch: int, n_episode: int, - writer: Optional[SummaryWriter] = None, + logger: Optional[BaseLogger] = None, global_step: Optional[int] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, ) -> Dict[str, Any]: @@ -26,12 +26,8 @@ def test_episode( result = collector.collect(n_episode=n_episode) if reward_metric: result["rews"] = reward_metric(result["rews"]) - if writer is not None and global_step is not None: - rews, lens = result["rews"], result["lens"] - writer.add_scalar("test/rew", rews.mean(), global_step=global_step) - writer.add_scalar("test/rew_std", rews.std(), global_step=global_step) - writer.add_scalar("test/len", lens.mean(), global_step=global_step) - writer.add_scalar("test/len_std", lens.std(), global_step=global_step) + if logger and global_step is not None: + logger.log_test_data(result, global_step) return result diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index d3a371577..b8cfa2315 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,9 +1,11 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.moving_average import MovAvg -from tianshou.utils.log_tools import SummaryWriter +from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger __all__ = [ "MovAvg", "tqdm_config", - "SummaryWriter", + "BaseLogger", + "BasicLogger", + "LazyLogger", ] diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py index bbbd82e67..7605a27ba 100644 --- a/tianshou/utils/log_tools.py +++ b/tianshou/utils/log_tools.py @@ -1,47 +1,159 @@ -import threading -from torch.utils import tensorboard -from typing import Any, Dict, Optional - - -class SummaryWriter(tensorboard.SummaryWriter): - """A more convenient Summary Writer(`tensorboard.SummaryWriter`). - - You can get the same instance of summary writer everywhere after you - created one. - :: - - >>> writer1 = SummaryWriter.get_instance( - key="first", log_dir="log/test_sw/first") - >>> writer2 = SummaryWriter.get_instance() - >>> writer1 is writer2 - True - >>> writer4 = SummaryWriter.get_instance( - key="second", log_dir="log/test_sw/second") - >>> writer5 = SummaryWriter.get_instance(key="second") - >>> writer1 is not writer4 - True - >>> writer4 is writer5 - True +import numpy as np +from numbers import Number +from typing import Any, Union +from abc import ABC, abstractmethod +from torch.utils.tensorboard import SummaryWriter + + +class BaseLogger(ABC): + """The base class for any logger which is compatible with trainer.""" + + def __init__(self, writer: Any) -> None: + super().__init__() + self.writer = writer + + @abstractmethod + def write( + self, + key: str, + x: Union[Number, np.number, np.ndarray], + y: Union[Number, np.number, np.ndarray], + **kwargs: Any, + ) -> None: + """Specify how the writer is used to log data. + + :param key: namespace which the input data tuple belongs to. + :param x: stands for the ordinate of the input data tuple. + :param y: stands for the abscissa of the input data tuple. + """ + pass + + def log_train_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during training. + + :param collect_result: a dict containing information of data collected in + training stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + def log_update_data(self, update_result: dict, step: int) -> None: + """Use writer to log statistics generated during updating. + + :param update_result: a dict containing information of data collected in + updating stage, i.e., returns of policy.update(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + def log_test_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during evaluating. + + :param collect_result: a dict containing information of data collected in + evaluating stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + +class BasicLogger(BaseLogger): + """A loggger that relies on tensorboard SummaryWriter by default to visualize \ + and log statistics. + + You can also rewrite write() func to use your own writer. + + :param SummaryWriter writer: the writer to log data. + :param int train_interval: the log interval in log_train_data(). Default to 1. + :param int test_interval: the log interval in log_test_data(). Default to 1. + :param int update_interval: the log interval in log_update_data(). Default to 1000. """ - _mutex_lock = threading.Lock() - _default_key: str - _instance: Optional[Dict[str, "SummaryWriter"]] = None + def __init__( + self, + writer: SummaryWriter, + train_interval: int = 1, + test_interval: int = 1, + update_interval: int = 1000, + ) -> None: + super().__init__(writer) + self.train_interval = train_interval + self.test_interval = test_interval + self.update_interval = update_interval + self.last_log_train_step = -1 + self.last_log_test_step = -1 + self.last_log_update_step = -1 + + def write( + self, + key: str, + x: Union[Number, np.number, np.ndarray], + y: Union[Number, np.number, np.ndarray], + **kwargs: Any, + ) -> None: + self.writer.add_scalar(key, y, global_step=x) + + def log_train_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during training. + + :param collect_result: a dict containing information of data collected in + training stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + + .. note:: + + ``collect_result`` will be modified in-place with "rew" and "len" keys. + """ + if collect_result["n/ep"] > 0: + collect_result["rew"] = collect_result["rews"].mean() + collect_result["len"] = collect_result["lens"].mean() + if step - self.last_log_train_step >= self.train_interval: + self.write("train/n/ep", step, collect_result["n/ep"]) + self.write("train/rew", step, collect_result["rew"]) + self.write("train/len", step, collect_result["len"]) + self.last_log_train_step = step + + def log_test_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during evaluating. + + :param collect_result: a dict containing information of data collected in + evaluating stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + + .. note:: + + ``collect_result`` will be modified in-place with "rew", "rew_std", "len", + and "len_std" keys. + """ + assert collect_result["n/ep"] > 0 + rews, lens = collect_result["rews"], collect_result["lens"] + rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() + collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) + if step - self.last_log_test_step >= self.test_interval: + self.write("test/rew", step, rew) + self.write("test/len", step, len_) + self.write("test/rew_std", step, rew_std) + self.write("test/len_std", step, len_std) + self.last_log_test_step = step + + def log_update_data(self, update_result: dict, step: int) -> None: + if step - self.last_log_update_step >= self.update_interval: + for k, v in update_result.items(): + self.write("train/" + k, step, v) # save in train/ + self.last_log_update_step = step + + +class LazyLogger(BasicLogger): + """A loggger that does nothing. Used as the placeholder in trainer.""" + + def __init__(self) -> None: + super().__init__(None) # type: ignore - @classmethod - def get_instance( - cls, - key: Optional[str] = None, - *args: Any, + def write( + self, + key: str, + x: Union[Number, np.number, np.ndarray], + y: Union[Number, np.number, np.ndarray], **kwargs: Any, - ) -> "SummaryWriter": - """Get instance of torch.utils.tensorboard.SummaryWriter by key.""" - with SummaryWriter._mutex_lock: - if key is None: - key = SummaryWriter._default_key - if SummaryWriter._instance is None: - SummaryWriter._instance = {} - SummaryWriter._default_key = key - if key not in SummaryWriter._instance.keys(): - SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs) - return SummaryWriter._instance[key] + ) -> None: + """The LazyLogger writes nothing.""" + pass