这是indexloc提供的服务,不要输入任何密码
Skip to content

W&B: add artifacts support #441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3a46660
add artifacts support
AyushExel Sep 7, 2021
32b3507
remove print
AyushExel Sep 7, 2021
1cc306e
Update tianshou/utils/logger/wandb.py
AyushExel Sep 7, 2021
5cc42b4
monitor gym
AyushExel Sep 8, 2021
5617fe2
monitor gym
AyushExel Sep 8, 2021
c751375
Merge branch 'master' into wandb
Trinkle23897 Sep 8, 2021
94de701
Update logger
AyushExel Sep 15, 2021
0ac506c
repo label
AyushExel Sep 15, 2021
00e4afb
add test
AyushExel Sep 16, 2021
e8616ff
update gym req.
AyushExel Sep 16, 2021
8e5c71f
ignore mypy checks
AyushExel Sep 16, 2021
acc0d1b
flake8
AyushExel Sep 16, 2021
3f4e9f8
update ci file
Trinkle23897 Sep 20, 2021
67943e3
try to fix ci
Trinkle23897 Sep 20, 2021
c7ad697
fix ci
AyushExel Sep 20, 2021
86a274b
try to fix ci
AyushExel Sep 20, 2021
7d423cd
try to fix ci
AyushExel Sep 20, 2021
761c40b
try ci fix
AyushExel Sep 20, 2021
f3fc3ea
try ci fix
AyushExel Sep 20, 2021
3c09be7
try ci fix
AyushExel Sep 20, 2021
0931a08
update docs
Trinkle23897 Sep 23, 2021
0bd9134
Update wandb.py
AyushExel Sep 23, 2021
261ee46
unify logger test on psrl
Trinkle23897 Sep 23, 2021
8d7e423
Merge branch 'wandb' of github.com:AyushExel/tianshou into wandb
Trinkle23897 Sep 23, 2021
b1698c0
Update wandb.py
AyushExel Sep 23, 2021
684afbb
fix format
Trinkle23897 Sep 23, 2021
ebf6015
add config logging
AyushExel Sep 23, 2021
0853f6b
Merge branch 'wandb' of https://github.com/AyushExel/tianshou into wandb
AyushExel Sep 23, 2021
397de8a
update
AyushExel Sep 23, 2021
1635000
merge atari_wandb into original file
Trinkle23897 Sep 23, 2021
dac8958
fix format
Trinkle23897 Sep 23, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/extra_sys.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
run: |
pytest test/base test/continuous --cov=tianshou --durations=0 -v
3 changes: 3 additions & 0 deletions .github/workflows/gputest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: wandb login
run: |
wandb login e2366d661b89f2bee877c40bee15502d67b7abef
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
Expand Down
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ Here is Tianshou's other features:

- Elegant framework, using only ~4000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions)

In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
Expand Down Expand Up @@ -191,8 +192,7 @@ gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
logger = ts.utils.TensorboardLogger(writer)
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn')) # TensorBoard is supported!
```

Make environments:
Expand All @@ -208,7 +208,7 @@ Define the network:
```python
from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/latest/tutorials/dqn.html#build-the-network
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
Expand Down Expand Up @@ -273,15 +273,15 @@ $ python3 test/discrete/test_pg.py --seed 0 --render 0.03

## Contributing

Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/latest/contributing.html).
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).

## Citing Tianshou

If you find Tianshou useful, please cite it in your publications.

```latex
@article{weng2021tianshou,
title={Tianshou: a Highly Modularized Deep Reinforcement Learning Library},
title={Tianshou: A Highly Modularized Deep Reinforcement Learning Library},
author={Weng, Jiayi and Chen, Huayu and Yan, Dong and You, Kaichao and Duburcq, Alexis and Zhang, Minghao and Su, Hang and Zhu, Jun},
journal={arXiv preprint arXiv:2107.14171},
year={2021}
Expand Down
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ Here is Tianshou's other features:
* Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe`
* Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking

中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_


Installation
Expand Down
34 changes: 29 additions & 5 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tianshou.env import ShmemVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger


def get_args():
Expand Down Expand Up @@ -41,6 +41,13 @@ def get_args():
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument(
'--logger',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and maybe we can add some instructions on how to use wandb (including resume) in examples/atari/README.md instead of tensorboard?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the instructions in doc also? I can make a separate PR tomorrow for adding the detailed instructions in examples/atari/README.md as well as docs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just append to this pr

type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument(
'--watch',
default=False,
Expand Down Expand Up @@ -112,9 +119,18 @@ def test_dqn(args=get_args()):
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
logger = WandbLogger(
save_interval=1,
project=args.task,
name='dqn',
run_id=args.resume_id,
config=args,
)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -141,6 +157,12 @@ def train_fn(epoch, env_step):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
return ckpt_path

# watch agent's performance
def watch():
print("Setup test envs ...")
Expand Down Expand Up @@ -192,7 +214,9 @@ def watch():
save_fn=save_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False
test_in_train=False,
resume_from_log=args.resume_id is not None,
save_checkpoint_fn=save_checkpoint_fn,
)

pprint.pprint(result)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_version() -> str:
exclude=["test", "test.*", "examples", "examples.*", "docs", "docs.*"]
),
install_requires=[
"gym>=0.15.4",
"gym>=0.15.4,<0.20",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0",
Expand Down
27 changes: 21 additions & 6 deletions test/modelbased/test_psrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import PSRLPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger


def get_args():
Expand All @@ -30,6 +31,12 @@ def get_args():
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--eps', type=float, default=0.01)
parser.add_argument('--add-done-loop', action="store_true", default=False)
parser.add_argument(
'--logger',
type=str,
default="wandb",
choices=["wandb", "tensorboard", "none"],
)
return parser.parse_known_args()[0]


Expand Down Expand Up @@ -72,10 +79,18 @@ def test_psrl(args=get_args()):
exploration_noise=True
)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
# Logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1, project='psrl', name='wandb_test', config=args
)
elif args.logger == "tensorboard":
log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
logger = LazyLogger()

def stop_fn(mean_rewards):
if env.spec.reward_threshold:
Expand All @@ -96,8 +111,8 @@ def stop_fn(mean_rewards):
0,
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
# logger=logger,
test_in_train=False
logger=logger,
test_in_train=False,
)

if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tianshou/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandBLogger
from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.statistics import MovAvg, RunningMeanStd

__all__ = [
"MovAvg", "RunningMeanStd", "tqdm_config", "BaseLogger", "TensorboardLogger",
"BasicLogger", "LazyLogger", "WandBLogger"
"BasicLogger", "LazyLogger", "WandbLogger"
]
91 changes: 88 additions & 3 deletions tianshou/utils/logger/wandb.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import argparse
import os
from typing import Callable, Optional, Tuple

from tianshou.utils import BaseLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE

Expand All @@ -7,10 +11,10 @@
pass


class WandBLogger(BaseLogger):
"""Weights and Biases logger that sends data to Weights and Biases.
class WandbLogger(BaseLogger):
"""Weights and Biases logger that sends data to https://wandb.ai/.

Creates three panels with plots: train, test, and update.
This logger creates three panels with plots: train, test, and update.
Make sure to select the correct access for each panel in weights and biases:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can here show where the example script is?

- ``train/env_step`` for train plots
Expand All @@ -29,16 +33,97 @@ class WandBLogger(BaseLogger):
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data().
Default to 1000.
:param str project: W&B project name. Default to "tianshou".
:param str name: W&B run name. Default to None. If None, random name is assigned.
:param str entity: W&B team/organization name. Default to None.
:param str run_id: run id of W&B run to be resumed. Default to None.
:param argparse.Namespace config: experiment configurations. Default to None.
"""

def __init__(
self,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1000,
project: str = 'tianshou',
name: Optional[str] = None,
entity: Optional[str] = None,
run_id: Optional[str] = None,
config: Optional[argparse.Namespace] = None,
) -> None:
super().__init__(train_interval, test_interval, update_interval)
self.last_save_step = -1
self.save_interval = save_interval
self.restored = False

self.wandb_run = wandb.init(
project=project,
name=name,
id=run_id,
resume="allow",
entity=entity,
monitor_gym=True,
config=config, # type: ignore
) if not wandb.run else wandb.run
self.wandb_run._label(repo="tianshou") # type: ignore

def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
data[step_type] = step
wandb.log(data)

def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.

:param int epoch: the epoch in trainer.
:param int env_step: the env_step in trainer.
:param int gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch
checkpoint_path = save_checkpoint_fn(epoch, env_step, gradient_step)

checkpoint_artifact = wandb.Artifact(
'run_' + self.wandb_run.id + '_checkpoint', # type: ignore
type='model',
metadata={
"save/epoch": epoch,
"save/env_step": env_step,
"save/gradient_step": gradient_step,
"checkpoint_path": str(checkpoint_path)
}
)
checkpoint_artifact.add_file(str(checkpoint_path))
self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore

def restore_data(self) -> Tuple[int, int, int]:
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore
)
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"

checkpoint_artifact.download(
os.path.dirname(checkpoint_artifact.metadata['checkpoint_path'])
)

try: # epoch / gradient_step
epoch = checkpoint_artifact.metadata["save/epoch"]
self.last_save_step = self.last_log_test_step = epoch
gradient_step = checkpoint_artifact.metadata["save/gradient_step"]
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = checkpoint_artifact.metadata["save/env_step"]
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step