这是indexloc提供的服务,不要输入任何密码
Skip to content

Updated atari wrappers, fixed pre-commit #781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ repos:
# pass_filenames: false
# args: [--config-file=setup.cfg, tianshou]

- repo: https://github.com/pre-commit/mirrors-yapf
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
- id: yapf
args: [-r]
args: [-r, -i]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
name: isort

- repo: https://gitlab.com/PyCQA/flake8
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
Expand Down
82 changes: 67 additions & 15 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
envpool = None


def _parse_reset_result(reset_result):
contains_info = (
isinstance(reset_result, tuple) and len(reset_result) == 2
and isinstance(reset_result[1], dict)
)
if contains_info:
return reset_result[0], reset_result[1], contains_info
return reset_result, {}, contains_info


class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
Expand All @@ -30,16 +40,23 @@ def __init__(self, env, noop_max=30):
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

def reset(self):
self.env.reset()
def reset(self, **kwargs):
_, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
if hasattr(self.unwrapped.np_random, "integers"):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
step_result = self.env.step(self.noop_action)
if len(step_result) == 4:
obs, rew, done, info = step_result
else:
obs, rew, term, trunc, info = step_result
done = term or trunc
if done:
obs = self.env.reset()
obs, info, _ = _parse_reset_result(self.env.reset())
if return_info:
return obs, info
return obs


Expand All @@ -59,14 +76,24 @@ def step(self, action):
"""Step the environment with the given action. Repeat action, sum
reward, and max over last observations.
"""
obs_list, total_reward, done = [], 0., False
obs_list, total_reward = [], 0.
new_step_api = False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True
obs_list.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(obs_list[-2:], axis=0)
if new_step_api:
return max_frame, total_reward, term, trunc, info

return max_frame, total_reward, done, info


Expand All @@ -81,9 +108,18 @@ def __init__(self, env):
super().__init__(env)
self.lives = 0
self.was_real_done = True
self._return_info = False

def step(self, action):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True

self.was_real_done = done
# check current lives, make loss of life terminal, then update lives to
# handle bonus lives
Expand All @@ -93,7 +129,10 @@ def step(self, action):
# frames, so its important to keep lives > 0, so that we only reset
# once the environment is actually done.
done = True
term = True
self.lives = lives
if new_step_api:
return obs, reward, term, trunc, info
return obs, reward, done, info

def reset(self):
Expand All @@ -102,12 +141,16 @@ def reset(self):
the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
obs, info, self._return_info = _parse_reset_result(self.env.reset())
else:
# no-op step to advance from terminal/lost life state
obs = self.env.step(0)[0]
step_result = self.env.step(0)
obs, info = step_result[0], step_result[-1]
self.lives = self.env.unwrapped.ale.lives()
return obs
if self._return_info:
return obs, info
else:
return obs


class FireResetEnv(gym.Wrapper):
Expand All @@ -123,8 +166,9 @@ def __init__(self, env):
assert len(env.unwrapped.get_action_meanings()) >= 3

def reset(self):
self.env.reset()
return self.env.step(1)[0]
_, _, return_info = _parse_reset_result(self.env.reset())
obs = self.env.step(1)[0]
return (obs, {}) if return_info else obs


class WarpFrame(gym.ObservationWrapper):
Expand Down Expand Up @@ -204,14 +248,22 @@ def __init__(self, env, n_frames):
)

def reset(self):
obs = self.env.reset()
obs, info, return_info = _parse_reset_result(self.env.reset())
for _ in range(self.n_frames):
self.frames.append(obs)
return self._get_ob()
return (self._get_ob(), info) if return_info else self._get_ob()

def step(self, action):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
new_step_api = True
self.frames.append(obs)
if new_step_api:
return self._get_ob(), reward, term, trunc, info
return self._get_ob(), reward, done, info

def _get_ob(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ exclude =
dist
*.egg-info
max-line-length = 87
ignore = B305,W504,B006,B008,B024
ignore = B305,W504,B006,B008,B024,W503

[yapf]
based_on_style = pep8
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_install_requires() -> str:
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"protobuf~=3.19.0", # breaking change, sphinx fail
"packaging",
]


Expand Down