From 47b966aca41079decef4f17b13d8fb79bc7ec73e Mon Sep 17 00:00:00 2001 From: anyongjin Date: Tue, 13 Aug 2024 10:40:02 +0800 Subject: [PATCH 1/4] add `evaluate_test_fn` to `BaseTrainer` --- tianshou/data/stats.py | 2 ++ tianshou/trainer/base.py | 38 ++++++++++++++++++++++++++++++++------ tianshou/trainer/utils.py | 2 ++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 4685f5730..409807286 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -56,6 +56,8 @@ class InfoStats(DataclassPPrintMixin): gradient_step: int """The total gradient step.""" + best_score: float + """The best score over the test results.""" best_reward: float """The best reward over the test results.""" best_reward_std: float diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 242f2b028..09eda98c8 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -73,6 +73,7 @@ class BaseTrainer(ABC): :param test_fn: a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. + :param evaluate_test_fn: Calculate the test batch performance score to determine whether it is the best model :param save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. @@ -164,6 +165,7 @@ def __init__( train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, + evaluate_test_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, @@ -185,6 +187,7 @@ def __init__( self.logger = logger self.start_time = time.time() self.stat: defaultdict[str, MovAvg] = defaultdict(MovAvg) + self.best_score = 0.0 self.best_reward = 0.0 self.best_reward_std = 0.0 self.start_epoch = 0 @@ -210,6 +213,7 @@ def __init__( self.train_fn = train_fn self.test_fn = test_fn self.stop_fn = stop_fn + self.evaluate_test_fn = evaluate_test_fn self.save_best_fn = save_best_fn self.save_checkpoint_fn = save_checkpoint_fn @@ -273,6 +277,10 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No test_result.returns_stat.mean, test_result.returns_stat.std, ) + if self.evaluate_test_fn: + self.best_score = self.evaluate_test_fn(test_result) + else: + self.best_score = self.best_reward if self.save_best_fn: self.save_best_fn(self.policy) @@ -351,6 +359,7 @@ def __next__(self) -> EpochStats: start_time=self.start_time, policy_update_time=self.policy_update_time, gradient_step=self._gradient_step, + best_score=self.best_score, best_reward=self.best_reward, best_reward_std=self.best_reward_std, train_collector=self.train_collector, @@ -384,17 +393,29 @@ def test_step(self) -> tuple[CollectStats, bool]: ) assert test_stat.returns_stat is not None # for mypy rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std - if self.best_epoch < 0 or self.best_reward < rew: + if self.evaluate_test_fn: + score = self.evaluate_test_fn(test_stat) + else: + score = float(rew) + if self.best_epoch < 0 or self.best_score < score: + self.best_score = score self.best_epoch = self.epoch self.best_reward = float(rew) self.best_reward_std = rew_std if self.save_best_fn: self.save_best_fn(self.policy) - log_msg = ( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f} in #{self.best_epoch}" - ) + if self.evaluate_test_fn: + log_msg = ( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f}, score: {self.best_score:.6f} in #{self.best_epoch}" + ) + else: + log_msg = ( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" + ) log.info(log_msg) if self.verbose: print(log_msg, flush=True) @@ -506,6 +527,10 @@ def _update_best_reward_and_return_should_stop_training( should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std + if self.evaluate_test_fn: + self.best_score = self.evaluate_test_fn(test_result) + else: + self.best_score = self.best_reward return should_stop_training @@ -562,6 +587,7 @@ def run(self, reset_prior_to_run: bool = True) -> InfoStats: start_time=self.start_time, policy_update_time=self.policy_update_time, gradient_step=self._gradient_step, + best_score=self.best_score, best_reward=self.best_reward, best_reward_std=self.best_reward_std, train_collector=self.train_collector, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index de730cee2..1f4369f72 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -42,6 +42,7 @@ def gather_info( start_time: float, policy_update_time: float, gradient_step: int, + best_score: float, best_reward: float, best_reward_std: float, train_collector: BaseCollector | None = None, @@ -75,6 +76,7 @@ def gather_info( return InfoStats( gradient_step=gradient_step, + best_score=best_score, best_reward=best_reward, best_reward_std=best_reward_std, train_step=train_collector.collect_step if train_collector is not None else 0, From 0dd252aa5b11f7df80808049d70c598e0769a38a Mon Sep 17 00:00:00 2001 From: anyongjin Date: Wed, 14 Aug 2024 11:30:01 +0800 Subject: [PATCH 2/4] rename to `compute_score_fn` --- tianshou/data/stats.py | 2 +- tianshou/trainer/base.py | 27 +++++++++++---------------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index 409807286..ed64a429d 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -57,7 +57,7 @@ class InfoStats(DataclassPPrintMixin): gradient_step: int """The total gradient step.""" best_score: float - """The best score over the test results.""" + """The best score over the test results. The one with the highest score will be considered the best model.""" best_reward: float """The best reward over the test results.""" best_reward_std: float diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 09eda98c8..76c4aec15 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -73,7 +73,8 @@ class BaseTrainer(ABC): :param test_fn: a hook called at the beginning of testing in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param evaluate_test_fn: Calculate the test batch performance score to determine whether it is the best model + :param compute_score_fn: Calculate the test batch performance score to + determine whether it is the best model, the mean reward will be used as score if not provided. :param save_best_fn: a hook called when the undiscounted average mean reward in evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> None``. @@ -165,7 +166,7 @@ def __init__( train_fn: Callable[[int, int], None] | None = None, test_fn: Callable[[int, int | None], None] | None = None, stop_fn: Callable[[float], bool] | None = None, - evaluate_test_fn: Callable[[CollectStats], float] | None = None, + compute_score_fn: Callable[[CollectStats], float] | None = None, save_best_fn: Callable[[BasePolicy], None] | None = None, save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, @@ -213,7 +214,9 @@ def __init__( self.train_fn = train_fn self.test_fn = test_fn self.stop_fn = stop_fn - self.evaluate_test_fn = evaluate_test_fn + self.compute_score_fn = compute_score_fn + if self.compute_score_fn is None: + self.compute_score_fn = lambda stat: stat.returns_stat.mean self.save_best_fn = save_best_fn self.save_checkpoint_fn = save_checkpoint_fn @@ -277,10 +280,7 @@ def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> No test_result.returns_stat.mean, test_result.returns_stat.std, ) - if self.evaluate_test_fn: - self.best_score = self.evaluate_test_fn(test_result) - else: - self.best_score = self.best_reward + self.best_score = self.compute_score_fn(test_result) if self.save_best_fn: self.save_best_fn(self.policy) @@ -393,10 +393,7 @@ def test_step(self) -> tuple[CollectStats, bool]: ) assert test_stat.returns_stat is not None # for mypy rew, rew_std = test_stat.returns_stat.mean, test_stat.returns_stat.std - if self.evaluate_test_fn: - score = self.evaluate_test_fn(test_stat) - else: - score = float(rew) + score = self.compute_score_fn(test_stat) if self.best_epoch < 0 or self.best_score < score: self.best_score = score self.best_epoch = self.epoch @@ -404,7 +401,8 @@ def test_step(self) -> tuple[CollectStats, bool]: self.best_reward_std = rew_std if self.save_best_fn: self.save_best_fn(self.policy) - if self.evaluate_test_fn: + if score != rew: + # use custom score calculater log_msg = ( f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, score: {score:.6f}," f" best_reward: {self.best_reward:.6f} ± " @@ -527,10 +525,7 @@ def _update_best_reward_and_return_should_stop_training( should_stop_training = True self.best_reward = test_result.returns_stat.mean self.best_reward_std = test_result.returns_stat.std - if self.evaluate_test_fn: - self.best_score = self.evaluate_test_fn(test_result) - else: - self.best_score = self.best_reward + self.best_score = self.compute_score_fn(test_result) return should_stop_training From 9f98f572808346f2024547a20f6f58f5bbd6712d Mon Sep 17 00:00:00 2001 From: anyongjin Date: Wed, 14 Aug 2024 12:07:33 +0800 Subject: [PATCH 3/4] fix mypy error --- tianshou/trainer/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 76c4aec15..189ea2934 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -214,9 +214,12 @@ def __init__( self.train_fn = train_fn self.test_fn = test_fn self.stop_fn = stop_fn + self.compute_score_fn: Callable[[CollectStats], float] + if compute_score_fn is None: + def compute_score_fn(stat: CollectStats) -> float: + assert stat.returns_stat is not None # for mypy + return stat.returns_stat.mean self.compute_score_fn = compute_score_fn - if self.compute_score_fn is None: - self.compute_score_fn = lambda stat: stat.returns_stat.mean self.save_best_fn = save_best_fn self.save_checkpoint_fn = save_checkpoint_fn From 68aadc98341cd1ed70d16979f8a8eab5f0934aa7 Mon Sep 17 00:00:00 2001 From: anyongjin Date: Wed, 14 Aug 2024 12:26:51 +0800 Subject: [PATCH 4/4] fix black format error --- tianshou/trainer/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 189ea2934..a6679fa20 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -216,9 +216,11 @@ def __init__( self.stop_fn = stop_fn self.compute_score_fn: Callable[[CollectStats], float] if compute_score_fn is None: + def compute_score_fn(stat: CollectStats) -> float: assert stat.returns_stat is not None # for mypy return stat.returns_stat.mean + self.compute_score_fn = compute_score_fn self.save_best_fn = save_best_fn self.save_checkpoint_fn = save_checkpoint_fn