diff --git a/README.md b/README.md index a8fb2f051..3ff1f8576 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ - [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf) - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) +- [Randomized Ensembled Double Q-Learning (REDQ)](https://arxiv.org/pdf/2101.05982.pdf) - Vanilla Imitation Learning - [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) - [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf) @@ -45,6 +46,7 @@ - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) - [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf) +- [Model-Based Policy Optimization (MBPO)](https://arxiv.org/pdf/1906.08253.pdf) Here is Tianshou's other features: diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index c3063665c..065ad5c57 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -101,6 +101,11 @@ Off-policy :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.REDQPolicy + :members: + :undoc-members: + :show-inheritance: + Imitation --------- @@ -147,6 +152,11 @@ Model-based :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.MBPOPolicy + :members: + :undoc-members: + :show-inheritance: + Multi-agent ----------- diff --git a/docs/index.rst b/docs/index.rst index 13adaed4c..31d23d92a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,6 +26,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ +* :class:`~tianshou.policy.REDQPolicy` `Randomized Ensembled Double Q-Learning `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning `_ @@ -34,6 +35,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ * :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ +* :class:`~tianshou.policy.MBPOPolicy` `Model-Based Policy Optimization `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/examples/modelbased/mujoco_mbpo.py b/examples/modelbased/mujoco_mbpo.py new file mode 100644 index 000000000..587347d0e --- /dev/null +++ b/examples/modelbased/mujoco_mbpo.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import gym +import numpy as np +import torch +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import ( + Collector, + ReplayBuffer, + RolloutsCollector, + SimpleReplayBuffer, + VectorReplayBuffer, +) +from tianshou.env import SubprocVectorEnv +from tianshou.env.fake import FakeEnv, GaussianModel +from tianshou.env.mujoco.static import TERMINAL_FUNCTIONS +from tianshou.policy import MBPOPolicy +from tianshou.trainer import dyna_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import EnsembleMLPGaussian, Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Hopper-v2') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=1000000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256]) + parser.add_argument( + '--model-hidden-sizes', type=int, nargs='*', default=[200, 200, 200, 200] + ) + parser.add_argument( + '--model-net-decays', + type=float, + nargs='*', + default=[0.000025, 0.00005, 0.000075, 0.000075, 0.0001] + ) + parser.add_argument('--ensemble-size', type=int, default=7) + parser.add_argument('--num-elites', type=int, default=5) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--model-lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', default=False, action='store_true') + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument("--start-timesteps", type=int, default=5000) + parser.add_argument('--epoch', type=int, default=200) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-collect', type=int, default=1) + parser.add_argument('--update-per-step', type=float, default=20.) + parser.add_argument('--n-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument('--real-ratio', type=float, default=0.1) + parser.add_argument('--training-num', type=int, default=1) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--rollout-batch-size', type=int, default=100000) + parser.add_argument( + '--rollout-schedule', type=int, nargs='*', default=[1, 100, 1, 1] + ) + parser.add_argument('--model-train-freq', type=int, default=250) + parser.add_argument('--model-retain-epochs', type=int, default=1) + parser.add_argument('--deterministic', default=False, action='store_true') + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', + type=str, + default='cuda' if torch.cuda.is_available() else 'cpu', + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + return parser.parse_args() + + +def test_mbpo(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + if args.training_num > 1: + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + else: + train_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + model_net = EnsembleMLPGaussian( + args.ensemble_size, + args.state_shape, + args.action_shape, + hidden_sizes=args.model_hidden_sizes, + activation=nn.SiLU, + device=args.device + ).to(args.device) + assert len(args.model_net_decays) == len(args.model_hidden_sizes) + 1 + parameters = [] + layer = -1 + for name, param in model_net.named_parameters(): + if name.endswith('.weight'): + layer += 1 + option = { + 'params': param, + 'weight_decay': args.model_net_decays[layer], + } + parameters.append(option) + else: + parameters.append({'params': param}) + model_net_optim = torch.optim.Adam( + parameters, + lr=args.model_lr, + ) + model = GaussianModel( + args.ensemble_size, + model_net, + model_net_optim, + device=args.device, + num_elites=args.num_elites, + batch_size=args.batch_size, + deterministic=args.deterministic + ) + domain = args.task.split("-")[0] + terminal_fn = TERMINAL_FUNCTIONS[domain] + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = MBPOPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + action_space=env.action_space + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + fake_env = FakeEnv(model, buffer, terminal_fn, args.rollout_batch_size) + model_buffer = SimpleReplayBuffer(args.rollout_batch_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + model_collector = RolloutsCollector(policy, fake_env, model_buffer) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_mbpo' + log_path = os.path.join(args.logdir, args.task, 'mbpo', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + if not args.watch: + # trainer + result = dyna_trainer( + policy, + model, + train_collector, + test_collector, + model_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + args.rollout_batch_size, + args.rollout_schedule, + args.real_ratio, + args.start_timesteps, + model_train_freq=args.model_train_freq, + model_retain_epochs=args.model_retain_epochs, + update_per_step=args.update_per_step, + save_fn=save_fn, + logger=logger, + test_in_train=False + ) + torch.save(model.network.state_dict(), os.path.join(log_path, 'model.pth')) + pprint.pprint(result) + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_mbpo() diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py new file mode 100755 index 000000000..91458c825 --- /dev/null +++ b/examples/mujoco/mujoco_redq.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import REDQPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import EnsembleNet, Net +from tianshou.utils.net.continuous import ActorProb, EnsembleCritic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Ant-v3') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=1000000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256]) + parser.add_argument('--ensemble-size', type=int, default=10) + parser.add_argument('--subset-size', type=int, default=2) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', default=False, action='store_true') + parser.add_argument('--alpha-lr', type=float, default=3e-4) + parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument('--epoch', type=int, default=200) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=1) + parser.add_argument('--update-per-step', type=int, default=20) + parser.add_argument('--n-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=256) + parser.add_argument( + '--target-mode', type=str, choices=('min', 'mean'), default='min' + ) + parser.add_argument('--training-num', type=int, default=1) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only' + ) + return parser.parse_args() + + +def test_redq(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + # train_envs = gym.make(args.task) + if args.training_num > 1: + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + else: + train_envs = gym.make(args.task) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + net_a, + args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net_c = EnsembleNet( + args.ensemble_size, + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device + ) + critics = EnsembleCritic( + args.ensemble_size, + net_c, + device=args.device, + ).to(args.device) + critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = REDQPolicy( + actor, + actor_optim, + critics, + critics_optim, + args.ensemble_size, + args.subset_size, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + estimation_step=args.n_step, + actor_delay=args.update_per_step, + target_mode=args.target_mode, + action_space=env.action_space, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_redq' + log_path = os.path.join(args.logdir, args.task, 'redq', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + if not args.watch: + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False + ) + pprint.pprint(result) + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_redq() diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 89250d009..178f1646b 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -15,7 +15,12 @@ PrioritizedVectorReplayBuffer, ) from tianshou.data.buffer.cached import CachedReplayBuffer -from tianshou.data.collector import Collector, AsyncCollector +from tianshou.data.buffer.simple import SimpleReplayBuffer +from tianshou.data.collector import ( + Collector, + AsyncCollector, + RolloutsCollector, +) __all__ = [ "Batch", @@ -30,6 +35,8 @@ "VectorReplayBuffer", "PrioritizedVectorReplayBuffer", "CachedReplayBuffer", + "SimpleReplayBuffer", "Collector", "AsyncCollector", + "RolloutsCollector", ] diff --git a/tianshou/data/buffer/simple.py b/tianshou/data/buffer/simple.py new file mode 100644 index 000000000..aa39cd51d --- /dev/null +++ b/tianshou/data/buffer/simple.py @@ -0,0 +1,168 @@ +from typing import Any, List, Tuple, Union + +import numpy as np + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.batch import _alloc_by_keys_diff, _create_value + + +class SimpleReplayBuffer(ReplayBuffer): + """:class:`~tianshou.data.SimpleReplayBuffer` stores data generated from interaction \ + between the policy and environment. + + SimpleReplayBuffer adds a sequence of data by directly filling in samples. It \ + ignores sequence information in an episode. + + :param int size: the maximum size of replay buffer. + """ + + def __init__( + self, + size: int, + ) -> None: + self.maxsize = size + self._meta: Batch = Batch() + self.reset() + + def reset(self) -> None: + """Clear all the data in replay buffer.""" + self._index = self._size = 0 + + def unfinished_index(self) -> np.ndarray: + """Return the index of unfinished episode.""" + return np.arange(self._size)[~self.done[:self._size]] + + def prev(self, index: Union[int, np.ndarray]) -> np.ndarray: + """Return the input index.""" + return np.array(index) + + def next(self, index: Union[int, np.ndarray]) -> np.ndarray: + """Return the input index.""" + return np.array(index) + + def update(self, buffer: "ReplayBuffer") -> np.ndarray: + """Move the data from the given buffer to current buffer. + + Return the updated indices. If update fails, return an empty array. + """ + if len(buffer) == 0 or self.maxsize == 0: + return np.array([], int) + self.add(buffer._meta) + num_samples = len(buffer) + to_indices = np.arange(self._index, self._index + num_samples) % self.maxsize + return to_indices + + def _add_index(self, rew: Union[float, np.ndarray], + done: bool) -> Tuple[int, Union[float, np.ndarray], int, int]: + """Deprecated.""" + raise NotImplementedError + + def add( + self, + batch: Batch, + ) -> Tuple[int, int, int, int]: + """Add a batch of data into SimpleReplayBuffer. + + :param Batch batch: the input data batch. Its keys must belong to the 7 + reserved keys, and "obs", "act", "rew", "done" is required. + + Return current_index and constants to keep compatability + """ + # preprocess batch + b = Batch() + for key in set(self._reserved_keys).intersection(batch.keys()): + b.__dict__[key] = batch[key] + batch = b + assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) + + num_samples = len(batch) + ptr = self._index + indices = np.arange(self._index, self._index + num_samples) % self.maxsize + self._size = min(self._size + num_samples, self.maxsize) + self._index = (self._index + num_samples) % self.maxsize + try: + self._meta[indices] = batch + except ValueError: + stack = False + batch.rew = batch.rew.astype(float) + batch.done = batch.done.astype(bool) + if self._meta.is_empty(): + self._meta = _create_value( # type: ignore + batch, self.maxsize, stack + ) + else: # dynamic key pops up in batch + _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) + self._meta[indices] = batch + return ptr, 0, 0, 0 + + def sample_indices(self, batch_size: int) -> np.ndarray: + """Get a random sample of index with size = batch_size. + + Return all available indices in the buffer if batch_size is 0; return an empty + numpy array if batch_size < 0 or no available index can be sampled. + """ + if batch_size > 0: + return np.random.choice(self._size, batch_size) + elif batch_size == 0: # construct current available indices + return np.concatenate( + [np.arange(self._index, self._size), + np.arange(self._index)] + ) + else: + return np.array([], int) + + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: + """Get a random sample from buffer with size = batch_size. + + Return all the data in the buffer if batch_size is 0. + + :return: Sample data and its corresponding index inside the buffer. + """ + indices = self.sample_indices(batch_size) + return self[indices], indices + + def get( + self, + index: Union[int, List[int], np.ndarray], + key: str, + default_value: Any = None, + ) -> Union[Batch, np.ndarray]: + """Return self.key[index] or default_value. + + E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the + stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. + + :param index: the index for getting stacked data. + :param str key: the key to get, should be one of the reserved_keys. + :param default_value: if the given key's data is not found and default_value is + set, return this default_value. + """ + if key not in self._meta and default_value is not None: + return default_value + val = self._meta[key] + try: + return val[index] + except IndexError as e: + if not (isinstance(val, Batch) and val.is_empty()): + raise e # val != Batch() + return Batch() + + def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch: + """Return a data batch: self[index].""" + if isinstance(index, slice): # change slice to np array + # buffer[:] will get all available data + indices = self.sample_indices(0) if index == slice(None) \ + else self._indices[:len(self)][index] + else: + indices = index + # raise KeyError first instead of AttributeError, + # to support np.array([SimpleReplayBuffer()]) + return Batch( + obs=self.obs[indices], + act=self.act[indices], + rew=self.rew[indices], + done=self.done[indices], + obs_next=self.obs_next[indices], + info=self.get(indices, "info", Batch()), + policy=self.get(indices, "policy", Batch()), + ) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 82545a041..3fe4645fd 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -11,6 +11,7 @@ CachedReplayBuffer, ReplayBuffer, ReplayBufferManager, + SimpleReplayBuffer, VectorReplayBuffer, to_numpy, ) @@ -571,3 +572,118 @@ def collect( "rew_std": rew_std, "len_std": len_std, } + + +class RolloutsCollector(Collector): + """RolloutsCollector collects model rollouts. + + The arguments are exactly the same as :class:`~tianshou.data.Collector`, please + refer to :class:`~tianshou.data.Collector` for more detailed explanation. + """ + + def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: + """Only SimpleReplayBuffer is used.""" + assert isinstance(buffer, SimpleReplayBuffer) + self.buffer = buffer + + def reset_buffer(self) -> None: + """Reset the data buffer.""" + self.buffer.reset() + + def collect( + self, + n_step: Optional[int] = None, + no_grad: bool = True, + ) -> Dict[str, Any]: + """Collect a specified number of step or episode. + + :param int n_step: how many steps you want to collect. + :param bool no_grad: whether to retain gradient in policy.forward(). Default to + True (no gradient retaining). + + :return: A dict including the following keys + + * ``n/ep`` collected number of episodes. + * ``n/st`` collected number of steps. + """ + if n_step is not None: + assert n_step > 0 + if not n_step % self.env_num == 0: + warnings.warn( + f"n_step={n_step} is not a multiple of #env ({self.env_num}), " + "which may cause extra transitions collected into the buffer." + ) + else: + raise TypeError("Please specify n_step in RolloutsCollector.collect().") + + start_time = time.time() + + step_count = 0 + episode_count = 0 + + while True: + # restore the state: if the last state is None, it won't store + last_state = self.data.policy.pop("hidden_state", None) + + # get the next action + if no_grad: + with torch.no_grad(): # faster than retain_grad version + # self.data.obs will be used by agent to get result + result = self.policy(self.data, last_state) + else: + result = self.policy(self.data, last_state) + # update state / act / policy into self.data + policy = result.get("policy", Batch()) + assert isinstance(policy, Batch) + state = result.get("state", None) + if state is not None: + policy.hidden_state = state # save state into buffer + act = to_numpy(result.act) + if self.exploration_noise: + act = self.policy.exploration_noise(act, self.data) + self.data.update(policy=policy, act=act) + + # get bounded and remapped actions first (not saved into buffer) + action_remap = self.policy.map_action(self.data.act) + # step in env + result = self.env.step(action_remap) # type: ignore + obs_next, rew, done, info = result + + self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) + if self.preprocess_fn: + self.data.update( + self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + policy=self.data.policy, + ) + ) + + # add data into the buffer + self.buffer.add(self.data) + + # collect statistics + step_count += self.env_num + + if np.any(done): + env_ind_local = np.where(done)[0] + episode_count += len(env_ind_local) + for i in env_ind_local: + self._reset_state(i) + + self.data.obs = self.data.obs_next + + if n_step and step_count >= n_step: + break + + # generate statistics + self.collect_step += step_count + self.collect_episode += episode_count + self.collect_time += max(time.time() - start_time, 1e-9) + + return { + "n/ep": episode_count, + "n/st": step_count, + } diff --git a/tianshou/data/dataset.py b/tianshou/data/dataset.py new file mode 100644 index 000000000..86752c7e0 --- /dev/null +++ b/tianshou/data/dataset.py @@ -0,0 +1,47 @@ +from typing import Tuple + +import torch + +from tianshou.data import Batch + + +class TransitionDataset(torch.utils.data.Dataset): + """Construct transition dataset.""" + + def __init__( + self, + batch: Batch, + ) -> None: + self.size = len(batch) + observation = batch.obs + action = batch.act + reward = batch.rew[:, None] + next_observation = batch.obs_next + delta_observation = next_observation - observation + + self.inputs = torch.cat((observation, action), dim=-1) + self.outputs = torch.cat((delta_observation, reward), dim=-1) + + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: + return self.inputs[idx], self.outputs[idx] + + def __len__(self) -> int: + return self.size + + +class TransitionEnsembleDataset(TransitionDataset): + """Construct transition dataset with data randomly shuffled.""" + + def __init__( + self, + batch: Batch, + ensemble_size: int = 1, + ) -> None: + super().__init__(batch=batch, ) + + indices = torch.randint(self.size, (ensemble_size, self.size)) + self.inputs = self.inputs[indices] + self.outputs = self.outputs[indices] + + def __getitem__(self, idx): + return self.inputs[:, idx, :], self.outputs[:, idx, :] diff --git a/tianshou/env/fake.py b/tianshou/env/fake.py new file mode 100644 index 000000000..8eb45b0a1 --- /dev/null +++ b/tianshou/env/fake.py @@ -0,0 +1,270 @@ +import itertools +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch.distributions import Independent, Normal +from torch.utils.data import DataLoader + +from tianshou.data import Batch, ReplayBuffer +from tianshou.data.dataset import TransitionDataset, TransitionEnsembleDataset +from tianshou.utils.net.common import GaussianMLELoss + + +class GaussianModel(object): + """Wrapper of Gaussian model. + + :param int ensemble_size: number of subnets in the ensemble. + :param torch.nn.Module network: core network of learned model + :param torch.nn.Optimizer optimizer: network optimizer + :param Optional[Union[str, int, torch.device]] device: + :param int env_num: number of environments to be executed in parallel. + :param int batch_size: model training batch size. + :param float ratio: train-validation split ratio. + :param bool deterministic: whether to predict the next observation + deterministically. + :param Optional[int] max_epoch: maximum number of epochs of each training. + :param int max_static_epoch: If validation error is not reduced by a certain + threshold for max_static_epoch epochs, the training is early-stopped. + """ + + def __init__( + self, + ensemble_size: int, + network: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: Optional[Union[str, int, torch.device]] = None, + num_elites: int = 1, + batch_size: int = 64, + ratio: float = 0.8, + deterministic: bool = False, + max_epoch: Optional[int] = None, + max_static_epoch: int = 5, + ) -> None: + self.ensemble_size = ensemble_size + self.network = network + self.optimizer = optimizer + self.device = device + self.num_elites = num_elites + self.batch_size = batch_size + self.ratio = ratio + self.deterministic = deterministic + self.max_epoch = max_epoch + self.max_static_epoch = max_static_epoch + self.best = [1e10] * ensemble_size + + def train( + self, + batch: Batch, + ) -> Dict[str, Union[float, int]]: + """Train the dynamics model. + + :param tianshou.data.Batch batch: Training data + + :return: Training information including training loss, validation loss and + number of training epochs. + """ + batch.to_torch(dtype=torch.float32) + total_num = batch.obs.shape[0] + train_num = int(total_num * self.ratio) + permutation = np.random.permutation(total_num) + train_dataset = TransitionEnsembleDataset( + batch=batch[permutation[:train_num]], + ensemble_size=self.ensemble_size, + ) + val_dataset = TransitionDataset(batch=batch[permutation[train_num:]], ) + train_dl = DataLoader( + dataset=train_dataset, + batch_size=self.batch_size, + shuffle=True, + ) + val_dl = DataLoader( + dataset=val_dataset, + batch_size=self.batch_size, + ) + + epoch_iter: Iterable + if self.max_epoch is None: + epoch_iter = itertools.count() + else: + epoch_iter = range(self.max_epoch) + + loss_fn = GaussianMLELoss(coeff=0.01) + epochs_since_update = 0 + epoch_this_train = 0 + for _ in epoch_iter: + self.network.train() + for x, y in train_dl: + # Input shape is (batch_size, ensemble_size, data_dimension) + x = x.transpose(0, 1).to(self.device) + y = y.transpose(0, 1).to(self.device) + mean, logvar, max_logvar, min_logvar = self.network(x) + loss = loss_fn(mean, logvar, max_logvar, min_logvar, y) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.network.eval() + mse = torch.zeros(self.ensemble_size) + with torch.no_grad(): + for i, (x, y) in enumerate(val_dl): + x = x.to(self.device) + y = y.to(self.device) + mean, _, _, _ = self.network(x) + batch_mse = torch.mean(torch.square(mean - y), dim=(1, 2)).cpu() + mse = (mse * i + batch_mse) / (i + 1) + + epoch_this_train += 1 + updated = False + for i in range(len(mse)): + mse_item = mse[i].item() + improvement = (self.best[i] - mse_item) / self.best[i] + if improvement > 0.01: + updated = True + self.best[i] = mse_item + if updated: + epochs_since_update = 0 + else: + epochs_since_update += 1 + + if epochs_since_update > self.max_static_epoch: + break + + # Select elites + self.elite_indice = torch.argsort(mse)[:self.num_elites] + elite_mse = mse[self.elite_indice].mean().item() + + # Collect training info to be logged. + train_info = { + "model/train_loss": loss.item(), + "model/val_loss": elite_mse, + "model/train_epoch": epoch_this_train, + } + + return train_info + + def predict( + self, + batch: Batch, + ) -> Batch: + """Predict a step forward. + + :param tianshou.data.Batch batch: prediction input + + :return: The input batch with next observation, reward and info added. + """ + batch.to_torch(dtype=torch.float32, device=self.device) + self.network.eval() + observation = batch.obs + action = batch.act + inputs = torch.cat((observation, action), dim=-1) + with torch.no_grad(): + mean, logvar, _, _ = self.network(inputs) + std = torch.sqrt(torch.exp(logvar)) + dist = Independent(Normal(mean, std), 1) + if self.deterministic: + sample = mean + else: + sample = dist.rsample() + log_prob = dist.log_prob(sample) + # For each input, choose a network from the ensemble + _, batch_size, _ = sample.shape + indice = torch.randint(self.num_elites, size=(batch_size, )) + choice_indice = self.elite_indice[indice] + batch_indice = torch.arange(batch_size) + next_observation = observation + \ + sample[choice_indice, batch_indice, :-1] + reward = sample[choice_indice, batch_indice, -1] + log_prob = log_prob[choice_indice, batch_indice] + info = list( + map(lambda x: {"log_prob": x.item()}, torch.split(log_prob, 1)) + ) + + batch.obs_next = next_observation + batch.rew = reward + batch.info = info + + return batch + + +class FakeEnv(object): + """Virtual environment with learned model. + + :param model: transition model. + :param buffer: environment buffer to sample the initial observations. + :param function terminal_fn: terminal function + :param int env_num: Number of environments to be executed in parallel. + """ + + def __init__( + self, + model: GaussianModel, + buffer: ReplayBuffer, + terminal_fn: Callable, + env_num: int = 1, + ) -> None: + self.model = model + self.buffer = buffer + self.env_num = env_num + self.terminal_fn = terminal_fn + + # To be compatible with Collector + self.action_space = None + + def __len__(self) -> int: + return self.env_num + + def reset( + self, + *args: Any, + **kwargs: Any, + ) -> Union[np.ndarray, None]: + """Reset the virtual environments. + + Sampling observations from the buffer. + """ + if len(self.buffer) == 0: + return None + + batch, _ = self.buffer.sample(batch_size=self.env_num, ) + self.state = batch.obs.copy() + + return self.state.copy() + + def step( + self, + action: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List]: + """Take a step in every virtual environment. + + If an environment is terminated, it is automatically reset. + + :param np.ndarray action: Actions of shape (batch_size, action_dim) + + :return: Vectorized results, similar to that of OpenAI Gym environments. + """ + observation = self.state.copy() + batch = Batch( + obs=observation, + act=action, + ) + batch: Batch = self.model.predict(batch) + batch.to_numpy() + reward = batch.rew + next_observation = batch.obs_next + done: np.ndarray = self.terminal_fn(observation, action, next_observation) + + # Reset terminal environments. + if np.any(done): + done_indices = np.where(done)[0] + batch_sampled, _ = self.buffer.sample(batch_size=len(done_indices), ) + observation_reset = batch_sampled.obs.copy() + next_observation[done_indices] = observation_reset + + info = batch.info + self.state = next_observation.copy() + + return next_observation, reward, done, info + + def close(self) -> None: + pass diff --git a/tianshou/env/mujoco/static.py b/tianshou/env/mujoco/static.py new file mode 100644 index 000000000..740b40523 --- /dev/null +++ b/tianshou/env/mujoco/static.py @@ -0,0 +1,58 @@ +import numpy as np + + +def halfcheetah_terminal_fn( + obs: np.ndarray, + act: np.ndarray, + next_obs: np.ndarray, +) -> np.ndarray: + """Terminal condition function for HalfCheetah.""" + assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 + + done = np.array([False]).repeat(len(obs)) + return done + + +def hopper_terminal_fn( + obs: np.ndarray, + act: np.ndarray, + next_obs: np.ndarray, +) -> np.ndarray: + """Terminal condition function for Hopper.""" + assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 + + height = next_obs[:, 0] + angle = next_obs[:, 1] + not_done = \ + np.isfinite(next_obs).all(axis=-1) * \ + np.abs(next_obs[:, 1:] < 100).all(axis=-1) * \ + (height > .7) * \ + (np.abs(angle) < .2) + + done = ~not_done + return done + + +def walker2d_terminal_fn( + obs: np.ndarray, act: np.ndarray, next_obs: np.ndarray +) -> np.ndarray: + """Terminal condition function for Walker2d.""" + assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2 + + height = next_obs[:, 0] + angle = next_obs[:, 1] + not_done = \ + (height > 0.8) * \ + (height < 2.0) * \ + (angle > -1.0) * \ + (angle < 1.0) + + done = ~not_done + return done + + +TERMINAL_FUNCTIONS = { + 'HalfCheetah': halfcheetah_terminal_fn, + 'Hopper': hopper_terminal_fn, + 'Walker2d': walker2d_terminal_fn, +} diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index ced11aff5..373ed61e2 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -18,6 +18,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.modelfree.redq import REDQPolicy from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.bcq import BCQPolicy from tianshou.policy.imitation.cql import CQLPolicy @@ -26,6 +27,7 @@ from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.modelbased.icm import ICMPolicy +from tianshou.policy.modelbased.mbpo import MBPOPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager __all__ = [ @@ -46,6 +48,7 @@ "TD3Policy", "SACPolicy", "DiscreteSACPolicy", + "REDQPolicy", "ImitationPolicy", "BCQPolicy", "CQLPolicy", @@ -54,5 +57,6 @@ "DiscreteCRRPolicy", "PSRLPolicy", "ICMPolicy", + "MBPOPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/modelbased/mbpo.py b/tianshou/policy/modelbased/mbpo.py new file mode 100644 index 000000000..e823ae483 --- /dev/null +++ b/tianshou/policy/modelbased/mbpo.py @@ -0,0 +1,35 @@ +from typing import Any, Dict + +from tianshou.data import Batch +from tianshou.data.buffer.base import ReplayBuffer +from tianshou.policy import SACPolicy + + +class MBPOPolicy(SACPolicy): + """Implementation of Model-Based Policy Optimization. arXiv:1906.08253. + + MBPO builds on SAC with different training scheme. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.SACPolicy` + """ + + def update( + self, env_sample_size: int, env_buffer: ReplayBuffer, model_sample_size: int, + model_buffer: ReplayBuffer, **kwargs: Any + ) -> Dict[str, Any]: + """MBPO collects samples from both the environment and model rollouts.""" + env_batch, env_indice = env_buffer.sample(env_sample_size) + model_batch, model_indice = model_buffer.sample(model_sample_size) + self.updating = True + env_batch = self.process_fn(env_batch, env_buffer, env_indice) + model_batch = self.process_fn(model_batch, model_buffer, model_indice) + batch = Batch.cat([env_batch, model_batch]) + result = self.learn(batch, **kwargs) + env_batch = batch[:env_sample_size] + model_batch = batch[env_sample_size:] + self.post_process_fn(env_batch, env_buffer, env_indice) + self.post_process_fn(model_batch, model_buffer, model_indice) + self.updating = False + return result diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py new file mode 100644 index 000000000..28a0bf711 --- /dev/null +++ b/tianshou/policy/modelfree/redq.py @@ -0,0 +1,207 @@ +from copy import deepcopy +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from torch.distributions import Independent, Normal + +from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.exploration import BaseNoise +from tianshou.policy import DDPGPolicy + + +class REDQPolicy(DDPGPolicy): + """Implementation of REDQ. arXiv:2101.05982. + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critics: critic ensemble networks. + :param torch.optim.Optimizer critics_optim: the optimizer for the critic networks. + :param int ensemble_size: Number of sub-networks in the critic ensemble. + Default to 10. + :param int subset_size: Number of networks in the subset. Default to 2. + :param float tau: param for soft update of the target network. Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. + :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy + regularization coefficient. Default to 0.2. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then + alpha is automatically tuned. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + :param int actor_delay: Number of critic updates before an actor update. + Default to 20. + :param BaseNoise exploration_noise: add a noise to action for exploration. + Default to None. This is useful when solving hard-exploration problem. + :param bool deterministic_eval: whether to use deterministic action (mean + of Gaussian policy) instead of stochastic action sampled by the policy. + Default to True. + :param str target_mode: methods to integrate critic values in the subset, + currently support minimum and average. Default to min. + :param bool action_scaling: whether to map actions from range [-1, 1] to range + [action_spaces.low, action_spaces.high]. Default to True. + :param str action_bound_method: method to bound action to range [-1, 1], can be + either "clip" (for simply clipping the action) or empty string for no bounding. + Default to "clip". + :param Optional[gym.Space] action_space: env's action space, mandatory if you want + to use option "action_scaling" or "action_bound_method". Default to None. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critics: torch.nn.Module, + critics_optim: torch.optim.Optimizer, + ensemble_size: int = 10, + subset_size: int = 2, + tau: float = 0.005, + gamma: float = 0.99, + alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, + reward_normalization: bool = False, + estimation_step: int = 1, + actor_delay: int = 20, + exploration_noise: Optional[BaseNoise] = None, + deterministic_eval: bool = True, + target_mode: str = 'min', + **kwargs: Any, + ) -> None: + super().__init__( + None, None, None, None, tau, gamma, exploration_noise, + reward_normalization, estimation_step, **kwargs + ) + self.actor, self.actor_optim = actor, actor_optim + self.critics, self.critics_old = critics, deepcopy(critics) + self.critics_old.eval() + self.critics_optim = critics_optim + assert 0 < subset_size <= ensemble_size, \ + 'Invalid choice of ensemble size or subset size.' + self.ensemble_size = ensemble_size + self.subset_size = subset_size + + self._is_auto_alpha = False + self._alpha: Union[float, torch.Tensor] + if isinstance(alpha, tuple): + self._is_auto_alpha = True + self._target_entropy, self._log_alpha, self._alpha_optim = alpha + assert alpha[1].shape == torch.Size([1]) and alpha[1].requires_grad + self._alpha = self._log_alpha.detach().exp() + else: + self._alpha = alpha + + if target_mode in ('min', 'mean'): + self.target_mode = target_mode + else: + raise ValueError('Unsupported mode of Q target computing.') + + self.critic_gradient_step = 0 + self.actor_delay = actor_delay + self._deterministic_eval = deterministic_eval + self.__eps = np.finfo(np.float32).eps.item() + + def train(self, mode: bool = True) -> "REDQPolicy": + self.training = mode + self.actor.train(mode) + self.critics.train(mode) + return self + + def sync_weight(self) -> None: + for o, n in zip(self.critics_old.parameters(), self.critics.parameters()): + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) + + def forward( # type: ignore + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + input: str = "obs", + **kwargs: Any, + ) -> Batch: + obs = batch[input] + logits, h = self.actor(obs, state=state, info=batch.info) + assert isinstance(logits, tuple) + dist = Independent(Normal(*logits), 1) + if self._deterministic_eval and not self.training: + act = logits[0] + else: + act = dist.rsample() + log_prob = dist.log_prob(act).unsqueeze(-1) + # apply correction for Tanh squashing when computing logprob from Gaussian + # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + # in appendix C to get some understanding of this equation. + if self.action_scaling and self.action_space is not None: + action_scale = to_torch_as( + (self.action_space.high - self.action_space.low) / 2.0, act + ) + else: + action_scale = 1.0 # type: ignore + squashed_action = torch.tanh(act) + log_prob = log_prob - torch.log( + action_scale * (1 - squashed_action.pow(2)) + self.__eps + ).sum(-1, keepdim=True) + return Batch( + logits=logits, act=squashed_action, state=h, dist=dist, log_prob=log_prob + ) + + def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: + batch = buffer[indices] # batch.obs: s_{t+n} + obs_next_result = self(batch, input='obs_next') + a_ = obs_next_result.act + sample_ensemble_idx = np.random.choice( + self.ensemble_size, self.subset_size, replace=False + ) + qs = self.critics_old(batch.obs_next, a_)[sample_ensemble_idx, ...] + if self.target_mode == 'min': + target_q, _ = torch.min(qs, dim=0) + elif self.target_mode == 'mean': + target_q = torch.mean(qs, dim=0) + target_q -= self._alpha * obs_next_result.log_prob + + return target_q + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + # critic ensemble + weight = getattr(batch, "weight", 1.0) + current_qs = self.critics(batch.obs, batch.act).flatten(1) + target_q = batch.returns.flatten() + td = current_qs - target_q + critic_loss = (td.pow(2) * weight).mean() + self.critics_optim.zero_grad() + critic_loss.backward() + self.critics_optim.step() + batch.weight = torch.mean(td, dim=0) # prio-buffer + self.critic_gradient_step += 1 + + # actor + if self.critic_gradient_step % self.actor_delay == 0: + obs_result = self(batch) + a = obs_result.act + current_qa = self.critics(batch.obs, a).mean(dim=0).flatten() + actor_loss = (self._alpha * obs_result.log_prob.flatten() - + current_qa).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + if self._is_auto_alpha: + log_prob = obs_result.log_prob.detach() + self._target_entropy + alpha_loss = -(self._log_alpha * log_prob).mean() + self._alpha_optim.zero_grad() + alpha_loss.backward() + self._alpha_optim.step() + self._alpha = self._log_alpha.detach().exp() + + self.sync_weight() + + result = {"loss/critics": critic_loss.item()} + if self.critic_gradient_step % self.actor_delay == 0: + result["loss/actor"] = actor_loss.item(), + if self._is_auto_alpha: + result["loss/alpha"] = alpha_loss.item() + result["alpha"] = self._alpha.item() # type: ignore + + return result diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 11b3a95ef..8fc093a60 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -3,11 +3,13 @@ # isort:skip_file from tianshou.trainer.utils import test_episode, gather_info +from tianshou.trainer.dyna import dyna_trainer from tianshou.trainer.onpolicy import onpolicy_trainer from tianshou.trainer.offpolicy import offpolicy_trainer from tianshou.trainer.offline import offline_trainer __all__ = [ + "dyna_trainer", "offpolicy_trainer", "onpolicy_trainer", "offline_trainer", diff --git a/tianshou/trainer/dyna.py b/tianshou/trainer/dyna.py new file mode 100644 index 000000000..642176869 --- /dev/null +++ b/tianshou/trainer/dyna.py @@ -0,0 +1,243 @@ +import time +from collections import defaultdict +from typing import Callable, Dict, Optional, Sequence, Union + +import numpy as np +import tqdm + +from tianshou.data import Collector, SimpleReplayBuffer +from tianshou.env.fake import GaussianModel +from tianshou.policy import BasePolicy, MBPOPolicy +from tianshou.trainer import gather_info, test_episode +from tianshou.utils import BaseLogger, LazyLogger, MovAvg, tqdm_config + + +def dyna_trainer( + policy: MBPOPolicy, + model: GaussianModel, + train_collector: Collector, + test_collector: Collector, + model_collector: Collector, + max_epoch: int, + step_per_epoch: int, + step_per_collect: int, + episode_per_test: int, + batch_size: int = 256, + rollout_batch_size: int = 100000, + rollout_schedule: Sequence[int] = (1, 1, 1, 1), + real_ratio: float = 0.1, + start_timesteps: int = 0, + model_train_freq: int = 250, + model_retain_epochs: int = 1, + update_per_step: Union[int, float] = 1, + train_fn: Optional[Callable[[int, int], None]] = None, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None, + resume_from_log: bool = False, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), + verbose: bool = True, + test_in_train: bool = True, +) -> Dict[str, Union[float, str]]: + """A wrapper for dyna style trainer procedure. + + The "step" in trainer means an environment step (a.k.a. transition). + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param Collector model_collector: the collector used for collecting model rollouts. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int rollout_batch_size: the batch size of rollouts in parallel. + :param Sequence rollout_schedule: scheduler for rollout length of each epoch. + :param float real_ratio: ratio of samples from real environment interactions in + each gradient update. + :param int model_retain_epochs: Number of epochs that retains samples in the + buffer. + :param int/float update_per_step: the number of times the policy network would be + updated per transition after (step_per_collect) transitions are collected, + e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will + be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are + collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function save_checkpoint_fn: a function to save training process, with the + signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can + save whatever you want. + :param bool resume_from_log: resume env_step/gradient_step and other metadata from + existing tensorboard log. Default to False. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + # Initial steps + train_collector.collect(n_step=start_timesteps, random=True) + + env_batch_size = int(batch_size * real_ratio) + model_batch_size = batch_size - env_batch_size + assert env_batch_size > 0 and model_batch_size > 0 + + start_epoch, env_step, gradient_step, last_train_step = 0, 0, 0, 0 + if resume_from_log: + start_epoch, env_step, gradient_step = logger.restore_data() + last_rew, last_len = 0.0, 0 + stat: Dict[str, MovAvg] = defaultdict(MovAvg) + start_time = time.time() + train_collector.reset_stat() + test_collector.reset_stat() + test_in_train = test_in_train and train_collector.policy == policy + test_result = test_episode( + policy, test_collector, test_fn, start_epoch, episode_per_test, logger, + env_step, reward_metric + ) + best_epoch = start_epoch + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] + if save_fn: + save_fn(policy) + + epoch_low, epoch_high, min_length, max_length = rollout_schedule + rollouts_per_epoch = rollout_batch_size * step_per_epoch / model_train_freq + + for epoch in range(1 + start_epoch, 1 + max_epoch): + # train + policy.train() + + # Determine rollout length + if epoch <= epoch_low: + rollout_length = min_length + else: + dx = (epoch - epoch_low) / (epoch_high - epoch_low) + dx = min(dx, 1) + rollout_length = int(dx * (max_length - min_length) + min_length) + + with tqdm.tqdm( + total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + ) as t: + while t.n < t.total: + if ( + env_step - last_train_step >= model_train_freq or env_step == 0 + ) and 0. <= model.ratio < 1.: + last_train_step = env_step + # Train model + batch, _ = train_collector.buffer.sample(batch_size=0) + train_info = model.train(batch) + train_info["model/rollout_length"] = rollout_length + logger.write( + step_type="", + step=env_step, + data=train_info, + ) + # Rollout + model_steps_per_epoch = int(rollout_length * rollouts_per_epoch) + new_size = model_retain_epochs * model_steps_per_epoch + if model_collector.buffer.maxsize < new_size: + temp_buffer = model_collector.buffer + model_collector.buffer = SimpleReplayBuffer(new_size) + model_collector.buffer.update(temp_buffer) + model_collector.reset_env() + model_collector.collect(n_step=rollout_batch_size * rollout_length) + + if train_fn: + train_fn(epoch, env_step) + result = train_collector.collect(n_step=step_per_collect) + if result["n/ep"] > 0 and reward_metric: + rew = reward_metric(result["rews"]) + result.update(rews=rew, rew=rew.mean(), rew_std=rew.std()) + env_step += int(result["n/st"]) + t.update(result["n/st"]) + logger.log_train_data(result, env_step) + last_rew = result["rew"] if result["n/ep"] > 0 else last_rew + last_len = result["len"] if result["n/ep"] > 0 else last_len + data = { + "env_step": str(env_step), + "rew": f"{last_rew:.2f}", + "len": str(int(last_len)), + "n/ep": str(int(result["n/ep"])), + "n/st": str(int(result["n/st"])), + } + if result["n/ep"] > 0: + if test_in_train and stop_fn and stop_fn(result["rew"]): + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, + logger, env_step + ) + if stop_fn(test_result["rew"]): + if save_fn: + save_fn(policy) + logger.save_data( + epoch, env_step, gradient_step, save_checkpoint_fn + ) + t.set_postfix(**data) + return gather_info( + start_time, train_collector, test_collector, + test_result["rew"], test_result["rew_std"] + ) + else: + policy.train() + + for _ in range(round(update_per_step * result["n/st"])): + gradient_step += 1 + losses = policy.update( + env_batch_size, + train_collector.buffer, + model_batch_size, + model_collector.buffer, + ) + for k in losses.keys(): + stat[k].add(losses[k]) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.3f}" + logger.log_update_data(losses, gradient_step) + t.set_postfix(**data) + if t.n <= t.total: + t.update() + # test + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, logger, env_step, + reward_metric + ) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch < 0 or best_reward < rew: + best_epoch, best_reward, best_reward_std = epoch, rew, rew_std + if save_fn: + save_fn(policy) + logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn) + if verbose: + print( + f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}" + ) + if stop_fn and stop_fn(best_reward): + break + return gather_info( + start_time, train_collector, test_collector, best_reward, best_reward_std + ) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 9d03b6120..90eaa66ca 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -2,6 +2,7 @@ import numpy as np import torch +import torch.nn.functional as F from torch import nn ModuleType = Type[nn.Module] @@ -311,3 +312,302 @@ def forward(self, obs: Union[np.ndarray, torch.Tensor], *args: Any, if not isinstance(obs, torch.Tensor): obs = torch.as_tensor(obs, dtype=torch.float32) return self.net(obs=obs.cuda(), *args, **kwargs) + + +class EnsembleLinear(nn.Module): + """Linear Layer of Ensemble network. + + :param int ensemble_size: Number of subnets in the ensemble. + :param int inp_feature: dimension of the input vector. + :param int out_feature: dimension of the output vector. + :param bool bias: whether to include an additive bias, default to be True. + """ + + def __init__( + self, + ensemble_size: int, + in_feature: int, + out_feature: int, + bias: bool = True, + ) -> None: + super().__init__() + + # To be consistent with PyTorch default initializer + k = np.sqrt(1. / in_feature) + weight_data = torch.rand((ensemble_size, in_feature, out_feature)) * 2 * k - k + self.weight = nn.Parameter(weight_data, requires_grad=True) + + self.bias: Union[nn.Parameter, None] + if bias: + bias_data = torch.rand((ensemble_size, 1, out_feature)) * 2 * k - k + self.bias = nn.Parameter(bias_data, requires_grad=True) + else: + self.bias = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.matmul(x, self.weight) + if self.bias is not None: + x = x + self.bias + return x + + +class EnsembleMLP(nn.Module): + """Ensemble MLP backbone. + + Create an Ensemble MLP each of size input_dim * hidden_sizes[0] * + hidden_sizes[1] * ... * hidden_sizes[-1] * output_dim + + :param int ensemble_size: number of subnet in the ensemble + :param int input_dim: dimension of the input vector. + :param int output_dim: dimension of the output vector. If set to 0, there + is no final linear layer. + :param hidden_sizes: shape of MLP passed in as a list, not including + input_dim and output_dim. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param device: which device to create this model on. Default to None. + :param linear_layer: use this module as linear layer. Default to nn.Linear. + """ + + def __init__( + self, + ensemble_size: int, + input_dim: int, + output_dim: int, + hidden_sizes: Sequence[int] = (), + norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, + activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, + device: Optional[Union[str, int, torch.device]] = None, + ) -> None: + super().__init__() + self.ensemble_size = ensemble_size + self.device = device + if norm_layer: + if isinstance(norm_layer, list): + assert len(norm_layer) == len(hidden_sizes) + norm_layer_list = norm_layer + else: + norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))] + else: + norm_layer_list = [None] * len(hidden_sizes) + if activation: + if isinstance(activation, list): + assert len(activation) == len(hidden_sizes) + activation_list = activation + else: + activation_list = [activation for _ in range(len(hidden_sizes))] + else: + activation_list = [None] * len(hidden_sizes) + hidden_sizes = [input_dim] + list(hidden_sizes) + model = [] + for in_dim, out_dim, norm, activ in zip( + hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list + ): + model += [EnsembleLinear(ensemble_size, in_dim, out_dim)] + if norm is not None: + model += [norm(out_dim)] + if activ is not None: + model += [activ()] + if output_dim > 0: + model += [EnsembleLinear(ensemble_size, hidden_sizes[-1], output_dim)] + self.output_dim = output_dim or hidden_sizes[-1] + self.model = nn.Sequential(*model) + + def forward(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + x = torch.as_tensor(x, device=self.device, dtype=torch.float32) # type: ignore + return self.model(x) + + +class EnsembleNet(nn.Module): + """Wrapper of EnsembleMLP to support more specific DRL usage. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + :param int ensemble_size: number of subsets in the ensemble. + :param state_shape: int or a sequence of int of the shape of state. + :param action_shape: int or a sequence of int of the shape of action. + :param hidden_sizes: shape of MLP passed in as a list. + :param norm_layer: use which normalization before activation, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. + You can also pass a list of normalization modules with the same length + of hidden_sizes, to use different normalization module in different + layers. Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to + nn.ReLU. + :param device: specify the device when the network actually runs. Default + to "cpu". + :param bool softmax: whether to apply a softmax layer over the last layer's + output. + :param bool concat: whether the input shape is concatenated by state_shape + and action_shape. If it is True, ``action_shape`` is not the output + shape, but affects the input shape only. + + .. seealso:: + + Please refer to :class:`~tianshou.utils.net.common.EnsembleMLP` for more + detailed explanation on the usage of activation, norm_layer, etc. + + You can also refer to + :class:`~tianshou.utils.net.continuous.EnsembleCritic`, etc, to see + how it's suggested be used. + """ + + def __init__( + self, + ensemble_size: int, + state_shape: Union[int, Sequence[int]], + action_shape: Union[int, Sequence[int]] = 0, + hidden_sizes: Sequence[int] = (), + norm_layer: Optional[ModuleType] = None, + activation: Optional[ModuleType] = nn.ReLU, + device: Union[str, int, torch.device] = "cpu", + softmax: bool = False, + concat: bool = False, + ) -> None: + super().__init__() + self.device = device + self.softmax = softmax + input_dim = int(np.prod(state_shape)) + action_dim = int(np.prod(action_shape)) + if concat: + input_dim += action_dim + output_dim = action_dim if not concat else 0 + self.model = EnsembleMLP( + ensemble_size, + input_dim, + output_dim, + hidden_sizes, + norm_layer, + activation, + device=device, + ) + self.output_dim = self.model.output_dim + + def forward( + self, + obs: Union[np.ndarray, torch.Tensor], + state: Any = None, + ) -> Tuple[torch.Tensor, Any]: + """Mapping: s -> flatten (inside MLP)-> logits.""" + logits = self.model(obs) + if self.softmax: + logits = torch.softmax(logits, dim=-1) + return logits, state + + +class EnsembleMLPGaussian(nn.Module): + """Ensemble of Gaussian distribution network. + + :param int ensemble_size: number of subnets in the ensemble. + :param Sequence[int] state_shape: a sequence of for the shape of state. + :param Sequence[int] action_shape: a sequence of int for the shape of action. + :param hidden_sizes: a sequence of int for constructing the ensemble MLP + not including input_dim and output_dim. Default to empty sequence. + :param norm_layer: use which normalization before activation. You can also + pass a list of normalization modules with the same length of hidden_sizes, + to use different normalization module in different layers. + Default to no normalization. + :param activation: which activation to use after each layer, can be both + the same activation for all layers if passed in nn.Module, or different + activation for different Modules if passed in a list. Default to None. + :param float init_max: initial maximum logarithm of variance. + :param float init_min: initial minimum logarithm of variance. + :param Union[str, torch.device] device: which device to create this model on. + Default to None. + """ + + def __init__( + self, + ensemble_size: int, + state_shape: Sequence[int], + action_shape: Sequence[int], + hidden_sizes: Sequence[int] = (), + norm_layer: Optional[ModuleType] = None, + activation: Optional[ModuleType] = nn.ReLU, + init_max: float = 0.5, + init_min: float = -10., + device: Optional[Union[str, int, torch.device]] = None, + ) -> None: + super().__init__() + self.device = device + input_dim = int(np.prod(state_shape)) + int(np.prod(action_shape)) + output_dim = int(np.prod(state_shape)) + 1 + self.output_dim = output_dim + self.model = EnsembleMLP( + ensemble_size=ensemble_size, + input_dim=input_dim, + output_dim=output_dim * 2, + hidden_sizes=hidden_sizes, + norm_layer=norm_layer, + activation=activation, + device=device, + ) + + max_logvar = torch.ones( + (1, 1, output_dim), + device=device, + dtype=torch.float32, + ) * init_max + min_logvar = torch.ones( + (1, 1, output_dim), + device=device, + dtype=torch.float32, + ) * init_min + self.max_logvar = nn.Parameter(max_logvar, requires_grad=True) + self.min_logvar = nn.Parameter(min_logvar, requires_grad=True) + + def forward( + self, + x: Union[np.ndarray, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, nn.Parameter, nn.Parameter]: + x = torch.as_tensor(x, device=self.device, dtype=torch.float32) + x = self.model(x) + mean, logvar = torch.split(x, self.output_dim, dim=-1) + logvar = self.max_logvar - F.softplus(self.max_logvar - logvar) + logvar = self.min_logvar + F.softplus(logvar - self.min_logvar) + return mean, logvar, self.max_logvar, self.min_logvar + + +class GaussianMLELoss(object): + """Loss function from Maximum Likelihood Estimate of Gaussian distribution. + + :param float coeff: Coefficient of optional variable normalization. + """ + + def __init__(self, coeff: float = 0.01) -> None: + self.opt_coeff = coeff + + def __call__( + self, + mean: torch.Tensor, + logvar: torch.Tensor, + max_logvar: torch.Tensor, + min_logvar: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """Execute loss calculation. + + :param torch.Tensor mean: tensor of mean from the network output. + :param torch.Tensor logvar: tensor of logarithm of variance + from the network output. + :param torch.Tensor max_logvar: tensor of maximum logarithm of variance. + :param torch.Tensor min_logvar: tensor of minimum logarithm of variance. + :param torch.Tensor y: tensor of target. + """ + inv_var = torch.exp(-logvar) + mse = torch.mean(torch.square(mean - y) * inv_var) + var_loss = torch.mean(logvar) + opt_loss = self.opt_coeff * (torch.sum(max_logvar) - torch.sum(min_logvar)) + loss = mse + var_loss + opt_loss + + return loss diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index d68f3856f..210d669b9 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -4,7 +4,7 @@ import torch from torch import nn -from tianshou.utils.net.common import MLP +from tianshou.utils.net.common import MLP, EnsembleMLP SIGMA_MIN = -20 SIGMA_MAX = 2 @@ -471,3 +471,70 @@ def decode( # decode z with state! return self.max_action * \ torch.tanh(self.decoder(torch.cat([state, latent_z], -1))) + + +class EnsembleCritic(nn.Module): + """Ensemble critic network. Will create an actor operated in continuous \ + action space with structure of preprocess_net ---> 1(q value). + + :param int ensemble_size: number of subnets in the ensemble. + :param preprocess_net: a self-defined preprocess_net which output + a hidden state. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param int preprocess_net_output_dim: the output dimension of + preprocess_net. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + + .. seealso:: + + Please refer to :class:`~tianshou.utils.net.common.Net` as an instance + of how preprocess_net is suggested to be defined. + """ + + def __init__( + self, + ensemble_size: int, + preprocess_net: nn.Module, + hidden_sizes: Sequence[int] = (), + device: Union[str, int, torch.device] = "cpu", + preprocess_net_output_dim: Optional[int] = None, + ) -> None: + super().__init__() + self.device = device + self.preprocess = preprocess_net + self.output_dim = 1 + input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) + self.last = EnsembleMLP( + ensemble_size, + input_dim, # type: ignore + 1, + hidden_sizes, + device=self.device, + ) + + def forward( + self, + state: Union[np.ndarray, torch.Tensor], + action: Optional[Union[np.ndarray, torch.Tensor]] = None, + info: Dict[str, Any] = {}, + ) -> torch.Tensor: + """Mapping: (s, a) -> logits -> Q(s, a).""" + state = torch.as_tensor( + state, + device=self.device, # type: ignore + dtype=torch.float32, + ).flatten(1) + if action is not None: + action = torch.as_tensor( + action, + device=self.device, # type: ignore + dtype=torch.float32, + ).flatten(1) + state = torch.cat([state, action], dim=1) + logits, h = self.preprocess(state) + logits = self.last(logits) + return logits