这是indexloc提供的服务,不要输入任何密码
Skip to content

Implement ContinuousBCQPolicy and offline_bcq example #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
fb84b4b
finish ContinuousBCQPolicy and mujoco_bcq code
thkkk Nov 16, 2021
1029f64
finish ContinuousBCQPolicy and mujoco_bcq example code
thkkk Nov 17, 2021
f4328c2
update docstring
thkkk Nov 17, 2021
69c9c87
formatted
thkkk Nov 17, 2021
aba37b2
reset
thkkk Nov 18, 2021
0f20e18
Implement ContinuousBCQPolicy and offline_bcq example
thkkk Nov 18, 2021
c16e1da
update some comments
thkkk Nov 18, 2021
51072e3
Merge branch 'master' into master
Trinkle23897 Nov 18, 2021
615fd01
Add ContinuousBCQ test and update offline_bcq example
thkkk Nov 19, 2021
cebaf9d
Merge branch 'master' of github.com:thkkk/tianshou
thkkk Nov 19, 2021
5e407d3
Update ContinuousBCQ test
thkkk Nov 19, 2021
d1b8e8a
Rename ContinuousBCQ to BCQ
thkkk Nov 19, 2021
209f887
Add BCQ
thkkk Nov 20, 2021
632ce9b
fix docs
thkkk Nov 20, 2021
a2fe98d
fix: update readme of offline example
thkkk Nov 20, 2021
70c7406
fix docstring
thkkk Nov 21, 2021
50e7400
modify some comments
thkkk Nov 21, 2021
eeb2bfa
Add parameters in BCQ
thkkk Nov 21, 2021
4972e50
Move VAE and Pertubation to utils/net/continuous.py
thkkk Nov 21, 2021
05d7adf
Add an arg in offline_bcq
thkkk Nov 21, 2021
8d08351
code format
thkkk Nov 21, 2021
76a1d83
Add gather_pendulum_data for unittest
thkkk Nov 21, 2021
032e41f
simplify
Trinkle23897 Nov 22, 2021
8cb5bc2
move all bcq tests under offline/
Trinkle23897 Nov 22, 2021
1e671f2
fix mypy
Trinkle23897 Nov 22, 2021
731f4db
remove dill and use offline_trainer to test bcq
Trinkle23897 Nov 22, 2021
c90ac1d
skip win/mac vecenv test
Trinkle23897 Nov 22, 2021
52ac650
polish
Trinkle23897 Nov 22, 2021
9ff232d
polish
Trinkle23897 Nov 22, 2021
029f869
Modify VAE and Perturbation, in order to adapt to more dimensional input
thkkk Nov 22, 2021
1cc3966
Merge branch 'master' of github.com:thkkk/tianshou
thkkk Nov 22, 2021
aa884a6
polish
Trinkle23897 Nov 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
Expand Down
5 changes: 5 additions & 0 deletions docs/api/tianshou.policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ Imitation
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.BCQPolicy
:members:
:undoc-members:
:show-inheritance:

.. autoclass:: tianshou.policy.DiscreteBCQPolicy
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
Expand Down
28 changes: 28 additions & 0 deletions examples/offline/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Offline

In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.

Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.

## Train

Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.

To train an agent with BCQ algorithm:

```bash
python offline_bcq.py --task halfcheetah-expert-v1
```

After 1M steps:

![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png)

`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment.

## Results

| Environment | BCQ |
| --------------------- | --------------- |
| halfcheetah-expert-v1 | 10624.0 ± 181.4 |

241 changes: 241 additions & 0 deletions examples/offline/offline_bcq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint

import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='halfcheetah-expert-v1')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=1000000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35)

parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750])
# default to 2 * action_dim
parser.add_argument('--latent-dim', type=int)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
# Weighting for Clipped Double Q-learning in BCQ
parser.add_argument("--lmbda", default=0.75)
# Max perturbation hyper-parameter for BCQ
parser.add_argument("--phi", default=0.05)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
return parser.parse_args()


def test_bcq():
args = get_args()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))

args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
print("Max_action", args.max_action)

# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)

# model
# perturbation network
net_a = MLP(
input_dim=args.state_dim + args.action_dim,
output_dim=args.action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Perturbation(
net_a, max_action=args.max_action, device=args.device, phi=args.phi
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)

net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

# vae
# output_dim = 0, so the last Module in the encoder is ReLU
vae_encoder = MLP(
input_dim=args.state_dim + args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
if not args.latent_dim:
args.latent_dim = args.action_dim * 2
vae_decoder = MLP(
input_dim=args.state_dim + args.latent_dim,
output_dim=args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
vae = VAE(
vae_encoder,
vae_decoder,
hidden_dim=args.vae_hidden_sizes[-1],
latent_dim=args.latent_dim,
max_action=args.max_action,
device=args.device,
).to(args.device)
vae_optim = torch.optim.Adam(vae.parameters())

policy = BCQPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
vae,
vae_optim,
device=args.device,
gamma=args.gamma,
tau=args.tau,
lmbda=args.lmbda,
)

# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)

# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq'
log_path = os.path.join(args.logdir, args.task, 'bcq', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)

def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def watch():
if args.resume_path is None:
args.resume_path = os.path.join(log_path, 'policy.pth')

policy.load_state_dict(
torch.load(args.resume_path, map_location=torch.device('cpu'))
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)

if not args.watch:
dataset = d4rl.qlearning_dataset(env)
dataset_size = dataset['rewards'].size

print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)

for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset['observations'][i],
act=dataset['actions'][i],
rew=dataset['rewards'][i],
done=dataset['terminals'][i],
obs_next=dataset['next_observations'][i],
)
)
print("dataset loaded")
# trainer
result = offline_trainer(
policy,
replay_buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_fn=save_fn,
logger=logger,
)
pprint.pprint(result)
else:
watch()

# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')


if __name__ == '__main__':
test_bcq()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
SubprocVectorEnv(env_fns),
ShmemVectorEnv(env_fns),
]
if has_ray():
if has_ray() and sys.platform == "linux":
venv += [RayVectorEnv(env_fns)]
for v in venv:
v.seed(0)
Expand Down
Empty file added test/offline/__init__.py
Empty file.
Loading