这是indexloc提供的服务,不要输入任何密码
Skip to content

Trainer refactor : flexible logger #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
logger = ts.utils.BasicLogger(writer)
```

Make environments:
Expand Down Expand Up @@ -237,7 +238,7 @@ result = ts.trainer.offpolicy_trainer(
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=writer)
logger=logger)
print(f'Finished training! Use {result["duration"]}')
```

Expand Down
1 change: 1 addition & 0 deletions docs/contributor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
* Huayu Chen (`ChenDRAG <https://github.com/ChenDRAG>`_)
8 changes: 5 additions & 3 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=None)
logger=None)
print(f'Finished training! Use {result["duration"]}')

The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
Expand All @@ -143,15 +143,17 @@ The meaning of each parameter is as follows (full description can be found at :f
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
* ``writer``: See below.
* ``logger``: See below.

The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as:
::

from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
writer = SummaryWriter('log/dqn')
logger = BasicLogger(writer)

Pass the writer into the trainer, and the training result will be recorded into the TensorBoard.
Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.

The returned result is a dictionary as follows:
::
Expand Down
12 changes: 6 additions & 6 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger

from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
Expand Down Expand Up @@ -319,11 +320,10 @@ With the above preparation, we are close to the first learned agent. The followi
train_collector.collect(n_step=args.batch_size * args.training_num)

# ======== tensorboard logging setup =========
if not hasattr(args, 'writer'):
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
else:
writer = args.writer
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)

# ======== callback functions used during training =========

Expand Down Expand Up @@ -359,7 +359,7 @@ With the above preparation, we are close to the first learned agent. The followi
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
writer=writer, test_in_train=False, reward_metric=reward_metric)
logger=logger, test_in_train=False, reward_metric=reward_metric)

agent = policy.policies[args.agent_id - 1]
# let's watch the match!
Expand Down
13 changes: 10 additions & 3 deletions examples/atari/atari_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import torch
import pickle
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
Expand Down Expand Up @@ -39,7 +41,7 @@ def get_args():
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=1000)
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
Expand Down Expand Up @@ -113,8 +115,13 @@ def test_discrete_bcq(args=get_args()):
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)

log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
# log
log_path = os.path.join(
args.logdir, args.task, 'bcq',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -141,7 +148,7 @@ def watch():
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
log_interval=args.log_interval,
)

Expand Down
7 changes: 5 additions & 2 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -98,6 +99,8 @@ def test_c51(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -118,7 +121,7 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)
logger.write('train/eps', env_step, eps)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand All @@ -144,7 +147,7 @@ def watch():
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)

pprint.pprint(result)
Expand Down
7 changes: 5 additions & 2 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -94,6 +95,8 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -114,7 +117,7 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)
logger.write('train/eps', env_step, eps)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand Down Expand Up @@ -154,7 +157,7 @@ def watch():
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)

pprint.pprint(result)
Expand Down
7 changes: 5 additions & 2 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.utils import BasicLogger
from tianshou.policy import QRDQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
Expand Down Expand Up @@ -96,6 +97,8 @@ def test_qrdqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -116,7 +119,7 @@ def train_fn(epoch, env_step):
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)
logger.write('train/eps', env_step, eps)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
Expand All @@ -142,7 +145,7 @@ def watch():
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)

pprint.pprint(result)
Expand Down
8 changes: 6 additions & 2 deletions examples/atari/runnable/pong_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import A2CPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
Expand Down Expand Up @@ -79,7 +81,9 @@ def test_a2c(args=get_args()):
preprocess_fn=preprocess_fn, exploration_noise=True)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c'))
log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)

def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
Expand All @@ -91,7 +95,7 @@ def stop_fn(mean_rewards):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
10 changes: 7 additions & 3 deletions examples/atari/runnable/pong_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.data import Collector, VectorReplayBuffer

from atari import create_atari_environment, preprocess_fn

Expand Down Expand Up @@ -84,7 +85,9 @@ def test_ppo(args=get_args()):
preprocess_fn=preprocess_fn, exploration_noise=True)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo'))
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)

def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
Expand All @@ -96,7 +99,8 @@ def stop_fn(mean_rewards):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)

if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
Expand Down
4 changes: 3 additions & 1 deletion examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
Expand Down Expand Up @@ -81,6 +82,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -106,7 +108,7 @@ def test_fn(epoch, env_step):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])
if __name__ == '__main__':
Expand Down
4 changes: 3 additions & 1 deletion examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
Expand Down Expand Up @@ -134,6 +135,7 @@ def test_sac_bipedal(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -146,7 +148,7 @@ def stop_fn(mean_rewards):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, test_in_train=False,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, logger=logger)

if __name__ == '__main__':
pprint.pprint(result)
Expand Down
4 changes: 3 additions & 1 deletion examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
Expand Down Expand Up @@ -83,6 +84,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
Expand All @@ -102,7 +104,7 @@ def test_fn(epoch, env_step):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn,
test_fn=test_fn, save_fn=save_fn, writer=writer)
test_fn=test_fn, save_fn=save_fn, logger=logger)

assert stop_fn(result['best_reward'])
if __name__ == '__main__':
Expand Down
Loading