diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4380db458..804de4d6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,11 +9,11 @@ repos: # pass_filenames: false # args: [--config-file=setup.cfg, tianshou] - - repo: https://github.com/pre-commit/mirrors-yapf + - repo: https://github.com/google/yapf rev: v0.32.0 hooks: - id: yapf - args: [-r] + args: [-r, -i] - repo: https://github.com/pycqa/isort rev: 5.10.1 @@ -21,7 +21,7 @@ repos: - id: isort name: isort - - repo: https://gitlab.com/PyCQA/flake8 + - repo: https://github.com/PyCQA/flake8 rev: 4.0.1 hooks: - id: flake8 diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index 293c1cec9..901b166ef 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -16,6 +16,16 @@ envpool = None +def _parse_reset_result(reset_result): + contains_info = ( + isinstance(reset_result, tuple) and len(reset_result) == 2 + and isinstance(reset_result[1], dict) + ) + if contains_info: + return reset_result[0], reset_result[1], contains_info + return reset_result, {}, contains_info + + class NoopResetEnv(gym.Wrapper): """Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. @@ -30,16 +40,23 @@ def __init__(self, env, noop_max=30): self.noop_action = 0 assert env.unwrapped.get_action_meanings()[0] == 'NOOP' - def reset(self): - self.env.reset() + def reset(self, **kwargs): + _, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) if hasattr(self.unwrapped.np_random, "integers"): noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) else: noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) for _ in range(noops): - obs, _, done, _ = self.env.step(self.noop_action) + step_result = self.env.step(self.noop_action) + if len(step_result) == 4: + obs, rew, done, info = step_result + else: + obs, rew, term, trunc, info = step_result + done = term or trunc if done: - obs = self.env.reset() + obs, info, _ = _parse_reset_result(self.env.reset()) + if return_info: + return obs, info return obs @@ -59,14 +76,24 @@ def step(self, action): """Step the environment with the given action. Repeat action, sum reward, and max over last observations. """ - obs_list, total_reward, done = [], 0., False + obs_list, total_reward = [], 0. + new_step_api = False for _ in range(self._skip): - obs, reward, done, info = self.env.step(action) + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result + else: + obs, reward, term, trunc, info = step_result + done = term or trunc + new_step_api = True obs_list.append(obs) total_reward += reward if done: break max_frame = np.max(obs_list[-2:], axis=0) + if new_step_api: + return max_frame, total_reward, term, trunc, info + return max_frame, total_reward, done, info @@ -81,9 +108,18 @@ def __init__(self, env): super().__init__(env) self.lives = 0 self.was_real_done = True + self._return_info = False def step(self, action): - obs, reward, done, info = self.env.step(action) + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result + new_step_api = False + else: + obs, reward, term, trunc, info = step_result + done = term or trunc + new_step_api = True + self.was_real_done = done # check current lives, make loss of life terminal, then update lives to # handle bonus lives @@ -93,7 +129,10 @@ def step(self, action): # frames, so its important to keep lives > 0, so that we only reset # once the environment is actually done. done = True + term = True self.lives = lives + if new_step_api: + return obs, reward, term, trunc, info return obs, reward, done, info def reset(self): @@ -102,12 +141,16 @@ def reset(self): the learner need not know about any of this behind-the-scenes. """ if self.was_real_done: - obs = self.env.reset() + obs, info, self._return_info = _parse_reset_result(self.env.reset()) else: # no-op step to advance from terminal/lost life state - obs = self.env.step(0)[0] + step_result = self.env.step(0) + obs, info = step_result[0], step_result[-1] self.lives = self.env.unwrapped.ale.lives() - return obs + if self._return_info: + return obs, info + else: + return obs class FireResetEnv(gym.Wrapper): @@ -123,8 +166,9 @@ def __init__(self, env): assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self): - self.env.reset() - return self.env.step(1)[0] + _, _, return_info = _parse_reset_result(self.env.reset()) + obs = self.env.step(1)[0] + return (obs, {}) if return_info else obs class WarpFrame(gym.ObservationWrapper): @@ -204,14 +248,22 @@ def __init__(self, env, n_frames): ) def reset(self): - obs = self.env.reset() + obs, info, return_info = _parse_reset_result(self.env.reset()) for _ in range(self.n_frames): self.frames.append(obs) - return self._get_ob() + return (self._get_ob(), info) if return_info else self._get_ob() def step(self, action): - obs, reward, done, info = self.env.step(action) + step_result = self.env.step(action) + if len(step_result) == 4: + obs, reward, done, info = step_result + new_step_api = False + else: + obs, reward, term, trunc, info = step_result + new_step_api = True self.frames.append(obs) + if new_step_api: + return self._get_ob(), reward, term, trunc, info return self._get_ob(), reward, done, info def _get_ob(self): diff --git a/setup.cfg b/setup.cfg index 96315242b..296035903 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ exclude = dist *.egg-info max-line-length = 87 -ignore = B305,W504,B006,B008,B024 +ignore = B305,W504,B006,B008,B024,W503 [yapf] based_on_style = pep8 diff --git a/setup.py b/setup.py index cbed99da2..3d96d75a5 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ def get_install_requires() -> str: "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements "protobuf~=3.19.0", # breaking change, sphinx fail + "packaging", ]