From 3d596a70d078598e113f17284a2275fb5c9519ea Mon Sep 17 00:00:00 2001 From: juno-t Date: Sat, 1 Oct 2022 07:00:05 +0000 Subject: [PATCH 01/20] init her replaybuffer --- test/base/env.py | 71 ++++++++++++ test/base/test_buffer.py | 108 +++++++++++++++++- tianshou/data/__init__.py | 8 +- tianshou/data/buffer/her.py | 190 ++++++++++++++++++++++++++++++++ tianshou/data/buffer/manager.py | 44 +++++++- tianshou/data/buffer/vecbuf.py | 21 ++++ 6 files changed, 437 insertions(+), 5 deletions(-) create mode 100644 tianshou/data/buffer/her.py diff --git a/test/base/env.py b/test/base/env.py index d7a96035d..b815b646f 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -7,6 +7,8 @@ import numpy as np from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple +from tianshou.data import Batch + class MyTestEnv(gym.Env): """This is a "going right" task. The task is to go right ``size`` steps. @@ -166,3 +168,72 @@ def step(self, action): for i in range(self.size): self.graph.nodes[i]["data"] = next_graph_state[i] return self._encode_obs(), 1.0, 0, 0, {} + + +class MyGoalEnv(MyTestEnv): + + def __init__(self, *args, **kwargs): + assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, \ + "dict_state / recurse_state not supported" + super().__init__(*args, **kwargs) + obs, _ = super().reset(state=0) + obs, _, _, _, _ = super().step(1) + self._goal = obs * self.size + super_obsv = self.observation_space + self.observation_space = Box( + shape=(super_obsv.shape[0] * 3, *super_obsv.shape[1:]), + low=0, + high=self.size + ) + + def reset(self, *args, **kwargs): + obs, info = super().reset(*args, **kwargs) + new_obs = np.concatenate([obs, obs, self._goal], axis=0) + return new_obs, info + + def step(self, *args, **kwargs): + obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs) + new_obs_next = np.concatenate([obs_next, obs_next, self._goal], axis=0) + return new_obs_next, rew, terminated, truncated, info + + def deconstruct_obs_fn(self, obs: np.ndarray) -> Batch: + """Deconstruct observation into observation, acheived_goal, goal + obs: shape(bsz, *observation_shape) + return: Batch( + o=shape(bsz, *o.shape), + ag=shape(bsz, *ag.shape), + g=shape(bsz, *g.shape) + ) + """ + state_sz = 1 + if self.array_state: + state_sz = 4 + return Batch( + o=obs[:, :state_sz], + ag=obs[:, state_sz:2 * state_sz], + g=obs[:, 2 * state_sz:], + ) + + def flatten_obs_fn(self, obs: Batch) -> np.ndarray: + """Reconstruct observation + obs: Batch( + o=shape(bsz, *o.shape), + ag=shape(bsz, *ag.shape), + g=shape(bsz, *g.shape) + ) + return: shape(bsz, *observation_shape) + """ + return np.concatenate((obs.o, obs.ag, obs.g), axis=1) + + def compute_reward_fn(self, obs: Batch) -> np.ndarray: + """Compute rewards from deconstructed obs + obs: Batch( + o=shape(bsz, *o.shape), + ag=shape(bsz, *ag.shape), + g=shape(bsz, *g.shape) + ) + return: shape(bsz,) + """ + ag_sum = obs.ag.reshape(obs.ag.shape[0], -1).sum(axis=1) + g_sum = obs.g.reshape(obs.g.shape[0], -1).sum(axis=1) + return (ag_sum == g_sum) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index f6c96e95b..028729ef0 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -11,6 +11,8 @@ from tianshou.data import ( Batch, CachedReplayBuffer, + HERReplayBuffer, + HERVectorReplayBuffer, PrioritizedReplayBuffer, PrioritizedVectorReplayBuffer, ReplayBuffer, @@ -20,9 +22,9 @@ from tianshou.data.utils.converter import to_hdf5 if __name__ == '__main__': - from env import MyTestEnv + from env import MyGoalEnv, MyTestEnv else: # pytest - from test.base.env import MyTestEnv + from test.base.env import MyGoalEnv, MyTestEnv def test_replaybuffer(size=10, bufsize=20): @@ -300,6 +302,108 @@ def test_priortized_replaybuffer(size=32, bufsize=15): assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1 +def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): + env_size = size + env = MyGoalEnv(env_size, array_state=True) + buf = HERReplayBuffer( + bufsize, + deconstruct_obs_fn=env.deconstruct_obs_fn, + flatten_obs_fn=env.flatten_obs_fn, + compute_reward_fn=env.compute_reward_fn, + horizon=30, + future_k=8 + ) + buf2 = HERVectorReplayBuffer( + bufsize, + deconstruct_obs_fn=env.deconstruct_obs_fn, + flatten_obs_fn=env.flatten_obs_fn, + compute_reward_fn=env.compute_reward_fn, + horizon=30, + future_k=8 + ) + # Apply her on every episodes sampled (Hacky but necessary for deterministic test) + buf.future_p = 1 + for buf2_buf in buf2.buffers: + buf2_buf.future_p = 1 + + obs, _ = env.reset() + action_list = [1] * 5 + [0] * 10 + [1] * 10 + for i, act in enumerate(action_list): + obs_next, rew, terminated, truncated, info = env.step(act) + batch = Batch( + obs=obs, + act=[act], + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info + ) + buf.add(batch) + buf2.add(Batch.stack([batch, batch, batch]), buffer_ids=[0, 1, 2]) + obs = obs_next + assert len(buf) == min(bufsize, i + 1) + assert len(buf2) == min(bufsize, 3 * (i + 1)) + + batch, indices = buf.sample(sample_sz) + + # Check that goals are the same for the episode (only 1 ep in buffer) + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = env.deconstruct_obs_fn(buf[tmp_indices].obs) + obs_next = env.deconstruct_obs_fn(buf[tmp_indices].obs_next) + rew = buf[tmp_indices].rew + g = obs.g.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next.ag.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.g.reshape(sample_sz, -1)[:, 0] + assert np.all(g == g[0]) + assert np.all(g_next == g_next[0]) + assert np.all(rew == (ag_next == g).astype(np.float32)) + tmp_indices = buf.next(tmp_indices) + + # Check that goals are correctly restored + buf._restore_cache() + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = env.deconstruct_obs_fn(buf[tmp_indices].obs) + obs_next = env.deconstruct_obs_fn(buf[tmp_indices].obs_next) + g = obs.g.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.g.reshape(sample_sz, -1)[:, 0] + assert np.all(g == env_size) + assert np.all(g_next == g_next[0]) + assert np.all(g == g[0]) + tmp_indices = buf.next(tmp_indices) + + # Test vector buffer + batch, indices = buf2.sample(sample_sz) + + # Check that goals are the same for the episode (only 1 ep in buffer) + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = env.deconstruct_obs_fn(buf2[tmp_indices].obs) + obs_next = env.deconstruct_obs_fn(buf2[tmp_indices].obs_next) + rew = buf2[tmp_indices].rew + g = obs.g.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next.ag.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.g.reshape(sample_sz, -1)[:, 0] + assert np.all(g == g_next) + assert np.all(rew == (ag_next == g).astype(np.float32)) + tmp_indices = buf2.next(tmp_indices) + + # Check that goals are correctly restored + buf2._restore_cache() + tmp_indices = indices.copy() + for _ in range(2 * env_size): + obs = env.deconstruct_obs_fn(buf2[tmp_indices].obs) + g = obs.g.reshape(sample_sz, -1)[:, 0] + g_next = env.deconstruct_obs_fn(buf2[tmp_indices].obs_next + ).g.reshape(sample_sz, -1)[:, 0] + assert np.all(g == env_size) + assert np.all(g_next == g_next[0]) + assert np.all(g == g[0]) + tmp_indices = buf2.next(tmp_indices) + + def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 89250d009..7a86ce857 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -6,13 +6,16 @@ from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer.base import ReplayBuffer from tianshou.data.buffer.prio import PrioritizedReplayBuffer +from tianshou.data.buffer.her import HERReplayBuffer from tianshou.data.buffer.manager import ( ReplayBufferManager, PrioritizedReplayBufferManager, + HERReplayBufferManager, ) from tianshou.data.buffer.vecbuf import ( - VectorReplayBuffer, + HERVectorReplayBuffer, PrioritizedVectorReplayBuffer, + VectorReplayBuffer, ) from tianshou.data.buffer.cached import CachedReplayBuffer from tianshou.data.collector import Collector, AsyncCollector @@ -25,10 +28,13 @@ "SegmentTree", "ReplayBuffer", "PrioritizedReplayBuffer", + "HERReplayBuffer", "ReplayBufferManager", "PrioritizedReplayBufferManager", + "HERReplayBufferManager", "VectorReplayBuffer", "PrioritizedVectorReplayBuffer", + "HERVectorReplayBuffer", "CachedReplayBuffer", "Collector", "AsyncCollector", diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py new file mode 100644 index 000000000..910b3646f --- /dev/null +++ b/tianshou/data/buffer/her.py @@ -0,0 +1,190 @@ +from typing import Any, Callable, List, Optional, Tuple, Union + +import numpy as np + +from tianshou.data import Batch, ReplayBuffer + + +class HERReplayBuffer(ReplayBuffer): + + def __init__( + self, + size: int, + deconstruct_obs_fn: Callable[[np.ndarray], Batch], + flatten_obs_fn: Callable[[Batch], np.ndarray], + compute_reward_fn: Callable[[Batch], np.ndarray], + horizon: int, + future_k: float = 8.0, + **kwargs: Any, + ) -> None: + super().__init__(size, **kwargs) + self.horizon = horizon + self.future_p = 1 - 1 / future_k + self.deconstruct_obs_fn = deconstruct_obs_fn + self.flatten_obs_fn = flatten_obs_fn + self.compute_reward_fn = compute_reward_fn + self._original_meta = Batch() + self._altered_indices = np.array([]) + + def _restore_cache(self) -> None: + """ + Write cached original meta back to self._meta + Do this everytime before 'writing', 'sampling' or 'saving' the buffer. + """ + if not hasattr(self, '_altered_indices'): + return + + if self._altered_indices.size == 0: + return + self._meta[self._altered_indices] = self._original_meta + # Clean + del self._original_meta, self._altered_indices + self._original_meta = Batch() + self._altered_indices = np.array([]) + + def reset(self, keep_statistics: bool = False) -> None: + self._restore_cache() + return super().reset(keep_statistics) + + def save_hdf5(self, path: str, compression: Optional[str] = None) -> None: + self._restore_cache() + return super().save_hdf5(path, compression) + + def set_batch(self, batch: Batch) -> None: + self._restore_cache() + return super().set_batch(batch) + + def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: + self._restore_cache() + return super().update(buffer) + + def add( + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._restore_cache() + return super().add(batch, buffer_ids) + + def sample_indices(self, batch_size: int) -> np.ndarray: + """Get a random sample of index with size = batch_size. + Return all available indices in the buffer if batch_size is 0; return an empty + numpy array if batch_size < 0 or no available index can be sampled. + """ + self._restore_cache() + indices = np.sort(super().sample_indices(batch_size=batch_size)) + self.rewrite_transitions(indices) + return indices + + def rewrite_transitions(self, indices: np.ndarray) -> None: + """ Re-write the goal of some sampled transitions' episodes according to HER's + 'future' strategy. The new goals will be written directly to the internal + batch data temporarily and will be restored right before the next sampling or + when using some of the buffer's method (such as `add` or `save_hdf5`). This is + to make sure that n-step returns calculation etc. performs correctly without + alteration. + """ + if indices.size == 0: + return + # Construct episode trajectories + indices = [indices] + for _ in range(self.horizon - 1): + indices.append(self.next(indices[-1])) + indices = np.stack(indices) + + # Calculate future timestep to use + current = indices[0] + terminal = indices[-1] + future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current) + future_offset = future_offset.astype(int) + future_t = (current + future_offset) + + # Compute indices + # open indices are used to find longest, unique trajectories among + # presented episodes + unique_ep_open_indices = np.unique(terminal, return_index=True)[1] + unique_ep_indices = indices[:, unique_ep_open_indices] + # close indices are used to find max future_t among presented episodes + unique_ep_close_indices = np.hstack( + [(unique_ep_open_indices - 1)[1:], + len(terminal) - 1] + ) + # episode indices that will be altered + her_ep_indices = np.random.choice( + len(unique_ep_open_indices), + size=int(len(unique_ep_open_indices) * self.future_p), + replace=False + ) + + # Copy original obs, ep_rew (and obs_next), and obs of future time step + ep_obs = self._deconstruct_obs(self[unique_ep_indices].obs) + ep_rew = self[unique_ep_indices].rew + if self._save_obs_next: + ep_obs_next = self._deconstruct_obs(self[unique_ep_indices].obs_next) + future_obs = self._deconstruct_obs( + self[future_t[unique_ep_close_indices]].obs_next, lead_dims=1 + ) + else: + future_obs = self._deconstruct_obs( + self[self.next(future_t[unique_ep_close_indices])].obs, lead_dims=1 + ) + + # Re-assign goals and rewards via broadcast assignment + ep_obs.g[:, her_ep_indices] = future_obs.ag[None, her_ep_indices] + if self._save_obs_next: + ep_obs_next.g[:, her_ep_indices] = future_obs.ag[None, her_ep_indices] + ep_rew[:, + her_ep_indices] = self._compute_reward(ep_obs_next)[:, + her_ep_indices] + else: + tmp_ep_obs_next = self._deconstruct_obs( + self[self.next(unique_ep_indices)].obs + ) + ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next + )[:, her_ep_indices] + + # Sanity check + assert ep_obs.g.shape[:2] == unique_ep_indices.shape + assert ep_obs.ag.shape[:2] == unique_ep_indices.shape + assert ep_rew.shape == unique_ep_indices.shape + assert np.all(future_t >= indices[0]) + + # Cache original meta + self._altered_indices = unique_ep_indices.copy() + self._original_meta = self._meta[self._altered_indices].copy() + + # Re-write meta + self._meta.obs[unique_ep_indices] = self._flatten_obs(ep_obs) + if self._save_obs_next: + self._meta.obs_next[unique_ep_indices] = self._flatten_obs(ep_obs_next) + self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32) + + # Reshaping obs into (bsz, *shape) instead of (..., *shape) before + # calling the provided functions. + def _deconstruct_obs(self, obs: np.ndarray, lead_dims: int = 2) -> Batch: + lead_shape = obs.shape[:lead_dims] + flatten_obs = obs.reshape(-1, *obs.shape[lead_dims:]) + de_obs = self.deconstruct_obs_fn(flatten_obs) + de_obs.o = de_obs.o.reshape(*lead_shape, *de_obs.o.shape[1:]) + de_obs.g = de_obs.g.reshape(*lead_shape, *de_obs.g.shape[1:]) + de_obs.ag = de_obs.ag.reshape(*lead_shape, *de_obs.ag.shape[1:]) + return de_obs + + def _flatten_obs(self, de_obs: Batch, lead_dims: int = 2) -> np.ndarray: + lead_shape = de_obs.o.shape[:lead_dims] + de_obs.o = de_obs.o.reshape(-1, *de_obs.o.shape[lead_dims:]) + de_obs.g = de_obs.g.reshape(-1, *de_obs.g.shape[lead_dims:]) + de_obs.ag = de_obs.ag.reshape(-1, *de_obs.ag.shape[lead_dims:]) + flatten_obs = self.flatten_obs_fn(de_obs) + return flatten_obs.reshape(*lead_shape, *flatten_obs.shape[1:]) + + def _compute_reward(self, de_obs: Batch, lead_dims: int = 2) -> np.ndarray: + lead_shape = de_obs.o.shape[:lead_dims] + de_obs.o = de_obs.o.reshape(-1, *de_obs.o.shape[lead_dims:]) + de_obs.g = de_obs.g.reshape(-1, *de_obs.g.shape[lead_dims:]) + de_obs.ag = de_obs.ag.reshape(-1, *de_obs.ag.shape[lead_dims:]) + rewards = self.compute_reward_fn(de_obs) + de_obs.o = de_obs.o.reshape(*lead_shape, *de_obs.o.shape[1:]) + de_obs.g = de_obs.g.reshape(*lead_shape, *de_obs.g.shape[1:]) + de_obs.ag = de_obs.ag.reshape(*lead_shape, *de_obs.ag.shape[1:]) + return rewards.reshape(*lead_shape, *rewards.shape[1:]) diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 2df50fa96..baa00af93 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -3,7 +3,7 @@ import numpy as np from numba import njit -from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer +from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer from tianshou.data.batch import _alloc_by_keys_diff, _create_value @@ -21,7 +21,9 @@ class ReplayBufferManager(ReplayBuffer): Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__(self, buffer_list: List[ReplayBuffer]) -> None: + def __init__( + self, buffer_list: Union[List[ReplayBuffer], List[HERReplayBuffer]] + ) -> None: self.buffer_num = len(buffer_list) self.buffers = np.array(buffer_list, dtype=object) offset, size = [], 0 @@ -212,6 +214,44 @@ def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) +class HERReplayBufferManager(ReplayBufferManager): + """HERReplayBufferManager contains a list of HERReplayBuffer with \ + exactly the same configuration. + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + :param buffer_list: a list of HERReplayBuffer needed to be handled. + .. seealso:: + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, buffer_list: List[HERReplayBuffer]) -> None: + super().__init__(buffer_list) + + def _restore_cache(self) -> None: + for buf in self.buffers: + buf._restore_cache() + + def save_hdf5(self, path: str, compression: Optional[str] = None) -> None: + self._restore_cache() + return super().save_hdf5(path, compression) + + def set_batch(self, batch: Batch) -> None: + self._restore_cache() + return super().set_batch(batch) + + def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: + self._restore_cache() + return super().update(buffer) + + def add( + self, + batch: Batch, + buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._restore_cache() + return super().add(batch, buffer_ids) + + @njit def _prev_index( index: np.ndarray, diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py index 2d4831c06..0457e21fc 100644 --- a/tianshou/data/buffer/vecbuf.py +++ b/tianshou/data/buffer/vecbuf.py @@ -3,6 +3,8 @@ import numpy as np from tianshou.data import ( + HERReplayBuffer, + HERReplayBufferManager, PrioritizedReplayBuffer, PrioritizedReplayBufferManager, ReplayBuffer, @@ -64,3 +66,22 @@ def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: def set_beta(self, beta: float) -> None: for buffer in self.buffers: buffer.set_beta(beta) + + +class HERVectorReplayBuffer(HERReplayBufferManager): + """HERVectorReplayBuffer contains n HERReplayBuffer with same size. + It is used for storing transition from different environments yet keeping the order + of time. + :param int total_size: the total size of HERVectorReplayBuffer. + :param int buffer_num: the number of HERReplayBuffer it uses, which are + under the same configuration. + Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`. + .. seealso:: + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) + buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)] + super().__init__(buffer_list) From c0d938e66e42993575d324cc9d4d53b7e888bd1c Mon Sep 17 00:00:00 2001 From: juno-t Date: Sat, 1 Oct 2022 07:16:41 +0000 Subject: [PATCH 02/20] debug test her --- test/base/test_buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 028729ef0..392d2231e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -315,6 +315,7 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): ) buf2 = HERVectorReplayBuffer( bufsize, + buffer_num=3, deconstruct_obs_fn=env.deconstruct_obs_fn, flatten_obs_fn=env.flatten_obs_fn, compute_reward_fn=env.compute_reward_fn, @@ -1276,6 +1277,7 @@ def test_from_data(): test_stack() test_segtree() test_priortized_replaybuffer() + test_herreplaybuffer() test_update() test_pickle() test_hdf5() From b4d4182e0278e38d6739d13753374a884b02db0d Mon Sep 17 00:00:00 2001 From: juno-t Date: Sun, 2 Oct 2022 03:11:57 +0000 Subject: [PATCH 03/20] add goal env wrapper, test --- test/base/test_env.py | 53 ++++++++++++++++ tianshou/env/__init__.py | 7 ++- tianshou/env/gym_wrappers.py | 114 ++++++++++++++++++++++++++++++++++- 3 files changed, 172 insertions(+), 2 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index 6a222709d..851acda2c 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -10,6 +10,7 @@ from tianshou.env import ( ContinuousToDiscrete, DummyVectorEnv, + GoalEnvWrapper, MultiDiscreteToDiscrete, RayVectorEnv, ShmemVectorEnv, @@ -376,6 +377,57 @@ def __init__(self): ) +def test_goal_env_wrapper(): + + class FetchLikeEnv(gym.Env): + + def __init__(self): + self.observation_space = gym.spaces.Dict( + { + 'observation': gym.spaces.Box(-1, 1, (7, 2)), + 'achieved_goal': gym.spaces.Box(-1, 1, (2, 3)), + 'desired_goal': gym.spaces.Box(-1, 1, (2, 3)), + } + ) + self.action_space = gym.spaces.Discrete(2) + + def compute_reward(self, ag, g, info): + return (ag == g).all(axis=(-2, -1)).astype(np.int64) + + def reset(self): + return self.observation_space.sample(), {} + + env = FetchLikeEnv() + env = GoalEnvWrapper(env, compute_reward=env.compute_reward) + + # Test single observation + obs, _ = env.reset() + assert isinstance(obs, np.ndarray) + assert obs.shape == (7 * 2 + 2 * 3 + 2 * 3, ) + de_obs = env.deconstruct_obs_fn(obs) + assert de_obs.o.shape == (7, 2) + assert de_obs.ag.shape == (2, 3) + assert de_obs.g.shape == (2, 3) + fl_obs = env.flatten_obs_fn(de_obs) + assert np.array_equal(obs, fl_obs) + rew = env.compute_reward_fn(de_obs) + assert rew.shape == () + + # Test batch observation + obs2, _ = env.reset() + assert isinstance(obs, np.ndarray) + assert obs2.shape == (7 * 2 + 2 * 3 + 2 * 3, ) + batch_obs = np.array([obs, obs2]) + de_obs = env.deconstruct_obs_fn(batch_obs) + assert de_obs.o.shape == (2, 7, 2) + assert de_obs.ag.shape == (2, 2, 3) + assert de_obs.g.shape == (2, 2, 3) + fl_obs = env.flatten_obs_fn(de_obs) + assert np.array_equal(batch_obs, fl_obs) + rew = env.compute_reward_fn(de_obs) + assert rew.shape == (2, ) + + @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_venv_wrapper_envpool(): raw = envpool.make_gym("Ant-v3", num_envs=4) @@ -415,3 +467,4 @@ def test_venv_wrapper_envpool_gym_reset_return_info(): test_async_check_id() test_env_reset_optional_kwargs() test_gym_wrappers() + test_goal_env_wrapper() diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 6abea3280..82d6c4ace 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,6 +1,10 @@ """Env package.""" -from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete +from tianshou.env.gym_wrappers import ( + ContinuousToDiscrete, + GoalEnvWrapper, + MultiDiscreteToDiscrete, +) from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper from tianshou.env.venvs import ( BaseVectorEnv, @@ -26,4 +30,5 @@ "PettingZooEnv", "ContinuousToDiscrete", "MultiDiscreteToDiscrete", + "GoalEnvWrapper", ] diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index 5b98e77af..27c69bb7c 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -1,8 +1,10 @@ -from typing import List, Union +from typing import List, Union, Callable import gym import numpy as np +from tianshou.data import Batch + class ContinuousToDiscrete(gym.ActionWrapper): """Gym environment wrapper to take discrete action in a continuous environment. @@ -55,3 +57,113 @@ def action(self, act: np.ndarray) -> np.ndarray: converted_act.append(act // b) act = act % b return np.array(converted_act).transpose() + + +class GoalEnv(gym.Env): + observation_space: gym.spaces.Space[gym.spaces.Dict] + + +class GoalEnvWrapper(gym.ObservationWrapper): + + def __init__( + self, + env: GoalEnv, + compute_reward: Callable[[np.ndarray, np.ndarray, dict], np.ndarray], + obs_space_keys: List[str] = ['observation', 'achieved_goal', 'desired_goal'] + # obs_space_keys must be in o, ag, g order. + ) -> None: + super().__init__(env) + self.env = env + self.compute_reward = compute_reward + self.obs_space_keys = obs_space_keys + self.original_space: gym.spaces.Space[gym.spaces.Dict] \ + = self.env.observation_space + + self.observation_space = self.calculate_obs_space() + + def calculate_obs_space(self) -> gym.Space: + for k in self.obs_space_keys: + assert isinstance(self.original_space[k], gym.spaces.Box) + new_low = np.concatenate( + [self.original_space[k].low.flatten() for k in self.obs_space_keys], + axis=0 + ) + + new_high = np.concatenate( + [self.original_space[k].high.flatten() for k in self.obs_space_keys], + axis=0 + ) + new_shape = np.concatenate( + [self.original_space[k].sample().flatten() for k in self.obs_space_keys], + axis=0 + ).shape + samples = [ + self.original_space[k].sample().flatten() for k in self.obs_space_keys + ] + self.partitions = np.cumsum([0] + [len(s) for s in samples], dtype=int) + return gym.spaces.Box(new_low, new_high, new_shape) + + def deconstruct_obs_fn(self, obs: np.ndarray) -> Batch: + """Deconstruct observation into observation, acheived_goal, goal. The first + dimension (bsz) is optional. + obs: shape(bsz, *observation_shape) + return: Batch( + o=shape(bsz, *o.shape), + ag=shape(bsz, *ag.shape), + g=shape(bsz, *g.shape) + ) or Batch without the first dim (bsz) according to the input. + """ + new_shapes = [ + [*self.original_space[self.obs_space_keys[0]].shape], + [*self.original_space[self.obs_space_keys[1]].shape], + [*self.original_space[self.obs_space_keys[2]].shape], + ] + if len(obs.shape) == 2: + new_shapes = [[-1] + s for s in new_shapes] + batch = Batch( + o=obs[..., self.partitions[0]:self.partitions[1]].reshape(*new_shapes[0]), + ag=obs[..., self.partitions[1]:self.partitions[2]].reshape(*new_shapes[1]), + g=obs[..., self.partitions[2]:self.partitions[3]].reshape(*new_shapes[2]), + ) + return batch + + def flatten_obs_fn(self, obs: Batch) -> np.ndarray: + """Reconstruct observation. The first dim (bsz) is optional + obs: Batch( + o=shape(bsz, *o.shape), + ag=shape(bsz, *ag.shape), + g=shape(bsz, *g.shape) + ) + return: shape(bsz, *observation_shape) + """ + new_shape = [-1] + if len(obs.o.shape) > len(self.original_space[self.obs_space_keys[0]].shape): + bsz = obs.shape[0] + new_shape = [bsz, -1] + return np.concatenate( + [ + obs.o.reshape(*new_shape), + obs.ag.reshape(*new_shape), + obs.g.reshape(*new_shape) + ], + axis=-1 + ) + + def compute_reward_fn(self, obs: Batch) -> np.ndarray: + """Compute rewards from deconstructed obs. The first dim (bsz) is optional + obs: Batch( + o=shape(bsz, *o.shape), + ag=shape(bsz, *ag.shape), + g=shape(bsz, *g.shape) + ) + return: shape(bsz,) + """ + ag = obs.ag + g = obs.g + return self.compute_reward(ag, g, {}) + + def observation(self, observation: dict) -> np.ndarray: + o = observation[self.obs_space_keys[0]].flatten() + ag = observation[self.obs_space_keys[1]].flatten() + g = observation[self.obs_space_keys[2]].flatten() + return np.concatenate([o, ag, g]) From f0480a3fb6df791b06eb1359fae9fa289ccf11cd Mon Sep 17 00:00:00 2001 From: juno-t Date: Sun, 2 Oct 2022 07:56:06 +0000 Subject: [PATCH 04/20] update HER to use dict obs, and test --- test/base/env.py | 70 +++++++-------------- test/base/test_buffer.py | 47 ++++++-------- test/base/test_env.py | 61 +++---------------- tianshou/data/buffer/her.py | 86 ++++++++------------------ tianshou/env/__init__.py | 4 +- tianshou/env/gym_wrappers.py | 115 +++-------------------------------- 6 files changed, 87 insertions(+), 296 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index b815b646f..8c6333d0b 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -7,8 +7,6 @@ import numpy as np from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple -from tianshou.data import Batch - class MyTestEnv(gym.Env): """This is a "going right" task. The task is to go right ``size`` steps. @@ -180,60 +178,36 @@ def __init__(self, *args, **kwargs): obs, _, _, _, _ = super().step(1) self._goal = obs * self.size super_obsv = self.observation_space - self.observation_space = Box( - shape=(super_obsv.shape[0] * 3, *super_obsv.shape[1:]), - low=0, - high=self.size + self.observation_space = gym.spaces.Dict( + { + 'observation': super_obsv, + 'achieved_goal': super_obsv, + 'desired_goal': super_obsv, + } ) def reset(self, *args, **kwargs): obs, info = super().reset(*args, **kwargs) - new_obs = np.concatenate([obs, obs, self._goal], axis=0) + new_obs = { + 'observation': obs, + 'achieved_goal': obs, + 'desired_goal': self._goal + } return new_obs, info def step(self, *args, **kwargs): obs_next, rew, terminated, truncated, info = super().step(*args, **kwargs) - new_obs_next = np.concatenate([obs_next, obs_next, self._goal], axis=0) + new_obs_next = { + 'observation': obs_next, + 'achieved_goal': obs_next, + 'desired_goal': self._goal + } return new_obs_next, rew, terminated, truncated, info - def deconstruct_obs_fn(self, obs: np.ndarray) -> Batch: - """Deconstruct observation into observation, acheived_goal, goal - obs: shape(bsz, *observation_shape) - return: Batch( - o=shape(bsz, *o.shape), - ag=shape(bsz, *ag.shape), - g=shape(bsz, *g.shape) - ) - """ - state_sz = 1 + def compute_reward_fn( + self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: dict + ) -> np.ndarray: + axis = -1 if self.array_state: - state_sz = 4 - return Batch( - o=obs[:, :state_sz], - ag=obs[:, state_sz:2 * state_sz], - g=obs[:, 2 * state_sz:], - ) - - def flatten_obs_fn(self, obs: Batch) -> np.ndarray: - """Reconstruct observation - obs: Batch( - o=shape(bsz, *o.shape), - ag=shape(bsz, *ag.shape), - g=shape(bsz, *g.shape) - ) - return: shape(bsz, *observation_shape) - """ - return np.concatenate((obs.o, obs.ag, obs.g), axis=1) - - def compute_reward_fn(self, obs: Batch) -> np.ndarray: - """Compute rewards from deconstructed obs - obs: Batch( - o=shape(bsz, *o.shape), - ag=shape(bsz, *ag.shape), - g=shape(bsz, *g.shape) - ) - return: shape(bsz,) - """ - ag_sum = obs.ag.reshape(obs.ag.shape[0], -1).sum(axis=1) - g_sum = obs.g.reshape(obs.g.shape[0], -1).sum(axis=1) - return (ag_sum == g_sum) + axis = (-3, -2, -1) + return (achieved_goal == desired_goal).all(axis=axis) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 392d2231e..00f62a75e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -306,18 +306,11 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): env_size = size env = MyGoalEnv(env_size, array_state=True) buf = HERReplayBuffer( - bufsize, - deconstruct_obs_fn=env.deconstruct_obs_fn, - flatten_obs_fn=env.flatten_obs_fn, - compute_reward_fn=env.compute_reward_fn, - horizon=30, - future_k=8 + bufsize, compute_reward_fn=env.compute_reward_fn, horizon=30, future_k=8 ) buf2 = HERVectorReplayBuffer( bufsize, buffer_num=3, - deconstruct_obs_fn=env.deconstruct_obs_fn, - flatten_obs_fn=env.flatten_obs_fn, compute_reward_fn=env.compute_reward_fn, horizon=30, future_k=8 @@ -351,12 +344,12 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = env.deconstruct_obs_fn(buf[tmp_indices].obs) - obs_next = env.deconstruct_obs_fn(buf[tmp_indices].obs_next) + obs = buf[tmp_indices].obs + obs_next = buf[tmp_indices].obs_next rew = buf[tmp_indices].rew - g = obs.g.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.ag.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.g.reshape(sample_sz, -1)[:, 0] + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g[0]) assert np.all(g_next == g_next[0]) assert np.all(rew == (ag_next == g).astype(np.float32)) @@ -366,10 +359,10 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): buf._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = env.deconstruct_obs_fn(buf[tmp_indices].obs) - obs_next = env.deconstruct_obs_fn(buf[tmp_indices].obs_next) - g = obs.g.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.g.reshape(sample_sz, -1)[:, 0] + obs = buf[tmp_indices].obs + obs_next = buf[tmp_indices].obs_next + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -381,12 +374,12 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = env.deconstruct_obs_fn(buf2[tmp_indices].obs) - obs_next = env.deconstruct_obs_fn(buf2[tmp_indices].obs_next) + obs = buf2[tmp_indices].obs + obs_next = buf2[tmp_indices].obs_next rew = buf2[tmp_indices].rew - g = obs.g.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.ag.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.g.reshape(sample_sz, -1)[:, 0] + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g_next) assert np.all(rew == (ag_next == g).astype(np.float32)) tmp_indices = buf2.next(tmp_indices) @@ -395,10 +388,10 @@ def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): buf2._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = env.deconstruct_obs_fn(buf2[tmp_indices].obs) - g = obs.g.reshape(sample_sz, -1)[:, 0] - g_next = env.deconstruct_obs_fn(buf2[tmp_indices].obs_next - ).g.reshape(sample_sz, -1)[:, 0] + obs = buf2[tmp_indices].obs + obs_next = buf2[tmp_indices].obs_next + g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -1277,7 +1270,6 @@ def test_from_data(): test_stack() test_segtree() test_priortized_replaybuffer() - test_herreplaybuffer() test_update() test_pickle() test_hdf5() @@ -1286,3 +1278,4 @@ def test_from_data(): test_multibuf_stack() test_multibuf_hdf5() test_from_data() + test_herreplaybuffer() diff --git a/test/base/test_env.py b/test/base/test_env.py index 851acda2c..3e6b20c49 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -10,13 +10,13 @@ from tianshou.env import ( ContinuousToDiscrete, DummyVectorEnv, - GoalEnvWrapper, MultiDiscreteToDiscrete, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv, VectorEnvNormObs, ) +from tianshou.env.gym_wrappers import TruncatedAsTerminated from tianshou.utils import RunningMeanStd if __name__ == "__main__": @@ -349,6 +349,9 @@ def __init__(self): low=-1.0, high=2.0, shape=(4, ), dtype=np.float32 ) + def step(self, act): + return np.array([0]), -1, False, True, {} + bsz = 10 action_per_branch = [4, 6, 10, 7] env = DummyEnv() @@ -375,57 +378,10 @@ def __init__(self): env_d.action(np.array([env_d.action_space.n - 1] * bsz)), np.array([env_m.action_space.nvec - 1] * bsz), ) - - -def test_goal_env_wrapper(): - - class FetchLikeEnv(gym.Env): - - def __init__(self): - self.observation_space = gym.spaces.Dict( - { - 'observation': gym.spaces.Box(-1, 1, (7, 2)), - 'achieved_goal': gym.spaces.Box(-1, 1, (2, 3)), - 'desired_goal': gym.spaces.Box(-1, 1, (2, 3)), - } - ) - self.action_space = gym.spaces.Discrete(2) - - def compute_reward(self, ag, g, info): - return (ag == g).all(axis=(-2, -1)).astype(np.int64) - - def reset(self): - return self.observation_space.sample(), {} - - env = FetchLikeEnv() - env = GoalEnvWrapper(env, compute_reward=env.compute_reward) - - # Test single observation - obs, _ = env.reset() - assert isinstance(obs, np.ndarray) - assert obs.shape == (7 * 2 + 2 * 3 + 2 * 3, ) - de_obs = env.deconstruct_obs_fn(obs) - assert de_obs.o.shape == (7, 2) - assert de_obs.ag.shape == (2, 3) - assert de_obs.g.shape == (2, 3) - fl_obs = env.flatten_obs_fn(de_obs) - assert np.array_equal(obs, fl_obs) - rew = env.compute_reward_fn(de_obs) - assert rew.shape == () - - # Test batch observation - obs2, _ = env.reset() - assert isinstance(obs, np.ndarray) - assert obs2.shape == (7 * 2 + 2 * 3 + 2 * 3, ) - batch_obs = np.array([obs, obs2]) - de_obs = env.deconstruct_obs_fn(batch_obs) - assert de_obs.o.shape == (2, 7, 2) - assert de_obs.ag.shape == (2, 2, 3) - assert de_obs.g.shape == (2, 2, 3) - fl_obs = env.flatten_obs_fn(de_obs) - assert np.array_equal(batch_obs, fl_obs) - rew = env.compute_reward_fn(de_obs) - assert rew.shape == (2, ) + # check truncate is True when terminated + env_t = TruncatedAsTerminated(env) + _, _, truncated, _, _ = env_t.step(0) + assert truncated @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") @@ -467,4 +423,3 @@ def test_venv_wrapper_envpool_gym_reset_return_info(): test_async_check_id() test_env_reset_optional_kwargs() test_gym_wrappers() - test_goal_env_wrapper() diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 910b3646f..3aa9ea9a1 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -10,9 +10,7 @@ class HERReplayBuffer(ReplayBuffer): def __init__( self, size: int, - deconstruct_obs_fn: Callable[[np.ndarray], Batch], - flatten_obs_fn: Callable[[Batch], np.ndarray], - compute_reward_fn: Callable[[Batch], np.ndarray], + compute_reward_fn: Callable[[np.ndarray, np.ndarray, dict], np.ndarray], horizon: int, future_k: float = 8.0, **kwargs: Any, @@ -20,8 +18,6 @@ def __init__( super().__init__(size, **kwargs) self.horizon = horizon self.future_p = 1 - 1 / future_k - self.deconstruct_obs_fn = deconstruct_obs_fn - self.flatten_obs_fn = flatten_obs_fn self.compute_reward_fn = compute_reward_fn self._original_meta = Batch() self._altered_indices = np.array([]) @@ -116,75 +112,47 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: replace=False ) + # Cache original meta + self._altered_indices = unique_ep_indices.copy() + self._original_meta = self._meta[self._altered_indices].copy() + # Copy original obs, ep_rew (and obs_next), and obs of future time step - ep_obs = self._deconstruct_obs(self[unique_ep_indices].obs) + ep_obs = self[unique_ep_indices].obs ep_rew = self[unique_ep_indices].rew if self._save_obs_next: - ep_obs_next = self._deconstruct_obs(self[unique_ep_indices].obs_next) - future_obs = self._deconstruct_obs( - self[future_t[unique_ep_close_indices]].obs_next, lead_dims=1 - ) + ep_obs_next = self[unique_ep_indices].obs_next + future_obs = self[future_t[unique_ep_close_indices]].obs_next else: - future_obs = self._deconstruct_obs( - self[self.next(future_t[unique_ep_close_indices])].obs, lead_dims=1 - ) + future_obs = self[self.next(future_t[unique_ep_close_indices])].obs # Re-assign goals and rewards via broadcast assignment - ep_obs.g[:, her_ep_indices] = future_obs.ag[None, her_ep_indices] + ep_obs.desired_goal[:, her_ep_indices] = \ + future_obs.achieved_goal[None, her_ep_indices] if self._save_obs_next: - ep_obs_next.g[:, her_ep_indices] = future_obs.ag[None, her_ep_indices] - ep_rew[:, - her_ep_indices] = self._compute_reward(ep_obs_next)[:, - her_ep_indices] + ep_obs_next.desired_goal[:, her_ep_indices] = \ + future_obs.achieved_goal[None, her_ep_indices] + ep_rew[:, her_ep_indices] = \ + self._compute_reward(ep_obs_next)[:, her_ep_indices] else: - tmp_ep_obs_next = self._deconstruct_obs( - self[self.next(unique_ep_indices)].obs - ) - ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next - )[:, her_ep_indices] + tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs + ep_rew[:, her_ep_indices] = \ + self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices] # Sanity check - assert ep_obs.g.shape[:2] == unique_ep_indices.shape - assert ep_obs.ag.shape[:2] == unique_ep_indices.shape + assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape + assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape assert ep_rew.shape == unique_ep_indices.shape assert np.all(future_t >= indices[0]) - # Cache original meta - self._altered_indices = unique_ep_indices.copy() - self._original_meta = self._meta[self._altered_indices].copy() - # Re-write meta - self._meta.obs[unique_ep_indices] = self._flatten_obs(ep_obs) + self._meta.obs[unique_ep_indices] = ep_obs if self._save_obs_next: - self._meta.obs_next[unique_ep_indices] = self._flatten_obs(ep_obs_next) + self._meta.obs_next[unique_ep_indices] = ep_obs_next self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32) - # Reshaping obs into (bsz, *shape) instead of (..., *shape) before - # calling the provided functions. - def _deconstruct_obs(self, obs: np.ndarray, lead_dims: int = 2) -> Batch: - lead_shape = obs.shape[:lead_dims] - flatten_obs = obs.reshape(-1, *obs.shape[lead_dims:]) - de_obs = self.deconstruct_obs_fn(flatten_obs) - de_obs.o = de_obs.o.reshape(*lead_shape, *de_obs.o.shape[1:]) - de_obs.g = de_obs.g.reshape(*lead_shape, *de_obs.g.shape[1:]) - de_obs.ag = de_obs.ag.reshape(*lead_shape, *de_obs.ag.shape[1:]) - return de_obs - - def _flatten_obs(self, de_obs: Batch, lead_dims: int = 2) -> np.ndarray: - lead_shape = de_obs.o.shape[:lead_dims] - de_obs.o = de_obs.o.reshape(-1, *de_obs.o.shape[lead_dims:]) - de_obs.g = de_obs.g.reshape(-1, *de_obs.g.shape[lead_dims:]) - de_obs.ag = de_obs.ag.reshape(-1, *de_obs.ag.shape[lead_dims:]) - flatten_obs = self.flatten_obs_fn(de_obs) - return flatten_obs.reshape(*lead_shape, *flatten_obs.shape[1:]) - - def _compute_reward(self, de_obs: Batch, lead_dims: int = 2) -> np.ndarray: - lead_shape = de_obs.o.shape[:lead_dims] - de_obs.o = de_obs.o.reshape(-1, *de_obs.o.shape[lead_dims:]) - de_obs.g = de_obs.g.reshape(-1, *de_obs.g.shape[lead_dims:]) - de_obs.ag = de_obs.ag.reshape(-1, *de_obs.ag.shape[lead_dims:]) - rewards = self.compute_reward_fn(de_obs) - de_obs.o = de_obs.o.reshape(*lead_shape, *de_obs.o.shape[1:]) - de_obs.g = de_obs.g.reshape(*lead_shape, *de_obs.g.shape[1:]) - de_obs.ag = de_obs.ag.reshape(*lead_shape, *de_obs.ag.shape[1:]) + def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray: + lead_shape = obs.observation.shape[:lead_dims] + g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:]) + ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:]) + rewards = self.compute_reward_fn(g, ag, {}) return rewards.reshape(*lead_shape, *rewards.shape[1:]) diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 82d6c4ace..a00c3cd38 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -2,8 +2,8 @@ from tianshou.env.gym_wrappers import ( ContinuousToDiscrete, - GoalEnvWrapper, MultiDiscreteToDiscrete, + TruncatedAsTerminated, ) from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper from tianshou.env.venvs import ( @@ -30,5 +30,5 @@ "PettingZooEnv", "ContinuousToDiscrete", "MultiDiscreteToDiscrete", - "GoalEnvWrapper", + "TruncatedAsTerminated", ] diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index 27c69bb7c..44dfc0718 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -1,9 +1,9 @@ -from typing import List, Union, Callable +from typing import List, Union import gym import numpy as np -from tianshou.data import Batch +from tianshou.env.utils import gym_new_venv_step_type class ContinuousToDiscrete(gym.ActionWrapper): @@ -59,111 +59,12 @@ def action(self, act: np.ndarray) -> np.ndarray: return np.array(converted_act).transpose() -class GoalEnv(gym.Env): - observation_space: gym.spaces.Space[gym.spaces.Dict] +class TruncatedAsTerminated(gym.Wrapper): - -class GoalEnvWrapper(gym.ObservationWrapper): - - def __init__( - self, - env: GoalEnv, - compute_reward: Callable[[np.ndarray, np.ndarray, dict], np.ndarray], - obs_space_keys: List[str] = ['observation', 'achieved_goal', 'desired_goal'] - # obs_space_keys must be in o, ag, g order. - ) -> None: + def __init__(self, env: gym.Env): super().__init__(env) - self.env = env - self.compute_reward = compute_reward - self.obs_space_keys = obs_space_keys - self.original_space: gym.spaces.Space[gym.spaces.Dict] \ - = self.env.observation_space - - self.observation_space = self.calculate_obs_space() - - def calculate_obs_space(self) -> gym.Space: - for k in self.obs_space_keys: - assert isinstance(self.original_space[k], gym.spaces.Box) - new_low = np.concatenate( - [self.original_space[k].low.flatten() for k in self.obs_space_keys], - axis=0 - ) - new_high = np.concatenate( - [self.original_space[k].high.flatten() for k in self.obs_space_keys], - axis=0 - ) - new_shape = np.concatenate( - [self.original_space[k].sample().flatten() for k in self.obs_space_keys], - axis=0 - ).shape - samples = [ - self.original_space[k].sample().flatten() for k in self.obs_space_keys - ] - self.partitions = np.cumsum([0] + [len(s) for s in samples], dtype=int) - return gym.spaces.Box(new_low, new_high, new_shape) - - def deconstruct_obs_fn(self, obs: np.ndarray) -> Batch: - """Deconstruct observation into observation, acheived_goal, goal. The first - dimension (bsz) is optional. - obs: shape(bsz, *observation_shape) - return: Batch( - o=shape(bsz, *o.shape), - ag=shape(bsz, *ag.shape), - g=shape(bsz, *g.shape) - ) or Batch without the first dim (bsz) according to the input. - """ - new_shapes = [ - [*self.original_space[self.obs_space_keys[0]].shape], - [*self.original_space[self.obs_space_keys[1]].shape], - [*self.original_space[self.obs_space_keys[2]].shape], - ] - if len(obs.shape) == 2: - new_shapes = [[-1] + s for s in new_shapes] - batch = Batch( - o=obs[..., self.partitions[0]:self.partitions[1]].reshape(*new_shapes[0]), - ag=obs[..., self.partitions[1]:self.partitions[2]].reshape(*new_shapes[1]), - g=obs[..., self.partitions[2]:self.partitions[3]].reshape(*new_shapes[2]), - ) - return batch - - def flatten_obs_fn(self, obs: Batch) -> np.ndarray: - """Reconstruct observation. The first dim (bsz) is optional - obs: Batch( - o=shape(bsz, *o.shape), - ag=shape(bsz, *ag.shape), - g=shape(bsz, *g.shape) - ) - return: shape(bsz, *observation_shape) - """ - new_shape = [-1] - if len(obs.o.shape) > len(self.original_space[self.obs_space_keys[0]].shape): - bsz = obs.shape[0] - new_shape = [bsz, -1] - return np.concatenate( - [ - obs.o.reshape(*new_shape), - obs.ag.reshape(*new_shape), - obs.g.reshape(*new_shape) - ], - axis=-1 - ) - - def compute_reward_fn(self, obs: Batch) -> np.ndarray: - """Compute rewards from deconstructed obs. The first dim (bsz) is optional - obs: Batch( - o=shape(bsz, *o.shape), - ag=shape(bsz, *ag.shape), - g=shape(bsz, *g.shape) - ) - return: shape(bsz,) - """ - ag = obs.ag - g = obs.g - return self.compute_reward(ag, g, {}) - - def observation(self, observation: dict) -> np.ndarray: - o = observation[self.obs_space_keys[0]].flatten() - ag = observation[self.obs_space_keys[1]].flatten() - g = observation[self.obs_space_keys[2]].flatten() - return np.concatenate([o, ag, g]) + def step(self, act: np.ndarray) -> gym_new_venv_step_type: + observation, reward, terminated, truncated, info = super().step(act) + terminated = (terminated or truncated) + return observation, reward, terminated, truncated, info From da2f8b3507e7e85ca36c30e8c4d438df2b06e171 Mon Sep 17 00:00:00 2001 From: juno-t Date: Sun, 2 Oct 2022 10:29:35 +0000 Subject: [PATCH 05/20] debug typehint, add example --- examples/offline/fetch_her_ddpg.py | 256 +++++++++++++++++++++++++++++ tianshou/env/gym_wrappers.py | 6 +- 2 files changed, 258 insertions(+), 4 deletions(-) create mode 100644 examples/offline/fetch_her_ddpg.py diff --git a/examples/offline/fetch_her_ddpg.py b/examples/offline/fetch_her_ddpg.py new file mode 100644 index 000000000..2e37fa446 --- /dev/null +++ b/examples/offline/fetch_her_ddpg.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +import argparse +import datetime +import os +import pprint +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, HERReplayBuffer, HERVectorReplayBuffer +from tianshou.data.batch import Batch +from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated +from tianshou.exploration import GaussianNoise +from tianshou.policy import DDPGPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger, WandbLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic + + +def get_args(): + # python3 fetch_her_ddpg.py --task FetchReach-v3 --seed 0 --horizon 50 + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="FetchReach-v3") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--buffer-size", type=int, default=1000000) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) + parser.add_argument("--actor-lr", type=float, default=1e-3) + parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--exploration-noise", type=float, default=0.1) + parser.add_argument("--start-timesteps", type=int, default=25000) + parser.add_argument("--epoch", type=int, default=200) + parser.add_argument("--step-per-epoch", type=int, default=5000) + parser.add_argument("--step-per-collect", type=int, default=1) + parser.add_argument("--update-per-step", type=int, default=1) + parser.add_argument("--n-step", type=int, default=1) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--horizon", type=int, default=50) + parser.add_argument("--future-k", type=int, default=8) + parser.add_argument("--training-num", type=int, default=1) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--resume-id", type=str, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb"], + ) + parser.add_argument("--wandb-project", type=str, default="HER-benchmark") + parser.add_argument( + "--watch", + default=False, + action="store_true", + help="watch the play of pre-trained policy only", + ) + return parser.parse_args() + + +def make_fetch_env(task, training_num, test_num): + env = TruncatedAsTerminated(gym.make(task)) + train_envs = ShmemVectorEnv( + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)] + ) + test_envs = ShmemVectorEnv( + [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)] + ) + return env, train_envs, test_envs + + +class DictStateNet(Net): + + def __init__( + self, state_shape: Dict[str, Union[int, Sequence[int]]], keys: Sequence[str], + **kwargs + ) -> None: + self.keys = keys + self.original_shape = state_shape + flat_state_shape = [] + for k in self.keys: + flat_state_shape.append(int(np.prod(state_shape[k]))) + super().__init__(sum(flat_state_shape), **kwargs) + + def preprocess_obs(self, obs): + if isinstance(obs, dict) or (isinstance(obs, Batch) and self.keys[0] in obs): + if self.original_shape[self.keys[0]] == obs[self.keys[0]].shape: + # No batch dim + new_obs = torch.Tensor([obs[k] for k in self.keys]).flatten() + # new_obs = torch.Tensor([obs[k] for k in self.keys]).reshape(1, -1) + else: + bsz = obs[self.keys[0]].shape[0] + new_obs = torch.cat( + [torch.Tensor(obs[k].reshape(bsz, -1)) for k in self.keys], axis=1 + ) + else: + new_obs = obs + return new_obs + + def forward( + self, + obs: Union[Dict[str, Union[np.ndarray, torch.Tensor]], Union[np.ndarray, + torch.Tensor]], + state: Any = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: + return super().forward(self.preprocess_obs(obs), state, info) + + +class GoalStateCritic(Critic): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + obs: Union[np.ndarray, torch.Tensor], + act: Optional[Union[np.ndarray, torch.Tensor]] = None, + info: Dict[str, Any] = {}, + ) -> torch.Tensor: + return super().forward(self.preprocess.preprocess_obs(obs), act, info) + + +def test_ddpg(args=get_args()): + env, train_envs, test_envs = make_fetch_env( + args.task, args.training_num, args.test_num + ) + args.state_shape = { + 'observation': env.observation_space['observation'].shape, + 'achieved_goal': env.observation_space['achieved_goal'].shape, + 'desired_goal': env.observation_space['desired_goal'].shape, + } + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + args.exploration_noise = args.exploration_noise * args.max_action + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # model + net_a = DictStateNet( + args.state_shape, + keys=['observation', 'achieved_goal', 'desired_goal'], + hidden_sizes=args.hidden_sizes, + device=args.device + ) + actor = Actor( + net_a, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + net_c = DictStateNet( + args.state_shape, + keys=['observation', 'achieved_goal', 'desired_goal'], + action_shape=args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic = GoalStateCritic(net_c, device=args.device).to(args.device) + critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) + policy = DDPGPolicy( + actor, + actor_optim, + critic, + critic_optim, + tau=args.tau, + gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + estimation_step=args.n_step, + action_space=env.action_space, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = HERVectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = HERReplayBuffer( + args.buffer_size, + compute_reward_fn=env.compute_reward, + horizon=args.horizon, + future_k=args.future_k, + ) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "ddpg" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) + + def save_best_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + if not args.watch: + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + save_best_fn=save_best_fn, + logger=logger, + update_per_step=args.update_per_step, + test_in_train=False, + ) + pprint.pprint(result) + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == "__main__": + test_ddpg() diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index 44dfc0718..d70bbbdd0 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -1,10 +1,8 @@ -from typing import List, Union +from typing import Any, Dict, List, Tuple, Union import gym import numpy as np -from tianshou.env.utils import gym_new_venv_step_type - class ContinuousToDiscrete(gym.ActionWrapper): """Gym environment wrapper to take discrete action in a continuous environment. @@ -64,7 +62,7 @@ class TruncatedAsTerminated(gym.Wrapper): def __init__(self, env: gym.Env): super().__init__(env) - def step(self, act: np.ndarray) -> gym_new_venv_step_type: + def step(self, act: np.ndarray) -> Tuple[Any, float, bool, bool, Dict[Any, Any]]: observation, reward, terminated, truncated, info = super().step(act) terminated = (terminated or truncated) return observation, reward, terminated, truncated, info From 7fc49b368158f0931c54726b7f8f9e297cf844b0 Mon Sep 17 00:00:00 2001 From: Juno T <42699114+Juno-T@users.noreply.github.com> Date: Sun, 2 Oct 2022 20:27:41 +0900 Subject: [PATCH 06/20] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 32b0c3b00..29f355640 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) - [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf) +- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf) Here are Tianshou's other features: From 4dcb98d09653ca9c9b77100c8546d956452e3e26 Mon Sep 17 00:00:00 2001 From: juno-t Date: Sun, 2 Oct 2022 12:25:27 +0000 Subject: [PATCH 07/20] add docstring --- tianshou/data/buffer/her.py | 44 ++++++++++++++++++++++++--------- tianshou/data/buffer/manager.py | 4 +++ tianshou/data/buffer/vecbuf.py | 4 +++ tianshou/env/gym_wrappers.py | 6 +++++ 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 3aa9ea9a1..01595071a 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -6,6 +6,24 @@ class HERReplayBuffer(ReplayBuffer): + """Implementation of Hindsight Experience Replay. arXiv:1707.01495. + + Currently support only 'future' strategy of HER. + + :param int size: the size of the replay buffer. + :param compute_reward_fn: a function that takes 3 arguments: \ + `acheived_goal`, `desired_goal`, `info` and returns the reward(s). + Note that the goal arguments can have extra batch_size dimension and in that \ + case, the rewards of size batch_size should be returned + :param int horizon: the maximum number of steps in an episode. + :param int future_k: the 'k' parameter introduced in the paper. In short, there \ + will be at most k episodes that are re-written for every 1 unaltered episode \ + during the sampling. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ def __init__( self, @@ -23,9 +41,9 @@ def __init__( self._altered_indices = np.array([]) def _restore_cache(self) -> None: - """ - Write cached original meta back to self._meta - Do this everytime before 'writing', 'sampling' or 'saving' the buffer. + """Write cached original meta back to `self._meta`. + + It's called everytime before 'writing', 'sampling' or 'saving' the buffer. """ if not hasattr(self, '_altered_indices'): return @@ -64,8 +82,11 @@ def add( def sample_indices(self, batch_size: int) -> np.ndarray: """Get a random sample of index with size = batch_size. - Return all available indices in the buffer if batch_size is 0; return an empty - numpy array if batch_size < 0 or no available index can be sampled. + + Return all available indices in the buffer if batch_size is 0; return an \ + empty numpy array if batch_size < 0 or no available index can be sampled. \ + Additionally, some episodes of the sampled transitions will be re-written \ + according to HER. """ self._restore_cache() indices = np.sort(super().sample_indices(batch_size=batch_size)) @@ -73,12 +94,13 @@ def sample_indices(self, batch_size: int) -> np.ndarray: return indices def rewrite_transitions(self, indices: np.ndarray) -> None: - """ Re-write the goal of some sampled transitions' episodes according to HER's - 'future' strategy. The new goals will be written directly to the internal - batch data temporarily and will be restored right before the next sampling or - when using some of the buffer's method (such as `add` or `save_hdf5`). This is - to make sure that n-step returns calculation etc. performs correctly without - alteration. + """Re-write the goal of some sampled transitions' episodes according to HER. + + Currently applies only HER's 'future' strategy. The new goals will be written \ + directly to the internal batch data temporarily and will be restored right \ + before the next sampling or when using some of the buffer's method (e.g. \ + `add`, `save_hdf5`, etc.). This is to make sure that n-step returns \ + calculation etc., performs correctly without additional alteration. """ if indices.size == 0: return diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index baa00af93..b694c1abe 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -217,10 +217,14 @@ def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: class HERReplayBufferManager(ReplayBufferManager): """HERReplayBufferManager contains a list of HERReplayBuffer with \ exactly the same configuration. + These replay buffers have contiguous memory layout, and the storage space each buffer has is a shallow copy of the topmost memory. + :param buffer_list: a list of HERReplayBuffer needed to be handled. + .. seealso:: + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py index 0457e21fc..a08ad9857 100644 --- a/tianshou/data/buffer/vecbuf.py +++ b/tianshou/data/buffer/vecbuf.py @@ -70,12 +70,16 @@ def set_beta(self, beta: float) -> None: class HERVectorReplayBuffer(HERReplayBufferManager): """HERVectorReplayBuffer contains n HERReplayBuffer with same size. + It is used for storing transition from different environments yet keeping the order of time. + :param int total_size: the total size of HERVectorReplayBuffer. :param int buffer_num: the number of HERReplayBuffer it uses, which are under the same configuration. + Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`. + .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index d70bbbdd0..8f5e4c01c 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -58,6 +58,12 @@ def action(self, act: np.ndarray) -> np.ndarray: class TruncatedAsTerminated(gym.Wrapper): + """A wrapper that set `terminated = terminated or truncated` for `step()`. + + It's intended to use with `gym.wrappers.TimeLimit`. + + :param gym.Env env: gym environment. + """ def __init__(self, env: gym.Env): super().__init__(env) From 1071dab0dfd12810e278f60aa29d3ea7779ce46a Mon Sep 17 00:00:00 2001 From: juno-t Date: Mon, 3 Oct 2022 00:36:55 +0000 Subject: [PATCH 08/20] correct wrapper test --- test/base/test_env.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index 3e6b20c49..328ca6dc5 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -348,9 +348,10 @@ def __init__(self): self.action_space = gym.spaces.Box( low=-1.0, high=2.0, shape=(4, ), dtype=np.float32 ) + self.observation_space = gym.spaces.Discrete(2) def step(self, act): - return np.array([0]), -1, False, True, {} + return self.observation_space.sample(), -1, False, True, {} bsz = 10 action_per_branch = [4, 6, 10, 7] @@ -380,7 +381,7 @@ def step(self, act): ) # check truncate is True when terminated env_t = TruncatedAsTerminated(env) - _, _, truncated, _, _ = env_t.step(0) + _, _, truncated, _, _ = env_t.step(env_t.action_space.sample()) assert truncated From 63e65c087fa4c070861f12ba4d5fea130fc56740 Mon Sep 17 00:00:00 2001 From: juno-t Date: Wed, 5 Oct 2022 08:45:08 +0000 Subject: [PATCH 09/20] update from feedback --- examples/offline/fetch_her_ddpg.py | 13 +++++++++++-- test/base/test_buffer.py | 8 ++++++-- test/base/test_env.py | 10 +++++++--- tianshou/data/buffer/her.py | 17 +++++++++-------- tianshou/env/gym_wrappers.py | 6 ++++++ 5 files changed, 39 insertions(+), 15 deletions(-) diff --git a/examples/offline/fetch_her_ddpg.py b/examples/offline/fetch_her_ddpg.py index 2e37fa446..c08a6ff21 100644 --- a/examples/offline/fetch_her_ddpg.py +++ b/examples/offline/fetch_her_ddpg.py @@ -188,12 +188,21 @@ def test_ddpg(args=get_args()): print("Loaded agent from: ", args.resume_path) # collector + def compute_reward_fn(ag: np.ndarray, g: np.ndarray): + return env.compute_reward(ag, g, {}) + if args.training_num > 1: - buffer = HERVectorReplayBuffer(args.buffer_size, len(train_envs)) + buffer = HERVectorReplayBuffer( + args.buffer_size, + len(train_envs), + compute_reward_fn=compute_reward_fn, + horizon=args.horizon, + future_k=args.future_k, + ) else: buffer = HERReplayBuffer( args.buffer_size, - compute_reward_fn=env.compute_reward, + compute_reward_fn=compute_reward_fn, horizon=args.horizon, future_k=args.future_k, ) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 00f62a75e..1cc2458ab 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -305,13 +305,17 @@ def test_priortized_replaybuffer(size=32, bufsize=15): def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4): env_size = size env = MyGoalEnv(env_size, array_state=True) + + def compute_reward_fn(ag, g): + return env.compute_reward_fn(ag, g, {}) + buf = HERReplayBuffer( - bufsize, compute_reward_fn=env.compute_reward_fn, horizon=30, future_k=8 + bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8 ) buf2 = HERVectorReplayBuffer( bufsize, buffer_num=3, - compute_reward_fn=env.compute_reward_fn, + compute_reward_fn=compute_reward_fn, horizon=30, future_k=8 ) diff --git a/test/base/test_env.py b/test/base/test_env.py index 328ca6dc5..1c91c0fa1 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -380,9 +380,13 @@ def step(self, act): np.array([env_m.action_space.nvec - 1] * bsz), ) # check truncate is True when terminated - env_t = TruncatedAsTerminated(env) - _, _, truncated, _, _ = env_t.step(env_t.action_space.sample()) - assert truncated + try: + env_t = TruncatedAsTerminated(env) + except EnvironmentError: + env_t = None + if env_t is not None: + _, _, truncated, _, _ = env_t.step(env_t.action_space.sample()) + assert truncated @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 01595071a..512322b30 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -11,10 +11,10 @@ class HERReplayBuffer(ReplayBuffer): Currently support only 'future' strategy of HER. :param int size: the size of the replay buffer. - :param compute_reward_fn: a function that takes 3 arguments: \ - `acheived_goal`, `desired_goal`, `info` and returns the reward(s). - Note that the goal arguments can have extra batch_size dimension and in that \ - case, the rewards of size batch_size should be returned + :param compute_reward_fn: a function that takes 2 `np.array` arguments, \ + `acheived_goal` and `desired_goal`, and returns rewards as `np.array`. + The two arguments are of shape (batch_size, *original_shape) and the returned \ + rewards must be of shape (batch_size,). :param int horizon: the maximum number of steps in an episode. :param int future_k: the 'k' parameter introduced in the paper. In short, there \ will be at most k episodes that are re-written for every 1 unaltered episode \ @@ -28,7 +28,7 @@ class HERReplayBuffer(ReplayBuffer): def __init__( self, size: int, - compute_reward_fn: Callable[[np.ndarray, np.ndarray, dict], np.ndarray], + compute_reward_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], horizon: int, future_k: float = 8.0, **kwargs: Any, @@ -52,7 +52,6 @@ def _restore_cache(self) -> None: return self._meta[self._altered_indices] = self._original_meta # Clean - del self._original_meta, self._altered_indices self._original_meta = Batch() self._altered_indices = np.array([]) @@ -89,7 +88,7 @@ def sample_indices(self, batch_size: int) -> np.ndarray: according to HER. """ self._restore_cache() - indices = np.sort(super().sample_indices(batch_size=batch_size)) + indices = super().sample_indices(batch_size=batch_size) self.rewrite_transitions(indices) return indices @@ -104,6 +103,8 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: """ if indices.size == 0: return + indices = np.sort(indices) + # Construct episode trajectories indices = [indices] for _ in range(self.horizon - 1): @@ -176,5 +177,5 @@ def _compute_reward(self, obs: Batch, lead_dims: int = 2) -> np.ndarray: lead_shape = obs.observation.shape[:lead_dims] g = obs.desired_goal.reshape(-1, *obs.desired_goal.shape[lead_dims:]) ag = obs.achieved_goal.reshape(-1, *obs.achieved_goal.shape[lead_dims:]) - rewards = self.compute_reward_fn(g, ag, {}) + rewards = self.compute_reward_fn(ag, g) return rewards.reshape(*lead_shape, *rewards.shape[1:]) diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index 8f5e4c01c..8a035194c 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -2,6 +2,7 @@ import gym import numpy as np +from packaging import version class ContinuousToDiscrete(gym.ActionWrapper): @@ -67,6 +68,11 @@ class TruncatedAsTerminated(gym.Wrapper): def __init__(self, env: gym.Env): super().__init__(env) + if not version.parse(gym.__version__) >= version.parse('0.26.0'): + raise EnvironmentError( + f"TruncatedAsTerminated is not applicable with gym version \ + {gym.__version__}" + ) def step(self, act: np.ndarray) -> Tuple[Any, float, bool, bool, Dict[Any, Any]]: observation, reward, terminated, truncated, info = super().step(act) From 9d52936b6b368abbdd8736835300880fdeeda101 Mon Sep 17 00:00:00 2001 From: juno-t Date: Wed, 5 Oct 2022 09:03:24 +0000 Subject: [PATCH 10/20] add doc --- docs/index.rst | 1 + tianshou/data/buffer/her.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 7fce12f6b..09098cb94 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.ICMPolicy` `Intrinsic Curiosity Module `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ +* :class:`~tianshou.data.HERReplayBuffer` `Hindsight Experience Replay `_ Here is Tianshou's other features: diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 512322b30..a4df2d936 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -8,7 +8,9 @@ class HERReplayBuffer(ReplayBuffer): """Implementation of Hindsight Experience Replay. arXiv:1707.01495. - Currently support only 'future' strategy of HER. + `HERReplayBuffer` is to be used with goal-based environment where the \ + observation is a dictionary with keys `observation`, `achieved_goal` and \ + `desired_goal`. Currently support only HER's future strategy, online sampling. :param int size: the size of the replay buffer. :param compute_reward_fn: a function that takes 2 `np.array` arguments, \ From 6ee7d08f414456ce2274e8ca0e6714e5cbe420c7 Mon Sep 17 00:00:00 2001 From: juno-t Date: Thu, 6 Oct 2022 04:08:52 +0000 Subject: [PATCH 11/20] fix indices calculation, add test --- test/base/test_buffer.py | 62 +++++++++++++++++++++++++++++-------- tianshou/data/buffer/her.py | 14 +++++---- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 1cc2458ab..4ed6f74b5 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -401,6 +401,42 @@ def compute_reward_fn(ag, g): assert np.all(g == g[0]) tmp_indices = buf2.next(tmp_indices) + # Test handling cycled indices + env_size = size + bufsize = 15 + env = MyGoalEnv(env_size, array_state=False) + + def compute_reward_fn(ag, g): + return env.compute_reward_fn(ag, g, {}) + + buf = HERReplayBuffer( + bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8 + ) + buf._index = 5 # shifted start index + buf.future_p = 1 + action_list = [1] * 10 + for ep_len in [5, 10]: + obs, _ = env.reset() + for i in range(ep_len): + act = 1 + obs_next, rew, terminated, truncated, info = env.step(act) + batch = Batch( + obs=obs, + act=[act], + rew=rew, + terminated=(i == ep_len - 1), + truncated=(i == ep_len - 1), + obs_next=obs_next, + info=info + ) + buf.add(batch) + obs = obs_next + batch, indices = buf.sample(0) + assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal) + assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal) + assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep) + assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep) + def test_update(): buf1 = ReplayBuffer(4, stack_num=2) @@ -1269,17 +1305,17 @@ def test_from_data(): if __name__ == '__main__': - test_replaybuffer() - test_ignore_obs_next() - test_stack() - test_segtree() - test_priortized_replaybuffer() - test_update() - test_pickle() - test_hdf5() - test_replaybuffermanager() - test_cachedbuffer() - test_multibuf_stack() - test_multibuf_hdf5() - test_from_data() + # test_replaybuffer() + # test_ignore_obs_next() + # test_stack() + # test_segtree() + # test_priortized_replaybuffer() + # test_update() + # test_pickle() + # test_hdf5() + # test_replaybuffermanager() + # test_cachedbuffer() + # test_multibuf_stack() + # test_multibuf_hdf5() + # test_from_data() test_herreplaybuffer() diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index a4df2d936..1f6f0c8c3 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -91,7 +91,7 @@ def sample_indices(self, batch_size: int) -> np.ndarray: """ self._restore_cache() indices = super().sample_indices(batch_size=batch_size) - self.rewrite_transitions(indices) + self.rewrite_transitions(indices.copy()) return indices def rewrite_transitions(self, indices: np.ndarray) -> None: @@ -105,7 +105,12 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: """ if indices.size == 0: return + + # Sort indices keeping chronological order + mask = indices >= self._index + indices[~mask] += self.maxsize indices = np.sort(indices) + indices[indices >= self.maxsize] -= self.maxsize # Construct episode trajectories indices = [indices] @@ -126,10 +131,8 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: unique_ep_open_indices = np.unique(terminal, return_index=True)[1] unique_ep_indices = indices[:, unique_ep_open_indices] # close indices are used to find max future_t among presented episodes - unique_ep_close_indices = np.hstack( - [(unique_ep_open_indices - 1)[1:], - len(terminal) - 1] - ) + unique_ep_close_indices = unique_ep_open_indices - 1 + unique_ep_close_indices[unique_ep_close_indices < 0] += len(indices[0]) # episode indices that will be altered her_ep_indices = np.random.choice( len(unique_ep_open_indices), @@ -167,7 +170,6 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: assert ep_obs.desired_goal.shape[:2] == unique_ep_indices.shape assert ep_obs.achieved_goal.shape[:2] == unique_ep_indices.shape assert ep_rew.shape == unique_ep_indices.shape - assert np.all(future_t >= indices[0]) # Re-write meta self._meta.obs[unique_ep_indices] = ep_obs From 7802451a0200bdf4e4c0399a204038d9dcf74925 Mon Sep 17 00:00:00 2001 From: juno-t Date: Thu, 6 Oct 2022 04:19:19 +0000 Subject: [PATCH 12/20] update doc, uncomment tests --- test/base/test_buffer.py | 26 +++++++++++++------------- tianshou/data/buffer/her.py | 16 ++++++++-------- tianshou/env/gym_wrappers.py | 4 ++-- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 4ed6f74b5..02d140d1e 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1305,17 +1305,17 @@ def test_from_data(): if __name__ == '__main__': - # test_replaybuffer() - # test_ignore_obs_next() - # test_stack() - # test_segtree() - # test_priortized_replaybuffer() - # test_update() - # test_pickle() - # test_hdf5() - # test_replaybuffermanager() - # test_cachedbuffer() - # test_multibuf_stack() - # test_multibuf_hdf5() - # test_from_data() + test_replaybuffer() + test_ignore_obs_next() + test_stack() + test_segtree() + test_priortized_replaybuffer() + test_update() + test_pickle() + test_hdf5() + test_replaybuffermanager() + test_cachedbuffer() + test_multibuf_stack() + test_multibuf_hdf5() + test_from_data() test_herreplaybuffer() diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 1f6f0c8c3..ab1a64163 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -8,18 +8,18 @@ class HERReplayBuffer(ReplayBuffer): """Implementation of Hindsight Experience Replay. arXiv:1707.01495. - `HERReplayBuffer` is to be used with goal-based environment where the \ - observation is a dictionary with keys `observation`, `achieved_goal` and \ - `desired_goal`. Currently support only HER's future strategy, online sampling. + HERReplayBuffer is to be used with goal-based environment where the + observation is a dictionary with keys ``observation``, ``achieved_goal`` and + ``desired_goal``. Currently support only HER's future strategy, online sampling. :param int size: the size of the replay buffer. - :param compute_reward_fn: a function that takes 2 `np.array` arguments, \ - `acheived_goal` and `desired_goal`, and returns rewards as `np.array`. - The two arguments are of shape (batch_size, *original_shape) and the returned \ + :param compute_reward_fn: a function that takes 2 ``np.array`` arguments, + ``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``. + The two arguments are of shape (batch_size, *original_shape) and the returned rewards must be of shape (batch_size,). :param int horizon: the maximum number of steps in an episode. - :param int future_k: the 'k' parameter introduced in the paper. In short, there \ - will be at most k episodes that are re-written for every 1 unaltered episode \ + :param int future_k: the 'k' parameter introduced in the paper. In short, there + will be at most k episodes that are re-written for every 1 unaltered episode during the sampling. .. seealso:: diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index 8a035194c..b906b79b9 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -59,9 +59,9 @@ def action(self, act: np.ndarray) -> np.ndarray: class TruncatedAsTerminated(gym.Wrapper): - """A wrapper that set `terminated = terminated or truncated` for `step()`. + """A wrapper that set ``terminated = terminated or truncated`` for ``step()``. - It's intended to use with `gym.wrappers.TimeLimit`. + It's intended to use with ``gym.wrappers.TimeLimit``. :param gym.Env env: gym environment. """ From 808760a1f64bb96f270a0d2977de5a6fde8eff50 Mon Sep 17 00:00:00 2001 From: juno-t Date: Sat, 22 Oct 2022 04:45:41 +0000 Subject: [PATCH 13/20] reorganize --- examples/{offline => mujoco}/fetch_her_ddpg.py | 0 tianshou/data/buffer/her.py | 3 +-- 2 files changed, 1 insertion(+), 2 deletions(-) rename examples/{offline => mujoco}/fetch_her_ddpg.py (100%) diff --git a/examples/offline/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py similarity index 100% rename from examples/offline/fetch_her_ddpg.py rename to examples/mujoco/fetch_her_ddpg.py diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index ab1a64163..1d1a42c26 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -107,8 +107,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: return # Sort indices keeping chronological order - mask = indices >= self._index - indices[~mask] += self.maxsize + indices[indices < self._index] += self.maxsize indices = np.sort(indices) indices[indices >= self.maxsize] -= self.maxsize From 3a98932595ea2f32ab6023f9e63eadcb1a9a3215 Mon Sep 17 00:00:00 2001 From: juno-t Date: Tue, 25 Oct 2022 10:47:13 +0000 Subject: [PATCH 14/20] debug her --- tianshou/data/buffer/her.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 1d1a42c26..d2931ec61 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -127,11 +127,13 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: # Compute indices # open indices are used to find longest, unique trajectories among # presented episodes - unique_ep_open_indices = np.unique(terminal, return_index=True)[1] + unique_ep_open_indices = np.sort(np.unique(terminal, return_index=True)[1]) unique_ep_indices = indices[:, unique_ep_open_indices] # close indices are used to find max future_t among presented episodes - unique_ep_close_indices = unique_ep_open_indices - 1 - unique_ep_close_indices[unique_ep_close_indices < 0] += len(indices[0]) + unique_ep_close_indices = np.hstack( + [(unique_ep_open_indices - 1)[1:], + len(terminal) - 1] + ) # episode indices that will be altered her_ep_indices = np.random.choice( len(unique_ep_open_indices), From 4a88d7646b0cb8e70d7a31d05970f317cbdf1ee9 Mon Sep 17 00:00:00 2001 From: juno-t Date: Thu, 27 Oct 2022 05:37:10 +0000 Subject: [PATCH 15/20] refactor example --- examples/mujoco/fetch_her_ddpg.py | 180 ++++++++++++------------------ tianshou/utils/net/common.py | 51 +++++++++ 2 files changed, 125 insertions(+), 106 deletions(-) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index c08a6ff21..d133328f7 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -4,45 +4,55 @@ import datetime import os import pprint -from typing import Any, Dict, Optional, Sequence, Tuple, Union import gym import numpy as np +import wandb import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, HERReplayBuffer, HERVectorReplayBuffer -from tianshou.data.batch import Batch +from tianshou.data import ( + Collector, + HERReplayBuffer, + HERVectorReplayBuffer, + ReplayBuffer, + VectorReplayBuffer +) from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy from tianshou.trainer import offpolicy_trainer from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic def get_args(): - # python3 fetch_her_ddpg.py --task FetchReach-v3 --seed 0 --horizon 50 parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="FetchReach-v3") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--buffer-size", type=int, default=1000000) + parser.add_argument("--buffer-size", type=int, default=100000) parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256]) parser.add_argument("--actor-lr", type=float, default=1e-3) - parser.add_argument("--critic-lr", type=float, default=1e-3) + parser.add_argument("--critic-lr", type=float, default=3e-3) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--tau", type=float, default=0.005) parser.add_argument("--exploration-noise", type=float, default=0.1) parser.add_argument("--start-timesteps", type=int, default=25000) - parser.add_argument("--epoch", type=int, default=200) + parser.add_argument("--epoch", type=int, default=10) parser.add_argument("--step-per-epoch", type=int, default=5000) parser.add_argument("--step-per-collect", type=int, default=1) parser.add_argument("--update-per-step", type=int, default=1) parser.add_argument("--n-step", type=int, default=1) - parser.add_argument("--batch-size", type=int, default=256) - parser.add_argument("--horizon", type=int, default=50) - parser.add_argument("--future-k", type=int, default=8) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument( + "--replay-buffer", + type=str, + default="her", + choices=["normal", "her"] + ) + parser.add_argument("--her-horizon", type=int, default=50) + parser.add_argument("--her-future-k", type=int, default=8) parser.add_argument("--training-num", type=int, default=1) parser.add_argument("--test-num", type=int, default=10) parser.add_argument("--logdir", type=str, default="log") @@ -79,59 +89,31 @@ def make_fetch_env(task, training_num, test_num): return env, train_envs, test_envs -class DictStateNet(Net): - - def __init__( - self, state_shape: Dict[str, Union[int, Sequence[int]]], keys: Sequence[str], - **kwargs - ) -> None: - self.keys = keys - self.original_shape = state_shape - flat_state_shape = [] - for k in self.keys: - flat_state_shape.append(int(np.prod(state_shape[k]))) - super().__init__(sum(flat_state_shape), **kwargs) - - def preprocess_obs(self, obs): - if isinstance(obs, dict) or (isinstance(obs, Batch) and self.keys[0] in obs): - if self.original_shape[self.keys[0]] == obs[self.keys[0]].shape: - # No batch dim - new_obs = torch.Tensor([obs[k] for k in self.keys]).flatten() - # new_obs = torch.Tensor([obs[k] for k in self.keys]).reshape(1, -1) - else: - bsz = obs[self.keys[0]].shape[0] - new_obs = torch.cat( - [torch.Tensor(obs[k].reshape(bsz, -1)) for k in self.keys], axis=1 - ) - else: - new_obs = obs - return new_obs - - def forward( - self, - obs: Union[Dict[str, Union[np.ndarray, torch.Tensor]], Union[np.ndarray, - torch.Tensor]], - state: Any = None, - info: Dict[str, Any] = {}, - ) -> Tuple[torch.Tensor, Any]: - return super().forward(self.preprocess_obs(obs), state, info) - - -class GoalStateCritic(Critic): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward( - self, - obs: Union[np.ndarray, torch.Tensor], - act: Optional[Union[np.ndarray, torch.Tensor]] = None, - info: Dict[str, Any] = {}, - ) -> torch.Tensor: - return super().forward(self.preprocess.preprocess_obs(obs), act, info) +def test_ddpg(args=get_args()): + # log + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + args.algo_name = "ddpg" + log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) + log_path = os.path.join(args.logdir, log_name) + # logger + if args.logger == "wandb": + logger = WandbLogger( + save_interval=1, + name=log_name.replace(os.path.sep, "__"), + run_id=args.resume_id, + config=args, + project=args.wandb_project, + ) + logger.wandb_run.config.setdefaults(vars(args)) + args = argparse.Namespace(**wandb.config) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + if args.logger == "tensorboard": + logger = TensorboardLogger(writer) + else: # wandb + logger.load(writer) -def test_ddpg(args=get_args()): env, train_envs, test_envs = make_fetch_env( args.task, args.training_num, args.test_num ) @@ -150,25 +132,27 @@ def test_ddpg(args=get_args()): np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = DictStateNet( - args.state_shape, - keys=['observation', 'achieved_goal', 'desired_goal'], + dict_state_dec, flat_state_shape = get_dict_state_decorator( + state_shape=args.state_shape, + keys=['observation', 'achieved_goal', 'desired_goal'] + ) + net_a = dict_state_dec(Net)( + flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device ) - actor = Actor( + actor = dict_state_dec(Actor)( net_a, args.action_shape, max_action=args.max_action, device=args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c = DictStateNet( - args.state_shape, - keys=['observation', 'achieved_goal', 'desired_goal'], + net_c = dict_state_dec(Net)( + flat_state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) - critic = GoalStateCritic(net_c, device=args.device).to(args.device) + critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, @@ -191,47 +175,31 @@ def test_ddpg(args=get_args()): def compute_reward_fn(ag: np.ndarray, g: np.ndarray): return env.compute_reward(ag, g, {}) - if args.training_num > 1: - buffer = HERVectorReplayBuffer( - args.buffer_size, - len(train_envs), - compute_reward_fn=compute_reward_fn, - horizon=args.horizon, - future_k=args.future_k, - ) + if args.replay_buffer == "normal": + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) else: - buffer = HERReplayBuffer( - args.buffer_size, - compute_reward_fn=compute_reward_fn, - horizon=args.horizon, - future_k=args.future_k, - ) + if args.training_num > 1: + buffer = HERVectorReplayBuffer( + args.buffer_size, + len(train_envs), + compute_reward_fn=compute_reward_fn, + horizon=args.her_horizon, + future_k=args.her_future_k, + ) + else: + buffer = HERReplayBuffer( + args.buffer_size, + compute_reward_fn=compute_reward_fn, + horizon=args.her_horizon, + future_k=args.her_future_k, + ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, test_envs) train_collector.collect(n_step=args.start_timesteps, random=True) - # log - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - args.algo_name = "ddpg" - log_name = os.path.join(args.task, args.algo_name, str(args.seed), now) - log_path = os.path.join(args.logdir, log_name) - - # logger - if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) - def save_best_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 1fc58f768..2c50dacab 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -9,11 +9,14 @@ Union, no_type_check, ) +from types import MethodType import numpy as np import torch from torch import nn +from tianshou.data.batch import Batch + ModuleType = Type[nn.Module] @@ -453,3 +456,51 @@ def forward( action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True) logits = value_out + action_scores return logits, state + + +def get_dict_state_decorator( + state_shape: Dict[str, Union[int, Sequence[int]]], + keys: Sequence[str] +) -> Net: + """A helper function to make Net or equivalent classes applicable to dict state. + + :param state_shape: A dictionary indicating each state's shape + :param keys: A list of state's keys. The flatten observation will be according to + this list order. + :returns: a 2-items tuple decorator_fn and new_state_shape + """ + original_shape = state_shape + flat_state_shapes = [] + for k in keys: + flat_state_shapes.append(int(np.prod(state_shape[k]))) + new_state_shape = sum(flat_state_shapes) + + def preprocess_obs( + obs: Union[Batch, dict, torch.Tensor, np.ndarray] + ) -> torch.Tensor: + if isinstance(obs, dict) or (isinstance(obs, Batch) and keys[0] in obs): + if original_shape[keys[0]] == obs[keys[0]].shape: + # No batch dim + new_obs = torch.Tensor([obs[k] for k in keys]).flatten() + # new_obs = torch.Tensor([obs[k] for k in keys]).reshape(1, -1) + else: + bsz = obs[keys[0]].shape[0] + new_obs = torch.cat( + [torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], axis=1 + ) + else: + new_obs = obs + return new_obs + + def decorator_fn(net_class) -> Net: + class new_net_class(net_class): + def forward( + self, + obs: Union[np.ndarray, torch.Tensor], + *args, + **kwargs, + ) -> Any: + return super().forward(preprocess_obs(obs), *args, **kwargs) + return new_net_class + + return decorator_fn, new_state_shape From b08fac795ecc7d0edfe9a804c0e7174eeeb75fa8 Mon Sep 17 00:00:00 2001 From: juno-t Date: Thu, 27 Oct 2022 08:48:37 +0000 Subject: [PATCH 16/20] format --- docs/api/tianshou.data.rst | 25 +++++++++++++++++++++++++ examples/mujoco/fetch_her_ddpg.py | 13 ++++--------- tianshou/data/buffer/her.py | 2 +- tianshou/utils/net/common.py | 30 ++++++++++++++++++++---------- 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/docs/api/tianshou.data.rst b/docs/api/tianshou.data.rst index 77c69aa15..63dfa91d6 100644 --- a/docs/api/tianshou.data.rst +++ b/docs/api/tianshou.data.rst @@ -30,6 +30,14 @@ PrioritizedReplayBuffer :undoc-members: :show-inheritance: +HERReplayBuffer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.HERReplayBuffer + :members: + :undoc-members: + :show-inheritance: + ReplayBufferManager ~~~~~~~~~~~~~~~~~~~ @@ -46,6 +54,15 @@ PrioritizedReplayBufferManager :undoc-members: :show-inheritance: + +HERReplayBufferManager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.HERReplayBufferManager + :members: + :undoc-members: + :show-inheritance: + VectorReplayBuffer ~~~~~~~~~~~~~~~~~~ @@ -62,6 +79,14 @@ PrioritizedVectorReplayBuffer :undoc-members: :show-inheritance: +HERVectorReplayBuffer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.HERVectorReplayBuffer + :members: + :undoc-members: + :show-inheritance: + CachedReplayBuffer ~~~~~~~~~~~~~~~~~~ diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index d133328f7..2d4f9434a 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -7,16 +7,16 @@ import gym import numpy as np -import wandb import torch from torch.utils.tensorboard import SummaryWriter +import wandb from tianshou.data import ( Collector, HERReplayBuffer, HERVectorReplayBuffer, ReplayBuffer, - VectorReplayBuffer + VectorReplayBuffer, ) from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.exploration import GaussianNoise @@ -46,10 +46,7 @@ def get_args(): parser.add_argument("--n-step", type=int, default=1) parser.add_argument("--batch-size", type=int, default=512) parser.add_argument( - "--replay-buffer", - type=str, - default="her", - choices=["normal", "her"] + "--replay-buffer", type=str, default="her", choices=["normal", "her"] ) parser.add_argument("--her-horizon", type=int, default=50) parser.add_argument("--her-future-k", type=int, default=8) @@ -137,9 +134,7 @@ def test_ddpg(args=get_args()): keys=['observation', 'achieved_goal', 'desired_goal'] ) net_a = dict_state_dec(Net)( - flat_state_shape, - hidden_sizes=args.hidden_sizes, - device=args.device + flat_state_shape, hidden_sizes=args.hidden_sizes, device=args.device ) actor = dict_state_dec(Actor)( net_a, args.action_shape, max_action=args.max_action, device=args.device diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index d2931ec61..8c5c37166 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -15,7 +15,7 @@ class HERReplayBuffer(ReplayBuffer): :param int size: the size of the replay buffer. :param compute_reward_fn: a function that takes 2 ``np.array`` arguments, ``acheived_goal`` and ``desired_goal``, and returns rewards as ``np.array``. - The two arguments are of shape (batch_size, *original_shape) and the returned + The two arguments are of shape (batch_size, ...original_shape) and the returned rewards must be of shape (batch_size,). :param int horizon: the maximum number of steps in an episode. :param int future_k: the 'k' parameter introduced in the paper. In short, there diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 2c50dacab..b0c30806d 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,5 +1,6 @@ from typing import ( Any, + Callable, Dict, List, Optional, @@ -9,7 +10,6 @@ Union, no_type_check, ) -from types import MethodType import numpy as np import torch @@ -459,15 +459,21 @@ def forward( def get_dict_state_decorator( - state_shape: Dict[str, Union[int, Sequence[int]]], - keys: Sequence[str] -) -> Net: - """A helper function to make Net or equivalent classes applicable to dict state. + state_shape: Dict[str, Union[int, Sequence[int]]], keys: Sequence[str] +) -> Tuple[Callable, int]: + """A helper function to make Net or equivalent classes (e.g. Actor, Critic) \ + applicable to dict state. + + The first return item, ``decorator_fn``, will alter the implementation of forward + function of the given class by preprocessing the observation. The preprocessing is + basically flatten the observation and concatenate them based on the ``keys`` order. + The batch dimension is preserved if presented. The result observation shape will + be equal to ``new_state_shape``, the second return item. :param state_shape: A dictionary indicating each state's shape - :param keys: A list of state's keys. The flatten observation will be according to + :param keys: A list of state's keys. The flatten observation will be according to \ this list order. - :returns: a 2-items tuple decorator_fn and new_state_shape + :returns: a 2-items tuple ``decorator_fn`` and ``new_state_shape`` """ original_shape = state_shape flat_state_shapes = [] @@ -486,14 +492,17 @@ def preprocess_obs( else: bsz = obs[keys[0]].shape[0] new_obs = torch.cat( - [torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], axis=1 + [torch.Tensor(obs[k].reshape(bsz, -1)) for k in keys], dim=1 ) else: - new_obs = obs + new_obs = torch.Tensor(obs) return new_obs - def decorator_fn(net_class) -> Net: + @no_type_check + def decorator_fn(net_class): + class new_net_class(net_class): + def forward( self, obs: Union[np.ndarray, torch.Tensor], @@ -501,6 +510,7 @@ def forward( **kwargs, ) -> Any: return super().forward(preprocess_obs(obs), *args, **kwargs) + return new_net_class return decorator_fn, new_state_shape From 1798b908ae4755540c9fdc56ef58d07c0b296252 Mon Sep 17 00:00:00 2001 From: Juno T <42699114+Juno-T@users.noreply.github.com> Date: Thu, 27 Oct 2022 18:21:06 +0900 Subject: [PATCH 17/20] add HER section --- examples/mujoco/README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md index 12b480a9d..e38107e6b 100644 --- a/examples/mujoco/README.md +++ b/examples/mujoco/README.md @@ -20,6 +20,7 @@ Supported algorithms are listed below: - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0) - [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32) +- [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495) ## EnvPool @@ -304,6 +305,18 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai 1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are. 2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general. +## Others + +### HER +| Environment | DDPG without HER | DDPG with HER | +| :--------------------: | :--------------: | :--------------: | +| FetchReach | -49.9±0.2. | **-17.6±21.7** | + +#### Hints for HER +1. The HER technique is proposed for solving task-based environments, so it cannot be compared with non-task-based mujoco benchmarks. The environment used in this evaluation is ``FetchReach-v3`` which requires an extra [installation](https://github.com/Farama-Foundation/Gymnasium-Robotics). +2. Simple hyperparameters optimizations are done for both settings, DDPG with and without HER. However, since *DDPG without HER* failed in every experiment, the best hyperparameters for *DDPG with HER* are used in the evaluation of both settings. +3. The scores are the mean reward ± 1 standard deviation of 16 seeds. The minimum reward for ``FetchReach-v3`` is -50 which we can imply that *DDPG without HER* performs as good as a random policy. *DDPG with HER* although has a better mean reward, the standard deviation is quite high. This is because in this setting, the agent will either fail completely (-50 reward) or successfully learn the task (close to 0 reward). This means that the agent successfully learned in about 70% of the 16 seeds. + ## Note [1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures. From 9082db779eacade5ecd27e76fc76df9f312a82cb Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 29 Oct 2022 09:36:09 -0700 Subject: [PATCH 18/20] make linter happy --- examples/mujoco/fetch_her_ddpg.py | 2 +- tianshou/utils/logger/base.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 2d4f9434a..893912566 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -8,9 +8,9 @@ import gym import numpy as np import torch +import wandb from torch.utils.tensorboard import SummaryWriter -import wandb from tianshou.data import ( Collector, HERReplayBuffer, diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 0eca53c0c..2dc47ed21 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -89,6 +89,7 @@ def log_update_data(self, update_result: dict, step: int) -> None: self.write("update/gradient_step", step, log_data) self.last_log_update_step = step + @abstractmethod def save_data( self, epoch: int, @@ -106,6 +107,7 @@ def save_data( """ pass + @abstractmethod def restore_data(self) -> Tuple[int, int, int]: """Return the metadata from existing log. @@ -126,3 +128,15 @@ def __init__(self) -> None: def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: """The LazyLogger writes nothing.""" pass + + def save_data( + self, + epoch: int, + env_step: int, + gradient_step: int, + save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, + ) -> None: + pass + + def restore_data(self) -> Tuple[int, int, int]: + pass From 20ba3aba83bed4f1acf363d20c513809b510d57f Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 29 Oct 2022 10:21:04 -0700 Subject: [PATCH 19/20] make mypy happy --- tianshou/utils/net/common.py | 2 +- tianshou/utils/net/continuous.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index b0c30806d..a6a018535 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -265,7 +265,7 @@ def forward( """ obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index bf083c32a..fb75e3317 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -124,13 +124,13 @@ def forward( """Mapping: (s, a) -> logits -> Q(s, a).""" obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ).flatten(1) if act is not None: act = torch.as_tensor( act, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) @@ -266,7 +266,7 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -339,7 +339,7 @@ def forward( """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" obs = torch.as_tensor( obs, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) # obs [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -352,7 +352,7 @@ def forward( if act is not None: act = torch.as_tensor( act, - device=self.device, # type: ignore + device=self.device, dtype=torch.float32, ) obs = torch.cat([obs, act], dim=1) From aaf0f73708dce245c2dc2fd29d758ff24cd75b40 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Sat, 29 Oct 2022 10:44:33 -0700 Subject: [PATCH 20/20] add to word list --- docs/spelling_wordlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index eb3ac4a23..c63486213 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -52,6 +52,7 @@ mujoco jit nstep preprocess +preprocessing repo ReLU namespace