这是indexloc提供的服务,不要输入任何密码
Skip to content

Add Rainbow DQN #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Aug 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8c0a4c0
implement Rainbow DQN
Jun 26, 2021
9d8f565
make linter happy
Jun 30, 2021
45d8bb1
make mypy happy
Jun 30, 2021
46ed079
fix a bug about #381
Trinkle23897 Jun 30, 2021
403404e
address review comments
Jun 30, 2021
1a44548
add a test for rainbow
Jun 30, 2021
a22f474
Merge branch 'master' into rainbow
Trinkle23897 Jul 5, 2021
2762945
fix documentation
Jul 1, 2021
7f7b136
control the timing of sampling noises
Jul 5, 2021
05ac8f2
fix a bug in noisy linear
Jul 5, 2021
45874a4
fix doc and test
Jul 6, 2021
14c9c18
update exp results
nuance1979 Jul 8, 2021
5e6c46d
make pydocstyle happy
nuance1979 Jul 8, 2021
03d3f73
minor fix
Jul 9, 2021
f85a584
minor fix about sample_noise on model_old
Jul 11, 2021
c2c12ce
remove eps hack in prio buffer
Jul 13, 2021
f9d4347
revert eps hack and scale weights instead
Jul 16, 2021
7900450
remove weight scaling by magic number in favor of weight normalization
Jul 18, 2021
1ea40a1
fix test failure
Jul 18, 2021
3772c0f
use np.max() to maximize compatibility
Jul 18, 2021
42b4023
move weight norm to the policy side
Jul 18, 2021
104d476
move weight norm back to buffer side as an option
Jul 19, 2021
a3fc666
anneal beta parameter of prio buffer
Jul 27, 2021
18c1391
cosmetic change
Jul 27, 2021
5641d37
change beta annealing schedule
Jul 31, 2021
9a458d0
update current rainbow results; still bad on some tasks
Aug 2, 2021
f178b0e
fix a minor bug
Aug 5, 2021
d16dbb9
Merge branch 'master' into rainbow
Trinkle23897 Aug 6, 2021
0ed3f21
separate log dirs
Aug 17, 2021
a204fda
Merge branch 'master' into rainbow
nuance1979 Aug 17, 2021
4cb94f2
Merge branch 'rainbow' of https://github.com/nuance1979/tianshou into…
nuance1979 Aug 17, 2021
ed8552f
update results
Aug 22, 2021
96a5b86
update plots
nuance1979 Aug 22, 2021
8599d1e
Merge branch 'master' into rainbow
nuance1979 Aug 22, 2021
2211296
fix test failure
nuance1979 Aug 22, 2021
f2384eb
fix test failure again
nuance1979 Aug 22, 2021
5ecf6d3
fix more test failure
nuance1979 Aug 22, 2021
4d2debf
fix a bug about explore_noise
Aug 24, 2021
c946542
update plots
nuance1979 Aug 24, 2021
a040a6d
make linter happy
Aug 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf)
- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf)
Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ DQN Family
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.RainbowPolicy
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.QRDQNPolicy
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN <https://arxiv.org/pdf/1707.02298.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network <https://arxiv.org/pdf/1806.06923.pdf>`_
* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function <https://arxiv.org/pdf/1911.02140.pdf>`_
Expand Down
14 changes: 14 additions & 0 deletions examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
| SeaquestNoFrameskip-v4 | 10775 | ![](results/fqf/Seaquest_rew.png) | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 2482 | ![](results/fqf/SpaceInvaders_rew.png) | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` |

# Rainbow (single run)

One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` |
| BreakoutNoFrameskip-v4 | 684.6 | ![](results/rainbow/Breakout_rew.png) | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` |
| EnduroNoFrameskip-v4 | 1625.9 | ![](results/rainbow/Enduro_rew.png) | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` |
| QbertNoFrameskip-v4 | 16192.5 | ![](results/rainbow/Qbert_rew.png) | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` |
| MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` |
| SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` |
| SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` |

# BCQ

To running BCQ algorithm on Atari, you need to do the following things:
Expand Down
60 changes: 60 additions & 0 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from torch import nn
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.utils.net.discrete import NoisyLinear


class DQN(nn.Module):
Expand Down Expand Up @@ -81,6 +82,65 @@ def forward(
return x, state


class Rainbow(DQN):
"""Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.

For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""

def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
noisy_std: float = 0.5,
device: Union[str, int, torch.device] = "cpu",
is_dueling: bool = True,
is_noisy: bool = True,
) -> None:
super().__init__(c, h, w, action_shape, device, features_only=True)
self.action_num = np.prod(action_shape)
self.num_atoms = num_atoms

def linear(x, y):
if is_noisy:
return NoisyLinear(x, y, noisy_std)
else:
return nn.Linear(x, y)

self.Q = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
linear(512, self.action_num * self.num_atoms))
self._is_dueling = is_dueling
if self._is_dueling:
self.V = nn.Sequential(
linear(self.output_dim, 512), nn.ReLU(inplace=True),
linear(512, self.num_atoms))
self.output_dim = self.action_num * self.num_atoms

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
q = self.Q(x)
q = q.view(-1, self.action_num, self.num_atoms)
if self._is_dueling:
v = self.V(x)
v = v.view(-1, 1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
else:
logits = q
y = logits.softmax(dim=2)
return y, state


class QRDQN(DQN):
"""Reference: Distributional Reinforcement Learning with Quantile \
Regression.
Expand Down
204 changes: 204 additions & 0 deletions examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import os
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import RainbowPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer

from atari_network import Rainbow
from atari_wrapper import wrap_deepmind


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0000625)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--noisy-std', type=float, default=0.1)
parser.add_argument('--no-dueling', action='store_true', default=False)
parser.add_argument('--no-noisy', action='store_true', default=False)
parser.add_argument('--no-priority', action='store_true', default=False)
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--beta-final', type=float, default=1.)
parser.add_argument('--beta-anneal-step', type=int, default=5000000)
parser.add_argument('--no-weight-norm', action='store_true', default=False)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()


def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)


def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)


def test_rainbow(args=get_args()):
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = Rainbow(*args.state_shape, args.action_shape,
args.num_atoms, args.noisy_std, args.device,
is_dueling=not args.no_dueling,
is_noisy=not args.no_noisy)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = RainbowPolicy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
if args.no_priority:
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
else:
buffer = PrioritizedVectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha,
beta=args.beta, weight_norm=not args.no_weight_norm)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(
args.logdir, args.task, 'rainbow',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False

def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
if not args.no_priority:
if env_step <= args.beta_anneal_step:
beta = args.beta - env_step / args.beta_anneal_step * \
(args.beta - args.beta_final)
else:
beta = args.beta_final
buffer.set_beta(beta)
logger.write('train/beta', env_step, beta)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = PrioritizedVectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack, alpha=args.alpha,
beta=args.beta)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

if args.watch:
watch()
exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)

pprint.pprint(result)
watch()


if __name__ == '__main__':
test_rainbow(get_args())
Binary file added examples/atari/results/rainbow/Breakout_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/rainbow/Enduro_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/rainbow/MsPacman_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/rainbow/Pong_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/rainbow/Qbert_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/atari/results/rainbow/Seaquest_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
mask = np.isin(np.arange(buf2.maxsize), indices)
assert np.all(weight[mask] == weight[mask][0])
assert np.all(weight[~mask] == weight[~mask][0])
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] < 1
assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1


def test_update():
Expand Down
2 changes: 2 additions & 0 deletions test/discrete/test_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def get_args():

def test_qrdqn(args=get_args()):
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
Expand Down
2 changes: 1 addition & 1 deletion test/discrete/test_qrdqn_il_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_discrete_cql(args=get_args()):
# envs
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
env.spec.reward_threshold = 185 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
Expand Down
Loading