From 03e932bccc6a34f228baf4b7c508e2a75b475c10 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Tue, 29 Mar 2022 19:29:47 -0400 Subject: [PATCH 1/2] add write_flush in tflogger, fix argument passing in wandblogger --- tianshou/utils/logger/tensorboard.py | 7 ++++++- tianshou/utils/logger/wandb.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index bc43bae1a..c72dfb3e5 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger): :param int update_interval: the log interval in log_update_data(). Default to 1000. :param int save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). + :param bool write_flush: whether to flush tensorboard result after each "add_*" + operation. Default to True. """ def __init__( @@ -26,16 +28,19 @@ def __init__( test_interval: int = 1, update_interval: int = 1000, save_interval: int = 1, + write_flush: bool = True ) -> None: super().__init__(train_interval, test_interval, update_interval) self.save_interval = save_interval + self.write_flush = write_flush self.last_save_step = -1 self.writer = writer def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: for k, v in data.items(): self.writer.add_scalar(k, v, global_step=step) - self.writer.flush() # issue #482 + if self.write_flush: # issue 580 + self.writer.flush() # issue #482 def save_data( self, diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 32a89d2a0..062a26a6c 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -31,6 +31,10 @@ class WandbLogger(BaseLogger): :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. + :param int save_interval: the save interval in save_data(). Default to 1 (save at + the end of each epoch). + :param bool write_flush: whether to flush tensorboard result after each "add_*" + operation. Default to True. :param str project: W&B project name. Default to "tianshou". :param str name: W&B run name. Default to None. If None, random name is assigned. :param str entity: W&B team/organization name. Default to None. @@ -44,6 +48,7 @@ def __init__( test_interval: int = 1, update_interval: int = 1000, save_interval: int = 1000, + write_flush: bool = True, project: Optional[str] = None, name: Optional[str] = None, entity: Optional[str] = None, @@ -53,6 +58,7 @@ def __init__( super().__init__(train_interval, test_interval, update_interval) self.last_save_step = -1 self.save_interval = save_interval + self.write_flush = write_flush self.restored = False if project is None: project = os.getenv("WANDB_PROJECT", "tianshou") @@ -72,7 +78,10 @@ def __init__( def load(self, writer: SummaryWriter) -> None: self.writer = writer - self.tensorboard_logger = TensorboardLogger(writer) + self.tensorboard_logger = TensorboardLogger( + writer, self.train_interval, self.test_interval, self.update_interval, + self.save_interval, self.write_flush + ) def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: if self.tensorboard_logger is None: From 7f9e9734996069b7e2b332ac04aa5bf891dc9b3b Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Tue, 29 Mar 2022 19:34:03 -0400 Subject: [PATCH 2/2] update --- tianshou/utils/logger/tensorboard.py | 6 +++--- tianshou/utils/logger/wandb.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index c72dfb3e5..843ff012e 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -17,8 +17,8 @@ class TensorboardLogger(BaseLogger): :param int update_interval: the log interval in log_update_data(). Default to 1000. :param int save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). - :param bool write_flush: whether to flush tensorboard result after each "add_*" - operation. Default to True. + :param bool write_flush: whether to flush tensorboard result after each + add_scalar operation. Default to True. """ def __init__( @@ -28,7 +28,7 @@ def __init__( test_interval: int = 1, update_interval: int = 1000, save_interval: int = 1, - write_flush: bool = True + write_flush: bool = True, ) -> None: super().__init__(train_interval, test_interval, update_interval) self.save_interval = save_interval diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 062a26a6c..e63a7bc7f 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -33,8 +33,8 @@ class WandbLogger(BaseLogger): Default to 1000. :param int save_interval: the save interval in save_data(). Default to 1 (save at the end of each epoch). - :param bool write_flush: whether to flush tensorboard result after each "add_*" - operation. Default to True. + :param bool write_flush: whether to flush tensorboard result after each + add_scalar operation. Default to True. :param str project: W&B project name. Default to "tianshou". :param str name: W&B run name. Default to None. If None, random name is assigned. :param str entity: W&B team/organization name. Default to None.