diff --git a/README.md b/README.md index 2848f9c12..7226d7c52 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ - [Double DQN](https://arxiv.org/pdf/1509.06461.pdf) - [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf) - [C51](https://arxiv.org/pdf/1707.06887.pdf) +- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) diff --git a/docs/index.rst b/docs/index.rst index 72704f49c..a3acbe64c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -14,6 +14,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ * :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ * :class:`~tianshou.policy.C51Policy` `C51 `_ +* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ * :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ diff --git a/examples/atari/README.md b/examples/atari/README.md index 933415e58..281f72ea1 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -40,6 +40,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. +# QRDQN (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20 | ![](results/qrdqn/Pong_rew.png) | `python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 409.2 | ![](results/qrdqn/Breakout_rew.png) | `python3 atari_qrdqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1055.9 | ![](results/qrdqn/Enduro_rew.png) | `python3 atari_qrdqn.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 14990 | ![](results/qrdqn/Qbert_rew.png) | `python3 atari_qrdqn.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2886 | ![](results/qrdqn/MsPacman_rew.png) | `python3 atari_qrdqn.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 5676 | ![](results/qrdqn/Seaquest_rew.png) | `python3 atari_qrdqn.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 938 | ![](results/qrdqn/SpaceInvader_rew.png) | `python3 atari_qrdqn.py --task "SpaceInvadersNoFrameskip-v4"` | + # BCQ TODO: after the `done` issue fixed, the result should be re-tuned and place here. @@ -49,4 +63,3 @@ To running BCQ algorithm on Atari, you need to do the following things: - Train an expert, by using the command listed in the above DQN section; - Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); - Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`. - diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index c31a6c8cf..f531d96e5 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -64,8 +64,8 @@ def __init__( num_atoms: int = 51, device: Union[str, int, torch.device] = "cpu", ) -> None: - super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device) - self.action_shape = action_shape + self.action_num = np.prod(action_shape) + super().__init__(c, h, w, [self.action_num * num_atoms], device) self.num_atoms = num_atoms def forward( @@ -77,5 +77,38 @@ def forward( r"""Mapping: x -> Z(x, \*).""" x, state = super().forward(x) x = x.view(-1, self.num_atoms).softmax(dim=-1) - x = x.view(-1, np.prod(self.action_shape), self.num_atoms) + x = x.view(-1, self.action_num, self.num_atoms) + return x, state + + +class QRDQN(DQN): + """Reference: Distributional Reinforcement Learning with Quantile \ + Regression. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + c: int, + h: int, + w: int, + action_shape: Sequence[int], + num_quantiles: int = 200, + device: Union[str, int, torch.device] = "cpu", + ) -> None: + self.action_num = np.prod(action_shape) + super().__init__(c, h, w, [self.action_num * num_quantiles], device) + self.num_quantiles = num_quantiles + + def forward( + self, + x: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: + r"""Mapping: x -> Z(x, \*).""" + x, state = super().forward(x) + x = x.view(-1, self.action_num, self.num_quantiles) return x, state diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py new file mode 100644 index 000000000..08c34733c --- /dev/null +++ b/examples/atari/atari_qrdqn.py @@ -0,0 +1,153 @@ +import os +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import QRDQNPolicy +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer + +from atari_network import QRDQN +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--num-quantiles', type=int, default=200) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') + return parser.parse_args() + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_qrdqn(args=get_args()): + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.env.action_space.shape or env.env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = SubprocVectorEnv([lambda: make_atari_env(args) + for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = QRDQN(*args.state_shape, args.action_shape, + args.num_quantiles, args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = QRDQNPolicy( + net, optim, args.gamma, args.num_quantiles, + args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device + )) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) + # collector + train_collector = Collector(policy, train_envs, buffer) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, 'qrdqn') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + def train_fn(epoch, env_step): + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ + (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + writer.add_scalar('train/eps', eps, global_step=env_step) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch(): + print("Testing agent ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * 4) + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_qrdqn(get_args()) diff --git a/examples/atari/results/qrdqn/Breakout_rew.png b/examples/atari/results/qrdqn/Breakout_rew.png new file mode 100644 index 000000000..5cc1916e6 Binary files /dev/null and b/examples/atari/results/qrdqn/Breakout_rew.png differ diff --git a/examples/atari/results/qrdqn/Enduro_rew.png b/examples/atari/results/qrdqn/Enduro_rew.png new file mode 100644 index 000000000..640e60d2d Binary files /dev/null and b/examples/atari/results/qrdqn/Enduro_rew.png differ diff --git a/examples/atari/results/qrdqn/MsPacman_rew.png b/examples/atari/results/qrdqn/MsPacman_rew.png new file mode 100644 index 000000000..0afd25787 Binary files /dev/null and b/examples/atari/results/qrdqn/MsPacman_rew.png differ diff --git a/examples/atari/results/qrdqn/Pong_rew.png b/examples/atari/results/qrdqn/Pong_rew.png new file mode 100644 index 000000000..30a02375e Binary files /dev/null and b/examples/atari/results/qrdqn/Pong_rew.png differ diff --git a/examples/atari/results/qrdqn/Qbert_rew.png b/examples/atari/results/qrdqn/Qbert_rew.png new file mode 100644 index 000000000..fbd25c732 Binary files /dev/null and b/examples/atari/results/qrdqn/Qbert_rew.png differ diff --git a/examples/atari/results/qrdqn/Seaquest_rew.png b/examples/atari/results/qrdqn/Seaquest_rew.png new file mode 100644 index 000000000..7e9d47af2 Binary files /dev/null and b/examples/atari/results/qrdqn/Seaquest_rew.png differ diff --git a/examples/atari/results/qrdqn/SpaceInvader_rew.png b/examples/atari/results/qrdqn/SpaceInvader_rew.png new file mode 100644 index 000000000..4751768fb Binary files /dev/null and b/examples/atari/results/qrdqn/SpaceInvader_rew.png differ diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py new file mode 100644 index 000000000..7020df275 --- /dev/null +++ b/test/discrete/test_qrdqn.py @@ -0,0 +1,136 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import QRDQNPolicy +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--num-quantiles', type=int, default=200) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[128, 128, 128, 128]) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', + action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_qrdqn(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device, + softmax=False, num_atoms=args.num_quantiles) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = QRDQNPolicy( + net, optim, args.gamma, args.num_quantiles, + args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # buffer + if args.prioritized_replay: + buf = PrioritizedReplayBuffer( + args.buffer_size, alpha=args.alpha, beta=args.beta) + else: + buf = ReplayBuffer(args.buffer_size) + # collector + train_collector = Collector(policy, train_envs, buf) + test_collector = Collector(policy, test_envs) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size) + # log + log_path = os.path.join(args.logdir, args.task, 'qrdqn') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) + + assert stop_fn(result['best_reward']) + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +def test_pqrdqn(args=get_args()): + args.prioritized_replay = True + args.gamma = .95 + args.seed = 1 + test_qrdqn(args) + + +if __name__ == '__main__': + test_pqrdqn(get_args()) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 968aaf69b..01b7019af 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -3,6 +3,7 @@ from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.ddpg import DDPGPolicy @@ -21,6 +22,7 @@ "ImitationPolicy", "DQNPolicy", "C51Policy", + "QRDQNPolicy", "PGPolicy", "A2CPolicy", "DDPGPolicy", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 6e172dfd7..99e16544b 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -257,7 +257,8 @@ def compute_nstep_return( mean, std = 0.0, 1.0 buf_len = len(buffer) terminal = (indice + n_step - 1) % buf_len - target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) + with torch.no_grad(): + target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) target_q = to_numpy(target_q_torch) target_q = _nstep_return(rew, buffer.done, target_q, indice, diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index faae34aaf..688a9901d 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -74,10 +74,9 @@ def _target_q( ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - with torch.no_grad(): - act = self(batch, input="obs_next", eps=0.0).act - target_q, _ = self.model_old(batch.obs_next) - target_q = target_q[np.arange(len(act)), act] + act = self(batch, input="obs_next", eps=0.0).act + target_q, _ = self.model_old(batch.obs_next) + target_q = target_q[np.arange(len(act)), act] return target_q def forward( # type: ignore diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 706872efb..b0e94c616 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -1,9 +1,9 @@ import torch import numpy as np -from typing import Any, Dict, Union, Optional +from typing import Any, Dict from tianshou.policy import DQNPolicy -from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data import Batch, ReplayBuffer class C51Policy(DQNPolicy): @@ -63,46 +63,9 @@ def _target_q( ) -> torch.Tensor: return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms] - def forward( - self, - batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - model: str = "model", - input: str = "obs", - **kwargs: Any, - ) -> Batch: - """Compute action over the given batch data. - - :return: A :class:`~tianshou.data.Batch` which has 2 keys: - - * ``act`` the action. - * ``state`` the hidden state. - - .. seealso:: - - Please refer to :meth:`~tianshou.policy.DQNPolicy.forward` for - more detailed explanation. - """ - model = getattr(self, model) - obs = batch[input] - obs_ = obs.obs if hasattr(obs, "obs") else obs - dist, h = model(obs_, state=state, info=batch.info) - q = (dist * self.support).sum(2) - act: np.ndarray = to_numpy(q.max(dim=1)[1]) - if hasattr(obs, "mask"): - # some of actions are masked, they cannot be selected - q_: np.ndarray = to_numpy(q) - q_[~obs.mask] = -np.inf - act = q_.argmax(axis=1) - # add eps to act in training or testing phase - if not self.updating and not np.isclose(self.eps, 0.0): - for i in range(len(q)): - if np.random.rand() < self.eps: - q_ = np.random.rand(*q[i].shape) - if hasattr(obs, "mask"): - q_[~obs.mask[i]] = -np.inf - act[i] = q_.argmax() - return Batch(logits=dist, act=act, state=h) + def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: + """Compute the q value based on the network's raw output logits.""" + return (logits * self.support).sum(2) def _target_dist(self, batch: Batch) -> torch.Tensor: if self._target: diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index ab28b6b7d..1ac293c88 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -102,10 +102,9 @@ def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} - with torch.no_grad(): - target_q = self.critic_old( - batch.obs_next, - self(batch, model='actor_old', input='obs_next').act) + target_q = self.critic_old( + batch.obs_next, + self(batch, model='actor_old', input='obs_next').act) return target_q def process_fn( diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index f71b18b54..4db2dc9dc 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -78,14 +78,13 @@ def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} - with torch.no_grad(): - obs_next_result = self(batch, input="obs_next") - dist = obs_next_result.dist - target_q = dist.probs * torch.min( - self.critic1_old(batch.obs_next), - self.critic2_old(batch.obs_next), - ) - target_q = target_q.sum(dim=-1) + self._alpha * dist.entropy() + obs_next_result = self(batch, input="obs_next") + dist = obs_next_result.dist + target_q = dist.probs * torch.min( + self.critic1_old(batch.obs_next), + self.critic2_old(batch.obs_next), + ) + target_q = target_q.sum(dim=-1) + self._alpha * dist.entropy() return target_q def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index a8f705b81..54397915c 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -79,15 +79,14 @@ def _target_q( ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - with torch.no_grad(): - if self._target: - a = self(batch, input="obs_next").act - target_q = self( - batch, model="model_old", input="obs_next" - ).logits - target_q = target_q[np.arange(len(a)), a] - else: - target_q = self(batch, input="obs_next").logits.max(dim=1)[0] + if self._target: + a = self(batch, input="obs_next").act + target_q = self( + batch, model="model_old", input="obs_next" + ).logits + target_q = target_q[np.arange(len(a)), a] + else: + target_q = self(batch, input="obs_next").logits.max(dim=1)[0] return target_q def process_fn( @@ -103,6 +102,10 @@ def process_fn( self._gamma, self._n_step, self._rew_norm) return batch + def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: + """Compute the q value based on the network's raw output logits.""" + return logits + def forward( self, batch: Batch, @@ -143,7 +146,8 @@ def forward( model = getattr(self, model) obs = batch[input] obs_ = obs.obs if hasattr(obs, "obs") else obs - q, h = model(obs_, state=state, info=batch.info) + logits, h = model(obs_, state=state, info=batch.info) + q = self.compute_q_value(logits) act: np.ndarray = to_numpy(q.max(dim=1)[1]) if hasattr(obs, "mask"): # some of actions are masked, they cannot be selected @@ -158,7 +162,7 @@ def forward( if hasattr(obs, "mask"): q_[~obs.mask[i]] = -np.inf act[i] = q_.argmax() - return Batch(logits=q, act=act, state=h) + return Batch(logits=logits, act=act, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._iter % self._freq == 0: diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py new file mode 100644 index 000000000..754d9acce --- /dev/null +++ b/tianshou/policy/modelfree/qrdqn.py @@ -0,0 +1,94 @@ +import torch +import warnings +import numpy as np +from typing import Any, Dict +import torch.nn.functional as F + +from tianshou.policy import DQNPolicy +from tianshou.data import Batch, ReplayBuffer + + +class QRDQNPolicy(DQNPolicy): + """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int num_quantiles: the number of quantile midpoints in the inverse + cumulative distribution function of the value, defaults to 200. + :param int estimation_step: greater than 1, the number of steps to look + ahead. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + num_quantiles: int = 200, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(model, optim, discount_factor, estimation_step, + target_update_freq, reward_normalization, **kwargs) + assert num_quantiles > 1, "num_quantiles should be greater than 1" + self._num_quantiles = num_quantiles + tau = torch.linspace(0, 1, self._num_quantiles + 1) + self.tau_hat = torch.nn.Parameter( + ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False) + warnings.filterwarnings("ignore", message="Using a target size") + + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: + batch = buffer[indice] # batch.obs_next: s_{t+n} + if self._target: + a = self(batch, input="obs_next").act + next_dist = self( + batch, model="model_old", input="obs_next" + ).logits + else: + next_b = self(batch, input="obs_next") + a = next_b.act + next_dist = next_b.logits + next_dist = next_dist[np.arange(len(a)), a, :] + return next_dist # shape: [bsz, num_quantiles] + + def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: + """Compute the q value based on the network's raw output logits.""" + return logits.mean(2) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._target and self._iter % self._freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + curr_dist = self(batch).logits + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = (u * ( + self.tau_hat - (target_dist - curr_dist).detach().le(0.).float() + ).abs()).sum(-1).mean(1) + loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 983df7158..fbdd12297 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -140,13 +140,12 @@ def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} - with torch.no_grad(): - obs_next_result = self(batch, input='obs_next') - a_ = obs_next_result.act - target_q = torch.min( - self.critic1_old(batch.obs_next, a_), - self.critic2_old(batch.obs_next, a_), - ) - self._alpha * obs_next_result.log_prob + obs_next_result = self(batch, input='obs_next') + a_ = obs_next_result.act + target_q = torch.min( + self.critic1_old(batch.obs_next, a_), + self.critic2_old(batch.obs_next, a_), + ) - self._alpha * obs_next_result.log_prob return target_q def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index d7cf9c09a..f79c2a0d5 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -104,17 +104,16 @@ def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} - with torch.no_grad(): - a_ = self(batch, model="actor_old", input="obs_next").act - dev = a_.device - noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise - if self._noise_clip > 0.0: - noise = noise.clamp(-self._noise_clip, self._noise_clip) - a_ += noise - a_ = a_.clamp(self._range[0], self._range[1]) - target_q = torch.min( - self.critic1_old(batch.obs_next, a_), - self.critic2_old(batch.obs_next, a_)) + a_ = self(batch, model="actor_old", input="obs_next").act + dev = a_.device + noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise + if self._noise_clip > 0.0: + noise = noise.clamp(-self._noise_clip, self._noise_clip) + a_ += noise + a_ = a_.clamp(self._range[0], self._range[1]) + target_q = torch.min( + self.critic1_old(batch.obs_next, a_), + self.critic2_old(batch.obs_next, a_)) return target_q def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: