diff --git a/README.md b/README.md index c50321f98..80ee2ff12 100644 --- a/README.md +++ b/README.md @@ -191,11 +191,11 @@ Define some hyper-parameters: ```python task = 'CartPole-v0' lr, epoch, batch_size = 1e-3, 10, 64 -train_num, test_num = 8, 100 +train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, collect_per_step = 1000, 8 +step_per_epoch, step_per_collect = 10000, 10 writer = SummaryWriter('log/dqn') # tensorboard is also supported! ``` @@ -232,8 +232,8 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( - policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, - test_num, batch_size, + policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect, + test_num, batch_size, update_per_step=1 / step_per_collect, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index b3e126352..26ee8d285 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -284,7 +284,7 @@ policy.process_fn The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns. -Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: +Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as: .. math:: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 361f79f3c..40e4a399f 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -35,10 +35,10 @@ If you want to use the original ``gym.Env``: Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`) :: - train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)]) -Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``. +Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``. For the demonstration, here we use the second code-block. @@ -87,7 +87,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -It is also possible to use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: +You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. 2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. @@ -113,7 +113,7 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. :: - train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 8), exploration_noise=True) + train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True) test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) @@ -125,8 +125,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, - max_epoch=10, step_per_epoch=1000, collect_per_step=10, - episode_per_test=100, batch_size=64, + max_epoch=10, step_per_epoch=10000, step_per_collect=10, + update_per_step=0.1, episode_per_test=100, batch_size=64, train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda epoch, env_step: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, @@ -136,8 +136,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``step_per_epoch``: The number of step for updating policy network in one epoch; -* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; +* ``step_per_epoch``: The number of environment step (a.k.a. transition) collected per epoch; +* ``step_per_collect``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". @@ -205,7 +205,7 @@ Train a Policy with Customized Codes Tianshou supports user-defined training code. Here is the code snippet: :: - # pre-collect at least 5000 frames with random action before training + # pre-collect at least 5000 transitions with random action before training train_collector.collect(n_step=5000, random=True) policy.set_eps(0.1) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index c656c1ee2..3d7f28106 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -200,8 +200,9 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -293,7 +294,7 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward: {result["rews"][:, args.agent_id - 1].mean()}, length: {result["lens"].mean()}') if args.watch: watch(args) exit(0) @@ -355,10 +356,10 @@ With the above preparation, we are close to the first learned agent. The followi # start training, this may require about three minutes result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, - writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + writer=writer, test_in_train=False, reward_metric=reward_metric) agent = policy.policies[args.agent_id - 1] # let's watch the match! diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index e2b8f0778..72b8d1336 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=10000) + parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) @@ -140,7 +140,7 @@ def watch(): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, log_interval=args.log_interval, ) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 42cffa9be..d0a7ab81d 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -30,8 +30,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -141,9 +142,10 @@ def watch(): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 559b0878e..077d24891 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -151,9 +152,10 @@ def watch(): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index ed356381f..e2eed3cfd 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -139,9 +140,10 @@ def watch(): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 0b81cecd6..5b760c14e 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -4,7 +4,6 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter - from tianshou.policy import A2CPolicy from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net @@ -24,7 +23,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--episode-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -91,8 +90,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 8ed04c21e..8a2a6845d 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--episode-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -95,8 +95,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 444c357f8..69d0bfbee 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -25,8 +25,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=100) + parser.add_argument('--update-per-step', type=float, default=0.01) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--dueling-q-hidden-sizes', type=int, @@ -103,8 +104,8 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 0bf802d3b..a903008ab 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--auto-alpha', type=int, default=1) parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -143,9 +144,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, test_in_train=False, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index de88aa315..3d5d033f2 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=5000) - parser.add_argument('--collect-per-step', type=int, default=16) + parser.add_argument('--step-per-epoch', type=int, default=80000) + parser.add_argument('--step-per-collect', type=int, default=16) + parser.add_argument('--update-per-step', type=float, default=0.0625) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -99,10 +100,9 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, + test_fn=test_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 14e5095e7..333dab49c 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -29,8 +29,9 @@ def get_args(): parser.add_argument('--auto_alpha', type=int, default=1) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=12000) + parser.add_argument('--step-per-collect', type=int, default=5) + parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -112,8 +113,10 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) + assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index ba9cd79a3..22dda7d9b 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--n-step', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-epoch', type=int, default=40000) + parser.add_argument('--step-per-collect', type=int, default=4) + parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--update-per-step', type=int, default=1) parser.add_argument('--pre-collect-step', type=int, default=10000) parser.add_argument('--batch-size', type=int, default=256) @@ -139,10 +140,9 @@ def stop_fn(mean_rewards): train_collector.collect(n_step=args.pre_collect_step, random=True) result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, args.update_per_step, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - log_interval=args.log_interval) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, log_interval=args.log_interval) pprint.pprint(result) watch() diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index b9a6e0118..a83abc37b 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-collect', type=int, default=4) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -87,7 +87,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 2f8370217..1bda7aa16 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -96,7 +96,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 6f34ce0ad..d64155dbe 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=200) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -97,7 +97,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, log_interval=args.log_interval) assert stop_fn(result['best_reward']) diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 76271f40f..f23e1b057 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -104,7 +104,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 68d3fc433..afe88e16b 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-epoch', type=int, default=9600) + parser.add_argument('--step-per-collect', type=int, default=4) + parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -102,8 +103,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 45a59f425..4f8ede1a0 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -23,8 +23,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=16) + parser.add_argument('--step-per-epoch', type=int, default=150000) + parser.add_argument('--episode-per-collect', type=int, default=16) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, @@ -121,8 +121,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 6e075bb5c..0a96dbfa9 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -26,8 +26,10 @@ def get_args(): parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--il-step-per-epoch', type=int, default=500) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -110,8 +112,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -142,7 +145,7 @@ def stop_fn(mean_rewards): train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch // 5, args.collect_per_step, args.test_num, + args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index c90a92aa4..bbc32d912 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -29,8 +29,9 @@ def get_args(): parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=20000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -115,8 +116,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index ea7e6b6ad..1032b3176 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -23,8 +23,11 @@ def get_args(): parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--il-step-per-epoch', type=int, default=1000) + parser.add_argument('--episode-per-collect', type=int, default=8) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -96,8 +99,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -121,13 +124,12 @@ def stop_fn(mean_rewards): il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector( il_policy, - DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) ) train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 684ce9696..fe0573e1d 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-epoch', type=int, default=8000) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -112,7 +113,7 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index df02684c0..e9104c152 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -25,9 +25,10 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -114,9 +115,9 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 420f8e6cd..dc2e06c00 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) parser.add_argument('--training-num', type=int, default=10) @@ -92,9 +93,10 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, + train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 3dd2b8d63..09c4c52d2 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128]) @@ -91,7 +91,7 @@ def stop_fn(mean_rewards): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index a413111b7..784ae70db 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -21,8 +21,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-epoch', type=int, default=40000) + parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -82,8 +82,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 9862ea7f7..35634c675 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -22,8 +22,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=2000) - parser.add_argument('--collect-per-step', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--episode-per-collect', type=int, default=20) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -108,8 +108,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 006dd827b..e5ce61b98 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -110,9 +111,10 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 16ab54cb2..ebcb75157 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.05) parser.add_argument('--auto_alpha', type=int, default=0) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=5) + parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -108,9 +109,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + update_per_step=args.update_per_step, test_in_train=False) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 5a813874f..01ea98a58 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -17,8 +17,8 @@ def get_args(): parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=50000) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=5) - parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--episode-per-collect', type=int, default=1) parser.add_argument('--training-num', type=int, default=1) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -78,8 +78,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, 1, - args.test_num, 0, stop_fn=stop_fn, writer=writer, + args.step_per_epoch, 1, args.test_num, 0, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer, test_in_train=False) if __name__ == '__main__': diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index b081a5055..edf066e09 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -28,8 +28,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -162,10 +163,10 @@ def reward_metric(rews): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, - writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + writer=writer, test_in_train=False, reward_metric=reward_metric) return result, policy.policies[args.agent_id - 1] @@ -183,4 +184,4 @@ def watch( collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b24e8c6b2..477a2531a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -605,7 +605,7 @@ def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: class VectorReplayBuffer(ReplayBufferManager): """VectorReplayBuffer contains n ReplayBuffer with the same size. - It is used for storing data frame from different environments yet keeping the order + It is used for storing transition from different environments yet keeping the order of time. :param int total_size: the total size of VectorReplayBuffer. @@ -631,7 +631,7 @@ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. - It is used for storing data frame from different environments yet keeping the order + It is used for storing transition from different environments yet keeping the order of time. :param int total_size: the total size of PrioritizedVectorReplayBuffer. diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 74ee72d11..bb3239e0e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -198,7 +198,7 @@ def collect( if not n_step % self.env_num == 0: warnings.warn( f"n_step={n_step} is not a multiple of #env ({self.env_num}), " - "which may cause extra frame collected into the buffer." + "which may cause extra transitions collected into the buffer." ) ready_env_ids = np.arange(self.env_num) elif n_episode is not None: @@ -357,9 +357,9 @@ def collect( ) -> Dict[str, Any]: """Collect a specified number of step or episode with async env setting. - This function doesn't collect exactly n_step or n_episode number of frames. - Instead, in order to support async setting, it may collect more than given - n_step or n_episode frames and save into buffer. + This function doesn't collect exactly n_step or n_episode number of + transitions. Instead, in order to support async setting, it may collect more + than given n_step or n_episode transitions and save into buffer. :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. @@ -395,7 +395,7 @@ def collect( else: raise TypeError("Please specify at least one (either n_step or n_episode) " "in AsyncCollector.collect().") - warnings.warn("Using async setting may collect extra frames into buffer.") + warnings.warn("Using async setting may collect extra transitions into buffer.") ready_env_ids = self._ready_env_ids diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index be6d8216b..9023bf47b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,21 +220,17 @@ def compute_episodic_return( Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. - :param batch: a data batch which contains several episodes of data + :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). - :type batch: :class:`~tianshou.data.Batch` - :param numpy.ndarray indice: tell batch's location in buffer, batch is + :param np.ndarray indice: tell batch's location in buffer, batch is equal to buffer[indice]. - :param v_s_: the value function of all next states :math:`V(s')`. - :type v_s_: numpy.ndarray - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param float gae_lambda: the parameter for Generalized Advantage - Estimation, should be in [0, 1], defaults to 0.95. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param float gae_lambda: the parameter for Generalized Advantage Estimation, + should be in [0, 1]. Default to 0.95. + :param bool rew_norm: normalize the reward to Normal(0, 1). Default to False. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). @@ -273,18 +269,14 @@ def compute_nstep_return( where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. - :param batch: a data batch, which is equal to buffer[indice]. - :type batch: :class:`~tianshou.data.Batch` - :param buffer: the data buffer. - :type buffer: :class:`~tianshou.data.ReplayBuffer` + :param Batch batch: a data batch, which is equal to buffer[indice]. + :param ReplayBuffer buffer: the data buffer. :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param int n_step: the number of estimation step, should be an int - greater than 0, defaults to 1. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param int n_step: the number of estimation step, should be an int greater + than 0. Default to 1. + :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor. diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 92b7f5d89..82fb9f704 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -45,7 +45,7 @@ def __init__( def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - r"""Compute the discounted returns for each frame. + r"""Compute the discounted returns for each transition. .. math:: G_t = \sum_{i=t}^T \gamma^{i-t}r_i diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 13f9159f8..9c7f132af 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -38,5 +38,5 @@ def forward( return Batch(act=logits.argmax(axis=-1)) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - """Since a random agent learn nothing, it returns an empty dict.""" + """Since a random agent learns nothing, it returns an empty dict.""" return {} diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 7eb6ec5b5..61714f7a0 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -16,7 +16,7 @@ def offline_trainer( buffer: ReplayBuffer, test_collector: Collector, max_epoch: int, - step_per_epoch: int, + update_per_epoch: int, episode_per_test: int, batch_size: int, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, @@ -29,50 +29,52 @@ def offline_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. - The "step" in trainer means a policy network update. + The "step" in offline trainer means a gradient step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int update_per_epoch: the number of policy network updates, so-called gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ gradient_step = 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() test_collector.reset_stat() - + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, gradient_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): policy.train() with tqdm.trange( - step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: for i in t: gradient_step += 1 @@ -87,16 +89,18 @@ def offline_trainer( global_step=gradient_step) t.set_postfix(**data) # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, gradient_step, reward_metric) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result['rews'].std() + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, gradient_step, + reward_metric) + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index ba08c2e0b..54e7cb166 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -17,10 +17,10 @@ def offpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, + step_per_collect: int, episode_per_test: int, batch_size: int, - update_per_step: int = 1, + update_per_step: Union[int, float] = 1, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, @@ -33,60 +33,62 @@ def offpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment step (a.k.a. transition). :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of frames the collector would - collect before the network update. In other words, collect some frames - and do some policy network update. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param int update_per_step: the number of times the policy network would - be updated after frames are collected, for example, set it to 256 means - it updates policy 256 times once after ``collect_per_step`` frames are - collected. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int/float update_per_step: the number of times the policy network would be + updated per transition after (step_per_collect) transitions are collected, + e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will + be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are + collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy:BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -96,10 +98,11 @@ def offpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_step=collect_per_step) + result = train_collector.collect(n_step=step_per_collect) if len(result["rews"]) > 0 and reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) data = { "env_step": str(env_step), "rew": f"{result['rews'].mean():.2f}", @@ -126,8 +129,7 @@ def offpolicy_trainer( test_result["rews"].mean(), test_result["rews"].std()) else: policy.train() - for i in range(update_per_step * min( - result["n/st"] // collect_per_step, t.total - t.n)): + for i in range(round(update_per_step * result["n/st"])): gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): @@ -136,21 +138,21 @@ def offpolicy_trainer( if writer and gradient_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=gradient_step) - t.update(1) t.set_postfix(**data) if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step, reward_metric) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, env_step, reward_metric) + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index b951f9a9e..43fcc8738 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -17,10 +17,11 @@ def onpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, repeat_per_collect: int, episode_per_test: int, batch_size: int, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, @@ -33,60 +34,67 @@ def onpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment step (a.k.a. transition). :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of episodes the collector would - collect before the network update. In other words, collect some - episodes and do one policy network update. - :param int repeat_per_collect: the number of repeat time for policy - learning, for example, set it to 2 means the policy needs to learn each - given batch data twice. - :param episode_per_test: the number of episodes for one policy evaluation. - :type episode_per_test: int or list of ints - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, for + example, set it to 2 means the policy needs to learn each given batch data + twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. + :param int episode_per_collect: the number of episodes the collector would collect + before the network update, i.e., trainer will collect "episode_per_collect" + episodes and do some policy network update repeatly in each epoch. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. + + .. note:: + + Only either one of step_per_collect and episode_per_collect can be specified. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -96,10 +104,12 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_episode=collect_per_step) + result = train_collector.collect(n_step=step_per_collect, + n_episode=episode_per_collect) if reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) data = { "env_step": str(env_step), "rew": f"{result['rews'].mean():.2f}", @@ -138,21 +148,21 @@ def onpolicy_trainer( if writer and gradient_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=gradient_step) - t.update(step) t.set_postfix(**data) if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, env_step) + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 2cdeb15fe..72803bef0 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -48,14 +48,14 @@ def gather_info( * ``train_step`` the total collected step of training collector; * ``train_episode`` the total collected episode of training collector; - * ``train_time/collector`` the time for collecting frames in the \ + * ``train_time/collector`` the time for collecting transitions in the \ training collector; * ``train_time/model`` the time for training models; - * ``train_speed`` the speed of training (frames per second); + * ``train_speed`` the speed of training (env_step per second); * ``test_step`` the total collected step of test collector; * ``test_episode`` the total collected episode of test collector; * ``test_time`` the time for testing; - * ``test_speed`` the speed of testing (frames per second); + * ``test_speed`` the speed of testing (env_step per second); * ``best_reward`` the best reward over the test results; * ``duration`` the total elapsed time. """