diff --git a/README.md b/README.md index 8f442e212..295ca96db 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ - [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) - [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf) - [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf) +- [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf) - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index d8579435c..39f478ca5 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -40,6 +40,11 @@ DQN Family :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.FQFPolicy + :members: + :undoc-members: + :show-inheritance: + On-policy ~~~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index 87189fd5f..4afe03aa8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ * :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ * :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network `_ +* :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function `_ * :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ * :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient `_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ diff --git a/examples/atari/README.md b/examples/atari/README.md index 971231b65..a46fcac5c 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -68,6 +68,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | SeaquestNoFrameskip-v4 | 4874 | ![](results/iqn/Seaquest_rew.png) | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 1498.5 | ![](results/iqn/SpaceInvaders_rew.png) | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` | +# FQF (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.7 | ![](results/fqf/Pong_rew.png) | `python3 atari_fqf.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 517.3 | ![](results/fqf/Breakout_rew.png) | `python3 atari_fqf.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 2240.5 | ![](results/fqf/Enduro_rew.png) | `python3 atari_fqf.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 16172.5 | ![](results/fqf/Qbert_rew.png) | `python3 atari_fqf.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2429 | ![](results/fqf/MsPacman_rew.png) | `python3 atari_fqf.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 10775 | ![](results/fqf/Seaquest_rew.png) | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 2482 | ![](results/fqf/SpaceInvaders_rew.png) | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` | + # BCQ To running BCQ algorithm on Atari, you need to do the following things: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py new file mode 100644 index 000000000..110a57167 --- /dev/null +++ b/examples/atari/atari_fqf.py @@ -0,0 +1,186 @@ +import os +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import FQFPolicy +from tianshou.utils import BasicLogger +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction + +from atari_network import DQN +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=3128) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=5e-5) + parser.add_argument('--fraction-lr', type=float, default=2.5e-9) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--num-fractions', type=int, default=32) + parser.add_argument('--num-cosines', type=int, default=64) + parser.add_argument('--ent-coef', type=float, default=10.) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') + parser.add_argument('--save-buffer-name', type=str, default=None) + return parser.parse_args() + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_fqf(args=get_args()): + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = SubprocVectorEnv([lambda: make_atari_env(args) + for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + feature_net = DQN(*args.state_shape, args.action_shape, args.device, + features_only=True) + net = FullQuantileFunction( + feature_net, args.action_shape, args.hidden_sizes, + args.num_cosines, device=args.device + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) + fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), + lr=args.fraction_lr) + # define policy + policy = FQFPolicy( + net, optim, fraction_net, fraction_optim, + args.gamma, args.num_fractions, args.ent_coef, args.n_step, + target_update_freq=args.target_update_freq + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # log + log_path = os.path.join(args.logdir, args.task, 'fqf') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + def train_fn(epoch, env_step): + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ + (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + logger.write('train/eps', env_step, eps) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack) + collector = Collector(policy, test_envs, buffer, + exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_fqf(get_args()) diff --git a/examples/atari/results/fqf/Breakout_rew.png b/examples/atari/results/fqf/Breakout_rew.png new file mode 100644 index 000000000..409a2adf5 Binary files /dev/null and b/examples/atari/results/fqf/Breakout_rew.png differ diff --git a/examples/atari/results/fqf/Enduro_rew.png b/examples/atari/results/fqf/Enduro_rew.png new file mode 100644 index 000000000..3d5125983 Binary files /dev/null and b/examples/atari/results/fqf/Enduro_rew.png differ diff --git a/examples/atari/results/fqf/MsPacman_rew.png b/examples/atari/results/fqf/MsPacman_rew.png new file mode 100644 index 000000000..6832017fb Binary files /dev/null and b/examples/atari/results/fqf/MsPacman_rew.png differ diff --git a/examples/atari/results/fqf/Pong_rew.png b/examples/atari/results/fqf/Pong_rew.png new file mode 100644 index 000000000..e3a4e44fd Binary files /dev/null and b/examples/atari/results/fqf/Pong_rew.png differ diff --git a/examples/atari/results/fqf/Qbert_rew.png b/examples/atari/results/fqf/Qbert_rew.png new file mode 100644 index 000000000..c0a2fec58 Binary files /dev/null and b/examples/atari/results/fqf/Qbert_rew.png differ diff --git a/examples/atari/results/fqf/Seaquest_rew.png b/examples/atari/results/fqf/Seaquest_rew.png new file mode 100644 index 000000000..2accab38e Binary files /dev/null and b/examples/atari/results/fqf/Seaquest_rew.png differ diff --git a/examples/atari/results/fqf/SpaceInvaders_rew.png b/examples/atari/results/fqf/SpaceInvaders_rew.png new file mode 100644 index 000000000..fe271ecc1 Binary files /dev/null and b/examples/atari/results/fqf/SpaceInvaders_rew.png differ diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py new file mode 100644 index 000000000..534927f12 --- /dev/null +++ b/test/discrete/test_fqf.py @@ -0,0 +1,153 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import FQFPolicy +from tianshou.utils import BasicLogger +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=3e-3) + parser.add_argument('--fraction-lr', type=float, default=2.5e-9) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--num-fractions', type=int, default=32) + parser.add_argument('--num-cosines', type=int, default=64) + parser.add_argument('--ent-coef', type=float, default=10.) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[64, 64, 64]) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', + action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_fqf(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + feature_net = Net(args.state_shape, args.hidden_sizes[-1], + hidden_sizes=args.hidden_sizes[:-1], device=args.device, + softmax=False) + net = FullQuantileFunction( + feature_net, args.action_shape, args.hidden_sizes, + num_cosines=args.num_cosines, device=args.device + ) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) + fraction_optim = torch.optim.RMSprop( + fraction_net.parameters(), lr=args.fraction_lr + ) + policy = FQFPolicy( + net, optim, fraction_net, fraction_optim, args.gamma, args.num_fractions, + args.ent_coef, args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + alpha=args.alpha, beta=args.beta) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'fqf') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +def test_pfqf(args=get_args()): + args.prioritized_replay = True + args.gamma = .95 + test_fqf(args) + + +if __name__ == '__main__': + test_fqf(get_args()) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index a35089131..9dd879cb1 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -4,6 +4,7 @@ from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.modelfree.iqn import IQNPolicy +from tianshou.policy.modelfree.fqf import FQFPolicy from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.npg import NPGPolicy @@ -28,6 +29,7 @@ "C51Policy", "QRDQNPolicy", "IQNPolicy", + "FQFPolicy", "PGPolicy", "A2CPolicy", "NPGPolicy", diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py new file mode 100644 index 000000000..db110cc79 --- /dev/null +++ b/tianshou/policy/modelfree/fqf.py @@ -0,0 +1,161 @@ +import torch +import numpy as np +import torch.nn.functional as F +from typing import Any, Dict, Optional, Union + +from tianshou.policy import DQNPolicy, QRDQNPolicy +from tianshou.data import Batch, to_numpy, ReplayBuffer +from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction + + +class FQFPolicy(QRDQNPolicy): + """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param FractionProposalNetwork fraction_model: a FractionProposalNetwork for + proposing fractions/quantiles given state. + :param torch.optim.Optimizer fraction_optim: a torch.optim for optimizing + the fraction model above. + :param float discount_factor: in [0, 1]. + :param int num_fractions: the number of fractions to use. Default to 32. + :param float ent_coef: the coefficient for entropy loss. Default to 0. + :param int estimation_step: the number of steps to look ahead. Default to 1. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: FullQuantileFunction, + optim: torch.optim.Optimizer, + fraction_model: FractionProposalNetwork, + fraction_optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + num_fractions: int = 32, + ent_coef: float = 0.0, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + model, optim, discount_factor, num_fractions, estimation_step, + target_update_freq, reward_normalization, **kwargs + ) + self.propose_model = fraction_model + self._ent_coef = ent_coef + self._fraction_optim = fraction_optim + + def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: + batch = buffer[indice] # batch.obs_next: s_{t+n} + if self._target: + result = self(batch, input="obs_next") + a, fractions = result.act, result.fractions + next_dist = self( + batch, model="model_old", input="obs_next", fractions=fractions + ).logits + else: + next_b = self(batch, input="obs_next") + a = next_b.act + next_dist = next_b.logits + next_dist = next_dist[np.arange(len(a)), a, :] + return next_dist # shape: [bsz, num_quantiles] + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: str = "model", + input: str = "obs", + fractions: Optional[Batch] = None, + **kwargs: Any, + ) -> Batch: + model = getattr(self, model) + obs = batch[input] + obs_ = obs.obs if hasattr(obs, "obs") else obs + if fractions is None: + (logits, fractions, quantiles_tau), h = model( + obs_, propose_model=self.propose_model, state=state, info=batch.info + ) + else: + (logits, _, quantiles_tau), h = model( + obs_, propose_model=self.propose_model, fractions=fractions, + state=state, info=batch.info + ) + weighted_logits = ( + fractions.taus[:, 1:] - fractions.taus[:, :-1] + ).unsqueeze(1) * logits + q = DQNPolicy.compute_q_value( + self, weighted_logits.sum(2), getattr(obs, "mask", None) + ) + if not hasattr(self, "max_action_num"): + self.max_action_num = q.shape[1] + act = to_numpy(q.max(dim=1)[1]) + return Batch( + logits=logits, act=act, state=h, fractions=fractions, + quantiles_tau=quantiles_tau + ) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._target and self._iter % self._freq == 0: + self.sync_weight() + weight = batch.pop("weight", 1.0) + out = self(batch) + curr_dist_orig = out.logits + taus, tau_hats = out.fractions.taus, out.fractions.tau_hats + act = batch.act + curr_dist = curr_dist_orig[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = (u * ( + tau_hats.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float() + ).abs()).sum(-1).mean(1) + quantile_loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + # calculate fraction loss + with torch.no_grad(): + sa_quantile_hats = curr_dist_orig[np.arange(len(act)), act, :] + sa_quantiles = out.quantiles_tau[np.arange(len(act)), act, :] + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/fqf_agent.py L169 + values_1 = sa_quantiles - sa_quantile_hats[:, :-1] + signs_1 = sa_quantiles > torch.cat([ + sa_quantile_hats[:, :1], sa_quantiles[:, :-1]], dim=1) + + values_2 = sa_quantiles - sa_quantile_hats[:, 1:] + signs_2 = sa_quantiles < torch.cat([ + sa_quantiles[:, 1:], sa_quantile_hats[:, -1:]], dim=1) + + gradient_of_taus = ( + torch.where(signs_1, values_1, -values_1) + + torch.where(signs_2, values_2, -values_2) + ) + fraction_loss = (gradient_of_taus * taus[:, 1:-1]).sum(1).mean() + # calculate entropy loss + entropy_loss = out.fractions.entropies.mean() + fraction_entropy_loss = fraction_loss - self._ent_coef * entropy_loss + self._fraction_optim.zero_grad() + fraction_entropy_loss.backward(retain_graph=True) + self._fraction_optim.step() + self.optim.zero_grad() + quantile_loss.backward() + self.optim.step() + self._iter += 1 + return { + "loss": quantile_loss.item() + fraction_entropy_loss.item(), + "loss/quantile": quantile_loss.item(), + "loss/fraction": fraction_loss.item(), + "loss/entropy": entropy_loss.item() + } diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 14cc85e50..b104b4ee8 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from typing import Any, Dict, Tuple, Union, Optional, Sequence +from tianshou.data import Batch from tianshou.utils.net.common import MLP @@ -199,6 +200,110 @@ def forward( # type: ignore embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view( batch_size * sample_size, -1 ) - out = self.last(embedding).view(batch_size, - sample_size, -1).transpose(1, 2) + out = self.last(embedding).view( + batch_size, sample_size, -1).transpose(1, 2) return (out, taus), h + + +class FractionProposalNetwork(nn.Module): + """Fraction proposal network for FQF. + + :param num_fractions: the number of factions to propose. + :param embedding_dim: the dimension of the embedding/input. + + .. note:: + + Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master + /fqf_iqn_qrdqn/network.py . + """ + + def __init__(self, num_fractions: int, embedding_dim: int) -> None: + super().__init__() + self.net = nn.Linear(embedding_dim, num_fractions) + torch.nn.init.xavier_uniform_(self.net.weight, gain=0.01) + torch.nn.init.constant_(self.net.bias, 0) + self.num_fractions = num_fractions + self.embedding_dim = embedding_dim + + def forward( + self, state_embeddings: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Calculate (log of) probabilities q_i in the paper. + m = torch.distributions.Categorical(logits=self.net(state_embeddings)) + taus_1_N = torch.cumsum(m.probs, dim=1) + # Calculate \tau_i (i=0,...,N). + taus = F.pad(taus_1_N, (1, 0)) + # Calculate \hat \tau_i (i=0,...,N-1). + tau_hats = (taus[:, :-1] + taus[:, 1:]).detach() / 2.0 + # Calculate entropies of value distributions. + entropies = m.entropy() + return taus, tau_hats, entropies + + +class FullQuantileFunction(ImplicitQuantileNetwork): + """Full(y parameterized) Quantile Function. + + :param preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param int action_dim: the dimension of action space. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param int num_cosines: the number of cosines to use for cosine embedding. + Default to 64. + :param int preprocess_net_output_dim: the output dimension of + preprocess_net. + + .. note:: + + The first return value is a tuple of (quantiles, fractions, quantiles_tau), + where fractions is a Batch(taus, tau_hats, entropies). + """ + + def __init__( + self, + preprocess_net: nn.Module, + action_shape: Sequence[int], + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu", + ) -> None: + super().__init__( + preprocess_net, action_shape, hidden_sizes, + num_cosines, preprocess_net_output_dim, device + ) + + def _compute_quantiles( + self, obs: torch.Tensor, taus: torch.Tensor + ) -> torch.Tensor: + batch_size, sample_size = taus.shape + embedding = (obs.unsqueeze(1) * self.embed_model(taus)).view( + batch_size * sample_size, -1 + ) + quantiles = self.last(embedding).view( + batch_size, sample_size, -1 + ).transpose(1, 2) + return quantiles + + def forward( # type: ignore + self, s: Union[np.ndarray, torch.Tensor], + propose_model: FractionProposalNetwork, + fractions: Optional[Batch] = None, + **kwargs: Any + ) -> Tuple[Any, torch.Tensor]: + r"""Mapping: s -> Q(s, \*).""" + logits, h = self.preprocess(s, state=kwargs.get("state", None)) + # Propose fractions + if fractions is None: + taus, tau_hats, entropies = propose_model(logits.detach()) + fractions = Batch(taus=taus, tau_hats=tau_hats, entropies=entropies) + else: + taus, tau_hats = fractions.taus, fractions.tau_hats + quantiles = self._compute_quantiles(logits, tau_hats) + # Calculate quantiles_tau for computing fraction grad + quantiles_tau = None + if self.training: + with torch.no_grad(): + quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1]) + return (quantiles, fractions, quantiles_tau), h