diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 04a4a7803..a758c4769 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -271,6 +271,8 @@ v_s_ obs obs_next dtype +iqm +kwarg entrypoint interquantile init diff --git a/examples/mujoco/mujoco_ppo_hl_multi.py b/examples/mujoco/mujoco_ppo_hl_multi.py index 33ec34a73..6d140386e 100644 --- a/examples/mujoco/mujoco_ppo_hl_multi.py +++ b/examples/mujoco/mujoco_ppo_hl_multi.py @@ -13,6 +13,7 @@ """ import os +import warnings import torch @@ -38,16 +39,25 @@ def main( num_experiments: int = 5, - run_experiments_sequentially: bool = False, + run_experiments_sequentially: bool = True, + logger_type: str = "wandb", ) -> RLiableExperimentResult: """:param num_experiments: the number of experiments to run. The experiments differ exclusively in the seeds. :param run_experiments_sequentially: if True, the experiments are run sequentially, otherwise in parallel. If a single experiment is set to use all available CPU cores, it might be undesired to run multiple experiments in parallel on the same machine, + :param logger_type: the type of logger to use. Currently, "wandb" and "tensorboard" are supported. :return: an object containing rliable-based evaluation results """ + if not run_experiments_sequentially and logger_type == "wandb": + warnings.warn( + "Parallel execution with wandb logger is still under development. Falling back to tensorboard.", + ) + logger_type = "tensorboard" + task = "Ant-v4" - persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", datetime_tag())) + tag = datetime_tag() + persistence_dir = os.path.abspath(os.path.join("log", task, "ppo", tag)) experiment_config = ExperimentConfig(persistence_base_dir=persistence_dir, watch=False) @@ -72,6 +82,21 @@ def main( hidden_sizes = (64, 64) + match logger_type: + case "wandb": + job_type = f"ppo/{tag}" + logger_factory = LoggerFactoryDefault( + logger_type="wandb", + wandb_project="tianshou", + group=task, + job_type=job_type, + save_interval=1, + ) + case "tensorboard": + logger_factory = LoggerFactoryDefault("tensorboard") + case _: + raise ValueError(f"Unknown logger type: {logger_type}") + experiment_collection = ( PPOExperimentBuilder(env_factory, experiment_config, sampling_config) .with_ppo_params( @@ -95,7 +120,7 @@ def main( ) .with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True) .with_critic_factory_default(hidden_sizes, torch.nn.Tanh) - .with_logger_factory(LoggerFactoryDefault("tensorboard")) + .with_logger_factory(logger_factory) .build_seeded_collection(num_experiments) ) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index b77318602..4685f5730 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -22,6 +22,8 @@ class SequenceSummaryStats(DataclassPPrintMixin): @classmethod def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": + if len(sequence) == 0: + return cls(mean=0.0, std=0.0, max=0.0, min=0.0) return cls( mean=float(np.mean(sequence)), std=float(np.std(sequence)), diff --git a/tianshou/evaluation/launcher.py b/tianshou/evaluation/launcher.py index 4dddae5af..52a21cb37 100644 --- a/tianshou/evaluation/launcher.py +++ b/tianshou/evaluation/launcher.py @@ -94,12 +94,11 @@ def _launch(self, experiments: Sequence[Experiment]) -> list[InfoStats | None]: successful_exp_stats = [] failed_exps = [] for exp in experiments: - for exp in experiments: - exp_stats = self._safe_execute(exp) - if exp_stats == "failed": - failed_exps.append(exp) - else: - successful_exp_stats.append(exp_stats) + exp_stats = self._safe_execute(exp) + if exp_stats == "failed": + failed_exps.append(exp) + else: + successful_exp_stats.append(exp_stats) # noinspection PyTypeChecker return self._return_from_successful_and_failed_exps(successful_exp_stats, failed_exps) diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index 884176bd1..e5fdf8eb4 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -3,7 +3,8 @@ """ import os -from dataclasses import asdict, dataclass, fields +from dataclasses import dataclass, fields +from typing import Literal import matplotlib.pyplot as plt import numpy as np @@ -12,7 +13,7 @@ from rliable import plot_utils from tianshou.highlevel.experiment import Experiment -from tianshou.utils import logging +from tianshou.utils import TensorboardLogger, logging from tianshou.utils.logger.base import DataScope log = logging.getLogger(__name__) @@ -61,17 +62,35 @@ class RLiableExperimentResult: test_episode_returns_RE: np.ndarray """The test episodes for each run of the experiment where each row corresponds to one run.""" + train_episode_returns_RE: np.ndarray + """The training episodes for each run of the experiment where each row corresponds to one run.""" + env_steps_E: np.ndarray """The number of environment steps at which the test episodes were evaluated.""" + env_steps_train_E: np.ndarray + """The number of environment steps at which the training episodes were evaluated.""" + @classmethod - def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult": + def load_from_disk( + cls, + exp_dir: str, + max_env_step: int | None = None, + ) -> "RLiableExperimentResult": """Load the experiment result from disk. :param exp_dir: The directory from where the experiment results are restored. + :param max_env_step: The maximum number of environment steps to consider. If None, all data is considered. + Note: if the experiments have different numbers of steps, the minimum number is used. """ test_episode_returns = [] + train_episode_returns = [] env_step_at_test = None + """The number of steps of the test run, + will try extracting it either from the loaded stats or from loaded arrays.""" + env_step_at_train = None + """The number of steps of the training run, + will try extracting it from the loaded stats or from loaded arrays.""" # TODO: env_step_at_test should not be defined in a loop and overwritten at each iteration # just for retrieving them. We might need a cleaner directory structure. @@ -79,43 +98,109 @@ def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult": if entry.name.startswith(".") or not entry.is_dir(): continue - exp = Experiment.from_directory(entry.path) - logger = exp.logger_factory.create_logger( - entry.path, - entry.name, - None, - asdict(exp.config), - ) - data = logger.restore_logged_data(entry.path) - - if DataScope.TEST.value not in data or not data[DataScope.TEST.value]: - continue - restored_test_data = data[DataScope.TEST.value] - if not isinstance(restored_test_data, dict): - raise RuntimeError( - f"Expected entry with key {DataScope.TEST.value} data to be a dictionary, " - f"but got {restored_test_data=}.", + try: + # TODO: fix + logger_factory = Experiment.from_directory(entry.path).logger_factory + # only retrieve logger class to prevent creating another tfevent file + logger_cls = logger_factory.get_logger_class() + # Usually this means from low-level API + except FileNotFoundError: + log.info( + f"Could not find persisted experiment in {entry.path}, using default logger.", ) + logger_cls = TensorboardLogger + + data = logger_cls.restore_logged_data(entry.path) + # TODO: align low-level and high-level dir structure. This is a hack! + if not data: + dirs = [ + d for d in os.listdir(entry.path) if os.path.isdir(os.path.join(entry.path, d)) + ] + if len(dirs) != 1: + raise ValueError( + f"Could not restore data from {entry.path}, " + f"expected either events or exactly one subdirectory, ", + ) + data = logger_cls.restore_logged_data(os.path.join(entry.path, dirs[0])) + if not data: + raise ValueError(f"Could not restore data from {entry.path}.") + + if DataScope.TEST not in data or not data[DataScope.TEST]: + continue + restored_test_data = data[DataScope.TEST] + restored_train_data = data[DataScope.TRAIN] + + assert isinstance(restored_test_data, dict) + assert isinstance(restored_train_data, dict) + + for restored_data, scope in zip( + [restored_test_data, restored_train_data], + [DataScope.TEST, DataScope.TRAIN], + strict=True, + ): + if not isinstance(restored_data, dict): + raise RuntimeError( + f"Expected entry with key {scope} data to be a dictionary, " + f"but got {restored_data=}.", + ) test_data = LoggedCollectStats.from_data_dict(restored_test_data) + train_data = LoggedCollectStats.from_data_dict(restored_train_data) - if test_data.returns_stat is None: - continue - test_episode_returns.append(test_data.returns_stat.mean) - env_step_at_test = test_data.env_step + if test_data.returns_stat is not None: + test_episode_returns.append(test_data.returns_stat.mean) + env_step_at_test = test_data.env_step + + if train_data.returns_stat is not None: + train_episode_returns.append(train_data.returns_stat.mean) + env_step_at_train = train_data.env_step + test_data_found = True + train_data_found = True if not test_episode_returns or env_step_at_test is None: - raise ValueError(f"No experiment data found in {exp_dir}.") + log.warning(f"No test experiment data found in {exp_dir}.") + test_data_found = False + if not train_episode_returns or env_step_at_train is None: + log.warning(f"No train experiment data found in {exp_dir}.") + train_data_found = False + + if not test_data_found and not train_data_found: + raise RuntimeError(f"No test or train data found in {exp_dir}.") + + min_train_len = min([len(arr) for arr in train_episode_returns]) + if max_env_step is not None: + min_train_len = min(min_train_len, max_env_step) + min_test_len = min([len(arr) for arr in test_episode_returns]) + if max_env_step is not None: + min_test_len = min(min_test_len, max_env_step) + + assert env_step_at_test is not None + assert env_step_at_train is not None + + env_step_at_test = env_step_at_test[:min_test_len] + env_step_at_train = env_step_at_train[:min_train_len] + if max_env_step: + # find the index at which the maximum env step is reached with searchsorted + min_test_len = int(np.searchsorted(env_step_at_test, max_env_step)) + min_train_len = int(np.searchsorted(env_step_at_train, max_env_step)) + env_step_at_test = env_step_at_test[:min_test_len] + env_step_at_train = env_step_at_train[:min_train_len] + + test_episode_returns = np.array([arr[:min_test_len] for arr in test_episode_returns]) + train_episode_returns = np.array([arr[:min_train_len] for arr in train_episode_returns]) return cls( - test_episode_returns_RE=np.array(test_episode_returns), - env_steps_E=np.array(env_step_at_test), + test_episode_returns_RE=test_episode_returns, + env_steps_E=env_step_at_test, exp_dir=exp_dir, + train_episode_returns_RE=train_episode_returns, + env_steps_train_E=env_step_at_train, ) def _get_rliable_data( self, algo_name: str | None = None, score_thresholds: np.ndarray | None = None, + scope: DataScope = DataScope.TEST, ) -> tuple[dict, np.ndarray, np.ndarray]: """Return the data in the format expected by the rliable library. @@ -126,19 +211,25 @@ def _get_rliable_data( :return: A tuple score_dict, env_steps, and score_thresholds. """ + if scope == DataScope.TEST: + env_steps, returns = self.env_steps_E, self.test_episode_returns_RE + elif scope == DataScope.TRAIN: + env_steps, returns = self.env_steps_train_E, self.train_episode_returns_RE + else: + raise ValueError(f"Invalid scope {scope}, should be either 'TEST' or 'TRAIN'.") if score_thresholds is None: score_thresholds = np.linspace( - np.min(self.test_episode_returns_RE), - np.max(self.test_episode_returns_RE), + np.min(returns), + np.max(returns), 101, ) if algo_name is None: algo_name = os.path.basename(self.exp_dir) - score_dict = {algo_name: self.test_episode_returns_RE} + score_dict = {algo_name: returns} - return score_dict, self.env_steps_E, score_thresholds + return score_dict, env_steps, score_thresholds def eval_results( self, @@ -146,6 +237,10 @@ def eval_results( score_thresholds: np.ndarray | None = None, save_plots: bool = False, show_plots: bool = True, + scope: DataScope = DataScope.TEST, + ax_iqm: plt.Axes | None = None, + ax_profile: plt.Axes | None = None, + algo2color: dict[str, str] | None = None, ) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile. @@ -155,19 +250,30 @@ def eval_results( from the minimum and maximum test episode returns. :param save_plots: If True, the figures are saved to the experiment directory. :param show_plots: If True, the figures are shown. - - :return: The created figures and axes. + :param scope: The scope of the evaluation, either 'TEST' or 'TRAIN'. + :param ax_iqm: The axis to plot the IQM sample efficiency curve on. If None, a new figure is created. + :param ax_profile: The axis to plot the performance profile on. If None, a new figure is created. + :param algo2color: A dictionary mapping algorithm names to colors. Useful for plotting + the evaluations of multiple algorithms in the same figure, e.g., by first creating an ax_iqm and ax_profile + with one evaluation and then passing them into the other evaluation. Same as the `colors` + kwarg in the rliable plotting utils. + + :return: The created figures and axes in the order: fig_iqm, ax_iqm, fig_profile, ax_profile. """ score_dict, env_steps, score_thresholds = self._get_rliable_data( algo_name, score_thresholds, + scope, ) iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm) # Plot IQM sample efficiency curve - fig_iqm, ax_iqm = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + if ax_iqm is None: + fig_iqm, ax_iqm = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + else: + fig_iqm = ax_iqm.get_figure() # type: ignore plot_utils.plot_sample_efficiency_curve( env_steps, iqm_scores, @@ -176,6 +282,7 @@ def eval_results( xlabel="env step", ylabel="IQM episode return", ax=ax_iqm, + colors=algo2color, ) if show_plots: plt.show(block=False) @@ -197,7 +304,10 @@ def eval_results( ) # Plot score distributions - fig_profile, ax_profile = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + if ax_profile is None: + fig_profile, ax_profile = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + else: + fig_profile = ax_profile.get_figure() # type: ignore plot_utils.plot_performance_profiles( score_distributions, score_thresholds, @@ -216,3 +326,30 @@ def eval_results( plt.show(block=False) return fig_iqm, ax_iqm, fig_profile, ax_profile + + +def load_and_eval_experiments( + log_dir: str, + show_plots: bool = True, + save_plots: bool = True, + scope: DataScope | Literal["both"] = DataScope.TEST, + max_env_step: int | None = None, +) -> RLiableExperimentResult: + """Evaluate the experiments in the given log directory using the rliable API and return the loaded results object. + + If neither `show_plots` nor `save_plots` is set to `True`, this is equivalent to just loading the results from disk. + + :param log_dir: The directory containing the experiment results. + :param show_plots: whether to display plots. + :param save_plots: whether to save plots to the `log_dir`. + :param scope: The scope of the evaluation, either 'test', 'train' or 'both'. + :param max_env_step: The maximum number of environment steps to consider. If None, all data is considered. + Note: if the experiments have different numbers of steps, the minimum number is used. + """ + rliable_result = RLiableExperimentResult.load_from_disk(log_dir, max_env_step=max_env_step) + if scope == "both": + for scope in [DataScope.TEST, DataScope.TRAIN]: + rliable_result.eval_results(show_plots=True, save_plots=True, scope=scope) + else: + rliable_result.eval_results(show_plots=show_plots, save_plots=save_plots, scope=scope) + return rliable_result diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 2223d81f7..39b5fc2e7 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -17,60 +17,90 @@ def create_logger( log_dir: str, experiment_name: str, run_id: str | None, - config_dict: dict, + config_dict: dict | None = None, ) -> TLogger: """Creates the logger. :param log_dir: path to the directory in which log data is to be stored - :param experiment_name: the name of the job, which may contain `os.path.sep` + :param experiment_name: the name of the job, which may contain `os.path.delimiter` :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger :param config_dict: a dictionary with data that is to be logged :return: the logger """ + @abstractmethod + def get_logger_class(self) -> type[TLogger]: + """Returns the class of the logger that is to be created.""" + class LoggerFactoryDefault(LoggerFactory): def __init__( self, logger_type: Literal["tensorboard", "wandb", "pandas"] = "tensorboard", + wand_entity: str | None = None, wandb_project: str | None = None, + group: str | None = None, + job_type: str | None = None, + save_interval: int = 1, ): if logger_type == "wandb" and wandb_project is None: raise ValueError("Must provide 'wandb_project'") self.logger_type = logger_type + self.wandb_entity = wand_entity self.wandb_project = wandb_project + self.group = group + self.job_type = job_type + self.save_interval = save_interval def create_logger( self, log_dir: str, experiment_name: str, run_id: str | None, - config_dict: dict, + config_dict: dict | None = None, ) -> TLogger: - if self.logger_type in ["wandb", "tensorboard"]: - writer = SummaryWriter(log_dir) - writer.add_text( - "args", - str( - dict( - log_dir=log_dir, - logger_type=self.logger_type, - wandb_project=self.wandb_project, - ), - ), - ) match self.logger_type: case "wandb": - wandb_logger = WandbLogger( - save_interval=1, + logger = WandbLogger( + save_interval=self.save_interval, name=experiment_name.replace(os.path.sep, "__"), run_id=run_id, config=config_dict, + entity=self.wandb_entity, project=self.wandb_project, + group=self.group, + job_type=self.job_type, + log_dir=log_dir, ) - wandb_logger.load(writer) - return wandb_logger + writer = self._create_writer(log_dir) # writer has to be created after wandb.init! + logger.load(writer) + return logger case "tensorboard": + writer = self._create_writer(log_dir) return TensorboardLogger(writer) case _: raise ValueError(f"Unknown logger type '{self.logger_type}'") + + def _create_writer(self, log_dir: str) -> SummaryWriter: + """Creates a tensorboard writer and adds a text artifact.""" + writer = SummaryWriter(log_dir) + writer.add_text( + "args", + str( + dict( + log_dir=log_dir, + logger_type=self.logger_type, + wandb_project=self.wandb_project, + ), + ), + ) + return writer + + def get_logger_class(self) -> type[TLogger]: + match self.logger_type: + case "wandb": + return WandbLogger + case "tensorboard": + return TensorboardLogger + case _: + raise ValueError(f"Unknown logger type '{self.logger_type}'") diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index fadb0ad57..2ff6c6760 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -1,7 +1,7 @@ import typing from abc import ABC, abstractmethod from collections.abc import Callable -from enum import Enum +from enum import StrEnum from numbers import Number import numpy as np @@ -13,7 +13,7 @@ TRestoredData = dict[str, np.ndarray | dict[str, "TRestoredData"]] -class DataScope(Enum): +class DataScope(StrEnum): TRAIN = "train" TEST = "test" UPDATE = "update" @@ -67,6 +67,10 @@ def prepare_dict_for_logging(self, log_data: dict) -> dict[str, VALID_LOG_VALS_T :return: the prepared dict. """ + @abstractmethod + def finalize(self) -> None: + """Finalize the logger, e.g., close writers and connections.""" + def log_train_data(self, log_data: dict, step: int) -> None: """Use writer to log statistics generated during training. @@ -76,7 +80,7 @@ def log_train_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method if step - self.last_log_train_step >= self.train_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data) + self.write(f"{DataScope.TRAIN}/env_step", step, log_data) self.last_log_train_step = step def log_test_data(self, log_data: dict, step: int) -> None: @@ -88,7 +92,7 @@ def log_test_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) if step - self.last_log_test_step >= self.test_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.TEST.value}/env_step", step, log_data) + self.write(f"{DataScope.TEST}/env_step", step, log_data) self.last_log_test_step = step def log_update_data(self, log_data: dict, step: int) -> None: @@ -100,7 +104,7 @@ def log_update_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data) + self.write(f"{DataScope.UPDATE}/gradient_step", step, log_data) self.last_log_update_step = step def log_info_data(self, log_data: dict, step: int) -> None: @@ -113,7 +117,7 @@ def log_info_data(self, log_data: dict, step: int) -> None: step - self.last_log_info_step >= self.info_interval ): # TODO: move interval check to calling method log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.INFO.value}/epoch", step, log_data) + self.write(f"{DataScope.INFO}/epoch", step, log_data) self.last_log_info_step = step @abstractmethod @@ -143,9 +147,9 @@ def restore_data(self) -> tuple[int, int, int]: :return: epoch, env_step, gradient_step. """ + @staticmethod @abstractmethod def restore_logged_data( - self, log_path: str, ) -> TRestoredData: """Load the logged data from disk for post-processing. @@ -153,9 +157,6 @@ def restore_logged_data( :return: a dict containing the logged data. """ - def finalize(self) -> None: - """Finalize the logger, e.g. close the file handler.""" - class LazyLogger(BaseLogger): """A logger that does nothing. Used as the placeholder in trainer.""" @@ -172,6 +173,9 @@ def prepare_dict_for_logging( def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) -> None: """The LazyLogger writes nothing.""" + def finalize(self) -> None: + pass + def save_data( self, epoch: int, @@ -184,5 +188,6 @@ def save_data( def restore_data(self) -> tuple[int, int, int]: return 0, 0, 0 - def restore_logged_data(self, log_path: str) -> dict: + @staticmethod + def restore_logged_data(log_path: str) -> dict: return {} diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index d824d862d..1406cfbb8 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -99,6 +99,9 @@ def write(self, step_type: str, step: int, data: dict[str, Any]) -> None: if self.write_flush: # issue 580 self.writer.flush() # issue #482 + def finalize(self) -> None: + self.writer.close() + def save_data( self, epoch: int, @@ -136,8 +139,8 @@ def restore_data(self) -> tuple[int, int, int]: return epoch, env_step, gradient_step + @staticmethod def restore_logged_data( - self, log_path: str, ) -> TRestoredData: """Restores the logged data from the tensorboard log directory. diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 74d844fa9..9172bf54b 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,5 +1,6 @@ import argparse import contextlib +import logging import os from collections.abc import Callable @@ -11,6 +12,8 @@ with contextlib.suppress(ImportError): import wandb +log = logging.getLogger(__name__) + class WandbLogger(BaseLogger): """Weights and Biases logger that sends data to https://wandb.ai/. @@ -54,8 +57,12 @@ def __init__( name: str | None = None, entity: str | None = None, run_id: str | None = None, + group: str | None = None, + job_type: str | None = None, config: argparse.Namespace | dict | None = None, monitor_gym: bool = True, + disable_stats: bool = False, + log_dir: str | None = None, ) -> None: super().__init__(train_interval, test_interval, update_interval, info_interval) self.last_save_step = -1 @@ -68,13 +75,17 @@ def __init__( self.wandb_run = ( wandb.init( project=project, + group=group, + job_type=job_type, name=name, id=run_id, resume="allow", entity=entity, sync_tensorboard=True, - monitor_gym=monitor_gym, + # monitor_gym=monitor_gym, # currently disabled until gymnasium version is bumped to >1.0.0 https://github.com/wandb/wandb/issues/7047 + dir=log_dir, config=config, # type: ignore + settings=wandb.Settings(_disable_stats=disable_stats), ) if not wandb.run else wandb.run @@ -111,6 +122,12 @@ def write(self, step_type: str, step: int, data: dict[str, VALID_LOG_VALS_TYPE]) ) self.tensorboard_logger.write(step_type, step, data) + def finalize(self) -> None: + if self.wandb_run is not None: + self.wandb_run.finish() + if self.tensorboard_logger is not None: + self.tensorboard_logger.finalize() + def save_data( self, epoch: int, @@ -167,11 +184,10 @@ def restore_data(self) -> tuple[int, int, int]: env_step = 0 return epoch, env_step, gradient_step - def restore_logged_data(self, log_path: str) -> TRestoredData: - if self.tensorboard_logger is None: - raise NotImplementedError( - "Restoring logged data directly from W&B is not yet implemented." - "Try instantiating the internal TensorboardLogger by calling something" - "like `logger.load(SummaryWriter(log_path))`", - ) - return self.tensorboard_logger.restore_logged_data(log_path) + @staticmethod + def restore_logged_data(log_path: str) -> TRestoredData: + log.warning( + "Logging data directly from W&B is not yet implemented, will use the " + "TensorboardLogger to restore it from disc instead.", + ) + return TensorboardLogger.restore_logged_data(log_path)