diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index f71d34cd5..7f8095e1d 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -30,6 +30,34 @@ Customize Training Process See :ref:`customized_trainer`. +.. _resume_training: + +Resume Training Process +----------------------- + +This is related to `Issue 349 `_. + +To resume training process from an existing checkpoint, you need to do the following things in the training process: + +1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer; +2. Use ``BasicLogger`` which contains a tensorboard; +3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger. + +And to successfully resume from a checkpoint: + +1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer; +2. Set ``resume_from_log=True`` with trainer; + +We provide an example to show how these steps work: checkout `test_c51.py `_, `test_ppo.py `_ or `test_il_bcq.py `_ by running + +.. code-block:: console + + $ python3 test/discrete/test_c51.py # train some epoch + $ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training + + +To correctly render the data (including several tfevent files), we highly recommend using ``tensorboard >= 2.5.0`` (see `here `_ for the reason). Otherwise, it may cause overlapping issue that you need to manually handle with. + .. _parallel_sampling: Parallel Sampling diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index 2c42c46c7..f16db4618 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -85,8 +85,7 @@ def test_discrete_bcq(args=get_args()): feature_net, args.action_shape, device=args.device, hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) optim = torch.optim.Adam( - set(policy_net.parameters()).union(imitation_net.parameters()), - lr=args.lr) + list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr) # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 076dfca5c..bf039a187 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -101,7 +101,7 @@ def test_a2c(args=get_args()): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()), + optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()), lr=args.lr, eps=1e-5, alpha=0.99) lr_scheduler = None diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 3974c2e63..681f626a1 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -106,8 +106,8 @@ def test_ppo(args=get_args()): torch.nn.init.zeros_(m.bias) m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr) lr_scheduler = None if args.lr_decay: diff --git a/setup.py b/setup.py index 04149ea1e..284ae0f56 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def get_version() -> str: "gym>=0.15.4", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 - "tensorboard", + "tensorboard>=2.5.0", "torch>=1.4.0", "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 5030abfe5..021cc0419 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -105,6 +105,7 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 9243313dc..cfba19f11 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -80,8 +80,7 @@ def test_npg(args=get_args()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = torch.optim.Adam(critic.parameters(), lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index b661bb756..6ab4717d6 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -47,6 +47,8 @@ def get_args(): parser.add_argument('--value-clip', type=int, default=1) parser.add_argument('--norm-adv', type=int, default=1) parser.add_argument('--recompute-adv', type=int, default=0) + parser.add_argument('--resume', action="store_true") + parser.add_argument("--save-interval", type=int, default=4) args = parser.parse_known_args()[0] return args @@ -83,8 +85,8 @@ def test_ppo(args=get_args()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward @@ -114,7 +116,7 @@ def dist(*logits): # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = BasicLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -122,13 +124,34 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + torch.save({ + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth')) + + if args.resume: + # load from existing checkpoint + print(f"Loading agent under {log_path}") + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location=args.device) + policy.load_state_dict(checkpoint['model']) + optim.load_state_dict(checkpoint['optim']) + print("Successfully restore policy and optim.") + else: + print("Fail to restore policy and optim.") + # trainer result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + policy, train_collector, test_collector, args.epoch, args.step_per_epoch, + args.repeat_per_collect, args.test_num, args.batch_size, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, - logger=logger) + logger=logger, resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! @@ -140,5 +163,10 @@ def stop_fn(mean_rewards): print(f"Final reward: {rews.mean()}, length: {lens.mean()}") +def test_ppo_resume(args=get_args()): + args.resume = True + test_ppo(args) + + if __name__ == '__main__': test_ppo() diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 7ed05cf07..ad0b9af66 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -124,6 +124,7 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 2e0674372..ee6fa11de 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -119,6 +119,7 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 9db4f449c..8c8387773 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -27,7 +27,8 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-collect', type=int, default=2048) - parser.add_argument('--repeat-per-collect', type=int, default=1) + parser.add_argument('--repeat-per-collect', type=int, + default=2) # theoretically it should be 1 parser.add_argument('--batch-size', type=int, default=99999) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=16) @@ -82,8 +83,7 @@ def test_trpo(args=get_args()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = torch.optim.Adam(critic.parameters(), lr=args.lr) # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index ff34f9e8d..9714219e9 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -74,8 +74,8 @@ def test_a2c_with_il(args=get_args()): device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical policy = A2CPolicy( actor, critic, optim, dist, @@ -106,6 +106,7 @@ def stop_fn(mean_rewards): episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! @@ -135,6 +136,7 @@ def stop_fn(mean_rewards): args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 1d0c4cc0a..a7fdd922a 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -1,6 +1,7 @@ import os import gym import torch +import pickle import pprint import argparse import numpy as np @@ -43,9 +44,11 @@ def get_args(): action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument('--resume', action="store_true") parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument("--save-interval", type=int, default=4) args = parser.parse_known_args()[0] return args @@ -90,7 +93,7 @@ def test_c51(args=get_args()): # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = BasicLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -112,14 +115,42 @@ def train_fn(epoch, env_step): def test_fn(epoch, env_step): policy.set_eps(args.eps_test) + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + torch.save({ + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth')) + pickle.dump(train_collector.buffer, + open(os.path.join(log_path, 'train_buffer.pkl'), "wb")) + + if args.resume: + # load from existing checkpoint + print(f"Loading agent under {log_path}") + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location=args.device) + policy.load_state_dict(checkpoint['model']) + policy.optim.load_state_dict(checkpoint['optim']) + print("Successfully restore policy and optim.") + else: + print("Fail to restore policy and optim.") + buffer_path = os.path.join(log_path, 'train_buffer.pkl') + if os.path.exists(buffer_path): + train_collector.buffer = pickle.load(open(buffer_path, "rb")) + print("Successfully restore buffer.") + else: + print("Fail to restore buffer.") + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, - test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) - + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, + resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! @@ -132,6 +163,11 @@ def test_fn(epoch, env_step): print(f"Final reward: {rews.mean()}, length: {lens.mean()}") +def test_c51_resume(args=get_args()): + args.resume = True + test_c51(args) + + def test_pc51(args=get_args()): args.prioritized_replay = True args.gamma = .95 diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index bd88b2b5b..cb7fac403 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -120,7 +120,6 @@ def test_fn(epoch, env_step): args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) - assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 39bef8dbc..c04cbc396 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -99,8 +99,8 @@ def test_fn(epoch, env_step): args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) - assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index c3bebb53a..1ea2db433 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -42,6 +42,8 @@ def get_args(): "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", ) + parser.add_argument("--resume", action="store_true") + parser.add_argument("--save-interval", type=int, default=4) args = parser.parse_known_args()[0] return args @@ -67,7 +69,7 @@ def test_discrete_bcq(args=get_args()): args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) optim = torch.optim.Adam( - set(policy_net.parameters()).union(imitation_net.parameters()), + list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr) policy = DiscreteBCQPolicy( @@ -85,7 +87,7 @@ def test_discrete_bcq(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) - logger = BasicLogger(writer) + logger = BasicLogger(writer, save_interval=args.save_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -93,11 +95,30 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + torch.save({ + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth')) + + if args.resume: + # load from existing checkpoint + print(f"Loading agent under {log_path}") + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location=args.device) + policy.load_state_dict(checkpoint['model']) + optim.load_state_dict(checkpoint['optim']) + print("Successfully restore policy and optim.") + else: + print("Fail to restore policy and optim.") + result = offline_trainer( policy, buffer, test_collector, args.epoch, args.update_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) - + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -112,5 +133,10 @@ def stop_fn(mean_rewards): print(f"Final reward: {rews.mean()}, length: {lens.mean()}") +def test_discrete_bcq_resume(args=get_args()): + args.resume = True + test_discrete_bcq(args) + + if __name__ == "__main__": test_discrete_bcq(get_args()) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 8ce73b5cf..193117150 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -93,6 +93,7 @@ def stop_fn(mean_rewards): episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index af9c3a75d..9fbd7d3b1 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -75,8 +75,8 @@ def test_ppo(args=get_args()): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) + optim = torch.optim.Adam( + list(actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical policy = PPOPolicy( actor, critic, optim, dist, @@ -114,6 +114,7 @@ def stop_fn(mean_rewards): episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 59e26b197..2847ac2b8 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -117,8 +117,8 @@ def test_fn(epoch, env_step): args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step) - assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index aaa731547..0cb2ae018 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -112,6 +112,7 @@ def stop_fn(mean_rewards): args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, test_in_train=False) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 802e34963..aad7299ca 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,5 +1,6 @@ import time import tqdm +import warnings import numpy as np from collections import defaultdict from typing import Dict, Union, Callable, Optional @@ -21,6 +22,8 @@ def offline_trainer( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, @@ -44,6 +47,12 @@ def offline_trainer( :param function save_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with the + signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can + save whatever you want. Because offline-RL doesn't have env_step, the env_step + is always 0 here. + :param bool resume_from_log: resume gradient_step and other metadata from existing + tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. @@ -59,15 +68,22 @@ def offline_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ - gradient_step = 0 + if save_fn: + warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.") + + start_epoch, gradient_step = 0, 0 + if resume_from_log: + start_epoch, _, gradient_step = logger.restore_data() stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() test_collector.reset_stat() - test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, - logger, gradient_step, reward_metric) - best_epoch = 0 + + test_result = test_episode(policy, test_collector, test_fn, start_epoch, + episode_per_test, logger, gradient_step, reward_metric) + best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - for epoch in range(1, 1 + max_epoch): + + for epoch in range(1 + start_epoch, 1 + max_epoch): policy.train() with tqdm.trange( update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config @@ -87,15 +103,14 @@ def offline_trainer( policy, test_collector, test_fn, epoch, episode_per_test, logger, gradient_step, reward_metric) rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch == -1 or best_reward < rew: - best_reward, best_reward_std = rew, rew_std - best_epoch = epoch + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std if save_fn: save_fn(policy) + logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:" - f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, None, test_collector, best_reward, best_reward_std) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 444dbc08a..5df2d6f29 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,13 +1,14 @@ import time import tqdm +import warnings import numpy as np from collections import defaultdict from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.trainer import test_episode, gather_info +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger def offpolicy_trainer( @@ -24,6 +25,8 @@ def offpolicy_trainer( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, @@ -57,8 +60,13 @@ def offpolicy_trainer( It can be used to perform custom additional operations, with the signature ``f( num_epoch: int, step_idx: int) -> None``. :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy:BasePolicy) -> + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with the + signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can + save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata from + existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. @@ -75,18 +83,24 @@ def offpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ - env_step, gradient_step = 0, 0 + if save_fn: + warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.") + + start_epoch, env_step, gradient_step = 0, 0, 0 + if resume_from_log: + start_epoch, env_step, gradient_step = logger.restore_data() last_rew, last_len = 0.0, 0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, - logger, env_step, reward_metric) - best_epoch = 0 + test_result = test_episode(policy, test_collector, test_fn, start_epoch, + episode_per_test, logger, env_step, reward_metric) + best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - for epoch in range(1, 1 + max_epoch): + + for epoch in range(1 + start_epoch, 1 + max_epoch): # train policy.train() with tqdm.tqdm( @@ -118,6 +132,8 @@ def offpolicy_trainer( if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) + logger.save_data( + epoch, env_step, gradient_step, save_checkpoint_fn) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, @@ -139,15 +155,14 @@ def offpolicy_trainer( test_result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, reward_metric) rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch == -1 or best_reward < rew: - best_reward, best_reward_std = rew, rew_std - best_epoch = epoch + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std if save_fn: save_fn(policy) + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:" - f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, train_collector, test_collector, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index e396295d0..379aba68e 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,13 +1,14 @@ import time import tqdm +import warnings import numpy as np from collections import defaultdict from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.trainer import test_episode, gather_info +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger def onpolicy_trainer( @@ -25,6 +26,8 @@ def onpolicy_trainer( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, @@ -61,6 +64,11 @@ def onpolicy_trainer( :param function save_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with the + signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can + save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata from + existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. @@ -81,18 +89,24 @@ def onpolicy_trainer( Only either one of step_per_collect and episode_per_collect can be specified. """ - env_step, gradient_step = 0, 0 + if save_fn: + warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.") + + start_epoch, env_step, gradient_step = 0, 0, 0 + if resume_from_log: + start_epoch, env_step, gradient_step = logger.restore_data() last_rew, last_len = 0.0, 0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, - logger, env_step, reward_metric) - best_epoch = 0 + test_result = test_episode(policy, test_collector, test_fn, start_epoch, + episode_per_test, logger, env_step, reward_metric) + best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - for epoch in range(1, 1 + max_epoch): + + for epoch in range(1 + start_epoch, 1 + max_epoch): # train policy.train() with tqdm.tqdm( @@ -125,6 +139,8 @@ def onpolicy_trainer( if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) + logger.save_data( + epoch, env_step, gradient_step, save_checkpoint_fn) t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, @@ -150,15 +166,14 @@ def onpolicy_trainer( test_result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, reward_metric) rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch == -1 or best_reward < rew: - best_reward, best_reward_std = rew, rew_std - best_epoch = epoch + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std if save_fn: save_fn(policy) + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:" - f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, train_collector, test_collector, diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py index fcd1d5575..2a8d3a251 100644 --- a/tianshou/utils/log_tools.py +++ b/tianshou/utils/log_tools.py @@ -1,8 +1,12 @@ import numpy as np from numbers import Number -from typing import Any, Union from abc import ABC, abstractmethod from torch.utils.tensorboard import SummaryWriter +from typing import Any, Tuple, Union, Callable, Optional +from tensorboard.backend.event_processing import event_accumulator + + +WRITE_TYPE = Union[int, Number, np.number, np.ndarray] class BaseLogger(ABC): @@ -13,9 +17,7 @@ def __init__(self, writer: Any) -> None: self.writer = writer @abstractmethod - def write( - self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any - ) -> None: + def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: """Specify how the writer is used to log data. :param str key: namespace which the input data tuple belongs to. @@ -51,6 +53,33 @@ def log_test_data(self, collect_result: dict, step: int) -> None: """ pass + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + ) -> None: + """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. + + :param int epoch: the epoch in trainer. + :param int env_step: the env_step in trainer. + :param int gradient_step: the gradient_step in trainer. + :param function save_checkpoint_fn: a hook defined by user, see trainer + documentation for detail. + """ + pass + + def restore_data(self) -> Tuple[int, int, int]: + """Return the metadata from existing log. + + If it finds nothing or an error occurs during the recover process, it will + return the default parameters. + + :return: epoch, env_step, gradient_step. + """ + pass + class BasicLogger(BaseLogger): """A loggger that relies on tensorboard SummaryWriter by default to visualize \ @@ -62,6 +91,8 @@ class BasicLogger(BaseLogger): :param int train_interval: the log interval in log_train_data(). Default to 1. :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. + :param int save_interval: the save interval in save_data(). Default to 1 (save at + the end of each epoch). """ def __init__( @@ -70,18 +101,19 @@ def __init__( train_interval: int = 1, test_interval: int = 1, update_interval: int = 1000, + save_interval: int = 1, ) -> None: super().__init__(writer) self.train_interval = train_interval self.test_interval = test_interval self.update_interval = update_interval + self.save_interval = save_interval self.last_log_train_step = -1 self.last_log_test_step = -1 self.last_log_update_step = -1 + self.last_save_step = -1 - def write( - self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any - ) -> None: + def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: self.writer.add_scalar(key, y, global_step=x) def log_train_data(self, collect_result: dict, step: int) -> None: @@ -133,6 +165,39 @@ def log_update_data(self, update_result: dict, step: int) -> None: self.write(k, step, v) self.last_log_update_step = step + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + ) -> None: + if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: + self.last_save_step = epoch + save_checkpoint_fn(epoch, env_step, gradient_step) + self.write("save/epoch", epoch, epoch) + self.write("save/env_step", env_step, env_step) + self.write("save/gradient_step", gradient_step, gradient_step) + + def restore_data(self) -> Tuple[int, int, int]: + ea = event_accumulator.EventAccumulator(self.writer.log_dir) + ea.Reload() + + try: # epoch / gradient_step + epoch = ea.scalars.Items("save/epoch")[-1].step + self.last_save_step = self.last_log_test_step = epoch + gradient_step = ea.scalars.Items("save/gradient_step")[-1].step + self.last_log_update_step = gradient_step + except KeyError: + epoch, gradient_step = 0, 0 + try: # offline trainer doesn't have env_step + env_step = ea.scalars.Items("save/env_step")[-1].step + self.last_log_train_step = env_step + except KeyError: + env_step = 0 + + return epoch, env_step, gradient_step + class LazyLogger(BasicLogger): """A loggger that does nothing. Used as the placeholder in trainer.""" @@ -140,8 +205,6 @@ class LazyLogger(BasicLogger): def __init__(self) -> None: super().__init__(None) # type: ignore - def write( - self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any - ) -> None: + def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None: """The LazyLogger writes nothing.""" pass