diff --git a/test/base/test_env.py b/test/base/test_env.py index b55b3034b..0a1aac6b7 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -18,7 +18,7 @@ ) from tianshou.utils import RunningMeanStd -if __name__ == '__main__': +if __name__ == "__main__": from env import MyTestEnv, NXEnv else: # pytest from test.base.env import MyTestEnv, NXEnv @@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1): spent_time = time.time() while current_idx_start < len(action_list): A, B, C, D = v.step(action=act, id=env_ids) - b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D}) + b = Batch({"obs": A, "rew": B, "done": C, "info": D}) env_ids = b.info.env_id o.append(b) current_idx_start += len(act) @@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): for info in infos: assert recurse_comp(infos[0], info) - if __name__ == '__main__': + if __name__ == "__main__": t = [0] * len(venv) for i, e in enumerate(venv): t[i] = time.time() @@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001): e.reset(np.where(done)[0]) t[i] = time.time() - t[i] for i, v in enumerate(venv): - print(f'{type(v)}: {t[i]:.6f}s') + print(f"{type(v)}: {t[i]:.6f}s") def assert_get(v, expected): assert v.get_env_attr("size") == expected @@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8): assert isinstance(info[0], dict) +def test_venv_wrapper_gym(num_envs: int = 4): + # Issue 697 + envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)]) + envs = VectorEnvNormObs(envs) + obs_ref = envs.reset(return_info=False) + obs, info = envs.reset(return_info=True) + assert isinstance(obs_ref, np.ndarray) + assert isinstance(obs, np.ndarray) + assert isinstance(info, list) + assert isinstance(info[0], dict) + assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs + + def run_align_norm_obs(raw_env, train_env, test_env, action_list): eps = np.finfo(np.float32).eps.item() raw_obs, train_obs = [raw_env.reset()], [train_env.reset()] @@ -309,7 +322,7 @@ def __init__(self): # check conversion is working properly for a batch of actions np.testing.assert_allclose( env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)), - np.array([original_act] * bsz) + np.array([original_act] * bsz), ) # convert multidiscrete with different action number per # dimension to discrete action space @@ -321,7 +334,7 @@ def __init__(self): # check conversion is working properly for a batch of actions np.testing.assert_allclose( env_d.action(np.array([env_d.action_space.n - 1] * bsz)), - np.array([env_m.action_space.nvec - 1] * bsz) + np.array([env_m.action_space.nvec - 1] * bsz), ) @@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info(): assert v.shape[0] == num_envs -if __name__ == '__main__': +if __name__ == "__main__": test_venv_norm_obs() + test_venv_wrapper_gym() test_venv_wrapper_envpool() + test_venv_wrapper_envpool_gym_reset_return_info() test_env_obs_dtype() test_vecenv() test_attr_unwrapped() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 7971d5b0c..a9b3763c8 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -135,7 +135,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} rval = self.env.reset(**gym_reset_kwargs) returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( - isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore + isinstance(rval[1], dict) or isinstance(rval[1][0], dict) ) if returns_info: obs, info = rval @@ -173,7 +173,7 @@ def _reset_env_with_ids( gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} rval = self.env.reset(global_ids, **gym_reset_kwargs) returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( - isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore + isinstance(rval[1], dict) or isinstance(rval[1][0], dict) ) if returns_info: obs_reset, info = rval diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 0098b8283..7bf2888ff 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -41,7 +41,7 @@ def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs: Any, - ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: + ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: return self.venv.reset(id, **kwargs) def step( @@ -84,15 +84,15 @@ def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs: Any, - ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: - retval = self.venv.reset(id, **kwargs) - reset_returns_info = isinstance( - retval, (tuple, list) - ) and len(retval) == 2 and isinstance(retval[1], dict) - if reset_returns_info: - obs, info = retval + ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: + rval = self.venv.reset(id, **kwargs) + returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and ( + isinstance(rval[1], dict) or isinstance(rval[1][0], dict) + ) + if returns_info: + obs, info = rval else: - obs = retval + obs = rval if isinstance(obs, tuple): raise TypeError( @@ -103,7 +103,7 @@ def reset( if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) obs = self._norm_obs(obs) - if reset_returns_info: + if returns_info: return obs, info else: return obs diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 1f12d3fdf..5e785f2bb 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -185,7 +185,7 @@ def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None, **kwargs: Any, - ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: + ) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return