From d3e2d157af1cc412c46e7b6aeb384f01a88a61b1 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Mon, 7 Feb 2022 09:32:25 -0500 Subject: [PATCH 1/5] change internal API for venv worker --- docs/spelling_wordlist.txt | 1 + tianshou/env/venvs.py | 8 +++--- tianshou/env/worker/base.py | 37 +++++++++++++++++++--------- tianshou/env/worker/dummy.py | 7 ++++-- tianshou/env/worker/ray.py | 15 ++++++++---- tianshou/env/worker/subproc.py | 45 ++++++++++++++++++---------------- 6 files changed, 69 insertions(+), 44 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 7a4a3f7d8..b6d419e00 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -48,6 +48,7 @@ preprocess repo ReLU namespace +recv th utils NaN diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 39421468f..62f385e3b 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -258,10 +258,10 @@ def step( if not self.is_async: assert len(action) == len(id) for i, j in enumerate(id): - self.workers[j].send_action(action[i]) + self.workers[j].send(action[i]) result = [] for j in id: - obs, rew, done, info = self.workers[j].get_result() + obs, rew, done, info = self.workers[j].recv() info["env_id"] = j result.append((obs, rew, done, info)) else: @@ -269,7 +269,7 @@ def step( self._assert_id(id) assert len(action) == len(id) for act, env_id in zip(action, id): - self.workers[env_id].send_action(act) + self.workers[env_id].send(act) self.waiting_conn.append(self.workers[env_id]) self.waiting_id.append(env_id) self.ready_id = [x for x in self.ready_id if x not in id] @@ -283,7 +283,7 @@ def step( waiting_index = self.waiting_conn.index(conn) self.waiting_conn.pop(waiting_index) env_id = self.waiting_id.pop(waiting_index) - obs, rew, done, info = conn.get_result() + obs, rew, done, info = conn.recv() info["env_id"] = env_id result.append((obs, rew, done, info)) self.ready_id.append(env_id) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 3c63be997..f23f8969a 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import gym import numpy as np @@ -23,28 +23,41 @@ def set_env_attr(self, key: str, value: Any) -> None: pass @abstractmethod - def reset(self) -> Any: - pass + def send(self, action: Optional[np.ndarray]) -> None: + """Send action signal to low-level worker. - @abstractmethod - def send_action(self, action: np.ndarray) -> None: + When action is None, it indicates sending "reset" signal; otherwise + it indicates "step" signal. The paired return value from "recv" + function is determined by such kind of different signal. + """ pass - def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def recv( + self + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: + """Receive result from low-level worker. + + If the last "send" function sends a NULL action, it only returns a + single observation; otherwise it returns a tuple of (obs, rew, done, + info). + """ return self.result + def reset(self) -> np.ndarray: + self.send(None) + return self.recv() # type: ignore + def step( self, action: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Perform one timestep of the environment's dynamic. - "send_action" and "get_result" are coupled in sync simulation, so - typically users only call "step" function. But they can be called - separately in async simulation, i.e. someone calls "send_action" first, - and calls "get_result" later. + "send" and "recv" are coupled in sync simulation, so users only call + "step" function. But they can be called separately in async + simulation, i.e. someone calls "send" first, and calls "recv" later. """ - self.send_action(action) - return self.get_result() + self.send(action) + return self.recv() # type: ignore @staticmethod def wait( diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 542c70210..958f6e907 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -29,8 +29,11 @@ def wait( # type: ignore # Sequential EnvWorker objects are always ready return workers - def send_action(self, action: np.ndarray) -> None: - self.result = self.env.step(action) + def send(self, action: Optional[np.ndarray]) -> None: + if action is None: + self.result = self.env.reset() + else: + self.result = self.env.step(action) def seed(self, seed: Optional[int] = None) -> List[int]: super().seed(seed) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 7917683be..1bc7991aa 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import gym import numpy as np @@ -44,11 +44,16 @@ def wait( # type: ignore ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) return [workers[results.index(result)] for result in ready_results] - def send_action(self, action: np.ndarray) -> None: + def send(self, action: Optional[np.ndarray]) -> None: # self.action is actually a handle - self.result = self.env.step.remote(action) - - def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + if action is None: + self.result = self.env.reset.remote() + else: + self.result = self.env.step.remote(action) + + def recv( + self + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: return ray.get(self.result) def seed(self, seed: Optional[int] = None) -> List[int]: diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 61a69cafd..779b78e34 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -86,17 +86,17 @@ def _encode_obs( p.close() break if cmd == "step": - obs, reward, done, info = env.step(data) + if data is None: # reset + obs = env.reset() + else: + obs, reward, done, info = env.step(data) if obs_bufs is not None: _encode_obs(obs, obs_bufs) obs = None - p.send((obs, reward, done, info)) - elif cmd == "reset": - obs = env.reset() - if obs_bufs is not None: - _encode_obs(obs, obs_bufs) - obs = None - p.send(obs) + if data is None: + p.send(obs) + else: + p.send((obs, reward, done, info)) elif cmd == "close": p.send(env.close()) p.close() @@ -140,6 +140,7 @@ def __init__( self.process = Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() + self.is_reset = False super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: @@ -165,13 +166,6 @@ def decode_obs( return decode_obs(self.buffer) - def reset(self) -> Any: - self.parent_remote.send(["reset", None]) - obs = self.parent_remote.recv() - if self.share_memory: - obs = self._decode_obs() - return obs - @staticmethod def wait( # type: ignore workers: List["SubprocEnvWorker"], @@ -192,14 +186,23 @@ def wait( # type: ignore remain_conns = [conn for conn in remain_conns if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] - def send_action(self, action: np.ndarray) -> None: + def send(self, action: Optional[np.ndarray]) -> None: self.parent_remote.send(["step", action]) - def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - obs, rew, done, info = self.parent_remote.recv() - if self.share_memory: - obs = self._decode_obs() - return obs, rew, done, info + def recv( + self + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: + result = self.parent_remote.recv() + if isinstance(result, tuple): + obs, rew, done, info = result + if self.share_memory: + obs = self._decode_obs() + return obs, rew, done, info + else: + obs = result + if self.share_memory: + obs = self._decode_obs() + return obs def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) From 36eacdf595d505ab57da394b2b8acf08cef104f9 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Mon, 7 Feb 2022 09:45:11 -0500 Subject: [PATCH 2/5] add time check for venvs.reset() --- test/base/env.py | 12 ++++++++---- test/base/test_env.py | 5 +++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 3ae031c59..0a649e546 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -81,6 +81,7 @@ def seed(self, seed=0): def reset(self, state=0): self.done = False + self.do_sleep() self.index = state return self._get_state() @@ -116,16 +117,19 @@ def _get_state(self): else: return np.array([self.index], dtype=np.float32) + def do_sleep(self): + if self.sleep > 0: + sleep_time = random.random() if self.random_sleep else 1 + sleep_time *= self.sleep + time.sleep(sleep_time) + def step(self, action): self.steps += 1 if self._md_action: action = action[0] if self.done: raise ValueError('step after done !!!') - if self.sleep > 0: - sleep_time = random.random() if self.random_sleep else 1 - sleep_time *= self.sleep - time.sleep(sleep_time) + self.do_sleep() if self.index == self.size: self.done = True return self._get_state(), self._get_reward(), self.done, {} diff --git a/test/base/test_env.py b/test/base/test_env.py index f0471e797..e9fc728d9 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -96,7 +96,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7): for cls in test_cls: pass_check = 1 v = cls(env_fns, wait_num=num - 1, timeout=timeout) + t = time.time() v.reset() + t = time.time() - t + print(f"{cls} reset {t}") + if t > sleep * 9: # huge than maximum sleep time (7 sleep) + pass_check = 0 expect_result = [ [0, 1], [0, 1, 2], From fd8eee9fe72cfe834ac8fdf1955cc6ce86236d53 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Mon, 7 Feb 2022 09:47:57 -0500 Subject: [PATCH 3/5] concurrent reset --- tianshou/env/venvs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 62f385e3b..3799d0af9 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -209,7 +209,13 @@ def reset( id = self._wrap_id(id) if self.is_async: self._assert_id(id) - obs_list = [self.workers[i].reset() for i in id] + # send(None) == reset() in worker + for i, j in enumerate(id): + self.workers[j].send(None) + obs_list = [] + for j in id: + obs = self.workers[j].recv() + obs_list.append(obs) try: obs = np.stack(obs_list) except ValueError: # different len(obs) From 41bd71c3b71b4aec369dee683ce23ab12d7e77a9 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Mon, 7 Feb 2022 10:04:03 -0500 Subject: [PATCH 4/5] fix lint --- tianshou/env/venvs.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 3799d0af9..15c727d86 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -210,14 +210,11 @@ def reset( if self.is_async: self._assert_id(id) # send(None) == reset() in worker - for i, j in enumerate(id): - self.workers[j].send(None) - obs_list = [] - for j in id: - obs = self.workers[j].recv() - obs_list.append(obs) + for i in id: + self.workers[i].send(None) + obs_list = [self.workers[i].recv() for i in id] try: - obs = np.stack(obs_list) + obs = np.stack(obs_list) # type: ignore except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) if self.obs_rms and self.update_obs_rms: From 7cc34876e267e2e0886ee31f0f91f1d4afcd8ad1 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Mon, 7 Feb 2022 10:10:52 -0500 Subject: [PATCH 5/5] fix lint --- tianshou/env/venvs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 15c727d86..c668109b6 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -214,7 +214,7 @@ def reset( self.workers[i].send(None) obs_list = [self.workers[i].recv() for i in id] try: - obs = np.stack(obs_list) # type: ignore + obs = np.stack(obs_list) except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) if self.obs_rms and self.update_obs_rms: