From 97fac1bcc4299aabf2e6afe8559f193ae50e87f1 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 27 Nov 2021 03:03:38 +0800 Subject: [PATCH 01/18] update atari_crr network to be consistent with other algorithms --- examples/atari/README.md | 4 +-- examples/atari/atari_bcq.py | 6 ++-- examples/atari/atari_crr.py | 18 +++++++----- test/discrete/test_il_bcq.py | 18 +++++------- test/discrete/test_il_crr.py | 25 +++++----------- tianshou/policy/imitation/discrete_crr.py | 4 +-- tianshou/trainer/offline.py | 36 +++++++++++++---------- 7 files changed, 52 insertions(+), 59 deletions(-) diff --git a/examples/atari/README.md b/examples/atari/README.md index ffccecba3..a1ba48130 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -154,7 +154,7 @@ We test our CRR implementation on two example tasks (different from author's ver | Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | | ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | +| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index ec89243b4..022d9bb44 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -15,6 +15,7 @@ from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor @@ -102,9 +103,8 @@ def test_discrete_bcq(args=get_args()): hidden_sizes=args.hidden_sizes, softmax_output=False ).to(args.device) - optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr - ) + actor_critic = ActorCritic(policy_net, imitation_net) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py index 8905c7e58..5a3814a60 100644 --- a/examples/atari/atari_crr.py +++ b/examples/atari/atari_crr.py @@ -15,7 +15,8 @@ from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor, Critic def get_args(): @@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()): actor = Actor( feature_net, args.action_shape, - device=args.device, hidden_sizes=args.hidden_sizes, + device=args.device, softmax_output=False ).to(args.device) - critic = DQN(*args.state_shape, args.action_shape, - device=args.device).to(args.device) - optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr - ) + critic = Critic( + feature_net, + hidden_sizes=args.hidden_sizes, + last_size=np.prod(args.action_shape), + device=args.device + ).to(args.device) + actor_critic = ActorCritic(actor, critic) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy policy = DiscreteCRRPolicy( actor, diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 47540dadd..e236f6f37 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -13,7 +13,8 @@ from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.discrete import Actor def get_args(): @@ -65,21 +66,16 @@ def test_discrete_bcq(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - policy_net = Net( + net = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device - ).to(args.device) - imitation_net = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device - ).to(args.device) - optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr ) + policy_net = Actor(net, args.action_shape, device=args.device).to(args.device) + imitation_net = Actor(net, args.action_shape, device=args.device).to(args.device) + actor_critic = ActorCritic(policy_net, imitation_net) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteBCQPolicy( policy_net, diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py index 929469e8b..d11909b4d 100644 --- a/test/discrete/test_il_crr.py +++ b/test/discrete/test_il_crr.py @@ -13,7 +13,8 @@ from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.discrete import Actor, Critic def get_args(): @@ -60,23 +61,11 @@ def test_discrete_crr(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - actor = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device, - softmax=False - ) - critic = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device, - softmax=False - ) - optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr - ) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + critic = Critic(net, last_size=np.prod(args.action_shape), device=args.device) + actor_critic = ActorCritic(actor, critic) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteCRRPolicy( actor, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 6a149509e..dd4efe78b 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -83,14 +83,14 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignor if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() - q_t, _ = self.critic(batch.obs) + q_t = self.critic(batch.obs) act = to_torch(batch.act, dtype=torch.long, device=q_t.device) qa_t = q_t.gather(1, act.unsqueeze(1)) # Critic loss with torch.no_grad(): target_a_t, _ = self.actor_old(batch.obs_next) target_m = Categorical(logits=target_a_t) - q_t_target, _ = self.critic_old(batch.obs_next) + q_t_target = self.critic_old(batch.obs_next) rew = to_torch_as(batch.rew, q_t_target) expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) expected_target_q[batch.done > 0] = 0.0 diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index f5b454fa2..d2980ddcf 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -27,6 +27,7 @@ def offline_trainer( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. @@ -65,6 +66,7 @@ def offline_trainer( :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -75,12 +77,13 @@ def offline_trainer( start_time = time.time() test_collector.reset_stat() - test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + if test_in_train: + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + gradient_step, reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] if save_fn: save_fn(policy) @@ -98,17 +101,18 @@ def offline_trainer( logger.log_update_data(losses, gradient_step) t.set_postfix(**data) # test - test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) + if test_in_train: + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, + gradient_step, reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - if verbose: + if verbose and test_in_train: print( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" From 01d366da044bb7285539f87a497d1baa2f22e660 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 27 Nov 2021 03:03:38 +0800 Subject: [PATCH 02/18] update atari_crr network to be consistent with other algorithms --- examples/atari/README.md | 4 +-- examples/atari/atari_bcq.py | 6 ++-- examples/atari/atari_crr.py | 18 +++++++----- test/discrete/test_il_bcq.py | 18 +++++------- test/discrete/test_il_crr.py | 25 +++++----------- tianshou/policy/imitation/discrete_crr.py | 4 +-- tianshou/trainer/offline.py | 36 +++++++++++++---------- 7 files changed, 52 insertions(+), 59 deletions(-) diff --git a/examples/atari/README.md b/examples/atari/README.md index ffccecba3..a1ba48130 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -154,7 +154,7 @@ We test our CRR implementation on two example tasks (different from author's ver | Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | | ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | +| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index ec89243b4..022d9bb44 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -15,6 +15,7 @@ from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor @@ -102,9 +103,8 @@ def test_discrete_bcq(args=get_args()): hidden_sizes=args.hidden_sizes, softmax_output=False ).to(args.device) - optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr - ) + actor_critic = ActorCritic(policy_net, imitation_net) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy policy = DiscreteBCQPolicy( policy_net, imitation_net, optim, args.gamma, args.n_step, diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py index 8905c7e58..5a3814a60 100644 --- a/examples/atari/atari_crr.py +++ b/examples/atari/atari_crr.py @@ -15,7 +15,8 @@ from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.discrete import Actor +from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.discrete import Actor, Critic def get_args(): @@ -91,15 +92,18 @@ def test_discrete_crr(args=get_args()): actor = Actor( feature_net, args.action_shape, - device=args.device, hidden_sizes=args.hidden_sizes, + device=args.device, softmax_output=False ).to(args.device) - critic = DQN(*args.state_shape, args.action_shape, - device=args.device).to(args.device) - optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr - ) + critic = Critic( + feature_net, + hidden_sizes=args.hidden_sizes, + last_size=np.prod(args.action_shape), + device=args.device + ).to(args.device) + actor_critic = ActorCritic(actor, critic) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) # define policy policy = DiscreteCRRPolicy( actor, diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 47540dadd..e236f6f37 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -13,7 +13,8 @@ from tianshou.policy import DiscreteBCQPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.discrete import Actor def get_args(): @@ -65,21 +66,16 @@ def test_discrete_bcq(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - policy_net = Net( + net = Net( args.state_shape, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device - ).to(args.device) - imitation_net = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device - ).to(args.device) - optim = torch.optim.Adam( - list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr ) + policy_net = Actor(net, args.action_shape, device=args.device).to(args.device) + imitation_net = Actor(net, args.action_shape, device=args.device).to(args.device) + actor_critic = ActorCritic(policy_net, imitation_net) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteBCQPolicy( policy_net, diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py index 929469e8b..d11909b4d 100644 --- a/test/discrete/test_il_crr.py +++ b/test/discrete/test_il_crr.py @@ -13,7 +13,8 @@ from tianshou.policy import DiscreteCRRPolicy from tianshou.trainer import offline_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.discrete import Actor, Critic def get_args(): @@ -60,23 +61,11 @@ def test_discrete_crr(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - actor = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device, - softmax=False - ) - critic = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device, - softmax=False - ) - optim = torch.optim.Adam( - list(actor.parameters()) + list(critic.parameters()), lr=args.lr - ) + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) + critic = Critic(net, last_size=np.prod(args.action_shape), device=args.device) + actor_critic = ActorCritic(actor, critic) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) policy = DiscreteCRRPolicy( actor, diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 6a149509e..dd4efe78b 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -83,14 +83,14 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignor if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() - q_t, _ = self.critic(batch.obs) + q_t = self.critic(batch.obs) act = to_torch(batch.act, dtype=torch.long, device=q_t.device) qa_t = q_t.gather(1, act.unsqueeze(1)) # Critic loss with torch.no_grad(): target_a_t, _ = self.actor_old(batch.obs_next) target_m = Categorical(logits=target_a_t) - q_t_target, _ = self.critic_old(batch.obs_next) + q_t_target = self.critic_old(batch.obs_next) rew = to_torch_as(batch.rew, q_t_target) expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) expected_target_q[batch.done > 0] = 0.0 diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index f5b454fa2..d2980ddcf 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -27,6 +27,7 @@ def offline_trainer( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, + test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. @@ -65,6 +66,7 @@ def offline_trainer( :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -75,12 +77,13 @@ def offline_trainer( start_time = time.time() test_collector.reset_stat() - test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + if test_in_train: + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + gradient_step, reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] if save_fn: save_fn(policy) @@ -98,17 +101,18 @@ def offline_trainer( logger.log_update_data(losses, gradient_step) t.set_postfix(**data) # test - test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, - gradient_step, reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) + if test_in_train: + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, + gradient_step, reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - if verbose: + if verbose and test_in_train: print( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" From 1e8ad473fb5a7489e62003598bb3a582a31c0922 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 27 Nov 2021 07:50:35 +0800 Subject: [PATCH 03/18] move atari offline examples to offline directory --- examples/__init__.py | 0 examples/atari/README.md | 63 ------------------- examples/atari/__init__.py | 0 examples/offline/README.md | 63 ++++++++++++++++++- examples/offline/__init__.py | 0 examples/{atari => offline}/atari_bcq.py | 4 +- examples/{atari => offline}/atari_cql.py | 4 +- examples/{atari => offline}/atari_crr.py | 4 +- test/{discrete => offline}/test_il_bcq.py | 26 +++++--- .../test_il_cql.py} | 9 ++- test/{discrete => offline}/test_il_crr.py | 30 ++++++--- tianshou/trainer/offline.py | 15 +++-- 12 files changed, 122 insertions(+), 96 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/atari/__init__.py create mode 100644 examples/offline/__init__.py rename examples/{atari => offline}/atari_bcq.py (98%) rename examples/{atari => offline}/atari_cql.py (98%) rename examples/{atari => offline}/atari_crr.py (98%) rename test/{discrete => offline}/test_il_bcq.py (90%) rename test/{discrete/test_qrdqn_il_cql.py => offline/test_il_cql.py} (94%) rename test/{discrete => offline}/test_il_crr.py (82%) diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/atari/README.md b/examples/atari/README.md index a1ba48130..b8d6fd5de 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -95,66 +95,3 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` | | SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` | - -# BCQ - -To running BCQ algorithm on Atari, you need to do the following things: - -- Train an expert, by using the command listed in the above DQN section; -- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); -- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`. - -We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): - -| Task | Online DQN | Behavioral | BCQ | -| ---------------------- | ---------- | ---------- | --------------------------------- | -| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) | -| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) | - -# CQL - -To running CQL algorithm on Atari, you need to do the following things: - -- Train an expert, by using the command listed in the above QRDQN section; -- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); -- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`. - -We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): - -| Task | Online QRDQN | Behavioral | CQL | parameters | -| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | - -We reduce the size of the offline data to 10% and 1% of the above and get: - -Buffer size 100000: - -| Task | Online QRDQN | Behavioral | CQL | parameters | -| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | - -Buffer size 10000: - -| Task | Online QRDQN | Behavioral | CQL | parameters | -| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | -| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | - -# CRR - -To running CRR algorithm on Atari, you need to do the following things: - -- Train an expert, by using the command listed in the above QRDQN section; -- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); -- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`. - -We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): - -| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | -| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | -| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | -| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | - -Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/atari/__init__.py b/examples/atari/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/offline/README.md b/examples/offline/README.md index 8995ee6e2..8166d2adf 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -2,9 +2,11 @@ In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. -Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. +## Continous control -## Train +Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. + +### Train Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. @@ -26,3 +28,60 @@ After 1M steps: | --------------------- | --------------- | | halfcheetah-expert-v1 | 10624.0 ± 181.4 | +## Discrete control + +For discrete control, we currently use ad hod Atari data generated from a trained QRDQN agent. In the future, we can switch +to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged). + +### Gather Data + +To running CQL algorithm on Atari, you need to do the following things: + +- Train an expert, by using the command listed in the QRDQN section of Atari examples: `python3 atari_qrdqn.py --task {your_task}` +- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); +- Train offline model: `python3 atari_{bcq,cql,crr}.py --task {your_task} --load-buffer-name expert.hdf5`. + +### BCQ + +We test our BCQ implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | BCQ | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.1 (epoch 5) | `python3 atari_bcq.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 64.6 (epoch 12, could be higher) | `python3 atari_bcq.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12` | + +### CQL + +We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | 20.4 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 129.4 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | + +We reduce the size of the offline data to 10% and 1% of the above and get: + +Buffer size 100000: + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` | + +Buffer size 10000: + +| Task | Online QRDQN | Behavioral | CQL | parameters | +| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | +| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | + +### CRR + +We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | +| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 17.7 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | + +Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/offline/__init__.py b/examples/offline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/atari/atari_bcq.py b/examples/offline/atari_bcq.py similarity index 98% rename from examples/atari/atari_bcq.py rename to examples/offline/atari_bcq.py index 022d9bb44..deba538b5 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter +from examples.atari.atari_network import DQN +from examples.atari.atari_wrapper import wrap_deepmind from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteBCQPolicy diff --git a/examples/atari/atari_cql.py b/examples/offline/atari_cql.py similarity index 98% rename from examples/atari/atari_cql.py rename to examples/offline/atari_cql.py index 685e006db..7e46c5fb3 100644 --- a/examples/atari/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import QRDQN -from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter +from examples.atari.atari_network import QRDQN +from examples.atari.atari_wrapper import wrap_deepmind from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCQLPolicy diff --git a/examples/atari/atari_crr.py b/examples/offline/atari_crr.py similarity index 98% rename from examples/atari/atari_crr.py rename to examples/offline/atari_crr.py index 5a3814a60..aa69fb145 100644 --- a/examples/atari/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -6,10 +6,10 @@ import numpy as np import torch -from atari_network import DQN -from atari_wrapper import wrap_deepmind from torch.utils.tensorboard import SummaryWriter +from examples.atari.atari_network import DQN +from examples.atari.atari_wrapper import wrap_deepmind from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import ShmemVectorEnv from tianshou.policy import DiscreteCRRPolicy diff --git a/test/discrete/test_il_bcq.py b/test/offline/test_il_bcq.py similarity index 90% rename from test/discrete/test_il_bcq.py rename to test/offline/test_il_bcq.py index e236f6f37..56759e69a 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/offline/test_il_bcq.py @@ -16,6 +16,8 @@ from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor +from .gather_cartpole_data import gather_data + def get_args(): parser = argparse.ArgumentParser() @@ -38,7 +40,7 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, - default="./expert_DQN_CartPole-v0.pkl", + default="./expert_QRDQN_CartPole-v0.pkl", ) parser.add_argument( "--device", @@ -66,14 +68,19 @@ def test_discrete_bcq(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net( - args.state_shape, + net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + policy_net = Actor( + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device - ) - policy_net = Actor(net, args.action_shape, device=args.device).to(args.device) - imitation_net = Actor(net, args.action_shape, device=args.device).to(args.device) + ).to(args.device) + imitation_net = Actor( + net, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) @@ -89,9 +96,10 @@ def test_discrete_bcq(args=get_args()): args.imitation_logits_penalty, ) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run test_dqn.py first to get expert's data buffer." - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/offline/test_il_cql.py similarity index 94% rename from test/discrete/test_qrdqn_il_cql.py rename to test/offline/test_il_cql.py index 01b868f13..9d4800c15 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/offline/test_il_cql.py @@ -15,6 +15,8 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from .gather_cartpole_data import gather_data + def get_args(): parser = argparse.ArgumentParser() @@ -83,9 +85,10 @@ def test_discrete_cql(args=get_args()): min_q_weight=args.min_q_weight ).to(args.device) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run test_qrdqn.py first to get expert's data buffer." - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) diff --git a/test/discrete/test_il_crr.py b/test/offline/test_il_crr.py similarity index 82% rename from test/discrete/test_il_crr.py rename to test/offline/test_il_crr.py index d11909b4d..69ebe2571 100644 --- a/test/discrete/test_il_crr.py +++ b/test/offline/test_il_crr.py @@ -16,6 +16,8 @@ from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic +from .gather_cartpole_data import gather_data + def get_args(): parser = argparse.ArgumentParser() @@ -35,7 +37,7 @@ def get_args(): parser.add_argument( "--load-buffer-name", type=str, - default="./expert_DQN_CartPole-v0.pkl", + default="./expert_QRDQN_CartPole-v0.pkl", ) parser.add_argument( "--device", @@ -61,9 +63,20 @@ def test_discrete_crr(args=get_args()): torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) - critic = Critic(net, last_size=np.prod(args.action_shape), device=args.device) + net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + actor = Actor( + net, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax_output=False + ) + critic = Critic( + net, + hidden_sizes=args.hidden_sizes, + last_size=np.prod(args.action_shape), + device=args.device + ) actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) @@ -75,14 +88,15 @@ def test_discrete_crr(args=get_args()): target_update_freq=args.target_update_freq, ).to(args.device) # buffer - assert os.path.exists(args.load_buffer_name), \ - "Please run test_dqn.py first to get expert's data buffer." - buffer = pickle.load(open(args.load_buffer_name, "rb")) + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() # collector test_collector = Collector(policy, test_envs, exploration_noise=True) - log_path = os.path.join(args.logdir, args.task, 'discrete_cql') + log_path = os.path.join(args.logdir, args.task, 'discrete_crr') writer = SummaryWriter(log_path) logger = TensorboardLogger(writer) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index d2980ddcf..ddce7b518 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -27,7 +27,7 @@ def offline_trainer( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, - test_in_train: bool = True, + disable_test: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. @@ -66,7 +66,7 @@ def offline_trainer( :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. - :param bool test_in_train: whether to test in the training phase. Default to True. + :param bool disable_test: whether to run tests at all. Default to False. :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -76,8 +76,9 @@ def offline_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() test_collector.reset_stat() + best_reward, best_reward_std = 0, 0 - if test_in_train: + if not disable_test: test_result = test_episode( policy, test_collector, test_fn, start_epoch, episode_per_test, logger, gradient_step, reward_metric @@ -101,7 +102,7 @@ def offline_trainer( logger.log_update_data(losses, gradient_step) t.set_postfix(**data) # test - if test_in_train: + if not disable_test: test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, logger, gradient_step, reward_metric @@ -112,11 +113,15 @@ def offline_trainer( if save_fn: save_fn(policy) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - if verbose and test_in_train: + if verbose and not disable_test: print( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) if stop_fn and stop_fn(best_reward): break + + if disable_test and save_fn: + save_fn(policy) + return gather_info(start_time, None, test_collector, best_reward, best_reward_std) From ce33933f1a1a88e1f7268c2ccc70add4b07ba38a Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 26 Nov 2021 17:57:26 -0800 Subject: [PATCH 04/18] fix code format --- test/offline/test_il_bcq.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/offline/test_il_bcq.py b/test/offline/test_il_bcq.py index 56759e69a..a8f990d77 100644 --- a/test/offline/test_il_bcq.py +++ b/test/offline/test_il_bcq.py @@ -70,16 +70,10 @@ def test_discrete_bcq(args=get_args()): # model net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) policy_net = Actor( - net, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device ).to(args.device) imitation_net = Actor( - net, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device + net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device ).to(args.device) actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) From c18595949807d07565524f893770f8d8bc1a9110 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 26 Nov 2021 19:02:22 -0800 Subject: [PATCH 05/18] check in missing file --- test/offline/gather_cartpole_data.py | 161 +++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 test/offline/gather_cartpole_data.py diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py new file mode 100644 index 000000000..fffaf8eb8 --- /dev/null +++ b/test/offline/gather_cartpole_data.py @@ -0,0 +1,161 @@ +import argparse +import os +import pickle +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import QRDQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--num-quantiles', type=int, default=200) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl" + ) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + args = parser.parse_known_args()[0] + return args + + +def gather_data(): + args = get_args() + env = gym.make(args.task) + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 190 # lower the goal + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + softmax=False, + num_atoms=args.num_quantiles + ) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = QRDQNPolicy( + net, + optim, + args.gamma, + args.num_quantiles, + args.n_step, + target_update_freq=args.target_update_freq + ).to(args.device) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, + buffer_num=len(train_envs), + alpha=args.alpha, + beta=args.beta + ) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'qrdqn') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step + ) + assert stop_fn(result['best_reward']) + + # save buffer in pickle format, for imitation learning unittest + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) + policy.set_eps(0.2) + collector = Collector(policy, test_envs, buf, exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + pickle.dump(buf, open(args.save_buffer_name, "wb")) + print(result["rews"].mean()) + return buf From 797febfb8261296dd5e6866804ca11baece0bcf8 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Fri, 26 Nov 2021 19:12:32 -0800 Subject: [PATCH 06/18] make linter happy --- test/offline/gather_cartpole_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index fffaf8eb8..9d7d79316 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -1,7 +1,6 @@ import argparse import os import pickle -import pprint import gym import numpy as np From 6522c1a12748d82e307de439d0622fe569db4302 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 27 Nov 2021 12:07:45 +0800 Subject: [PATCH 07/18] remove unused code --- test/discrete/test_dqn.py | 11 ----------- test/discrete/test_qrdqn.py | 11 ----------- tianshou/trainer/utils.py | 5 ++++- 3 files changed, 4 insertions(+), 23 deletions(-) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 6912a1933..b52b0a6d3 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -42,9 +42,6 @@ def get_args(): parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) - parser.add_argument( - '--save-buffer-name', type=str, default="./expert_DQN_CartPole-v0.pkl" - ) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' ) @@ -157,14 +154,6 @@ def test_fn(epoch, env_step): rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - # save buffer in pickle format, for imitation learning unittest - buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) - policy.set_eps(0.2) - collector = Collector(policy, test_envs, buf, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) - pickle.dump(buf, open(args.save_buffer_name, "wb")) - print(result["rews"].mean()) - def test_pdqn(args=get_args()): args.prioritized_replay = True diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index cf8d22212..45a5188b0 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -43,9 +43,6 @@ def get_args(): parser.add_argument('--prioritized-replay', action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) - parser.add_argument( - '--save-buffer-name', type=str, default="./expert_QRDQN_CartPole-v0.pkl" - ) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' ) @@ -161,14 +158,6 @@ def test_fn(epoch, env_step): rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews.mean()}, length: {lens.mean()}") - # save buffer in pickle format, for imitation learning unittest - buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) - policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper - collector = Collector(policy, test_envs, buf, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) - pickle.dump(buf, open(args.save_buffer_name, "wb")) - print(result["rews"].mean()) - def test_pqrdqn(args=get_args()): args.prioritized_replay = True diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 6ad2f0f2a..104658d10 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -59,7 +59,10 @@ def gather_info( """ duration = time.time() - start_time model_time = duration - test_c.collect_time - test_speed = test_c.collect_step / test_c.collect_time + try: + test_speed = test_c.collect_step / test_c.collect_time + except ZeroDivisionError: + test_speed = 0.0 result: Dict[str, Union[float, str]] = { "test_step": test_c.collect_step, "test_episode": test_c.collect_episode, From 3df590008fd49a676c179df3890474273afbe724 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 27 Nov 2021 12:10:12 +0800 Subject: [PATCH 08/18] make linter happy --- test/discrete/test_dqn.py | 1 - test/discrete/test_qrdqn.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b52b0a6d3..2b254fb28 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,6 +1,5 @@ import argparse import os -import pickle import pprint import gym diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 45a5188b0..f869f018b 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,6 +1,5 @@ import argparse import os -import pickle import pprint import gym From 8e0799737f31f9d45b28dabe747d23f4ff021edd Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 27 Nov 2021 14:39:53 -0500 Subject: [PATCH 09/18] format --- examples/atari/README.md | 2 +- examples/offline/README.md | 3 +-- examples/offline/atari_bcq.py | 6 +++--- examples/offline/atari_cql.py | 4 ++-- examples/offline/atari_crr.py | 8 ++++---- test/discrete/test_dqn.py | 6 +++--- test/discrete/test_qrdqn.py | 8 ++++---- test/offline/gather_cartpole_data.py | 8 ++++---- test/offline/{test_il_bcq.py => test_discrete_bcq.py} | 0 test/offline/{test_il_cql.py => test_discrete_cql.py} | 0 test/offline/{test_il_crr.py => test_discrete_crr.py} | 0 11 files changed, 22 insertions(+), 23 deletions(-) rename test/offline/{test_il_bcq.py => test_discrete_bcq.py} (100%) rename test/offline/{test_il_cql.py => test_discrete_cql.py} (100%) rename test/offline/{test_il_crr.py => test_discrete_crr.py} (100%) diff --git a/examples/atari/README.md b/examples/atari/README.md index b8d6fd5de..51b1af931 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -1,4 +1,4 @@ -# Atari General +# Atari The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network. diff --git a/examples/offline/README.md b/examples/offline/README.md index 8166d2adf..c0a07fab0 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -30,8 +30,7 @@ After 1M steps: ## Discrete control -For discrete control, we currently use ad hod Atari data generated from a trained QRDQN agent. In the future, we can switch -to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged). +For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged). ### Gather Data diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index deba538b5..83865d18d 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -94,14 +94,14 @@ def test_discrete_bcq(args=get_args()): args.action_shape, device=args.device, hidden_sizes=args.hidden_sizes, - softmax_output=False + softmax_output=False, ).to(args.device) imitation_net = Actor( feature_net, args.action_shape, device=args.device, hidden_sizes=args.hidden_sizes, - softmax_output=False + softmax_output=False, ).to(args.device) actor_critic = ActorCritic(policy_net, imitation_net) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) @@ -171,7 +171,7 @@ def watch(): args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) pprint.pprint(result) diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 7e46c5fb3..22ef7b253 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -94,7 +94,7 @@ def test_discrete_cql(args=get_args()): args.num_quantiles, args.n_step, args.target_update_freq, - min_q_weight=args.min_q_weight + min_q_weight=args.min_q_weight, ).to(args.device) # load a previous policy if args.resume_path: @@ -156,7 +156,7 @@ def watch(): args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) pprint.pprint(result) diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index aa69fb145..0214bf2f5 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -94,13 +94,13 @@ def test_discrete_crr(args=get_args()): args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, - softmax_output=False + softmax_output=False, ).to(args.device) critic = Critic( feature_net, hidden_sizes=args.hidden_sizes, last_size=np.prod(args.action_shape), - device=args.device + device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) @@ -114,7 +114,7 @@ def test_discrete_crr(args=get_args()): ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, min_q_weight=args.min_q_weight, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # load a previous policy if args.resume_path: @@ -175,7 +175,7 @@ def watch(): args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) pprint.pprint(result) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 2b254fb28..c02866493 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -81,7 +81,7 @@ def test_dqn(args=get_args()): optim, args.gamma, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ) # buffer if args.prioritized_replay: @@ -89,7 +89,7 @@ def test_dqn(args=get_args()): args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, - beta=args.beta + beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) @@ -138,7 +138,7 @@ def test_fn(epoch, env_step): test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, - logger=logger + logger=logger, ) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index f869f018b..956cb03fd 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -76,7 +76,7 @@ def test_qrdqn(args=get_args()): hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, - num_atoms=args.num_quantiles + num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = QRDQNPolicy( @@ -85,7 +85,7 @@ def test_qrdqn(args=get_args()): args.gamma, args.num_quantiles, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # buffer if args.prioritized_replay: @@ -93,7 +93,7 @@ def test_qrdqn(args=get_args()): args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, - beta=args.beta + beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) @@ -142,7 +142,7 @@ def test_fn(epoch, env_step): stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step + update_per_step=args.update_per_step, ) assert stop_fn(result['best_reward']) diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 9d7d79316..78d85170d 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -80,7 +80,7 @@ def gather_data(): hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, - num_atoms=args.num_quantiles + num_atoms=args.num_quantiles, ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = QRDQNPolicy( @@ -89,7 +89,7 @@ def gather_data(): args.gamma, args.num_quantiles, args.n_step, - target_update_freq=args.target_update_freq + target_update_freq=args.target_update_freq, ).to(args.device) # buffer if args.prioritized_replay: @@ -97,7 +97,7 @@ def gather_data(): args.buffer_size, buffer_num=len(train_envs), alpha=args.alpha, - beta=args.beta + beta=args.beta, ) else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) @@ -146,7 +146,7 @@ def test_fn(epoch, env_step): stop_fn=stop_fn, save_fn=save_fn, logger=logger, - update_per_step=args.update_per_step + update_per_step=args.update_per_step, ) assert stop_fn(result['best_reward']) diff --git a/test/offline/test_il_bcq.py b/test/offline/test_discrete_bcq.py similarity index 100% rename from test/offline/test_il_bcq.py rename to test/offline/test_discrete_bcq.py diff --git a/test/offline/test_il_cql.py b/test/offline/test_discrete_cql.py similarity index 100% rename from test/offline/test_il_cql.py rename to test/offline/test_discrete_cql.py diff --git a/test/offline/test_il_crr.py b/test/offline/test_discrete_crr.py similarity index 100% rename from test/offline/test_il_crr.py rename to test/offline/test_discrete_crr.py From 28665c1cb45cf36e259574e90bf2bb670a57b35c Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sun, 28 Nov 2021 10:04:28 +0800 Subject: [PATCH 10/18] use test_collector is None to turn off testing --- tianshou/trainer/offline.py | 25 ++++++++++++------------- tianshou/trainer/utils.py | 34 +++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index ddce7b518..211d9903f 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -14,7 +14,7 @@ def offline_trainer( policy: BasePolicy, buffer: ReplayBuffer, - test_collector: Collector, + test_collector: Optional[Collector], max_epoch: int, update_per_epoch: int, episode_per_test: int, @@ -27,14 +27,14 @@ def offline_trainer( reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, logger: BaseLogger = LazyLogger(), verbose: bool = True, - disable_test: bool = False, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. The "step" in offline trainer means a gradient step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param Collector test_collector: the collector used for testing. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int update_per_epoch: the number of policy network updates, so-called @@ -66,7 +66,6 @@ def offline_trainer( :param BaseLogger logger: A logger that logs statistics during updating/testing. Default to a logger that doesn't log anything. :param bool verbose: whether to print the information. Default to True. - :param bool disable_test: whether to run tests at all. Default to False. :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -75,10 +74,10 @@ def offline_trainer( start_epoch, _, gradient_step = logger.restore_data() stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() - test_collector.reset_stat() best_reward, best_reward_std = 0, 0 - if not disable_test: + if test_collector is not None: + test_collector.reset_stat() test_result = test_episode( policy, test_collector, test_fn, start_epoch, episode_per_test, logger, gradient_step, reward_metric @@ -102,7 +101,7 @@ def offline_trainer( logger.log_update_data(losses, gradient_step) t.set_postfix(**data) # test - if not disable_test: + if test_collector is not None: test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, logger, gradient_step, reward_metric @@ -112,16 +111,16 @@ def offline_trainer( best_epoch, best_reward, best_reward_std = epoch, rew, rew_std if save_fn: save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - if verbose and not disable_test: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" - ) if stop_fn and stop_fn(best_reward): break - if disable_test and save_fn: + if test_collector is None and save_fn: save_fn(policy) return gather_info(start_time, None, test_collector, best_reward, best_reward_std) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 104658d10..9bb841248 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -36,7 +36,7 @@ def test_episode( def gather_info( start_time: float, train_c: Optional[Collector], - test_c: Collector, + test_c: Optional[Collector], best_reward: float, best_reward_std: float, ) -> Dict[str, Union[float, str]]: @@ -58,24 +58,32 @@ def gather_info( * ``duration`` the total elapsed time. """ duration = time.time() - start_time - model_time = duration - test_c.collect_time - try: - test_speed = test_c.collect_step / test_c.collect_time - except ZeroDivisionError: - test_speed = 0.0 + model_time = duration result: Dict[str, Union[float, str]] = { - "test_step": test_c.collect_step, - "test_episode": test_c.collect_episode, - "test_time": f"{test_c.collect_time:.2f}s", - "test_speed": f"{test_speed:.2f} step/s", - "best_reward": best_reward, - "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", "duration": f"{duration:.2f}s", "train_time/model": f"{model_time:.2f}s", } + if test_c is not None: + model_time = duration - test_c.collect_time + test_speed = test_c.collect_step / test_c.collect_time + result.update( + { + "test_step": test_c.collect_step, + "test_episode": test_c.collect_episode, + "test_time": f"{test_c.collect_time:.2f}s", + "test_speed": f"{test_speed:.2f} step/s", + "best_reward": best_reward, + "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", + "duration": f"{duration:.2f}s", + "train_time/model": f"{model_time:.2f}s", + } + ) if train_c is not None: model_time -= train_c.collect_time - train_speed = train_c.collect_step / (duration - test_c.collect_time) + if test_c is not None: + train_speed = train_c.collect_step / (duration - test_c.collect_time) + else: + train_speed = train_c.collect_step / duration result.update( { "train_step": train_c.collect_step, From 71825e93e6a6d38e3174b55752a6f3e633e02800 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sun, 28 Nov 2021 10:32:56 +0800 Subject: [PATCH 11/18] make test_collector optional in all trainers --- tianshou/trainer/offpolicy.py | 58 +++++++++++++++++++++-------------- tianshou/trainer/onpolicy.py | 58 +++++++++++++++++++++-------------- 2 files changed, 70 insertions(+), 46 deletions(-) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index eb14dfb1f..32612dc01 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -14,7 +14,7 @@ def offpolicy_trainer( policy: BasePolicy, train_collector: Collector, - test_collector: Collector, + test_collector: Optional[Collector], max_epoch: int, step_per_epoch: int, step_per_collect: int, @@ -38,7 +38,8 @@ def offpolicy_trainer( :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int step_per_epoch: the number of transitions collected per epoch. @@ -90,14 +91,20 @@ def offpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - test_collector.reset_stat() - test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - env_step, reward_metric + best_reward, best_reward_std = 0, 0 + test_in_train = ( + test_in_train and train_collector.policy == policy + and test_collector is not None ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + + if test_collector is not None: + test_collector.reset_stat() + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + env_step, reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] if save_fn: save_fn(policy) @@ -157,23 +164,28 @@ def offpolicy_trainer( if t.n <= t.total: t.update() # test - test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + if test_collector is not None: + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, + env_step, reward_metric ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if stop_fn and stop_fn(best_reward): break + + if test_collector is None and save_fn: + save_fn(policy) + return gather_info( start_time, train_collector, test_collector, best_reward, best_reward_std ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 2c539a2e7..0f9a522df 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -14,7 +14,7 @@ def onpolicy_trainer( policy: BasePolicy, train_collector: Collector, - test_collector: Collector, + test_collector: Optional[Collector], max_epoch: int, step_per_epoch: int, repeat_per_collect: int, @@ -39,7 +39,8 @@ def onpolicy_trainer( :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. - :param Collector test_collector: the collector used for testing. + :param Collector test_collector: the collector used for testing. If it's None, then + no testing will be performed. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int step_per_epoch: the number of transitions collected per epoch. @@ -96,14 +97,20 @@ def onpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - test_collector.reset_stat() - test_in_train = test_in_train and train_collector.policy == policy - test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - env_step, reward_metric + best_reward, best_reward_std = 0, 0 + test_in_train = ( + test_in_train and train_collector.policy == policy + and test_collector is not None ) - best_epoch = start_epoch - best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + + if test_collector is not None: + test_collector.reset_stat() + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + env_step, reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] if save_fn: save_fn(policy) @@ -173,23 +180,28 @@ def onpolicy_trainer( if t.n <= t.total: t.update() # test - test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, - reward_metric - ) - rew, rew_std = test_result["rew"], test_result["rew_std"] - if best_epoch < 0 or best_reward < rew: - best_epoch, best_reward, best_reward_std = epoch, rew, rew_std - if save_fn: - save_fn(policy) - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - if verbose: - print( - f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" - f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + if test_collector is not None: + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, + env_step, reward_metric ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) if stop_fn and stop_fn(best_reward): break + + if test_collector is None and save_fn: + save_fn(policy) + return gather_info( start_time, train_collector, test_collector, best_reward, best_reward_std ) From 00f1b8e5567b1a8f69b5742accb2c1af58110b3a Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sun, 28 Nov 2021 10:37:41 +0800 Subject: [PATCH 12/18] make linter and formatter happy --- tianshou/trainer/offpolicy.py | 5 ++--- tianshou/trainer/onpolicy.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 32612dc01..dc4cf80d8 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -92,9 +92,8 @@ def offpolicy_trainer( start_time = time.time() train_collector.reset_stat() best_reward, best_reward_std = 0, 0 - test_in_train = ( - test_in_train and train_collector.policy == policy - and test_collector is not None + test_in_train = test_in_train and ( + train_collector.policy == policy and test_collector is not None ) if test_collector is not None: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 0f9a522df..a77771726 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -98,9 +98,8 @@ def onpolicy_trainer( start_time = time.time() train_collector.reset_stat() best_reward, best_reward_std = 0, 0 - test_in_train = ( - test_in_train and train_collector.policy == policy - and test_collector is not None + test_in_train = test_in_train and ( + train_collector.policy == policy and test_collector is not None ) if test_collector is not None: From e930a965ed4da54c61ec1da8cab2064758469501 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sun, 28 Nov 2021 11:07:07 +0800 Subject: [PATCH 13/18] make mypy happy --- tianshou/trainer/offline.py | 5 +++-- tianshou/trainer/offpolicy.py | 13 +++++++------ tianshou/trainer/onpolicy.py | 7 ++++--- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 211d9903f..deb5139d0 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -77,9 +77,10 @@ def offline_trainer( best_reward, best_reward_std = 0, 0 if test_collector is not None: + test_c: Collector = test_collector test_collector.reset_stat() test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + policy, test_c, test_fn, start_epoch, episode_per_test, logger, gradient_step, reward_metric ) best_epoch = start_epoch @@ -103,7 +104,7 @@ def offline_trainer( # test if test_collector is not None: test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, + policy, test_c, test_fn, epoch, episode_per_test, logger, gradient_step, reward_metric ) rew, rew_std = test_result["rew"], test_result["rew_std"] diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index dc4cf80d8..b62431b32 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -97,10 +97,11 @@ def offpolicy_trainer( ) if test_collector is not None: + test_c: Collector = test_collector # for mypy test_collector.reset_stat() test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, - env_step, reward_metric + policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, + reward_metric ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] @@ -135,8 +136,8 @@ def offpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, - logger, env_step + policy, test_c, test_fn, epoch, episode_per_test, logger, + env_step ) if stop_fn(test_result["rew"]): if save_fn: @@ -165,8 +166,8 @@ def offpolicy_trainer( # test if test_collector is not None: test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, - env_step, reward_metric + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index a77771726..6a80b935f 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -103,9 +103,10 @@ def onpolicy_trainer( ) if test_collector is not None: + test_c: Collector = test_collector # for mypy test_collector.reset_stat() test_result = test_episode( - policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, reward_metric ) best_epoch = start_epoch @@ -143,7 +144,7 @@ def onpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step ) if stop_fn(test_result["rew"]): @@ -181,7 +182,7 @@ def onpolicy_trainer( # test if test_collector is not None: test_result = test_episode( - policy, test_collector, test_fn, epoch, episode_per_test, logger, + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, reward_metric ) rew, rew_std = test_result["rew"], test_result["rew_std"] From a8afc7dda27aa6c8e532a1f35dd80665cb195b84 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 27 Nov 2021 19:11:11 -0800 Subject: [PATCH 14/18] code format --- tianshou/trainer/onpolicy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 6a80b935f..88f36bfcd 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -106,8 +106,8 @@ def onpolicy_trainer( test_c: Collector = test_collector # for mypy test_collector.reset_stat() test_result = test_episode( - policy, test_c, test_fn, start_epoch, episode_per_test, logger, - env_step, reward_metric + policy, test_c, test_fn, start_epoch, episode_per_test, logger, env_step, + reward_metric ) best_epoch = start_epoch best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] @@ -144,8 +144,8 @@ def onpolicy_trainer( if result["n/ep"] > 0: if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, - logger, env_step + policy, test_c, test_fn, epoch, episode_per_test, logger, + env_step ) if stop_fn(test_result["rew"]): if save_fn: @@ -182,8 +182,8 @@ def onpolicy_trainer( # test if test_collector is not None: test_result = test_episode( - policy, test_c, test_fn, epoch, episode_per_test, logger, - env_step, reward_metric + policy, test_c, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric ) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch < 0 or best_reward < rew: From a9a33f4057707a183be70d008bc38d3d3aad373e Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sun, 28 Nov 2021 11:52:22 +0800 Subject: [PATCH 15/18] make mypy happy --- tianshou/trainer/offline.py | 2 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index deb5139d0..417ca5670 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -74,7 +74,7 @@ def offline_trainer( start_epoch, _, gradient_step = logger.restore_data() stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() - best_reward, best_reward_std = 0, 0 + best_reward, best_reward_std = 0.0, 0.0 if test_collector is not None: test_c: Collector = test_collector diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index b62431b32..2df6bb356 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -91,7 +91,7 @@ def offpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - best_reward, best_reward_std = 0, 0 + best_reward, best_reward_std = 0.0, 0.0 test_in_train = test_in_train and ( train_collector.policy == policy and test_collector is not None ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 88f36bfcd..7c8ca5bed 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -97,7 +97,7 @@ def onpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - best_reward, best_reward_std = 0, 0 + best_reward, best_reward_std = 0.0, 0.0 test_in_train = test_in_train and ( train_collector.policy == policy and test_collector is not None ) From 5fb2d43a15536467eabed560adbc7a0253ade3e6 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sun, 28 Nov 2021 12:14:26 +0800 Subject: [PATCH 16/18] make mypy happy --- tianshou/trainer/offline.py | 8 ++++++-- tianshou/trainer/offpolicy.py | 10 ++++++---- tianshou/trainer/onpolicy.py | 10 ++++++---- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 417ca5670..c630669d1 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -74,7 +74,6 @@ def offline_trainer( start_epoch, _, gradient_step = logger.restore_data() stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() - best_reward, best_reward_std = 0.0, 0.0 if test_collector is not None: test_c: Collector = test_collector @@ -124,4 +123,9 @@ def offline_trainer( if test_collector is None and save_fn: save_fn(policy) - return gather_info(start_time, None, test_collector, best_reward, best_reward_std) + if test_collector is None: + return gather_info(start_time, None, None, 0.0, 0.0) + else: + return gather_info( + start_time, None, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 2df6bb356..f0d385e7c 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -91,7 +91,6 @@ def offpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - best_reward, best_reward_std = 0.0, 0.0 test_in_train = test_in_train and ( train_collector.policy == policy and test_collector is not None ) @@ -186,6 +185,9 @@ def offpolicy_trainer( if test_collector is None and save_fn: save_fn(policy) - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 7c8ca5bed..845bcaf81 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -97,7 +97,6 @@ def onpolicy_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() - best_reward, best_reward_std = 0.0, 0.0 test_in_train = test_in_train and ( train_collector.policy == policy and test_collector is not None ) @@ -202,6 +201,9 @@ def onpolicy_trainer( if test_collector is None and save_fn: save_fn(policy) - return gather_info( - start_time, train_collector, test_collector, best_reward, best_reward_std - ) + if test_collector is None: + return gather_info(start_time, train_collector, None, 0.0, 0.0) + else: + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) From 8d33a2815acfc4d80c020acefcadf05a38dec433 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sun, 28 Nov 2021 09:32:40 -0500 Subject: [PATCH 17/18] fix exception when test_collector=None --- test/offline/test_discrete_bcq.py | 5 ++++- test/offline/test_discrete_cql.py | 5 ++++- test/offline/test_discrete_crr.py | 5 ++++- tianshou/trainer/offline.py | 6 +++--- tianshou/trainer/offpolicy.py | 6 +++--- tianshou/trainer/onpolicy.py | 6 +++--- tianshou/utils/logger/tensorboard.py | 1 + 7 files changed, 22 insertions(+), 12 deletions(-) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index a8f990d77..460ddb304 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -16,7 +16,10 @@ from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor -from .gather_cartpole_data import gather_data +if __name__ == "__main__": + from gather_cartpole_data import gather_data +else: # pytest + from test.offline.gather_cartpole_data import gather_data def get_args(): diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 9d4800c15..c97f45628 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -15,7 +15,10 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -from .gather_cartpole_data import gather_data +if __name__ == "__main__": + from gather_cartpole_data import gather_data +else: # pytest + from test.offline.gather_cartpole_data import gather_data def get_args(): diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index 69ebe2571..0b4e7c63d 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -16,7 +16,10 @@ from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic -from .gather_cartpole_data import gather_data +if __name__ == "__main__": + from gather_cartpole_data import gather_data +else: # pytest + from test.offline.gather_cartpole_data import gather_data def get_args(): diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index c630669d1..d2f85bc2a 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -100,6 +100,7 @@ def offline_trainer( data[k] = f"{losses[k]:.3f}" logger.log_update_data(losses, gradient_step) t.set_postfix(**data) + logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) # test if test_collector is not None: test_result = test_episode( @@ -116,9 +117,8 @@ def offline_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn) - if stop_fn and stop_fn(best_reward): - break + if stop_fn and stop_fn(best_reward): + break if test_collector is None and save_fn: save_fn(policy) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index f0d385e7c..9b8727b24 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -162,6 +162,7 @@ def offpolicy_trainer( t.set_postfix(**data) if t.n <= t.total: t.update() + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # test if test_collector is not None: test_result = test_episode( @@ -178,9 +179,8 @@ def offpolicy_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - if stop_fn and stop_fn(best_reward): - break + if stop_fn and stop_fn(best_reward): + break if test_collector is None and save_fn: save_fn(policy) diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 845bcaf81..251c55637 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -178,6 +178,7 @@ def onpolicy_trainer( t.set_postfix(**data) if t.n <= t.total: t.update() + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) # test if test_collector is not None: test_result = test_episode( @@ -194,9 +195,8 @@ def onpolicy_trainer( f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" ) - logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) - if stop_fn and stop_fn(best_reward): - break + if stop_fn and stop_fn(best_reward): + break if test_collector is None and save_fn: save_fn(policy) diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index 86e873cda..469d32765 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -35,6 +35,7 @@ def __init__( def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: for k, v in data.items(): self.writer.add_scalar(k, v, global_step=step) + self.writer.flush() # issue #482 def save_data( self, From 0b61ba0303a93fea73c4ddb09331049c130195c1 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sun, 28 Nov 2021 09:35:50 -0500 Subject: [PATCH 18/18] bump to 0.4.5 --- tianshou/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index bc61c15fc..c9416847d 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,6 +1,6 @@ from tianshou import data, env, exploration, policy, trainer, utils -__version__ = "0.4.4" +__version__ = "0.4.5" __all__ = [ "env",