diff --git a/README.md b/README.md index 512cd7697..13cfc191f 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning +- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) - [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index b05f5be42..7292afdcc 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -109,6 +109,11 @@ Imitation :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.BCQPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.DiscreteBCQPolicy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index b56bce367..a7fa0da26 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -27,6 +27,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning +* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ diff --git a/examples/offline/README.md b/examples/offline/README.md new file mode 100644 index 000000000..8995ee6e2 --- /dev/null +++ b/examples/offline/README.md @@ -0,0 +1,28 @@ +# Offline + +In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. + +Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. + +## Train + +Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. + +To train an agent with BCQ algorithm: + +```bash +python offline_bcq.py --task halfcheetah-expert-v1 +``` + +After 1M steps: + +![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png) + +`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment. + +## Results + +| Environment | BCQ | +| --------------------- | --------------- | +| halfcheetah-expert-v1 | 10624.0 ± 181.4 | + diff --git a/examples/offline/offline_bcq.py b/examples/offline/offline_bcq.py new file mode 100644 index 000000000..e488489e2 --- /dev/null +++ b/examples/offline/offline_bcq.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +import argparse +import datetime +import os +import pprint + +import d4rl +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import BCQPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import BasicLogger +from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.continuous import VAE, Critic, Perturbation + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='halfcheetah-expert-v1') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=1000000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[400, 300]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument('--epoch', type=int, default=200) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=1 / 35) + + parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[750, 750]) + # default to 2 * action_dim + parser.add_argument('--latent-dim', type=int) + parser.add_argument("--gamma", default=0.99) + parser.add_argument("--tau", default=0.005) + # Weighting for Clipped Double Q-learning in BCQ + parser.add_argument("--lmbda", default=0.75) + # Max perturbation hyper-parameter for BCQ + parser.add_argument("--phi", default=0.05) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + return parser.parse_args() + + +def test_bcq(): + args = get_args() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + print("device:", args.device) + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + print("Max_action", args.max_action) + + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + # model + # perturbation network + net_a = MLP( + input_dim=args.state_dim + args.action_dim, + output_dim=args.action_dim, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = Perturbation( + net_a, max_action=args.max_action, device=args.device, phi=args.phi + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + # vae + # output_dim = 0, so the last Module in the encoder is ReLU + vae_encoder = MLP( + input_dim=args.state_dim + args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + if not args.latent_dim: + args.latent_dim = args.action_dim * 2 + vae_decoder = MLP( + input_dim=args.state_dim + args.latent_dim, + output_dim=args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + vae = VAE( + vae_encoder, + vae_decoder, + hidden_dim=args.vae_hidden_sizes[-1], + latent_dim=args.latent_dim, + max_action=args.max_action, + device=args.device, + ).to(args.device) + vae_optim = torch.optim.Adam(vae.parameters()) + + policy = BCQPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + vae, + vae_optim, + device=args.device, + gamma=args.gamma, + tau=args.tau, + lmbda=args.lmbda, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' + log_path = os.path.join(args.logdir, args.task, 'bcq', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def watch(): + if args.resume_path is None: + args.resume_path = os.path.join(log_path, 'policy.pth') + + policy.load_state_dict( + torch.load(args.resume_path, map_location=torch.device('cpu')) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + if not args.watch: + dataset = d4rl.qlearning_dataset(env) + dataset_size = dataset['rewards'].size + + print("dataset_size", dataset_size) + replay_buffer = ReplayBuffer(dataset_size) + + for i in range(dataset_size): + replay_buffer.add( + Batch( + obs=dataset['observations'][i], + act=dataset['actions'][i], + rew=dataset['rewards'][i], + done=dataset['terminals'][i], + obs_next=dataset['next_observations'][i], + ) + ) + print("dataset loaded") + # trainer + result = offline_trainer( + policy, + replay_buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + ) + pprint.pprint(result) + else: + watch() + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_bcq() diff --git a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png new file mode 100644 index 000000000..5afa6a3ad Binary files /dev/null and b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.png differ diff --git a/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg new file mode 100644 index 000000000..87ede75ed --- /dev/null +++ b/examples/offline/results/bcq/halfcheetah-expert-v1_reward.svg @@ -0,0 +1 @@ +1e+32e+33e+34e+35e+36e+37e+38e+39e+31e+40100k200k300k400k500k600k700k800k900k1M1.1M \ No newline at end of file diff --git a/test/base/test_env.py b/test/base/test_env.py index 7f47501c3..dbd651d14 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -134,7 +134,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): SubprocVectorEnv(env_fns), ShmemVectorEnv(env_fns), ] - if has_ray(): + if has_ray() and sys.platform == "linux": venv += [RayVectorEnv(env_fns)] for v in venv: v.seed(0) diff --git a/test/offline/__init__.py b/test/offline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py new file mode 100644 index 000000000..4c0275e69 --- /dev/null +++ b/test/offline/gather_pendulum_data.py @@ -0,0 +1,170 @@ +import argparse +import os +import pickle + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import SACPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=200000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--epoch', type=int, default=7) + parser.add_argument('--step-per-epoch', type=int, default=8000) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.125) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + + parser.add_argument("--gamma", default=0.99) + parser.add_argument("--tau", default=0.005) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + # sac: + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', type=int, default=1) + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument('--rew-norm', action="store_true", default=False) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument( + "--save-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" + ) + args = parser.parse_known_args()[0] + return args + + +def gather_data(): + """Return expert buffer data.""" + args = get_args() + env = gym.make(args.task) + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -250 + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = SACPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + reward_normalization=args.rew_norm, + estimation_step=args.n_step, + action_space=env.action_space, + ) + # collector + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # train_collector.collect(n_step=args.buffer_size) + # log + log_path = os.path.join(args.logdir, args.task, 'sac') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + # trainer + offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + save_fn=save_fn, + stop_fn=stop_fn, + logger=logger, + ) + train_collector.reset() + result = train_collector.collect(n_step=args.buffer_size) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + pickle.dump(buffer, open(args.save_buffer_name, "wb")) + return buffer diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py new file mode 100644 index 000000000..ab98e497a --- /dev/null +++ b/test/offline/test_bcq.py @@ -0,0 +1,221 @@ +import argparse +import datetime +import os +import pickle +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector +from tianshou.env import SubprocVectorEnv +from tianshou.policy import BCQPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import MLP, Net +from tianshou.utils.net.continuous import VAE, Critic, Perturbation + +if __name__ == "__main__": + from gather_pendulum_data import gather_data +else: # pytest + from test.offline.gather_pendulum_data import gather_data + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[200, 150]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--epoch', type=int, default=7) + parser.add_argument('--step-per-epoch', type=int, default=2000) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + + parser.add_argument("--vae-hidden-sizes", type=int, nargs='*', default=[375, 375]) + # default to 2 * action_dim + parser.add_argument('--latent_dim', type=int, default=None) + parser.add_argument("--gamma", default=0.99) + parser.add_argument("--tau", default=0.005) + # Weighting for Clipped Double Q-learning in BCQ + parser.add_argument("--lmbda", default=0.75) + # Max perturbation hyper-parameter for BCQ + parser.add_argument("--phi", default=0.05) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + parser.add_argument( + "--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" + ) + args = parser.parse_known_args()[0] + return args + + +def test_bcq(args=get_args()): + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -800 # too low? + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + # model + # perturbation network + net_a = MLP( + input_dim=args.state_dim + args.action_dim, + output_dim=args.action_dim, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = Perturbation( + net_a, max_action=args.max_action, device=args.device, phi=args.phi + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + # vae + # output_dim = 0, so the last Module in the encoder is ReLU + vae_encoder = MLP( + input_dim=args.state_dim + args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + if not args.latent_dim: + args.latent_dim = args.action_dim * 2 + vae_decoder = MLP( + input_dim=args.state_dim + args.latent_dim, + output_dim=args.action_dim, + hidden_sizes=args.vae_hidden_sizes, + device=args.device, + ) + vae = VAE( + vae_encoder, + vae_decoder, + hidden_dim=args.vae_hidden_sizes[-1], + latent_dim=args.latent_dim, + max_action=args.max_action, + device=args.device, + ).to(args.device) + vae_optim = torch.optim.Adam(vae.parameters()) + + policy = BCQPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + vae, + vae_optim, + device=args.device, + gamma=args.gamma, + tau=args.tau, + lmbda=args.lmbda, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + # buffer has been gathered + # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' + log_path = os.path.join(args.logdir, args.task, 'bcq', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def watch(): + policy.load_state_dict( + torch.load( + os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu') + ) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + # trainer + result = offline_trainer( + policy, + buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + stop_fn=stop_fn, + logger=logger, + ) + assert stop_fn(result['best_reward']) + + # Let's watch its performance! + if __name__ == '__main__': + pprint.pprint(result) + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_bcq() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 6a842356f..174762e25 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -19,6 +19,7 @@ from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.imitation.base import ImitationPolicy +from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy @@ -44,6 +45,7 @@ "SACPolicy", "DiscreteSACPolicy", "ImitationPolicy", + "BCQPolicy", "DiscreteBCQPolicy", "DiscreteCQLPolicy", "DiscreteCRRPolicy", diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py new file mode 100644 index 000000000..2aeeb323d --- /dev/null +++ b/tianshou/policy/imitation/bcq.py @@ -0,0 +1,213 @@ +import copy +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, to_torch +from tianshou.policy import BasePolicy +from tianshou.utils.net.continuous import VAE + + +class BCQPolicy(BasePolicy): + """Implementation of BCQ algorithm. arXiv:1812.02900. + + :param Perturbation actor: the actor perturbation. (s, a -> perturbed a) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic1_optim: the optimizer for the first + critic network. + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic2_optim: the optimizer for the second + critic network. + :param VAE vae: the VAE network, generating actions similar + to those in batch. (s, a -> generated a) + :param torch.optim.Optimizer vae_optim: the optimizer for the VAE network. + :param Union[str, torch.device] device: which device to create this model on. + Default to "cpu". + :param float gamma: discount factor, in [0, 1]. Default to 0.99. + :param float tau: param for soft update of the target network. + Default to 0.005. + :param float lmbda: param for Clipped Double Q-learning. Default to 0.75. + :param int forward_sampled_times: the number of sampled actions in forward + function. The policy samples many actions and takes the action with the + max value. Default to 100. + :param int num_sampled_action: the number of sampled actions in calculating + target Q. The algorithm samples several actions using VAE, and perturbs + each action to get the target Q. Default to 10. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + vae: VAE, + vae_optim: torch.optim.Optimizer, + device: Union[str, torch.device] = "cpu", + gamma: float = 0.99, + tau: float = 0.005, + lmbda: float = 0.75, + forward_sampled_times: int = 100, + num_sampled_action: int = 10, + **kwargs: Any + ) -> None: + # actor is Perturbation! + super().__init__(**kwargs) + self.actor = actor + self.actor_target = copy.deepcopy(self.actor) + self.actor_optim = actor_optim + + self.critic1 = critic1 + self.critic1_target = copy.deepcopy(self.critic1) + self.critic1_optim = critic1_optim + + self.critic2 = critic2 + self.critic2_target = copy.deepcopy(self.critic2) + self.critic2_optim = critic2_optim + + self.vae = vae + self.vae_optim = vae_optim + + self.gamma = gamma + self.tau = tau + self.lmbda = lmbda + self.device = device + self.forward_sampled_times = forward_sampled_times + self.num_sampled_action = num_sampled_action + + def train(self, mode: bool = True) -> "BCQPolicy": + """Set the module in training mode, except for the target network.""" + self.training = mode + self.actor.train(mode) + self.critic1.train(mode) + self.critic2.train(mode) + return self + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: + """Compute action over the given batch data.""" + # There is "obs" in the Batch + # obs_group: several groups. Each group has a state. + obs_group: torch.Tensor = to_torch( # type: ignore + batch.obs, device=self.device + ) + act = [] + for obs in obs_group: + # now obs is (state_dim) + obs = (obs.reshape(1, -1)).repeat(self.forward_sampled_times, 1) + # now obs is (forward_sampled_times, state_dim) + + # decode(obs) generates action and actor perturbs it + action = self.actor(obs, self.vae.decode(obs)) + # now action is (forward_sampled_times, action_dim) + q1 = self.critic1(obs, action) + # q1 is (forward_sampled_times, 1) + ind = q1.argmax(0) + act.append(action[ind].cpu().data.numpy().flatten()) + act = np.array(act) + return Batch(act=act) + + def sync_weight(self) -> None: + """Soft-update the weight for the target network.""" + for net, net_target in [ + [self.critic1, self.critic1_target], [self.critic2, self.critic2_target], + [self.actor, self.actor_target] + ]: + for param, target_param in zip(net.parameters(), net_target.parameters()): + target_param.data.copy_( + self.tau * param.data + (1 - self.tau) * target_param.data + ) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + # batch: obs, act, rew, done, obs_next. (numpy array) + # (batch_size, state_dim) + batch: Batch = to_torch( # type: ignore + batch, dtype=torch.float, device=self.device + ) + obs, act = batch.obs, batch.act + batch_size = obs.shape[0] + + # mean, std: (state.shape[0], latent_dim) + recon, mean, std = self.vae(obs, act) + recon_loss = F.mse_loss(act, recon) + # (....) is D_KL( N(mu, sigma) || N(0,1) ) + KL_loss = (-torch.log(std) + (std.pow(2) + mean.pow(2) - 1) / 2).mean() + vae_loss = recon_loss + KL_loss / 2 + + self.vae_optim.zero_grad() + vae_loss.backward() + self.vae_optim.step() + + # critic training: + with torch.no_grad(): + # repeat num_sampled_action times + obs_next = batch.obs_next.repeat_interleave(self.num_sampled_action, dim=0) + # now obs_next: (num_sampled_action * batch_size, state_dim) + + # perturbed action generated by VAE + act_next = self.vae.decode(obs_next) + # now obs_next: (num_sampled_action * batch_size, action_dim) + target_Q1 = self.critic1_target(obs_next, act_next) + target_Q2 = self.critic2_target(obs_next, act_next) + + # Clipped Double Q-learning + target_Q = \ + self.lmbda * torch.min(target_Q1, target_Q2) + \ + (1 - self.lmbda) * torch.max(target_Q1, target_Q2) + # now target_Q: (num_sampled_action * batch_size, 1) + + # the max value of Q + target_Q = target_Q.reshape(batch_size, -1).max(dim=1)[0].reshape(-1, 1) + # now target_Q: (batch_size, 1) + + target_Q = \ + batch.rew.reshape(-1, 1) + \ + (1 - batch.done).reshape(-1, 1) * self.gamma * target_Q + + current_Q1 = self.critic1(obs, act) + current_Q2 = self.critic2(obs, act) + + critic1_loss = F.mse_loss(current_Q1, target_Q) + critic2_loss = F.mse_loss(current_Q2, target_Q) + + self.critic1_optim.zero_grad() + self.critic2_optim.zero_grad() + critic1_loss.backward() + critic2_loss.backward() + self.critic1_optim.step() + self.critic2_optim.step() + + sampled_act = self.vae.decode(obs) + perturbed_act = self.actor(obs, sampled_act) + + # max + actor_loss = -self.critic1(obs, perturbed_act).mean() + + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + # update target network + self.sync_weight() + + result = { + "loss/actor": actor_loss.item(), + "loss/critic1": critic1_loss.item(), + "loss/critic2": critic2_loss.item(), + "loss/vae": vae_loss.item(), + } + return result diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 1bb090cdf..761540502 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -325,3 +325,122 @@ def forward( s = torch.cat([s, a], dim=1) s = self.fc2(s) return s + + +class Perturbation(nn.Module): + """Implementation of perturbation network in BCQ algorithm. Given a state and \ + action, it can generate perturbed action. + + :param torch.nn.Module preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param float max_action: the maximum value of each dimension of action. + :param Union[str, int, torch.device] device: which device to create this model on. + Default to cpu. + :param float phi: max perturbation parameter for BCQ. Default to 0.05. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + .. seealso:: + + You can refer to `examples/offline/offline_bcq.py` to see how to use it. + """ + + def __init__( + self, + preprocess_net: nn.Module, + max_action: float, + device: Union[str, int, torch.device] = "cpu", + phi: float = 0.05 + ): + # preprocess_net: input_dim=state_dim+action_dim, output_dim=action_dim + super(Perturbation, self).__init__() + self.preprocess_net = preprocess_net + self.device = device + self.max_action = max_action + self.phi = phi + + def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor: + # preprocess_net + logits = self.preprocess_net(torch.cat([state, action], -1))[0] + a = self.phi * self.max_action * torch.tanh(logits) + # clip to [-max_action, max_action] + return (a + action).clamp(-self.max_action, self.max_action) + + +class VAE(nn.Module): + """Implementation of VAE. It models the distribution of action. Given a \ + state, it can generate actions similar to those in batch. It is used \ + in BCQ algorithm. + + :param torch.nn.Module encoder: the encoder in VAE. Its input_dim must be + state_dim + action_dim, and output_dim must be hidden_dim. + :param torch.nn.Module decoder: the decoder in VAE. Its input_dim must be + state_dim + latent_dim, and output_dim must be action_dim. + :param int hidden_dim: the size of the last linear-layer in encoder. + :param int latent_dim: the size of latent layer. + :param float max_action: the maximum value of each dimension of action. + :param Union[str, torch.device] device: which device to create this model on. + Default to "cpu". + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + .. seealso:: + + You can refer to `examples/offline/offline_bcq.py` to see how to use it. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + hidden_dim: int, + latent_dim: int, + max_action: float, + device: Union[str, torch.device] = "cpu" + ): + super(VAE, self).__init__() + self.encoder = encoder + + self.mean = nn.Linear(hidden_dim, latent_dim) + self.log_std = nn.Linear(hidden_dim, latent_dim) + + self.decoder = decoder + + self.max_action = max_action + self.latent_dim = latent_dim + self.device = device + + def forward( + self, state: torch.Tensor, action: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [state, action] -> z , [state, z] -> action + z = self.encoder(torch.cat([state, action], -1)) + # shape of z: (state.shape[:-1], hidden_dim) + + mean = self.mean(z) + # Clamped for numerical stability + log_std = self.log_std(z).clamp(-4, 15) + std = torch.exp(log_std) + # shape of mean, std: (state.shape[:-1], latent_dim) + + z = mean + std * torch.randn_like(std) # (state.shape[:-1], latent_dim) + + u = self.decode(state, z) # (state.shape[:-1], action_dim) + return u, mean, std + + def decode( + self, + state: torch.Tensor, + z: Union[torch.Tensor, None] = None + ) -> torch.Tensor: + # decode(state) -> action + if z is None: + # state.shape[0] may be batch_size + # latent vector clipped to [-0.5, 0.5] + z = torch.randn(state.shape[:-1] + (self.latent_dim, )) \ + .to(self.device).clamp(-0.5, 0.5) + + # decode z with state! + return self.max_action * torch.tanh(self.decoder(torch.cat([state, z], -1)))