From 8546de3c9b57a440a5f1c07d9159c846ad4fdfe9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 25 Sep 2020 16:21:56 +0800 Subject: [PATCH 1/8] change train_fn(epoch) -> train_fn(env_step) and test_fn(epoch) -> test_fn() --- README.md | 7 ++-- docs/tutorials/dqn.rst | 10 +++--- docs/tutorials/tictactoe.rst | 8 ++--- examples/atari/atari_dqn.py | 17 +++++---- examples/atari/runnable/pong_a2c.py | 6 ++-- examples/atari/runnable/pong_ppo.py | 6 ++-- examples/box2d/acrobot_dualdqn.py | 20 +++++------ examples/box2d/bipedal_hardcore_sac.py | 8 ++--- examples/box2d/lunarlander_dqn.py | 14 ++++---- examples/box2d/mcc_sac.py | 4 +-- examples/mujoco/ant_v2_ddpg.py | 4 +-- examples/mujoco/ant_v2_sac.py | 4 +-- examples/mujoco/ant_v2_td3.py | 4 +-- examples/mujoco/halfcheetahBullet_v0_sac.py | 4 +-- examples/mujoco/point_maze_td3.py | 4 +-- test/continuous/test_ddpg.py | 4 +-- test/continuous/test_ppo.py | 4 +-- test/continuous/test_sac_with_il.py | 4 +-- test/continuous/test_td3.py | 4 +-- test/discrete/test_a2c_with_il.py | 4 +-- test/discrete/test_dqn.py | 16 ++++----- test/discrete/test_drqn.py | 8 ++--- test/discrete/test_pg.py | 4 +-- test/discrete/test_ppo.py | 4 +-- test/discrete/test_sac.py | 4 +-- test/modelbase/test_psrl.py | 4 +-- test/multiagent/tic_tac_toe.py | 39 +++++++++++---------- tianshou/trainer/offpolicy.py | 15 ++++---- tianshou/trainer/onpolicy.py | 17 ++++----- tianshou/trainer/utils.py | 4 +-- 30 files changed, 126 insertions(+), 129 deletions(-) diff --git a/README.md b/README.md index 360043418..55c073e39 100644 --- a/README.md +++ b/README.md @@ -229,9 +229,10 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, - test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train), - test_fn=lambda e: policy.set_eps(eps_test), - stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task) + test_num, batch_size, train_fn=lambda env_step: policy.set_eps(eps_train), + test_fn=lambda: policy.set_eps(eps_test), + stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, + writer=writer, task=task) print(f'Finished training! Use {result["duration"]}') ``` diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index d923b5669..90fd20050 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -123,9 +123,9 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians policy, train_collector, test_collector, max_epoch=10, step_per_epoch=1000, collect_per_step=10, episode_per_test=100, batch_size=64, - train_fn=lambda e: policy.set_eps(0.1), - test_fn=lambda e: policy.set_eps(0.05), - stop_fn=lambda x: x >= env.spec.reward_threshold, + train_fn=lambda env_step: policy.set_eps(0.1), + test_fn=lambda: policy.set_eps(0.05), + stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=None) print(f'Finished training! Use {result["duration"]}') @@ -136,8 +136,8 @@ The meaning of each parameter is as follows (full description can be found at :m * ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. -* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". -* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". +* ``train_fn``: A function receives the current number of step index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". +* ``test_fn``: A function performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". * ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. * ``writer``: See below. diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index cc4116deb..a19713442 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -334,15 +334,15 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].state_dict(), model_save_path) - def stop_fn(x): - return x >= args.win_rate # 95% winning rate by default + def stop_fn(mean_rewards): + return mean_rewards >= args.win_rate # 95% winning rate by default # the default args.win_rate is 0.9, but the reward is [-1, 1] # instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate. - def train_fn(x): + def train_fn(env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(x): + def test_fn(): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # start training, this may require about three minutes diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 92bc04551..5ba3cc593 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -95,26 +95,25 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): + def stop_fn(mean_rewards): if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: - return x >= 20 + return mean_rewards >= 20 else: return False - def train_fn(x): + def train_fn(env_step): # nature DQN setting, linear decay in the first 1M steps - now = x * args.collect_per_step * args.step_per_epoch - if now <= 1e6: - eps = args.eps_train - now / 1e6 * \ + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=now) + writer.add_scalar('train/eps', eps, global_step=env_step) - def test_fn(x): + def test_fn(): policy.set_eps(args.eps_test) # watch agent's performance diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index f4b0a3031..55ed15a0d 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -76,11 +76,11 @@ def test_a2c(args=get_args()): preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(args.logdir + '/' + 'a2c') + writer = SummaryWriter(args.logdir + '/a2c') - def stop_fn(x): + def stop_fn(mean_rewards): if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 9d5563fe1..109e8130b 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -80,11 +80,11 @@ def test_ppo(args=get_args()): preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(args.logdir + '/' + 'ppo') + writer = SummaryWriter(args.logdir + '/ppo') - def stop_fn(x): + def stop_fn(mean_rewards): if env.env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 6345d62eb..b0afe4ee3 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -6,11 +6,11 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter -from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.common import Net def get_args(): @@ -75,20 +75,20 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): - if x <= int(0.1 * args.epoch): + def train_fn(env_step): + if env_step <= 100000: policy.set_eps(args.eps_train) - elif x <= int(0.5 * args.epoch): - eps = args.eps_train - (x - 0.1 * args.epoch) / \ - (0.4 * args.epoch) * (0.5 * args.eps_train) + elif env_step <= 500000: + eps = args.eps_train - (env_step - 100000) / \ + 400000 * (0.5 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.5 * args.eps_train) - def test_fn(x): + def test_fn(): policy.set_eps(args.eps_test) # trainer diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index ffd3e8fef..b4da185a4 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -74,9 +74,6 @@ def step(self, action): def test_sac_bipedal(args=get_args()): env = EnvWrapper(args.task) - def IsStop(reward): - return reward >= env.spec.reward_threshold - args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -141,11 +138,14 @@ def IsStop(reward): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) if __name__ == '__main__': diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index aa0f5888c..e111fb730 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -18,7 +18,7 @@ def get_args(): # the parameters are found by Optuna parser.add_argument('--task', type=str, default='LunarLander-v2') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-test', type=float, default=0.01) parser.add_argument('--eps-train', type=float, default=0.73) parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument('--lr', type=float, default=0.013) @@ -77,14 +77,14 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): - args.eps_train = max(args.eps_train * 0.6, 0.01) - policy.set_eps(args.eps_train) + def train_fn(env_step): # exp decay + eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) + policy.set_eps(eps) - def test_fn(x): + def test_fn(): policy.set_eps(args.eps_test) # trainer diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 9ca6845b5..b9481a5cf 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -98,8 +98,8 @@ def test_sac(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/ant_v2_ddpg.py b/examples/mujoco/ant_v2_ddpg.py index 948ceee4c..528ae4bf4 100644 --- a/examples/mujoco/ant_v2_ddpg.py +++ b/examples/mujoco/ant_v2_ddpg.py @@ -77,8 +77,8 @@ def test_ddpg(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + 'ddpg') - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/ant_v2_sac.py b/examples/mujoco/ant_v2_sac.py index a86bcffbf..156be7e37 100644 --- a/examples/mujoco/ant_v2_sac.py +++ b/examples/mujoco/ant_v2_sac.py @@ -86,8 +86,8 @@ def test_sac(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/ant_v2_td3.py b/examples/mujoco/ant_v2_td3.py index 7165315d5..370e591c0 100644 --- a/examples/mujoco/ant_v2_td3.py +++ b/examples/mujoco/ant_v2_td3.py @@ -88,8 +88,8 @@ def test_td3(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + 'td3') - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/halfcheetahBullet_v0_sac.py b/examples/mujoco/halfcheetahBullet_v0_sac.py index 97b3bc74f..41c33dcf3 100644 --- a/examples/mujoco/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/halfcheetahBullet_v0_sac.py @@ -91,8 +91,8 @@ def test_sac(args=get_args()): log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) writer = SummaryWriter(log_path) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/examples/mujoco/point_maze_td3.py b/examples/mujoco/point_maze_td3.py index 6de2b20bf..48a26e265 100644 --- a/examples/mujoco/point_maze_td3.py +++ b/examples/mujoco/point_maze_td3.py @@ -91,9 +91,9 @@ def test_td3(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + 'td3') - def stop_fn(x): + def stop_fn(mean_rewards): if env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 979444fd1..09915629c 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -96,8 +96,8 @@ def test_ddpg(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index bee5af4e7..ef3692aab 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -113,8 +113,8 @@ def dist(*logits): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 20679739f..009218cc0 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -96,8 +96,8 @@ def test_sac_with_il(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index a6215e08e..847971474 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -103,8 +103,8 @@ def test_td3(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index dda770419..b0c31e3cb 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -84,8 +84,8 @@ def test_a2c_with_il(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 4d28d3828..8fcafddc8 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -85,21 +85,21 @@ def test_dqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): + def train_fn(env_step): # eps annnealing, just a demo - if x <= int(0.1 * args.epoch): + if env_step <= 10000: policy.set_eps(args.eps_train) - elif x <= int(0.5 * args.epoch): - eps = args.eps_train - (x - 0.1 * args.epoch) / \ - (0.4 * args.epoch) * (0.9 * args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.1 * args.eps_train) - def test_fn(x): + def test_fn(): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5ef6c1624..e1daa13ac 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -79,13 +79,13 @@ def test_drqn(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold - def train_fn(x): + def train_fn(env_step): policy.set_eps(args.eps_train) - def test_fn(x): + def test_fn(): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 3604adbc6..d84130b35 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -73,8 +73,8 @@ def test_pg(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index c8d849448..f6d23fecd 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -98,8 +98,8 @@ def test_ppo(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = onpolicy_trainer( diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 865924a1a..4b8607386 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -93,8 +93,8 @@ def test_discrete_sac(args=get_args()): def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) - def stop_fn(x): - return x >= env.spec.reward_threshold + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 6fb0e16ad..29dfb6b8c 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -66,9 +66,9 @@ def test_psrl(args=get_args()): # log writer = SummaryWriter(args.logdir + '/' + args.task) - def stop_fn(x): + def stop_fn(mean_rewards): if env.spec.reward_threshold: - return x >= env.spec.reward_threshold + return mean_rewards >= env.spec.reward_threshold else: return False diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 9110b9d5b..a5a608d58 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -64,11 +64,12 @@ def get_args() -> argparse.Namespace: return args -def get_agents(args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, - ) -> Tuple[BasePolicy, torch.optim.Optimizer]: +def get_agents( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[BasePolicy, torch.optim.Optimizer]: env = TicTacToeEnv(args.board_size, args.win_size) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n @@ -99,11 +100,12 @@ def get_agents(args: argparse.Namespace = get_args(), return policy, optim -def train_agent(args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, - ) -> Tuple[dict, BasePolicy]: +def train_agent( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[dict, BasePolicy]: def env_func(): return TicTacToeEnv(args.board_size, args.win_size) train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)]) @@ -142,13 +144,13 @@ def save_fn(policy): policy.policies[args.agent_id - 1].state_dict(), model_save_path) - def stop_fn(x): - return x >= args.win_rate + def stop_fn(mean_rewards): + return mean_rewards >= args.win_rate - def train_fn(x): + def train_fn(env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(x): + def test_fn(): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # trainer @@ -162,10 +164,11 @@ def test_fn(x): return result, policy.policies[args.agent_id - 1] -def watch(args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - ) -> None: +def watch( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, +) -> None: env = TicTacToeEnv(args.board_size, args.win_size) policy, optim = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 170fd6835..bb5d411cb 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -20,7 +20,7 @@ def offpolicy_trainer( batch_size: int, update_per_step: int = 1, train_fn: Optional[Callable[[int], None]] = None, - test_fn: Optional[Callable[[int], None]] = None, + test_fn: Optional[Callable[[], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, @@ -52,12 +52,11 @@ def offpolicy_trainer( be updated after frames are collected, for example, set it to 256 means it updates policy 256 times once after ``collect_per_step`` frames are collected. - :param function train_fn: a function receives the current number of epoch + :param function train_fn: a function receives the current number of step index and performs some operations at the beginning of training in this epoch. - :param function test_fn: a function receives the current number of epoch - index and performs some operations at the beginning of testing in this - epoch. + :param function test_fn: a function performs some operations at the + beginning of testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted @@ -81,12 +80,12 @@ def offpolicy_trainer( for epoch in range(1, 1 + max_epoch): # train policy.train() - if train_fn: - train_fn(epoch) with tqdm.tqdm( total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: while t.n < t.total: + if train_fn: + train_fn(global_step) result = train_collector.collect(n_step=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result["rew"]): @@ -104,8 +103,6 @@ def offpolicy_trainer( test_result["rew"]) else: policy.train() - if train_fn: - train_fn(epoch) for i in range(update_per_step * min( result["n/st"] // collect_per_step, t.total - t.n)): global_step += collect_per_step diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 877c6348c..59a1f1f4f 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -20,7 +20,7 @@ def onpolicy_trainer( episode_per_test: Union[int, List[int]], batch_size: int, train_fn: Optional[Callable[[int], None]] = None, - test_fn: Optional[Callable[[int], None]] = None, + test_fn: Optional[Callable[[], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, @@ -52,12 +52,11 @@ def onpolicy_trainer( :type episode_per_test: int or list of ints :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param function train_fn: a function receives the current number of epoch + :param function train_fn: a function receives the current number of step index and performs some operations at the beginning of training in this - epoch. - :param function test_fn: a function receives the current number of epoch - index and performs some operations at the beginning of testing in this - epoch. + poch. + :param function test_fn: a function performs some operations at the + beginning of testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted @@ -81,12 +80,12 @@ def onpolicy_trainer( for epoch in range(1, 1 + max_epoch): # train policy.train() - if train_fn: - train_fn(epoch) with tqdm.tqdm( total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: while t.n < t.total: + if train_fn: + train_fn(global_step) result = train_collector.collect(n_episode=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result["rew"]): @@ -104,8 +103,6 @@ def onpolicy_trainer( test_result["rew"]) else: policy.train() - if train_fn: - train_fn(epoch) losses = policy.update( 0, train_collector.buffer, batch_size=batch_size, repeat=repeat_per_collect) diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 0c5d2dd9b..30ff17be5 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -10,7 +10,7 @@ def test_episode( policy: BasePolicy, collector: Collector, - test_fn: Optional[Callable[[int], None]], + test_fn: Optional[Callable[[], None]], epoch: int, n_episode: Union[int, List[int]], writer: Optional[SummaryWriter] = None, @@ -21,7 +21,7 @@ def test_episode( collector.reset_buffer() policy.eval() if test_fn: - test_fn(epoch) + test_fn() if collector.get_env_num() > 1 and isinstance(n_episode, int): n = collector.get_env_num() n_ = np.zeros(n) + n_episode // n From f0be2c3bc6f4ae560bc5322cb29f68671a973471 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 25 Sep 2020 16:25:39 +0800 Subject: [PATCH 2/8] version file --- tianshou/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 0b9c0e942..cc37f2773 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.3.0rc0" +__version__ = "0.3.0" __all__ = [ "env", From 5dc0b61237563bd9d9eb840800011dddeb4c3611 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 25 Sep 2020 16:42:59 +0800 Subject: [PATCH 3/8] train_fn(env_step) -> train_fn(epoch, env_step) --- README.md | 2 +- docs/tutorials/dqn.rst | 2 +- docs/tutorials/tictactoe.rst | 2 +- examples/atari/atari_dqn.py | 2 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/multiagent/tic_tac_toe.py | 2 +- tianshou/trainer/offpolicy.py | 10 +++++----- tianshou/trainer/onpolicy.py | 10 +++++----- 11 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 55c073e39..61cfb9c7b 100644 --- a/README.md +++ b/README.md @@ -229,7 +229,7 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, - test_num, batch_size, train_fn=lambda env_step: policy.set_eps(eps_train), + test_num, batch_size, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=writer, task=task) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 90fd20050..35fe4fa7e 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -123,7 +123,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians policy, train_collector, test_collector, max_epoch=10, step_per_epoch=1000, collect_per_step=10, episode_per_test=100, batch_size=64, - train_fn=lambda env_step: policy.set_eps(0.1), + train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=None) diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index a19713442..b8e077e37 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -339,7 +339,7 @@ With the above preparation, we are close to the first learned agent. The followi # the default args.win_rate is 0.9, but the reward is [-1, 1] # instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate. - def train_fn(env_step): + def train_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) def test_fn(): diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 5ba3cc593..97a8993c2 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -103,7 +103,7 @@ def stop_fn(mean_rewards): else: return False - def train_fn(env_step): + def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps if env_step <= 1e6: eps = args.eps_train - env_step / 1e6 * \ diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index b0afe4ee3..d72d01827 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -78,7 +78,7 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(env_step): + def train_fn(epoch, env_step): if env_step <= 100000: policy.set_eps(args.eps_train) elif env_step <= 500000: diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index e111fb730..e7979034b 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -80,7 +80,7 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(env_step): # exp decay + def train_fn(epoch, env_step): # exp decay eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) policy.set_eps(eps) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 8fcafddc8..9a961cfc5 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -88,7 +88,7 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(env_step): + def train_fn(epoch, env_step): # eps annnealing, just a demo if env_step <= 10000: policy.set_eps(args.eps_train) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index e1daa13ac..d6a00bf6d 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -82,7 +82,7 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(env_step): + def train_fn(epoch, env_step): policy.set_eps(args.eps_train) def test_fn(): diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index a5a608d58..83be2973c 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -147,7 +147,7 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= args.win_rate - def train_fn(env_step): + def train_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) def test_fn(): diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index bb5d411cb..66eccf743 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -19,7 +19,7 @@ def offpolicy_trainer( episode_per_test: Union[int, List[int]], batch_size: int, update_per_step: int = 1, - train_fn: Optional[Callable[[int], None]] = None, + train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, @@ -52,9 +52,9 @@ def offpolicy_trainer( be updated after frames are collected, for example, set it to 256 means it updates policy 256 times once after ``collect_per_step`` frames are collected. - :param function train_fn: a function receives the current number of step - index and performs some operations at the beginning of training in this - epoch. + :param function train_fn: a function receives the current number of epoch + and step index and performs some operations at the beginning of + training in this epoch. :param function test_fn: a function performs some operations at the beginning of testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted @@ -85,7 +85,7 @@ def offpolicy_trainer( ) as t: while t.n < t.total: if train_fn: - train_fn(global_step) + train_fn(epoch, global_step) result = train_collector.collect(n_step=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result["rew"]): diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 59a1f1f4f..8bf1986fc 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -19,7 +19,7 @@ def onpolicy_trainer( repeat_per_collect: int, episode_per_test: Union[int, List[int]], batch_size: int, - train_fn: Optional[Callable[[int], None]] = None, + train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, @@ -52,9 +52,9 @@ def onpolicy_trainer( :type episode_per_test: int or list of ints :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param function train_fn: a function receives the current number of step - index and performs some operations at the beginning of training in this - poch. + :param function train_fn: a function receives the current number of epoch + and step index, and performs some operations at the beginning of + training in this poch. :param function test_fn: a function performs some operations at the beginning of testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted @@ -85,7 +85,7 @@ def onpolicy_trainer( ) as t: while t.n < t.total: if train_fn: - train_fn(global_step) + train_fn(epoch, global_step) result = train_collector.collect(n_episode=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result["rew"]): From 6a4f333c70ad052bd47f57c6103dcb4ffa6f128b Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 25 Sep 2020 16:47:04 +0800 Subject: [PATCH 4/8] fix --- docs/tutorials/dqn.rst | 2 +- tianshou/trainer/offpolicy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 35fe4fa7e..fbcca1ac6 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -136,7 +136,7 @@ The meaning of each parameter is as follows (full description can be found at :m * ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. -* ``train_fn``: A function receives the current number of step index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". +* ``train_fn``: A function receives the current number of epoch and step index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". * ``test_fn``: A function performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". * ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. * ``writer``: See below. diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 66eccf743..0516cf47e 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -53,7 +53,7 @@ def offpolicy_trainer( it updates policy 256 times once after ``collect_per_step`` frames are collected. :param function train_fn: a function receives the current number of epoch - and step index and performs some operations at the beginning of + and step index, and performs some operations at the beginning of training in this epoch. :param function test_fn: a function performs some operations at the beginning of testing in this epoch. From f280797531a9f4717443819c11062e1ede2662b9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 25 Sep 2020 17:16:56 +0800 Subject: [PATCH 5/8] test(epoch, num_env_step) --- README.md | 5 +++-- docs/tutorials/dqn.rst | 8 ++++---- docs/tutorials/tictactoe.rst | 4 ++-- examples/atari/atari_dqn.py | 10 +++++----- examples/box2d/acrobot_dualdqn.py | 10 +++++----- examples/box2d/lunarlander_dqn.py | 6 +++--- test/discrete/test_dqn.py | 10 +++++----- test/discrete/test_drqn.py | 4 ++-- test/multiagent/tic_tac_toe.py | 4 ++-- tianshou/trainer/offpolicy.py | 7 ++++--- tianshou/trainer/onpolicy.py | 7 ++++--- tianshou/trainer/utils.py | 4 ++-- 12 files changed, 41 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 61cfb9c7b..5ff434b2f 100644 --- a/README.md +++ b/README.md @@ -229,8 +229,9 @@ Let's train it: ```python result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, - test_num, batch_size, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), - test_fn=lambda: policy.set_eps(eps_test), + test_num, batch_size, + train_fn=lambda epoch, num_env_step: policy.set_eps(eps_train), + test_fn=lambda epoch, num_env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=writer, task=task) print(f'Finished training! Use {result["duration"]}') diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index fbcca1ac6..a73db48e3 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -123,8 +123,8 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians policy, train_collector, test_collector, max_epoch=10, step_per_epoch=1000, collect_per_step=10, episode_per_test=100, batch_size=64, - train_fn=lambda epoch, env_step: policy.set_eps(0.1), - test_fn=lambda: policy.set_eps(0.05), + train_fn=lambda epoch, num_env_step: policy.set_eps(0.1), + test_fn=lambda epoch, num_env_step: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=None) print(f'Finished training! Use {result["duration"]}') @@ -136,8 +136,8 @@ The meaning of each parameter is as follows (full description can be found at :m * ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. -* ``train_fn``: A function receives the current number of epoch and step index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". -* ``test_fn``: A function performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". +* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". +* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". * ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. * ``writer``: See below. diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index b8e077e37..a9c5b6e05 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -339,10 +339,10 @@ With the above preparation, we are close to the first learned agent. The followi # the default args.win_rate is 0.9, but the reward is [-1, 1] # instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate. - def train_fn(epoch, env_step): + def train_fn(epoch, num_env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(): + def test_fn(epoch, num_env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # start training, this may require about three minutes diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 97a8993c2..6e9de5b49 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -103,17 +103,17 @@ def stop_fn(mean_rewards): else: return False - def train_fn(epoch, env_step): + def train_fn(epoch, num_env_step): # nature DQN setting, linear decay in the first 1M steps - if env_step <= 1e6: - eps = args.eps_train - env_step / 1e6 * \ + if num_env_step <= 1e6: + eps = args.eps_train - num_env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=env_step) + writer.add_scalar('train/eps', eps, global_step=num_env_step) - def test_fn(): + def test_fn(epoch, num_env_step): policy.set_eps(args.eps_test) # watch agent's performance diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index d72d01827..85c15726e 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -78,17 +78,17 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, env_step): - if env_step <= 100000: + def train_fn(epoch, num_env_step): + if num_env_step <= 100000: policy.set_eps(args.eps_train) - elif env_step <= 500000: - eps = args.eps_train - (env_step - 100000) / \ + elif num_env_step <= 500000: + eps = args.eps_train - (num_env_step - 100000) / \ 400000 * (0.5 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.5 * args.eps_train) - def test_fn(): + def test_fn(epoch, num_env_step): policy.set_eps(args.eps_test) # trainer diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index e7979034b..f959a1bd3 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -80,11 +80,11 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, env_step): # exp decay - eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) + def train_fn(epoch, num_env_step): # exp decay + eps = max(args.eps_train * (1 - 5e-6) ** num_env_step, args.eps_test) policy.set_eps(eps) - def test_fn(): + def test_fn(epoch, num_env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 9a961cfc5..8be03ae53 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -88,18 +88,18 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch, num_env_step): # eps annnealing, just a demo - if env_step <= 10000: + if num_env_step <= 10000: policy.set_eps(args.eps_train) - elif env_step <= 50000: - eps = args.eps_train - (env_step - 10000) / \ + elif num_env_step <= 50000: + eps = args.eps_train - (num_env_step - 10000) / \ 40000 * (0.9 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.1 * args.eps_train) - def test_fn(): + def test_fn(epoch, num_env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index d6a00bf6d..5667e8c11 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -82,10 +82,10 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, env_step): + def train_fn(epoch, num_env_step): policy.set_eps(args.eps_train) - def test_fn(): + def test_fn(epoch, num_env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 83be2973c..fc82246c9 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -147,10 +147,10 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= args.win_rate - def train_fn(epoch, env_step): + def train_fn(epoch, num_env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(): + def test_fn(epoch, num_env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # trainer diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 0516cf47e..75fd6cfd4 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -20,7 +20,7 @@ def offpolicy_trainer( batch_size: int, update_per_step: int = 1, train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, @@ -55,8 +55,9 @@ def offpolicy_trainer( :param function train_fn: a function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. - :param function test_fn: a function performs some operations at the - beginning of testing in this epoch. + :param function test_fn: a function receives the current number of epoch + and step index, and performs some operations at the beginning of + testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 8bf1986fc..023dd2985 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -20,7 +20,7 @@ def onpolicy_trainer( episode_per_test: Union[int, List[int]], batch_size: int, train_fn: Optional[Callable[[int, int], None]] = None, - test_fn: Optional[Callable[[], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, writer: Optional[SummaryWriter] = None, @@ -55,8 +55,9 @@ def onpolicy_trainer( :param function train_fn: a function receives the current number of epoch and step index, and performs some operations at the beginning of training in this poch. - :param function test_fn: a function performs some operations at the - beginning of testing in this epoch. + :param function test_fn: a function receives the current number of epoch + and step index, and performs some operations at the beginning of + testing in this epoch. :param function save_fn: a function for saving policy when the undiscounted average mean reward in evaluation phase gets better. :param function stop_fn: a function receives the average undiscounted diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 30ff17be5..2c2fb5438 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -10,7 +10,7 @@ def test_episode( policy: BasePolicy, collector: Collector, - test_fn: Optional[Callable[[], None]], + test_fn: Optional[Callable[[int, Optional[int]], None]], epoch: int, n_episode: Union[int, List[int]], writer: Optional[SummaryWriter] = None, @@ -21,7 +21,7 @@ def test_episode( collector.reset_buffer() policy.eval() if test_fn: - test_fn() + test_fn(epoch, global_step) if collector.get_env_num() > 1 and isinstance(n_episode, int): n = collector.get_env_num() n_ = np.zeros(n) + n_episode // n From b77104ea00527724c6a60d2af1c34b5dbd8a9f21 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 25 Sep 2020 20:11:26 +0800 Subject: [PATCH 6/8] check every example is runnable --- examples/atari/README.md | 18 +++++++++--------- examples/atari/atari_dqn.py | 20 ++++++++++---------- examples/mujoco/ant_v2_ddpg.py | 6 +++--- examples/mujoco/ant_v2_sac.py | 4 ++-- examples/mujoco/ant_v2_td3.py | 6 +++--- examples/mujoco/halfcheetahBullet_v0_sac.py | 10 +++------- examples/mujoco/point_maze_td3.py | 8 ++++---- 7 files changed, 34 insertions(+), 38 deletions(-) diff --git a/examples/atari/README.md b/examples/atari/README.md index 40c025c4f..474f74c42 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -12,14 +12,14 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | task | best reward | reward curve | parameters | time cost | | --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- | -| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) | -| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) | -| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | -| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) | - -Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. +| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) | +| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) | +| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | +| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) | + +Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed. We haven't tuned this result to the best, so have fun with playing these hyperparameters! diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 6e9de5b49..5ba0c3838 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -18,20 +18,20 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--eps_test', type=float, default=0.005) - parser.add_argument('--eps_train', type=float, default=1.) - parser.add_argument('--eps_train_final', type=float, default=0.05) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) parser.add_argument('--buffer-size', type=int, default=100000) parser.add_argument('--lr', type=float, default=0.0001) parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--n_step', type=int, default=3) - parser.add_argument('--target_update_freq', type=int, default=500) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step_per_epoch', type=int, default=10000) - parser.add_argument('--collect_per_step', type=int, default=10) - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--training_num', type=int, default=16) - parser.add_argument('--test_num', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) parser.add_argument( diff --git a/examples/mujoco/ant_v2_ddpg.py b/examples/mujoco/ant_v2_ddpg.py index 528ae4bf4..dd4486de7 100644 --- a/examples/mujoco/ant_v2_ddpg.py +++ b/examples/mujoco/ant_v2_ddpg.py @@ -6,11 +6,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.exploration import GaussianNoise +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import Actor, Critic diff --git a/examples/mujoco/ant_v2_sac.py b/examples/mujoco/ant_v2_sac.py index 156be7e37..819c7454e 100644 --- a/examples/mujoco/ant_v2_sac.py +++ b/examples/mujoco/ant_v2_sac.py @@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic diff --git a/examples/mujoco/ant_v2_td3.py b/examples/mujoco/ant_v2_td3.py index 370e591c0..9e43f1027 100644 --- a/examples/mujoco/ant_v2_td3.py +++ b/examples/mujoco/ant_v2_td3.py @@ -6,11 +6,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.env import SubprocVectorEnv -from tianshou.exploration import GaussianNoise from tianshou.utils.net.common import Net +from tianshou.exploration import GaussianNoise +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import Actor, Critic diff --git a/examples/mujoco/halfcheetahBullet_v0_sac.py b/examples/mujoco/halfcheetahBullet_v0_sac.py index 41c33dcf3..05591676f 100644 --- a/examples/mujoco/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/halfcheetahBullet_v0_sac.py @@ -4,18 +4,14 @@ import pprint import argparse import numpy as np +import pybullet_envs from torch.utils.tensorboard import SummaryWriter -from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy +from tianshou.utils.net.common import Net +from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -try: - import pybullet_envs -except ImportError: - pass -from tianshou.utils.net.common import Net from tianshou.utils.net.continuous import ActorProb, Critic diff --git a/examples/mujoco/point_maze_td3.py b/examples/mujoco/point_maze_td3.py index 48a26e265..ff42716c5 100644 --- a/examples/mujoco/point_maze_td3.py +++ b/examples/mujoco/point_maze_td3.py @@ -6,12 +6,13 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.exploration import GaussianNoise -from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer from tianshou.utils.net.continuous import Actor, Critic + from mujoco.register import reg @@ -40,7 +41,6 @@ def get_args(): parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') - parser.add_argument('--max_episode_steps', type=int, default=2000) return parser.parse_args() From d745c44e980ddbe1fd007fde3b821301221e9875 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Fri, 25 Sep 2020 21:42:09 +0800 Subject: [PATCH 7/8] os.path --- examples/atari/runnable/pong_a2c.py | 3 ++- examples/atari/runnable/pong_ppo.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py index 55ed15a0d..290d11d9c 100644 --- a/examples/atari/runnable/pong_a2c.py +++ b/examples/atari/runnable/pong_a2c.py @@ -1,3 +1,4 @@ +import os import torch import pprint import argparse @@ -76,7 +77,7 @@ def test_a2c(args=get_args()): preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(args.logdir + '/a2c') + writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c')) def stop_fn(mean_rewards): if env.env.spec.reward_threshold: diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py index 109e8130b..4e898c8bf 100644 --- a/examples/atari/runnable/pong_ppo.py +++ b/examples/atari/runnable/pong_ppo.py @@ -1,3 +1,4 @@ +import os import torch import pprint import argparse @@ -80,7 +81,7 @@ def test_ppo(args=get_args()): preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) # log - writer = SummaryWriter(args.logdir + '/ppo') + writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo')) def stop_fn(mean_rewards): if env.env.spec.reward_threshold: From 4b54dad94ff0f935e9b42782edd1562fd5c5ae0e Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Sat, 26 Sep 2020 06:59:52 +0800 Subject: [PATCH 8/8] num_env_step -> env_step --- README.md | 4 ++-- docs/tutorials/dqn.rst | 4 ++-- docs/tutorials/tictactoe.rst | 4 ++-- examples/atari/atari_dqn.py | 10 +++++----- examples/box2d/acrobot_dualdqn.py | 10 +++++----- examples/box2d/lunarlander_dqn.py | 6 +++--- test/discrete/test_dqn.py | 10 +++++----- test/discrete/test_drqn.py | 4 ++-- test/multiagent/tic_tac_toe.py | 4 ++-- tianshou/data/buffer.py | 3 ++- 10 files changed, 30 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 5ff434b2f..21e8cdbc7 100644 --- a/README.md +++ b/README.md @@ -230,8 +230,8 @@ Let's train it: result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, test_num, batch_size, - train_fn=lambda epoch, num_env_step: policy.set_eps(eps_train), - test_fn=lambda epoch, num_env_step: policy.set_eps(eps_test), + train_fn=lambda epoch, env_step: policy.set_eps(eps_train), + test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=writer, task=task) print(f'Finished training! Use {result["duration"]}') diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index a73db48e3..49f6260b0 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -123,8 +123,8 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians policy, train_collector, test_collector, max_epoch=10, step_per_epoch=1000, collect_per_step=10, episode_per_test=100, batch_size=64, - train_fn=lambda epoch, num_env_step: policy.set_eps(0.1), - test_fn=lambda epoch, num_env_step: policy.set_eps(0.05), + train_fn=lambda epoch, env_step: policy.set_eps(0.1), + test_fn=lambda epoch, env_step: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, writer=None) print(f'Finished training! Use {result["duration"]}') diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index a9c5b6e05..a511af229 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -339,10 +339,10 @@ With the above preparation, we are close to the first learned agent. The followi # the default args.win_rate is 0.9, but the reward is [-1, 1] # instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate. - def train_fn(epoch, num_env_step): + def train_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(epoch, num_env_step): + def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # start training, this may require about three minutes diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 5ba0c3838..e44890308 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -103,17 +103,17 @@ def stop_fn(mean_rewards): else: return False - def train_fn(epoch, num_env_step): + def train_fn(epoch, env_step): # nature DQN setting, linear decay in the first 1M steps - if num_env_step <= 1e6: - eps = args.eps_train - num_env_step / 1e6 * \ + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ (args.eps_train - args.eps_train_final) else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=num_env_step) + writer.add_scalar('train/eps', eps, global_step=env_step) - def test_fn(epoch, num_env_step): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # watch agent's performance diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 85c15726e..4408283ad 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -78,17 +78,17 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, num_env_step): - if num_env_step <= 100000: + def train_fn(epoch, env_step): + if env_step <= 100000: policy.set_eps(args.eps_train) - elif num_env_step <= 500000: - eps = args.eps_train - (num_env_step - 100000) / \ + elif env_step <= 500000: + eps = args.eps_train - (env_step - 100000) / \ 400000 * (0.5 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.5 * args.eps_train) - def test_fn(epoch, num_env_step): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index f959a1bd3..0bdb2283c 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -80,11 +80,11 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, num_env_step): # exp decay - eps = max(args.eps_train * (1 - 5e-6) ** num_env_step, args.eps_test) + def train_fn(epoch, env_step): # exp decay + eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test) policy.set_eps(eps) - def test_fn(epoch, num_env_step): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 8be03ae53..7564c08cb 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -88,18 +88,18 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, num_env_step): + def train_fn(epoch, env_step): # eps annnealing, just a demo - if num_env_step <= 10000: + if env_step <= 10000: policy.set_eps(args.eps_train) - elif num_env_step <= 50000: - eps = args.eps_train - (num_env_step - 10000) / \ + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ 40000 * (0.9 * args.eps_train) policy.set_eps(eps) else: policy.set_eps(0.1 * args.eps_train) - def test_fn(epoch, num_env_step): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5667e8c11..f3f00e69f 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -82,10 +82,10 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= env.spec.reward_threshold - def train_fn(epoch, num_env_step): + def train_fn(epoch, env_step): policy.set_eps(args.eps_train) - def test_fn(epoch, num_env_step): + def test_fn(epoch, env_step): policy.set_eps(args.eps_test) # trainer diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index fc82246c9..13218333a 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -147,10 +147,10 @@ def save_fn(policy): def stop_fn(mean_rewards): return mean_rewards >= args.win_rate - def train_fn(epoch, num_env_step): + def train_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_train) - def test_fn(epoch, num_env_step): + def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) # trainer diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 8cf4f6bc9..c8c572a22 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -230,7 +230,8 @@ def add( obs = obs[-1] self._add_to_buffer("obs", obs) self._add_to_buffer("act", act) - self._add_to_buffer("rew", rew) + # make sure the reward is a float instead of an int + self._add_to_buffer("rew", rew * 1.0) # type: ignore self._add_to_buffer("done", done) if self._save_s_: if obs_next is None: