这是indexloc提供的服务,不要输入任何密码
Skip to content

Add QR-DQN algorithm #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [C51](https://arxiv.org/pdf/1707.06887.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `C51 <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
Expand Down
15 changes: 14 additions & 1 deletion examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper.

# QRDQN (single run)

One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64` |
| BreakoutNoFrameskip-v4 | 409.2 | ![](results/qrdqn/Breakout_rew.png) | `python3 atari_qrdqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
| EnduroNoFrameskip-v4 | 1055.9 | ![](results/qrdqn/Enduro_rew.png) | `python3 atari_qrdqn.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 14990 | ![](results/qrdqn/Qbert_rew.png) | `python3 atari_qrdqn.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 2886 | ![](results/qrdqn/MsPacman_rew.png) | `python3 atari_qrdqn.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 5676 | ![](results/qrdqn/Seaquest_rew.png) | `python3 atari_qrdqn.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 938 | ![](results/qrdqn/SpaceInvader_rew.png) | `python3 atari_qrdqn.py --task "SpaceInvadersNoFrameskip-v4"` |

# BCQ

TODO: after the `done` issue fixed, the result should be re-tuned and place here.
Expand All @@ -49,4 +63,3 @@ To running BCQ algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above DQN section;
- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`.

39 changes: 36 additions & 3 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(
num_atoms: int = 51,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
self.action_shape = action_shape
self.action_num = np.prod(action_shape)
super().__init__(c, h, w, [self.action_num * num_atoms], device)
self.num_atoms = num_atoms

def forward(
Expand All @@ -77,5 +77,38 @@ def forward(
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
x = x.view(-1, self.action_num, self.num_atoms)
return x, state


class QRDQN(DQN):
"""Reference: Distributional Reinforcement Learning with Quantile \
Regression.

For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""

def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_quantiles: int = 200,
device: Union[str, int, torch.device] = "cpu",
) -> None:
self.action_num = np.prod(action_shape)
super().__init__(c, h, w, [self.action_num * num_quantiles], device)
self.num_quantiles = num_quantiles

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.action_num, self.num_quantiles)
return x, state
153 changes: 153 additions & 0 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import QRDQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer

from atari_network import QRDQN
from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_qrdqn(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = QRDQN(*args.state_shape, args.action_shape,
args.num_quantiles, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = QRDQNPolicy(
net, optim, args.gamma, args.num_quantiles,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device
))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False

def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
print("Testing agent ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=[1] * args.test_num,
render=args.render)
pprint.pprint(result)

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * 4)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)

pprint.pprint(result)
watch()


if __name__ == '__main__':
test_qrdqn(get_args())
Binary file added examples/atari/results/qrdqn/Breakout_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/qrdqn/Enduro_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/qrdqn/MsPacman_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/qrdqn/Pong_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/qrdqn/Qbert_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/qrdqn/Seaquest_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
136 changes: 136 additions & 0 deletions test/discrete/test_qrdqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import QRDQNPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay',
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args


def test_qrdqn(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
softmax=False, num_atoms=args.num_quantiles)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = QRDQNPolicy(
net, optim, args.gamma, args.num_quantiles,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.beta)
else:
buf = ReplayBuffer(args.buffer_size)
# collector
train_collector = Collector(policy, train_envs, buf)
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold

def train_fn(epoch, env_step):
# eps annnealing, just a demo
if env_step <= 10000:
policy.set_eps(args.eps_train)
elif env_step <= 50000:
eps = args.eps_train - (env_step - 10000) / \
40000 * (0.9 * args.eps_train)
policy.set_eps(eps)
else:
policy.set_eps(0.1 * args.eps_train)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)

assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')


def test_pqrdqn(args=get_args()):
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_qrdqn(args)


if __name__ == '__main__':
test_pqrdqn(get_args())
2 changes: 2 additions & 0 deletions tianshou/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.modelfree.dqn import DQNPolicy
from tianshou.policy.modelfree.c51 import C51Policy
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.modelfree.ddpg import DDPGPolicy
Expand All @@ -21,6 +22,7 @@
"ImitationPolicy",
"DQNPolicy",
"C51Policy",
"QRDQNPolicy",
"PGPolicy",
"A2CPolicy",
"DDPGPolicy",
Expand Down
Loading