这是indexloc提供的服务,不要输入任何密码
Skip to content

wandb logger #426

@drozzy

Description

@drozzy

If anyone needs a weights and biases (wandb) logger, here is one I hacked together.
Note - you have to select different step values for each of train/test/update panels in your https://wandb.ai/ panels:

from tianshou.utils import BaseLogger
from tianshou.utils.log_tools import WRITE_TYPE
from typing import Any, Tuple, Union, Callable, Optional
import wandb

class WandBLogger(BaseLogger):
    def __init__(
        self,
        train_interval: int = 1000,
        test_interval: int = 1,
        update_interval: int = 1000
    ) -> None:
        super().__init__(writer=None)
        
        self.train_interval = train_interval
        self.test_interval = test_interval
        self.update_interval = update_interval
        self.last_log_train_step = -1
        self.last_log_test_step = -1
        self.last_log_update_step = -1        

    def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
        pass

    def log_train_data(self, collect_result: dict, step: int) -> None:
        if collect_result["n/ep"] > 0:
            collect_result["rew"] = collect_result["rews"].mean()
            collect_result["len"] = collect_result["lens"].mean()
            if step - self.last_log_train_step >= self.train_interval:

                log_data = {                    
                    "train/env_step": step,
                    "train/episode": collect_result["n/ep"],
                    "train/reward": collect_result["rew"],
                    "train/length": collect_result["len"]}
                wandb.log(log_data)

                self.last_log_train_step = step

    def log_test_data(self, collect_result: dict, step: int) -> None:
        assert collect_result["n/ep"] > 0
        rews, lens = collect_result["rews"], collect_result["lens"]
        rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
        collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
        if step - self.last_log_test_step >= self.test_interval:

            log_data = {                
                "test/env_step": step,
                "test/reward": rew,
                "test/length": len_,
                "test/reward_std": rew_std,
                "test/length_std": len_std}

            wandb.log(log_data)
            self.last_log_test_step = step

    def log_update_data(self, update_result: dict, step: int) -> None:        
        if step - self.last_log_update_step >= self.update_interval:
            log_data = {}
            
            for k,v in update_result.items():
                log_data[f'update/{k}'] = v

            log_data['update/gradient_step'] = step
            wandb.log(log_data)
                
            self.last_log_update_step = step

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementFeature that is not a new algorithm or an algorithm enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions