这是indexloc提供的服务,不要输入任何密码
Skip to content

Fix venv wrapper reset retval error with gym env #712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from tianshou.utils import RunningMeanStd

if __name__ == '__main__':
if __name__ == "__main__":
from env import MyTestEnv, NXEnv
else: # pytest
from test.base.env import MyTestEnv, NXEnv
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
spent_time = time.time()
while current_idx_start < len(action_list):
A, B, C, D = v.step(action=act, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
b = Batch({"obs": A, "rew": B, "done": C, "info": D})
env_ids = b.info.env_id
o.append(b)
current_idx_start += len(act)
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
for info in infos:
assert recurse_comp(infos[0], info)

if __name__ == '__main__':
if __name__ == "__main__":
t = [0] * len(venv)
for i, e in enumerate(venv):
t[i] = time.time()
Expand All @@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
e.reset(np.where(done)[0])
t[i] = time.time() - t[i]
for i, v in enumerate(venv):
print(f'{type(v)}: {t[i]:.6f}s')
print(f"{type(v)}: {t[i]:.6f}s")

def assert_get(v, expected):
assert v.get_env_attr("size") == expected
Expand Down Expand Up @@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
assert isinstance(info[0], dict)


def test_venv_wrapper_gym(num_envs: int = 4):
# Issue 697
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
envs = VectorEnvNormObs(envs)
obs_ref = envs.reset(return_info=False)
obs, info = envs.reset(return_info=True)
assert isinstance(obs_ref, np.ndarray)
assert isinstance(obs, np.ndarray)
assert isinstance(info, list)
assert isinstance(info[0], dict)
assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs


def run_align_norm_obs(raw_env, train_env, test_env, action_list):
eps = np.finfo(np.float32).eps.item()
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
Expand Down Expand Up @@ -309,7 +322,7 @@ def __init__(self):
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
np.array([original_act] * bsz)
np.array([original_act] * bsz),
)
# convert multidiscrete with different action number per
# dimension to discrete action space
Expand All @@ -321,7 +334,7 @@ def __init__(self):
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
np.array([env_m.action_space.nvec - 1] * bsz)
np.array([env_m.action_space.nvec - 1] * bsz),
)


Expand Down Expand Up @@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
assert v.shape[0] == num_envs


if __name__ == '__main__':
if __name__ == "__main__":
test_venv_norm_obs()
test_venv_wrapper_gym()
test_venv_wrapper_envpool()
test_venv_wrapper_envpool_gym_reset_return_info()
test_env_obs_dtype()
test_vecenv()
test_attr_unwrapped()
Expand Down
4 changes: 2 additions & 2 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(**gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs, info = rval
Expand Down Expand Up @@ -173,7 +173,7 @@ def _reset_env_with_ids(
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(global_ids, **gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs_reset, info = rval
Expand Down
20 changes: 10 additions & 10 deletions tianshou/env/venv_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
return self.venv.reset(id, **kwargs)

def step(
Expand Down Expand Up @@ -84,15 +84,15 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
retval = self.venv.reset(id, **kwargs)
reset_returns_info = isinstance(
retval, (tuple, list)
) and len(retval) == 2 and isinstance(retval[1], dict)
if reset_returns_info:
obs, info = retval
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
rval = self.venv.reset(id, **kwargs)
returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
)
if returns_info:
obs, info = rval
else:
obs = retval
obs = rval

if isinstance(obs, tuple):
raise TypeError(
Expand All @@ -103,7 +103,7 @@ def reset(
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
obs = self._norm_obs(obs)
if reset_returns_info:
if returns_info:
return obs, info
else:
return obs
Expand Down
2 changes: 1 addition & 1 deletion tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def reset(
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
"""Reset the state of some envs and return initial observations.

If id is None, reset the state of all the environments and return
Expand Down