这是indexloc提供的服务,不要输入任何密码
Skip to content

add write_flush in tflogger, fix argument passing in wandblogger #581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tianshou/utils/logger/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class TensorboardLogger(BaseLogger):
:param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
"""

def __init__(
Expand All @@ -26,16 +28,19 @@ def __init__(
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1,
write_flush: bool = True,
) -> None:
super().__init__(train_interval, test_interval, update_interval)
self.save_interval = save_interval
self.write_flush = write_flush
self.last_save_step = -1
self.writer = writer

def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step)
self.writer.flush() # issue #482
if self.write_flush: # issue 580
self.writer.flush() # issue #482

def save_data(
self,
Expand Down
11 changes: 10 additions & 1 deletion tianshou/utils/logger/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class WandbLogger(BaseLogger):
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data().
Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
:param bool write_flush: whether to flush tensorboard result after each
add_scalar operation. Default to True.
:param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None.
Expand All @@ -44,6 +48,7 @@ def __init__(
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1000,
write_flush: bool = True,
project: Optional[str] = None,
name: Optional[str] = None,
entity: Optional[str] = None,
Expand All @@ -53,6 +58,7 @@ def __init__(
super().__init__(train_interval, test_interval, update_interval)
self.last_save_step = -1
self.save_interval = save_interval
self.write_flush = write_flush
self.restored = False
if project is None:
project = os.getenv("WANDB_PROJECT", "tianshou")
Expand All @@ -72,7 +78,10 @@ def __init__(

def load(self, writer: SummaryWriter) -> None:
self.writer = writer
self.tensorboard_logger = TensorboardLogger(writer)
self.tensorboard_logger = TensorboardLogger(
writer, self.train_interval, self.test_interval, self.update_interval,
self.save_interval, self.write_flush
)
Comment on lines +81 to +84
Copy link
Collaborator Author

@Trinkle23897 Trinkle23897 Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vwxyzjn this is a bug fix #558


def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
if self.tensorboard_logger is None:
Expand Down