From 8bad065425133ba79f623cb5f5f082d745268a55 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 10:09:53 +0100 Subject: [PATCH 01/40] add docstring :param buffer to offline_trainer in offline.py --- tianshou/trainer/offline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index d2f85bc2a..07fb3dc47 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -33,6 +33,8 @@ def offline_trainer( The "step" in offline trainer means a gradient step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + This buffer must be populated with experiences for offline RL. :param Collector test_collector: the collector used for testing. If it's None, then no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training From ff9c0c9343230eb0aff1e82fa42a40cc454eb7f4 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 13:43:02 +0100 Subject: [PATCH 02/40] Add param yield_epoch to trainers. if True, converts the function into a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. --- tianshou/trainer/offline.py | 35 ++++++++++++++++++++++++------ tianshou/trainer/offpolicy.py | 40 +++++++++++++++++++++++++++++------ tianshou/trainer/onpolicy.py | 40 +++++++++++++++++++++++++++++------ 3 files changed, 97 insertions(+), 18 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 07fb3dc47..87fb69ee3 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -27,6 +27,7 @@ def offline_trainer( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + yield_epoch: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. @@ -68,6 +69,8 @@ def offline_trainer( :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. + :param bool yield_epoch: if True, converts the function into a generator that yields + a 3-tuple (epoch, stats, info) of train results on every epoch :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -91,6 +94,7 @@ def offline_trainer( for epoch in range(1 + start_epoch, 1 + max_epoch): policy.train() + with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: for _ in t: gradient_step += 1 @@ -102,7 +106,10 @@ def offline_trainer( data[k] = f"{losses[k]:.3f}" logger.log_update_data(losses, gradient_step) t.set_postfix(**data) + logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) + # epoch_stat for yield clause + epoch_stat = {**stat, "gradient_step": gradient_step} # test if test_collector is not None: test_result = test_episode( @@ -119,15 +126,31 @@ def offline_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) + epoch_stat.update({"test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + }) if stop_fn and stop_fn(best_reward): break + if yield_epoch: + if test_collector is None: + info = gather_info(start_time, None, None, 0.0, 0.0) + else: + info = gather_info( + start_time, None, test_collector, best_reward, best_reward_std + ) + yield epoch, epoch_stat, info + if test_collector is None and save_fn: save_fn(policy) - if test_collector is None: - return gather_info(start_time, None, None, 0.0, 0.0) - else: - return gather_info( - start_time, None, test_collector, best_reward, best_reward_std - ) + if not yield_epoch: + if test_collector is None: + return gather_info(start_time, None, None, 0.0, 0.0) + else: + return gather_info( + start_time, None, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 9b8727b24..1e92ef644 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -31,6 +31,7 @@ def offpolicy_trainer( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, + yield_epoch: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. @@ -81,6 +82,8 @@ def offpolicy_trainer( training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. + :param bool yield_epoch: if True, converts the function into a generator that yields + a 3-tuple (epoch, stats, info) of train results on every epoch :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -163,6 +166,15 @@ def offpolicy_trainer( if t.n <= t.total: t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) + # epoch_stat for yield clause + epoch_stat = {**stat, "gradient_step": gradient_step} + epoch_stat.update({ + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + }) # test if test_collector is not None: test_result = test_episode( @@ -179,15 +191,31 @@ def offpolicy_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) + epoch_stat.update({"test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + }) if stop_fn and stop_fn(best_reward): break + if yield_epoch: + if test_collector is None: + info = gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + info = gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) + yield epoch, epoch_stat, info + if test_collector is None and save_fn: save_fn(policy) - if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + if not yield_epoch: + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 251c55637..04e8bc41f 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -32,6 +32,7 @@ def onpolicy_trainer( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, + yield_epoch: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. @@ -83,6 +84,8 @@ def onpolicy_trainer( training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. + :param bool yield_epoch: if True, converts the function into a generator that yields + a 3-tuple (epoch, stats, info) of train results on every epoch :return: See :func:`~tianshou.trainer.gather_info`. @@ -179,6 +182,15 @@ def onpolicy_trainer( if t.n <= t.total: t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) + # epoch_stat for yield clause + epoch_stat = {**stat, "gradient_step": gradient_step} + epoch_stat.update({ + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + }) # test if test_collector is not None: test_result = test_episode( @@ -195,15 +207,31 @@ def onpolicy_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) + epoch_stat.update({"test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + }) if stop_fn and stop_fn(best_reward): break + if yield_epoch: + if test_collector is None: + info = gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + info = gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) + yield epoch, epoch_stat, info + if test_collector is None and save_fn: save_fn(policy) - if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + if not yield_epoch: + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) From 2b72992bb3313844e7007f0ee3ff4ed1cc0d3bf3 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 17:49:46 +0100 Subject: [PATCH 03/40] Add trainer geneators for offline.py, offpolicy.py and onpolicy.py . Add tests for trainer generators. --- test/continuous/test_ppo_trainer_generator.py | 199 ++++++++++++++ .../test_sac_with_il_trainer_generator.py | 243 ++++++++++++++++++ test/offline/test_cql_trainer_generator.py | 226 ++++++++++++++++ tianshou/trainer/__init__.py | 9 +- tianshou/trainer/offline.py | 142 +++++++++- tianshou/trainer/offpolicy.py | 237 +++++++++++++++-- tianshou/trainer/onpolicy.py | 232 ++++++++++++++++- 7 files changed, 1261 insertions(+), 27 deletions(-) create mode 100644 test/continuous/test_ppo_trainer_generator.py create mode 100644 test/continuous/test_sac_with_il_trainer_generator.py create mode 100644 test/offline/test_cql_trainer_generator.py diff --git a/test/continuous/test_ppo_trainer_generator.py b/test/continuous/test_ppo_trainer_generator.py new file mode 100644 index 000000000..3a064997f --- /dev/null +++ b/test/continuous/test_ppo_trainer_generator.py @@ -0,0 +1,199 @@ +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer_generator +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.95) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=150000) + parser.add_argument('--episode-per-collect', type=int, default=16) + parser.add_argument('--repeat-per-collect', type=int, default=2) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + # ppo special + parser.add_argument('--vf-coef', type=float, default=0.25) + parser.add_argument('--ent-coef', type=float, default=0.0) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=1) + parser.add_argument('--norm-adv', type=int, default=1) + parser.add_argument('--recompute-adv', type=int, default=0) + parser.add_argument('--resume', action="store_true") + parser.add_argument("--save-interval", type=int, default=4) + args = parser.parse_known_args()[0] + return args + + +def test_ppo(args=get_args()): + env = gym.make(args.task) + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -250 + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) + critic = Critic( + Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + device=args.device + ).to(args.device) + actor_critic = ActorCritic(actor, critic) + # orthogonal initialization + for m in actor_critic.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + + # replace DiagGuassian with Independent(Normal) which is equivalent + # pass *logits to be consistent with policy.forward + def dist(*logits): + return Independent(Normal(*logits), 1) + + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + max_grad_norm=args.max_grad_norm, + eps_clip=args.eps_clip, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv, + dual_clip=args.dual_clip, + value_clip=args.value_clip, + gae_lambda=args.gae_lambda, + action_space=env.action_space + ) + # collector + train_collector = Collector( + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'ppo') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer, save_interval=args.save_interval) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) + + if args.resume: + # load from existing checkpoint + print(f"Loading agent under {log_path}") + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location=args.device) + policy.load_state_dict(checkpoint['model']) + optim.load_state_dict(checkpoint['optim']) + print("Successfully restore policy and optim.") + else: + print("Fail to restore policy and optim.") + + # trainer + trainer = onpolicy_trainer_generator( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) + print(trainer) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + result = info + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +def test_ppo_resume(args=get_args()): + args.resume = True + test_ppo(args) + + +if __name__ == '__main__': + test_ppo() diff --git a/test/continuous/test_sac_with_il_trainer_generator.py b/test/continuous/test_sac_with_il_trainer_generator.py new file mode 100644 index 000000000..465ec9981 --- /dev/null +++ b/test/continuous/test_sac_with_il_trainer_generator.py @@ -0,0 +1,243 @@ +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import ImitationPolicy, SACPolicy +from tianshou.trainer import offpolicy_trainer_generator +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--il-lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', type=int, default=1) + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--il-step-per-epoch', type=int, default=500) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument( + '--imitation-hidden-sizes', type=int, nargs='*', default=[128, 128] + ) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--rew-norm', action="store_true", default=False) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + args = parser.parse_known_args()[0] + return args + + +def test_sac_with_il(args=get_args()): + torch.set_num_threads(1) # we just need only one thread for NN + env = gym.make(args.task) + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -250 + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = SACPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + reward_normalization=args.rew_norm, + estimation_step=args.n_step, + action_space=env.action_space + ) + # collector + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True + ) + test_collector = Collector(policy, test_envs) + # train_collector.collect(n_step=args.buffer_size) + # log + log_path = os.path.join(args.logdir, args.task, 'sac') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + # trainer + trainer = offpolicy_trainer_generator( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) + print(trainer) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + result = info + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + # here we define an imitation collector with a trivial policy + policy.eval() + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -300 # lower the goal + net = Actor( + Net( + args.state_shape, + hidden_sizes=args.imitation_hidden_sizes, + device=args.device + ), + args.action_shape, + max_action=args.max_action, + device=args.device + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) + il_policy = ImitationPolicy( + net, + optim, + action_space=env.action_space, + action_scaling=True, + action_bound_method="clip" + ) + il_test_collector = Collector( + il_policy, + DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + ) + train_collector.reset() + trainer = offpolicy_trainer_generator( + il_policy, + train_collector, + il_test_collector, + args.epoch, + args.il_step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) + print(trainer) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + result = info + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + il_policy.eval() + collector = Collector(il_policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_sac_with_il() diff --git a/test/offline/test_cql_trainer_generator.py b/test/offline/test_cql_trainer_generator.py new file mode 100644 index 000000000..20490743f --- /dev/null +++ b/test/offline/test_cql_trainer_generator.py @@ -0,0 +1,226 @@ +import argparse +import datetime +import os +import pickle +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import CQLPolicy +from tianshou.trainer import offline_trainer_generator +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic + +if __name__ == "__main__": + from gather_pendulum_data import expert_file_name, gather_data +else: # pytest + from test.offline.gather_pendulum_data import expert_file_name, gather_data + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', default=True, action='store_true') + parser.add_argument('--alpha-lr', type=float, default=1e-3) + parser.add_argument('--cql-alpha-lr', type=float, default=1e-3) + parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--batch-size', type=int, default=64) + + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--cql-weight", type=float, default=1.0) + parser.add_argument("--with-lagrange", type=bool, default=True) + parser.add_argument("--lagrange-threshold", type=float, default=10.0) + parser.add_argument("--gamma", type=float, default=0.99) + + parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=1 / 35) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + args = parser.parse_known_args()[0] + return args + + +def test_cql_trainer_generator(args=get_args()): + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -1200 # too low? + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + # model + # actor network + net_a = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = ActorProb( + net_a, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True, + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + # critic network + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = CQLPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + cql_alpha_lr=args.cql_alpha_lr, + cql_weight=args.cql_weight, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + temperature=args.temperature, + with_lagrange=args.with_lagrange, + lagrange_threshold=args.lagrange_threshold, + min_action=np.min(env.action_space.low), + max_action=np.max(env.action_space.high), + device=args.device, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + # buffer has been gathered + # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' + log_path = os.path.join(args.logdir, args.task, 'cql', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def watch(): + policy.load_state_dict( + torch.load( + os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu') + ) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + # trainer + trainer = offline_trainer_generator( + policy, + buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + stop_fn=stop_fn, + logger=logger, + ) + print(trainer) + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + result = info + assert stop_fn(result['best_reward']) + + # Let's watch its performance! + if __name__ == '__main__': + pprint.pprint(result) + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_cql() diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 11b3a95ef..4f542b865 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -3,14 +3,17 @@ # isort:skip_file from tianshou.trainer.utils import test_episode, gather_info -from tianshou.trainer.onpolicy import onpolicy_trainer -from tianshou.trainer.offpolicy import offpolicy_trainer -from tianshou.trainer.offline import offline_trainer +from tianshou.trainer.onpolicy import onpolicy_trainer, onpolicy_trainer_generator +from tianshou.trainer.offpolicy import offpolicy_trainer, offpolicy_trainer_generator +from tianshou.trainer.offline import offline_trainer, offline_trainer_generator __all__ = [ "offpolicy_trainer", + "offpolicy_trainer_generator", "onpolicy_trainer", + "onpolicy_trainer_generator", "offline_trainer", + "offline_trainer_generator", "test_episode", "gather_info", ] diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 87fb69ee3..56e3cceac 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, Optional, Union, Generator, Tuple import numpy as np import tqdm @@ -80,6 +80,9 @@ def offline_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() + if yield_epoch: + yield 0, {}, {} + if test_collector is not None: test_c: Collector = test_collector test_collector.reset_stat() @@ -154,3 +157,140 @@ def offline_trainer( return gather_info( start_time, None, test_collector, best_reward, best_reward_std ) + + +def offline_trainer_generator( + policy: BasePolicy, + buffer: ReplayBuffer, + test_collector: Optional[Collector], + max_epoch: int, + update_per_epoch: int, + episode_per_test: int, + batch_size: int, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, +) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], None, None]: + """A wrapper for offline trainer procedure. + Returns a generator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + The "step" in offline trainer means a gradient step. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + This buffer must be populated with experiences for offline RL. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int update_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function save_checkpoint_fn: a function to save training process, with the + signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can + save whatever you want. Because offline-RL doesn't have env_step, the env_step + is always 0 here. + :param bool resume_from_log: resume gradient_step and other metadata from existing + tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during updating/testing. + Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + start_epoch, gradient_step = 0, 0 + if resume_from_log: + start_epoch, _, gradient_step = logger.restore_data() + stat: Dict[str, MovAvg] = defaultdict(MovAvg) + start_time = time.time() + + if test_collector is not None: + test_c: Collector = test_collector + test_collector.reset_stat() + test_result = test_episode( + policy, test_c, test_fn, start_epoch, episode_per_test, logger, + gradient_step, reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + if save_fn: + save_fn(policy) + + for epoch in range(1 + start_epoch, 1 + max_epoch): + policy.train() + + with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: + for _ in t: + gradient_step += 1 + losses = policy.update(batch_size, buffer) + data = {"gradient_step": str(gradient_step)} + for k in losses.keys(): + stat[k].add(losses[k]) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.3f}" + logger.log_update_data(losses, gradient_step) + t.set_postfix(**data) + + logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) + # epoch_stat for yield clause + epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + # test + if test_collector is not None: + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, + gradient_step, reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + epoch_stat.update({"test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + }) + + if test_collector is None: + info = gather_info(start_time, None, None, 0.0, 0.0) + else: + info = gather_info( + start_time, None, test_collector, best_reward, best_reward_std + ) + yield epoch, epoch_stat, info + + if test_collector is not None and stop_fn and stop_fn(best_reward): + break + + if test_collector is None and save_fn: + save_fn(policy) + diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 1e92ef644..6fbe88810 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, Optional, Union, Generator, Tuple import numpy as np import tqdm @@ -31,7 +31,6 @@ def offpolicy_trainer( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, - yield_epoch: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. @@ -82,8 +81,6 @@ def offpolicy_trainer( training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. - :param bool yield_epoch: if True, converts the function into a generator that yields - a 3-tuple (epoch, stats, info) of train results on every epoch :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -166,8 +163,210 @@ def offpolicy_trainer( if t.n <= t.total: t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) + + # test + if test_collector is not None: + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + epoch_stat.update({"test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + }) + if stop_fn and stop_fn(best_reward): + break + + if test_collector is None and save_fn: + save_fn(policy) + + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) + + +def offpolicy_trainer_generator( + policy: BasePolicy, + train_collector: Collector, + test_collector: Optional[Collector], + max_epoch: int, + step_per_epoch: int, + step_per_collect: int, + episode_per_test: int, + batch_size: int, + update_per_step: Union[int, float] = 1, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, +) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], None, None]: + """A wrapper for off-policy trainer procedure. + Returns a generator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + The "step" in trainer means an environment step (a.k.a. transition). + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatedly in each epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int/float update_per_step: the number of times the policy network would be + updated per transition after (step_per_collect) transitions are collected, + e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will + be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are + collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function save_checkpoint_fn: a function to save training process, with the + signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can + save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata from + existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + start_epoch, env_step, gradient_step = 0, 0, 0 + if resume_from_log: + start_epoch, env_step, gradient_step = logger.restore_data() + last_rew, last_len = 0.0, 0 + stat: Dict[str, MovAvg] = defaultdict(MovAvg) + start_time = time.time() + train_collector.reset_stat() + test_in_train = test_in_train and ( + train_collector.policy == policy and test_collector is not None + ) + + if test_collector is not None: + test_c: Collector = test_collector # for mypy + test_collector.reset_stat() + test_result = test_episode( + policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, + reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + if save_fn: + save_fn(policy) + + for epoch in range(1 + start_epoch, 1 + max_epoch): + # train + policy.train() + with tqdm.tqdm( + total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + ) as t: + while t.n < t.total: + if train_fn: + train_fn(epoch, env_step) + result = train_collector.collect(n_step=step_per_collect) + if result["n/ep"] > 0 and reward_metric: + rew = reward_metric(result["rews"]) + result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) + env_step += int(result["n/st"]) + t.update(result["n/st"]) + logger.log_train_data(result, env_step) + last_rew = result['rew'] if result["n/ep"] > 0 else last_rew + last_len = result['len'] if result["n/ep"] > 0 else last_len + data = { + "env_step": str(env_step), + "rew": f"{last_rew:.2f}", + "len": str(int(last_len)), + "n/ep": str(int(result["n/ep"])), + "n/st": str(int(result["n/st"])), + } + if result["n/ep"] > 0: + if test_in_train and stop_fn and stop_fn(result["rew"]): + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, + env_step + ) + if stop_fn(test_result["rew"]): + if save_fn: + save_fn(policy) + logger.save_data( + epoch, env_step, gradient_step, save_checkpoint_fn + ) + t.set_postfix(**data) + # epoch_stat for yield clause + epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat.update({ + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + }) + info = gather_info( + start_time, train_collector, test_collector, + test_result["rew"], test_result["rew_std"] + ) + yield epoch, epoch_stat, info + return + else: + policy.train() + for _ in range(round(update_per_step * result["n/st"])): + gradient_step += 1 + losses = policy.update(batch_size, train_collector.buffer) + for k in losses.keys(): + stat[k].add(losses[k]) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.3f}" + logger.log_update_data(losses, gradient_step) + t.set_postfix(**data) + if t.n <= t.total: + t.update() + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = {**stat, "gradient_step": gradient_step} + epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} epoch_stat.update({ "env_step": env_step, "rew": last_rew, @@ -197,25 +396,19 @@ def offpolicy_trainer( "best_reward_std": best_reward_std, "best_epoch": best_epoch }) - if stop_fn and stop_fn(best_reward): - break - - if yield_epoch: - if test_collector is None: - info = gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - info = gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) - yield epoch, epoch_stat, info - - if test_collector is None and save_fn: - save_fn(policy) - if not yield_epoch: if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) + info = gather_info(start_time, train_collector, None, 0.0, 0.0) else: - return gather_info( + info = gather_info( start_time, train_collector, test_collector, best_reward, best_reward_std ) + yield epoch, epoch_stat, info + + if test_collector is not None and stop_fn and stop_fn(best_reward): + break + + if test_collector is None and save_fn: + save_fn(policy) + + diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 04e8bc41f..d2634e2d9 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, Optional, Union, Generator, Tuple import numpy as np import tqdm @@ -235,3 +235,233 @@ def onpolicy_trainer( return gather_info( start_time, train_collector, test_collector, best_reward, best_reward_std ) + + +def onpolicy_trainer_generator( + policy: BasePolicy, + train_collector: Collector, + test_collector: Optional[Collector], + max_epoch: int, + step_per_epoch: int, + repeat_per_collect: int, + episode_per_test: int, + batch_size: int, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, +) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], None, None]: + """A wrapper for on-policy trainer procedure. + Returns a generator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + The "step" in trainer means an environment step (a.k.a. transition). + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, for + example, set it to 2 means the policy needs to learn each given batch data + twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatedly in each epoch. + :param int episode_per_collect: the number of episodes the collector would collect + before the network update, i.e., trainer will collect "episode_per_collect" + episodes and do some policy network update repeatedly in each epoch. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function save_checkpoint_fn: a function to save training process, with the + signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can + save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata from + existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. + + :return: See :func:`~tianshou.trainer.gather_info`. + + .. note:: + + Only either one of step_per_collect and episode_per_collect can be specified. + """ + start_epoch, env_step, gradient_step = 0, 0, 0 + if resume_from_log: + start_epoch, env_step, gradient_step = logger.restore_data() + last_rew, last_len = 0.0, 0 + stat: Dict[str, MovAvg] = defaultdict(MovAvg) + start_time = time.time() + train_collector.reset_stat() + test_in_train = test_in_train and ( + train_collector.policy == policy and test_collector is not None + ) + + if test_collector is not None: + test_c: Collector = test_collector # for mypy + test_collector.reset_stat() + test_result = test_episode( + policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, + reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + if save_fn: + save_fn(policy) + + for epoch in range(1 + start_epoch, 1 + max_epoch): + # train + policy.train() + with tqdm.tqdm( + total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + ) as t: + while t.n < t.total: + if train_fn: + train_fn(epoch, env_step) + result = train_collector.collect( + n_step=step_per_collect, n_episode=episode_per_collect + ) + if result["n/ep"] > 0 and reward_metric: + rew = reward_metric(result["rews"]) + result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) + env_step += int(result["n/st"]) + t.update(result["n/st"]) + logger.log_train_data(result, env_step) + last_rew = result['rew'] if result["n/ep"] > 0 else last_rew + last_len = result['len'] if result["n/ep"] > 0 else last_len + data = { + "env_step": str(env_step), + "rew": f"{last_rew:.2f}", + "len": str(int(last_len)), + "n/ep": str(int(result["n/ep"])), + "n/st": str(int(result["n/st"])), + } + if result["n/ep"] > 0: + if test_in_train and stop_fn and stop_fn(result["rew"]): + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, + env_step + ) + if stop_fn(test_result["rew"]): + if save_fn: + save_fn(policy) + logger.save_data( + epoch, env_step, gradient_step, save_checkpoint_fn + ) + t.set_postfix(**data) + # epoch_stat for yield clause + epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat.update({ + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + }) + info = gather_info( + start_time, train_collector, test_collector, + test_result["rew"], test_result["rew_std"] + ) + yield epoch, epoch_stat, info + return + else: + policy.train() + losses = policy.update( + 0, + train_collector.buffer, + batch_size=batch_size, + repeat=repeat_per_collect + ) + train_collector.reset_buffer(keep_statistics=True) + step = max( + [1] + [len(v) for v in losses.values() if isinstance(v, list)] + ) + gradient_step += step + for k in losses.keys(): + stat[k].add(losses[k]) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.3f}" + logger.log_update_data(losses, gradient_step) + t.set_postfix(**data) + if t.n <= t.total: + t.update() + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) + # epoch_stat for yield clause + epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat.update({ + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + }) + # test + if test_collector is not None: + test_result = test_episode( + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + epoch_stat.update({"test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + }) + + if test_collector is None: + info = gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + info = gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) + yield epoch, epoch_stat, info + + if test_collector is not None and stop_fn and stop_fn(best_reward): + break + + if test_collector is None and save_fn: + save_fn(policy) + + From 9a6a72bb0caf45ca0924e72169ee59954091c5c0 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 18:06:07 +0100 Subject: [PATCH 04/40] fix PEP8 fix offline.py --- tianshou/trainer/offline.py | 43 ++++++++---------------------- tianshou/trainer/offpolicy.py | 20 +++++++------- tianshou/trainer/onpolicy.py | 49 +++++++++-------------------------- 3 files changed, 32 insertions(+), 80 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 56e3cceac..6416bd754 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -27,7 +27,6 @@ def offline_trainer( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, - yield_epoch: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. @@ -69,8 +68,6 @@ def offline_trainer( :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. - :param bool yield_epoch: if True, converts the function into a generator that yields - a 3-tuple (epoch, stats, info) of train results on every epoch :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -80,9 +77,6 @@ def offline_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() - if yield_epoch: - yield 0, {}, {} - if test_collector is not None: test_c: Collector = test_collector test_collector.reset_stat() @@ -111,8 +105,6 @@ def offline_trainer( t.set_postfix(**data) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - # epoch_stat for yield clause - epoch_stat = {**stat, "gradient_step": gradient_step} # test if test_collector is not None: test_result = test_episode( @@ -129,34 +121,18 @@ def offline_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - epoch_stat.update({"test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - }) if stop_fn and stop_fn(best_reward): break - if yield_epoch: - if test_collector is None: - info = gather_info(start_time, None, None, 0.0, 0.0) - else: - info = gather_info( - start_time, None, test_collector, best_reward, best_reward_std - ) - yield epoch, epoch_stat, info - if test_collector is None and save_fn: save_fn(policy) - if not yield_epoch: - if test_collector is None: - return gather_info(start_time, None, None, 0.0, 0.0) - else: - return gather_info( - start_time, None, test_collector, best_reward, best_reward_std - ) + if test_collector is None: + return gather_info(start_time, None, None, 0.0, 0.0) + else: + return gather_info( + start_time, None, test_collector, best_reward, best_reward_std + ) def offline_trainer_generator( @@ -175,7 +151,9 @@ def offline_trainer_generator( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, -) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], None, None]: +) -> Generator[Tuple[int, + Dict[str, Union[float, str]], + Dict[str, Union[float, str]]], None, None]: """A wrapper for offline trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -256,7 +234,8 @@ def offline_trainer_generator( logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat = {**{k: v.get() for k, v in stat.items()}, + "gradient_step": gradient_step} # test if test_collector is not None: test_result = test_episode( diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 6fbe88810..b7ec6e79b 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -163,7 +163,6 @@ def offpolicy_trainer( if t.n <= t.total: t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # test if test_collector is not None: test_result = test_episode( @@ -180,12 +179,6 @@ def offpolicy_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - epoch_stat.update({"test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - }) if stop_fn and stop_fn(best_reward): break @@ -220,7 +213,9 @@ def offpolicy_trainer_generator( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, -) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], None, None]: +) -> Generator[Tuple[int, + Dict[str, Union[float, str]], + Dict[str, Union[float, str]]], None, None]: """A wrapper for off-policy trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -337,7 +332,8 @@ def offpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat = {**{k: v.get() for k, v in stat.items()}, + "gradient_step": gradient_step} epoch_stat.update({ "env_step": env_step, "rew": last_rew, @@ -366,7 +362,8 @@ def offpolicy_trainer_generator( t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat = {**{k: v.get() for k, v in stat.items()}, + "gradient_step": gradient_step} epoch_stat.update({ "env_step": env_step, "rew": last_rew, @@ -401,7 +398,8 @@ def offpolicy_trainer_generator( info = gather_info(start_time, train_collector, None, 0.0, 0.0) else: info = gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std + start_time, train_collector, test_collector, + best_reward, best_reward_std ) yield epoch, epoch_stat, info diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index d2634e2d9..0bc2fdbbb 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -32,7 +32,6 @@ def onpolicy_trainer( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, - yield_epoch: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. @@ -84,8 +83,6 @@ def onpolicy_trainer( training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. - :param bool yield_epoch: if True, converts the function into a generator that yields - a 3-tuple (epoch, stats, info) of train results on every epoch :return: See :func:`~tianshou.trainer.gather_info`. @@ -182,15 +179,6 @@ def onpolicy_trainer( if t.n <= t.total: t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # epoch_stat for yield clause - epoch_stat = {**stat, "gradient_step": gradient_step} - epoch_stat.update({ - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - }) # test if test_collector is not None: test_result = test_episode( @@ -207,34 +195,18 @@ def onpolicy_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - epoch_stat.update({"test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - }) if stop_fn and stop_fn(best_reward): break - if yield_epoch: - if test_collector is None: - info = gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - info = gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) - yield epoch, epoch_stat, info - if test_collector is None and save_fn: save_fn(policy) - if not yield_epoch: - if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) def onpolicy_trainer_generator( @@ -382,7 +354,8 @@ def onpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat = {**{k: v.get() for k, v in stat.items()}, + "gradient_step": gradient_step} epoch_stat.update({ "env_step": env_step, "rew": last_rew, @@ -419,7 +392,8 @@ def onpolicy_trainer_generator( t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, "gradient_step": gradient_step} + epoch_stat = {**{k: v.get() for k, v in stat.items()}, + "gradient_step": gradient_step} epoch_stat.update({ "env_step": env_step, "rew": last_rew, @@ -454,7 +428,8 @@ def onpolicy_trainer_generator( info = gather_info(start_time, train_collector, None, 0.0, 0.0) else: info = gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std + start_time, train_collector, test_collector, + best_reward, best_reward_std ) yield epoch, epoch_stat, info From d05f0e01d7a3c40d58b77dbdd1c3ddbc9e99c679 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 18:10:16 +0100 Subject: [PATCH 05/40] fix PEP8 fix onpolicy.py --- tianshou/trainer/onpolicy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 0bc2fdbbb..1b21b846a 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -230,7 +230,9 @@ def onpolicy_trainer_generator( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, -) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], None, None]: +) -> Generator[Tuple[int, + Dict[str, Union[float, str]], + Dict[str, Union[float, str]]], None, None]: """A wrapper for on-policy trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. From 5566be0aaac54f9f978e3bdbb027d0dd87a482de Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 18:18:09 +0100 Subject: [PATCH 06/40] fix PEP8 --- test/offline/test_cql_trainer_generator.py | 2 +- tianshou/trainer/offline.py | 1 - tianshou/trainer/offpolicy.py | 2 -- tianshou/trainer/onpolicy.py | 2 -- 4 files changed, 1 insertion(+), 6 deletions(-) diff --git a/test/offline/test_cql_trainer_generator.py b/test/offline/test_cql_trainer_generator.py index 20490743f..061f6ea51 100644 --- a/test/offline/test_cql_trainer_generator.py +++ b/test/offline/test_cql_trainer_generator.py @@ -66,7 +66,7 @@ def get_args(): return args -def test_cql_trainer_generator(args=get_args()): +def test_cql(args=get_args()): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 6416bd754..2d3773971 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -272,4 +272,3 @@ def offline_trainer_generator( if test_collector is None and save_fn: save_fn(policy) - diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index b7ec6e79b..8889ba097 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -408,5 +408,3 @@ def offpolicy_trainer_generator( if test_collector is None and save_fn: save_fn(policy) - - diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 1b21b846a..c7d40ad18 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -440,5 +440,3 @@ def onpolicy_trainer_generator( if test_collector is None and save_fn: save_fn(policy) - - From 185c006fa1eb16ada3a3018de7643ece06e9d87b Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 18:30:17 +0100 Subject: [PATCH 07/40] fix yapf --- tianshou/trainer/offline.py | 26 +++++++------ tianshou/trainer/offpolicy.py | 69 ++++++++++++++++++++--------------- tianshou/trainer/onpolicy.py | 69 ++++++++++++++++++++--------------- 3 files changed, 95 insertions(+), 69 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 2d3773971..c887c2efa 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -151,9 +151,8 @@ def offline_trainer_generator( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, -) -> Generator[Tuple[int, - Dict[str, Union[float, str]], - Dict[str, Union[float, str]]], None, None]: +) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], + None, None]: """A wrapper for offline trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -234,8 +233,10 @@ def offline_trainer_generator( logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, - "gradient_step": gradient_step} + epoch_stat = { + **{k: v.get() + for k, v in stat.items()}, "gradient_step": gradient_step + } # test if test_collector is not None: test_result = test_episode( @@ -252,12 +253,15 @@ def offline_trainer_generator( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - epoch_stat.update({"test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - }) + epoch_stat.update( + { + "test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + } + ) if test_collector is None: info = gather_info(start_time, None, None, 0.0, 0.0) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 8889ba097..aab72f65b 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -213,9 +213,8 @@ def offpolicy_trainer_generator( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, -) -> Generator[Tuple[int, - Dict[str, Union[float, str]], - Dict[str, Union[float, str]]], None, None]: +) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], + None, None]: """A wrapper for off-policy trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -332,15 +331,20 @@ def offpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, - "gradient_step": gradient_step} - epoch_stat.update({ - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - }) + epoch_stat = { + **{k: v.get() + for k, v in stat.items()}, "gradient_step": + gradient_step + } + epoch_stat.update( + { + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) info = gather_info( start_time, train_collector, test_collector, test_result["rew"], test_result["rew_std"] @@ -362,15 +366,19 @@ def offpolicy_trainer_generator( t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, - "gradient_step": gradient_step} - epoch_stat.update({ - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - }) + epoch_stat = { + **{k: v.get() + for k, v in stat.items()}, "gradient_step": gradient_step + } + epoch_stat.update( + { + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) # test if test_collector is not None: test_result = test_episode( @@ -387,19 +395,22 @@ def offpolicy_trainer_generator( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - epoch_stat.update({"test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - }) + epoch_stat.update( + { + "test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + } + ) if test_collector is None: info = gather_info(start_time, train_collector, None, 0.0, 0.0) else: info = gather_info( - start_time, train_collector, test_collector, - best_reward, best_reward_std + start_time, train_collector, test_collector, best_reward, + best_reward_std ) yield epoch, epoch_stat, info diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index c7d40ad18..3d2f6eb59 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -230,9 +230,8 @@ def onpolicy_trainer_generator( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, -) -> Generator[Tuple[int, - Dict[str, Union[float, str]], - Dict[str, Union[float, str]]], None, None]: +) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], + None, None]: """A wrapper for on-policy trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -356,15 +355,20 @@ def onpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, - "gradient_step": gradient_step} - epoch_stat.update({ - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - }) + epoch_stat = { + **{k: v.get() + for k, v in stat.items()}, "gradient_step": + gradient_step + } + epoch_stat.update( + { + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) info = gather_info( start_time, train_collector, test_collector, test_result["rew"], test_result["rew_std"] @@ -394,15 +398,19 @@ def onpolicy_trainer_generator( t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = {**{k: v.get() for k, v in stat.items()}, - "gradient_step": gradient_step} - epoch_stat.update({ - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - }) + epoch_stat = { + **{k: v.get() + for k, v in stat.items()}, "gradient_step": gradient_step + } + epoch_stat.update( + { + "env_step": env_step, + "rew": last_rew, + "len": int(last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) # test if test_collector is not None: test_result = test_episode( @@ -419,19 +427,22 @@ def onpolicy_trainer_generator( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - epoch_stat.update({"test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - }) + epoch_stat.update( + { + "test_reward": rew, + "test_reward_std": rew_std, + "best_reward": best_reward, + "best_reward_std": best_reward_std, + "best_epoch": best_epoch + } + ) if test_collector is None: info = gather_info(start_time, train_collector, None, 0.0, 0.0) else: info = gather_info( - start_time, train_collector, test_collector, - best_reward, best_reward_std + start_time, train_collector, test_collector, best_reward, + best_reward_std ) yield epoch, epoch_stat, info From 79f050a2593312199cd801e99951f4cd0d4bc6cd Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 18:56:25 +0100 Subject: [PATCH 08/40] removed comments in format section of Makefile. It produces errors on windows make. --- Makefile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Makefile b/Makefile index da5030ccd..f876ee697 100644 --- a/Makefile +++ b/Makefile @@ -22,10 +22,8 @@ lint: flake8 ${LINT_PATHS} --count --show-source --statistics format: - # sort imports $(call check_install, isort) isort ${LINT_PATHS} - # reformat using yapf $(call check_install, yapf) yapf -ir ${LINT_PATHS} From 4cbc7c8dd190b42ce7c132aad7cbd84e698cd700 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 18:57:02 +0100 Subject: [PATCH 09/40] fix isort --- tianshou/trainer/offline.py | 2 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index c887c2efa..4c8bf6a11 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Optional, Union, Generator, Tuple +from typing import Callable, Dict, Generator, Optional, Tuple, Union import numpy as np import tqdm diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index aab72f65b..3a98800f5 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Optional, Union, Generator, Tuple +from typing import Callable, Dict, Generator, Optional, Tuple, Union import numpy as np import tqdm diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 3d2f6eb59..9923cabe4 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Optional, Union, Generator, Tuple +from typing import Callable, Dict, Generator, Optional, Tuple, Union import numpy as np import tqdm From ffbe30a62b72f232379f894fee11bfdc9389d512 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 19:57:10 +0100 Subject: [PATCH 10/40] fix rare error with dict with mypy --- tianshou/trainer/offline.py | 11 ++++------- tianshou/trainer/offpolicy.py | 18 ++++++------------ tianshou/trainer/onpolicy.py | 18 ++++++------------ 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 4c8bf6a11..e5dcc009c 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Generator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union import numpy as np import tqdm @@ -151,8 +151,7 @@ def offline_trainer_generator( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, -) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], - None, None]: +) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: """A wrapper for offline trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -233,10 +232,8 @@ def offline_trainer_generator( logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = { - **{k: v.get() - for k, v in stat.items()}, "gradient_step": gradient_step - } + epoch_stat: Dict[str, Any] = {k: v.get() for k, v in stat.items()} + epoch_stat["gradient_step"] = gradient_step # test if test_collector is not None: test_result = test_episode( diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 3a98800f5..53e47b6cb 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Generator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union import numpy as np import tqdm @@ -213,8 +213,7 @@ def offpolicy_trainer_generator( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, -) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], - None, None]: +) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: """A wrapper for off-policy trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -331,11 +330,8 @@ def offpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat = { - **{k: v.get() - for k, v in stat.items()}, "gradient_step": - gradient_step - } + epoch_stat: Dict[str, Any] = {k: v.get() for k, v in stat.items()} + epoch_stat["gradient_step"] = gradient_step epoch_stat.update( { "env_step": env_step, @@ -366,10 +362,8 @@ def offpolicy_trainer_generator( t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = { - **{k: v.get() - for k, v in stat.items()}, "gradient_step": gradient_step - } + epoch_stat = {k: v.get() for k, v in stat.items()} + epoch_stat["gradient_step"] = gradient_step epoch_stat.update( { "env_step": env_step, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 9923cabe4..80b4e2b72 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Callable, Dict, Generator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union import numpy as np import tqdm @@ -230,8 +230,7 @@ def onpolicy_trainer_generator( logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, -) -> Generator[Tuple[int, Dict[str, Union[float, str]], Dict[str, Union[float, str]]], - None, None]: +) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: """A wrapper for on-policy trainer procedure. Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. @@ -355,11 +354,8 @@ def onpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat = { - **{k: v.get() - for k, v in stat.items()}, "gradient_step": - gradient_step - } + epoch_stat: Dict[str, Any] = {k: v.get() for k, v in stat.items()} + epoch_stat["gradient_step"] = gradient_step epoch_stat.update( { "env_step": env_step, @@ -398,10 +394,8 @@ def onpolicy_trainer_generator( t.update() logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # epoch_stat for yield clause - epoch_stat = { - **{k: v.get() - for k, v in stat.items()}, "gradient_step": gradient_step - } + epoch_stat = {k: v.get() for k, v in stat.items()} + epoch_stat["gradient_step"] = gradient_step epoch_stat.update( { "env_step": env_step, From 23f00d224b9698210a884690f1ac8d9a542a37ab Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 20:06:10 +0100 Subject: [PATCH 11/40] fix rare error with dict with mypy --- tianshou/trainer/offpolicy.py | 5 ++++- tianshou/trainer/onpolicy.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 53e47b6cb..e76e26cca 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -330,7 +330,10 @@ def offpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat: Dict[str, Any] = {k: v.get() for k, v in stat.items()} + epoch_stat: Dict[str, Any] = { + k: v.get() + for k, v in stat.items() + } epoch_stat["gradient_step"] = gradient_step epoch_stat.update( { diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 80b4e2b72..16197e077 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -354,7 +354,10 @@ def onpolicy_trainer_generator( ) t.set_postfix(**data) # epoch_stat for yield clause - epoch_stat: Dict[str, Any] = {k: v.get() for k, v in stat.items()} + epoch_stat: Dict[str, Any] = { + k: v.get() + for k, v in stat.items() + } epoch_stat["gradient_step"] = gradient_step epoch_stat.update( { From f64eb2dad6a7a0d573630680fb1f8b0ecf3c295c Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 5 Mar 2022 20:12:44 +0100 Subject: [PATCH 12/40] fix docstrings --- tianshou/trainer/offline.py | 3 ++- tianshou/trainer/offpolicy.py | 3 ++- tianshou/trainer/onpolicy.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index e5dcc009c..226f40e3f 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -152,7 +152,8 @@ def offline_trainer_generator( logger: BaseLogger = LazyLogger(), verbose: bool = True, ) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: - """A wrapper for offline trainer procedure. + """A generator wrapper for offline trainer procedure. + Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index e76e26cca..39c92a782 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -214,7 +214,8 @@ def offpolicy_trainer_generator( verbose: bool = True, test_in_train: bool = True, ) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: - """A wrapper for off-policy trainer procedure. + """A generator wrapper for off-policy trainer procedure. + Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 16197e077..c7141d84b 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -231,7 +231,8 @@ def onpolicy_trainer_generator( verbose: bool = True, test_in_train: bool = True, ) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: - """A wrapper for on-policy trainer procedure. + """A generator wrapper for on-policy trainer procedure. + Returns a generator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. From b6b0ed7bba62c702492ae5c3f7b85201f8977328 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 6 Mar 2022 02:21:59 +0100 Subject: [PATCH 13/40] refactored offline.py to one iterator class --- test/offline/test_cql.py | 24 +- test/offline/test_cql_trainer_generator.py | 226 ----------- tianshou/trainer/__init__.py | 4 +- tianshou/trainer/offline.py | 418 +++++++++------------ 4 files changed, 206 insertions(+), 466 deletions(-) delete mode 100644 test/offline/test_cql_trainer_generator.py diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index ce780ac0f..79ac55c58 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import CQLPolicy -from tianshou.trainer import offline_trainer +from tianshou.trainer import offline_trainer, offline_trainer_iter from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -204,6 +204,28 @@ def watch(): ) assert stop_fn(result['best_reward']) + # trainer + trainer = offline_trainer_iter( + policy, + buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + stop_fn=stop_fn, + logger=logger, + ) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + result_iter = info + assert stop_fn(result_iter['best_reward']) + # Let's watch its performance! if __name__ == '__main__': pprint.pprint(result) diff --git a/test/offline/test_cql_trainer_generator.py b/test/offline/test_cql_trainer_generator.py deleted file mode 100644 index 061f6ea51..000000000 --- a/test/offline/test_cql_trainer_generator.py +++ /dev/null @@ -1,226 +0,0 @@ -import argparse -import datetime -import os -import pickle -import pprint - -import gym -import numpy as np -import torch -from torch.utils.tensorboard import SummaryWriter - -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import SubprocVectorEnv -from tianshou.policy import CQLPolicy -from tianshou.trainer import offline_trainer_generator -from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic - -if __name__ == "__main__": - from gather_pendulum_data import expert_file_name, gather_data -else: # pytest - from test.offline.gather_pendulum_data import expert_file_name, gather_data - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) - parser.add_argument('--actor-lr', type=float, default=1e-3) - parser.add_argument('--critic-lr', type=float, default=1e-3) - parser.add_argument('--alpha', type=float, default=0.2) - parser.add_argument('--auto-alpha', default=True, action='store_true') - parser.add_argument('--alpha-lr', type=float, default=1e-3) - parser.add_argument('--cql-alpha-lr', type=float, default=1e-3) - parser.add_argument("--start-timesteps", type=int, default=10000) - parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--batch-size', type=int, default=64) - - parser.add_argument("--tau", type=float, default=0.005) - parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--cql-weight", type=float, default=1.0) - parser.add_argument("--with-lagrange", type=bool, default=True) - parser.add_argument("--lagrange-threshold", type=float, default=10.0) - parser.add_argument("--gamma", type=float, default=0.99) - - parser.add_argument("--eval-freq", type=int, default=1) - parser.add_argument('--test-num', type=int, default=10) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=1 / 35) - parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' - ) - parser.add_argument('--resume-path', type=str, default=None) - parser.add_argument( - '--watch', - default=False, - action='store_true', - help='watch the play of pre-trained policy only', - ) - parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) - args = parser.parse_known_args()[0] - return args - - -def test_cql(args=get_args()): - if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): - if args.load_buffer_name.endswith(".hdf5"): - buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) - else: - buffer = pickle.load(open(args.load_buffer_name, "rb")) - else: - buffer = gather_data() - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] # float - if args.task == 'Pendulum-v0': - env.spec.reward_threshold = -1200 # too low? - - args.state_dim = args.state_shape[0] - args.action_dim = args.action_shape[0] - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)] - ) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - test_envs.seed(args.seed) - - # model - # actor network - net_a = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device, - ) - actor = ActorProb( - net_a, - action_shape=args.action_shape, - max_action=args.max_action, - device=args.device, - unbounded=True, - conditioned_sigma=True, - ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - - # critic network - net_c1 = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, - device=args.device, - ) - net_c2 = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, - device=args.device, - ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - - if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) - - policy = CQLPolicy( - actor, - actor_optim, - critic1, - critic1_optim, - critic2, - critic2_optim, - cql_alpha_lr=args.cql_alpha_lr, - cql_weight=args.cql_weight, - tau=args.tau, - gamma=args.gamma, - alpha=args.alpha, - temperature=args.temperature, - with_lagrange=args.with_lagrange, - lagrange_threshold=args.lagrange_threshold, - min_action=np.min(env.action_space.low), - max_action=np.max(env.action_space.high), - device=args.device, - ) - - # load a previous policy - if args.resume_path: - policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) - print("Loaded agent from: ", args.resume_path) - - # collector - # buffer has been gathered - # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) - # log - t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") - log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' - log_path = os.path.join(args.logdir, args.task, 'cql', log_file) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) - - def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold - - def watch(): - policy.load_state_dict( - torch.load( - os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu') - ) - ) - policy.eval() - collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) - - # trainer - trainer = offline_trainer_generator( - policy, - buffer, - test_collector, - args.epoch, - args.step_per_epoch, - args.test_num, - args.batch_size, - save_fn=save_fn, - stop_fn=stop_fn, - logger=logger, - ) - print(trainer) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) - - result = info - assert stop_fn(result['best_reward']) - - # Let's watch its performance! - if __name__ == '__main__': - pprint.pprint(result) - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - - -if __name__ == '__main__': - test_cql() diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 4f542b865..01599bbc5 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -5,7 +5,7 @@ from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer, onpolicy_trainer_generator from tianshou.trainer.offpolicy import offpolicy_trainer, offpolicy_trainer_generator -from tianshou.trainer.offline import offline_trainer, offline_trainer_generator +from tianshou.trainer.offline import offline_trainer, offline_trainer_iter __all__ = [ "offpolicy_trainer", @@ -13,7 +13,7 @@ "onpolicy_trainer", "onpolicy_trainer_generator", "offline_trainer", - "offline_trainer_generator", + "offline_trainer_iter", "test_episode", "gather_info", ] diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 226f40e3f..d388e2ad1 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,6 +1,6 @@ import time from collections import defaultdict -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import tqdm @@ -11,266 +11,210 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config -def offline_trainer( - policy: BasePolicy, - buffer: ReplayBuffer, - test_collector: Optional[Collector], - max_epoch: int, - update_per_epoch: int, - episode_per_test: int, - batch_size: int, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, -) -> Dict[str, Union[float, str]]: - """A wrapper for offline trainer procedure. +class OffLineTrainer: + + def __init__( + self, + policy: BasePolicy, + buffer: ReplayBuffer, + test_collector: Optional[Collector], + max_epoch: int, + update_per_epoch: int, + episode_per_test: int, + batch_size: int, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + ): + """A generator wrapper for offline trainer procedure. + + Returns a generator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + The "step" in offline trainer means a gradient step. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + This buffer must be populated with experiences for offline RL. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. + :param int update_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param function test_fn: a hook called at the beginning of testing in each + epoch. + It can be used to perform custom additional operations, with the signature + ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, + with the signature ``f(epoch: int, env_step: int, + gradient_step: int) -> None``; you can save whatever you want. Because + offline-RL doesn't have env_step, the env_step is always 0 here. + :param bool resume_from_log: resume gradient_step and other metadata from + existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: + np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape + (num_episode,)``, used in multi-agent RL. We need to return a single scalar + for each episode's result to monitor training in the multi-agent RL + setting. This function specifies what is the desired metric, e.g., the + reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + updating/testing. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + self.is_run = False + self.policy = policy + self.buffer = buffer + self.test_collector = test_collector + self.max_epoch = max_epoch + self.update_per_epoch = update_per_epoch + self.episode_per_test = episode_per_test + self.batch_size = batch_size + self.test_fn = test_fn + self.stop_fn = stop_fn + self.save_fn = save_fn + self.save_checkpoint_fn = save_checkpoint_fn + + self.reward_metric = reward_metric + self.logger = logger + self.verbose = verbose + + self.start_epoch, self.gradient_step = 0, 0 + if resume_from_log: + self.start_epoch, _, self.gradient_step = logger.restore_data() + self.stat: Dict[str, MovAvg] = defaultdict(MovAvg) + self.start_time = time.time() - The "step" in offline trainer means a gradient step. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - This buffer must be populated with experiences for offline RL. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int update_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in - the policy network. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. Because offline-RL doesn't have env_step, the env_step - is always 0 here. - :param bool resume_from_log: resume gradient_step and other metadata from existing - tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during updating/testing. - Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. - """ - start_epoch, gradient_step = 0, 0 - if resume_from_log: - start_epoch, _, gradient_step = logger.restore_data() - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - - if test_collector is not None: - test_c: Collector = test_collector - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - - for epoch in range(1 + start_epoch, 1 + max_epoch): - policy.train() - - with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: - for _ in t: - gradient_step += 1 - losses = policy.update(batch_size, buffer) - data = {"gradient_step": str(gradient_step)} - for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() - data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) - t.set_postfix(**data) - - logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - # test if test_collector is not None: + self.test_c: Collector = test_collector + test_collector.reset_stat() test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - gradient_step, reward_metric + self.policy, self.test_c, test_fn, self.start_epoch, + self.episode_per_test, self.logger, self.gradient_step, + self.reward_metric ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) - if stop_fn and stop_fn(best_reward): - break + self.best_epoch = self.start_epoch + self.best_reward, self.best_reward_std = test_result["rew"], test_result[ + "rew_std"] - if test_collector is None and save_fn: - save_fn(policy) + if self.save_fn: + self.save_fn(policy) - if test_collector is None: - return gather_info(start_time, None, None, 0.0, 0.0) - else: - return gather_info( - start_time, None, test_collector, best_reward, best_reward_std - ) + self.epoch = self.start_epoch + def __iter__(self): # type: ignore + return self -def offline_trainer_generator( - policy: BasePolicy, - buffer: ReplayBuffer, - test_collector: Optional[Collector], - max_epoch: int, - update_per_epoch: int, - episode_per_test: int, - batch_size: int, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, -) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: - """A generator wrapper for offline trainer procedure. + def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: + self.epoch += 1 + if self.epoch >= self.max_epoch: + if self.test_collector is None and self.save_fn: + self.save_fn(self.policy) + raise StopIteration - Returns a generator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. + self.policy.train() - The "step" in offline trainer means a gradient step. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - This buffer must be populated with experiences for offline RL. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int update_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in - the policy network. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. Because offline-RL doesn't have env_step, the env_step - is always 0 here. - :param bool resume_from_log: resume gradient_step and other metadata from existing - tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during updating/testing. - Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. - """ - start_epoch, gradient_step = 0, 0 - if resume_from_log: - start_epoch, _, gradient_step = logger.restore_data() - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - - if test_collector is not None: - test_c: Collector = test_collector - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - - for epoch in range(1 + start_epoch, 1 + max_epoch): - policy.train() - - with tqdm.trange(update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config) as t: + with tqdm.trange( + self.update_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config + ) as t: for _ in t: - gradient_step += 1 - losses = policy.update(batch_size, buffer) - data = {"gradient_step": str(gradient_step)} + self.gradient_step += 1 + losses = self.policy.update(self.batch_size, self.buffer) + data = {"gradient_step": str(self.gradient_step)} for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() + self.stat[k].add(losses[k]) + losses[k] = self.stat[k].get() data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) + self.logger.log_update_data(losses, self.gradient_step) t.set_postfix(**data) - logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - # epoch_stat for yield clause - epoch_stat: Dict[str, Any] = {k: v.get() for k, v in stat.items()} - epoch_stat["gradient_step"] = gradient_step + self.logger.save_data( + self.epoch, 0, self.gradient_step, self.save_checkpoint_fn + ) + if not self.is_run: + # epoch_stat for yield clause + epoch_stat: Dict[str, Any] = {k: v.get() for k, v in self.stat.items()} + epoch_stat["gradient_step"] = self.gradient_step # test - if test_collector is not None: + if self.test_collector is not None: test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - gradient_step, reward_metric + self.policy, self.test_c, self.test_fn, self.epoch, + self.episode_per_test, self.logger, self.gradient_step, + self.reward_metric ) rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: + if self.best_epoch < 0 or self.best_reward < rew: + self.best_epoch = self.epoch + self.best_reward = rew + self.best_reward_std = rew_std + if self.save_fn: + self.save_fn(self.policy) + if self.verbose: print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" ) - epoch_stat.update( - { - "test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - } - ) + if not self.is_run: + epoch_stat.update( + { + "test_reward": rew, + "test_reward_std": rew_std, + "best_reward": self.best_reward, + "best_reward_std": self.best_reward_std, + "best_epoch": self.best_epoch + } + ) + info = gather_info( + self.start_time, None, self.test_collector, self.best_reward, + self.best_reward_std + ) + return self.epoch, epoch_stat, info + else: + return 0, {}, {} - if test_collector is None: - info = gather_info(start_time, None, None, 0.0, 0.0) + else: + if not self.is_run: + info = gather_info(self.start_time, None, None, 0.0, 0.0) + return self.epoch, epoch_stat, info + else: + return 0, {}, {} + + def run(self) -> Dict[str, Union[float, str]]: + self.is_run = True + for _ in iter(self): + pass + + if self.test_collector is None: + info = gather_info(self.start_time, None, None, 0.0, 0.0) else: info = gather_info( - start_time, None, test_collector, best_reward, best_reward_std + self.start_time, None, self.test_collector, self.best_reward, + self.best_reward_std ) - yield epoch, epoch_stat, info - if test_collector is not None and stop_fn and stop_fn(best_reward): - break + return info + + +def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + return OffLineTrainer(*args, **kwargs).run() + - if test_collector is None and save_fn: - save_fn(policy) +offline_trainer_iter = OffLineTrainer From 0f39eac750cced0be855c7fe0855c2bf1ba86273 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 6 Mar 2022 20:13:12 +0100 Subject: [PATCH 14/40] drop test_sac_with_il_trainer_generator.py --- .../test_sac_with_il_trainer_generator.py | 243 ------------------ 1 file changed, 243 deletions(-) delete mode 100644 test/continuous/test_sac_with_il_trainer_generator.py diff --git a/test/continuous/test_sac_with_il_trainer_generator.py b/test/continuous/test_sac_with_il_trainer_generator.py deleted file mode 100644 index 465ec9981..000000000 --- a/test/continuous/test_sac_with_il_trainer_generator.py +++ /dev/null @@ -1,243 +0,0 @@ -import argparse -import os -import pprint - -import gym -import numpy as np -import torch -from torch.utils.tensorboard import SummaryWriter - -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.policy import ImitationPolicy, SACPolicy -from tianshou.trainer import offpolicy_trainer_generator -from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import Actor, ActorProb, Critic - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--actor-lr', type=float, default=1e-3) - parser.add_argument('--critic-lr', type=float, default=1e-3) - parser.add_argument('--il-lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--tau', type=float, default=0.005) - parser.add_argument('--alpha', type=float, default=0.2) - parser.add_argument('--auto-alpha', type=int, default=1) - parser.add_argument('--alpha-lr', type=float, default=3e-4) - parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=24000) - parser.add_argument('--il-step-per-epoch', type=int, default=500) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument( - '--imitation-hidden-sizes', type=int, nargs='*', default=[128, 128] - ) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', action="store_true", default=False) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' - ) - args = parser.parse_known_args()[0] - return args - - -def test_sac_with_il(args=get_args()): - torch.set_num_threads(1) # we just need only one thread for NN - env = gym.make(args.task) - if args.task == 'Pendulum-v0': - env.spec.reward_threshold = -250 - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)] - ) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)] - ) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net, - args.action_shape, - max_action=args.max_action, - device=args.device, - unbounded=True - ).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, - device=args.device - ) - critic1 = Critic(net_c1, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net_c2 = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - concat=True, - device=args.device - ) - critic2 = Critic(net_c2, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - - if args.auto_alpha: - target_entropy = -np.prod(env.action_space.shape) - log_alpha = torch.zeros(1, requires_grad=True, device=args.device) - alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - args.alpha = (target_entropy, log_alpha, alpha_optim) - - policy = SACPolicy( - actor, - actor_optim, - critic1, - critic1_optim, - critic2, - critic2_optim, - tau=args.tau, - gamma=args.gamma, - alpha=args.alpha, - reward_normalization=args.rew_norm, - estimation_step=args.n_step, - action_space=env.action_space - ) - # collector - train_collector = Collector( - policy, - train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True - ) - test_collector = Collector(policy, test_envs) - # train_collector.collect(n_step=args.buffer_size) - # log - log_path = os.path.join(args.logdir, args.task, 'sac') - writer = SummaryWriter(log_path) - logger = TensorboardLogger(writer) - - def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold - - # trainer - trainer = offpolicy_trainer_generator( - policy, - train_collector, - test_collector, - args.epoch, - args.step_per_epoch, - args.step_per_collect, - args.test_num, - args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_fn=save_fn, - logger=logger - ) - print(trainer) - - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) - - result = info - assert stop_fn(result['best_reward']) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - - # here we define an imitation collector with a trivial policy - policy.eval() - if args.task == 'Pendulum-v0': - env.spec.reward_threshold = -300 # lower the goal - net = Actor( - Net( - args.state_shape, - hidden_sizes=args.imitation_hidden_sizes, - device=args.device - ), - args.action_shape, - max_action=args.max_action, - device=args.device - ).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) - il_policy = ImitationPolicy( - net, - optim, - action_space=env.action_space, - action_scaling=True, - action_bound_method="clip" - ) - il_test_collector = Collector( - il_policy, - DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) - ) - train_collector.reset() - trainer = offpolicy_trainer_generator( - il_policy, - train_collector, - il_test_collector, - args.epoch, - args.il_step_per_epoch, - args.step_per_collect, - args.test_num, - args.batch_size, - stop_fn=stop_fn, - save_fn=save_fn, - logger=logger - ) - print(trainer) - - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) - - result = info - assert stop_fn(result['best_reward']) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - il_policy.eval() - collector = Collector(il_policy, env) - result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - - -if __name__ == '__main__': - test_sac_with_il() From 21cdbe677cc196020db092a4a5878d7439ea28c6 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 6 Mar 2022 20:14:26 +0100 Subject: [PATCH 15/40] improve offline.py with best practices on exhausting iterator and cleaner less code keeping the same functionality --- tianshou/trainer/offline.py | 67 ++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index d388e2ad1..83630fef8 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,5 +1,5 @@ import time -from collections import defaultdict +from collections import defaultdict, deque from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np @@ -12,6 +12,15 @@ class OffLineTrainer: + """ + An iterator wrapper for offline training procedure. + + Returns an iterator that yields a 3 tuple (epoch, stats, info) of train results + on every epoch. + + The "step" in offline trainer means a gradient step. + + """ def __init__( self, @@ -31,13 +40,7 @@ def __init__( logger: BaseLogger = LazyLogger(), verbose: bool = True, ): - """A generator wrapper for offline trainer procedure. - - Returns a generator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. - - The "step" in offline trainer means a gradient step. - + """ :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. This buffer must be populated with experiences for offline RL. @@ -76,8 +79,6 @@ def __init__( :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. """ self.is_run = False self.policy = policy @@ -97,6 +98,8 @@ def __init__( self.verbose = verbose self.start_epoch, self.gradient_step = 0, 0 + self.best_reward, self.best_reward_std = 0.0, 0.0 + if resume_from_log: self.start_epoch, _, self.gradient_step = logger.restore_data() self.stat: Dict[str, MovAvg] = defaultdict(MovAvg) @@ -124,13 +127,17 @@ def __iter__(self): # type: ignore def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: self.epoch += 1 + + # iterator exhaustion check if self.epoch >= self.max_epoch: if self.test_collector is None and self.save_fn: self.save_fn(self.policy) raise StopIteration + # set policy in train mode self.policy.train() + # Performs n update_per_epoch with tqdm.trange( self.update_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config ) as t: @@ -148,10 +155,11 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: self.logger.save_data( self.epoch, 0, self.gradient_step, self.save_checkpoint_fn ) + if not self.is_run: - # epoch_stat for yield clause epoch_stat: Dict[str, Any] = {k: v.get() for k, v in self.stat.items()} epoch_stat["gradient_step"] = self.gradient_step + # test if self.test_collector is not None: test_result = test_episode( @@ -182,33 +190,32 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: "best_epoch": self.best_epoch } ) - info = gather_info( - self.start_time, None, self.test_collector, self.best_reward, - self.best_reward_std - ) - return self.epoch, epoch_stat, info - else: - return 0, {}, {} + # return iterator -> next(self) + if not self.is_run: + info = gather_info( + self.start_time, None, self.test_collector, self.best_reward, + self.best_reward_std + ) + return self.epoch, epoch_stat, info else: - if not self.is_run: - info = gather_info(self.start_time, None, None, 0.0, 0.0) - return self.epoch, epoch_stat, info - else: - return 0, {}, {} + return 0, {}, {} def run(self) -> Dict[str, Union[float, str]]: - self.is_run = True - for _ in iter(self): - pass - - if self.test_collector is None: - info = gather_info(self.start_time, None, None, 0.0, 0.0) - else: + """ + Consume iterator, see itertools-recipes. Use functions that consume + iterators at C speed (feed the entire iterator into a zero-length deque). + """ + try: + self.is_run = True + i = iter(self) + deque(i, maxlen=0) # feed the entire iterator into a zero-length deque info = gather_info( self.start_time, None, self.test_collector, self.best_reward, self.best_reward_std ) + finally: + self.is_run = False return info From 2483dea4aa08db13d985f1f3a5f6b4c1819e14e3 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 6 Mar 2022 20:15:30 +0100 Subject: [PATCH 16/40] Create an Iterator class instead of a generator function, following the sketch in offline.py --- tianshou/trainer/offpolicy.py | 648 +++++++++++++++------------------- 1 file changed, 282 insertions(+), 366 deletions(-) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 39c92a782..528d23e43 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,6 +1,6 @@ import time -from collections import defaultdict -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union +from collections import defaultdict, deque +from typing import Any, Callable, Dict, Optional, Tuple, Union import numpy as np import tqdm @@ -11,409 +11,325 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config -def offpolicy_trainer( - policy: BasePolicy, - train_collector: Collector, - test_collector: Optional[Collector], - max_epoch: int, - step_per_epoch: int, - step_per_collect: int, - episode_per_test: int, - batch_size: int, - update_per_step: Union[int, float] = 1, - train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, - test_in_train: bool = True, -) -> Dict[str, Union[float, str]]: - """A wrapper for off-policy trainer procedure. - - The "step" in trainer means an environment step (a.k.a. transition). - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int step_per_epoch: the number of transitions collected per epoch. - :param int step_per_collect: the number of transitions the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatedly in each epoch. - :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in the - policy network. - :param int/float update_per_step: the number of times the policy network would be - updated per transition after (step_per_collect) transitions are collected, - e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will - be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are - collected by the collector. Default to 1. - :param function train_fn: a hook called at the beginning of training in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata from - existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during - training/testing/updating. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. - """ - start_epoch, env_step, gradient_step = 0, 0, 0 - if resume_from_log: - start_epoch, env_step, gradient_step = logger.restore_data() - last_rew, last_len = 0.0, 0 - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - train_collector.reset_stat() - test_in_train = test_in_train and ( - train_collector.policy == policy and test_collector is not None - ) - - if test_collector is not None: - test_c: Collector = test_collector # for mypy - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, - reward_metric +class OffPolicyTrainer: + """An iterator wrapper for off-policy trainer procedure. + + Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + The "step" in trainer means an environment step (a.k.a. transition).""" + + def __init__( + self, + policy: BasePolicy, + train_collector: Collector, + test_collector: Optional[Collector], + max_epoch: int, + step_per_epoch: int, + step_per_collect: int, + episode_per_test: int, + batch_size: int, + update_per_step: Union[int, float] = 1, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, + ): + """ + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int/float update_per_step: the number of times the policy network would + be updated per transition after (step_per_collect) transitions are + collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256 + , policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256 + transitions are collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with + the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; + you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> + np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to + return a single scalar for each episode's result to monitor training in the + multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. + Default to True. + + """ + self.is_run = False + self.policy = policy + + self.train_collector = train_collector + self.test_collector = test_collector + + self.max_epoch = max_epoch + self.step_per_epoch = step_per_epoch + self.step_per_collect = step_per_collect + self.episode_per_test = episode_per_test + self.batch_size = batch_size + self.update_per_step = update_per_step + + self.train_fn = train_fn + self.test_fn = test_fn + self.stop_fn = stop_fn + self.save_fn = save_fn + self.save_checkpoint_fn = save_checkpoint_fn + + self.reward_metric = reward_metric + self.logger = logger + self.verbose = verbose + self.test_in_train = test_in_train + + self.start_epoch, self.env_step, self.gradient_step = 0, 0, 0 + self.best_reward, self.best_reward_std = 0.0, 0.0 + + if resume_from_log: + self.start_epoch, self.env_step, self.gradient_step = logger.restore_data() + self.last_rew, self.last_len = 0.0, 0 + self.stat: Dict[str, MovAvg] = defaultdict(MovAvg) + self.start_time = time.time() + self.train_collector.reset_stat() + self.test_in_train = self.test_in_train and ( + self.train_collector.policy == policy and self.test_collector is not None ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - - for epoch in range(1 + start_epoch, 1 + max_epoch): - # train - policy.train() - with tqdm.tqdm( - total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ) as t: - while t.n < t.total: - if train_fn: - train_fn(epoch, env_step) - result = train_collector.collect(n_step=step_per_collect) - if result["n/ep"] > 0 and reward_metric: - rew = reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - env_step += int(result["n/st"]) - t.update(result["n/st"]) - logger.log_train_data(result, env_step) - last_rew = result['rew'] if result["n/ep"] > 0 else last_rew - last_len = result['len'] if result["n/ep"] > 0 else last_len - data = { - "env_step": str(env_step), - "rew": f"{last_rew:.2f}", - "len": str(int(last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), - } - if result["n/ep"] > 0: - if test_in_train and stop_fn and stop_fn(result["rew"]): - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - env_step - ) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn - ) - t.set_postfix(**data) - return gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"] - ) - else: - policy.train() - for _ in range(round(update_per_step * result["n/st"])): - gradient_step += 1 - losses = policy.update(batch_size, train_collector.buffer) - for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() - data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) - t.set_postfix(**data) - if t.n <= t.total: - t.update() - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # test + if test_collector is not None: + self.test_c: Collector = test_collector # for mypy + self.test_collector.reset_stat() test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric + self.policy, self.test_c, self.test_fn, self.start_epoch, + self.episode_per_test, self.logger, self.env_step, self.reward_metric ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) - if stop_fn and stop_fn(best_reward): - break + self.best_epoch = self.start_epoch + self.best_reward, self.best_reward_std = test_result["rew"], test_result[ + "rew_std"] + if save_fn: + save_fn(policy) - if test_collector is None and save_fn: - save_fn(policy) + self.epoch = self.start_epoch + self.exit_flag = 0 - if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + def __iter__(self): # type: ignore + return self + def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: + self.epoch += 1 -def offpolicy_trainer_generator( - policy: BasePolicy, - train_collector: Collector, - test_collector: Optional[Collector], - max_epoch: int, - step_per_epoch: int, - step_per_collect: int, - episode_per_test: int, - batch_size: int, - update_per_step: Union[int, float] = 1, - train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, - test_in_train: bool = True, -) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: - """A generator wrapper for off-policy trainer procedure. - - Returns a generator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. - - The "step" in trainer means an environment step (a.k.a. transition). - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int step_per_epoch: the number of transitions collected per epoch. - :param int step_per_collect: the number of transitions the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatedly in each epoch. - :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in the - policy network. - :param int/float update_per_step: the number of times the policy network would be - updated per transition after (step_per_collect) transitions are collected, - e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will - be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are - collected by the collector. Default to 1. - :param function train_fn: a hook called at the beginning of training in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata from - existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during - training/testing/updating. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. - """ - start_epoch, env_step, gradient_step = 0, 0, 0 - if resume_from_log: - start_epoch, env_step, gradient_step = logger.restore_data() - last_rew, last_len = 0.0, 0 - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - train_collector.reset_stat() - test_in_train = test_in_train and ( - train_collector.policy == policy and test_collector is not None - ) - - if test_collector is not None: - test_c: Collector = test_collector # for mypy - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, - reward_metric - ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - - for epoch in range(1 + start_epoch, 1 + max_epoch): - # train - policy.train() + # exit flag 1, when test_in_train and stop_fn succeeds on result["rew"] + if self.test_in_train and self.stop_fn and self.exit_flag == 1: + raise StopIteration + + # iterator exhaustion check + if self.epoch >= self.max_epoch: + if self.test_collector is None and self.save_fn: + self.save_fn(self.policy) + raise StopIteration + + # stop_fn criterion + if self.test_collector is not None and self.stop_fn and self.stop_fn( + self.best_reward + ): + raise StopIteration + + # set policy in train mode + self.policy.train() + + # Performs n step_per_epoch with tqdm.tqdm( - total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config ) as t: while t.n < t.total: - if train_fn: - train_fn(epoch, env_step) - result = train_collector.collect(n_step=step_per_collect) - if result["n/ep"] > 0 and reward_metric: - rew = reward_metric(result["rews"]) + if self.train_fn: + self.train_fn(self.epoch, self.env_step) + result = self.train_collector.collect(n_step=self.step_per_collect) + if result["n/ep"] > 0 and self.reward_metric: + rew = self.reward_metric(result["rews"]) result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - env_step += int(result["n/st"]) + self.env_step += int(result["n/st"]) t.update(result["n/st"]) - logger.log_train_data(result, env_step) - last_rew = result['rew'] if result["n/ep"] > 0 else last_rew - last_len = result['len'] if result["n/ep"] > 0 else last_len + self.logger.log_train_data(result, self.env_step) + self.last_rew = result['rew'] if result["n/ep"] > 0 else self.last_rew + self.last_len = result['len'] if result["n/ep"] > 0 else self.last_len data = { - "env_step": str(env_step), - "rew": f"{last_rew:.2f}", - "len": str(int(last_len)), + "env_step": str(self.env_step), + "rew": f"{self.last_rew:.2f}", + "len": str(int(self.last_len)), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), } if result["n/ep"] > 0: - if test_in_train and stop_fn and stop_fn(result["rew"]): + if self.test_in_train and self.stop_fn and self.stop_fn( + result["rew"] + ): test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - env_step + self.policy, self.test_c, self.test_fn, self.epoch, + self.episode_per_test, self.logger, self.env_step ) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn + if self.stop_fn(test_result["rew"]): + if self.save_fn: + self.save_fn(self.policy) + self.logger.save_data( + self.epoch, self.env_step, self.gradient_step, + self.save_checkpoint_fn ) t.set_postfix(**data) - # epoch_stat for yield clause - epoch_stat: Dict[str, Any] = { - k: v.get() - for k, v in stat.items() - } - epoch_stat["gradient_step"] = gradient_step - epoch_stat.update( - { - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), + if not self.is_run: + epoch_stat: Dict[str, Any] = { + k: v.get() + for k, v in self.stat.items() } - ) - info = gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"] - ) - yield epoch, epoch_stat, info - return + epoch_stat["gradient_step"] = self.gradient_step + epoch_stat.update( + { + "env_step": self.env_step, + "rew": self.last_rew, + "len": int(self.last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) + self.exit_flag = 1 + self.best_reward = test_result["rew"] + self.best_reward_std = test_result["rew_std"] + + if not self.is_run: + info = gather_info( + self.start_time, self.train_collector, + self.test_collector, self.best_reward, + self.best_reward_std + ) + return self.epoch, epoch_stat, info + else: + return 0, {}, {} else: - policy.train() - for _ in range(round(update_per_step * result["n/st"])): - gradient_step += 1 - losses = policy.update(batch_size, train_collector.buffer) + self.policy.train() + for _ in range(round(self.update_per_step * result["n/st"])): + self.gradient_step += 1 + losses = self.policy.update( + self.batch_size, self.train_collector.buffer + ) for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() + self.stat[k].add(losses[k]) + losses[k] = self.stat[k].get() data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) + self.logger.log_update_data(losses, self.gradient_step) t.set_postfix(**data) if t.n <= t.total: t.update() - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # epoch_stat for yield clause - epoch_stat = {k: v.get() for k, v in stat.items()} - epoch_stat["gradient_step"] = gradient_step - epoch_stat.update( - { - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - } + self.logger.save_data( + self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn ) + + if not self.is_run: + epoch_stat = {k: v.get() for k, v in self.stat.items()} + epoch_stat["gradient_step"] = self.gradient_step + epoch_stat.update( + { + "env_step": self.env_step, + "rew": self.last_rew, + "len": int(self.last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) + # test - if test_collector is not None: + if self.test_collector is not None: test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric + self.policy, self.test_c, self.test_fn, self.epoch, + self.episode_per_test, self.logger, self.env_step, self.reward_metric ) rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: + if self.best_epoch < 0 or self.best_reward < rew: + self.best_epoch = self.epoch + self.best_reward = rew + self.best_reward_std = rew_std + if self.save_fn: + self.save_fn(self.policy) + if self.verbose: print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" + ) + if not self.is_run: + epoch_stat.update( + { + "test_reward": rew, + "test_reward_std": rew_std, + "best_reward": self.best_reward, + "best_reward_std": self.best_reward_std, + "best_epoch": self.best_epoch + } ) - epoch_stat.update( - { - "test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - } - ) - if test_collector is None: - info = gather_info(start_time, train_collector, None, 0.0, 0.0) + # return iterator -> next(self) + if not self.is_run: + info = gather_info( + self.start_time, self.train_collector, self.test_collector, + self.best_reward, self.best_reward_std + ) + return self.epoch, epoch_stat, info else: + return 0, {}, {} + + def run(self) -> Dict[str, Union[float, str]]: + """ + Consume iterator, see itertools-recipes. Use functions that consume + iterators at C speed (feed the entire iterator into a zero-length deque). + """ + try: + self.is_run = True + i = iter(self) + deque(i, maxlen=0) # feed the entire iterator into a zero-length deque info = gather_info( - start_time, train_collector, test_collector, best_reward, - best_reward_std + self.start_time, None, self.test_collector, self.best_reward, + self.best_reward_std ) - yield epoch, epoch_stat, info + finally: + self.is_run = False + + return info + + +def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + return OffPolicyTrainer(*args, **kwargs).run() - if test_collector is not None and stop_fn and stop_fn(best_reward): - break - if test_collector is None and save_fn: - save_fn(policy) +offpolicy_trainer_iter = OffPolicyTrainer From 88cb63c35ee8da7ee0fdb9f88f0f18c2233ab88a Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 6 Mar 2022 20:16:21 +0100 Subject: [PATCH 17/40] Expose new _iter versions and Iterator Classes --- tianshou/trainer/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 01599bbc5..d3e49ecef 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -4,16 +4,20 @@ from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer, onpolicy_trainer_generator -from tianshou.trainer.offpolicy import offpolicy_trainer, offpolicy_trainer_generator -from tianshou.trainer.offline import offline_trainer, offline_trainer_iter +from tianshou.trainer.offpolicy import offpolicy_trainer, offpolicy_trainer_iter,\ + OffPolicyTrainer +from tianshou.trainer.offline import offline_trainer, offline_trainer_iter,\ + OffLineTrainer __all__ = [ "offpolicy_trainer", - "offpolicy_trainer_generator", + "offpolicy_trainer_iter", + "OffPolicyTrainer", "onpolicy_trainer", "onpolicy_trainer_generator", "offline_trainer", "offline_trainer_iter", + "OffLineTrainer", "test_episode", "gather_info", ] From 34feb5b18d1eb91c33abcf65af134573a78185f9 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 6 Mar 2022 20:53:42 +0100 Subject: [PATCH 18/40] Add OffPolicyTrainer as Iterator adn add testing in test_td3.py --- test/continuous/test_td3.py | 25 ++++++++++++++++++++++++- tianshou/trainer/offpolicy.py | 33 +++++++++++++++++---------------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 2e3ef7ba7..80caf0728 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy -from tianshou.trainer import offpolicy_trainer +from tianshou.trainer import offpolicy_trainer, offpolicy_trainer_iter from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -149,6 +149,29 @@ def stop_fn(mean_rewards): ) assert stop_fn(result['best_reward']) + # Iterator trainer + trainer = offpolicy_trainer_iter( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + result_iter = info + assert stop_fn(result_iter['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 528d23e43..7ce4cf833 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -129,8 +129,8 @@ def __init__( self.train_collector.policy == policy and self.test_collector is not None ) - if test_collector is not None: - self.test_c: Collector = test_collector # for mypy + if self.test_collector is not None: + self.test_c: Collector = self.test_collector # for mypy self.test_collector.reset_stat() test_result = test_episode( self.policy, self.test_c, self.test_fn, self.start_epoch, @@ -151,21 +151,22 @@ def __iter__(self): # type: ignore def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: self.epoch += 1 - # exit flag 1, when test_in_train and stop_fn succeeds on result["rew"] - if self.test_in_train and self.stop_fn and self.exit_flag == 1: - raise StopIteration + if self.epoch > 1: + # exit flag 1, when test_in_train and stop_fn succeeds on result["rew"] + if self.test_in_train and self.stop_fn and self.exit_flag == 1: + raise StopIteration - # iterator exhaustion check - if self.epoch >= self.max_epoch: - if self.test_collector is None and self.save_fn: - self.save_fn(self.policy) - raise StopIteration + # iterator exhaustion check + if self.epoch >= self.max_epoch: + if self.test_collector is None and self.save_fn: + self.save_fn(self.policy) + raise StopIteration - # stop_fn criterion - if self.test_collector is not None and self.stop_fn and self.stop_fn( - self.best_reward - ): - raise StopIteration + # stop_fn criterion + if self.test_collector is not None and self.stop_fn and self.stop_fn( + self.best_reward + ): + raise StopIteration # set policy in train mode self.policy.train() @@ -278,7 +279,7 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: rew, rew_std = test_result["rew"], test_result["rew_std"] if self.best_epoch < 0 or self.best_reward < rew: self.best_epoch = self.epoch - self.best_reward = rew + self.best_reward = float(rew) self.best_reward_std = rew_std if self.save_fn: self.save_fn(self.policy) From 1c7eaef4f115311bda0f4c8f94a357d2095dfd90 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 6 Mar 2022 21:11:49 +0100 Subject: [PATCH 19/40] fix doc format --- tianshou/trainer/offline.py | 13 +++++++------ tianshou/trainer/offpolicy.py | 14 +++++++++----- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 83630fef8..3dd50c547 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -12,8 +12,7 @@ class OffLineTrainer: - """ - An iterator wrapper for offline training procedure. + """An iterator wrapper for offline training procedure. Returns an iterator that yields a 3 tuple (epoch, stats, info) of train results on every epoch. @@ -40,7 +39,8 @@ def __init__( logger: BaseLogger = LazyLogger(), verbose: bool = True, ): - """ + """Create an iterator wrapper for offline training procedure. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. This buffer must be populated with experiences for offline RL. @@ -202,9 +202,10 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: return 0, {}, {} def run(self) -> Dict[str, Union[float, str]]: - """ - Consume iterator, see itertools-recipes. Use functions that consume - iterators at C speed (feed the entire iterator into a zero-length deque). + """Consume iterator. + + See itertools - recipes. Use functions that consume iterators at C speed + (feed the entire iterator into a zero-length deque). """ try: self.is_run = True diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 7ce4cf833..c9c0b69ad 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -17,7 +17,9 @@ class OffPolicyTrainer: Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. - The "step" in trainer means an environment step (a.k.a. transition).""" + The "step" in trainer means an environment step (a.k.a. transition). + + """ def __init__( self, @@ -41,7 +43,8 @@ def __init__( verbose: bool = True, test_in_train: bool = True, ): - """ + """Create an iterator wrapper for offline training procedure. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. :param Collector test_collector: the collector used for testing. If it's None, @@ -311,9 +314,10 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: return 0, {}, {} def run(self) -> Dict[str, Union[float, str]]: - """ - Consume iterator, see itertools-recipes. Use functions that consume - iterators at C speed (feed the entire iterator into a zero-length deque). + """Consume iterator. + + See itertools - recipes. Use functions that consume iterators at C speed + (feed the entire iterator into a zero-length deque). """ try: self.is_run = True From 5ca6fb8ebe3fbed8b0cbe268754b338af415a45f Mon Sep 17 00:00:00 2001 From: R107333 Date: Tue, 8 Mar 2022 20:08:44 +0100 Subject: [PATCH 20/40] * Refactored trainers into One BaseTrainer class. * All the procedures are so equal that separating them will make to much unnecessary duplicated complex code * Included tests in test_ppo.py, test_cql.py and test_rd3.py * It can be simplified even more, but would break backward Api compatibility --- test/continuous/test_ppo.py | 27 +- test/continuous/test_ppo_trainer_generator.py | 199 ------- tianshou/trainer/__init__.py | 4 +- tianshou/trainer/base.py | 440 ++++++++++++++ tianshou/trainer/offline.py | 175 +----- tianshou/trainer/offpolicy.py | 281 ++------- tianshou/trainer/onpolicy.py | 544 ++++-------------- 7 files changed, 642 insertions(+), 1028 deletions(-) delete mode 100644 test/continuous/test_ppo_trainer_generator.py create mode 100644 tianshou/trainer/base.py diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 32d8411c2..218bc45a6 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy -from tianshou.trainer import onpolicy_trainer +from tianshou.trainer import onpolicy_trainer, onpolicy_trainer_iter from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -171,6 +171,31 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): ) assert stop_fn(result['best_reward']) + trainer = onpolicy_trainer_iter( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn + ) + + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + + result_iter = info + assert stop_fn(result_iter['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/continuous/test_ppo_trainer_generator.py b/test/continuous/test_ppo_trainer_generator.py deleted file mode 100644 index 3a064997f..000000000 --- a/test/continuous/test_ppo_trainer_generator.py +++ /dev/null @@ -1,199 +0,0 @@ -import argparse -import os -import pprint - -import gym -import numpy as np -import torch -from torch.distributions import Independent, Normal -from torch.utils.tensorboard import SummaryWriter - -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.policy import PPOPolicy -from tianshou.trainer import onpolicy_trainer_generator -from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.95) - parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=150000) - parser.add_argument('--episode-per-collect', type=int, default=16) - parser.add_argument('--repeat-per-collect', type=int, default=2) - parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) - parser.add_argument('--training-num', type=int, default=16) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' - ) - # ppo special - parser.add_argument('--vf-coef', type=float, default=0.25) - parser.add_argument('--ent-coef', type=float, default=0.0) - parser.add_argument('--eps-clip', type=float, default=0.2) - parser.add_argument('--max-grad-norm', type=float, default=0.5) - parser.add_argument('--gae-lambda', type=float, default=0.95) - parser.add_argument('--rew-norm', type=int, default=1) - parser.add_argument('--dual-clip', type=float, default=None) - parser.add_argument('--value-clip', type=int, default=1) - parser.add_argument('--norm-adv', type=int, default=1) - parser.add_argument('--recompute-adv', type=int, default=0) - parser.add_argument('--resume', action="store_true") - parser.add_argument("--save-interval", type=int, default=4) - args = parser.parse_known_args()[0] - return args - - -def test_ppo(args=get_args()): - env = gym.make(args.task) - if args.task == 'Pendulum-v0': - env.spec.reward_threshold = -250 - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) - train_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)] - ) - # test_envs = gym.make(args.task) - test_envs = DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)] - ) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb( - net, args.action_shape, max_action=args.max_action, device=args.device - ).to(args.device) - critic = Critic( - Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), - device=args.device - ).to(args.device) - actor_critic = ActorCritic(actor, critic) - # orthogonal initialization - for m in actor_critic.modules(): - if isinstance(m, torch.nn.Linear): - torch.nn.init.orthogonal_(m.weight) - torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) - - # replace DiagGuassian with Independent(Normal) which is equivalent - # pass *logits to be consistent with policy.forward - def dist(*logits): - return Independent(Normal(*logits), 1) - - policy = PPOPolicy( - actor, - critic, - optim, - dist, - discount_factor=args.gamma, - max_grad_norm=args.max_grad_norm, - eps_clip=args.eps_clip, - vf_coef=args.vf_coef, - ent_coef=args.ent_coef, - reward_normalization=args.rew_norm, - advantage_normalization=args.norm_adv, - recompute_advantage=args.recompute_adv, - dual_clip=args.dual_clip, - value_clip=args.value_clip, - gae_lambda=args.gae_lambda, - action_space=env.action_space - ) - # collector - train_collector = Collector( - policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) - ) - test_collector = Collector(policy, test_envs) - # log - log_path = os.path.join(args.logdir, args.task, 'ppo') - writer = SummaryWriter(log_path) - logger = TensorboardLogger(writer, save_interval=args.save_interval) - - def save_fn(policy): - torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold - - def save_checkpoint_fn(epoch, env_step, gradient_step): - # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html - torch.save( - { - 'model': policy.state_dict(), - 'optim': optim.state_dict(), - }, os.path.join(log_path, 'checkpoint.pth') - ) - - if args.resume: - # load from existing checkpoint - print(f"Loading agent under {log_path}") - ckpt_path = os.path.join(log_path, 'checkpoint.pth') - if os.path.exists(ckpt_path): - checkpoint = torch.load(ckpt_path, map_location=args.device) - policy.load_state_dict(checkpoint['model']) - optim.load_state_dict(checkpoint['optim']) - print("Successfully restore policy and optim.") - else: - print("Fail to restore policy and optim.") - - # trainer - trainer = onpolicy_trainer_generator( - policy, - train_collector, - test_collector, - args.epoch, - args.step_per_epoch, - args.repeat_per_collect, - args.test_num, - args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_fn=save_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn - ) - print(trainer) - - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) - - result = info - assert stop_fn(result['best_reward']) - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = gym.make(args.task) - policy.eval() - collector = Collector(policy, env) - result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - - -def test_ppo_resume(args=get_args()): - args.resume = True - test_ppo(args) - - -if __name__ == '__main__': - test_ppo() diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index d3e49ecef..76449a95e 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -3,7 +3,7 @@ # isort:skip_file from tianshou.trainer.utils import test_episode, gather_info -from tianshou.trainer.onpolicy import onpolicy_trainer, onpolicy_trainer_generator +from tianshou.trainer.onpolicy import onpolicy_trainer, onpolicy_trainer_iter from tianshou.trainer.offpolicy import offpolicy_trainer, offpolicy_trainer_iter,\ OffPolicyTrainer from tianshou.trainer.offline import offline_trainer, offline_trainer_iter,\ @@ -14,7 +14,7 @@ "offpolicy_trainer_iter", "OffPolicyTrainer", "onpolicy_trainer", - "onpolicy_trainer_generator", + "onpolicy_trainer_iter", "offline_trainer", "offline_trainer_iter", "OffLineTrainer", diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py new file mode 100644 index 000000000..50dc10f8e --- /dev/null +++ b/tianshou/trainer/base.py @@ -0,0 +1,440 @@ +import time +from collections import defaultdict, deque +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import numpy as np +import tqdm + +from tianshou.data import Collector, ReplayBuffer +from tianshou.policy import BasePolicy +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config + + +class BaseTrainer: + """An iterator base class for trainers procedure. + + Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. + + The "step" in trainer means an environment step (a.k.a. transition). + There are three types of learning iterators: + (1) offpolicy learning trainer + (2) onpolicy learning trainer + (3) offpolicy learning trainer + + """ + + learning_types: Dict[Union[int, str], Union[int, str]] = { + 0: "offpolicy", + "offpolicy": 0, + 1: "onpolicy", + "onpolicy": 1, + 2: "offline", + "offline": 2, + } + + def __init__( + self, + learning_type: Union[int, str], + policy: BasePolicy, + max_epoch: int, + batch_size: int, + train_collector: Optional[Collector] = None, + test_collector: Optional[Collector] = None, + buffer: Optional[ReplayBuffer] = None, + step_per_epoch: Optional[int] = None, + repeat_per_collect: Optional[int] = None, + episode_per_test: Optional[int] = None, + update_per_step: Union[int, float] = 1, + update_per_epoch: Optional[int] = None, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, + ): + """Create an iterator wrapper for training procedure. + + :param learning_type int|str: type of learning iterator, 0,1,2 for "offpolicy", + "onpolicy" and "offline" respectively + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` + is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, + for example, set it to 2 means the policy needs to learn each given batch + data twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param int episode_per_collect: the number of episodes the collector would + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly + in each epoch. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with + the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; + you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray + with shape (num_episode,)``, used in multi-agent RL. We need to return a + single scalar for each episode's result to monitor training in the + multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. + Default to True. + + """ + self.policy = policy + self.buffer = buffer + + self.train_collector = train_collector + self.test_collector = test_collector + + self.logger = logger + self.start_time = time.time() + self.stat: Dict[str, MovAvg] = defaultdict(MovAvg) + self.best_reward = 0.0 + self.best_reward_std = 0.0 + self.start_epoch = 0 + self.gradient_step = 0 + + self.max_epoch = max_epoch + self.step_per_epoch = step_per_epoch + + # either on of these two + self.step_per_collect = step_per_collect + self.episode_per_collect = episode_per_collect + + self.update_per_step = update_per_step + self.repeat_per_collect = repeat_per_collect + + self.episode_per_test = episode_per_test + + self.batch_size = batch_size + + self.train_fn = train_fn + self.test_fn = test_fn + self.stop_fn = stop_fn + self.save_fn = save_fn + self.save_checkpoint_fn = save_checkpoint_fn + + self.reward_metric = reward_metric + self.verbose = verbose + self.test_in_train = test_in_train + self.resume_from_log = resume_from_log + + self.is_run = False + self.last_rew, self.last_len = 0.0, 0 + self.env_step = 0 + self.test_c = self.test_collector + self.epoch = self.start_epoch + self.best_epoch = self.start_epoch + self.stop_fn_flag = 0 + + self.update_function: Dict[Union[int, str], Callable] = { + 0: self.offpolicy_update, + "offpolicy": self.offpolicy_update, + 1: self.onpolicy_update, + "onpolicy": self.onpolicy_update, + 2: self.offline_update, + "offline": self.offline_update, + } + assert learning_type in self.learning_types + self.learning_type = learning_type + self.policy_update_fn = self.update_function[self.learning_type] + + def reset(self) -> None: + """Initialize or reset the instance to yield a new iterator from zero.""" + self.is_run = False + self.env_step = 0 + if self.resume_from_log: + self.start_epoch, self.env_step, self.gradient_step =\ + self.logger.restore_data() + + self.last_rew, self.last_len = 0.0, 0 + + if self.train_collector is not None: + self.train_collector.reset_stat() + self.test_in_train = ( + self.test_in_train and + self.train_collector.policy == self.policy and + self.test_collector is not None + ) + + else: + self.test_in_train = False + + if self.test_collector is not None: + assert self.episode_per_test + self.test_collector.reset_stat() + test_result = test_episode( + self.policy, self.test_collector, self.test_fn, self.start_epoch, + self.episode_per_test, self.logger, self.env_step, self.reward_metric + ) + self.best_epoch = self.start_epoch + self.best_reward, self.best_reward_std = test_result["rew"], test_result[ + "rew_std"] + if self.save_fn: + self.save_fn(self.policy) + self.epoch = self.start_epoch + self.stop_fn_flag = 0 + + def __iter__(self): # type: ignore + self.reset() + return self + + def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: + self.epoch += 1 + + if self.epoch > 1: + + # iterator exhaustion check + if self.epoch >= self.max_epoch: + if self.test_collector is None and self.save_fn: + self.save_fn(self.policy) + raise StopIteration + + # exit flag 1, when test_in_train and stop_fn succeeds on result["rew"] + if self.test_in_train and self.stop_fn and self.stop_fn_flag == 1: + raise StopIteration + + # stop_fn criterion + if self.test_collector is not None and self.stop_fn and self.stop_fn( + self.best_reward + ): + raise StopIteration + + # set policy in train mode + self.policy.train() + + epoch_stat: Dict[str, Any] = dict() + # Performs n step_per_epoch + with tqdm.tqdm( + total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config + ) as t: + + while t.n < t.total and not self.stop_fn_flag: + data: Dict[str, Any] = dict() + result: Dict[str, Any] = dict() + if self.train_collector is not None: + data, result, self.stop_fn_flag = self.train_step() + t.update(result["n/st"]) + else: + assert self.buffer + result["n/ep"] = len(self.buffer) + result["n/st"] = int(self.gradient_step) + t.update() + + self.policy_update_fn(data, result) + t.set_postfix(**data) + + if t.n <= t.total: + t.update() + + self.logger.save_data( + self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn + ) + + if not self.is_run: + epoch_stat.update({k: v.get() for k, v in self.stat.items()}) + epoch_stat["gradient_step"] = self.gradient_step + epoch_stat.update( + { + "env_step": self.env_step, + "rew": self.last_rew, + "len": int(self.last_len), + "n/ep": int(result["n/ep"]), + "n/st": int(result["n/st"]), + } + ) + + if self.stop_fn_flag: + if not self.is_run: + info = gather_info( + self.start_time, self.train_collector, self.test_collector, + self.best_reward, self.best_reward_std + ) + return self.epoch, epoch_stat, info + else: + return 0, {}, {} + + # test + if self.test_collector is not None: + test_stat = self.test_step() + epoch_stat.update(test_stat) + + # return iterator -> next(self) + if not self.is_run: + info = gather_info( + self.start_time, self.train_collector, self.test_collector, + self.best_reward, self.best_reward_std + ) + return self.epoch, epoch_stat, info + else: + return 0, {}, {} + + def test_step(self) -> Dict[str, Any]: + """Performs a testing step.""" + assert self.episode_per_test is not None + assert self.test_collector is not None + test_result = test_episode( + self.policy, self.test_collector, self.test_fn, self.epoch, + self.episode_per_test, self.logger, self.env_step, self.reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if self.best_epoch < 0 or self.best_reward < rew: + self.best_epoch = self.epoch + self.best_reward = float(rew) + self.best_reward_std = rew_std + if self.save_fn: + self.save_fn(self.policy) + if self.verbose: + print( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" + ) + if not self.is_run: + test_stat = { + "test_reward": rew, + "test_reward_std": rew_std, + "best_reward": self.best_reward, + "best_reward_std": self.best_reward_std, + "best_epoch": self.best_epoch + } + else: + test_stat = {} + return test_stat + + def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: + """Performs 1 training step.""" + assert self.episode_per_test is not None + assert self.train_collector is not None + stop_fn_flag = False + if self.train_fn: + self.train_fn(self.epoch, self.env_step) + result = self.train_collector.collect( + n_step=self.step_per_collect, n_episode=self.episode_per_collect + ) + if result["n/ep"] > 0 and self.reward_metric: + rew = self.reward_metric(result["rews"]) + result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) + self.env_step += int(result["n/st"]) + self.logger.log_train_data(result, self.env_step) + self.last_rew = result['rew'] if result["n/ep"] > 0 else self.last_rew + self.last_len = result['len'] if result["n/ep"] > 0 else self.last_len + data = { + "env_step": str(self.env_step), + "rew": f"{self.last_rew:.2f}", + "len": str(int(self.last_len)), + "n/ep": str(int(result["n/ep"])), + "n/st": str(int(result["n/st"])), + } + if result["n/ep"] > 0: + if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]): + assert self.test_c is not None + test_result = test_episode( + self.policy, self.test_c, self.test_fn, self.epoch, + self.episode_per_test, self.logger, self.env_step + ) + if self.stop_fn(test_result["rew"]): + stop_fn_flag = True + self.best_reward = test_result["rew"] + self.best_reward_std = test_result["rew_std"] + else: + self.policy.train() + + return data, result, stop_fn_flag + + def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None: + """Log losses to current logger.""" + for k in losses.keys(): + self.stat[k].add(losses[k]) + losses[k] = self.stat[k].get() + data[k] = f"{losses[k]:.3f}" + self.logger.log_update_data(losses, self.gradient_step) + + def offpolicy_update(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: + """Performs off-policy updates.""" + assert self.train_collector is not None + for _ in range(round(self.update_per_step * result["n/st"])): + self.gradient_step += 1 + losses = self.policy.update(self.batch_size, self.train_collector.buffer) + self.log_update_data(data, losses) + + def onpolicy_update( + self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None + ) -> None: + """Performs on-policy updates.""" + assert self.train_collector is not None + losses = self.policy.update( + 0, + self.train_collector.buffer, + batch_size=self.batch_size, + repeat=self.repeat_per_collect + ) + self.train_collector.reset_buffer(keep_statistics=True) + step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)]) + self.gradient_step += step + self.log_update_data(data, losses) + + def offline_update( + self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None + ) -> None: + """Performs off-line policy update.""" + self.gradient_step += 1 + losses = self.policy.update(self.batch_size, self.buffer) + data.update({"gradient_step": str(self.gradient_step)}) + self.log_update_data(data, losses) + + def run(self) -> Dict[str, Union[float, str]]: + """Consume iterator. + + See itertools - recipes. Use functions that consume iterators at C speed + (feed the entire iterator into a zero-length deque). + """ + try: + self.is_run = True + i = iter(self) + deque(i, maxlen=0) # feed the entire iterator into a zero-length deque + info = gather_info( + self.start_time, None, self.test_collector, self.best_reward, + self.best_reward_std + ) + finally: + self.is_run = False + + return info diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 3dd50c547..2aa8937b2 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,18 +1,16 @@ -import time -from collections import defaultdict, deque -from typing import Any, Callable, Dict, Optional, Tuple, Union +from functools import wraps +from typing import Callable, Dict, Optional, Union import numpy as np -import tqdm from tianshou.data import Collector, ReplayBuffer from tianshou.policy import BasePolicy -from tianshou.trainer import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config +from tianshou.trainer.base import BaseTrainer +from tianshou.utils import BaseLogger, LazyLogger -class OffLineTrainer: - """An iterator wrapper for offline training procedure. +class OffLineTrainer(BaseTrainer): + """An iterator wrapper for off-line training procedure. Returns an iterator that yields a 3 tuple (epoch, stats, info) of train results on every epoch. @@ -39,7 +37,7 @@ def __init__( logger: BaseLogger = LazyLogger(), verbose: bool = True, ): - """Create an iterator wrapper for offline training procedure. + """Create an iterator wrapper for off-line training procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. @@ -80,148 +78,31 @@ def __init__( updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. """ - self.is_run = False - self.policy = policy - self.buffer = buffer - self.test_collector = test_collector - self.max_epoch = max_epoch - self.update_per_epoch = update_per_epoch - self.episode_per_test = episode_per_test - self.batch_size = batch_size - self.test_fn = test_fn - self.stop_fn = stop_fn - self.save_fn = save_fn - self.save_checkpoint_fn = save_checkpoint_fn - - self.reward_metric = reward_metric - self.logger = logger - self.verbose = verbose - - self.start_epoch, self.gradient_step = 0, 0 - self.best_reward, self.best_reward_std = 0.0, 0.0 - - if resume_from_log: - self.start_epoch, _, self.gradient_step = logger.restore_data() - self.stat: Dict[str, MovAvg] = defaultdict(MovAvg) - self.start_time = time.time() - - if test_collector is not None: - self.test_c: Collector = test_collector - test_collector.reset_stat() - test_result = test_episode( - self.policy, self.test_c, test_fn, self.start_epoch, - self.episode_per_test, self.logger, self.gradient_step, - self.reward_metric - ) - self.best_epoch = self.start_epoch - self.best_reward, self.best_reward_std = test_result["rew"], test_result[ - "rew_std"] - - if self.save_fn: - self.save_fn(policy) - - self.epoch = self.start_epoch - - def __iter__(self): # type: ignore - return self - - def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: - self.epoch += 1 - - # iterator exhaustion check - if self.epoch >= self.max_epoch: - if self.test_collector is None and self.save_fn: - self.save_fn(self.policy) - raise StopIteration - - # set policy in train mode - self.policy.train() - - # Performs n update_per_epoch - with tqdm.trange( - self.update_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config - ) as t: - for _ in t: - self.gradient_step += 1 - losses = self.policy.update(self.batch_size, self.buffer) - data = {"gradient_step": str(self.gradient_step)} - for k in losses.keys(): - self.stat[k].add(losses[k]) - losses[k] = self.stat[k].get() - data[k] = f"{losses[k]:.3f}" - self.logger.log_update_data(losses, self.gradient_step) - t.set_postfix(**data) - - self.logger.save_data( - self.epoch, 0, self.gradient_step, self.save_checkpoint_fn + learning_type = super().learning_types["offline"] + super().__init__( + learning_type=learning_type, + policy=policy, + buffer=buffer, + test_collector=test_collector, + max_epoch=max_epoch, + update_per_epoch=update_per_epoch, + step_per_epoch=update_per_epoch, + episode_per_test=episode_per_test, + batch_size=batch_size, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + save_checkpoint_fn=save_checkpoint_fn, + resume_from_log=resume_from_log, + reward_metric=reward_metric, + logger=logger, + verbose=verbose, ) - if not self.is_run: - epoch_stat: Dict[str, Any] = {k: v.get() for k, v in self.stat.items()} - epoch_stat["gradient_step"] = self.gradient_step - - # test - if self.test_collector is not None: - test_result = test_episode( - self.policy, self.test_c, self.test_fn, self.epoch, - self.episode_per_test, self.logger, self.gradient_step, - self.reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if self.best_epoch < 0 or self.best_reward < rew: - self.best_epoch = self.epoch - self.best_reward = rew - self.best_reward_std = rew_std - if self.save_fn: - self.save_fn(self.policy) - if self.verbose: - print( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f} in #{self.best_epoch}" - ) - if not self.is_run: - epoch_stat.update( - { - "test_reward": rew, - "test_reward_std": rew_std, - "best_reward": self.best_reward, - "best_reward_std": self.best_reward_std, - "best_epoch": self.best_epoch - } - ) - - # return iterator -> next(self) - if not self.is_run: - info = gather_info( - self.start_time, None, self.test_collector, self.best_reward, - self.best_reward_std - ) - return self.epoch, epoch_stat, info - else: - return 0, {}, {} - - def run(self) -> Dict[str, Union[float, str]]: - """Consume iterator. - - See itertools - recipes. Use functions that consume iterators at C speed - (feed the entire iterator into a zero-length deque). - """ - try: - self.is_run = True - i = iter(self) - deque(i, maxlen=0) # feed the entire iterator into a zero-length deque - info = gather_info( - self.start_time, None, self.test_collector, self.best_reward, - self.best_reward_std - ) - finally: - self.is_run = False - - return info - +@wraps(OffLineTrainer.__init__) def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + """Wrapper for offline_trainer run method.""" return OffLineTrainer(*args, **kwargs).run() diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index c9c0b69ad..b02cce836 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,25 +1,23 @@ -import time -from collections import defaultdict, deque -from typing import Any, Callable, Dict, Optional, Tuple, Union +from functools import wraps +from typing import Callable, Dict, Optional, Union import numpy as np -import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config +from tianshou.trainer.base import BaseTrainer +from tianshou.utils import BaseLogger, LazyLogger -class OffPolicyTrainer: +class OffPolicyTrainer(BaseTrainer): """An iterator wrapper for off-policy trainer procedure. - Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. + Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results + on every epoch. - The "step" in trainer means an environment step (a.k.a. transition). + The "step" in trainer means an environment step (a.k.a. transition). - """ + """ def __init__( self, @@ -43,7 +41,7 @@ def __init__( verbose: bool = True, test_in_train: bool = True, ): - """Create an iterator wrapper for offline training procedure. + """Create an iterator wrapper for off-policy training procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. @@ -95,245 +93,34 @@ def __init__( Default to True. """ - self.is_run = False - self.policy = policy - - self.train_collector = train_collector - self.test_collector = test_collector - - self.max_epoch = max_epoch - self.step_per_epoch = step_per_epoch - self.step_per_collect = step_per_collect - self.episode_per_test = episode_per_test - self.batch_size = batch_size - self.update_per_step = update_per_step - - self.train_fn = train_fn - self.test_fn = test_fn - self.stop_fn = stop_fn - self.save_fn = save_fn - self.save_checkpoint_fn = save_checkpoint_fn - - self.reward_metric = reward_metric - self.logger = logger - self.verbose = verbose - self.test_in_train = test_in_train - - self.start_epoch, self.env_step, self.gradient_step = 0, 0, 0 - self.best_reward, self.best_reward_std = 0.0, 0.0 - - if resume_from_log: - self.start_epoch, self.env_step, self.gradient_step = logger.restore_data() - self.last_rew, self.last_len = 0.0, 0 - self.stat: Dict[str, MovAvg] = defaultdict(MovAvg) - self.start_time = time.time() - self.train_collector.reset_stat() - self.test_in_train = self.test_in_train and ( - self.train_collector.policy == policy and self.test_collector is not None - ) - - if self.test_collector is not None: - self.test_c: Collector = self.test_collector # for mypy - self.test_collector.reset_stat() - test_result = test_episode( - self.policy, self.test_c, self.test_fn, self.start_epoch, - self.episode_per_test, self.logger, self.env_step, self.reward_metric - ) - self.best_epoch = self.start_epoch - self.best_reward, self.best_reward_std = test_result["rew"], test_result[ - "rew_std"] - if save_fn: - save_fn(policy) - - self.epoch = self.start_epoch - self.exit_flag = 0 - - def __iter__(self): # type: ignore - return self - - def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: - self.epoch += 1 - - if self.epoch > 1: - # exit flag 1, when test_in_train and stop_fn succeeds on result["rew"] - if self.test_in_train and self.stop_fn and self.exit_flag == 1: - raise StopIteration - - # iterator exhaustion check - if self.epoch >= self.max_epoch: - if self.test_collector is None and self.save_fn: - self.save_fn(self.policy) - raise StopIteration - - # stop_fn criterion - if self.test_collector is not None and self.stop_fn and self.stop_fn( - self.best_reward - ): - raise StopIteration - - # set policy in train mode - self.policy.train() - - # Performs n step_per_epoch - with tqdm.tqdm( - total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config - ) as t: - while t.n < t.total: - if self.train_fn: - self.train_fn(self.epoch, self.env_step) - result = self.train_collector.collect(n_step=self.step_per_collect) - if result["n/ep"] > 0 and self.reward_metric: - rew = self.reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - self.env_step += int(result["n/st"]) - t.update(result["n/st"]) - self.logger.log_train_data(result, self.env_step) - self.last_rew = result['rew'] if result["n/ep"] > 0 else self.last_rew - self.last_len = result['len'] if result["n/ep"] > 0 else self.last_len - data = { - "env_step": str(self.env_step), - "rew": f"{self.last_rew:.2f}", - "len": str(int(self.last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), - } - if result["n/ep"] > 0: - if self.test_in_train and self.stop_fn and self.stop_fn( - result["rew"] - ): - test_result = test_episode( - self.policy, self.test_c, self.test_fn, self.epoch, - self.episode_per_test, self.logger, self.env_step - ) - if self.stop_fn(test_result["rew"]): - if self.save_fn: - self.save_fn(self.policy) - self.logger.save_data( - self.epoch, self.env_step, self.gradient_step, - self.save_checkpoint_fn - ) - t.set_postfix(**data) - if not self.is_run: - epoch_stat: Dict[str, Any] = { - k: v.get() - for k, v in self.stat.items() - } - epoch_stat["gradient_step"] = self.gradient_step - epoch_stat.update( - { - "env_step": self.env_step, - "rew": self.last_rew, - "len": int(self.last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - } - ) - self.exit_flag = 1 - self.best_reward = test_result["rew"] - self.best_reward_std = test_result["rew_std"] - - if not self.is_run: - info = gather_info( - self.start_time, self.train_collector, - self.test_collector, self.best_reward, - self.best_reward_std - ) - return self.epoch, epoch_stat, info - else: - return 0, {}, {} - else: - self.policy.train() - for _ in range(round(self.update_per_step * result["n/st"])): - self.gradient_step += 1 - losses = self.policy.update( - self.batch_size, self.train_collector.buffer - ) - for k in losses.keys(): - self.stat[k].add(losses[k]) - losses[k] = self.stat[k].get() - data[k] = f"{losses[k]:.3f}" - self.logger.log_update_data(losses, self.gradient_step) - t.set_postfix(**data) - if t.n <= t.total: - t.update() - self.logger.save_data( - self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn + learning_type = super().learning_types["offpolicy"] + super().__init__( + learning_type=learning_type, + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=max_epoch, + step_per_epoch=step_per_epoch, + step_per_collect=step_per_collect, + episode_per_test=episode_per_test, + batch_size=batch_size, + update_per_step=update_per_step, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + save_checkpoint_fn=save_checkpoint_fn, + resume_from_log=resume_from_log, + reward_metric=reward_metric, + logger=logger, + verbose=verbose, + test_in_train=test_in_train, ) - if not self.is_run: - epoch_stat = {k: v.get() for k, v in self.stat.items()} - epoch_stat["gradient_step"] = self.gradient_step - epoch_stat.update( - { - "env_step": self.env_step, - "rew": self.last_rew, - "len": int(self.last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - } - ) - - # test - if self.test_collector is not None: - test_result = test_episode( - self.policy, self.test_c, self.test_fn, self.epoch, - self.episode_per_test, self.logger, self.env_step, self.reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if self.best_epoch < 0 or self.best_reward < rew: - self.best_epoch = self.epoch - self.best_reward = float(rew) - self.best_reward_std = rew_std - if self.save_fn: - self.save_fn(self.policy) - if self.verbose: - print( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f} in #{self.best_epoch}" - ) - if not self.is_run: - epoch_stat.update( - { - "test_reward": rew, - "test_reward_std": rew_std, - "best_reward": self.best_reward, - "best_reward_std": self.best_reward_std, - "best_epoch": self.best_epoch - } - ) - - # return iterator -> next(self) - if not self.is_run: - info = gather_info( - self.start_time, self.train_collector, self.test_collector, - self.best_reward, self.best_reward_std - ) - return self.epoch, epoch_stat, info - else: - return 0, {}, {} - - def run(self) -> Dict[str, Union[float, str]]: - """Consume iterator. - - See itertools - recipes. Use functions that consume iterators at C speed - (feed the entire iterator into a zero-length deque). - """ - try: - self.is_run = True - i = iter(self) - deque(i, maxlen=0) # feed the entire iterator into a zero-length deque - info = gather_info( - self.start_time, None, self.test_collector, self.best_reward, - self.best_reward_std - ) - finally: - self.is_run = False - - return info - +@wraps(OffPolicyTrainer.__init__) def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + """Wrapper for OffPolicyTrainer run method.""" return OffPolicyTrainer(*args, **kwargs).run() diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index c7141d84b..504aa7cb1 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,451 +1,131 @@ -import time -from collections import defaultdict -from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union +from functools import wraps +from typing import Callable, Dict, Optional, Union import numpy as np -import tqdm from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.trainer import gather_info, test_episode -from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config +from tianshou.trainer.base import BaseTrainer +from tianshou.utils import BaseLogger, LazyLogger -def onpolicy_trainer( - policy: BasePolicy, - train_collector: Collector, - test_collector: Optional[Collector], - max_epoch: int, - step_per_epoch: int, - repeat_per_collect: int, - episode_per_test: int, - batch_size: int, - step_per_collect: Optional[int] = None, - episode_per_collect: Optional[int] = None, - train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, - test_in_train: bool = True, -) -> Dict[str, Union[float, str]]: - """A wrapper for on-policy trainer procedure. +class OnPolicyTrainer(BaseTrainer): + """An iterator wrapper for On-policy trainer procedure. - The "step" in trainer means an environment step (a.k.a. transition). - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int step_per_epoch: the number of transitions collected per epoch. - :param int repeat_per_collect: the number of repeat time for policy learning, for - example, set it to 2 means the policy needs to learn each given batch data - twice. - :param int episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in the - policy network. - :param int step_per_collect: the number of transitions the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatedly in each epoch. - :param int episode_per_collect: the number of episodes the collector would collect - before the network update, i.e., trainer will collect "episode_per_collect" - episodes and do some policy network update repeatedly in each epoch. - :param function train_fn: a hook called at the beginning of training in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata from - existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during - training/testing/updating. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. - - .. note:: - - Only either one of step_per_collect and episode_per_collect can be specified. - """ - start_epoch, env_step, gradient_step = 0, 0, 0 - if resume_from_log: - start_epoch, env_step, gradient_step = logger.restore_data() - last_rew, last_len = 0.0, 0 - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - train_collector.reset_stat() - test_in_train = test_in_train and ( - train_collector.policy == policy and test_collector is not None - ) - - if test_collector is not None: - test_c: Collector = test_collector # for mypy - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, - reward_metric - ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - - for epoch in range(1 + start_epoch, 1 + max_epoch): - # train - policy.train() - with tqdm.tqdm( - total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ) as t: - while t.n < t.total: - if train_fn: - train_fn(epoch, env_step) - result = train_collector.collect( - n_step=step_per_collect, n_episode=episode_per_collect - ) - if result["n/ep"] > 0 and reward_metric: - rew = reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - env_step += int(result["n/st"]) - t.update(result["n/st"]) - logger.log_train_data(result, env_step) - last_rew = result['rew'] if result["n/ep"] > 0 else last_rew - last_len = result['len'] if result["n/ep"] > 0 else last_len - data = { - "env_step": str(env_step), - "rew": f"{last_rew:.2f}", - "len": str(int(last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), - } - if result["n/ep"] > 0: - if test_in_train and stop_fn and stop_fn(result["rew"]): - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - env_step - ) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn - ) - t.set_postfix(**data) - return gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"] - ) - else: - policy.train() - losses = policy.update( - 0, - train_collector.buffer, - batch_size=batch_size, - repeat=repeat_per_collect - ) - train_collector.reset_buffer(keep_statistics=True) - step = max( - [1] + [len(v) for v in losses.values() if isinstance(v, list)] - ) - gradient_step += step - for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() - data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) - t.set_postfix(**data) - if t.n <= t.total: - t.update() - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # test - if test_collector is not None: - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) - if stop_fn and stop_fn(best_reward): - break - - if test_collector is None and save_fn: - save_fn(policy) - - if test_collector is None: - return gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) - - -def onpolicy_trainer_generator( - policy: BasePolicy, - train_collector: Collector, - test_collector: Optional[Collector], - max_epoch: int, - step_per_epoch: int, - repeat_per_collect: int, - episode_per_test: int, - batch_size: int, - step_per_collect: Optional[int] = None, - episode_per_collect: Optional[int] = None, - train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[int, Optional[int]], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, - resume_from_log: bool = False, - reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, - logger: BaseLogger = LazyLogger(), - verbose: bool = True, - test_in_train: bool = True, -) -> Generator[Tuple[int, Dict[str, Any], Dict[str, Any]], None, None]: - """A generator wrapper for on-policy trainer procedure. - - Returns a generator that yields a 3-tuple (epoch, stats, info) of train results + Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. The "step" in trainer means an environment step (a.k.a. transition). - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, then - no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int step_per_epoch: the number of transitions collected per epoch. - :param int repeat_per_collect: the number of repeat time for policy learning, for - example, set it to 2 means the policy needs to learn each given batch data - twice. - :param int episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in the - policy network. - :param int step_per_collect: the number of transitions the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" - transitions and do some policy network update repeatedly in each epoch. - :param int episode_per_collect: the number of episodes the collector would collect - before the network update, i.e., trainer will collect "episode_per_collect" - episodes and do some policy network update repeatedly in each epoch. - :param function train_fn: a hook called at the beginning of training in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each epoch. - It can be used to perform custom additional operations, with the signature ``f( - num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean reward in - evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> - None``. - :param function save_checkpoint_fn: a function to save training process, with the - signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can - save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata from - existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: np.ndarray - with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, - used in multi-agent RL. We need to return a single scalar for each episode's - result to monitor training in the multi-agent RL setting. This function - specifies what is the desired metric, e.g., the reward of agent 1 or the - average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during - training/testing/updating. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to True. - - :return: See :func:`~tianshou.trainer.gather_info`. - - .. note:: - - Only either one of step_per_collect and episode_per_collect can be specified. """ - start_epoch, env_step, gradient_step = 0, 0, 0 - if resume_from_log: - start_epoch, env_step, gradient_step = logger.restore_data() - last_rew, last_len = 0.0, 0 - stat: Dict[str, MovAvg] = defaultdict(MovAvg) - start_time = time.time() - train_collector.reset_stat() - test_in_train = test_in_train and ( - train_collector.policy == policy and test_collector is not None - ) - if test_collector is not None: - test_c: Collector = test_collector # for mypy - test_collector.reset_stat() - test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, - reward_metric + def __init__( + self, + policy: BasePolicy, + train_collector: Collector, + test_collector: Optional[Collector], + max_epoch: int, + step_per_epoch: int, + repeat_per_collect: int, + episode_per_test: int, + batch_size: int, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, + ): + """Create an iterator wrapper for on-policy training procedure. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, + for example, set it to 2 means the policy needs to learn each given batch + data twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param int episode_per_collect: the number of episodes the collector would + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly + in each epoch. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, + with the signature ``f(epoch: int, env_step: int, gradient_step: int) + -> None``; you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> + np.ndarray with shape (num_episode,)``, used in multi-agent RL. + We need to return a single scalar for each episode's result to monitor + training in the multi-agent RL setting. This function specifies what is the + desired metric, e.g., the reward of agent 1 or the average reward over + all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to + True. + """ + learning_type = super().learning_types["onpolicy"] + super().__init__( + learning_type=learning_type, + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=max_epoch, + step_per_epoch=step_per_epoch, + repeat_per_collect=repeat_per_collect, + episode_per_test=episode_per_test, + batch_size=batch_size, + step_per_collect=step_per_collect, + episode_per_collect=episode_per_collect, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + save_checkpoint_fn=save_checkpoint_fn, + resume_from_log=resume_from_log, + reward_metric=reward_metric, + logger=logger, + verbose=verbose, + test_in_train=test_in_train, ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] - if save_fn: - save_fn(policy) - for epoch in range(1 + start_epoch, 1 + max_epoch): - # train - policy.train() - with tqdm.tqdm( - total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config - ) as t: - while t.n < t.total: - if train_fn: - train_fn(epoch, env_step) - result = train_collector.collect( - n_step=step_per_collect, n_episode=episode_per_collect - ) - if result["n/ep"] > 0 and reward_metric: - rew = reward_metric(result["rews"]) - result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) - env_step += int(result["n/st"]) - t.update(result["n/st"]) - logger.log_train_data(result, env_step) - last_rew = result['rew'] if result["n/ep"] > 0 else last_rew - last_len = result['len'] if result["n/ep"] > 0 else last_len - data = { - "env_step": str(env_step), - "rew": f"{last_rew:.2f}", - "len": str(int(last_len)), - "n/ep": str(int(result["n/ep"])), - "n/st": str(int(result["n/st"])), - } - if result["n/ep"] > 0: - if test_in_train and stop_fn and stop_fn(result["rew"]): - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - env_step - ) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - logger.save_data( - epoch, env_step, gradient_step, save_checkpoint_fn - ) - t.set_postfix(**data) - # epoch_stat for yield clause - epoch_stat: Dict[str, Any] = { - k: v.get() - for k, v in stat.items() - } - epoch_stat["gradient_step"] = gradient_step - epoch_stat.update( - { - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - } - ) - info = gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"] - ) - yield epoch, epoch_stat, info - return - else: - policy.train() - losses = policy.update( - 0, - train_collector.buffer, - batch_size=batch_size, - repeat=repeat_per_collect - ) - train_collector.reset_buffer(keep_statistics=True) - step = max( - [1] + [len(v) for v in losses.values() if isinstance(v, list)] - ) - gradient_step += step - for k in losses.keys(): - stat[k].add(losses[k]) - losses[k] = stat[k].get() - data[k] = f"{losses[k]:.3f}" - logger.log_update_data(losses, gradient_step) - t.set_postfix(**data) - if t.n <= t.total: - t.update() - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - # epoch_stat for yield clause - epoch_stat = {k: v.get() for k, v in stat.items()} - epoch_stat["gradient_step"] = gradient_step - epoch_stat.update( - { - "env_step": env_step, - "rew": last_rew, - "len": int(last_len), - "n/ep": int(result["n/ep"]), - "n/st": int(result["n/st"]), - } - ) - # test - if test_collector is not None: - test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) - epoch_stat.update( - { - "test_reward": rew, - "test_reward_std": rew_std, - "best_reward": best_reward, - "best_reward_std": best_reward_std, - "best_epoch": best_epoch - } - ) - if test_collector is None: - info = gather_info(start_time, train_collector, None, 0.0, 0.0) - else: - info = gather_info( - start_time, train_collector, test_collector, best_reward, - best_reward_std - ) - yield epoch, epoch_stat, info +@wraps(OnPolicyTrainer.__init__) +def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore + """Wrapper for OnPolicyTrainer run method.""" + return OnPolicyTrainer(*args, **kwargs).run() - if test_collector is not None and stop_fn and stop_fn(best_reward): - break - if test_collector is None and save_fn: - save_fn(policy) +onpolicy_trainer_iter = OnPolicyTrainer From b4fa395c9b37a4214c163762837f87a80a1aafe1 Mon Sep 17 00:00:00 2001 From: R107333 Date: Tue, 8 Mar 2022 21:43:35 +0100 Subject: [PATCH 21/40] fix formatting --- tianshou/trainer/base.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 50dc10f8e..18f9b27fb 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -185,20 +185,15 @@ def reset(self) -> None: self.logger.restore_data() self.last_rew, self.last_len = 0.0, 0 - if self.train_collector is not None: self.train_collector.reset_stat() - self.test_in_train = ( - self.test_in_train and - self.train_collector.policy == self.policy and - self.test_collector is not None - ) - - else: - self.test_in_train = False + if self.train_collector.policy != self.policy: + self.test_in_train = False + elif self.test_collector is None: + self.test_in_train = False if self.test_collector is not None: - assert self.episode_per_test + assert self.episode_per_test is not None self.test_collector.reset_stat() test_result = test_episode( self.policy, self.test_collector, self.test_fn, self.start_epoch, From c1f5f25740281624754ec00246e21ecc40a4ed48 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Tue, 8 Mar 2022 18:13:36 -0500 Subject: [PATCH 22/40] docs --- Makefile | 2 +- tianshou/trainer/base.py | 15 +++++++-------- tianshou/trainer/offline.py | 8 ++++---- tianshou/trainer/offpolicy.py | 12 +++++------- tianshou/trainer/onpolicy.py | 14 +++++++------- 5 files changed, 24 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index b3d8d96b2..b4c4a623e 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,6 @@ doc-clean: clean: doc-clean -commit-checks: format lint mypy check-docstyle spelling +commit-checks: lint check-codestyle mypy check-docstyle spelling .PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 18f9b27fb..81b94c678 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -11,7 +11,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config -class BaseTrainer: +class BaseTrainer(object): """An iterator base class for trainers procedure. Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results @@ -19,10 +19,10 @@ class BaseTrainer: The "step" in trainer means an environment step (a.k.a. transition). There are three types of learning iterators: - (1) offpolicy learning trainer - (2) onpolicy learning trainer - (3) offpolicy learning trainer + 1. offpolicy learning trainer + 2. onpolicy learning trainer + 3. offpolicy learning trainer """ learning_types: Dict[Union[int, str], Union[int, str]] = { @@ -82,11 +82,11 @@ def __init__( :param int step_per_collect: the number of transitions the collector would collect before the network update, i.e., trainer will collect "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. + in each epoch. :param int episode_per_collect: the number of episodes the collector would collect before the network update, i.e., trainer will collect "episode_per_collect" episodes and do some policy network update repeatedly - in each epoch. + in each epoch. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. @@ -114,8 +114,7 @@ def __init__( training/testing/updating. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. - Default to True. - + Default to True. """ self.policy = policy self.buffer = buffer diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 2aa8937b2..4c6732d18 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -60,11 +60,11 @@ def __init__( reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. :param function save_checkpoint_fn: a function to save training process, - with the signature ``f(epoch: int, env_step: int, - gradient_step: int) -> None``; you can save whatever you want. Because - offline-RL doesn't have env_step, the env_step is always 0 here. + with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> + None``; you can save whatever you want. Because offline-RL doesn't have + env_step, the env_step is always 0 here. :param bool resume_from_log: resume gradient_step and other metadata from - existing tensorboard log. Default to False. + existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index b02cce836..ac087a9e9 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -16,7 +16,6 @@ class OffPolicyTrainer(BaseTrainer): on every epoch. The "step" in trainer means an environment step (a.k.a. transition). - """ def __init__( @@ -54,7 +53,7 @@ def __init__( :param int step_per_collect: the number of transitions the collector would collect before the network update, i.e., trainer will collect "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. + in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. @@ -65,16 +64,16 @@ def __init__( transitions are collected by the collector. Default to 1. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. + signature ``f(num_epoch: int, step_idx: int) -> None``. :param function test_fn: a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. :param function save_checkpoint_fn: a function to save training process, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; - you can save whatever you want. + you can save whatever you want. :param bool resume_from_log: resume env_step/gradient_step and other metadata from existing tensorboard log. Default to False. :param function stop_fn: a function with signature ``f(mean_rewards: float) -> @@ -91,7 +90,6 @@ def __init__( :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. - """ learning_type = super().learning_types["offpolicy"] super().__init__( diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 504aa7cb1..847c26666 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -59,12 +59,12 @@ def __init__( :param int batch_size: the batch size of sample data, which is going to feed in the policy network. :param int step_per_collect: the number of transitions the collector would - collect before the network update, i.e., trainer will collect - "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. :param int episode_per_collect: the number of episodes the collector would - collect before the network update, i.e., trainer will collect - "episode_per_collect" episodes and do some policy network update repeatedly + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly in each epoch. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the @@ -74,7 +74,7 @@ def __init__( signature ``f(num_epoch: int, step_idx: int) -> None``. :param function save_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. + ``f(policy: BasePolicy) -> None``. :param function save_checkpoint_fn: a function to save training process, with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can save whatever you want. @@ -84,7 +84,7 @@ def __init__( bool``, receives the average undiscounted returns of the testing result, returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature - ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the From b12beb12588b679ee404e1a9aa643d8e036e6f5b Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Tue, 8 Mar 2022 22:09:56 -0500 Subject: [PATCH 23/40] fix missing import --- docs/spelling_wordlist.txt | 1 + tianshou/trainer/__init__.py | 27 +++++++++++++++++++-------- tianshou/trainer/base.py | 8 ++++---- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 86bdb281e..a3f96b433 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -30,6 +30,7 @@ dqn param async subprocess +deque nn equ cql diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 76449a95e..895f59ab3 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,20 +1,31 @@ """Trainer package.""" -# isort:skip_file - -from tianshou.trainer.utils import test_episode, gather_info -from tianshou.trainer.onpolicy import onpolicy_trainer, onpolicy_trainer_iter -from tianshou.trainer.offpolicy import offpolicy_trainer, offpolicy_trainer_iter,\ - OffPolicyTrainer -from tianshou.trainer.offline import offline_trainer, offline_trainer_iter,\ - OffLineTrainer +from tianshou.trainer.base import BaseTrainer +from tianshou.trainer.offline import ( + OffLineTrainer, + offline_trainer, + offline_trainer_iter, +) +from tianshou.trainer.offpolicy import ( + OffPolicyTrainer, + offpolicy_trainer, + offpolicy_trainer_iter, +) +from tianshou.trainer.onpolicy import ( + OnPolicyTrainer, + onpolicy_trainer, + onpolicy_trainer_iter, +) +from tianshou.trainer.utils import gather_info, test_episode __all__ = [ + "BaseTrainer", "offpolicy_trainer", "offpolicy_trainer_iter", "OffPolicyTrainer", "onpolicy_trainer", "onpolicy_trainer_iter", + "OnPolicyTrainer", "offline_trainer", "offline_trainer_iter", "OffLineTrainer", diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 81b94c678..853ceeb59 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -7,7 +7,7 @@ from tianshou.data import Collector, ReplayBuffer from tianshou.policy import BasePolicy -from tianshou.trainer import gather_info, test_episode +from tianshou.trainer.utils import gather_info, test_episode from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config @@ -20,9 +20,9 @@ class BaseTrainer(object): The "step" in trainer means an environment step (a.k.a. transition). There are three types of learning iterators: - 1. offpolicy learning trainer - 2. onpolicy learning trainer - 3. offpolicy learning trainer + 1. off-policy learning trainer + 2. on-policy learning trainer + 3. offline learning trainer """ learning_types: Dict[Union[int, str], Union[int, str]] = { From a4ae2e39bfc8eabaf3f500bb8c845b8ede55bd65 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 12 Mar 2022 14:43:06 +0100 Subject: [PATCH 24/40] * fix formatting * fix docs * fix drop "test_c" --- tianshou/trainer/base.py | 49 ++++++++++++++++++++++++++++++----- tianshou/trainer/offline.py | 1 + tianshou/trainer/offpolicy.py | 1 + tianshou/trainer/onpolicy.py | 1 + 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 853ceeb59..b5ea3aa6d 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -1,6 +1,6 @@ import time from collections import defaultdict, deque -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union import numpy as np import tqdm @@ -23,6 +23,36 @@ class BaseTrainer(object): 1. off-policy learning trainer 2. on-policy learning trainer 3. offline learning trainer + + Examples: + + :: + + trainer = onpolicy_trainer_generator(...) + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + do_something_with_policy() + query_something_about_policy() + make_a_plot_with(epoch_stat) + display(info) + + + - epoch int: the epoch number + - epoch_stat dict: a large collection of metrics of the current epoch + - info dict: gather_info + + You can even iterate on several trainers at the same time: + + :: + + trainer1 = onpolicy_trainer_generator(...) + trainer2 = onpolicy_trainer_generator(...) + for result1,result2,... in zip(trainer1,trainer2,...): + compare_results(result1,result2,...) + + """ learning_types: Dict[Union[int, str], Union[int, str]] = { @@ -116,6 +146,7 @@ def __init__( :param bool test_in_train: whether to test in the training phase. Default to True. """ + self.policy = policy self.buffer = buffer @@ -124,7 +155,7 @@ def __init__( self.logger = logger self.start_time = time.time() - self.stat: Dict[str, MovAvg] = defaultdict(MovAvg) + self.stat: DefaultDict[str, MovAvg] = defaultdict(MovAvg) self.best_reward = 0.0 self.best_reward_std = 0.0 self.start_epoch = 0 @@ -158,10 +189,10 @@ def __init__( self.is_run = False self.last_rew, self.last_len = 0.0, 0 self.env_step = 0 - self.test_c = self.test_collector self.epoch = self.start_epoch self.best_epoch = self.start_epoch self.stop_fn_flag = 0 + self.iter_num = 0 self.update_function: Dict[Union[int, str], Callable] = { 0: self.offpolicy_update, @@ -173,7 +204,9 @@ def __init__( } assert learning_type in self.learning_types self.learning_type = learning_type - self.policy_update_fn = self.update_function[self.learning_type] + self.policy_update_fn: Callable[[Any, Any], + None] = self.update_function[self.learning_type + ] def reset(self) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" @@ -205,6 +238,7 @@ def reset(self) -> None: self.save_fn(self.policy) self.epoch = self.start_epoch self.stop_fn_flag = 0 + self.iter_num = 0 def __iter__(self): # type: ignore self.reset() @@ -212,8 +246,9 @@ def __iter__(self): # type: ignore def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: self.epoch += 1 + self.iter_num += 1 - if self.epoch > 1: + if self.iter_num > 1: # iterator exhaustion check if self.epoch >= self.max_epoch: @@ -359,9 +394,9 @@ def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: } if result["n/ep"] > 0: if self.test_in_train and self.stop_fn and self.stop_fn(result["rew"]): - assert self.test_c is not None + assert self.test_collector is not None test_result = test_episode( - self.policy, self.test_c, self.test_fn, self.epoch, + self.policy, self.test_collector, self.test_fn, self.epoch, self.episode_per_test, self.logger, self.env_step ) if self.stop_fn(test_result["rew"]): diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 4c6732d18..03fa1bdb4 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -18,6 +18,7 @@ class OffLineTrainer(BaseTrainer): The "step" in offline trainer means a gradient step. """ + __doc__ = BaseTrainer.__doc__ def __init__( self, diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index ac087a9e9..7b7acbfe3 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -17,6 +17,7 @@ class OffPolicyTrainer(BaseTrainer): The "step" in trainer means an environment step (a.k.a. transition). """ + __doc__ = BaseTrainer.__doc__ def __init__( self, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 847c26666..ec79ab90f 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -18,6 +18,7 @@ class OnPolicyTrainer(BaseTrainer): The "step" in trainer means an environment step (a.k.a. transition). """ + __doc__ = BaseTrainer.__doc__ def __init__( self, From c902d6199c08dc53ed01e627d1d288bb2cc80a85 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 12 Mar 2022 10:44:48 -0500 Subject: [PATCH 25/40] update docs --- docs/spelling_wordlist.txt | 2 + tianshou/trainer/__init__.py | 12 +-- tianshou/trainer/base.py | 188 ++++++++++++++++++---------------- tianshou/trainer/offline.py | 102 +++++++++--------- tianshou/trainer/offpolicy.py | 121 +++++++++++----------- tianshou/trainer/onpolicy.py | 128 +++++++++++------------ 6 files changed, 273 insertions(+), 280 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a3f96b433..820fb8867 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -24,6 +24,8 @@ fqf iqn qrdqn rl +offpolicy +onpolicy quantile quantiles dqn diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 895f59ab3..8f1361bec 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -2,17 +2,17 @@ from tianshou.trainer.base import BaseTrainer from tianshou.trainer.offline import ( - OffLineTrainer, + OfflineTrainer, offline_trainer, offline_trainer_iter, ) from tianshou.trainer.offpolicy import ( - OffPolicyTrainer, + OffpolicyTrainer, offpolicy_trainer, offpolicy_trainer_iter, ) from tianshou.trainer.onpolicy import ( - OnPolicyTrainer, + OnpolicyTrainer, onpolicy_trainer, onpolicy_trainer_iter, ) @@ -22,13 +22,13 @@ "BaseTrainer", "offpolicy_trainer", "offpolicy_trainer_iter", - "OffPolicyTrainer", + "OffpolicyTrainer", "onpolicy_trainer", "onpolicy_trainer_iter", - "OnPolicyTrainer", + "OnpolicyTrainer", "offline_trainer", "offline_trainer_iter", - "OffLineTrainer", + "OfflineTrainer", "test_episode", "gather_info", ] diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index b5ea3aa6d..bc110f7a6 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -17,52 +17,116 @@ class BaseTrainer(object): Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. - The "step" in trainer means an environment step (a.k.a. transition). - There are three types of learning iterators: + :param learning_type int|str: type of learning iterator, 0,1,2 for "offpolicy", + "onpolicy" and "offline" respectively + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` + is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, + for example, set it to 2 means the policy needs to learn each given batch + data twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param int episode_per_collect: the number of episodes the collector would + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly + in each epoch. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with + the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; + you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray + with shape (num_episode,)``, used in multi-agent RL. We need to return a + single scalar for each episode's result to monitor training in the + multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. + Default to True. + """ - 1. off-policy learning trainer - 2. on-policy learning trainer - 3. offline learning trainer + learning_types: Dict[Union[int, str], Union[int, str]] = { + 0: "offpolicy", + "offpolicy": 0, + 1: "onpolicy", + "onpolicy": 1, + 2: "offline", + "offline": 2, + } - Examples: + @staticmethod + def gen_doc(learning_type: Union[int, str]) -> str: + if isinstance(learning_type, int): + learning_type = BaseTrainer.learning_types[learning_type] - :: + step_means = f'The "step" in {learning_type} trainer means ' + if learning_type != 2: + step_means += "an environment step (a.k.a. transition)." + else: # offline + step_means += "a gradient step." - trainer = onpolicy_trainer_generator(...) - for epoch, epoch_stat, info in trainer: - print(f"Epoch: {epoch}") - print(epoch_stat) - print(info) - do_something_with_policy() - query_something_about_policy() - make_a_plot_with(epoch_stat) - display(info) + trainer_name = learning_type.capitalize() + "Trainer" # type: ignore + return f"""An iterator class for {learning_type} trainer procedure. - - epoch int: the epoch number - - epoch_stat dict: a large collection of metrics of the current epoch - - info dict: gather_info + Returns an iterator that yields a 3-tuple (epoch, stats, info) of + train results on every epoch. - You can even iterate on several trainers at the same time: + {step_means} - :: + Example usage: - trainer1 = onpolicy_trainer_generator(...) - trainer2 = onpolicy_trainer_generator(...) - for result1,result2,... in zip(trainer1,trainer2,...): - compare_results(result1,result2,...) + :: + trainer = {trainer_name}(...) + for epoch, epoch_stat, info in trainer: + print("Epoch:", epoch) + print(epoch_stat) + print(info) + do_something_with_policy() + query_something_about_policy() + make_a_plot_with(epoch_stat) + display(info) - """ + - epoch int: the epoch number + - epoch_stat dict: a large collection of metrics of the current epoch + - info dict: :func:`~tianshou.trainer.gather_info` result - learning_types: Dict[Union[int, str], Union[int, str]] = { - 0: "offpolicy", - "offpolicy": 0, - 1: "onpolicy", - "onpolicy": 1, - 2: "offline", - "offline": 2, - } + You can even iterate on several trainers at the same time: + + :: + + trainer1 = {trainer_name}(...) + trainer2 = {trainer_name}(...) + for result1, result2, ... in zip(trainer1, trainer2, ...): + compare_results(result1, result2, ...) + """ def __init__( self, @@ -91,62 +155,6 @@ def __init__( verbose: bool = True, test_in_train: bool = True, ): - """Create an iterator wrapper for training procedure. - - :param learning_type int|str: type of learning iterator, 0,1,2 for "offpolicy", - "onpolicy" and "offline" respectively - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, - then no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` - is set. - :param int step_per_epoch: the number of transitions collected per epoch. - :param int repeat_per_collect: the number of repeat time for policy learning, - for example, set it to 2 means the policy needs to learn each given batch - data twice. - :param int episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in - the policy network. - :param int step_per_collect: the number of transitions the collector would - collect before the network update, i.e., trainer will collect - "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. - :param int episode_per_collect: the number of episodes the collector would - collect before the network update, i.e., trainer will collect - "episode_per_collect" episodes and do some policy network update repeatedly - in each epoch. - :param function train_fn: a hook called at the beginning of training in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. - :param function save_checkpoint_fn: a function to save training process, with - the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; - you can save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata - from existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature - ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray - with shape (num_episode,)``, used in multi-agent RL. We need to return a - single scalar for each episode's result to monitor training in the - multi-agent RL setting. This function specifies what is the desired metric, - e.g., the reward of agent 1 or the average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during - training/testing/updating. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. - Default to True. - """ - self.policy = policy self.buffer = buffer diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 03fa1bdb4..78f802743 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,4 +1,3 @@ -from functools import wraps from typing import Callable, Dict, Optional, Union import numpy as np @@ -9,16 +8,50 @@ from tianshou.utils import BaseLogger, LazyLogger -class OffLineTrainer(BaseTrainer): - """An iterator wrapper for off-line training procedure. - - Returns an iterator that yields a 3 tuple (epoch, stats, info) of train results - on every epoch. - - The "step" in offline trainer means a gradient step. +class OfflineTrainer(BaseTrainer): + """Create an iterator class for offline training procedure. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + This buffer must be populated with experiences for offline RL. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. + :param int update_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param function test_fn: a hook called at the beginning of testing in each + epoch. + It can be used to perform custom additional operations, with the signature + ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, + with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> + None``; you can save whatever you want. Because offline-RL doesn't have + env_step, the env_step is always 0 here. + :param bool resume_from_log: resume gradient_step and other metadata from + existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: + np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape + (num_episode,)``, used in multi-agent RL. We need to return a single scalar + for each episode's result to monitor training in the multi-agent RL + setting. This function specifies what is the desired metric, e.g., the + reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + updating/testing. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. """ - __doc__ = BaseTrainer.__doc__ + + __doc__ = BaseTrainer.gen_doc("offline") + "\n".join(__doc__.split("\n")[1:]) def __init__( self, @@ -38,47 +71,6 @@ def __init__( logger: BaseLogger = LazyLogger(), verbose: bool = True, ): - """Create an iterator wrapper for off-line training procedure. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - This buffer must be populated with experiences for offline RL. - :param Collector test_collector: the collector used for testing. If it's None, - then no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is - set. - :param int update_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in - the policy network. - :param function test_fn: a hook called at the beginning of testing in each - epoch. - It can be used to perform custom additional operations, with the signature - ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. - :param function save_checkpoint_fn: a function to save training process, - with the signature ``f(epoch: int, env_step: int, gradient_step: int) -> - None``; you can save whatever you want. Because offline-RL doesn't have - env_step, the env_step is always 0 here. - :param bool resume_from_log: resume gradient_step and other metadata from - existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature ``f(rewards: - np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape - (num_episode,)``, used in multi-agent RL. We need to return a single scalar - for each episode's result to monitor training in the multi-agent RL - setting. This function specifies what is the desired metric, e.g., the - reward of agent 1 or the average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during - updating/testing. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - """ learning_type = super().learning_types["offline"] super().__init__( learning_type=learning_type, @@ -101,10 +93,12 @@ def __init__( ) -@wraps(OffLineTrainer.__init__) def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore - """Wrapper for offline_trainer run method.""" - return OffLineTrainer(*args, **kwargs).run() + """Wrapper for offline_trainer run method. + + It is identical to ``OfflineTrainer(...).run()``. + """ + return OfflineTrainer(*args, **kwargs).run() -offline_trainer_iter = OffLineTrainer +offline_trainer_iter = OfflineTrainer diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 7b7acbfe3..775cafad2 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,4 +1,3 @@ -from functools import wraps from typing import Callable, Dict, Optional, Union import numpy as np @@ -9,15 +8,60 @@ from tianshou.utils import BaseLogger, LazyLogger -class OffPolicyTrainer(BaseTrainer): - """An iterator wrapper for off-policy trainer procedure. +class OffpolicyTrainer(BaseTrainer): + """Create an iterator wrapper for off-policy training procedure. - Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. - - The "step" in trainer means an environment step (a.k.a. transition). + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int/float update_per_step: the number of times the policy network would + be updated per transition after (step_per_collect) transitions are + collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256 + , policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256 + transitions are collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, with + the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; + you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> + np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to + return a single scalar for each episode's result to monitor training in the + multi-agent RL setting. This function specifies what is the desired metric, + e.g., the reward of agent 1 or the average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. + Default to True. """ - __doc__ = BaseTrainer.__doc__ + + __doc__ = BaseTrainer.gen_doc("offpolicy") + "\n".join(__doc__.split("\n")[1:]) def __init__( self, @@ -41,57 +85,6 @@ def __init__( verbose: bool = True, test_in_train: bool = True, ): - """Create an iterator wrapper for off-policy training procedure. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, - then no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is - set. - :param int step_per_epoch: the number of transitions collected per epoch. - :param int step_per_collect: the number of transitions the collector would - collect before the network update, i.e., trainer will collect - "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. - :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in - the policy network. - :param int/float update_per_step: the number of times the policy network would - be updated per transition after (step_per_collect) transitions are - collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256 - , policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256 - transitions are collected by the collector. Default to 1. - :param function train_fn: a hook called at the beginning of training in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. - :param function save_checkpoint_fn: a function to save training process, with - the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; - you can save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata - from existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature - ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> - np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to - return a single scalar for each episode's result to monitor training in the - multi-agent RL setting. This function specifies what is the desired metric, - e.g., the reward of agent 1 or the average reward over all agents. - :param BaseLogger logger: A logger that logs statistics during - training/testing/updating. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. - Default to True. - """ learning_type = super().learning_types["offpolicy"] super().__init__( learning_type=learning_type, @@ -117,10 +110,12 @@ def __init__( ) -@wraps(OffPolicyTrainer.__init__) def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore - """Wrapper for OffPolicyTrainer run method.""" - return OffPolicyTrainer(*args, **kwargs).run() + """Wrapper for OffPolicyTrainer run method. + + It is identical to ``OffpolicyTrainer(...).run()``. + """ + return OffpolicyTrainer(*args, **kwargs).run() -offpolicy_trainer_iter = OffPolicyTrainer +offpolicy_trainer_iter = OffpolicyTrainer diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index ec79ab90f..3e286a5fc 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,4 +1,3 @@ -from functools import wraps from typing import Callable, Dict, Optional, Union import numpy as np @@ -9,16 +8,63 @@ from tianshou.utils import BaseLogger, LazyLogger -class OnPolicyTrainer(BaseTrainer): - """An iterator wrapper for On-policy trainer procedure. - - Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results - on every epoch. - - The "step" in trainer means an environment step (a.k.a. transition). +class OnpolicyTrainer(BaseTrainer): + """Create an iterator wrapper for on-policy training procedure. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. If it's None, + then no testing will be performed. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is + set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, + for example, set it to 2 means the policy needs to learn each given batch + data twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param int step_per_collect: the number of transitions the collector would + collect before the network update, i.e., trainer will collect + "step_per_collect" transitions and do some policy network update repeatedly + in each epoch. + :param int episode_per_collect: the number of episodes the collector would + collect before the network update, i.e., trainer will collect + "episode_per_collect" episodes and do some policy network update repeatedly + in each epoch. + :param function train_fn: a hook called at the beginning of training in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature + ``f(policy: BasePolicy) -> None``. + :param function save_checkpoint_fn: a function to save training process, + with the signature ``f(epoch: int, env_step: int, gradient_step: int) + -> None``; you can save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata + from existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature + ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> + np.ndarray with shape (num_episode,)``, used in multi-agent RL. + We need to return a single scalar for each episode's result to monitor + training in the multi-agent RL setting. This function specifies what is the + desired metric, e.g., the reward of agent 1 or the average reward over + all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to + True. """ - __doc__ = BaseTrainer.__doc__ + + __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:]) def __init__( self, @@ -43,60 +89,6 @@ def __init__( verbose: bool = True, test_in_train: bool = True, ): - """Create an iterator wrapper for on-policy training procedure. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. If it's None, - then no testing will be performed. - :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` is - set. - :param int step_per_epoch: the number of transitions collected per epoch. - :param int repeat_per_collect: the number of repeat time for policy learning, - for example, set it to 2 means the policy needs to learn each given batch - data twice. - :param int episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to feed in - the policy network. - :param int step_per_collect: the number of transitions the collector would - collect before the network update, i.e., trainer will collect - "step_per_collect" transitions and do some policy network update repeatedly - in each epoch. - :param int episode_per_collect: the number of episodes the collector would - collect before the network update, i.e., trainer will collect - "episode_per_collect" episodes and do some policy network update repeatedly - in each epoch. - :param function train_fn: a hook called at the beginning of training in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature - ``f(policy: BasePolicy) -> None``. - :param function save_checkpoint_fn: a function to save training process, - with the signature ``f(epoch: int, env_step: int, gradient_step: int) - -> None``; you can save whatever you want. - :param bool resume_from_log: resume env_step/gradient_step and other metadata - from existing tensorboard log. Default to False. - :param function stop_fn: a function with signature ``f(mean_rewards: float) -> - bool``, receives the average undiscounted returns of the testing result, - returns a boolean which indicates whether reaching the goal. - :param function reward_metric: a function with signature - ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> - np.ndarray with shape (num_episode,)``, used in multi-agent RL. - We need to return a single scalar for each episode's result to monitor - training in the multi-agent RL setting. This function specifies what is the - desired metric, e.g., the reward of agent 1 or the average reward over - all agents. - :param BaseLogger logger: A logger that logs statistics during - training/testing/updating. Default to a logger that doesn't log anything. - :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to - True. - """ learning_type = super().learning_types["onpolicy"] super().__init__( learning_type=learning_type, @@ -123,10 +115,12 @@ def __init__( ) -@wraps(OnPolicyTrainer.__init__) def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore - """Wrapper for OnPolicyTrainer run method.""" - return OnPolicyTrainer(*args, **kwargs).run() + """Wrapper for OnpolicyTrainer run method. + + It is identical to ``OnpolicyTrainer(...).run()``. + """ + return OnpolicyTrainer(*args, **kwargs).run() -onpolicy_trainer_iter = OnPolicyTrainer +onpolicy_trainer_iter = OnpolicyTrainer From a3e7e2cb339a137f3faa8256a9ff909a29d55d66 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 12 Mar 2022 11:06:47 -0500 Subject: [PATCH 26/40] update rst --- docs/api/tianshou.trainer.rst | 44 ++++++++++++++++++++++++++++++++++- docs/tutorials/concepts.rst | 20 ++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/docs/api/tianshou.trainer.rst b/docs/api/tianshou.trainer.rst index 9deed5053..13c6d66c9 100644 --- a/docs/api/tianshou.trainer.rst +++ b/docs/api/tianshou.trainer.rst @@ -1,7 +1,49 @@ tianshou.trainer ================ -.. automodule:: tianshou.trainer + +On-policy +--------- + +.. autoclass:: tianshou.trainer.OnpolicyTrainer + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: tianshou.trainer.onpolicy_trainer + +.. autoclass:: tianshou.trainer.onpolicy_trainer_iter + + +Off-policy +---------- + +.. autoclass:: tianshou.trainer.OffpolicyTrainer + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: tianshou.trainer.offpolicy_trainer + +.. autoclass:: tianshou.trainer.offpolicy_trainer_iter + + +Offline +------- + +.. autoclass:: tianshou.trainer.OfflineTrainer :members: :undoc-members: :show-inheritance: + +.. autofunction:: tianshou.trainer.offline_trainer + +.. autoclass:: tianshou.trainer.offline_trainer_iter + + +utils +----- + +.. autofunction:: tianshou.trainer.test_episode + +.. autofunction:: tianshou.trainer.gather_info diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index d500787ee..cb6d616fe 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage. +We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic: +:: + + trainer = OnpolicyTrainer(...) + for epoch, epoch_stat, info in trainer: + print(f"Epoch: {epoch}") + print(epoch_stat) + print(info) + do_something_with_policy() + query_something_about_policy() + make_a_plot_with(epoch_stat) + display(info) + + # or even iterate on several trainers at the same time + + trainer1 = OnpolicyTrainer(...) + trainer2 = OnpolicyTrainer(...) + for result1, result2, ... in zip(trainer1, trainer2, ...): + compare_results(result1, result2, ...) + .. _pseudocode: From 651726f954c59be0dff9f58251dc81b73d7d9f34 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 12 Mar 2022 20:03:48 +0100 Subject: [PATCH 27/40] fix early stopping during train [train_step] --- tianshou/trainer/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index bc110f7a6..e8dd389d8 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -168,7 +168,7 @@ def __init__( self.best_reward_std = 0.0 self.start_epoch = 0 self.gradient_step = 0 - + self.env_step = 0 self.max_epoch = max_epoch self.step_per_epoch = step_per_epoch @@ -196,7 +196,7 @@ def __init__( self.is_run = False self.last_rew, self.last_len = 0.0, 0 - self.env_step = 0 + self.epoch = self.start_epoch self.best_epoch = self.start_epoch self.stop_fn_flag = 0 @@ -225,8 +225,10 @@ def reset(self) -> None: self.logger.restore_data() self.last_rew, self.last_len = 0.0, 0 + self.start_time = time.time() if self.train_collector is not None: self.train_collector.reset_stat() + if self.train_collector.policy != self.policy: self.test_in_train = False elif self.test_collector is None: @@ -244,6 +246,7 @@ def reset(self) -> None: "rew_std"] if self.save_fn: self.save_fn(self.policy) + self.epoch = self.start_epoch self.stop_fn_flag = 0 self.iter_num = 0 @@ -288,6 +291,8 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: result: Dict[str, Any] = dict() if self.train_collector is not None: data, result, self.stop_fn_flag = self.train_step() + if self.stop_fn_flag: + break t.update(result["n/st"]) else: assert self.buffer From e6b00e2e28ba63461bfac601834adf412d0747bc Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 12 Mar 2022 20:26:12 +0100 Subject: [PATCH 28/40] * fix early stopping during train train_step * Simplify return logic --- tianshou/trainer/base.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index e8dd389d8..d04a0ac7a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -306,9 +306,14 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: if t.n <= t.total: t.update() - self.logger.save_data( - self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn - ) + if not self.stop_fn_flag: + self.logger.save_data( + self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn + ) + # test + if self.test_collector is not None: + test_stat = self.test_step() + epoch_stat.update(test_stat) if not self.is_run: epoch_stat.update({k: v.get() for k, v in self.stat.items()}) @@ -322,24 +327,6 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: "n/st": int(result["n/st"]), } ) - - if self.stop_fn_flag: - if not self.is_run: - info = gather_info( - self.start_time, self.train_collector, self.test_collector, - self.best_reward, self.best_reward_std - ) - return self.epoch, epoch_stat, info - else: - return 0, {}, {} - - # test - if self.test_collector is not None: - test_stat = self.test_step() - epoch_stat.update(test_stat) - - # return iterator -> next(self) - if not self.is_run: info = gather_info( self.start_time, self.train_collector, self.test_collector, self.best_reward, self.best_reward_std From 4d768431c2fc8100dc979424ce0c165ce2ae9786 Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 12 Mar 2022 20:37:23 +0100 Subject: [PATCH 29/40] * fix early stopping during train train_step * Simplify return logic --- tianshou/trainer/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index d04a0ac7a..6fd3d6d44 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -295,7 +295,7 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: break t.update(result["n/st"]) else: - assert self.buffer + assert self.buffer, "No train_collector or buffer specified" result["n/ep"] = len(self.buffer) result["n/st"] = int(self.gradient_step) t.update() @@ -444,6 +444,7 @@ def offline_update( self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None ) -> None: """Performs off-line policy update.""" + assert self.buffer self.gradient_step += 1 losses = self.policy.update(self.batch_size, self.buffer) data.update({"gradient_step": str(self.gradient_step)}) From 23ce4836a32208ffc917e7f45916c07cc874fb2e Mon Sep 17 00:00:00 2001 From: R107333 Date: Sat, 12 Mar 2022 20:52:24 +0100 Subject: [PATCH 30/40] * fix early stopping during train train_step * Simplify return logic --- tianshou/trainer/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 6fd3d6d44..5c86dc9ef 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -255,7 +255,7 @@ def __iter__(self): # type: ignore self.reset() return self - def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: + def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: self.epoch += 1 self.iter_num += 1 @@ -333,7 +333,7 @@ def __next__(self) -> Tuple[int, Dict[str, Any], Dict[str, Any]]: ) return self.epoch, epoch_stat, info else: - return 0, {}, {} + return None def test_step(self) -> Dict[str, Any]: """Performs a testing step.""" From 1d707f8c240360f1af3a61c184fd6421bd582adf Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 13 Mar 2022 08:51:42 +0100 Subject: [PATCH 31/40] * fix early stopping during train train_step * Simplify return logic --- tianshou/trainer/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 5c86dc9ef..f0ffdadc6 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -292,6 +292,7 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: if self.train_collector is not None: data, result, self.stop_fn_flag = self.train_step() if self.stop_fn_flag: + t.set_postfix(**data) break t.update(result["n/st"]) else: @@ -306,10 +307,11 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: if t.n <= t.total: t.update() + self.logger.save_data( + self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn + ) + if not self.stop_fn_flag: - self.logger.save_data( - self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn - ) # test if self.test_collector is not None: test_stat = self.test_step() From 479b794ae6fc0e20f72c8011fe21ba75ac67dd6e Mon Sep 17 00:00:00 2001 From: R107333 Date: Sun, 13 Mar 2022 11:32:38 +0100 Subject: [PATCH 32/40] * fix early stopping during train train_step * Simplify return logic --- tianshou/trainer/base.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index f0ffdadc6..ac247e3dd 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -199,7 +199,7 @@ def __init__( self.epoch = self.start_epoch self.best_epoch = self.start_epoch - self.stop_fn_flag = 0 + self.stop_fn_flag = False self.iter_num = 0 self.update_function: Dict[Union[int, str], Callable] = { @@ -248,7 +248,7 @@ def reset(self) -> None: self.save_fn(self.policy) self.epoch = self.start_epoch - self.stop_fn_flag = 0 + self.stop_fn_flag = False self.iter_num = 0 def __iter__(self): # type: ignore @@ -267,8 +267,8 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: self.save_fn(self.policy) raise StopIteration - # exit flag 1, when test_in_train and stop_fn succeeds on result["rew"] - if self.test_in_train and self.stop_fn and self.stop_fn_flag == 1: + # exit flag 1, when stop_fn succeeds in train_step or test_step + if self.stop_fn_flag: raise StopIteration # stop_fn criterion @@ -285,16 +285,15 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: with tqdm.tqdm( total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config ) as t: - while t.n < t.total and not self.stop_fn_flag: data: Dict[str, Any] = dict() result: Dict[str, Any] = dict() if self.train_collector is not None: data, result, self.stop_fn_flag = self.train_step() + t.update(result["n/st"]) if self.stop_fn_flag: t.set_postfix(**data) break - t.update(result["n/st"]) else: assert self.buffer, "No train_collector or buffer specified" result["n/ep"] = len(self.buffer) @@ -304,18 +303,18 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: self.policy_update_fn(data, result) t.set_postfix(**data) - if t.n <= t.total: + if t.n <= t.total and not self.stop_fn_flag: t.update() - self.logger.save_data( - self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn - ) - if not self.stop_fn_flag: + self.logger.save_data( + self.epoch, self.env_step, self.gradient_step, self.save_checkpoint_fn + ) # test if self.test_collector is not None: - test_stat = self.test_step() - epoch_stat.update(test_stat) + test_stat, self.stop_fn_flag = self.test_step() + if not self.is_run: + epoch_stat.update(test_stat) if not self.is_run: epoch_stat.update({k: v.get() for k, v in self.stat.items()}) @@ -337,10 +336,11 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: else: return None - def test_step(self) -> Dict[str, Any]: + def test_step(self) -> Tuple[Dict[str, Any], bool]: """Performs a testing step.""" assert self.episode_per_test is not None assert self.test_collector is not None + stop_fn_flag = False test_result = test_episode( self.policy, self.test_collector, self.test_fn, self.epoch, self.episode_per_test, self.logger, self.env_step, self.reward_metric @@ -368,7 +368,10 @@ def test_step(self) -> Dict[str, Any]: } else: test_stat = {} - return test_stat + if self.stop_fn and self.stop_fn(self.best_reward): + stop_fn_flag = True + + return test_stat, stop_fn_flag def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: """Performs 1 training step.""" @@ -460,8 +463,7 @@ def run(self) -> Dict[str, Union[float, str]]: """ try: self.is_run = True - i = iter(self) - deque(i, maxlen=0) # feed the entire iterator into a zero-length deque + deque(self, maxlen=0) # feed the entire iterator into a zero-length deque info = gather_info( self.start_time, None, self.test_collector, self.best_reward, self.best_reward_std From 08f65a6ffcfe8896fde3738a08eb8df048581545 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Wed, 16 Mar 2022 10:44:54 -0400 Subject: [PATCH 33/40] fix a bug in BaseTrainer.run return value missing --- test/continuous/test_ppo.py | 27 ++++----------------------- test/continuous/test_td3.py | 30 ++++++------------------------ test/offline/test_cql.py | 26 +++++--------------------- tianshou/trainer/base.py | 34 +++++++++++++++++----------------- 4 files changed, 32 insertions(+), 85 deletions(-) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 715460658..f187b0f68 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,7 +11,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy -from tianshou.trainer import onpolicy_trainer, onpolicy_trainer_iter +from tianshou.trainer import OnpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -157,25 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): print("Fail to restore policy and optim.") # trainer - result = onpolicy_trainer( - policy, - train_collector, - test_collector, - args.epoch, - args.step_per_epoch, - args.repeat_per_collect, - args.test_num, - args.batch_size, - episode_per_collect=args.episode_per_collect, - stop_fn=stop_fn, - save_fn=save_fn, - logger=logger, - resume_from_log=args.resume, - save_checkpoint_fn=save_checkpoint_fn - ) - assert stop_fn(result['best_reward']) - - trainer = onpolicy_trainer_iter( + trainer = OnpolicyTrainer( policy, train_collector, test_collector, @@ -197,11 +179,10 @@ def save_checkpoint_fn(epoch, env_step, gradient_step): print(epoch_stat) print(info) - result_iter = info - assert stop_fn(result_iter['best_reward']) + assert stop_fn(info["best_reward"]) if __name__ == '__main__': - pprint.pprint(result) + pprint.pprint(info) # Let's watch its performance! env = gym.make(args.task) policy.eval() diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index e7eb19f34..d7ee186fc 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,7 +11,7 @@ from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy -from tianshou.trainer import offpolicy_trainer, offpolicy_trainer_iter +from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import Actor, Critic @@ -135,25 +135,8 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= args.reward_threshold - # trainer - result = offpolicy_trainer( - policy, - train_collector, - test_collector, - args.epoch, - args.step_per_epoch, - args.step_per_collect, - args.test_num, - args.batch_size, - update_per_step=args.update_per_step, - stop_fn=stop_fn, - save_fn=save_fn, - logger=logger - ) - assert stop_fn(result['best_reward']) - # Iterator trainer - trainer = offpolicy_trainer_iter( + trainer = OffpolicyTrainer( policy, train_collector, test_collector, @@ -165,18 +148,17 @@ def stop_fn(mean_rewards): update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) for epoch, epoch_stat, info in trainer: print(f"Epoch: {epoch}") print(epoch_stat) print(info) - result_iter = info - assert stop_fn(result_iter['best_reward']) + assert stop_fn(info["best_reward"]) - if __name__ == '__main__': - pprint.pprint(result) + if __name__ == "__main__": + pprint.pprint(info) # Let's watch its performance! env = gym.make(args.task) policy.eval() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 32875ff79..91a2784df 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -12,7 +12,7 @@ from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import CQLPolicy -from tianshou.trainer import offline_trainer, offline_trainer_iter +from tianshou.trainer import OfflineTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic @@ -195,22 +195,7 @@ def watch(): collector.collect(n_episode=1, render=1 / 35) # trainer - result = offline_trainer( - policy, - buffer, - test_collector, - args.epoch, - args.step_per_epoch, - args.test_num, - args.batch_size, - save_fn=save_fn, - stop_fn=stop_fn, - logger=logger, - ) - assert stop_fn(result['best_reward']) - - # trainer - trainer = offline_trainer_iter( + trainer = OfflineTrainer( policy, buffer, test_collector, @@ -228,12 +213,11 @@ def watch(): print(epoch_stat) print(info) - result_iter = info - assert stop_fn(result_iter['best_reward']) + assert stop_fn(info["best_reward"]) # Let's watch its performance! - if __name__ == '__main__': - pprint.pprint(result) + if __name__ == "__main__": + pprint.pprint(info) env = gym.make(args.task) policy.eval() collector = Collector(policy, env) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index ac247e3dd..cb474e731 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -116,7 +116,7 @@ def gen_doc(learning_type: Union[int, str]) -> str: - epoch int: the epoch number - epoch_stat dict: a large collection of metrics of the current epoch - - info dict: :func:`~tianshou.trainer.gather_info` result + - info dict: result returned from :func:`~tianshou.trainer.gather_info` You can even iterate on several trainers at the same time: @@ -202,7 +202,7 @@ def __init__( self.stop_fn_flag = False self.iter_num = 0 - self.update_function: Dict[Union[int, str], Callable] = { + update_function: Dict[Union[int, str], Callable] = { 0: self.offpolicy_update, "offpolicy": self.offpolicy_update, 1: self.onpolicy_update, @@ -212,16 +212,15 @@ def __init__( } assert learning_type in self.learning_types self.learning_type = learning_type - self.policy_update_fn: Callable[[Any, Any], - None] = self.update_function[self.learning_type - ] + self.policy_update_fn: Callable[[Any, Any], None] = \ + update_function[self.learning_type] def reset(self) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" self.is_run = False self.env_step = 0 if self.resume_from_log: - self.start_epoch, self.env_step, self.gradient_step =\ + self.start_epoch, self.env_step, self.gradient_step = \ self.logger.restore_data() self.last_rew, self.last_len = 0.0, 0 @@ -242,8 +241,8 @@ def reset(self) -> None: self.episode_per_test, self.logger, self.env_step, self.reward_metric ) self.best_epoch = self.start_epoch - self.best_reward, self.best_reward_std = test_result["rew"], test_result[ - "rew_std"] + self.best_reward, self.best_reward_std = \ + test_result["rew"], test_result["rew_std"] if self.save_fn: self.save_fn(self.policy) @@ -256,6 +255,7 @@ def __iter__(self): # type: ignore return self def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: + """Perform one epoch (both train and eval).""" self.epoch += 1 self.iter_num += 1 @@ -281,7 +281,7 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: self.policy.train() epoch_stat: Dict[str, Any] = dict() - # Performs n step_per_epoch + # perform n step_per_epoch with tqdm.tqdm( total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config ) as t: @@ -337,7 +337,7 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: return None def test_step(self) -> Tuple[Dict[str, Any], bool]: - """Performs a testing step.""" + """Perform one testing step.""" assert self.episode_per_test is not None assert self.test_collector is not None stop_fn_flag = False @@ -374,7 +374,7 @@ def test_step(self) -> Tuple[Dict[str, Any], bool]: return test_stat, stop_fn_flag def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: - """Performs 1 training step.""" + """Perform one training step.""" assert self.episode_per_test is not None assert self.train_collector is not None stop_fn_flag = False @@ -422,7 +422,7 @@ def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None: self.logger.log_update_data(losses, self.gradient_step) def offpolicy_update(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: - """Performs off-policy updates.""" + """Perform off-policy updates.""" assert self.train_collector is not None for _ in range(round(self.update_per_step * result["n/st"])): self.gradient_step += 1 @@ -432,13 +432,13 @@ def offpolicy_update(self, data: Dict[str, Any], result: Dict[str, Any]) -> None def onpolicy_update( self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None ) -> None: - """Performs on-policy updates.""" + """Perform one on-policy update.""" assert self.train_collector is not None losses = self.policy.update( 0, self.train_collector.buffer, batch_size=self.batch_size, - repeat=self.repeat_per_collect + repeat=self.repeat_per_collect, ) self.train_collector.reset_buffer(keep_statistics=True) step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)]) @@ -448,7 +448,7 @@ def onpolicy_update( def offline_update( self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None ) -> None: - """Performs off-line policy update.""" + """Perform one off-line policy update.""" assert self.buffer self.gradient_step += 1 losses = self.policy.update(self.batch_size, self.buffer) @@ -465,8 +465,8 @@ def run(self) -> Dict[str, Union[float, str]]: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque info = gather_info( - self.start_time, None, self.test_collector, self.best_reward, - self.best_reward_std + self.start_time, self.train_collector, self.test_collector, + self.best_reward, self.best_reward_std ) finally: self.is_run = False From 5ec4eb3e55942359615c978b81fc7ee7e79ef7a4 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Wed, 16 Mar 2022 10:54:06 -0400 Subject: [PATCH 34/40] change seed to pass ci --- .github/ISSUE_TEMPLATE.md | 4 ++-- test/continuous/test_sac_with_il.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index e1488688c..c5c07bf3f 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -7,6 +7,6 @@ - [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: ```python - import tianshou, torch, numpy, sys - print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) + import tianshou, gym, torch, numpy, sys + print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform) ``` diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index b2287a2c5..8930b0839 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -24,7 +24,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') parser.add_argument('--reward-threshold', type=float, default=None) - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=1e-3) parser.add_argument('--critic-lr', type=float, default=1e-3) From 89ce44f69366f4f0ef0233378117664bdd98f847 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Wed, 16 Mar 2022 11:02:31 -0400 Subject: [PATCH 35/40] learning_type: str --- tianshou/trainer/base.py | 37 ++++++++--------------------------- tianshou/trainer/offline.py | 3 +-- tianshou/trainer/offpolicy.py | 3 +-- tianshou/trainer/onpolicy.py | 3 +-- 4 files changed, 11 insertions(+), 35 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index cb474e731..c1523f4ee 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -17,8 +17,8 @@ class BaseTrainer(object): Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results on every epoch. - :param learning_type int|str: type of learning iterator, 0,1,2 for "offpolicy", - "onpolicy" and "offline" respectively + :param learning_type str: type of learning iterator, available choices are + "offpolicy", "onpolicy" and "offline". :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. :param Collector test_collector: the collector used for testing. If it's None, @@ -71,27 +71,17 @@ class BaseTrainer(object): Default to True. """ - learning_types: Dict[Union[int, str], Union[int, str]] = { - 0: "offpolicy", - "offpolicy": 0, - 1: "onpolicy", - "onpolicy": 1, - 2: "offline", - "offline": 2, - } - @staticmethod - def gen_doc(learning_type: Union[int, str]) -> str: - if isinstance(learning_type, int): - learning_type = BaseTrainer.learning_types[learning_type] + def gen_doc(learning_type: str) -> str: + """Document string for subclass trainer.""" step_means = f'The "step" in {learning_type} trainer means ' - if learning_type != 2: + if learning_type != "offline": step_means += "an environment step (a.k.a. transition)." else: # offline step_means += "a gradient step." - trainer_name = learning_type.capitalize() + "Trainer" # type: ignore + trainer_name = learning_type.capitalize() + "Trainer" return f"""An iterator class for {learning_type} trainer procedure. @@ -130,7 +120,7 @@ def gen_doc(learning_type: Union[int, str]) -> str: def __init__( self, - learning_type: Union[int, str], + learning_type: str, policy: BasePolicy, max_epoch: int, batch_size: int, @@ -202,18 +192,7 @@ def __init__( self.stop_fn_flag = False self.iter_num = 0 - update_function: Dict[Union[int, str], Callable] = { - 0: self.offpolicy_update, - "offpolicy": self.offpolicy_update, - 1: self.onpolicy_update, - "onpolicy": self.onpolicy_update, - 2: self.offline_update, - "offline": self.offline_update, - } - assert learning_type in self.learning_types - self.learning_type = learning_type - self.policy_update_fn: Callable[[Any, Any], None] = \ - update_function[self.learning_type] + self.policy_update_fn = getattr(self, f"{learning_type}_update") def reset(self) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 78f802743..dd742092f 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -71,9 +71,8 @@ def __init__( logger: BaseLogger = LazyLogger(), verbose: bool = True, ): - learning_type = super().learning_types["offline"] super().__init__( - learning_type=learning_type, + learning_type="offline", policy=policy, buffer=buffer, test_collector=test_collector, diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 775cafad2..2f457f89c 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -85,9 +85,8 @@ def __init__( verbose: bool = True, test_in_train: bool = True, ): - learning_type = super().learning_types["offpolicy"] super().__init__( - learning_type=learning_type, + learning_type="offpolicy", policy=policy, train_collector=train_collector, test_collector=test_collector, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 3e286a5fc..cda8b60d4 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -89,9 +89,8 @@ def __init__( verbose: bool = True, test_in_train: bool = True, ): - learning_type = super().learning_types["onpolicy"] super().__init__( - learning_type=learning_type, + learning_type="onpolicy", policy=policy, train_collector=train_collector, test_collector=test_collector, From a320e68f617c70e3f73ecb03afe4b6bd318488b7 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Wed, 16 Mar 2022 11:26:50 -0400 Subject: [PATCH 36/40] fix ci --- tianshou/trainer/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index c1523f4ee..b44aa5a9a 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -74,7 +74,6 @@ class BaseTrainer(object): @staticmethod def gen_doc(learning_type: str) -> str: """Document string for subclass trainer.""" - step_means = f'The "step" in {learning_type} trainer means ' if learning_type != "offline": step_means += "an environment step (a.k.a. transition)." From 6df93659333d3a83f6769fc0451af06b18269641 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 17 Mar 2022 08:07:11 -0400 Subject: [PATCH 37/40] reorg some code --- tianshou/trainer/base.py | 57 ++++++++--------------------------- tianshou/trainer/offline.py | 12 +++++++- tianshou/trainer/offpolicy.py | 10 +++++- tianshou/trainer/onpolicy.py | 18 ++++++++++- 4 files changed, 49 insertions(+), 48 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index b44aa5a9a..6d478d41c 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -1,4 +1,5 @@ import time +from abc import ABC, abstractmethod from collections import defaultdict, deque from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Union @@ -11,7 +12,7 @@ from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config -class BaseTrainer(object): +class BaseTrainer(ABC): """An iterator base class for trainers procedure. Returns an iterator that yields a 3-tuple (epoch, stats, info) of train results @@ -191,8 +192,6 @@ def __init__( self.stop_fn_flag = False self.iter_num = 0 - self.policy_update_fn = getattr(self, f"{learning_type}_update") - def reset(self) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" self.is_run = False @@ -241,7 +240,7 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: # iterator exhaustion check if self.epoch >= self.max_epoch: - if self.test_collector is None and self.save_fn: + if self.save_fn: self.save_fn(self.policy) raise StopIteration @@ -249,12 +248,6 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: if self.stop_fn_flag: raise StopIteration - # stop_fn criterion - if self.test_collector is not None and self.stop_fn and self.stop_fn( - self.best_reward - ): - raise StopIteration - # set policy in train mode self.policy.train() @@ -366,8 +359,8 @@ def train_step(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]: result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) self.env_step += int(result["n/st"]) self.logger.log_train_data(result, self.env_step) - self.last_rew = result['rew'] if result["n/ep"] > 0 else self.last_rew - self.last_len = result['len'] if result["n/ep"] > 0 else self.last_len + self.last_rew = result["rew"] if result["n/ep"] > 0 else self.last_rew + self.last_len = result["len"] if result["n/ep"] > 0 else self.last_len data = { "env_step": str(self.env_step), "rew": f"{self.last_rew:.2f}", @@ -399,39 +392,13 @@ def log_update_data(self, data: Dict[str, Any], losses: Dict[str, Any]) -> None: data[k] = f"{losses[k]:.3f}" self.logger.log_update_data(losses, self.gradient_step) - def offpolicy_update(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: - """Perform off-policy updates.""" - assert self.train_collector is not None - for _ in range(round(self.update_per_step * result["n/st"])): - self.gradient_step += 1 - losses = self.policy.update(self.batch_size, self.train_collector.buffer) - self.log_update_data(data, losses) - - def onpolicy_update( - self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None - ) -> None: - """Perform one on-policy update.""" - assert self.train_collector is not None - losses = self.policy.update( - 0, - self.train_collector.buffer, - batch_size=self.batch_size, - repeat=self.repeat_per_collect, - ) - self.train_collector.reset_buffer(keep_statistics=True) - step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)]) - self.gradient_step += step - self.log_update_data(data, losses) - - def offline_update( - self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None - ) -> None: - """Perform one off-line policy update.""" - assert self.buffer - self.gradient_step += 1 - losses = self.policy.update(self.batch_size, self.buffer) - data.update({"gradient_step": str(self.gradient_step)}) - self.log_update_data(data, losses) + @abstractmethod + def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: + """Policy update function for different trainer implementation. + + :param data: information in progress bar. + :param result: collector's return value. + """ def run(self) -> Dict[str, Union[float, str]]: """Consume iterator. diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index dd742092f..a5670dc4b 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np @@ -91,6 +91,16 @@ def __init__( verbose=verbose, ) + def policy_update_fn( + self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None + ) -> None: + """Perform one off-line policy update.""" + assert self.buffer + self.gradient_step += 1 + losses = self.policy.update(self.batch_size, self.buffer) + data.update({"gradient_step": str(self.gradient_step)}) + self.log_update_data(data, losses) + def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore """Wrapper for offline_trainer run method. diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 2f457f89c..feb132513 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np @@ -108,6 +108,14 @@ def __init__( test_in_train=test_in_train, ) + def policy_update_fn(self, data: Dict[str, Any], result: Dict[str, Any]) -> None: + """Perform off-policy updates.""" + assert self.train_collector is not None + for _ in range(round(self.update_per_step * result["n/st"])): + self.gradient_step += 1 + losses = self.policy.update(self.batch_size, self.train_collector.buffer) + self.log_update_data(data, losses) + def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore """Wrapper for OffPolicyTrainer run method. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index cda8b60d4..f718f3a25 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import numpy as np @@ -113,6 +113,22 @@ def __init__( test_in_train=test_in_train, ) + def policy_update_fn( + self, data: Dict[str, Any], result: Optional[Dict[str, Any]] = None + ) -> None: + """Perform one on-policy update.""" + assert self.train_collector is not None + losses = self.policy.update( + 0, + self.train_collector.buffer, + batch_size=self.batch_size, + repeat=self.repeat_per_collect, + ) + self.train_collector.reset_buffer(keep_statistics=True) + step = max([1] + [len(v) for v in losses.values() if isinstance(v, list)]) + self.gradient_step += step + self.log_update_data(data, losses) + def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: ignore """Wrapper for OnpolicyTrainer run method. From 7a00dafbf582ac11058ced999ded01d4e803f3d7 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 17 Mar 2022 08:27:16 -0400 Subject: [PATCH 38/40] revert --- tianshou/trainer/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 6d478d41c..fca1036ff 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -240,7 +240,7 @@ def __next__(self) -> Union[None, Tuple[int, Dict[str, Any], Dict[str, Any]]]: # iterator exhaustion check if self.epoch >= self.max_epoch: - if self.save_fn: + if self.test_collector is None and self.save_fn: self.save_fn(self.policy) raise StopIteration From 3ce4f6df958d6065409f0952378cbac9ffdfdf95 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 17 Mar 2022 08:34:57 -0400 Subject: [PATCH 39/40] missing docs for on-policy trainer --- tianshou/trainer/onpolicy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index f718f3a25..449cca117 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -62,6 +62,10 @@ class OnpolicyTrainer(BaseTrainer): :param bool verbose: whether to print the information. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True. + + .. note:: + + Only either one of step_per_collect and episode_per_collect can be specified. """ __doc__ = BaseTrainer.gen_doc("onpolicy") + "\n".join(__doc__.split("\n")[1:]) From a62cf84f228dae3e200a6922d437a7494e8753b5 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Thu, 17 Mar 2022 08:37:07 -0400 Subject: [PATCH 40/40] missing docs --- tianshou/trainer/offline.py | 2 ++ tianshou/trainer/offpolicy.py | 2 ++ tianshou/trainer/onpolicy.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index a5670dc4b..890429a8b 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -106,6 +106,8 @@ def offline_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: i """Wrapper for offline_trainer run method. It is identical to ``OfflineTrainer(...).run()``. + + :return: See :func:`~tianshou.trainer.gather_info`. """ return OfflineTrainer(*args, **kwargs).run() diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index feb132513..c3580397a 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -121,6 +121,8 @@ def offpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: """Wrapper for OffPolicyTrainer run method. It is identical to ``OffpolicyTrainer(...).run()``. + + :return: See :func:`~tianshou.trainer.gather_info`. """ return OffpolicyTrainer(*args, **kwargs).run() diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 449cca117..46b195a70 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -138,6 +138,8 @@ def onpolicy_trainer(*args, **kwargs) -> Dict[str, Union[float, str]]: # type: """Wrapper for OnpolicyTrainer run method. It is identical to ``OnpolicyTrainer(...).run()``. + + :return: See :func:`~tianshou.trainer.gather_info`. """ return OnpolicyTrainer(*args, **kwargs).run()