这是indexloc提供的服务,不要输入任何密码
Skip to content

Add BranchDQN for large discrete action spaces #618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
May 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
1d94337
Fixed hardcoded reward_treshold
Mar 2, 2022
11c2875
Added reward_treshold as argumenet for offline test scripts
Mar 2, 2022
8288a10
reward_treshold as argument for offline test scripts
Mar 2, 2022
8467afd
reward_threshold defaults to None
Mar 3, 2022
c5daa06
correction
Mar 3, 2022
b5c2d1d
bdq implemented not tested yet
Mar 11, 2022
abc5db2
Implement BDQ algorithm
Apr 4, 2022
c4f1d21
Add bdq to policy package
Apr 4, 2022
64e1c3f
Change some training parameters
Apr 27, 2022
da59357
Deleted test_d4rl.py file that was commited by mistake
Apr 27, 2022
b3e582b
Documents formated with yapf
Apr 27, 2022
250ac5c
Pass tianshou contrib requirements
Apr 27, 2022
424de0b
Merge commit '250ac5cec2ee60dd94f0480b245bf48efbde11ba'
Apr 27, 2022
7a4c69d
Merge branch 'master' of https://github.com/thu-ml/tianshou
Apr 27, 2022
6542e88
Fixed merge conflicts
Apr 27, 2022
1d1ec6f
Merge commit '6542e8805d576dc0908b325b686ee47925739321'
Apr 27, 2022
abe26b9
Remove unused imports
Apr 27, 2022
89d9c6e
Merge commit 'abe26b9db47d41b1f495ad8cb5528700e151e1d3'
Apr 27, 2022
3ccf47e
Merge branch 'master' into bdq
Apr 27, 2022
dba85b4
Fix some linting issues
Apr 29, 2022
af6b5dc
revert all unrelated things
Trinkle23897 Apr 29, 2022
1b705d0
revert again
Trinkle23897 Apr 29, 2022
61e08a2
Merge branch 'master' into bdq
Trinkle23897 Apr 29, 2022
74be256
update
Trinkle23897 Apr 29, 2022
1c5c2fb
Change name of env wrapper to ContinuousToDiscrete
Apr 29, 2022
605f8b6
Fix magic number, now its an argument
Apr 29, 2022
f92df57
BDQPolicy -> BranchingDQNPolicy, BDQNet -> BranchingNet
Apr 29, 2022
65a03eb
Accommodate for one-dimensional discrete action space
May 9, 2022
2590ce3
Test bdq on pendulum environment
May 9, 2022
aabe4a2
Run local contrib tests
May 9, 2022
39e6772
Merge branch 'master' into bdq
Trinkle23897 May 9, 2022
05fcf50
fix code format and resolve linter complaints
May 14, 2022
480b064
fix code format and resolve linter complaints
May 14, 2022
8d78b88
Add BranchingDQN (BDQ) entry
May 14, 2022
1f09c6d
Add BDQ bipedal result
May 14, 2022
864f214
Merge branch 'bdq' of https://github.com/BFAnas/tianshou into bdq
May 14, 2022
fcb0c76
Add type: ignore, and v1 for Pendulum
May 14, 2022
0bc2d20
Fix lint two spaces comment error
May 14, 2022
0f5a8c3
Resolve docstyle error
May 14, 2022
7ee8939
Resolve docstyle error: indentation
May 14, 2022
ae5709a
Fix torch.to(device) error
May 14, 2022
1706b64
Merge branch 'master' into bdq
Trinkle23897 May 14, 2022
8d565ca
polish
Trinkle23897 May 14, 2022
29a01b8
fix spelling error
Trinkle23897 May 14, 2022
db193a5
fix lint
Trinkle23897 May 14, 2022
baec43f
Tuned hyperparameters for BDQ test
May 14, 2022
8bb0ba7
Pendulum-v1 not v0
May 14, 2022
df75756
make test faster
Trinkle23897 May 14, 2022
cc81575
simplify cql code
Trinkle23897 May 15, 2022
b6d7a9e
sync
Trinkle23897 May 15, 2022
47a2376
Update tianshou/policy/modelfree/bdq.py
BFAnas May 15, 2022
3b40b47
Revert to_torch_as, to resolve type error
May 15, 2022
8060866
add no_type_check for to_*
Trinkle23897 May 15, 2022
5b6cd12
use to_torch in bdq
Trinkle23897 May 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [Branching DQN](https://arxiv.org/pdf/1711.08946.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
Expand Down
8 changes: 8 additions & 0 deletions docs/api/tianshou.env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ RayVectorEnv
Wrapper
-------

ContinuousToDiscrete
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: tianshou.env.ContinuousToDiscrete
:members:
:undoc-members:
:show-inheritance:

VectorEnvWrapper
~~~~~~~~~~~~~~~~

Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ DQN Family
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.BranchingDQNPolicy
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.C51Policy
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.BranchingDQNPolicy` `Branching DQN <https://arxiv.org/pdf/1711.08946.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN <https://arxiv.org/pdf/1710.02298.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
Expand Down
7 changes: 7 additions & 0 deletions examples/box2d/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)

![](results/sac/BipedalHardcore.png)


# BipedalWalker-BDQ

- To demonstrate the cpabilities of the BDQ to scale up to big discrete action spaces, we run it on a discretized version of the BipedalWalker-v3 environment, where the number of possible actions in each dimension is 25, for a total of 25^4 = 390 625 possible actions. A usaual DQN architecture would use 25^4 output neurons for the Q-network, thus scaling exponentially with the number of action space dimensions, while the Branching architecture scales linearly and uses only 25*4 output neurons.

![](results/bdq/BipedalWalker.png)
163 changes: 163 additions & 0 deletions examples/box2d/bipedal_bdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import argparse
import datetime
import os
import pprint

import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv
from tianshou.policy import BranchingDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import BranchingNet


def get_args():
parser = argparse.ArgumentParser()
# task
parser.add_argument("--task", type=str, default="BipedalWalker-v3")
# network architecture
parser.add_argument(
"--common-hidden-sizes", type=int, nargs="*", default=[512, 256]
)
parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128])
parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128])
parser.add_argument("--action-per-branch", type=int, default=25)
# training hyperparameters
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.)
parser.add_argument("--eps-train", type=float, default=0.73)
parser.add_argument("--eps-decay", type=float, default=5e-6)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--target-update-freq", type=int, default=1000)
parser.add_argument("--epoch", type=int, default=1000)
parser.add_argument("--step-per-epoch", type=int, default=80000)
parser.add_argument("--step-per-collect", type=int, default=16)
parser.add_argument("--update-per-step", type=float, default=0.0625)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--training-num", type=int, default=20)
parser.add_argument("--test-num", type=int, default=10)
# other
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
return parser.parse_args()


def test_bdq(args=get_args()):
env = gym.make(args.task)
env = ContinuousToDiscrete(env, args.action_per_branch)

args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.num_branches = args.action_shape if isinstance(args.action_shape,
int) else args.action_shape[0]

print("Observations shape:", args.state_shape)
print("Num branches:", args.num_branches)
print("Actions per branch:", args.action_per_branch)

# train_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = SubprocVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.training_num)
]
)
# test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
test_envs = SubprocVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.test_num)
]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = BranchingNet(
args.state_shape,
args.num_branches,
args.action_per_branch,
args.common_hidden_sizes,
args.value_hidden_sizes,
args.action_hidden_sizes,
device=args.device,
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = BranchingDQNPolicy(
net, optim, args.gamma, target_update_freq=args.target_update_freq
)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_path = os.path.join(args.logdir, "bdq", args.task, current_time)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
return mean_rewards >= getattr(env.spec.reward_threshold)

def train_fn(epoch, env_step): # exp decay
eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test)
policy.set_eps(eps)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
update_per_step=args.update_per_step,
# stop_fn=stop_fn,
train_fn=train_fn,
test_fn=test_fn,
save_best_fn=save_best_fn,
logger=logger
)

# assert stop_fn(result["best_reward"])
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


if __name__ == "__main__":
test_bdq(get_args())
Binary file added examples/box2d/results/bdq/BipedalWalker.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
150 changes: 150 additions & 0 deletions test/discrete/test_bdq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import argparse
import pprint

import gym
import numpy as np
import torch

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ContinuousToDiscrete, DummyVectorEnv
from tianshou.policy import BranchingDQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import BranchingNet


def get_args():
parser = argparse.ArgumentParser()
# task
parser.add_argument("--task", type=str, default="Pendulum-v1")
parser.add_argument('--reward-threshold', type=float, default=None)
# network architecture
parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[64])
parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[64])
parser.add_argument("--action-per-branch", type=int, default=40)
# training hyperparameters
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.01)
parser.add_argument("--eps-train", type=float, default=0.76)
parser.add_argument("--eps-decay", type=float, default=1e-4)
parser.add_argument("--buffer-size", type=int, default=20000)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--gamma", type=float, default=0.9)
parser.add_argument("--target-update-freq", type=int, default=200)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--step-per-epoch", type=int, default=80000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
args = parser.parse_known_args()[0]
return args


def test_bdq(args=get_args()):
env = gym.make(args.task)
env = ContinuousToDiscrete(env, args.action_per_branch)

args.state_shape = env.observation_space.shape or env.observation_space.n
args.num_branches = env.action_space.shape[0]

if args.reward_threshold is None:
default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250}
args.reward_threshold = default_reward_threshold.get(
args.task, env.spec.reward_threshold
)

print("Observations shape:", args.state_shape)
print("Num branches:", args.num_branches)
print("Actions per branch:", args.action_per_branch)

train_envs = DummyVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.training_num)
]
)
test_envs = DummyVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.test_num)
]
)

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = BranchingNet(
args.state_shape,
args.num_branches,
args.action_per_branch,
args.common_hidden_sizes,
args.value_hidden_sizes,
args.action_hidden_sizes,
device=args.device,
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = BranchingDQNPolicy(
net, optim, args.gamma, target_update_freq=args.target_update_freq
)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, args.training_num),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)

def train_fn(epoch, env_step): # exp decay
eps = max(args.eps_train * (1 - args.eps_decay)**env_step, args.eps_test)
policy.set_eps(eps)

def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)

def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold

# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
update_per_step=args.update_per_step,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
)

# assert stop_fn(result["best_reward"])
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")


if __name__ == "__main__":
test_bdq(get_args())
Loading