diff --git a/README.md b/README.md index d23a21f49..7319d4346 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) - [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf) +- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf) Here are Tianshou's other features: diff --git a/docs/api/tianshou.data.rst b/docs/api/tianshou.data.rst index 77c69aa15..63dfa91d6 100644 --- a/docs/api/tianshou.data.rst +++ b/docs/api/tianshou.data.rst @@ -30,6 +30,14 @@ PrioritizedReplayBuffer :undoc-members: :show-inheritance: +HERReplayBuffer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.HERReplayBuffer + :members: + :undoc-members: + :show-inheritance: + ReplayBufferManager ~~~~~~~~~~~~~~~~~~~ @@ -46,6 +54,15 @@ PrioritizedReplayBufferManager :undoc-members: :show-inheritance: + +HERReplayBufferManager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.HERReplayBufferManager + :members: + :undoc-members: + :show-inheritance: + VectorReplayBuffer ~~~~~~~~~~~~~~~~~~ @@ -62,6 +79,14 @@ PrioritizedVectorReplayBuffer :undoc-members: :show-inheritance: +HERVectorReplayBuffer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.HERVectorReplayBuffer + :members: + :undoc-members: + :show-inheritance: + CachedReplayBuffer ~~~~~~~~~~~~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index 7fce12f6b..09098cb94 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ +* :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay `_ Here is Tianshou's other features: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index eb3ac4a23..c63486213 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -52,6 +52,7 @@ mujoco jit nstep preprocess +preprocessing repo ReLU namespace diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md index 12b480a9d..e38107e6b 100644 --- a/examples/mujoco/README.md +++ b/examples/mujoco/README.md @@ -20,6 +20,7 @@ Supported algorithms are listed below: - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32) +- [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495) ## EnvPool @@ -304,6 +305,18 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai 1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are. 2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general. +## Others + +### HER +| Environment | DDPG without HER | DDPG with HER | +| :--------------------: | :--------------: | :--------------: | +| FetchReach | -49.9±0.2. | **-17.6±21.7** | + +#### Hints for HER +1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is ``FetchReach-v3`` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics). +2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since *DDPG without HER* failed in every experiment, the best hyperparameters for *DDPG with HER* are used in the evaluation of both settings. +3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for ``FetchReach-v3`` is -50 which we can imply that *DDPG without HER* performs as good as a random policy. *DDPG with HER* although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds. + ## Note [1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures. diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py new file mode 100644 index 000000000..893912566 --- /dev/null +++ b/examples/mujoco/fetch_her_ddpg.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint + +import gym +import numpy as np +import torch +import wandb +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import ( + Collector, + HERReplayBuffer, + HERVectorReplayBuffer, + ReplayBuffer, + VectorReplayBuffer, +) +from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated +from tianshou.exploration import GaussianNoise +from tianshou.policy import DDPGPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.common import Net, get_dict_state_decorator +from tianshou.utils.net.continuous import Actor, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="FetchReach-v3") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--buffer-size", type=int, default=100000) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor-lr", type=float, default=1e-3) + parser.add_argument("--critic-lr", type=float, default=3e-3) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--exploration-noise", type=float, default=0.1) + parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--epoch", type=int, default=10) + parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--update-per-step", type=int, default=1) + parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument( + "--replay-buffer", type=str, default="her", choices=["normal", "her"] + ) + parser.add_argument("--her-horizon", type=int, default=50) + parser.add_argument("--her-future-k", type=int, default=8) + parser.add_argument("--training-num", type=int, default=1) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="HER-benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + return parser.parse_args() + + +def make_fetch_env(task, training_num, test_num): + env = TruncatedAsTerminated(gym.make(task)) + train_envs = ShmemVectorEnv( + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)] + ) + test_envs = ShmemVectorEnv( + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)] + ) + return env, train_envs, test_envs + + +def test_ddpg(args=get_args()): + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "ddpg" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) + logger.wandb_run.config.setdefaults(vars(args)) + args = argparse.Namespace(**wandb.config) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) + + env, train_envs, test_envs = make_fetch_env( + args.task, args.training_num, args.test_num + ) + args.state_shape = { + 'observation': env.observation_space['observation'].shape, + 'achieved_goal': env.observation_space['achieved_goal'].shape, + 'desired_goal': env.observation_space['desired_goal'].shape, + } + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + args.exploration_noise = args.exploration_noise * args.max_action + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # model + dict_state_dec, flat_state_shape = get_dict_state_decorator( + state_shape=args.state_shape, + keys=['observation', 'achieved_goal', 'desired_goal'] + ) + net_a = dict_state_dec(Net)( + flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device + ) + actor = dict_state_dec(Actor)( + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net_c = dict_state_dec(Net)( + flat_state_shape, + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) + critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + policy = DDPGPolicy( + actor, + actor_optim, + critic, + critic_optim, + tau=args.tau, + gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + estimation_step=args.n_step, + action_space=env.action_space, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + def compute_reward_fn(ag: np.ndarray, g: np.ndarray): + return env.compute_reward(ag, g, {}) + + if args.replay_buffer == "normal": + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + else: + if args.training_num > 1: + buffer = HERVectorReplayBuffer( + args.buffer_size, + len(train_envs), + compute_reward_fn=compute_reward_fn, + horizon=args.her_horizon, + future_k=args.her_future_k, + ) + else: + buffer = HERReplayBuffer( + args.buffer_size, + compute_reward_fn=compute_reward_fn, + horizon=args.her_horizon, + future_k=args.her_future_k, + ) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + if not args.watch: + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + pprint.pprint(result) + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == "__main__": + test_ddpg() diff --git a/test/base/env.py b/test/base/env.py index d7a96035d..8c6333d0b 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -166,3 +166,48 @@ def step(self, action): for i in range(self.size): self.graph.nodes[i]["data"] = next_graph_state[i] return self._encode_obs(), 1.0, 0, 0, {} + + +class MyGoalEnv(MyTestEnv): + + def __init__(self, *args, **kwargs): + assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, \ + "dict_state / recurse_state not supported" + super().__init__(*args, **kwargs) + obs, _ = super().reset(state=0) + obs, _, _, _, _ = super().step(1) + self._goal = obs * self.size + super_obsv = self.observation_space + self.observation_space = gym.spaces.Dict( + { + 'observation': super_obsv, + 'achieved_goal': super_obsv, + 'desired_goal': super_obsv, + } + ) + + def reset(self, *args, **kwargs): + obs, info = super().reset(*args, **kwargs) + new_obs = { + 'observation': obs, + 'achieved_goal': obs, + 'desired_goal': self._goal + } + return new_obs, info + + def step(self, *args, **kwargs): + obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs) + new_obs_next = { + 'observation': obs_next, + 'achieved_goal': obs_next, + 'desired_goal': self._goal + } + return new_obs_next, rew, terminated, truncated, info + + def compute_reward_fn( + self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: dict + ) -> np.ndarray: + axis = -1 + if self.array_state: + axis = (-3, -2, -1) + return (achieved_goal == desired_goal).all(axis=axis) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index f6c96e95b..02d140d1e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -11,6 +11,8 @@ from tianshou.data import ( Batch, CachedReplayBuffer, + HERReplayBuffer, + HERVectorReplayBuffer, PrioritizedReplayBuffer, PrioritizedVectorReplayBuffer, ReplayBuffer, @@ -20,9 +22,9 @@ from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': - from env import MyTestEnv + from env import MyGoalEnv, MyTestEnv else: # pytest - from test.base.env import MyTestEnv + from test.base.env import MyGoalEnv, MyTestEnv def test_replaybuffer(size=10, bufsize=20): @@ -300,6 +302,142 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1 +def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): + env_size = size + env = MyGoalEnv(env_size, array_state=True) + + def compute_reward_fn(ag, g): + return env.compute_reward_fn(ag, g, {}) + + buf = HERReplayBuffer( + bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8 + ) + buf2 = HERVectorReplayBuffer( + bufsize, + buffer_num=3, + compute_reward_fn=compute_reward_fn, + horizon=30, + future_k=8 + ) + # Apply her on every episodes sampled (Hacky but necessary for deterministic test) + buf.future_p = 1 + for buf2_buf in buf2.buffers: + buf2_buf.future_p = 1 + + obs, _ = env.reset() + action_list = [1] * 5 + [0] * 10 + [1] * 10 + for i, act in enumerate(action_list): + obs_next, rew, terminated, truncated, info = env.step(act) + batch = Batch( + obs=obs, + act=[act], + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info + ) + buf.add(batch) + buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2]) + obs = obs_next + assert len(buf) == min(bufsize, i + 1) + assert len(buf2) == min(bufsize, 3 * (i + 1)) + + batch, indices = buf.sample(sample_sz) + + # Check that goals are the same for the episode (only 1 ep in buffer) + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = buf[tmp_indices].obs + obs_next = buf[tmp_indices].obs_next + rew = buf[tmp_indices].rew + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + assert np.all(g == g[0]) + assert np.all(g_next == g_next[0]) + assert np.all(rew == (ag_next == g).astype(np.float32)) + tmp_indices = buf.next(tmp_indices) + + # Check that goals are correctly restored + buf._restore_cache() + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = buf[tmp_indices].obs + obs_next = buf[tmp_indices].obs_next + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + assert np.all(g == env_size) + assert np.all(g_next == g_next[0]) + assert np.all(g == g[0]) + tmp_indices = buf.next(tmp_indices) + + # Test vector buffer + batch, indices = buf2.sample(sample_sz) + + # Check that goals are the same for the episode (only 1 ep in buffer) + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = buf2[tmp_indices].obs + obs_next = buf2[tmp_indices].obs_next + rew = buf2[tmp_indices].rew + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + assert np.all(g == g_next) + assert np.all(rew == (ag_next == g).astype(np.float32)) + tmp_indices = buf2.next(tmp_indices) + + # Check that goals are correctly restored + buf2._restore_cache() + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = buf2[tmp_indices].obs + obs_next = buf2[tmp_indices].obs_next + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + assert np.all(g == env_size) + assert np.all(g_next == g_next[0]) + assert np.all(g == g[0]) + tmp_indices = buf2.next(tmp_indices) + + # Test handling cycled indices + env_size = size + bufsize = 15 + env = MyGoalEnv(env_size, array_state=False) + + def compute_reward_fn(ag, g): + return env.compute_reward_fn(ag, g, {}) + + buf = HERReplayBuffer( + bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8 + ) + buf._index = 5 # shifted start index + buf.future_p = 1 + action_list = [1] * 10 + for ep_len in [5, 10]: + obs, _ = env.reset() + for i in range(ep_len): + act = 1 + obs_next, rew, terminated, truncated, info = env.step(act) + batch = Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info + ) + buf.add(batch) + obs = obs_next + batch, indices = buf.sample(0) + assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal) + assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal) + assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep) + assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep) + + def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) @@ -1180,3 +1318,4 @@ def test_from_data(): test_multibuf_stack() test_multibuf_hdf5() test_from_data() + test_herreplaybuffer() diff --git a/test/base/test_env.py b/test/base/test_env.py index 6a222709d..1c91c0fa1 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -16,6 +16,7 @@ SubprocVectorEnv, VectorEnvNormObs, ) +from tianshou.env.gym_wrappers import TruncatedAsTerminated from tianshou.utils import RunningMeanStd if __name__ == "__main__": @@ -347,6 +348,10 @@ def __init__(self): self.action_space = gym.spaces.Box( low=-1.0, high=2.0, shape=(4, ), dtype=np.float32 ) + self.observation_space = gym.spaces.Discrete(2) + + def step(self, act): + return self.observation_space.sample(), -1, False, True, {} bsz = 10 action_per_branch = [4, 6, 10, 7] @@ -374,6 +379,14 @@ def __init__(self): env_d.action(np.array([env_d.action_space.n - 1] * bsz)), np.array([env_m.action_space.nvec - 1] * bsz), ) + # check truncate is True when terminated + try: + env_t = TruncatedAsTerminated(env) + except EnvironmentError: + env_t = None + if env_t is not None: + _, _, truncated, _, _ = env_t.step(env_t.action_space.sample()) + assert truncated @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 89250d009..7a86ce857 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -6,13 +6,16 @@ from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer.base import ReplayBuffer from tianshou.data.buffer.prio import PrioritizedReplayBuffer +from tianshou.data.buffer.her import HERReplayBuffer from tianshou.data.buffer.manager import ( ReplayBufferManager, PrioritizedReplayBufferManager, + HERReplayBufferManager, ) from tianshou.data.buffer.vecbuf import ( - VectorReplayBuffer, + HERVectorReplayBuffer, PrioritizedVectorReplayBuffer, + VectorReplayBuffer, ) from tianshou.data.buffer.cached import CachedReplayBuffer from tianshou.data.collector import Collector, AsyncCollector @@ -25,10 +28,13 @@ "SegmentTree", "ReplayBuffer", "PrioritizedReplayBuffer", + "HERReplayBuffer", "ReplayBufferManager", "PrioritizedReplayBufferManager", + "HERReplayBufferManager", "VectorReplayBuffer", "PrioritizedVectorReplayBuffer", + "HERVectorReplayBuffer", "CachedReplayBuffer", "Collector", "AsyncCollector", diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py new file mode 100644 index 000000000..8c5c37166 --- /dev/null +++ b/tianshou/data/buffer/her.py @@ -0,0 +1,186 @@ +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np + +from tianshou.data import Batch, ReplayBuffer + + +class HERReplayBuffer(ReplayBuffer): + """Implementation of Hindsight Experience Replay. arXiv:1707.01495. + + HERReplayBuffer is to be used with goal-based environment where the + observation is a dictionary with keys ``observation``, ``achieved_goal`` and + ``desired_goal``. Currently support only HER's future strategy, online sampling. + + :param int size: the size of the replay buffer. + :param compute_reward_fn: a function that takes 2 ``np.array`` arguments, + ``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``. + The two arguments are of shape (batch_size, ...original_shape) and the returned + rewards must be of shape (batch_size,). + :param int horizon: the maximum number of steps in an episode. + :param int future_k: the 'k' parameter introduced in the paper. In short, there + will be at most k episodes that are re-written for every 1 unaltered episode + during the sampling. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__( + self, + size: int, + compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], + horizon: int, + future_k: float = 8.0, + **kwargs: Any, + ) -> None: + super().__init__(size, **kwargs) + self.horizon = horizon + self.future_p = 1 - 1 / future_k + self.compute_reward_fn = compute_reward_fn + self._original_meta = Batch() + self._altered_indices = np.array([]) + + def _restore_cache(self) -> None: + """Write cached original meta back to `self._meta`. + + It's called everytime before 'writing', 'sampling' or 'saving' the buffer. + """ + if not hasattr(self, '_altered_indices'): + return + + if self._altered_indices.size == 0: + return + self._meta[self._altered_indices] = self._original_meta + # Clean + self._original_meta = Batch() + self._altered_indices = np.array([]) + + def reset(self, keep_statistics: bool = False) -> None: + self._restore_cache() + return super().reset(keep_statistics) + + def save_hdf5(self, path: str, compression: Optional[str] = None) -> None: + self._restore_cache() + return super().save_hdf5(path, compression) + + def set_batch(self, batch: Batch) -> None: + self._restore_cache() + return super().set_batch(batch) + + def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: + self._restore_cache() + return super().update(buffer) + + def add( + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._restore_cache() + return super().add(batch, buffer_ids) + + def sample_indices(self, batch_size: int) -> np.ndarray: + """Get a random sample of index with size = batch_size. + + Return all available indices in the buffer if batch_size is 0; return an \ + empty numpy array if batch_size < 0 or no available index can be sampled. \ + Additionally, some episodes of the sampled transitions will be re-written \ + according to HER. + """ + self._restore_cache() + indices = super().sample_indices(batch_size=batch_size) + self.rewrite_transitions(indices.copy()) + return indices + + def rewrite_transitions(self, indices: np.ndarray) -> None: + """Re-write the goal of some sampled transitions' episodes according to HER. + + Currently applies only HER's 'future' strategy. The new goals will be written \ + directly to the internal batch data temporarily and will be restored right \ + before the next sampling or when using some of the buffer's method (e.g. \ + `add`, `save_hdf5`, etc.). This is to make sure that n-step returns \ + calculation etc., performs correctly without additional alteration. + """ + if indices.size == 0: + return + + # Sort indices keeping chronological order + indices[indices < self._index] += self.maxsize + indices = np.sort(indices) + indices[indices >= self.maxsize] -= self.maxsize + + # Construct episode trajectories + indices = [indices] + for _ in range(self.horizon - 1): + indices.append(self.next(indices[-1])) + indices = np.stack(indices) + + # Calculate future timestep to use + current = indices[0] + terminal = indices[-1] + future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current) + future_offset = future_offset.astype(int) + future_t = (current + future_offset) + + # Compute indices + # open indices are used to find longest, unique trajectories among + # presented episodes + unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1]) + unique_ep_indices = indices[:, unique_ep_open_indices] + # close indices are used to find max future_t among presented episodes + unique_ep_close_indices = np.hstack( + [(unique_ep_open_indices - 1)[1:], + len(terminal) - 1] + ) + # episode indices that will be altered + her_ep_indices = np.random.choice( + len(unique_ep_open_indices), + size=int(len(unique_ep_open_indices) * self.future_p), + replace=False + ) + + # Cache original meta + self._altered_indices = unique_ep_indices.copy() + self._original_meta = self._meta[self._altered_indices].copy() + + # Copy original obs, ep_rew (and obs_next), and obs of future time step + ep_obs = self[unique_ep_indices].obs + ep_rew = self[unique_ep_indices].rew + if self._save_obs_next: + ep_obs_next = self[unique_ep_indices].obs_next + future_obs = self[future_t[unique_ep_close_indices]].obs_next + else: + future_obs = self[self.next(future_t[unique_ep_close_indices])].obs + + # Re-assign goals and rewards via broadcast assignment + ep_obs.desired_goal[:, her_ep_indices] = \ + future_obs.achieved_goal[None, her_ep_indices] + if self._save_obs_next: + ep_obs_next.desired_goal[:, her_ep_indices] = \ + future_obs.achieved_goal[None, her_ep_indices] + ep_rew[:, her_ep_indices] = \ + self._compute_reward(ep_obs_next)[:, her_ep_indices] + else: + tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs + ep_rew[:, her_ep_indices] = \ + self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices] + + # Sanity check + assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape + assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape + assert ep_rew.shape == unique_ep_indices.shape + + # Re-write meta + self._meta.obs[unique_ep_indices] = ep_obs + if self._save_obs_next: + self._meta.obs_next[unique_ep_indices] = ep_obs_next + self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32) + + def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray: + lead_shape = obs.observation.shape[:lead_dims] + g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:]) + ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:]) + rewards = self.compute_reward_fn(ag, g) + return rewards.reshape(*lead_shape, *rewards.shape[1:]) diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 2df50fa96..b694c1abe 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -3,7 +3,7 @@ import numpy as np from numba import njit -from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer +from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer from tianshou.data.batch import _alloc_by_keys_diff, _create_value @@ -21,7 +21,9 @@ class ReplayBufferManager(ReplayBuffer): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__(self, buffer_list: List[ReplayBuffer]) -> None: + def __init__( + self, buffer_list: Union[List[ReplayBuffer], List[HERReplayBuffer]] + ) -> None: self.buffer_num = len(buffer_list) self.buffers = np.array(buffer_list, dtype=object) offset, size = [], 0 @@ -212,6 +214,48 @@ def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) +class HERReplayBufferManager(ReplayBufferManager): + """HERReplayBufferManager contains a list of HERReplayBuffer with \ + exactly the same configuration. + + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + + :param buffer_list: a list of HERReplayBuffer needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, buffer_list: List[HERReplayBuffer]) -> None: + super().__init__(buffer_list) + + def _restore_cache(self) -> None: + for buf in self.buffers: + buf._restore_cache() + + def save_hdf5(self, path: str, compression: Optional[str] = None) -> None: + self._restore_cache() + return super().save_hdf5(path, compression) + + def set_batch(self, batch: Batch) -> None: + self._restore_cache() + return super().set_batch(batch) + + def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: + self._restore_cache() + return super().update(buffer) + + def add( + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._restore_cache() + return super().add(batch, buffer_ids) + + @njit def _prev_index( index: np.ndarray, diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py index 2d4831c06..a08ad9857 100644 --- a/tianshou/data/buffer/vecbuf.py +++ b/tianshou/data/buffer/vecbuf.py @@ -3,6 +3,8 @@ import numpy as np from tianshou.data import ( + HERReplayBuffer, + HERReplayBufferManager, PrioritizedReplayBuffer, PrioritizedReplayBufferManager, ReplayBuffer, @@ -64,3 +66,26 @@ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: def set_beta(self, beta: float) -> None: for buffer in self.buffers: buffer.set_beta(beta) + + +class HERVectorReplayBuffer(HERReplayBufferManager): + """HERVectorReplayBuffer contains n HERReplayBuffer with same size. + + It is used for storing transition from different environments yet keeping the order + of time. + + :param int total_size: the total size of HERVectorReplayBuffer. + :param int buffer_num: the number of HERReplayBuffer it uses, which are + under the same configuration. + + Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`. + + .. seealso:: + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) + buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)] + super().__init__(buffer_list) diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 6abea3280..a00c3cd38 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,6 +1,10 @@ """Env package.""" -from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete +from tianshou.env.gym_wrappers import ( + ContinuousToDiscrete, + MultiDiscreteToDiscrete, + TruncatedAsTerminated, +) from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper from tianshou.env.venvs import ( BaseVectorEnv, @@ -26,4 +30,5 @@ "PettingZooEnv", "ContinuousToDiscrete", "MultiDiscreteToDiscrete", + "TruncatedAsTerminated", ] diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index 5b98e77af..b906b79b9 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -1,7 +1,8 @@ -from typing import List, Union +from typing import Any, Dict, List, Tuple, Union import gym import numpy as np +from packaging import version class ContinuousToDiscrete(gym.ActionWrapper): @@ -55,3 +56,25 @@ def action(self, act: np.ndarray) -> np.ndarray: converted_act.append(act // b) act = act % b return np.array(converted_act).transpose() + + +class TruncatedAsTerminated(gym.Wrapper): + """A wrapper that set ``terminated = terminated or truncated`` for ``step()``. + + It's intended to use with ``gym.wrappers.TimeLimit``. + + :param gym.Env env: gym environment. + """ + + def __init__(self, env: gym.Env): + super().__init__(env) + if not version.parse(gym.__version__) >= version.parse('0.26.0'): + raise EnvironmentError( + f"TruncatedAsTerminated is not applicable with gym version \ + {gym.__version__}" + ) + + def step(self, act: np.ndarray) -> Tuple[Any, float, bool, bool, Dict[Any, Any]]: + observation, reward, terminated, truncated, info = super().step(act) + terminated = (terminated or truncated) + return observation, reward, terminated, truncated, info diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 0eca53c0c..2dc47ed21 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -89,6 +89,7 @@ def log_update_data(self, update_result: dict, step: int) -> None: self.write("update/gradient_step", step, log_data) self.last_log_update_step = step + @abstractmethod def save_data( self, epoch: int, @@ -106,6 +107,7 @@ def save_data( """ pass + @abstractmethod def restore_data(self) -> Tuple[int, int, int]: """Return the metadata from existing log. @@ -126,3 +128,15 @@ def __init__(self) -> None: def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: """The LazyLogger writes nothing.""" pass + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, + ) -> None: + pass + + def restore_data(self) -> Tuple[int, int, int]: + pass diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 1fc58f768..a6a018535 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,5 +1,6 @@ from typing import ( Any, + Callable, Dict, List, Optional, @@ -14,6 +15,8 @@ import torch from torch import nn +from tianshou.data.batch import Batch + ModuleType = Type[nn.Module] @@ -262,7 +265,7 @@ def forward( """ obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -453,3 +456,61 @@ def forward( action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True) logits = value_out + action_scores return logits, state + + +def get_dict_state_decorator( + state_shape: Dict[str, Union[int, Sequence[int]]], keys: Sequence[str] +) -> Tuple[Callable, int]: + """A helper function to make Net or equivalent classes (e.g. Actor, Critic) \ + applicable to dict state. + + The first return item, ``decorator_fn``, will alter the implementation of forward + function of the given class by preprocessing the observation. The preprocessing is + basically flatten the observation and concatenate them based on the ``keys`` order. + The batch dimension is preserved if presented. The result observation shape will + be equal to ``new_state_shape``, the second return item. + + :param state_shape: A dictionary indicating each state's shape + :param keys: A list of state's keys. The flatten observation will be according to \ + this list order. + :returns: a 2-items tuple ``decorator_fn`` and ``new_state_shape`` + """ + original_shape = state_shape + flat_state_shapes = [] + for k in keys: + flat_state_shapes.append(int(np.prod(state_shape[k]))) + new_state_shape = sum(flat_state_shapes) + + def preprocess_obs( + obs: Union[Batch, dict, torch.Tensor, np.ndarray] + ) -> torch.Tensor: + if isinstance(obs, dict) or (isinstance(obs, Batch) and keys[0] in obs): + if original_shape[keys[0]] == obs[keys[0]].shape: + # No batch dim + new_obs = torch.Tensor([obs[k] for k in keys]).flatten() + # new_obs = torch.Tensor([obs[k] for k in keys]).reshape(1, -1) + else: + bsz = obs[keys[0]].shape[0] + new_obs = torch.cat( + [torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], dim=1 + ) + else: + new_obs = torch.Tensor(obs) + return new_obs + + @no_type_check + def decorator_fn(net_class): + + class new_net_class(net_class): + + def forward( + self, + obs: Union[np.ndarray, torch.Tensor], + *args, + **kwargs, + ) -> Any: + return super().forward(preprocess_obs(obs), *args, **kwargs) + + return new_net_class + + return decorator_fn, new_state_shape diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index bf083c32a..fb75e3317 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -124,13 +124,13 @@ def forward( """Mapping: (s, a) -> logits -> Q(s, a).""" obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ).flatten(1) if act is not None: act = torch.as_tensor( act, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) @@ -266,7 +266,7 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -339,7 +339,7 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -352,7 +352,7 @@ def forward( if act is not None: act = torch.as_tensor( act, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) obs = torch.cat([obs, act], dim=1)