From 27f8abe1208fc843092b3dda58c35a219af2dd29 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 30 Apr 2022 01:06:23 +0800 Subject: [PATCH 1/3] implement REDQ based on original contribution by @Jimenius --- README.md | 1 + docs/api/tianshou.policy.rst | 5 + docs/index.rst | 1 + examples/mujoco/mujoco_redq.py | 192 +++++++++++++++++++++++++++ test/continuous/test_redq.py | 178 +++++++++++++++++++++++++ tianshou/policy/__init__.py | 2 + tianshou/policy/modelfree/redq.py | 207 ++++++++++++++++++++++++++++++ tianshou/utils/net/common.py | 49 ++++++- tianshou/utils/net/continuous.py | 11 +- 9 files changed, 642 insertions(+), 4 deletions(-) create mode 100755 examples/mujoco/mujoco_redq.py create mode 100644 test/continuous/test_redq.py create mode 100644 tianshou/policy/modelfree/redq.py diff --git a/README.md b/README.md index aaf610549..46e3f9a15 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) +- [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning - [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index a0d9bed99..27197570b 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -96,6 +96,11 @@ Off-policy :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.REDQPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.DiscreteSACPolicy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index 6a82e6d52..49453ad62 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,6 +25,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ +* :class:`~tianshou.policy.REDQPolicy` `Randomized Ensembled Double Q-Learning `_ * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py new file mode 100755 index 000000000..d580ef9b3 --- /dev/null +++ b/examples/mujoco/mujoco_redq.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import REDQPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import EnsembleLinear, Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Ant-v3') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=1000000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256]) + parser.add_argument('--ensemble-size', type=int, default=10) + parser.add_argument('--subset-size', type=int, default=2) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', default=False, action='store_true') + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument('--epoch', type=int, default=200) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=1) + parser.add_argument('--update-per-step', type=int, default=20) + parser.add_argument('--n-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument( + '--target-mode', type=str, choices=('min', 'mean'), default='min' + ) + parser.add_argument('--training-num', type=int, default=1) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + return parser.parse_args() + + +def test_redq(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + # train_envs = gym.make(args.task) + if args.training_num > 1: + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + else: + train_envs = gym.make(args.task) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + def linear(x, y): + return EnsembleLinear(args.ensemble_size, x, y) + + net_c = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + linear_layer=linear, + ) + critics = Critic( + net_c, + device=args.device, + linear_layer=linear, + flatten_input=False, + ).to(args.device) + critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = REDQPolicy( + actor, + actor_optim, + critics, + critics_optim, + args.ensemble_size, + args.subset_size, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + actor_delay=args.update_per_step, + target_mode=args.target_mode, + action_space=env.action_space, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_redq' + log_path = os.path.join(args.logdir, args.task, 'redq', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + if not args.watch: + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) + pprint.pprint(result) + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_redq() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py new file mode 100644 index 000000000..8b649a2be --- /dev/null +++ b/test/continuous/test_redq.py @@ -0,0 +1,178 @@ +import argparse +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import REDQPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import EnsembleLinear, Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v1') + parser.add_argument('--reward-threshold', type=float, default=None) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--ensemble-size', type=int, default=4) + parser.add_argument('--subset-size', type=int, default=2) + parser.add_argument('--actor-lr', type=float, default=1e-4) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', action='store_true', default=False) + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument("--start-timesteps", type=int, default=1000) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=1) + parser.add_argument('--update-per-step', type=int, default=3) + parser.add_argument('--n-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument( + '--target-mode', type=str, choices=('min', 'mean'), default='min' + ) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + args = parser.parse_known_args()[0] + return args + + +def test_redq(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + if args.reward_threshold is None: + default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} + args.reward_threshold = default_reward_threshold.get( + args.task, env.spec.reward_threshold + ) + # you can also use tianshou.env.SubprocVectorEnv + # train_envs = gym.make(args.task) + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + def linear(x, y): + return EnsembleLinear(args.ensemble_size, x, y) + + net_c = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + linear_layer=linear, + ) + critic = Critic( + net_c, device=args.device, linear_layer=linear, flatten_input=False + ).to(args.device) + critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = REDQPolicy( + actor, + actor_optim, + critic, + critic_optim, + args.ensemble_size, + args.subset_size, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + actor_delay=args.update_per_step, + target_mode=args.target_mode, + action_space=env.action_space, + ) + # collector + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True + ) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + # log + log_path = os.path.join(args.logdir, args.task, 'redq') + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= args.reward_threshold + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + update_per_step=args.update_per_step, + stop_fn=stop_fn, + save_best_fn=save_best_fn, + logger=logger + ) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_redq() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index dae9da638..f7774cba1 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -17,6 +17,7 @@ from tianshou.policy.modelfree.trpo import TRPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.bcq import BCQPolicy @@ -46,6 +47,7 @@ "TRPOPolicy", "TD3Policy", "SACPolicy", + "REDQPolicy", "DiscreteSACPolicy", "ImitationPolicy", "BCQPolicy", diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py new file mode 100644 index 000000000..88d1cd95e --- /dev/null +++ b/tianshou/policy/modelfree/redq.py @@ -0,0 +1,207 @@ +from copy import deepcopy +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from torch.distributions import Independent, Normal + +from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.exploration import BaseNoise +from tianshou.policy import DDPGPolicy + + +class REDQPolicy(DDPGPolicy): + """Implementation of REDQ. arXiv:2101.05982. + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critics: critic ensemble networks. + :param torch.optim.Optimizer critics_optim: the optimizer for the critic networks. + :param int ensemble_size: Number of sub-networks in the critic ensemble. + Default to 10. + :param int subset_size: Number of networks in the subset. Default to 2. + :param float tau: param for soft update of the target network. Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. + :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy + regularization coefficient. Default to 0.2. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then + alpha is automatically tuned. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + :param int actor_delay: Number of critic updates before an actor update. + Default to 20. + :param BaseNoise exploration_noise: add a noise to action for exploration. + Default to None. This is useful when solving hard-exploration problem. + :param bool deterministic_eval: whether to use deterministic action (mean + of Gaussian policy) instead of stochastic action sampled by the policy. + Default to True. + :param str target_mode: methods to integrate critic values in the subset, + currently support minimum and average. Default to min. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action) or empty string for no bounding. + Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critics: torch.nn.Module, + critics_optim: torch.optim.Optimizer, + ensemble_size: int = 10, + subset_size: int = 2, + tau: float = 0.005, + gamma: float = 0.99, + alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, + reward_normalization: bool = False, + estimation_step: int = 1, + actor_delay: int = 20, + exploration_noise: Optional[BaseNoise] = None, + deterministic_eval: bool = True, + target_mode: str = 'min', + **kwargs: Any, + ) -> None: + super().__init__( + None, None, None, None, tau, gamma, exploration_noise, + reward_normalization, estimation_step, **kwargs + ) + self.actor, self.actor_optim = actor, actor_optim + self.critics, self.critics_old = critics, deepcopy(critics) + self.critics_old.eval() + self.critics_optim = critics_optim + assert 0 < subset_size <= ensemble_size, \ + 'Invalid choice of ensemble size or subset size.' + self.ensemble_size = ensemble_size + self.subset_size = subset_size + + self._is_auto_alpha = False + self._alpha: Union[float, torch.Tensor] + if isinstance(alpha, tuple): + self._is_auto_alpha = True + self._target_entropy, self._log_alpha, self._alpha_optim = alpha + assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad + self._alpha = self._log_alpha.detach().exp() + else: + self._alpha = alpha + + if target_mode in ('min', 'mean'): + self.target_mode = target_mode + else: + raise ValueError('Unsupported mode of Q target computing.') + + self.critic_gradient_step = 0 + self.actor_delay = actor_delay + self._deterministic_eval = deterministic_eval + self.__eps = np.finfo(np.float32).eps.item() + + def train(self, mode: bool = True) -> "REDQPolicy": + self.training = mode + self.actor.train(mode) + self.critics.train(mode) + return self + + def sync_weight(self) -> None: + for o, n in zip(self.critics_old.parameters(), self.critics.parameters()): + o.data.copy_(o.data * (1.0 - self.tau) + n.data * self.tau) + + def forward( # type: ignore + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + input: str = "obs", + **kwargs: Any, + ) -> Batch: + obs = batch[input] + logits, h = self.actor(obs, state=state, info=batch.info) + assert isinstance(logits, tuple) + dist = Independent(Normal(*logits), 1) + if self._deterministic_eval and not self.training: + act = logits[0] + else: + act = dist.rsample() + log_prob = dist.log_prob(act).unsqueeze(-1) + # apply correction for Tanh squashing when computing logprob from Gaussian + # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + # in appendix C to get some understanding of this equation. + if self.action_scaling and self.action_space is not None: + action_scale = to_torch_as( + (self.action_space.high - self.action_space.low) / 2.0, act + ) + else: + action_scale = 1.0 # type: ignore + squashed_action = torch.tanh(act) + log_prob = log_prob - torch.log( + action_scale * (1 - squashed_action.pow(2)) + self.__eps + ).sum(-1, keepdim=True) + return Batch( + logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob + ) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + batch = buffer[indices] # batch.obs: s_{t+n} + obs_next_result = self(batch, input='obs_next') + a_ = obs_next_result.act + sample_ensemble_idx = np.random.choice( + self.ensemble_size, self.subset_size, replace=False + ) + qs = self.critics_old(batch.obs_next, a_)[sample_ensemble_idx, ...] + if self.target_mode == 'min': + target_q, _ = torch.min(qs, dim=0) + elif self.target_mode == 'mean': + target_q = torch.mean(qs, dim=0) + target_q -= self._alpha * obs_next_result.log_prob + + return target_q + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + # critic ensemble + weight = getattr(batch, "weight", 1.0) + current_qs = self.critics(batch.obs, batch.act).flatten(1) + target_q = batch.returns.flatten() + td = current_qs - target_q + critic_loss = (td.pow(2) * weight).mean() + self.critics_optim.zero_grad() + critic_loss.backward() + self.critics_optim.step() + batch.weight = torch.mean(td, dim=0) # prio-buffer + self.critic_gradient_step += 1 + + # actor + if self.critic_gradient_step % self.actor_delay == 0: + obs_result = self(batch) + a = obs_result.act + current_qa = self.critics(batch.obs, a).mean(dim=0).flatten() + actor_loss = (self._alpha * obs_result.log_prob.flatten() - + current_qa).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + if self._is_auto_alpha: + log_prob = obs_result.log_prob.detach() + self._target_entropy + alpha_loss = -(self._log_alpha * log_prob).mean() + self._alpha_optim.zero_grad() + alpha_loss.backward() + self._alpha_optim.step() + self._alpha = self._log_alpha.detach().exp() + + self.sync_weight() + + result = {"loss/critics": critic_loss.item()} + if self.critic_gradient_step % self.actor_delay == 0: + result["loss/actor"] = actor_loss.item(), + if self._is_auto_alpha: + result["loss/alpha"] = alpha_loss.item() + result["alpha"] = self._alpha.item() # type: ignore + + return result diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 9d03b6120..477bb14f0 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -46,6 +46,7 @@ class MLP(nn.Module): nn.ReLU. :param device: which device to create this model on. Default to None. :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param bool flatten_input: whether to flatten input data. Default to True. """ def __init__( @@ -57,6 +58,7 @@ def __init__( activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, device: Optional[Union[str, int, torch.device]] = None, linear_layer: Type[nn.Linear] = nn.Linear, + flatten_input: bool = True, ) -> None: super().__init__() self.device = device @@ -86,6 +88,7 @@ def __init__( model += [linear_layer(hidden_sizes[-1], output_dim)] self.output_dim = output_dim or hidden_sizes[-1] self.model = nn.Sequential(*model) + self.flatten_input = flatten_input def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: if self.device is not None: @@ -94,7 +97,9 @@ def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: device=self.device, # type: ignore dtype=torch.float32, ) - return self.model(obs.flatten(1)) # type: ignore + if self.flatten_input: + obs = obs.flatten(1) + return self.model(obs) # type: ignore class Net(nn.Module): @@ -129,6 +134,7 @@ class Net(nn.Module): pass a tuple of two dict (first for Q and second for V) stating self-defined arguments as stated in class:`~tianshou.utils.net.common.MLP`. Default to None. + :param linear_layer: use this module as linear layer. Default to nn.Linear. .. seealso:: @@ -152,6 +158,7 @@ def __init__( concat: bool = False, num_atoms: int = 1, dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, + linear_layer: ModuleType = nn.Linear, ) -> None: super().__init__() self.device = device @@ -164,7 +171,8 @@ def __init__( self.use_dueling = dueling_param is not None output_dim = action_dim if not self.use_dueling and not concat else 0 self.model = MLP( - input_dim, output_dim, hidden_sizes, norm_layer, activation, device + input_dim, output_dim, hidden_sizes, norm_layer, activation, device, + linear_layer ) self.output_dim = self.model.output_dim if self.use_dueling: # dueling DQN @@ -311,3 +319,40 @@ def forward(self, obs: Union[np.ndarray, torch.Tensor], *args: Any, if not isinstance(obs, torch.Tensor): obs = torch.as_tensor(obs, dtype=torch.float32) return self.net(obs=obs.cuda(), *args, **kwargs) + + +class EnsembleLinear(nn.Module): + """Linear Layer of Ensemble network. + + :param int ensemble_size: Number of subnets in the ensemble. + :param int inp_feature: dimension of the input vector. + :param int out_feature: dimension of the output vector. + :param bool bias: whether to include an additive bias, default to be True. + """ + + def __init__( + self, + ensemble_size: int, + in_feature: int, + out_feature: int, + bias: bool = True, + ) -> None: + super().__init__() + + # To be consistent with PyTorch default initializer + k = np.sqrt(1. / in_feature) + weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k + self.weight = nn.Parameter(weight_data, requires_grad=True) + + self.bias: Union[nn.Parameter, None] + if bias: + bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k + self.bias = nn.Parameter(bias_data, requires_grad=True) + else: + self.bias = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.matmul(x, self.weight) + if self.bias is not None: + x = x + self.bias + return x diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index d68f3856f..360f7baeb 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union import numpy as np import torch @@ -79,6 +79,9 @@ class Critic(nn.Module): only a single linear layer). :param int preprocess_net_output_dim: the output dimension of preprocess_net. + :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param bool flatten_input: whether to flatten input data for the last layer. + Default to True. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -95,6 +98,8 @@ def __init__( hidden_sizes: Sequence[int] = (), device: Union[str, int, torch.device] = "cpu", preprocess_net_output_dim: Optional[int] = None, + linear_layer: Type[nn.Module] = nn.Linear, + flatten_input: bool = True, ) -> None: super().__init__() self.device = device @@ -105,7 +110,9 @@ def __init__( input_dim, # type: ignore 1, hidden_sizes, - device=self.device + device=self.device, + linear_layer=linear_layer, + flatten_input=flatten_input, ) def forward( From e3f506af233ca507b23befde18874c0c0f0855ed Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 30 Apr 2022 11:31:29 -0400 Subject: [PATCH 2/3] fix ci --- tianshou/env/pettingzoo_env.py | 9 ++++++--- tianshou/policy/modelfree/redq.py | 27 ++++++++++----------------- tianshou/utils/net/common.py | 23 +++++++++++++++-------- tianshou/utils/net/continuous.py | 2 +- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 25c34f994..c406872dc 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -55,8 +55,8 @@ def __init__(self, env: BaseWrapper): self.reset() - def reset(self) -> dict: - self.env.reset() + def reset(self, *args: Any, **kwargs: Any) -> dict: + self.env.reset(*args, **kwargs) observation = self.env.observe(self.env.agent_selection) if isinstance(observation, dict) and 'action_mask' in observation: return { @@ -103,7 +103,10 @@ def close(self) -> None: self.env.close() def seed(self, seed: Any = None) -> None: - self.env.seed(seed) + try: + self.env.seed(seed) + except NotImplementedError: + self.env.reset(seed=seed) def render(self, mode: str = "human") -> Any: return self.env.render(mode) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 88d1cd95e..fdf27d462 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -5,7 +5,7 @@ import torch from torch.distributions import Independent, Normal -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy @@ -68,7 +68,7 @@ def __init__( actor_delay: int = 20, exploration_noise: Optional[BaseNoise] = None, deterministic_eval: bool = True, - target_mode: str = 'min', + target_mode: str = "min", **kwargs: Any, ) -> None: super().__init__( @@ -80,7 +80,7 @@ def __init__( self.critics_old.eval() self.critics_optim = critics_optim assert 0 < subset_size <= ensemble_size, \ - 'Invalid choice of ensemble size or subset size.' + "Invalid choice of ensemble size or subset size." self.ensemble_size = ensemble_size self.subset_size = subset_size @@ -94,10 +94,10 @@ def __init__( else: self._alpha = alpha - if target_mode in ('min', 'mean'): + if target_mode in ("min", "mean"): self.target_mode = target_mode else: - raise ValueError('Unsupported mode of Q target computing.') + raise ValueError("Unsupported mode of Q target computing.") self.critic_gradient_step = 0 self.actor_delay = actor_delay @@ -133,31 +133,24 @@ def forward( # type: ignore # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. - if self.action_scaling and self.action_space is not None: - action_scale = to_torch_as( - (self.action_space.high - self.action_space.low) / 2.0, act - ) - else: - action_scale = 1.0 # type: ignore squashed_action = torch.tanh(act) - log_prob = log_prob - torch.log( - action_scale * (1 - squashed_action.pow(2)) + self.__eps - ).sum(-1, keepdim=True) + log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + + self.__eps).sum(-1, keepdim=True) return Batch( logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: batch = buffer[indices] # batch.obs: s_{t+n} - obs_next_result = self(batch, input='obs_next') + obs_next_result = self(batch, input="obs_next") a_ = obs_next_result.act sample_ensemble_idx = np.random.choice( self.ensemble_size, self.subset_size, replace=False ) qs = self.critics_old(batch.obs_next, a_)[sample_ensemble_idx, ...] - if self.target_mode == 'min': + if self.target_mode == "min": target_q, _ = torch.min(qs, dim=0) - elif self.target_mode == 'mean': + elif self.target_mode == "mean": target_q = torch.mean(qs, dim=0) target_q -= self._alpha * obs_next_result.log_prob diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 477bb14f0..64283b44f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,4 +1,14 @@ -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, + no_type_check, +) import numpy as np import torch @@ -90,16 +100,13 @@ def __init__( self.model = nn.Sequential(*model) self.flatten_input = flatten_input + @no_type_check def forward(self, obs: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: if self.device is not None: - obs = torch.as_tensor( - obs, - device=self.device, # type: ignore - dtype=torch.float32, - ) + obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) if self.flatten_input: obs = obs.flatten(1) - return self.model(obs) # type: ignore + return self.model(obs) class Net(nn.Module): @@ -158,7 +165,7 @@ def __init__( concat: bool = False, num_atoms: int = 1, dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None, - linear_layer: ModuleType = nn.Linear, + linear_layer: Type[nn.Linear] = nn.Linear, ) -> None: super().__init__() self.device = device diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 360f7baeb..bf083c32a 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -98,7 +98,7 @@ def __init__( hidden_sizes: Sequence[int] = (), device: Union[str, int, torch.device] = "cpu", preprocess_net_output_dim: Optional[int] = None, - linear_layer: Type[nn.Module] = nn.Linear, + linear_layer: Type[nn.Linear] = nn.Linear, flatten_input: bool = True, ) -> None: super().__init__() From d25da3b28507b613a46aa403116430bba411b4a9 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 30 Apr 2022 11:36:45 -0400 Subject: [PATCH 3/3] fix ci --- docs/spelling_wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 31e53fce0..471f262b5 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -157,3 +157,4 @@ Nvidia Enduro Qbert Seaquest +subnets