diff --git a/tianshou/env/vecenv/asyncenv.py b/tianshou/env/vecenv/asyncenv.py index 00d1e51ca..75ffc4465 100644 --- a/tianshou/env/vecenv/asyncenv.py +++ b/tianshou/env/vecenv/asyncenv.py @@ -52,10 +52,12 @@ def _assert_and_transform_id(self, return id def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: + self._assert_is_closed() id = self._assert_and_transform_id(id) return super().reset(id) def render(self, **kwargs) -> List[Any]: + self._assert_is_closed() if len(self.waiting_id) > 0: raise RuntimeError( f"Environments {self.waiting_id} are still " @@ -80,6 +82,7 @@ def step(self, (initially they are env_ids of all the environments). If action is ``None``, fetch unfinished step() calls instead. """ + self._assert_is_closed() if action is not None: id = self._assert_and_transform_id(id) assert len(action) == len(id) diff --git a/tianshou/env/vecenv/base.py b/tianshou/env/vecenv/base.py index b6c160dab..f4897b5c4 100644 --- a/tianshou/env/vecenv/base.py +++ b/tianshou/env/vecenv/base.py @@ -48,6 +48,10 @@ def __len__(self) -> int: """Return len(self), which is the number of environments.""" return self.env_num + def __del__(self): + """Close the environment before garbage collected""" + self.close() + def __getattribute__(self, key: str): """Switch between the default attribute getter or one looking at wrapped environment level depending on the key.""" diff --git a/tianshou/env/vecenv/shmemenv.py b/tianshou/env/vecenv/shmemenv.py index a764ba6d3..9eda9045e 100644 --- a/tianshou/env/vecenv/shmemenv.py +++ b/tianshou/env/vecenv/shmemenv.py @@ -124,6 +124,7 @@ def step(self, action: np.ndarray, id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._assert_is_closed() if id is None: id = range(self.env_num) elif np.isscalar(id): @@ -140,6 +141,7 @@ def step(self, return obs, rew, done, info def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: + self._assert_is_closed() if id is None: id = range(self.env_num) elif np.isscalar(id): diff --git a/tianshou/env/vecenv/subproc.py b/tianshou/env/vecenv/subproc.py index 9b8d8e2f3..de476afea 100644 --- a/tianshou/env/vecenv/subproc.py +++ b/tianshou/env/vecenv/subproc.py @@ -60,6 +60,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: c.close() def __getattr__(self, key): + self._assert_is_closed() for p in self.parent_remote: p.send(['getattr', key]) return [p.recv() for p in self.parent_remote] @@ -68,6 +69,7 @@ def step(self, action: np.ndarray, id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + self._assert_is_closed() if id is None: id = range(self.env_num) elif np.isscalar(id): @@ -80,6 +82,7 @@ def step(self, return obs, rew, done, info def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: + self._assert_is_closed() if id is None: id = range(self.env_num) elif np.isscalar(id): @@ -90,6 +93,7 @@ def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: return obs def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: + self._assert_is_closed() if np.isscalar(seed): seed = [seed + _ for _ in range(self.env_num)] elif seed is None: @@ -99,11 +103,13 @@ def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: return [p.recv() for p in self.parent_remote] def render(self, **kwargs) -> List[Any]: + self._assert_is_closed() for p in self.parent_remote: p.send(['render', kwargs]) return [p.recv() for p in self.parent_remote] def close(self) -> List[Any]: + "Instances of SubprocVectorEnv should not called again once" if self.closed: return [] for p in self.parent_remote: @@ -113,3 +119,7 @@ def close(self) -> List[Any]: for p in self.processes: p.join() return result + + def _assert_is_closed(self): + assert not self.closed, \ + f"Methods of {self.__class__.__name__} should not be called after closed."