diff --git a/README.md b/README.md index 861788514..f6d79c56d 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ - [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) - [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf) - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) +- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index 0664e0107..aa24897c4 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -43,6 +43,11 @@ On-policy :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.NPGPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.A2CPolicy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index bd19d5618..08ed3245d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ * :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ * :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ +* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient `_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ * :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization `_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py new file mode 100644 index 000000000..d5172fa8b --- /dev/null +++ b/test/continuous/test_npg.py @@ -0,0 +1,136 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch import nn +from torch.utils.tensorboard import SummaryWriter +from torch.distributions import Independent, Normal + +from tianshou.policy import NPGPolicy +from tianshou.utils import BasicLogger +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import onpolicy_trainer +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--buffer-size', type=int, default=50000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.95) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--step-per-collect', type=int, default=2048) + parser.add_argument('--repeat-per-collect', type=int, + default=2) # theoretically it should be 1 + parser.add_argument('--batch-size', type=int, default=99999) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + # npg special + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--norm-adv', type=int, default=1) + parser.add_argument('--optim-critic-iters', type=int, default=5) + parser.add_argument('--actor-step-size', type=float, default=0.5) + args = parser.parse_known_args()[0] + return args + + +def test_npg(args=get_args()): + env = gym.make(args.task) + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -250 + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, + activation=nn.Tanh, device=args.device) + actor = ActorProb(net, args.action_shape, max_action=args.max_action, + unbounded=True, device=args.device).to(args.device) + critic = Critic(Net( + args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device, + activation=nn.Tanh), device=args.device).to(args.device) + # orthogonal initialization + for m in list(actor.modules()) + list(critic.modules()): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + optim = torch.optim.Adam(set( + actor.parameters()).union(critic.parameters()), lr=args.lr) + + # replace DiagGuassian with Independent(Normal) which is equivalent + # pass *logits to be consistent with policy.forward + def dist(*logits): + return Independent(Normal(*logits), 1) + + policy = NPGPolicy( + actor, critic, optim, dist, + discount_factor=args.gamma, + reward_normalization=args.rew_norm, + advantage_normalization=args.norm_adv, + gae_lambda=args.gae_lambda, + action_space=env.action_space, + optim_critic_iters=args.optim_critic_iters, + actor_step_size=args.actor_step_size) + # collector + train_collector = Collector( + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs))) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'npg') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + # trainer + result = onpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, + logger=logger) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_npg() diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 4b2dc08dc..9db4f449c 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -27,8 +27,7 @@ def get_args(): parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-collect', type=int, default=2048) - parser.add_argument('--repeat-per-collect', type=int, - default=2) # theoretically it should be 1 + parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=99999) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=16) @@ -43,7 +42,7 @@ def get_args(): parser.add_argument('--rew-norm', type=int, default=1) parser.add_argument('--norm-adv', type=int, default=1) parser.add_argument('--optim-critic-iters', type=int, default=5) - parser.add_argument('--max-kl', type=float, default=0.01) + parser.add_argument('--max-kl', type=float, default=0.005) parser.add_argument('--backtrack-coeff', type=float, default=0.8) parser.add_argument('--max-backtracks', type=int, default=10) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index d087a89b1..f0177d9cb 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -5,6 +5,7 @@ from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.a2c import A2CPolicy +from tianshou.policy.modelfree.npg import NPGPolicy from tianshou.policy.modelfree.ddpg import DDPGPolicy from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.trpo import TRPOPolicy @@ -25,6 +26,7 @@ "QRDQNPolicy", "PGPolicy", "A2CPolicy", + "NPGPolicy", "DDPGPolicy", "PPOPolicy", "TRPOPolicy", diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py new file mode 100644 index 000000000..abb6396af --- /dev/null +++ b/tianshou/policy/modelfree/npg.py @@ -0,0 +1,182 @@ +import torch +import numpy as np +from torch import nn +import torch.nn.functional as F +from typing import Any, Dict, List, Type +from torch.distributions import kl_divergence + + +from tianshou.policy import A2CPolicy +from tianshou.data import Batch, ReplayBuffer + + +class NPGPolicy(A2CPolicy): + """Implementation of Natural Policy Gradient. + + https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.nn.Module critic: the critic network. (s -> V(s)) + :param torch.optim.Optimizer optim: the optimizer for actor and critic network. + :param dist_fn: distribution class for computing the action. + :type dist_fn: Type[torch.distributions.Distribution] + :param bool advantage_normalization: whether to do per mini-batch advantage + normalization. Default to True. + :param int optim_critic_iters: Number of times to optimize critic network per + update. Default to 5. + :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation. + Default to 0.95. + :param bool reward_normalization: normalize estimated values to have std close to + 1. Default to False. + :param int max_batchsize: the maximum size of the batch when computing GAE, + depends on the size of available memory and the memory cost of the + model; should be as large as possible within the memory constraint. + Default to 256. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action), "tanh" (for applying tanh + squashing) for now, or empty string for no bounding. Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). + """ + + def __init__( + self, + actor: torch.nn.Module, + critic: torch.nn.Module, + optim: torch.optim.Optimizer, + dist_fn: Type[torch.distributions.Distribution], + advantage_normalization: bool = True, + optim_critic_iters: int = 5, + actor_step_size: float = 0.5, + **kwargs: Any, + ) -> None: + super().__init__(actor, critic, optim, dist_fn, **kwargs) + del self._weight_vf, self._weight_ent, self._grad_norm + self._norm_adv = advantage_normalization + self._optim_critic_iters = optim_critic_iters + self._step_size = actor_step_size + # adjusts Hessian-vector product calculation for numerical stability + self._damping = 0.1 + + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: + batch = super().process_fn(batch, buffer, indice) + old_log_prob = [] + with torch.no_grad(): + for b in batch.split(self._batch, shuffle=False, merge_last=True): + old_log_prob.append(self(b).dist.log_prob(b.act)) + batch.logp_old = torch.cat(old_log_prob, dim=0) + if self._norm_adv: + batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() + return batch + + def learn( # type: ignore + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, List[float]]: + actor_losses, vf_losses, kls = [], [], [] + for step in range(repeat): + for b in batch.split(batch_size, merge_last=True): + # optimize actor + # direction: calculate villia gradient + dist = self(b).dist # TODO could come from batch + ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() + ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) + actor_loss = -(ratio * b.adv).mean() + flat_grads = self._get_flat_grad( + actor_loss, self.actor, retain_graph=True).detach() + + # direction: calculate natural gradient + with torch.no_grad(): + old_dist = self(b).dist + + kl = kl_divergence(old_dist, dist).mean() + # calculate first order gradient of kl with respect to theta + flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) + search_direction = -self._conjugate_gradients( + flat_grads, flat_kl_grad, nsteps=10) + + # step + with torch.no_grad(): + flat_params = torch.cat([param.data.view(-1) + for param in self.actor.parameters()]) + new_flat_params = flat_params + self._step_size * search_direction + self._set_from_flat_params(self.actor, new_flat_params) + new_dist = self(b).dist + kl = kl_divergence(old_dist, new_dist).mean() + + # optimize citirc + for _ in range(self._optim_critic_iters): + value = self.critic(b.obs).flatten() + vf_loss = F.mse_loss(b.returns, value) + self.optim.zero_grad() + vf_loss.backward() + self.optim.step() + + actor_losses.append(actor_loss.item()) + vf_losses.append(vf_loss.item()) + kls.append(kl.item()) + + # update learning rate if lr_scheduler is given + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return { + "loss/actor": actor_losses, + "loss/vf": vf_losses, + "kl": kls, + } + + def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor: + """Matrix vector product.""" + # caculate second order gradient of kl with respect to theta + kl_v = (flat_kl_grad * v).sum() + flat_kl_grad_grad = self._get_flat_grad( + kl_v, self.actor, retain_graph=True).detach() + return flat_kl_grad_grad + v * self._damping + + def _conjugate_gradients( + self, + b: torch.Tensor, + flat_kl_grad: torch.Tensor, + nsteps: int = 10, + residual_tol: float = 1e-10 + ) -> torch.Tensor: + x = torch.zeros_like(b) + r, p = b.clone(), b.clone() + # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0. + # Change if doing warm start. + rdotr = r.dot(r) + for i in range(nsteps): + z = self._MVP(p, flat_kl_grad) + alpha = rdotr / p.dot(z) + x += alpha * p + r -= alpha * z + new_rdotr = r.dot(r) + if new_rdotr < residual_tol: + break + p = r + new_rdotr / rdotr * p + rdotr = new_rdotr + return x + + def _get_flat_grad( + self, y: torch.Tensor, model: nn.Module, **kwargs: Any + ) -> torch.Tensor: + grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore + return torch.cat([grad.reshape(-1) for grad in grads]) + + def _set_from_flat_params( + self, model: nn.Module, flat_params: torch.Tensor + ) -> nn.Module: + prev_ind = 0 + for param in model.parameters(): + flat_size = int(np.prod(list(param.size()))) + param.data.copy_( + flat_params[prev_ind:prev_ind + flat_size].view(param.size())) + prev_ind += flat_size + return model diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 32ba13976..9d456878c 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -1,56 +1,15 @@ import torch import warnings -import numpy as np -from torch import nn import torch.nn.functional as F +from typing import Any, Dict, List, Type from torch.distributions import kl_divergence -from typing import Any, Dict, List, Type, Callable - - -from tianshou.policy import A2CPolicy -from tianshou.data import Batch, ReplayBuffer - - -def _conjugate_gradients( - Avp: Callable[[torch.Tensor], torch.Tensor], - b: torch.Tensor, - nsteps: int = 10, - residual_tol: float = 1e-10 -) -> torch.Tensor: - x = torch.zeros_like(b) - r, p = b.clone(), b.clone() - # Note: should be 'r, p = b - A(x)', but for x=0, A(x)=0. - # Change if doing warm start. - rdotr = r.dot(r) - for i in range(nsteps): - z = Avp(p) - alpha = rdotr / p.dot(z) - x += alpha * p - r -= alpha * z - new_rdotr = r.dot(r) - if new_rdotr < residual_tol: - break - p = r + new_rdotr / rdotr * p - rdotr = new_rdotr - return x - - -def _get_flat_grad(y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor: - grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore - return torch.cat([grad.reshape(-1) for grad in grads]) - - -def _set_from_flat_params(model: nn.Module, flat_params: torch.Tensor) -> nn.Module: - prev_ind = 0 - for param in model.parameters(): - flat_size = int(np.prod(list(param.size()))) - param.data.copy_( - flat_params[prev_ind:prev_ind + flat_size].view(param.size())) - prev_ind += flat_size - return model - - -class TRPOPolicy(A2CPolicy): + + +from tianshou.data import Batch +from tianshou.policy import NPGPolicy + + +class TRPOPolicy(NPGPolicy): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. :param torch.nn.Module actor: the actor network following the rules in @@ -94,35 +53,16 @@ def __init__( critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: Type[torch.distributions.Distribution], - advantage_normalization: bool = True, - optim_critic_iters: int = 5, max_kl: float = 0.01, backtrack_coeff: float = 0.8, max_backtracks: int = 10, **kwargs: Any, ) -> None: super().__init__(actor, critic, optim, dist_fn, **kwargs) - del self._weight_vf, self._weight_ent, self._grad_norm - self._norm_adv = advantage_normalization - self._optim_critic_iters = optim_critic_iters + del self._step_size self._max_backtracks = max_backtracks self._delta = max_kl self._backtrack_coeff = backtrack_coeff - # adjusts Hessian-vector product calculation for numerical stability - self.__damping = 0.1 - - def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray - ) -> Batch: - batch = super().process_fn(batch, buffer, indice) - old_log_prob = [] - with torch.no_grad(): - for b in batch.split(self._batch, shuffle=False, merge_last=True): - old_log_prob.append(self(b).dist.log_prob(b.act)) - batch.logp_old = torch.cat(old_log_prob, dim=0) - if self._norm_adv: - batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std() - return batch def learn( # type: ignore self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any @@ -136,7 +76,7 @@ def learn( # type: ignore ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) actor_loss = -(ratio * b.adv).mean() - flat_grads = _get_flat_grad( + flat_grads = self._get_flat_grad( actor_loss, self.actor, retain_graph=True).detach() # direction: calculate natural gradient @@ -145,20 +85,14 @@ def learn( # type: ignore kl = kl_divergence(old_dist, dist).mean() # calculate first order gradient of kl with respect to theta - flat_kl_grad = _get_flat_grad(kl, self.actor, create_graph=True) - - def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product - # caculate second order gradient of kl with respect to theta - kl_v = (flat_kl_grad * v).sum() - flat_kl_grad_grad = _get_flat_grad( - kl_v, self.actor, retain_graph=True).detach() - return flat_kl_grad_grad + v * self.__damping - - search_direction = -_conjugate_gradients(MVP, flat_grads, nsteps=10) + flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True) + search_direction = -self._conjugate_gradients( + flat_grads, flat_kl_grad, nsteps=10) # stepsize: calculate max stepsize constrained by kl bound step_size = torch.sqrt(2 * self._delta / ( - search_direction * MVP(search_direction)).sum(0, keepdim=True)) + search_direction * self._MVP(search_direction, flat_kl_grad) + ).sum(0, keepdim=True)) # stepsize: linesearch stepsize with torch.no_grad(): @@ -166,7 +100,7 @@ def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product for param in self.actor.parameters()]) for i in range(self._max_backtracks): new_flat_params = flat_params + step_size * search_direction - _set_from_flat_params(self.actor, new_flat_params) + self._set_from_flat_params(self.actor, new_flat_params) # calculate kl and if in bound, loss actually down new_dist = self(b).dist new_dratio = ( @@ -183,7 +117,7 @@ def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product elif i < self._max_backtracks - 1: step_size = step_size * self._backtrack_coeff else: - _set_from_flat_params(self.actor, new_flat_params) + self._set_from_flat_params(self.actor, new_flat_params) step_size = torch.tensor([0.0]) warnings.warn("Line search failed! It seems hyperparamters" " are poor and need to be changed.")