这是indexloc提供的服务,不要输入任何密码
Skip to content

change API of train_fn and test_fn #229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,11 @@ Let's train it:
```python
result = ts.trainer.offpolicy_trainer(
policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step,
test_num, batch_size, train_fn=lambda e: policy.set_eps(eps_train),
test_fn=lambda e: policy.set_eps(eps_test),
stop_fn=lambda x: x >= env.spec.reward_threshold, writer=writer, task=task)
test_num, batch_size,
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=writer, task=task)
print(f'Finished training! Use {result["duration"]}')
```

Expand Down
10 changes: 5 additions & 5 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians
policy, train_collector, test_collector,
max_epoch=10, step_per_epoch=1000, collect_per_step=10,
episode_per_test=100, batch_size=64,
train_fn=lambda e: policy.set_eps(0.1),
test_fn=lambda e: policy.set_eps(0.05),
stop_fn=lambda x: x >= env.spec.reward_threshold,
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=None)
print(f'Finished training! Use {result["duration"]}')

Expand All @@ -136,8 +136,8 @@ The meaning of each parameter is as follows (full description can be found at :m
* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update";
* ``episode_per_test``: The number of episodes for one policy evaluation.
* ``batch_size``: The batch size of sample data, which is going to feed in the policy network.
* ``train_fn``: A function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch index and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
* ``writer``: See below.

Expand Down
8 changes: 4 additions & 4 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ With the above preparation, we are close to the first learned agent. The followi
policy.policies[args.agent_id - 1].state_dict(),
model_save_path)

def stop_fn(x):
return x >= args.win_rate # 95% winning rate by default
def stop_fn(mean_rewards):
return mean_rewards >= args.win_rate # 95% winning rate by default
# the default args.win_rate is 0.9, but the reward is [-1, 1]
# instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate.

def train_fn(x):
def train_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_train)

def test_fn(x):
def test_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_test)

# start training, this may require about three minutes
Expand Down
18 changes: 9 additions & 9 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

| task | best reward | reward curve | parameters | time cost |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- |
| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) |
| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) |
| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |

Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
| PongNoFrameskip-v4 | 20 | ![](results/dqn/Pong_rew.png) | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch-size 64` | ~30 min (~15 epoch) |
| BreakoutNoFrameskip-v4 | 316 | ![](results/dqn/Breakout_rew.png) | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| EnduroNoFrameskip-v4 | 670 | ![](results/dqn/Enduro_rew.png) | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test-num 100` | 3~4h (100 epoch) |
| QbertNoFrameskip-v4 | 7307 | ![](results/dqn/Qbert_rew.png) | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| MsPacmanNoFrameskip-v4 | 2107 | ![](results/dqn/MsPacman_rew.png) | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| SeaquestNoFrameskip-v4 | 2088 | ![](results/dqn/Seaquest_rew.png) | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |
| SpaceInvadersNoFrameskip-v4 | 812.2 | ![](results/dqn/SpaceInvader_rew.png) | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test-num 100` | 3~4h (100 epoch) |

Note: The `eps_train_final` and `eps_test` in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.

We haven't tuned this result to the best, so have fun with playing these hyperparameters!
37 changes: 18 additions & 19 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps_test', type=float, default=0.005)
parser.add_argument('--eps_train', type=float, default=1.)
parser.add_argument('--eps_train_final', type=float, default=0.05)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--n_step', type=int, default=3)
parser.add_argument('--target_update_freq', type=int, default=500)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step_per_epoch', type=int, default=10000)
parser.add_argument('--collect_per_step', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--training_num', type=int, default=16)
parser.add_argument('--test_num', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
Expand Down Expand Up @@ -95,26 +95,25 @@ def test_dqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(x):
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return x >= 20
return mean_rewards >= 20
else:
return False

def train_fn(x):
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
now = x * args.collect_per_step * args.step_per_epoch
if now <= 1e6:
eps = args.eps_train - now / 1e6 * \
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=now)
writer.add_scalar('train/eps', eps, global_step=env_step)

def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# watch agent's performance
Expand Down
7 changes: 4 additions & 3 deletions examples/atari/runnable/pong_a2c.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
import pprint
import argparse
Expand Down Expand Up @@ -76,11 +77,11 @@ def test_a2c(args=get_args()):
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'a2c')
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c'))

def stop_fn(x):
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
else:
return False

Expand Down
7 changes: 4 additions & 3 deletions examples/atari/runnable/pong_ppo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
import pprint
import argparse
Expand Down Expand Up @@ -80,11 +81,11 @@ def test_ppo(args=get_args()):
preprocess_fn=preprocess_fn)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo'))

def stop_fn(x):
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return x >= env.spec.reward_threshold
return mean_rewards >= env.spec.reward_threshold
else:
return False

Expand Down
20 changes: 10 additions & 10 deletions examples/box2d/acrobot_dualdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.env import DummyVectorEnv
from tianshou.policy import DQNPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net


def get_args():
Expand Down Expand Up @@ -75,20 +75,20 @@ def test_dqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def train_fn(x):
if x <= int(0.1 * args.epoch):
def train_fn(epoch, env_step):
if env_step <= 100000:
policy.set_eps(args.eps_train)
elif x <= int(0.5 * args.epoch):
eps = args.eps_train - (x - 0.1 * args.epoch) / \
(0.4 * args.epoch) * (0.5 * args.eps_train)
elif env_step <= 500000:
eps = args.eps_train - (env_step - 100000) / \
400000 * (0.5 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.5 * args.eps_train)

def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# trainer
Expand Down
8 changes: 4 additions & 4 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ def step(self, action):
def test_sac_bipedal(args=get_args()):
env = EnvWrapper(args.task)

def IsStop(reward):
return reward >= env.spec.reward_threshold

args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
Expand Down Expand Up @@ -141,11 +138,14 @@ def IsStop(reward):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
test_in_train=False)

if __name__ == '__main__':
Expand Down
14 changes: 7 additions & 7 deletions examples/box2d/lunarlander_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_args():
# the parameters are found by Optuna
parser.add_argument('--task', type=str, default='LunarLander-v2')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-test', type=float, default=0.01)
parser.add_argument('--eps-train', type=float, default=0.73)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.013)
Expand Down Expand Up @@ -77,14 +77,14 @@ def test_dqn(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def train_fn(x):
args.eps_train = max(args.eps_train * 0.6, 0.01)
policy.set_eps(args.eps_train)
def train_fn(epoch, env_step): # exp decay
eps = max(args.eps_train * (1 - 5e-6) ** env_step, args.eps_test)
policy.set_eps(eps)

def test_fn(x):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# trainer
Expand Down
4 changes: 2 additions & 2 deletions examples/box2d/mcc_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_sac(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

# trainer
result = offpolicy_trainer(
Expand Down
10 changes: 5 additions & 5 deletions examples/mujoco/ant_v2_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.exploration import GaussianNoise
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import Actor, Critic


Expand Down Expand Up @@ -77,8 +77,8 @@ def test_ddpg(args=get_args()):
# log
writer = SummaryWriter(args.logdir + '/' + 'ddpg')

def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

# trainer
result = offpolicy_trainer(
Expand Down
8 changes: 4 additions & 4 deletions examples/mujoco/ant_v2_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic


Expand Down Expand Up @@ -86,8 +86,8 @@ def test_sac(args=get_args()):
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

# trainer
result = offpolicy_trainer(
Expand Down
10 changes: 5 additions & 5 deletions examples/mujoco/ant_v2_td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.continuous import Actor, Critic


Expand Down Expand Up @@ -88,8 +88,8 @@ def test_td3(args=get_args()):
# log
writer = SummaryWriter(args.logdir + '/' + 'td3')

def stop_fn(x):
return x >= env.spec.reward_threshold
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

# trainer
result = offpolicy_trainer(
Expand Down
Loading