这是indexloc提供的服务,不要输入任何密码
Skip to content

Enable venvs.reset() concurrent execution #517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ preprocess
repo
ReLU
namespace
recv
th
utils
NaN
Expand Down
12 changes: 8 additions & 4 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def seed(self, seed=0):

def reset(self, state=0):
self.done = False
self.do_sleep()
self.index = state
return self._get_state()

Expand Down Expand Up @@ -116,16 +117,19 @@ def _get_state(self):
else:
return np.array([self.index], dtype=np.float32)

def do_sleep(self):
if self.sleep > 0:
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)

def step(self, action):
self.steps += 1
if self._md_action:
action = action[0]
if self.done:
raise ValueError('step after done !!!')
if self.sleep > 0:
sleep_time = random.random() if self.random_sleep else 1
sleep_time *= self.sleep
time.sleep(sleep_time)
self.do_sleep()
if self.index == self.size:
self.done = True
return self._get_state(), self._get_reward(), self.done, {}
Expand Down
5 changes: 5 additions & 0 deletions test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
for cls in test_cls:
pass_check = 1
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
t = time.time()
v.reset()
t = time.time() - t
print(f"{cls} reset {t}")
if t > sleep * 9: # huge than maximum sleep time (7 sleep)
pass_check = 0
expect_result = [
[0, 1],
[0, 1, 2],
Expand Down
13 changes: 8 additions & 5 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def reset(
id = self._wrap_id(id)
if self.is_async:
self._assert_id(id)
obs_list = [self.workers[i].reset() for i in id]
# send(None) == reset() in worker
for i in id:
self.workers[i].send(None)
obs_list = [self.workers[i].recv() for i in id]
try:
obs = np.stack(obs_list)
except ValueError: # different len(obs)
Expand Down Expand Up @@ -258,18 +261,18 @@ def step(
if not self.is_async:
assert len(action) == len(id)
for i, j in enumerate(id):
self.workers[j].send_action(action[i])
self.workers[j].send(action[i])
result = []
for j in id:
obs, rew, done, info = self.workers[j].get_result()
obs, rew, done, info = self.workers[j].recv()
info["env_id"] = j
result.append((obs, rew, done, info))
else:
if action is not None:
self._assert_id(id)
assert len(action) == len(id)
for act, env_id in zip(action, id):
self.workers[env_id].send_action(act)
self.workers[env_id].send(act)
self.waiting_conn.append(self.workers[env_id])
self.waiting_id.append(env_id)
self.ready_id = [x for x in self.ready_id if x not in id]
Expand All @@ -283,7 +286,7 @@ def step(
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index)
obs, rew, done, info = conn.get_result()
obs, rew, done, info = conn.recv()
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)
Expand Down
37 changes: 25 additions & 12 deletions tianshou/env/worker/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union

import gym
import numpy as np
Expand All @@ -23,28 +23,41 @@ def set_env_attr(self, key: str, value: Any) -> None:
pass

@abstractmethod
def reset(self) -> Any:
pass
def send(self, action: Optional[np.ndarray]) -> None:
"""Send action signal to low-level worker.

@abstractmethod
def send_action(self, action: np.ndarray) -> None:
When action is None, it indicates sending "reset" signal; otherwise
it indicates "step" signal. The paired return value from "recv"
function is determined by such kind of different signal.
"""
pass

def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
"""Receive result from low-level worker.

If the last "send" function sends a NULL action, it only returns a
single observation; otherwise it returns a tuple of (obs, rew, done,
info).
"""
return self.result

def reset(self) -> np.ndarray:
self.send(None)
return self.recv() # type: ignore

def step(
self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Perform one timestep of the environment's dynamic.

"send_action" and "get_result" are coupled in sync simulation, so
typically users only call "step" function. But they can be called
separately in async simulation, i.e. someone calls "send_action" first,
and calls "get_result" later.
"send" and "recv" are coupled in sync simulation, so users only call
"step" function. But they can be called separately in async
simulation, i.e. someone calls "send" first, and calls "recv" later.
"""
self.send_action(action)
return self.get_result()
self.send(action)
return self.recv() # type: ignore

@staticmethod
def wait(
Expand Down
7 changes: 5 additions & 2 deletions tianshou/env/worker/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ def wait( # type: ignore
# Sequential EnvWorker objects are always ready
return workers

def send_action(self, action: np.ndarray) -> None:
self.result = self.env.step(action)
def send(self, action: Optional[np.ndarray]) -> None:
if action is None:
self.result = self.env.reset()
else:
self.result = self.env.step(action)

def seed(self, seed: Optional[int] = None) -> List[int]:
super().seed(seed)
Expand Down
15 changes: 10 additions & 5 deletions tianshou/env/worker/ray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union

import gym
import numpy as np
Expand Down Expand Up @@ -44,11 +44,16 @@ def wait( # type: ignore
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
return [workers[results.index(result)] for result in ready_results]

def send_action(self, action: np.ndarray) -> None:
def send(self, action: Optional[np.ndarray]) -> None:
# self.action is actually a handle
self.result = self.env.step.remote(action)

def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if action is None:
self.result = self.env.reset.remote()
else:
self.result = self.env.step.remote(action)

def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
return ray.get(self.result)

def seed(self, seed: Optional[int] = None) -> List[int]:
Expand Down
45 changes: 24 additions & 21 deletions tianshou/env/worker/subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,17 @@ def _encode_obs(
p.close()
break
if cmd == "step":
obs, reward, done, info = env.step(data)
if data is None: # reset
obs = env.reset()
else:
obs, reward, done, info = env.step(data)
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
p.send((obs, reward, done, info))
elif cmd == "reset":
obs = env.reset()
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
p.send(obs)
if data is None:
p.send(obs)
else:
p.send((obs, reward, done, info))
elif cmd == "close":
p.send(env.close())
p.close()
Expand Down Expand Up @@ -140,6 +140,7 @@ def __init__(
self.process = Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
self.is_reset = False
super().__init__(env_fn)

def get_env_attr(self, key: str) -> Any:
Expand All @@ -165,13 +166,6 @@ def decode_obs(

return decode_obs(self.buffer)

def reset(self) -> Any:
self.parent_remote.send(["reset", None])
obs = self.parent_remote.recv()
if self.share_memory:
obs = self._decode_obs()
return obs

@staticmethod
def wait( # type: ignore
workers: List["SubprocEnvWorker"],
Expand All @@ -192,14 +186,23 @@ def wait( # type: ignore
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]

def send_action(self, action: np.ndarray) -> None:
def send(self, action: Optional[np.ndarray]) -> None:
self.parent_remote.send(["step", action])

def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
obs, rew, done, info = self.parent_remote.recv()
if self.share_memory:
obs = self._decode_obs()
return obs, rew, done, info
def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
result = self.parent_remote.recv()
if isinstance(result, tuple):
obs, rew, done, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, rew, done, info
else:
obs = result
if self.share_memory:
obs = self._decode_obs()
return obs

def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed)
Expand Down