这是indexloc提供的服务,不要输入任何密码
Skip to content

Use global_step as the x-axis for wandb #558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 6, 2022
4 changes: 4 additions & 0 deletions docs/tutorials/logger.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ WandbLogger
::

from tianshou.utils import WandbLogger
from torch.utils.tensorboard import SummaryWriter

logger = WandbLogger(...)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger.load(writer)
result = trainer(..., logger=logger)

Please refer to :class:`~tianshou.utils.WandbLogger` documentation for advanced configuration.
Expand Down
100 changes: 63 additions & 37 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint

Expand All @@ -11,46 +12,54 @@
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import C51Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils import TensorboardLogger, WandbLogger


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--num-atoms", type=int, default=51)
parser.add_argument("--v-min", type=float, default=-10.)
parser.add_argument("--v-max", type=float, default=10.)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--watch',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
return parser.parse_args()


Expand Down Expand Up @@ -101,19 +110,36 @@ def test_c51(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)

# log
log_path = os.path.join(args.logdir, args.task, 'c51')
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "c51"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
Expand Down Expand Up @@ -159,7 +185,7 @@ def watch():
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
watch()
Expand Down Expand Up @@ -190,5 +216,5 @@ def watch():
watch()


if __name__ == '__main__':
if __name__ == "__main__":
test_c51(get_args())
109 changes: 59 additions & 50 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import datetime
import os
import pprint

Expand All @@ -18,62 +19,63 @@

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--scale-obs', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--scale-obs", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--resume-id', type=str, default=None)
parser.add_argument("--frames-stack", type=int, default=4)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
'--logger',
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="atari.benchmark")
parser.add_argument(
'--watch',
"--watch",
default=False,
action='store_true',
help='watch the play of pre-trained policy only'
action="store_true",
help="watch the play of pre-trained policy only"
)
parser.add_argument('--save-buffer-name', type=str, default=None)
parser.add_argument("--save-buffer-name", type=str, default=None)
parser.add_argument(
'--icm-lr-scale',
"--icm-lr-scale",
type=float,
default=0.,
help='use intrinsic curiosity module with this lr scale'
help="use intrinsic curiosity module with this lr scale"
)
parser.add_argument(
'--icm-reward-scale',
"--icm-reward-scale",
type=float,
default=0.01,
help='scaling factor for intrinsic curiosity reward'
help="scaling factor for intrinsic curiosity reward"
)
parser.add_argument(
'--icm-forward-loss-weight',
"--icm-forward-loss-weight",
type=float,
default=0.2,
help='weight for the forward model loss in ICM'
help="weight for the forward model loss in ICM"
)
return parser.parse_args()

Expand Down Expand Up @@ -140,29 +142,36 @@ def test_dqn(args=get_args()):
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)

# log
log_name = 'dqn_icm' if args.icm_lr_scale > 0 else 'dqn'
log_path = os.path.join(args.logdir, args.task, log_name)
if args.logger == "tensorboard":
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
else:
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
project=args.task,
name=log_name,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
elif "Pong" in args.task:
return mean_rewards >= 20
else:
return False
Expand All @@ -183,8 +192,8 @@ def test_fn(epoch, env_step):

def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
torch.save({'model': policy.state_dict()}, ckpt_path)
ckpt_path = os.path.join(log_path, "checkpoint.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

# watch agent's performance
Expand Down Expand Up @@ -214,7 +223,7 @@ def watch():
n_episode=args.test_num, render=args.render
)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

if args.watch:
watch()
Expand Down Expand Up @@ -247,5 +256,5 @@ def watch():
watch()


if __name__ == '__main__':
if __name__ == "__main__":
test_dqn(get_args())
Loading