这是indexloc提供的服务,不要输入任何密码
Skip to content

fix #166, assure close in vec env #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tianshou/env/vecenv/asyncenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def _assert_and_transform_id(self,
return id

def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
self._assert_is_closed()
id = self._assert_and_transform_id(id)
return super().reset(id)

def render(self, **kwargs) -> List[Any]:
self._assert_is_closed()
if len(self.waiting_id) > 0:
raise RuntimeError(
f"Environments {self.waiting_id} are still "
Expand All @@ -80,6 +82,7 @@ def step(self,
(initially they are env_ids of all the environments). If action is
``None``, fetch unfinished step() calls instead.
"""
self._assert_is_closed()
if action is not None:
id = self._assert_and_transform_id(id)
assert len(action) == len(id)
Expand Down
4 changes: 4 additions & 0 deletions tianshou/env/vecenv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __len__(self) -> int:
"""Return len(self), which is the number of environments."""
return self.env_num

def __del__(self):
"""Close the environment before garbage collected"""
self.close()

def __getattribute__(self, key: str):
"""Switch between the default attribute getter or one
looking at wrapped environment level depending on the key."""
Expand Down
2 changes: 2 additions & 0 deletions tianshou/env/vecenv/shmemenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
self._assert_is_closed()
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
Expand All @@ -140,6 +141,7 @@ def step(self,
return obs, rew, done, info

def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
self._assert_is_closed()
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
Expand Down
10 changes: 10 additions & 0 deletions tianshou/env/vecenv/subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
c.close()

def __getattr__(self, key):
self._assert_is_closed()
for p in self.parent_remote:
p.send(['getattr', key])
return [p.recv() for p in self.parent_remote]
Expand All @@ -68,6 +69,7 @@ def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
self._assert_is_closed()
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
Expand All @@ -80,6 +82,7 @@ def step(self,
return obs, rew, done, info

def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
self._assert_is_closed()
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
Expand All @@ -90,6 +93,7 @@ def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
return obs

def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
self._assert_is_closed()
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
Expand All @@ -99,11 +103,13 @@ def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
return [p.recv() for p in self.parent_remote]

def render(self, **kwargs) -> List[Any]:
self._assert_is_closed()
for p in self.parent_remote:
p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote]

def close(self) -> List[Any]:
"Instances of SubprocVectorEnv should not called again once"
if self.closed:
return []
for p in self.parent_remote:
Expand All @@ -113,3 +119,7 @@ def close(self) -> List[Any]:
for p in self.processes:
p.join()
return result

def _assert_is_closed(self):
assert not self.closed, \
f"Methods of {self.__class__.__name__} should not be called after closed."