这是indexloc提供的服务,不要输入任何密码
Skip to content

Add Weights and Biases Logger #427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
logger = ts.utils.BasicLogger(writer)
logger = ts.utils.TensorboardLogger(writer)
```

Make environments:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.
To resume training process from an existing checkpoint, you need to do the following things in the training process:

1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``BasicLogger`` which contains a tensorboard;
3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger.
2. Use ``TensorboardLogger``;
3. To adjust the save frequency, specify ``save_interval`` when initializing TensorboardLogger.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to resume an experiment with wandb? like #350

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but I've never dealt with it before. I suggest waiting for people who have that specific need to raise an issue.
Because I think to resume it you need to start it differently - and it may be extra complexity not everyone needs.


And to successfully resume from a checkpoint:

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for
::

from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
writer = SummaryWriter('log/dqn')
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger

from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
Expand Down Expand Up @@ -323,7 +323,7 @@ With the above preparation, we are close to the first learned agent. The followi
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

# ======== callback functions used during training =========

Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_discrete_bcq(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
logger = TensorboardLogger(writer, update_interval=args.log_interval)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_c51(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_discrete_cql(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
logger = TensorboardLogger(writer, update_interval=args.log_interval)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_discrete_crr(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
logger = TensorboardLogger(writer, update_interval=args.log_interval)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -96,7 +96,7 @@ def test_dqn(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import FQFPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_fqf(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'fqf')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import IQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_iqn(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'iqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.policy import QRDQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_qrdqn(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import RainbowPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_rainbow(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_sac_bipedal(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.exploration import OUNoise
from tianshou.utils.net.common import Net
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_sac(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions import Independent, Normal

from tianshou.policy import A2CPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
Expand Down Expand Up @@ -141,7 +141,7 @@ def dist(*logits):
log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DDPGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_ddpg(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions import Independent, Normal

from tianshou.policy import NPGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
Expand Down Expand Up @@ -142,7 +142,7 @@ def dist(*logits):
log_path = os.path.join(args.logdir, args.task, 'npg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions import Independent, Normal

from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
Expand Down Expand Up @@ -149,7 +149,7 @@ def dist(*logits):
log_path = os.path.join(args.logdir, args.task, 'ppo', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.distributions import Independent, Normal

from tianshou.policy import PGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
Expand Down Expand Up @@ -131,7 +131,7 @@ def dist(*logits):
log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=10, train_interval=100)
logger = TensorboardLogger(writer, update_interval=10, train_interval=100)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
4 changes: 2 additions & 2 deletions examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_sac(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand Down
Loading