diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index ab731270e..2325ce6ee 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -103,7 +103,7 @@ def onpolicy_trainer( train_fn(epoch, env_step) result = train_collector.collect(n_step=step_per_collect, n_episode=episode_per_collect) - if reward_metric: + if result["n/ep"] > 0 and reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) t.update(result["n/st"]) @@ -117,19 +117,20 @@ def onpolicy_trainer( "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), } - if test_in_train and stop_fn and stop_fn(result["rew"]): - test_result = test_episode( - policy, test_collector, test_fn, - epoch, episode_per_test, logger, env_step) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - t.set_postfix(**data) - return gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) - else: - policy.train() + if result["n/ep"] > 0: + if test_in_train and stop_fn and stop_fn(result["rew"]): + test_result = test_episode( + policy, test_collector, test_fn, + epoch, episode_per_test, logger, env_step) + if stop_fn(test_result["rew"]): + if save_fn: + save_fn(policy) + t.set_postfix(**data) + return gather_info( + start_time, train_collector, test_collector, + test_result["rew"], test_result["rew_std"]) + else: + policy.train() losses = policy.update( 0, train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect) @@ -147,7 +148,7 @@ def onpolicy_trainer( t.update() # test test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, logger, env_step) + episode_per_test, logger, env_step, reward_metric) rew, rew_std = test_result["rew"], test_result["rew_std"] if best_epoch == -1 or best_reward < rew: best_reward, best_reward_std = rew, rew_std