diff --git a/.github/workflows/extra_sys.yml b/.github/workflows/extra_sys.yml index 124ec40d4..df21abc31 100644 --- a/.github/workflows/extra_sys.yml +++ b/.github/workflows/extra_sys.yml @@ -22,6 +22,9 @@ jobs: - name: Install dependencies run: | python -m pip install ".[dev]" --upgrade + - name: wandb login + run: | + wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest run: | pytest test/base test/continuous --cov=tianshou --durations=0 -v diff --git a/.github/workflows/gputest.yml b/.github/workflows/gputest.yml index b973c34d5..8032bd3f9 100644 --- a/.github/workflows/gputest.yml +++ b/.github/workflows/gputest.yml @@ -18,6 +18,9 @@ jobs: - name: Install dependencies run: | python -m pip install ".[dev]" --upgrade + - name: wandb login + run: | + wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest # ignore test/throughput which only profiles the code run: | diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index ac52b1979..102219bf4 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -21,6 +21,9 @@ jobs: - name: Install dependencies run: | python -m pip install ".[dev]" --upgrade + - name: wandb login + run: | + wandb login e2366d661b89f2bee877c40bee15502d67b7abef - name: Test with pytest # ignore test/throughput which only profiles the code run: | diff --git a/README.md b/README.md index 9e2d7a9da..40fcdd314 100644 --- a/README.md +++ b/README.md @@ -47,12 +47,13 @@ Here is Tianshou's other features: - Elegant framework, using only ~4000 lines of code - State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms -- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling) -- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training) -- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) -- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) +- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling) +- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training) +- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) +- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process) - Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation -- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) +- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) +- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools - Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions) In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. @@ -191,8 +192,7 @@ gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 step_per_epoch, step_per_collect = 10000, 10 -writer = SummaryWriter('log/dqn') # tensorboard is also supported! -logger = ts.utils.TensorboardLogger(writer) +logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported! ``` Make environments: @@ -208,7 +208,7 @@ Define the network: ```python from tianshou.utils.net.common import Net # you can define other net by following the API: -# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network +# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network env = gym.make(task) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n @@ -273,7 +273,7 @@ $ python3 test/discrete/test_pg.py --seed 0 --render 0.03 ## Contributing -Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html). +Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html). ## Citing Tianshou @@ -281,7 +281,7 @@ If you find Tianshou useful, please cite it in your publications. ```latex @article{weng2021tianshou, - title={Tianshou: a Highly Modularized Deep Reinforcement Learning Library}, + title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library}, author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun}, journal={arXiv preprint arXiv:2107.14171}, year={2021} diff --git a/docs/index.rst b/docs/index.rst index 5c332451d..a4dae5ba1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,9 +44,10 @@ Here is Tianshou's other features: * Support :ref:`customize_training` * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support :doc:`/tutorials/tictactoe` +* Support both `TensorBoard `_ and `W&B `_ log tools * Comprehensive `unit tests `_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking -中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_ +中文文档位于 `https://tianshou.readthedocs.io/zh/master/ `_ Installation diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 67a44d002..2cd26f824 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -12,7 +12,7 @@ from tianshou.env import ShmemVectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer -from tianshou.utils import TensorboardLogger +from tianshou.utils import TensorboardLogger, WandbLogger def get_args(): @@ -41,6 +41,13 @@ def get_args(): ) parser.add_argument('--frames-stack', type=int, default=4) parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--resume-id', type=str, default=None) + parser.add_argument( + '--logger', + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) parser.add_argument( '--watch', default=False, @@ -112,9 +119,18 @@ def test_dqn(args=get_args()): test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'dqn') - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) + if args.logger == "tensorboard": + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + else: + logger = WandbLogger( + save_interval=1, + project=args.task, + name='dqn', + run_id=args.resume_id, + config=args, + ) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -141,6 +157,12 @@ def train_fn(epoch, env_step): def test_fn(epoch, env_step): policy.set_eps(args.eps_test) + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + torch.save({'model': policy.state_dict()}, ckpt_path) + return ckpt_path + # watch agent's performance def watch(): print("Setup test envs ...") @@ -192,7 +214,9 @@ def watch(): save_fn=save_fn, logger=logger, update_per_step=args.update_per_step, - test_in_train=False + test_in_train=False, + resume_from_log=args.resume_id is not None, + save_checkpoint_fn=save_checkpoint_fn, ) pprint.pprint(result) diff --git a/setup.py b/setup.py index bf48020e5..80209edcb 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def get_version() -> str: exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"] ), install_requires=[ - "gym>=0.15.4", + "gym>=0.15.4,<0.20", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 3a50f36e9..b6b05b48e 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -11,6 +11,7 @@ from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import PSRLPolicy from tianshou.trainer import onpolicy_trainer +from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger def get_args(): @@ -30,6 +31,12 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--eps', type=float, default=0.01) parser.add_argument('--add-done-loop', action="store_true", default=False) + parser.add_argument( + '--logger', + type=str, + default="wandb", + choices=["wandb", "tensorboard", "none"], + ) return parser.parse_known_args()[0] @@ -72,10 +79,18 @@ def test_psrl(args=get_args()): exploration_noise=True ) test_collector = Collector(policy, test_envs) - # log - log_path = os.path.join(args.logdir, args.task, 'psrl') - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) + # Logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, project='psrl', name='wandb_test', config=args + ) + elif args.logger == "tensorboard": + log_path = os.path.join(args.logdir, args.task, 'psrl') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + else: + logger = LazyLogger() def stop_fn(mean_rewards): if env.spec.reward_threshold: @@ -96,8 +111,8 @@ def stop_fn(mean_rewards): 0, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, - # logger=logger, - test_in_train=False + logger=logger, + test_in_train=False, ) if __name__ == '__main__': diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 5af038ab3..25ceda186 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -3,10 +3,10 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger -from tianshou.utils.logger.wandb import WandBLogger +from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.statistics import MovAvg, RunningMeanStd __all__ = [ "MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger", - "BasicLogger", "LazyLogger", "WandBLogger" + "BasicLogger", "LazyLogger", "WandbLogger" ] diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 7a837c96c..f9c047c59 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,3 +1,7 @@ +import argparse +import os +from typing import Callable, Optional, Tuple + from tianshou.utils import BaseLogger from tianshou.utils.logger.base import LOG_DATA_TYPE @@ -7,10 +11,10 @@ pass -class WandBLogger(BaseLogger): - """Weights and Biases logger that sends data to Weights and Biases. +class WandbLogger(BaseLogger): + """Weights and Biases logger that sends data to https://wandb.ai/. - Creates three panels with plots: train, test, and update. + This logger creates three panels with plots: train, test, and update. Make sure to select the correct access for each panel in weights and biases: - ``train/env_step`` for train plots @@ -29,6 +33,11 @@ class WandBLogger(BaseLogger): :param int test_interval: the log interval in log_test_data(). Default to 1. :param int update_interval: the log interval in log_update_data(). Default to 1000. + :param str project: W&B project name. Default to "tianshou". + :param str name: W&B run name. Default to None. If None, random name is assigned. + :param str entity: W&B team/organization name. Default to None. + :param str run_id: run id of W&B run to be resumed. Default to None. + :param argparse.Namespace config: experiment configurations. Default to None. """ def __init__( @@ -36,9 +45,85 @@ def __init__( train_interval: int = 1000, test_interval: int = 1, update_interval: int = 1000, + save_interval: int = 1000, + project: str = 'tianshou', + name: Optional[str] = None, + entity: Optional[str] = None, + run_id: Optional[str] = None, + config: Optional[argparse.Namespace] = None, ) -> None: super().__init__(train_interval, test_interval, update_interval) + self.last_save_step = -1 + self.save_interval = save_interval + self.restored = False + + self.wandb_run = wandb.init( + project=project, + name=name, + id=run_id, + resume="allow", + entity=entity, + monitor_gym=True, + config=config, # type: ignore + ) if not wandb.run else wandb.run + self.wandb_run._label(repo="tianshou") # type: ignore def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: data[step_type] = step wandb.log(data) + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + ) -> None: + """Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer. + + :param int epoch: the epoch in trainer. + :param int env_step: the env_step in trainer. + :param int gradient_step: the gradient_step in trainer. + :param function save_checkpoint_fn: a hook defined by user, see trainer + documentation for detail. + """ + if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: + self.last_save_step = epoch + checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step) + + checkpoint_artifact = wandb.Artifact( + 'run_' + self.wandb_run.id + '_checkpoint', # type: ignore + type='model', + metadata={ + "save/epoch": epoch, + "save/env_step": env_step, + "save/gradient_step": gradient_step, + "checkpoint_path": str(checkpoint_path) + } + ) + checkpoint_artifact.add_file(str(checkpoint_path)) + self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore + + def restore_data(self) -> Tuple[int, int, int]: + checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore + 'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore + ) + assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist" + + checkpoint_artifact.download( + os.path.dirname(checkpoint_artifact.metadata['checkpoint_path']) + ) + + try: # epoch / gradient_step + epoch = checkpoint_artifact.metadata["save/epoch"] + self.last_save_step = self.last_log_test_step = epoch + gradient_step = checkpoint_artifact.metadata["save/gradient_step"] + self.last_log_update_step = gradient_step + except KeyError: + epoch, gradient_step = 0, 0 + try: # offline trainer doesn't have env_step + env_step = checkpoint_artifact.metadata["save/env_step"] + self.last_log_train_step = env_step + except KeyError: + env_step = 0 + return epoch, env_step, gradient_step