From e8e0ec188d7bb287206324e8e85f39fc9f4834dd Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 19 Feb 2021 11:08:52 +0800 Subject: [PATCH 01/25] rebase --- README.md | 4 +-- docs/tutorials/dqn.rst | 4 +-- docs/tutorials/tictactoe.rst | 2 +- examples/atari/atari_bcq.py | 4 +-- examples/atari/atari_c51.py | 4 +-- examples/atari/atari_dqn.py | 4 +-- examples/atari/atari_qrdqn.py | 4 +-- examples/atari/runnable/pong_a2c.py | 6 ++--- examples/atari/runnable/pong_ppo.py | 4 +-- examples/box2d/acrobot_dualdqn.py | 4 +-- examples/box2d/bipedal_hardcore_sac.py | 4 +-- examples/box2d/lunarlander_dqn.py | 4 +-- examples/box2d/mcc_sac.py | 4 +-- examples/mujoco/mujoco_sac.py | 4 +-- examples/mujoco/runnable/ant_v2_ddpg.py | 4 +-- examples/mujoco/runnable/ant_v2_td3.py | 4 +-- .../runnable/halfcheetahBullet_v0_sac.py | 4 +-- examples/mujoco/runnable/point_maze_td3.py | 4 +-- test/continuous/test_ddpg.py | 4 +-- test/continuous/test_ppo.py | 6 ++--- test/continuous/test_sac_with_il.py | 6 ++--- test/continuous/test_td3.py | 4 +-- test/discrete/test_a2c_with_il.py | 8 +++--- test/discrete/test_c51.py | 4 +-- test/discrete/test_dqn.py | 4 +-- test/discrete/test_drqn.py | 12 +++++---- test/discrete/test_il_bcq.py | 4 +-- test/discrete/test_pg.py | 4 +-- test/discrete/test_ppo.py | 4 +-- test/discrete/test_qrdqn.py | 4 +-- test/discrete/test_sac.py | 4 +-- test/modelbase/test_psrl.py | 6 ++--- test/multiagent/tic_tac_toe.py | 4 +-- tianshou/trainer/offline.py | 6 ++--- tianshou/trainer/offpolicy.py | 27 ++++++++++--------- tianshou/trainer/onpolicy.py | 22 ++++++++++----- 36 files changed, 108 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index c50321f98..d543515db 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,7 @@ train_num, test_num = 8, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, collect_per_step = 1000, 8 +step_per_epoch, step_per_collect = 1000, 8 writer = SummaryWriter('log/dqn') # tensorboard is also supported! ``` @@ -232,7 +232,7 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( - policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, + policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect, test_num, batch_size, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 361f79f3c..31ee2ce46 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -125,7 +125,7 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, - max_epoch=10, step_per_epoch=1000, collect_per_step=10, + max_epoch=10, step_per_epoch=1000, step_per_collect=10, episode_per_test=100, batch_size=64, train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda epoch, env_step: policy.set_eps(0.05), @@ -137,7 +137,7 @@ The meaning of each parameter is as follows (full description can be found at :f * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; * ``step_per_epoch``: The number of step for updating policy network in one epoch; -* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; +* ``step_per_collect``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index c656c1ee2..c0f683623 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -355,7 +355,7 @@ With the above preparation, we are close to the first learned agent. The followi # start training, this may require about three minutes result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, writer=writer, test_in_train=False) diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index e2b8f0778..72b8d1336 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=10000) + parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) @@ -140,7 +140,7 @@ def watch(): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, log_interval=args.log_interval, ) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 42cffa9be..9c046c50e 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -141,7 +141,7 @@ def watch(): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 559b0878e..2b0153891 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -151,7 +151,7 @@ def watch(): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index ed356381f..8a6d97fa4 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -139,7 +139,7 @@ def watch(): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 0b81cecd6..3a0e7e690 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -4,7 +4,7 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter - +#TODO why bug from tianshou.policy import A2CPolicy from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -91,7 +91,7 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 8ed04c21e..c1f293879 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -95,7 +95,7 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 444c357f8..68c2eac88 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--step-per-collect', type=int, default=100) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--dueling-q-hidden-sizes', type=int, @@ -103,7 +103,7 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 0bf802d3b..a0baf575f 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -143,7 +143,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index de88aa315..5a4c61b42 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -27,7 +27,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=5000) - parser.add_argument('--collect-per-step', type=int, default=16) + parser.add_argument('--step-per-collect', type=int, default=16) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -99,7 +99,7 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 14e5095e7..cb31b7001 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -30,7 +30,7 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--step-per-collect', type=int, default=5) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -112,7 +112,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index ba9cd79a3..558d24c05 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--n-step', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-collect', type=int, default=4) parser.add_argument('--update-per-step', type=int, default=1) parser.add_argument('--pre-collect-step', type=int, default=10000) parser.add_argument('--batch-size', type=int, default=256) @@ -139,7 +139,7 @@ def stop_fn(mean_rewards): train_collector.collect(n_step=args.pre_collect_step, random=True) result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, writer=writer, log_interval=args.log_interval) diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index b9a6e0118..a83abc37b 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-collect', type=int, default=4) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -87,7 +87,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 2f8370217..1bda7aa16 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -96,7 +96,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 6f34ce0ad..d64155dbe 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=200) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -97,7 +97,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, log_interval=args.log_interval) assert stop_fn(result['best_reward']) diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 76271f40f..f23e1b057 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -31,7 +31,7 @@ def get_args(): parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -104,7 +104,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 68d3fc433..2ad173beb 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -27,7 +27,7 @@ def get_args(): parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--step-per-collect', type=int, default=4) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -102,7 +102,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 45a59f425..a90471f82 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -23,8 +23,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=16) + parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--step-per-collect', type=int, default=16) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, @@ -121,7 +121,7 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 6e075bb5c..71086894e 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -27,7 +27,7 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -110,7 +110,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -142,7 +142,7 @@ def stop_fn(mean_rewards): train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch // 5, args.collect_per_step, args.test_num, + args.step_per_epoch // 5, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index c90a92aa4..fc5bf037e 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -30,7 +30,7 @@ def get_args(): parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -115,7 +115,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index ea7e6b6ad..7652dbe35 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -23,8 +23,8 @@ def get_args(): parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -96,7 +96,7 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) @@ -127,7 +127,7 @@ def stop_fn(mean_rewards): train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 684ce9696..3f801be90 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -112,7 +112,7 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index df02684c0..f3923e6a9 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -27,7 +27,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -114,7 +114,7 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 420f8e6cd..6f1191f88 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=3) parser.add_argument('--training-num', type=int, default=10) @@ -92,9 +93,10 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step = args.update_per_step, + train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index 3dd2b8d63..09c4c52d2 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128]) @@ -91,7 +91,7 @@ def stop_fn(mean_rewards): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index a413111b7..ffbe7a9d4 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -22,7 +22,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=8) + parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -82,7 +82,7 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 9862ea7f7..0dce4416b 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=2000) - parser.add_argument('--collect-per-step', type=int, default=20) + parser.add_argument('--step-per-collect', type=int, default=20) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -108,7 +108,7 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, + args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 006dd827b..942d4ee07 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -27,7 +27,7 @@ def get_args(): parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -110,7 +110,7 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 16ab54cb2..3f7f2a39a 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -28,7 +28,7 @@ def get_args(): parser.add_argument('--auto_alpha', type=int, default=0) parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--step-per-collect', type=int, default=5) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -108,7 +108,7 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) assert stop_fn(result['best_reward']) diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 5a813874f..9ea2c530f 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -17,8 +17,8 @@ def get_args(): parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=50000) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=5) - parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-collect', type=int, default=1) parser.add_argument('--training-num', type=int, default=1) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -78,7 +78,7 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, 1, + args.step_per_epoch, args.step_per_collect, 1, args.test_num, 0, stop_fn=stop_fn, writer=writer, test_in_train=False) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index b081a5055..1b471f62e 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -29,7 +29,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -162,7 +162,7 @@ def reward_metric(rews): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, writer=writer, test_in_train=False) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 7eb6ec5b5..93597c876 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -16,7 +16,7 @@ def offline_trainer( buffer: ReplayBuffer, test_collector: Collector, max_epoch: int, - step_per_epoch: int, + update_per_epoch: int, episode_per_test: int, batch_size: int, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, @@ -36,7 +36,7 @@ def offline_trainer( :type test_collector: :class:`~tianshou.data.Collector` :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called + :param int update_per_epoch: the number of policy network updates, so-called gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to @@ -72,7 +72,7 @@ def offline_trainer( for epoch in range(1, 1 + max_epoch): policy.train() with tqdm.trange( - step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: for i in t: gradient_step += 1 diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index ba08c2e0b..470e91892 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -17,10 +17,10 @@ def offpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, + step_per_collect: int, episode_per_test: int, batch_size: int, - update_per_step: int = 1, + update_per_step: Union[int, float] = 1, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, @@ -42,18 +42,18 @@ def offpolicy_trainer( :type test_collector: :class:`~tianshou.data.Collector` :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of frames the collector would + :param int step_per_epoch: the number of environment frames collected per epoch. + :param int step_per_collect: the number of frames the collector would collect before the network update. In other words, collect some frames and do some policy network update. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param int update_per_step: the number of times the policy network would - be updated after frames are collected, for example, set it to 256 means - it updates policy 256 times once after ``collect_per_step`` frames are - collected. + :param int/float update_per_step: the number of times the policy network would + be updated per environment frame after (step_per_collect) frames are collected, + for example, if update_per_step set to 0.3, and step_per_collect is 256, + policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256 frames are + collected by the collector. Default to 1. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. @@ -87,6 +87,8 @@ def offpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step) for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -96,10 +98,11 @@ def offpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_step=collect_per_step) + result = train_collector.collect(n_step=step_per_collect) if len(result["rews"]) > 0 and reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) data = { "env_step": str(env_step), "rew": f"{result['rews'].mean():.2f}", @@ -126,8 +129,7 @@ def offpolicy_trainer( test_result["rews"].mean(), test_result["rews"].std()) else: policy.train() - for i in range(update_per_step * min( - result["n/st"] // collect_per_step, t.total - t.n)): + for i in range(round(update_per_step * result["n/st"])): gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): @@ -136,7 +138,6 @@ def offpolicy_trainer( if writer and gradient_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=gradient_step) - t.update(1) t.set_postfix(**data) if t.n <= t.total: t.update() diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index b951f9a9e..31c63f605 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -17,7 +17,7 @@ def onpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, + step_per_collect: int, repeat_per_collect: int, episode_per_test: int, batch_size: int, @@ -30,6 +30,7 @@ def onpolicy_trainer( log_interval: int = 1, verbose: bool = True, test_in_train: bool = True, + collect_method = "episode", ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. @@ -42,9 +43,12 @@ def onpolicy_trainer( :type test_collector: :class:`~tianshou.data.Collector` :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of episodes the collector would + :param int step_per_epoch: the number of environment frames collected per epoch. + :param int step_per_collect: the number of episodes the collector would + collect before the network update in "episode" collect mode(defalut), + the number of frames the collector would collect in "step" collect + mode. + :param int step_per_collect: the number of episodes the collector would collect before the network update. In other words, collect some episodes and do one policy network update. :param int repeat_per_collect: the number of repeat time for policy @@ -77,6 +81,8 @@ def onpolicy_trainer( :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. :param bool test_in_train: whether to test in the training phase. + :param string collect_method: specifies collect mode. Can be either "episode" + or "step". :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -87,6 +93,8 @@ def onpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step) for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -96,10 +104,11 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_episode=collect_per_step) + result = train_collector.collect(**{"n_" + collect_method : step_per_collect}) if reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) data = { "env_step": str(env_step), "rew": f"{result['rews'].mean():.2f}", @@ -138,12 +147,11 @@ def onpolicy_trainer( if writer and gradient_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=gradient_step) - t.update(step) t.set_postfix(**data) if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, + test_result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) if best_epoch == -1 or best_reward < result["rews"].mean(): best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() From 3f215f9df99ea1bfa48c15f34ed5c226588859e4 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 11:21:22 +0800 Subject: [PATCH 02/25] pep8 fix --- examples/atari/runnable/pong_a2c.py | 1 - test/discrete/test_drqn.py | 2 +- tianshou/trainer/offpolicy.py | 4 ++-- tianshou/trainer/onpolicy.py | 11 ++++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 3a0e7e690..0b7c00118 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -4,7 +4,6 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter -#TODO why bug from tianshou.policy import A2CPolicy from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 6f1191f88..5eddb0e8d 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -94,7 +94,7 @@ def test_fn(epoch, env_step): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, update_per_step = args.update_per_step, + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 470e91892..638b8b388 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -87,8 +87,8 @@ def offpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_episode(policy, test_collector, test_fn, 0, episode_per_test, - writer, env_step) + test_episode(policy, test_collector, test_fn, 0, + episode_per_test, writer, env_step) for epoch in range(1, 1 + max_epoch): # train policy.train() diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 31c63f605..a5b0e1792 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -30,7 +30,7 @@ def onpolicy_trainer( log_interval: int = 1, verbose: bool = True, test_in_train: bool = True, - collect_method = "episode", + collect_method: str = "episode", ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. @@ -93,8 +93,8 @@ def onpolicy_trainer( train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_episode(policy, test_collector, test_fn, 0, episode_per_test, - writer, env_step) + test_episode(policy, test_collector, test_fn, 0, + episode_per_test, writer, env_step) for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -104,7 +104,8 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(**{"n_" + collect_method : step_per_collect}) + result = train_collector.collect( + **{"n_" + collect_method: step_per_collect}) if reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) @@ -152,7 +153,7 @@ def onpolicy_trainer( t.update() # test test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step) + episode_per_test, writer, env_step) if best_epoch == -1 or best_reward < result["rews"].mean(): best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() best_epoch = epoch From 652a5dea89b1492ad8d04a2ed7011be6a51bd80a Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 11:40:35 +0800 Subject: [PATCH 03/25] remove collect method --- examples/atari/runnable/pong_a2c.py | 6 +++--- examples/atari/runnable/pong_ppo.py | 6 +++--- test/continuous/test_ppo.py | 6 +++--- test/discrete/test_a2c_with_il.py | 6 +++--- test/discrete/test_pg.py | 6 +++--- test/discrete/test_ppo.py | 6 +++--- test/modelbase/test_psrl.py | 6 +++--- tianshou/trainer/onpolicy.py | 23 ++++++++++------------- 8 files changed, 31 insertions(+), 34 deletions(-) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 0b7c00118..5b760c14e 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--episode-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -90,8 +90,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index c1f293879..8a2a6845d 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--episode-per-collect', type=int, default=10) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -95,8 +95,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index a90471f82..503b8eb73 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=24000) - parser.add_argument('--step-per-collect', type=int, default=16) + parser.add_argument('--episode-per-collect', type=int, default=16) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, @@ -121,8 +121,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 7652dbe35..96d727f15 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -24,7 +24,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -96,8 +96,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index ffbe7a9d4..b8cb6367e 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -22,7 +22,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -82,8 +82,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 0dce4416b..316e4c69c 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=2000) - parser.add_argument('--step-per-collect', type=int, default=20) + parser.add_argument('--episode-per-collect', type=int, default=20) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -108,8 +108,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 9ea2c530f..01ea98a58 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -18,7 +18,7 @@ def get_args(): parser.add_argument('--buffer-size', type=int, default=50000) parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--step-per-collect', type=int, default=1) + parser.add_argument('--episode-per-collect', type=int, default=1) parser.add_argument('--training-num', type=int, default=1) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -78,8 +78,8 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, 1, - args.test_num, 0, stop_fn=stop_fn, writer=writer, + args.step_per_epoch, 1, args.test_num, 0, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer, test_in_train=False) if __name__ == '__main__': diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index a5b0e1792..c92bce6b1 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -17,10 +17,11 @@ def onpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - step_per_collect: int, repeat_per_collect: int, episode_per_test: int, batch_size: int, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, @@ -30,7 +31,6 @@ def onpolicy_trainer( log_interval: int = 1, verbose: bool = True, test_in_train: bool = True, - collect_method: str = "episode", ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. @@ -44,13 +44,6 @@ def onpolicy_trainer( :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching the ``max_epoch``. :param int step_per_epoch: the number of environment frames collected per epoch. - :param int step_per_collect: the number of episodes the collector would - collect before the network update in "episode" collect mode(defalut), - the number of frames the collector would collect in "step" collect - mode. - :param int step_per_collect: the number of episodes the collector would - collect before the network update. In other words, collect some - episodes and do one policy network update. :param int repeat_per_collect: the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. @@ -58,6 +51,12 @@ def onpolicy_trainer( :type episode_per_test: int or list of ints :param int batch_size: the batch size of sample data, which is going to feed in the policy network. + :param int step_per_collect: the number of episodes the collector would + collect before the network update. Only either one of step_per_collect + and episode_per_collect can be specified. + :param int episode_per_collect: the number of episodes the collector would + collect before the network update. Only either one of step_per_collect + and episode_per_collect can be specified. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. @@ -81,8 +80,6 @@ def onpolicy_trainer( :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. :param bool test_in_train: whether to test in the training phase. - :param string collect_method: specifies collect mode. Can be either "episode" - or "step". :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -104,8 +101,8 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect( - **{"n_" + collect_method: step_per_collect}) + result = train_collector.collect(n_step=step_per_collect, + n_episode=episode_per_collect) if reward_metric: result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) From cfdefe183e91a86dd3a8aab0901d8e2612987470 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 11:52:05 +0800 Subject: [PATCH 04/25] small fix --- docs/tutorials/dqn.rst | 2 +- tianshou/trainer/offline.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 31ee2ce46..57dc9c07d 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -136,7 +136,7 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``step_per_epoch``: The number of step for updating policy network in one epoch; +* ``step_per_epoch``: The number of environment frames collected per epoch; * ``step_per_collect``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 93597c876..e50f7f0cc 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -68,7 +68,8 @@ def offline_trainer( stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() test_collector.reset_stat() - + test_episode(policy, test_collector, test_fn, 0, + episode_per_test, writer, gradient_step, reward_metric) for epoch in range(1, 1 + max_epoch): policy.train() with tqdm.trange( From b9a7597785409211cc9434ba0f01baf0bcd51bf0 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 12:47:15 +0800 Subject: [PATCH 05/25] test fix --- test/discrete/test_a2c_with_il.py | 1 + test/discrete/test_dqn.py | 7 ++++--- test/multiagent/tic_tac_toe.py | 7 ++++--- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 96d727f15..020ff4f60 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -25,6 +25,7 @@ def get_args(): parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--episode-per-collect', type=int, default=8) + parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index f3923e6a9..3e3aee767 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -115,8 +116,8 @@ def test_fn(epoch, env_step): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 1b471f62e..0b9b66ddd 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -28,8 +28,9 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--step-per-epoch', type=int, default=5000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -164,8 +165,8 @@ def reward_metric(rews): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, - writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + writer=writer, test_in_train=False, reward_metric=reward_metric) return result, policy.policies[args.agent_id - 1] From 237f16aea0405a77bd3fe8a9e1181e3f634ad141 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 12:53:11 +0800 Subject: [PATCH 06/25] fix a bug --- tianshou/trainer/offline.py | 13 +++++++------ tianshou/trainer/offpolicy.py | 11 ++++++----- tianshou/trainer/onpolicy.py | 9 +++++---- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index e50f7f0cc..9f0ee98b6 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -88,16 +88,17 @@ def offline_trainer( global_step=gradient_step) t.set_postfix(**data) # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, gradient_step, reward_metric) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result['rews'].std() + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, gradient_step, reward_metric) + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 638b8b388..b0fed7cae 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -142,16 +142,17 @@ def offpolicy_trainer( if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, + test_result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step, reward_metric) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index c92bce6b1..eff1a2250 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -151,14 +151,15 @@ def onpolicy_trainer( # test test_result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < result["rews"].mean(): - best_reward, best_reward_std = result["rews"].mean(), result["rews"].std() + if best_epoch == -1 or best_reward < test_result["rews"].mean(): + best_reward = test_result["rews"].mean() + best_reward_std = test_result['rews'].std() best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rews'].mean():.6f} ± " - f"{result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " + print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " + f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"{best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break From 9ac0228fb45052b70a697c828801a0bdc9fa3684 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 12:57:31 +0800 Subject: [PATCH 07/25] pep8 fix --- tianshou/trainer/offline.py | 3 ++- tianshou/trainer/offpolicy.py | 4 ++-- tianshou/trainer/onpolicy.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 9f0ee98b6..6960ebcf2 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -89,7 +89,8 @@ def offline_trainer( t.set_postfix(**data) # test test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, gradient_step, reward_metric) + episode_per_test, writer, gradient_step, + reward_metric) if best_epoch == -1 or best_reward < test_result["rews"].mean(): best_reward = test_result["rews"].mean() best_reward_std = test_result['rews'].std() diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index b0fed7cae..775030522 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -88,7 +88,7 @@ def offpolicy_trainer( test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy test_episode(policy, test_collector, test_fn, 0, - episode_per_test, writer, env_step) + episode_per_test, writer, env_step, reward_metric) for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -143,7 +143,7 @@ def offpolicy_trainer( t.update() # test test_result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step, reward_metric) + episode_per_test, writer, env_step, reward_metric) if best_epoch == -1 or best_reward < test_result["rews"].mean(): best_reward = test_result["rews"].mean() best_reward_std = test_result['rews'].std() diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index eff1a2250..84ae6e332 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -91,7 +91,7 @@ def onpolicy_trainer( test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy test_episode(policy, test_collector, test_fn, 0, - episode_per_test, writer, env_step) + episode_per_test, writer, env_step, reward_metric) for epoch in range(1, 1 + max_epoch): # train policy.train() From 27568a30643182c0135a5911b7a5e95a2ac5053c Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 16:33:32 +0800 Subject: [PATCH 08/25] adjust update option to be consistent with history --- examples/atari/atari_c51.py | 6 ++++-- examples/atari/atari_dqn.py | 6 ++++-- examples/atari/atari_qrdqn.py | 6 ++++-- examples/box2d/mcc_sac.py | 8 +++++--- examples/mujoco/mujoco_sac.py | 8 ++++---- test/continuous/test_ddpg.py | 8 +++++--- test/continuous/test_sac_with_il.py | 8 +++++--- test/continuous/test_td3.py | 8 +++++--- test/discrete/test_a2c_with_il.py | 1 + test/discrete/test_c51.py | 3 ++- test/discrete/test_drqn.py | 2 +- test/discrete/test_qrdqn.py | 3 ++- test/discrete/test_sac.py | 3 ++- test/multiagent/tic_tac_toe.py | 2 +- 14 files changed, 45 insertions(+), 27 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 9c046c50e..303befa6b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -30,8 +30,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -143,7 +144,8 @@ def watch(): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step,test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 2b0153891..e3807c98f 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -153,7 +154,8 @@ def watch(): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step,test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 8a6d97fa4..96626922e 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) @@ -141,7 +142,8 @@ def watch(): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step,test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index cb31b7001..bdc98cfde 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -29,8 +29,9 @@ def get_args(): parser.add_argument('--auto_alpha', type=int, default=1) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--step-per-epoch', type=int, default=12000) parser.add_argument('--step-per-collect', type=int, default=5) + parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -112,8 +113,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 558d24c05..22dda7d9b 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--n-step', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-epoch', type=int, default=40000) parser.add_argument('--step-per-collect', type=int, default=4) + parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--update-per-step', type=int, default=1) parser.add_argument('--pre-collect-step', type=int, default=10000) parser.add_argument('--batch-size', type=int, default=256) @@ -140,9 +141,8 @@ def stop_fn(mean_rewards): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, args.update_per_step, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - log_interval=args.log_interval) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step, log_interval=args.log_interval) pprint.pprint(result) watch() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 2ad173beb..afe88e16b 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--step-per-epoch', type=int, default=9600) parser.add_argument('--step-per-collect', type=int, default=4) + parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -102,8 +103,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 71086894e..1fde511fc 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--step-per-epoch', type=int, default=24000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -110,8 +111,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index fc5bf037e..8a46e9005 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -29,8 +29,9 @@ def get_args(): parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--step-per-epoch', type=int, default=24000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -115,8 +116,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 020ff4f60..a49a77448 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -26,6 +26,7 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 3f801be90..fe0573e1d 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -28,8 +28,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=8000) parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5eddb0e8d..7dd73df44 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 942d4ee07..3a5324a67 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 3f7f2a39a..f9ecdaf94 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.05) parser.add_argument('--auto_alpha', type=int, default=0) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=5000) parser.add_argument('--step-per-collect', type=int, default=5) + parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 0b9b66ddd..8a7eadf94 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -28,7 +28,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) From 044f909934edf341f8ed768fc7472b7e574a8701 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 16:42:33 +0800 Subject: [PATCH 09/25] pep8 fix --- examples/atari/atari_c51.py | 2 +- examples/atari/atari_dqn.py | 2 +- examples/atari/atari_qrdqn.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 303befa6b..d0a7ab81d 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -145,7 +145,7 @@ def watch(): args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - update_per_step=args.update_per_step,test_in_train=False) + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index e3807c98f..077d24891 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -155,7 +155,7 @@ def watch(): args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - update_per_step=args.update_per_step,test_in_train=False) + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 96626922e..e2eed3cfd 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -143,7 +143,7 @@ def watch(): args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - update_per_step=args.update_per_step,test_in_train=False) + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() From fe412acddc48b4cdf51ccf6a59fe0d3c3ea00e7e Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 17:07:32 +0800 Subject: [PATCH 10/25] some other change --- test/continuous/test_td3.py | 2 +- test/discrete/test_drqn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 8a46e9005..bbc32d912 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -29,7 +29,7 @@ def get_args(): parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--update-actor-freq', type=int, default=2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--step-per-epoch', type=int, default=20000) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 7dd73df44..dc2e06c00 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -26,7 +26,7 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) From 62622399cabe6b9b1ef0dadff02f4f8ca994e1b4 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 18:18:42 +0800 Subject: [PATCH 11/25] fix test --- test/discrete/test_sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index f9ecdaf94..ebcb75157 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -111,7 +111,7 @@ def stop_fn(mean_rewards): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + update_per_step=args.update_per_step, test_in_train=False) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) From 75750e800fcaf9aace76f2fe9fec9f39a6691e74 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 18:22:27 +0800 Subject: [PATCH 12/25] update change --- test/discrete/test_qrdqn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 3a5324a67..e5ce61b98 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -113,7 +113,8 @@ def test_fn(epoch, env_step): policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + update_per_step=args.update_per_step) assert stop_fn(result['best_reward']) if __name__ == '__main__': From 602fa2de3d32ff3453e48678a1bb5156e751b38f Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 19:01:44 +0800 Subject: [PATCH 13/25] restart test --- test/discrete/test_dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 3e3aee767..e9104c152 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -25,7 +25,7 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=10000) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) From 7642f44f269b37f43ad55e55320192c77169c427 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 19 Feb 2021 20:13:25 +0800 Subject: [PATCH 14/25] fix --- docs/tutorials/tictactoe.rst | 11 ++++++----- test/multiagent/tic_tac_toe.py | 4 ++-- tianshou/trainer/onpolicy.py | 8 ++++---- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index c0f683623..3d7f28106 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -200,8 +200,9 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -293,7 +294,7 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}') + print(f'Final reward: {result["rews"][:, args.agent_id - 1].mean()}, length: {result["lens"].mean()}') if args.watch: watch(args) exit(0) @@ -357,8 +358,8 @@ With the above preparation, we are close to the first learned agent. The followi policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric, - writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + writer=writer, test_in_train=False, reward_metric=reward_metric) agent = policy.policies[args.agent_id - 1] # let's watch the match! diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 8a7eadf94..edf066e09 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -28,7 +28,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--step-per-epoch', type=int, default=5000) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) @@ -184,4 +184,4 @@ def watch( collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 84ae6e332..185560c76 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -52,11 +52,11 @@ def onpolicy_trainer( :param int batch_size: the batch size of sample data, which is going to feed in the policy network. :param int step_per_collect: the number of episodes the collector would - collect before the network update. Only either one of step_per_collect - and episode_per_collect can be specified. + collect before the network update. Only either one of step_per_collect + and episode_per_collect can be specified. :param int episode_per_collect: the number of episodes the collector would - collect before the network update. Only either one of step_per_collect - and episode_per_collect can be specified. + collect before the network update. Only either one of step_per_collect + and episode_per_collect can be specified. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature ``f(num_epoch: int, step_idx: int) -> None``. From cccea728f08f7b1d3609aaffe30f01e6aa966798 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 22:38:34 +0800 Subject: [PATCH 15/25] solve review --- tianshou/trainer/offline.py | 8 +++++--- tianshou/trainer/offpolicy.py | 8 +++++--- tianshou/trainer/onpolicy.py | 10 ++++++---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 6960ebcf2..760d14bcc 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -64,12 +64,14 @@ def offline_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ gradient_step = 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() test_collector.reset_stat() - test_episode(policy, test_collector, test_fn, 0, - episode_per_test, writer, gradient_step, reward_metric) + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, gradient_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): policy.train() with tqdm.trange( diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 775030522..d7cf10485 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -81,14 +81,16 @@ def offpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_episode(policy, test_collector, test_fn, 0, - episode_per_test, writer, env_step, reward_metric) + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): # train policy.train() diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 185560c76..3cc26539e 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -51,7 +51,7 @@ def onpolicy_trainer( :type episode_per_test: int or list of ints :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param int step_per_collect: the number of episodes the collector would + :param int step_per_collect: the number of frames the collector would collect before the network update. Only either one of step_per_collect and episode_per_collect can be specified. :param int episode_per_collect: the number of episodes the collector would @@ -84,14 +84,16 @@ def onpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy - test_episode(policy, test_collector, test_fn, 0, - episode_per_test, writer, env_step, reward_metric) + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + writer, env_step, reward_metric) + best_epoch = 0 + best_reward = test_result["rews"].mean() + best_reward_std = test_result["rews"].std() for epoch in range(1, 1 + max_epoch): # train policy.train() From 670d3df5aea46689cca6b90e7d1bf436a7f74811 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Fri, 19 Feb 2021 22:40:59 +0800 Subject: [PATCH 16/25] fix review --- tianshou/trainer/offline.py | 2 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 760d14bcc..4e14afe24 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -29,7 +29,7 @@ def offline_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment frame. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param test_collector: the collector used for testing. diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index d7cf10485..6df269970 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -33,7 +33,7 @@ def offpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment frame. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param train_collector: the collector used for training. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 3cc26539e..8911619c4 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -34,7 +34,7 @@ def onpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment frame. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param train_collector: the collector used for training. From 3abdaff77d71f3aaf4eeab021b0781ecd6432b6f Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Sat, 20 Feb 2021 08:59:48 +0800 Subject: [PATCH 17/25] Update tianshou/trainer/offline.py Co-authored-by: n+e --- tianshou/trainer/offline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 4e14afe24..8a3c1a931 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -29,7 +29,7 @@ def offline_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. - The "step" in trainer means an environment frame. + The "step" in offline trainer means a gradient step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param test_collector: the collector used for testing. From d383229e151897363faa97863073a187dfc37613 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Sat, 20 Feb 2021 09:06:34 +0800 Subject: [PATCH 18/25] update option in box2d --- examples/box2d/acrobot_dualdqn.py | 7 ++++--- examples/box2d/bipedal_hardcore_sac.py | 9 +++++---- examples/box2d/lunarlander_dqn.py | 10 +++++----- examples/box2d/mcc_sac.py | 1 + 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 68c2eac88..527ae752f 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -25,8 +25,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument('--step-per-collect', type=int, default=100) + parser.add_argument('--update-per-step', type=float, default=0.01) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--dueling-q-hidden-sizes', type=int, @@ -103,8 +104,8 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index a0baf575f..a903008ab 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -27,8 +27,9 @@ def get_args(): parser.add_argument('--auto-alpha', type=int, default=1) parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-epoch', type=int, default=100000) parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -143,9 +144,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, test_in_train=False, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) if __name__ == '__main__': pprint.pprint(result) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 5a4c61b42..3d5d033f2 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -26,8 +26,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-epoch', type=int, default=80000) parser.add_argument('--step-per-collect', type=int, default=16) + parser.add_argument('--update-per-step', type=float, default=0.0625) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -99,10 +100,9 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, + test_fn=test_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index bdc98cfde..333dab49c 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -116,6 +116,7 @@ def stop_fn(mean_rewards): args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) From f3727a0b2cd5dcd963bc81f334aca2326f3e42a1 Mon Sep 17 00:00:00 2001 From: chy <308604256@qq.com> Date: Sat, 20 Feb 2021 09:59:03 +0800 Subject: [PATCH 19/25] pep8fix --- examples/box2d/acrobot_dualdqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 527ae752f..69d0bfbee 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -105,7 +105,7 @@ def test_fn(epoch, env_step): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, - update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, + update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) From 9f7d410a145cd3f2a8ea663c173f5da7216c5733 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 16:06:06 +0800 Subject: [PATCH 20/25] fix test --- test/continuous/test_ppo.py | 2 +- test/continuous/test_sac_with_il.py | 3 ++- test/discrete/test_a2c_with_il.py | 8 ++++---- test/discrete/test_pg.py | 2 +- test/discrete/test_ppo.py | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 503b8eb73..4f8ede1a0 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -23,7 +23,7 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--step-per-epoch', type=int, default=150000) parser.add_argument('--episode-per-collect', type=int, default=16) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 1fde511fc..0a96dbfa9 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -27,6 +27,7 @@ def get_args(): parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--il-step-per-epoch', type=int, default=500) parser.add_argument('--step-per-collect', type=int, default=10) parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) @@ -144,7 +145,7 @@ def stop_fn(mean_rewards): train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch // 5, args.step_per_collect, args.test_num, + args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index a49a77448..1032b3176 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -23,7 +23,8 @@ def get_args(): parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--il-step-per-epoch', type=int, default=1000) parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--step-per-collect', type=int, default=8) parser.add_argument('--update-per-step', type=float, default=0.125) @@ -123,13 +124,12 @@ def stop_fn(mean_rewards): il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector( il_policy, - DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) ) train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch, args.step_per_collect, args.test_num, + args.il_step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index b8cb6367e..784ae70db 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -21,7 +21,7 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-epoch', type=int, default=40000) parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 316e4c69c..35634c675 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -22,7 +22,7 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=2000) + parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--episode-per-collect', type=int, default=20) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) From 2d9fc0256ece08d8379f7ec04f5934db52f77107 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 16:24:26 +0800 Subject: [PATCH 21/25] fix doc --- README.md | 6 +++--- docs/tutorials/dqn.rst | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d543515db..80ee2ff12 100644 --- a/README.md +++ b/README.md @@ -191,11 +191,11 @@ Define some hyper-parameters: ```python task = 'CartPole-v0' lr, epoch, batch_size = 1e-3, 10, 64 -train_num, test_num = 8, 100 +train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, step_per_collect = 1000, 8 +step_per_epoch, step_per_collect = 10000, 10 writer = SummaryWriter('log/dqn') # tensorboard is also supported! ``` @@ -233,7 +233,7 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect, - test_num, batch_size, + test_num, batch_size, update_per_step=1 / step_per_collect, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 57dc9c07d..bb99f3032 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -35,10 +35,10 @@ If you want to use the original ``gym.Env``: Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`) :: - train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)]) -Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``. +Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``. For the demonstration, here we use the second code-block. @@ -87,7 +87,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -It is also possible to use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: +You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. 2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. @@ -113,7 +113,7 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. :: - train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 8), exploration_noise=True) + train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True) test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) @@ -125,8 +125,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, - max_epoch=10, step_per_epoch=1000, step_per_collect=10, - episode_per_test=100, batch_size=64, + max_epoch=10, step_per_epoch=10000, step_per_collect=10, + update_per_step=0.1, episode_per_test=100, batch_size=64, train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda epoch, env_step: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, From 28f5f8fda11b3adc252280cc9e61f87f3cf6dc29 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 16:59:22 +0800 Subject: [PATCH 22/25] fix doc --- tianshou/policy/base.py | 34 ++++++---------- tianshou/policy/random.py | 2 +- tianshou/trainer/offline.py | 37 +++++++++-------- tianshou/trainer/offpolicy.py | 66 +++++++++++++++--------------- tianshou/trainer/onpolicy.py | 75 ++++++++++++++++++----------------- tianshou/trainer/utils.py | 4 +- 6 files changed, 104 insertions(+), 114 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index be6d8216b..451bb3f84 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,21 +220,17 @@ def compute_episodic_return( Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. - :param batch: a data batch which contains several episodes of data + :param Batch atch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). - :type batch: :class:`~tianshou.data.Batch` - :param numpy.ndarray indice: tell batch's location in buffer, batch is + :param np.ndarray indice: tell batch's location in buffer, batch is equal to buffer[indice]. - :param v_s_: the value function of all next states :math:`V(s')`. - :type v_s_: numpy.ndarray - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param float gae_lambda: the parameter for Generalized Advantage - Estimation, should be in [0, 1], defaults to 0.95. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param float gae_lambda: the parameter for Generalized Advantage Estimation, + should be in [0, 1]. Default to 0.95. + :param bool rew_norm: normalize the reward to Normal(0, 1). Default to False. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). @@ -273,18 +269,14 @@ def compute_nstep_return( where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step :math:`t`. - :param batch: a data batch, which is equal to buffer[indice]. - :type batch: :class:`~tianshou.data.Batch` - :param buffer: the data buffer. - :type buffer: :class:`~tianshou.data.ReplayBuffer` + :param Batch batch: a data batch, which is equal to buffer[indice]. + :param ReplayBuffer buffer: the data buffer. :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param int n_step: the number of estimation step, should be an int - greater than 0, defaults to 1. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param int n_step: the number of estimation step, should be an int greater + than 0. Default to 1. + :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor. diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 13f9159f8..9c7f132af 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -38,5 +38,5 @@ def forward( return Batch(act=logits.argmax(axis=-1)) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - """Since a random agent learn nothing, it returns an empty dict.""" + """Since a random agent learns nothing, it returns an empty dict.""" return {} diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 8a3c1a931..9e68295b8 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -32,34 +32,33 @@ def offline_trainer( The "step" in offline trainer means a gradient step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` sets. :param int update_per_epoch: the number of policy network updates, so-called gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 6df269970..6df9df549 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -33,50 +33,48 @@ def offpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. - The "step" in trainer means an environment frame. + The "step" in trainer means an environment step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` sets. :param int step_per_epoch: the number of environment frames collected per epoch. - :param int step_per_collect: the number of frames the collector would - collect before the network update. In other words, collect some frames - and do some policy network update. + :param int step_per_collect: the number of frames the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" frames + and do some policy network update repeatly in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param int/float update_per_step: the number of times the policy network would - be updated per environment frame after (step_per_collect) frames are collected, - for example, if update_per_step set to 0.3, and step_per_collect is 256, - policy will be updated round(256 * 0.3 = 76.8) = 77 times after 256 frames are - collected by the collector. Default to 1. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int/float update_per_step: the number of times the policy network would be + updated per environment frame after (step_per_collect) frames are collected, + e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will + be updated round(256 * 0.3 = 76.8) = 77 times after 256 frames are collected by + the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy:BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 8911619c4..f6319cc1b 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -34,54 +34,55 @@ def onpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. - The "step" in trainer means an environment frame. + The "step" in trainer means an environment step. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` sets. :param int step_per_epoch: the number of environment frames collected per epoch. - :param int repeat_per_collect: the number of repeat time for policy - learning, for example, set it to 2 means the policy needs to learn each - given batch data twice. - :param episode_per_test: the number of episodes for one policy evaluation. - :type episode_per_test: int or list of ints - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param int step_per_collect: the number of frames the collector would - collect before the network update. Only either one of step_per_collect - and episode_per_collect can be specified. - :param int episode_per_collect: the number of episodes the collector would - collect before the network update. Only either one of step_per_collect - and episode_per_collect can be specified. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. + :param int repeat_per_collect: the number of repeat time for policy learning, for + example, set it to 2 means the policy needs to learn each given batch data + twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int step_per_collect: the number of frames the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" frames + and do some policy network update repeatly in each epoch. + :param int episode_per_collect: the number of episodes the collector would collect + before the network update, i.e., trainer will collect "episode_per_collect" + episodes and do some policy network update repeatly in each epoch. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. :param function reward_metric: a function with signature ``f(rewards: np.ndarray with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, used in multi-agent RL. We need to return a single scalar for each episode's result to monitor training in the multi-agent RL setting. This function specifies what is the desired metric, e.g., the reward of agent 1 or the average reward over all agents. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; + if None is given, it will not write logs to TensorBoard. Default to None. + :param int log_interval: the log interval of the writer. Default to 1. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. + + .. note:: + + Only either one of step_per_collect and episode_per_collect can be specified. """ env_step, gradient_step = 0, 0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 2cdeb15fe..8162d603a 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -51,11 +51,11 @@ def gather_info( * ``train_time/collector`` the time for collecting frames in the \ training collector; * ``train_time/model`` the time for training models; - * ``train_speed`` the speed of training (frames per second); + * ``train_speed`` the speed of training (env_step per second); * ``test_step`` the total collected step of test collector; * ``test_episode`` the total collected episode of test collector; * ``test_time`` the time for testing; - * ``test_speed`` the speed of testing (frames per second); + * ``test_speed`` the speed of testing (env_step per second); * ``best_reward`` the best reward over the test results; * ``duration`` the total elapsed time. """ From 6330c51609ff28f8dd747d52ce3c30646c6a6b52 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 17:03:24 +0800 Subject: [PATCH 23/25] fix doc --- tianshou/trainer/offline.py | 2 +- tianshou/trainer/offpolicy.py | 2 +- tianshou/trainer/onpolicy.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index 9e68295b8..61714f7a0 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -34,7 +34,7 @@ def offline_trainer( :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector test_collector: the collector used for testing. :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` sets. + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int update_per_epoch: the number of policy network updates, so-called gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 6df9df549..d4a307673 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -39,7 +39,7 @@ def offpolicy_trainer( :param Collector train_collector: the collector used for training. :param Collector test_collector: the collector used for testing. :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` sets. + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int step_per_epoch: the number of environment frames collected per epoch. :param int step_per_collect: the number of frames the collector would collect before the network update, i.e., trainer will collect "step_per_collect" frames diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index f6319cc1b..c143441d0 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -40,7 +40,7 @@ def onpolicy_trainer( :param Collector train_collector: the collector used for training. :param Collector test_collector: the collector used for testing. :param int max_epoch: the maximum number of epochs for training. The training - process might be finished before reaching ``max_epoch`` if ``stop_fn`` sets. + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. :param int step_per_epoch: the number of environment frames collected per epoch. :param int repeat_per_collect: the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data From 4597367dd5a5c7183f1bc23be83d97d69f925fdd Mon Sep 17 00:00:00 2001 From: n+e Date: Sat, 20 Feb 2021 20:16:04 +0800 Subject: [PATCH 24/25] Update tianshou/policy/base.py Co-authored-by: danagi <420147879@qq.com> --- tianshou/policy/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 451bb3f84..9023bf47b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -220,7 +220,7 @@ def compute_episodic_return( Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) to calculate q function/reward to go of given batch. - :param Batch atch: a data batch which contains several episodes of data + :param Batch batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be recongized by buffer.unfinished_index(). From fdb4e5787d15788cd620fefa5d224cb042f7303d Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 20 Feb 2021 23:53:33 +0800 Subject: [PATCH 25/25] replace frame with transition --- docs/tutorials/concepts.rst | 2 +- docs/tutorials/dqn.rst | 6 +++--- tianshou/data/buffer.py | 4 ++-- tianshou/data/collector.py | 10 +++++----- tianshou/policy/modelfree/pg.py | 2 +- tianshou/trainer/offpolicy.py | 16 ++++++++-------- tianshou/trainer/onpolicy.py | 10 +++++----- tianshou/trainer/utils.py | 2 +- 8 files changed, 26 insertions(+), 26 deletions(-) diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index b3e126352..26ee8d285 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -284,7 +284,7 @@ policy.process_fn The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns. -Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: +Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as: .. math:: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index bb99f3032..40e4a399f 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -136,8 +136,8 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``step_per_epoch``: The number of environment frames collected per epoch; -* ``step_per_collect``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; +* ``step_per_epoch``: The number of environment step (a.k.a. transition) collected per epoch; +* ``step_per_collect``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". @@ -205,7 +205,7 @@ Train a Policy with Customized Codes Tianshou supports user-defined training code. Here is the code snippet: :: - # pre-collect at least 5000 frames with random action before training + # pre-collect at least 5000 transitions with random action before training train_collector.collect(n_step=5000, random=True) policy.set_eps(0.1) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index b24e8c6b2..477a2531a 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -605,7 +605,7 @@ def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: class VectorReplayBuffer(ReplayBufferManager): """VectorReplayBuffer contains n ReplayBuffer with the same size. - It is used for storing data frame from different environments yet keeping the order + It is used for storing transition from different environments yet keeping the order of time. :param int total_size: the total size of VectorReplayBuffer. @@ -631,7 +631,7 @@ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. - It is used for storing data frame from different environments yet keeping the order + It is used for storing transition from different environments yet keeping the order of time. :param int total_size: the total size of PrioritizedVectorReplayBuffer. diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 74ee72d11..bb3239e0e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -198,7 +198,7 @@ def collect( if not n_step % self.env_num == 0: warnings.warn( f"n_step={n_step} is not a multiple of #env ({self.env_num}), " - "which may cause extra frame collected into the buffer." + "which may cause extra transitions collected into the buffer." ) ready_env_ids = np.arange(self.env_num) elif n_episode is not None: @@ -357,9 +357,9 @@ def collect( ) -> Dict[str, Any]: """Collect a specified number of step or episode with async env setting. - This function doesn't collect exactly n_step or n_episode number of frames. - Instead, in order to support async setting, it may collect more than given - n_step or n_episode frames and save into buffer. + This function doesn't collect exactly n_step or n_episode number of + transitions. Instead, in order to support async setting, it may collect more + than given n_step or n_episode transitions and save into buffer. :param int n_step: how many steps you want to collect. :param int n_episode: how many episodes you want to collect. @@ -395,7 +395,7 @@ def collect( else: raise TypeError("Please specify at least one (either n_step or n_episode) " "in AsyncCollector.collect().") - warnings.warn("Using async setting may collect extra frames into buffer.") + warnings.warn("Using async setting may collect extra transitions into buffer.") ready_env_ids = self._ready_env_ids diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 92b7f5d89..82fb9f704 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -45,7 +45,7 @@ def __init__( def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - r"""Compute the discounted returns for each frame. + r"""Compute the discounted returns for each transition. .. math:: G_t = \sum_{i=t}^T \gamma^{i-t}r_i diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index d4a307673..54e7cb166 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -33,25 +33,25 @@ def offpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. - The "step" in trainer means an environment step. + The "step" in trainer means an environment step (a.k.a. transition). :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. :param Collector test_collector: the collector used for testing. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int step_per_epoch: the number of environment frames collected per epoch. - :param int step_per_collect: the number of frames the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" frames - and do some policy network update repeatly in each epoch. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. :param int/float update_per_step: the number of times the policy network would be - updated per environment frame after (step_per_collect) frames are collected, + updated per transition after (step_per_collect) transitions are collected, e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will - be updated round(256 * 0.3 = 76.8) = 77 times after 256 frames are collected by - the collector. Default to 1. + be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are + collected by the collector. Default to 1. :param function train_fn: a hook called at the beginning of training in each epoch. It can be used to perform custom additional operations, with the signature ``f( num_epoch: int, step_idx: int) -> None``. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index c143441d0..43fcc8738 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -34,23 +34,23 @@ def onpolicy_trainer( ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. - The "step" in trainer means an environment step. + The "step" in trainer means an environment step (a.k.a. transition). :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param Collector train_collector: the collector used for training. :param Collector test_collector: the collector used for testing. :param int max_epoch: the maximum number of epochs for training. The training process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. - :param int step_per_epoch: the number of environment frames collected per epoch. + :param int step_per_epoch: the number of transitions collected per epoch. :param int repeat_per_collect: the number of repeat time for policy learning, for example, set it to 2 means the policy needs to learn each given batch data twice. :param int episode_per_test: the number of episodes for one policy evaluation. :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param int step_per_collect: the number of frames the collector would collect - before the network update, i.e., trainer will collect "step_per_collect" frames - and do some policy network update repeatly in each epoch. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. :param int episode_per_collect: the number of episodes the collector would collect before the network update, i.e., trainer will collect "episode_per_collect" episodes and do some policy network update repeatly in each epoch. diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 8162d603a..72803bef0 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -48,7 +48,7 @@ def gather_info( * ``train_step`` the total collected step of training collector; * ``train_episode`` the total collected episode of training collector; - * ``train_time/collector`` the time for collecting frames in the \ + * ``train_time/collector`` the time for collecting transitions in the \ training collector; * ``train_time/model`` the time for training models; * ``train_speed`` the speed of training (env_step per second);