diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 08aac4451..f02b34074 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -48,7 +48,7 @@ And to successfully resume from a checkpoint: 1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer; 2. Set ``resume_from_log=True`` with trainer; -We provide an example to show how these steps work: checkout `test_c51.py `_, `test_ppo.py `_ or `test_il_bcq.py `_ by running +We provide an example to show how these steps work: checkout `test_c51.py `_, `test_ppo.py `_ or `test_discrete_bcq.py `_ by running .. code-block:: console diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 7935d9b4b..a426ae1b5 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -192,7 +192,7 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - ckpt_path = os.path.join(log_path, "checkpoint.pth") + ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 22eff1034..8ef69af15 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -222,7 +222,7 @@ def stop_fn(mean_rewards): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - ckpt_path = os.path.join(log_path, "checkpoint.pth") + ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save({"model": policy.state_dict()}, ckpt_path) return ckpt_path diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index e0858edf1..80cf6e330 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -117,7 +117,7 @@ def dist(*logits): dual_clip=args.dual_clip, value_clip=args.value_clip, gae_lambda=args.gae_lambda, - action_space=env.action_space + action_space=env.action_space, ) # collector train_collector = Collector( @@ -125,33 +125,37 @@ def dist(*logits): ) test_collector = Collector(policy, test_envs) # log - log_path = os.path.join(args.logdir, args.task, 'ppo') + log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, "checkpoint.pth") + # Example: saving by epoch num + # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth') + "model": policy.state_dict(), + "optim": optim.state_dict(), + }, ckpt_path ) + return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") - ckpt_path = os.path.join(log_path, 'checkpoint.pth') + ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint['model']) - optim.load_state_dict(checkpoint['optim']) + policy.load_state_dict(checkpoint["model"]) + optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") @@ -171,7 +175,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn + save_checkpoint_fn=save_checkpoint_fn, ) for epoch, epoch_stat, info in trainer: @@ -181,7 +185,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): assert stop_fn(info["best_reward"]) - if __name__ == '__main__': + if __name__ == "__main__": pprint.pprint(info) # Let's watch its performance! env = gym.make(args.task) @@ -197,5 +201,5 @@ def test_ppo_resume(args=get_args()): test_ppo(args) -if __name__ == '__main__': +if __name__ == "__main__": test_ppo() diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 993c4a80e..3912b0740 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -85,7 +85,7 @@ def test_c51(args=get_args()): hidden_sizes=args.hidden_sizes, device=args.device, softmax=True, - num_atoms=args.num_atoms + num_atoms=args.num_atoms, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = C51Policy( @@ -96,7 +96,7 @@ def test_c51(args=get_args()): args.v_min, args.v_max, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # buffer if args.prioritized_replay: @@ -104,7 +104,7 @@ def test_c51(args=get_args()): args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, - beta=args.beta + beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) @@ -114,12 +114,12 @@ def test_c51(args=get_args()): # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log - log_path = os.path.join(args.logdir, args.task, 'c51') + log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold @@ -140,29 +140,31 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, "checkpoint.pth") + # Example: saving by epoch num + # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth') - ) - pickle.dump( - train_collector.buffer, - open(os.path.join(log_path, 'train_buffer.pkl'), "wb") + "model": policy.state_dict(), + "optim": optim.state_dict(), + }, ckpt_path ) + buffer_path = os.path.join(log_path, "train_buffer.pkl") + pickle.dump(train_collector.buffer, open(buffer_path, "wb")) + return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") - ckpt_path = os.path.join(log_path, 'checkpoint.pth') + ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint['model']) - policy.optim.load_state_dict(checkpoint['optim']) + policy.load_state_dict(checkpoint["model"]) + policy.optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") - buffer_path = os.path.join(log_path, 'train_buffer.pkl') + buffer_path = os.path.join(log_path, "train_buffer.pkl") if os.path.exists(buffer_path): train_collector.buffer = pickle.load(open(buffer_path, "rb")) print("Successfully restore buffer.") @@ -186,11 +188,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn + save_checkpoint_fn=save_checkpoint_fn, ) - assert stop_fn(result['best_reward']) + assert stop_fn(result["best_reward"]) - if __name__ == '__main__': + if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) @@ -214,5 +216,5 @@ def test_pc51(args=get_args()): test_c51(args) -if __name__ == '__main__': +if __name__ == "__main__": test_c51(get_args()) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 5e2345300..78b43dff9 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -98,7 +98,7 @@ def noisy_linear(x, y): "linear_layer": noisy_linear }, { "linear_layer": noisy_linear - }) + }), ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = RainbowPolicy( @@ -109,7 +109,7 @@ def noisy_linear(x, y): args.v_min, args.v_max, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # buffer if args.prioritized_replay: @@ -118,7 +118,7 @@ def noisy_linear(x, y): buffer_num=len(train_envs), alpha=args.alpha, beta=args.beta, - weight_norm=True + weight_norm=True, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) @@ -128,12 +128,12 @@ def noisy_linear(x, y): # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log - log_path = os.path.join(args.logdir, args.task, 'rainbow') + log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold @@ -164,21 +164,23 @@ def test_fn(epoch, env_step): def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, "checkpoint.pth") + # Example: saving by epoch num + # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth') - ) - pickle.dump( - train_collector.buffer, - open(os.path.join(log_path, 'train_buffer.pkl'), "wb") + "model": policy.state_dict(), + "optim": optim.state_dict(), + }, ckpt_path ) + buffer_path = os.path.join(log_path, "train_buffer.pkl") + pickle.dump(train_collector.buffer, open(buffer_path, "wb")) + return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") - ckpt_path = os.path.join(log_path, 'checkpoint.pth') + ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) policy.load_state_dict(checkpoint['model']) @@ -186,7 +188,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") - buffer_path = os.path.join(log_path, 'train_buffer.pkl') + buffer_path = os.path.join(log_path, "train_buffer.pkl") if os.path.exists(buffer_path): train_collector.buffer = pickle.load(open(buffer_path, "rb")) print("Successfully restore buffer.") @@ -210,11 +212,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn + save_checkpoint_fn=save_checkpoint_fn, ) - assert stop_fn(result['best_reward']) + assert stop_fn(result["best_reward"]) - if __name__ == '__main__': + if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) @@ -238,5 +240,5 @@ def test_prainbow(args=get_args()): test_rainbow(args) -if __name__ == '__main__': +if __name__ == "__main__": test_rainbow(get_args()) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 51380fb62..66e42f3e9 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -25,7 +25,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="CartPole-v0") - parser.add_argument('--reward-threshold', type=float, default=None) + parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--lr", type=float, default=3e-4) @@ -37,7 +37,7 @@ def get_args(): parser.add_argument("--epoch", type=int, default=5) parser.add_argument("--update-per-epoch", type=int, default=2000) parser.add_argument("--batch-size", type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) parser.add_argument("--test-num", type=int, default=100) parser.add_argument("--logdir", type=str, default="log") parser.add_argument("--render", type=float, default=0.) @@ -104,33 +104,37 @@ def test_discrete_bcq(args=get_args()): # collector test_collector = Collector(policy, test_envs, exploration_noise=True) - log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') + log_path = os.path.join(args.logdir, args.task, "discrete_bcq") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, "checkpoint.pth") + # Example: saving by epoch num + # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth') + "model": policy.state_dict(), + "optim": optim.state_dict(), + }, ckpt_path ) + return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") - ckpt_path = os.path.join(log_path, 'checkpoint.pth') + ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint['model']) - optim.load_state_dict(checkpoint['optim']) + policy.load_state_dict(checkpoint["model"]) + optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") @@ -147,11 +151,11 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): save_best_fn=save_best_fn, logger=logger, resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn + save_checkpoint_fn=save_checkpoint_fn, ) - assert stop_fn(result['best_reward']) + assert stop_fn(result["best_reward"]) - if __name__ == '__main__': + if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 3ccaeaca7..aa123ec87 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -163,33 +163,37 @@ def dist(*logits): ) test_collector = Collector(policy, test_envs) # log - log_path = os.path.join(args.logdir, args.task, 'gail') + log_path = os.path.join(args.logdir, args.task, "gail") writer = SummaryWriter(log_path) logger = TensorboardLogger(writer, save_interval=args.save_interval) def save_best_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold def save_checkpoint_fn(epoch, env_step, gradient_step): # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, "checkpoint.pth") + # Example: saving by epoch num + # ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth") torch.save( { - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth') + "model": policy.state_dict(), + "optim": optim.state_dict(), + }, ckpt_path ) + return ckpt_path if args.resume: # load from existing checkpoint print(f"Loading agent under {log_path}") - ckpt_path = os.path.join(log_path, 'checkpoint.pth') + ckpt_path = os.path.join(log_path, "checkpoint.pth") if os.path.exists(ckpt_path): checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint['model']) - optim.load_state_dict(checkpoint['optim']) + policy.load_state_dict(checkpoint["model"]) + optim.load_state_dict(checkpoint["optim"]) print("Successfully restore policy and optim.") else: print("Fail to restore policy and optim.") @@ -211,9 +215,9 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn, ) - assert stop_fn(result['best_reward']) + assert stop_fn(result["best_reward"]) - if __name__ == '__main__': + if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) @@ -224,5 +228,5 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): print(f"Final reward: {rews.mean()}, length: {lens.mean()}") -if __name__ == '__main__': +if __name__ == "__main__": test_gail() diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 25040562c..b4dcfb6be 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -58,9 +58,9 @@ class BaseTrainer(ABC): :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. - :param function save_checkpoint_fn: a function to save training process, with - the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; - you can save whatever you want. + :param function save_checkpoint_fn: a function to save training process and + return the saved checkpoint path, with the signature ``f(epoch: int, + env_step: int, gradient_step: int) -> str``; you can save whatever you want. :param bool resume_from_log: resume env_step/gradient_step and other metadata from existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> @@ -147,7 +147,7 @@ def __init__( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_best_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), @@ -259,7 +259,7 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: if self.iter_num > 1: # iterator exhaustion check - if self.epoch >= self.max_epoch: + if self.epoch > self.max_epoch: raise StopIteration # exit flag 1, when stop_fn succeeds in train_step or test_step diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index e60b909b1..33028ca99 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -31,10 +31,10 @@ class OfflineTrainer(BaseTrainer): :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. - :param function save_checkpoint_fn: a function to save training process, - with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> - None``; you can save whatever you want. Because offline-RL doesn't have - env_step, the env_step is always 0 here. + :param function save_checkpoint_fn: a function to save training process and + return the saved checkpoint path, with the signature ``f(epoch: int, + env_step: int, gradient_step: int) -> str``; you can save whatever you want. + Because offline-RL doesn't have env_step, the env_step is always 0 here. :param bool resume_from_log: resume gradient_step and other metadata from existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> @@ -67,7 +67,7 @@ def __init__( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_best_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index e6a42ae9a..628239929 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -40,9 +40,9 @@ class OffpolicyTrainer(BaseTrainer): :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. - :param function save_checkpoint_fn: a function to save training process, with - the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; - you can save whatever you want. + :param function save_checkpoint_fn: a function to save training process and + return the saved checkpoint path, with the signature ``f(epoch: int, + env_step: int, gradient_step: int) -> str``; you can save whatever you want. :param bool resume_from_log: resume env_step/gradient_step and other metadata from existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> @@ -80,7 +80,7 @@ def __init__( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_best_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 0739cd99a..641bda3bf 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -42,9 +42,9 @@ class OnpolicyTrainer(BaseTrainer): :param function save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously. - :param function save_checkpoint_fn: a function to save training process, - with the signature ``f(epoch: int, env_step: int, gradient_step: int) - -> None``; you can save whatever you want. + :param function save_checkpoint_fn: a function to save training process and + return the saved checkpoint path, with the signature ``f(epoch: int, + env_step: int, gradient_step: int) -> str``; you can save whatever you want. :param bool resume_from_log: resume env_step/gradient_step and other metadata from existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> @@ -88,7 +88,7 @@ def __init__( test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_best_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, resume_from_log: bool = False, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index e6d15a14c..0eca53c0c 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -94,7 +94,7 @@ def save_data( epoch: int, env_step: int, gradient_step: int, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 843ff012e..2583673e6 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -47,7 +47,7 @@ def save_data( epoch: int, env_step: int, gradient_step: int, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, ) -> None: if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: self.last_save_step = epoch diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index e63a7bc7f..c8e78d5d0 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -97,7 +97,7 @@ def save_data( epoch: int, env_step: int, gradient_step: int, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, ) -> None: """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. @@ -118,7 +118,7 @@ def save_data( "save/epoch": epoch, "save/env_step": env_step, "save/gradient_step": gradient_step, - "checkpoint_path": str(checkpoint_path) + "checkpoint_path": str(checkpoint_path), } ) checkpoint_artifact.add_file(str(checkpoint_path)) @@ -126,7 +126,7 @@ def save_data( def restore_data(self) -> Tuple[int, int, int]: checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore - 'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore + f"run_{self.wandb_run.id}_checkpoint:latest" # type: ignore ) assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"