这是indexloc提供的服务,不要输入任何密码
Skip to content

Upgrade gym #613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
558a5ea
upgrade version of cartpole
Apr 26, 2022
bbbcd81
np_random.randint --> np_random.integers
Apr 26, 2022
c296427
try to use env.reset(seed=seed) instead of env.seed(seed)
Apr 27, 2022
95d6d87
commit checks
Apr 27, 2022
f0c5945
Merge branch 'master' into upgrade-gym
Trinkle23897 Apr 27, 2022
5812300
Merge branch 'master' into upgrade-gym
ycheng517 May 11, 2022
b1ceefd
Revert "upgrade version of cartpole"
ycheng517 May 11, 2022
4eae57c
fix merge error
ycheng517 May 11, 2022
61358d5
make venvs and env workers to support reset()->[obs, info]
ycheng517 May 12, 2022
fcc2cfc
clean up
ycheng517 May 13, 2022
2563439
add test case for reset with optional kwargs
ycheng517 May 13, 2022
b82eb0e
satisfy checks
ycheng517 May 13, 2022
d644280
pettingzoo reset supports return_info
ycheng517 May 13, 2022
8ffd633
small addition
ycheng517 May 13, 2022
861a1ba
fix mypy
ycheng517 May 13, 2022
2e7485e
Merge branch 'master' into upgrade-gym
Trinkle23897 May 14, 2022
364b46e
return info based on the return type of env.reset
ycheng517 May 21, 2022
d99b5ae
switch tuple observation Exception to TypeError
ycheng517 May 21, 2022
88d865a
remove debug prints
ycheng517 May 21, 2022
c5c7246
Merge branch 'master' into upgrade-gym
Trinkle23897 May 21, 2022
662a68a
check reset_returns_info once
ycheng517 May 22, 2022
75ecd18
support reset returns info in collector
ycheng517 May 25, 2022
c2ff71f
dynamically check reset retval in collector
ycheng517 Jun 6, 2022
12cf50f
bump gym version to 0.23.1 and fix mypy
ycheng517 Jun 6, 2022
6c60d53
fix lint check
ycheng517 Jun 6, 2022
fe06182
undo changes to test_sac_with_il
ycheng517 Jun 6, 2022
144e88a
doc formatting
ycheng517 Jun 6, 2022
f34cfec
Merge branch 'master' into upgrade-gym
Trinkle23897 Jun 6, 2022
5d4b9a0
undo changes to test_sac_with_il.py
ycheng517 Jun 8, 2022
f5eef9c
undo changes to test/continuous/test_sac_with_il.py
ycheng517 Jun 8, 2022
be7148a
test/continuous/test_sac_with_il.py
ycheng517 Jun 8, 2022
784f749
Merge remote-tracking branch 'origin/master' into upgrade-gym
ycheng517 Jun 22, 2022
edfcbcf
undo caching reset_return_info in subproc
ycheng517 Jun 22, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def __init__(self, env, noop_max=30):

def reset(self):
self.env.reset()
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
if hasattr(self.unwrapped.np_random, "integers"):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_version() -> str:

def get_install_requires() -> str:
return [
"gym>=0.15.4",
"gym>=0.23.1",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0",
Expand Down
9 changes: 7 additions & 2 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,16 @@ def seed(self, seed=0):
self.rng = np.random.RandomState(seed)
return [seed]

def reset(self, state=0):
def reset(self, state=0, seed=None, return_info=False):
if seed is not None:
self.rng = np.random.RandomState(seed)
self.done = False
self.do_sleep()
self.index = state
return self._get_state()
if return_info:
return self._get_state(), {'key': 1, 'env': self}
else:
return self._get_state()

def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""
Expand Down
136 changes: 110 additions & 26 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy

try:
import envpool
except ImportError:
envpool = None

if __name__ == '__main__':
from env import MyTestEnv, NXEnv
else: # pytest
Expand All @@ -23,14 +28,15 @@

class MyPolicy(BasePolicy):

def __init__(self, dict_state=False, need_state=True):
def __init__(self, dict_state=False, need_state=True, action_shape=None):
"""
:param bool dict_state: if the observation of the environment is a dict
:param bool need_state: if the policy needs the hidden state (for RNN)
"""
super().__init__()
self.dict_state = dict_state
self.need_state = need_state
self.action_shape = action_shape

def forward(self, batch, state=None):
if self.need_state:
Expand All @@ -39,8 +45,12 @@ def forward(self, batch, state=None):
else:
state += 1
if self.dict_state:
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
return Batch(act=np.ones(len(batch.obs)), state=state)
action_shape = self.action_shape if self.action_shape else len(
batch.obs['index']
)
return Batch(act=np.ones(action_shape), state=state)
action_shape = self.action_shape if self.action_shape else len(batch.obs)
return Batch(act=np.ones(action_shape), state=state)

def learn(self):
pass
Expand Down Expand Up @@ -77,7 +87,8 @@ def single_preprocess_fn(**kwargs):
return Batch()


def test_collector():
@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
def test_collector(gym_reset_kwargs):
writer = SummaryWriter('log/collector')
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
Expand All @@ -86,52 +97,102 @@ def test_collector():
dum = DummyVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
c0.collect(n_step=3)
c0 = Collector(
policy,
env,
ReplayBuffer(size=100),
logger.preprocess_fn,
)
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 3
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
c0.collect(n_episode=3)
keys = np.zeros(100)
keys[:3] = 1
assert np.allclose(c0.buffer.info["key"], keys)
for e in c0.buffer.info["env"][:3]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"], 0)
rews = np.zeros(100)
rews[:3] = [0, 1, 0]
assert np.allclose(c0.buffer.info["rew"], rews)
c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 8
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True)
assert np.allclose(c0.buffer.info["key"][:8], 1)
for e in c0.buffer.info["env"][:8]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"][:8], 0)
assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)

c1 = Collector(
policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn
)
c1.collect(n_step=8)
c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
obs = np.zeros(100)
obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]

valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c1.collect(n_episode=4)
keys = np.zeros(100)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids = np.zeros(100)
env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
rews = np.zeros(100)
rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
assert len(c1.buffer) == 16
valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(
c1.buffer[:].obs_next[..., 0],
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
)
c1.collect(n_episode=4, random=True)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)

c2 = Collector(
policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn
)
c2.collect(n_episode=7)
c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
obs1 = obs.copy()
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
obs2 = obs.copy()
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
c2obs = c2.buffer.obs[:, 0]
assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
c2.reset_env()
c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
c2.reset_buffer()
assert c2.collect(n_episode=8)['n/ep'] == 8
obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8
valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert np.all(c2.buffer.obs[:, 0] == obs)
c2.collect(n_episode=4, random=True)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c2.buffer.info["key"], keys)
for e in c2.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
assert np.allclose(c2.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
assert np.allclose(c2.buffer.info["rew"], rews)
c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)

# test corner case
with pytest.raises(TypeError):
Expand All @@ -147,11 +208,12 @@ def test_collector():
[lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]
)
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=6)
c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
assert c3.buffer.obs.dtype == object


def test_collector_with_async():
@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
def test_collector_with_async(gym_reset_kwargs):
env_lens = [2, 3, 4, 5]
writer = SummaryWriter('log/async_collector')
logger = Logger(writer)
Expand All @@ -163,12 +225,14 @@ def test_collector_with_async():
policy = MyPolicy()
bufsize = 60
c1 = AsyncCollector(
policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn
policy,
venv,
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn,
)
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
Expand All @@ -183,7 +247,7 @@ def test_collector_with_async():
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step)
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
assert result["n/st"] >= n_step
for i in range(4):
env_len = i + 2
Expand Down Expand Up @@ -618,9 +682,29 @@ def test_collector_with_atari_setting():
assert np.allclose(result2[key], result[key])


@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_collector_envpool_gym_reset_return_info():
envs = envpool.make_gym("Pendulum-v0", num_envs=4, gym_reset_return_info=True)
policy = MyPolicy(action_shape=(len(envs), 1))

c0 = Collector(
policy,
envs,
VectorReplayBuffer(len(envs) * 10, len(envs)),
exploration_noise=True
)
c0.collect(n_step=8)
env_ids = np.zeros(len(envs) * 10)
env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c0.buffer.info["env_id"], env_ids)


if __name__ == '__main__':
test_collector()
test_collector(gym_reset_kwargs=None)
test_collector(gym_reset_kwargs=dict(return_info=True))
test_collector_with_dict_state()
test_collector_with_ma()
test_collector_with_atari_setting()
test_collector_with_async()
test_collector_with_async(gym_reset_kwargs=None)
test_collector_with_async(gym_reset_kwargs=dict(return_info=True))
test_collector_envpool_gym_reset_return_info()
28 changes: 27 additions & 1 deletion test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,18 @@ def test_env_obs_dtype():
assert obs.dtype == object


def test_env_reset_optional_kwargs(size=10000, num=8):
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
for cls in test_cls:
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
_, info = v.reset(seed=1, return_info=True)
assert len(info) == len(env_fns)
assert isinstance(info[0], dict)


def run_align_norm_obs(raw_env, train_env, test_env, action_list):
eps = np.finfo(np.float32).eps.item()
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
Expand Down Expand Up @@ -319,11 +331,25 @@ def test_venv_wrapper_envpool():
run_align_norm_obs(raw, train, test, actions)


if __name__ == "__main__":
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info():
num_envs = 4
env = VectorEnvNormObs(
envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True)
)
obs, info = env.reset()
assert obs.shape[0] == num_envs
for _, v in info.items():
if not isinstance(v, dict):
assert v.shape[0] == num_envs


if __name__ == '__main__':
test_venv_norm_obs()
test_venv_wrapper_envpool()
test_env_obs_dtype()
test_vecenv()
test_async_env()
test_async_check_id()
test_env_reset_optional_kwargs()
test_gym_wrappers()
Loading