diff --git a/README.md b/README.md index 1144a0cab..11f7b51b9 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) - [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf) +- [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/pdf/1606.03476.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index c3063665c..a0d9bed99 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -134,6 +134,11 @@ Imitation :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.GAILPolicy + :members: + :undoc-members: + :show-inheritance: + Model-based ----------- diff --git a/docs/index.rst b/docs/index.rst index 15ac74037..b131ff402 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ +* :class:`~tianshou.policy.GAILPolicy` `Generative Adversarial Imitation Learning `_ * :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 9f0377b43..86bdb281e 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -140,3 +140,6 @@ Strens Ornstein Uhlenbeck mse +gail +airl +ppo diff --git a/examples/inverse/README.md b/examples/inverse/README.md new file mode 100644 index 000000000..8e5276bbd --- /dev/null +++ b/examples/inverse/README.md @@ -0,0 +1,27 @@ +# Inverse Reinforcement Learning + +In inverse reinforcement learning setting, the agent learns a policy from interaction with an environment without reward and a fixed dataset which is collected with an expert policy. + +## Continuous control + +Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. + +We provide implementation of GAIL algorithm for continuous control. + +### Train + +You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `expert_buffer` of `GAILPolicy`. `irl_gail.py` is an example of inverse RL using the d4rl dataset. + +To train an agent with BCQ algorithm: + +```bash +python irl_gail.py --task HalfCheetah-v2 --expert-data-task halfcheetah-expert-v2 +``` + +## GAIL (single run) + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| HalfCheetah-v2 | 5177.07 | ![](results/gail/HalfCheetah-v2_rew.png) | `python3 irl_gail.py --task "HalfCheetah-v2" --expert-data-task "halfcheetah-expert-v2"` | +| Hopper-v2 | 1761.44 | ![](results/gail/Hopper-v2_rew.png) | `python3 irl_gail.py --task "Hopper-v2" --expert-data-task "hopper-expert-v2"` | +| Walker2d-v2 | 2020.77 | ![](results/gail/Walker2d-v2_rew.png) | `python3 irl_gail.py --task "Walker2d-v2" --expert-data-task "walker2d-expert-v2"` | diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py new file mode 100644 index 000000000..1e4c0389c --- /dev/null +++ b/examples/inverse/irl_gail.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import d4rl +import gym +import numpy as np +import torch +from torch import nn +from torch.distributions import Independent, Normal +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import GAILPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +class NoRewardEnv(gym.RewardWrapper): + """sets the reward to 0. + + :param gym.Env env: the environment to wrap. + """ + + def __init__(self, env): + super().__init__(env) + + def reward(self, reward): + """Set reward to 0.""" + return np.zeros_like(reward) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='HalfCheetah-v2') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument( + '--expert-data-task', type=str, default='halfcheetah-expert-v2' + ) + parser.add_argument('--buffer-size', type=int, default=4096) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--lr', type=float, default=3e-4) + parser.add_argument('--disc-lr', type=float, default=2.5e-5) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=30000) + parser.add_argument('--step-per-collect', type=int, default=2048) + parser.add_argument('--repeat-per-collect', type=int, default=10) + parser.add_argument('--disc-update-num', type=int, default=2) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--training-num', type=int, default=64) + parser.add_argument('--test-num', type=int, default=10) + # ppo special + parser.add_argument('--rew-norm', type=int, default=True) + # In theory, `vf-coef` will not make any difference if using Adam optimizer. + parser.add_argument('--vf-coef', type=float, default=0.25) + parser.add_argument('--ent-coef', type=float, default=0.001) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--bound-action-method', type=str, default="clip") + parser.add_argument('--lr-decay', type=int, default=True) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=0) + parser.add_argument('--norm-adv', type=int, default=0) + parser.add_argument('--recompute-adv', type=int, default=1) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + return parser.parse_args() + + +def test_gail(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: NoRewardEnv(gym.make(args.task)) for _ in range(args.training_num)], + norm_obs=True + ) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + norm_obs=True, + obs_rms=train_envs.obs_rms, + update_obs_rms=False + ) + + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net_a = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + unbounded=True, + device=args.device + ).to(args.device) + net_c = Net( + args.state_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device + ) + critic = Critic(net_c, device=args.device).to(args.device) + torch.nn.init.constant_(actor.sigma_param, -0.5) + for m in list(actor.modules()) + list(critic.modules()): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + # do last policy layer scaling, this will make initial actions have (close to) + # 0 mean and std, and will help boost performances, + # see https://arxiv.org/abs/2006.05990, Fig.24 for details + for m in actor.mu.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.zeros_(m.bias) + m.weight.data.copy_(0.01 * m.weight.data) + + optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) + # discriminator + net_d = Net( + args.state_shape, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, + device=args.device, + concat=True + ) + disc_net = Critic(net_d, device=args.device).to(args.device) + for m in disc_net.modules(): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) + + lr_scheduler = None + if args.lr_decay: + # decay learning rate to 0 linearly + max_update_num = np.ceil( + args.step_per_epoch / args.step_per_collect + ) * args.epoch + + lr_scheduler = LambdaLR( + optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num + ) + + def dist(*logits): + return Independent(Normal(*logits), 1) + + # expert replay buffer + dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) + dataset_size = dataset['rewards'].size + + print("dataset_size", dataset_size) + expert_buffer = ReplayBuffer(dataset_size) + + for i in range(dataset_size): + expert_buffer.add( + Batch( + obs=dataset['observations'][i], + act=dataset['actions'][i], + rew=dataset['rewards'][i], + done=dataset['terminals'][i], + obs_next=dataset['next_observations'][i], + ) + ) + print("dataset loaded") + + policy = GAILPolicy( + actor, + critic, + optim, + dist, + expert_buffer, + disc_net, + disc_optim, + disc_update_num=args.disc_update_num, + discount_factor=args.gamma, + gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + action_scaling=True, + action_bound_method=args.bound_action_method, + lr_scheduler=lr_scheduler, + action_space=env.action_space, + eps_clip=args.eps_clip, + value_clip=args.value_clip, + dual_clip=args.dual_clip, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail' + log_path = os.path.join(args.logdir, args.task, 'gail', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer, update_interval=100, train_interval=100) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + if not args.watch: + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + step_per_collect=args.step_per_collect, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) + pprint.pprint(result) + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_gail() diff --git a/examples/inverse/results/gail/HalfCheetah-v2_rew.png b/examples/inverse/results/gail/HalfCheetah-v2_rew.png new file mode 100644 index 000000000..f900a8c1a Binary files /dev/null and b/examples/inverse/results/gail/HalfCheetah-v2_rew.png differ diff --git a/examples/inverse/results/gail/Hopper-v2_rew.png b/examples/inverse/results/gail/Hopper-v2_rew.png new file mode 100644 index 000000000..1cf54253b Binary files /dev/null and b/examples/inverse/results/gail/Hopper-v2_rew.png differ diff --git a/examples/inverse/results/gail/Walker2d-v2_rew.png b/examples/inverse/results/gail/Walker2d-v2_rew.png new file mode 100644 index 000000000..46e711ba7 Binary files /dev/null and b/examples/inverse/results/gail/Walker2d-v2_rew.png differ diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py new file mode 100644 index 000000000..f68b7ae74 --- /dev/null +++ b/test/offline/test_gail.py @@ -0,0 +1,228 @@ +import argparse +import os +import pickle +import pprint + +import gym +import numpy as np +import torch +from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import GAILPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import ActorCritic, Net +from tianshou.utils.net.continuous import ActorProb, Critic + +if __name__ == "__main__": + from gather_pendulum_data import expert_file_name, gather_data +else: # pytest + from test.offline.gather_pendulum_data import expert_file_name, gather_data + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v1') + parser.add_argument("--reward-threshold", type=float, default=None) + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--disc-lr', type=float, default=5e-4) + parser.add_argument('--gamma', type=float, default=0.95) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=150000) + parser.add_argument('--episode-per-collect', type=int, default=16) + parser.add_argument('--repeat-per-collect', type=int, default=2) + parser.add_argument('--disc-update-num', type=int, default=2) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + # ppo special + parser.add_argument('--vf-coef', type=float, default=0.25) + parser.add_argument('--ent-coef', type=float, default=0.0) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=1) + parser.add_argument('--norm-adv', type=int, default=1) + parser.add_argument('--recompute-adv', type=int, default=0) + parser.add_argument('--resume', action="store_true") + parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument("--load-buffer-name", type=str, default=expert_file_name()) + args = parser.parse_known_args()[0] + return args + + +def test_gail(args=get_args()): + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + if args.load_buffer_name.endswith(".hdf5"): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() + env = gym.make(args.task) + if args.reward_threshold is None: + default_reward_threshold = {"Pendulum-v0": -1100, "Pendulum-v1": -1100} + args.reward_threshold = default_reward_threshold.get( + args.task, env.spec.reward_threshold + ) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) + critic = Critic( + Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + device=args.device + ).to(args.device) + actor_critic = ActorCritic(actor, critic) + # orthogonal initialization + for m in actor_critic.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) + # discriminator + disc_net = Critic( + Net( + args.state_shape, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, + activation=torch.nn.Tanh, + device=args.device, + concat=True, + ), + device=args.device + ).to(args.device) + for m in disc_net.modules(): + if isinstance(m, torch.nn.Linear): + # orthogonal initialization + torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) + torch.nn.init.zeros_(m.bias) + disc_optim = torch.optim.Adam(disc_net.parameters(), lr=args.disc_lr) + + # replace DiagGuassian with Independent(Normal) which is equivalent + # pass *logits to be consistent with policy.forward + def dist(*logits): + return Independent(Normal(*logits), 1) + + policy = GAILPolicy( + actor, + critic, + optim, + dist, + buffer, + disc_net, + disc_optim, + disc_update_num=args.disc_update_num, + discount_factor=args.gamma, + max_grad_norm=args.max_grad_norm, + eps_clip=args.eps_clip, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv, + dual_clip=args.dual_clip, + value_clip=args.value_clip, + gae_lambda=args.gae_lambda, + action_space=env.action_space, + ) + # collector + train_collector = Collector( + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'gail') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer, save_interval=args.save_interval) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= args.reward_threshold + + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + torch.save( + { + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth') + ) + + if args.resume: + # load from existing checkpoint + print(f"Loading agent under {log_path}") + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location=args.device) + policy.load_state_dict(checkpoint['model']) + optim.load_state_dict(checkpoint['optim']) + print("Successfully restore policy and optim.") + else: + print("Fail to restore policy and optim.") + + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume, + save_checkpoint_fn=save_checkpoint_fn, + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_gail() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index ced11aff5..dae9da638 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -24,6 +24,7 @@ from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy +from tianshou.policy.imitation.gail import GAILPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.modelbased.icm import ICMPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -52,6 +53,7 @@ "DiscreteBCQPolicy", "DiscreteCQLPolicy", "DiscreteCRRPolicy", + "GAILPolicy", "PSRLPolicy", "ICMPolicy", "MultiAgentPolicyManager", diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py new file mode 100644 index 000000000..d4779326b --- /dev/null +++ b/tianshou/policy/imitation/gail.py @@ -0,0 +1,139 @@ +from typing import Any, Dict, List, Optional, Type + +import numpy as np +import torch +import torch.nn.functional as F + +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch +from tianshou.policy import PPOPolicy + + +class GAILPolicy(PPOPolicy): + r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.nn.Module critic: the critic network. (s -> V(s)) + :param torch.optim.Optimizer optim: the optimizer for actor and critic network. + :param dist_fn: distribution class for computing the action. + :type dist_fn: Type[torch.distributions.Distribution] + :param ReplayBuffer expert_buffer: the replay buffer contains expert experience. + :param torch.nn.Module disc_net: the discriminator network with input dim equals + state dim plus action dim and output dim equals 1. + :param torch.optim.Optimizer disc_optim: the optimizer for the discriminator + network. + :param int disc_update_num: the number of discriminator grad steps per model grad + step. Default to 4. + :param float discount_factor: in [0, 1]. Default to 0.99. + :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original + paper. Default to 0.2. + :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, + where c > 1 is a constant indicating the lower bound. + Default to 5.0 (set None if you do not want to use it). + :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. + Default to True. + :param bool advantage_normalization: whether to do per mini-batch advantage + normalization. Default to True. + :param bool recompute_advantage: whether to recompute advantage every update + repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5. + Default to False. + :param float vf_coef: weight for value loss. Default to 0.5. + :param float ent_coef: weight for entropy loss. Default to 0.01. + :param float max_grad_norm: clipping gradients in back propagation. Default to + None. + :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + Default to 0.95. + :param bool reward_normalization: normalize estimated values to have std close + to 1, also normalize the advantage to Normal(0, 1). Default to False. + :param int max_batchsize: the maximum size of the batch when computing GAE, + depends on the size of available memory and the memory cost of the model; + should be as large as possible within the memory constraint. Default to 256. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). + :param bool deterministic_eval: whether to use deterministic action instead of + stochastic action sampled by the policy. Default to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.PPOPolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + critic: torch.nn.Module, + optim: torch.optim.Optimizer, + dist_fn: Type[torch.distributions.Distribution], + expert_buffer: ReplayBuffer, + disc_net: torch.nn.Module, + disc_optim: torch.optim.Optimizer, + disc_update_num: int = 4, + eps_clip: float = 0.2, + dual_clip: Optional[float] = None, + value_clip: bool = False, + advantage_normalization: bool = True, + recompute_advantage: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + actor, critic, optim, dist_fn, eps_clip, dual_clip, value_clip, + advantage_normalization, recompute_advantage, **kwargs + ) + self.disc_net = disc_net + self.disc_optim = disc_optim + self.disc_update_num = disc_update_num + self.expert_buffer = expert_buffer + self.action_dim = actor.output_dim + + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + ) -> Batch: + """Pre-process the data from the provided replay buffer. + + Used in :meth:`update`. Check out :ref:`process_fn` for more information. + """ + # update reward + with torch.no_grad(): + batch.rew = to_numpy(-F.logsigmoid(-self.disc(batch)).flatten()) + return super().process_fn(batch, buffer, indices) + + def disc(self, batch: Batch) -> torch.Tensor: + obs = to_torch(batch.obs, device=self.disc_net.device) # type: ignore + act = to_torch(batch.act, device=self.disc_net.device) # type: ignore + return self.disc_net(torch.cat([obs, act], dim=1)) # type: ignore + + def learn( # type: ignore + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, List[float]]: + # update discriminator + losses = [] + acc_pis = [] + acc_exps = [] + bsz = len(batch) // self.disc_update_num + for b in batch.split(bsz, merge_last=True): + logits_pi = self.disc(b) + exp_b = self.expert_buffer.sample(bsz)[0] + logits_exp = self.disc(exp_b) + loss_pi = -F.logsigmoid(-logits_pi).mean() + loss_exp = -F.logsigmoid(logits_exp).mean() + loss_disc = loss_pi + loss_exp + self.disc_optim.zero_grad() + loss_disc.backward() + self.disc_optim.step() + losses.append(loss_disc.item()) + acc_pis.append((logits_pi < 0).float().mean().item()) + acc_exps.append((logits_exp > 0).float().mean().item()) + # update policy + res = super().learn(batch, batch_size, repeat, **kwargs) + res["loss/disc"] = losses + res["stats/acc_pi"] = acc_pis + res["stats/acc_exp"] = acc_exps + return res