From 35b12b8c83d3626917fd911fe13c84f9930c2c3a Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 25 Mar 2021 21:42:21 +0800 Subject: [PATCH 1/4] Test env finite --- test/base/test_env_finite.py | 212 +++++++++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100644 test/base/test_env_finite.py diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py new file mode 100644 index 000000000..cc52b34eb --- /dev/null +++ b/test/base/test_env_finite.py @@ -0,0 +1,212 @@ +import copy +import random +from collections import Counter +from typing import Callable, Optional, Union, List + +import numpy as np +import torch +from gym import Env +from tianshou.data import Collector, Batch +from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv +from tianshou.policy import BasePolicy +from torch.utils.data import Dataset, DataLoader, DistributedSampler + + +class DummyDataset(Dataset): + def __init__(self, length): + self.length = length + self.episodes = [3 * i % 5 + 1 for i in range(self.length)] + + def __getitem__(self, index): + assert 0 <= index < self.length + return index, self.episodes[index] + + def __len__(self): + return self.length + + +class FiniteEnv(Env): + def __init__(self, dataset, num_replicas, rank): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None) + self.iterator = None + + def reset(self): + if self.iterator is None: + self.iterator = iter(self.loader) + try: + self.current_sample, self.step_count = next(self.iterator) + self.current_step = 0 + return self.current_sample + except StopIteration: + self.iterator = None + return None + + def step(self, action): + self.current_step += 1 + assert self.current_step <= self.step_count + return 0, 1.0, self.current_step >= self.step_count, {'sample': self.current_sample, 'action': action, 'metric': 2.0} + + +class FiniteVectorEnv(BaseVectorEnv): + def __init__( + self, + env_fns: List[Callable[[], Env]], + wait_num: Optional[int] = None, + timeout: Optional[float] = None, + ): + super().__init__(env_fns, wait_num, timeout) + self._alive_env_ids = set() + self._reset_alive_envs() + self._default_obs = self._default_info = None + + def _reset_alive_envs(self): + if not self._alive_env_ids: + # starting or running out + self._alive_env_ids = set(range(self.env_num)) + + ### to workaround with tianshou's buffer and batch ### + def _set_default_obs(self, obs): + if obs is not None and self._default_obs is None: + self._default_obs = copy.deepcopy(obs) + + def _set_default_info(self, info): + if info is not None and self._default_info is None: + self._default_info = copy.deepcopy(info) + + def _get_default_obs(self): + return copy.deepcopy(self._default_obs) + + def _get_default_info(self): + return copy.deepcopy(self._default_info) + ### END ### + + def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None): + id = self._wrap_id(id) + self._reset_alive_envs() + + # ask super to reset alive envs and remap to current index + request_id = list(filter(lambda i: i in self._alive_env_ids, id)) + obs = [None] * len(id) + id2idx = {i: k for k, i in enumerate(id)} + if request_id: + for i, o in zip(request_id, super().reset(request_id)): + obs[id2idx[i]] = o + for i, o in zip(id, obs): + if o is None and i in self._alive_env_ids: + self._alive_env_ids.remove(i) + + # fill empty observation with default(fake) observation + for o in obs: + self._set_default_obs(o) + for i in range(len(obs)): + if obs[i] is None: + obs[i] = self._get_default_obs() + + if not self._alive_env_ids: + self.reset() + raise StopIteration + + return np.stack(obs) + + def step(self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None): + id = self._wrap_id(id) + id2idx = {i: k for k, i in enumerate(id)} + request_id = list(filter(lambda i: i in self._alive_env_ids, id)) + result = [[None, 0., False, None] for _ in range(len(id))] + + # ask super to step alive envs and remap to current index + if request_id: + valid_act = np.stack([action[id2idx[i]] for i in request_id]) + for i, r in zip(request_id, zip(*super().step(valid_act, request_id))): + result[id2idx[i]] = r + + # logging + for i, r in zip(id, result): + if i in self._alive_env_ids: + self.tracker.log(*r) + + # fill empty observation/info with default(fake) + for _, __, ___, i in result: + self._set_default_info(i) + for i in range(len(result)): + if result[i][0] is None: + result[i][0] = self._get_default_obs() + if result[i][3] is None: + result[i][3] = self._get_default_info() + + return list(map(np.stack, zip(*result))) + + +class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): + pass + + +class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv): + pass + + +class AnyPolicy(BasePolicy): + def forward(self, batch, state=None): + return Batch(act=np.stack([1] * len(batch))) + + def learn(self, batch): + pass + + +def _finite_env_factory(dataset, num_replicas, rank): + return lambda: FiniteEnv(dataset, num_replicas, rank) + + +class MetricTracker: + def __init__(self): + self.counter = Counter() + self.finished = set() + + def log(self, obs, rew, done, info): + assert rew == 1. + index = info['sample'] + if done: + assert index not in self.finished + self.finished.add(index) + self.counter[index] += 1 + + def validate(self): + assert len(self.finished) == 100 + for k, v in self.counter.items(): + assert v == k * 3 % 5 + 1 + + +def test_finite_dummy_vector_env(): + dataset = DummyDataset(100) + envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) + policy = AnyPolicy() + test_collector = Collector(policy, envs, exploration_noise=True) + + for epoch in range(3): + envs.tracker = MetricTracker() + try: + test_collector.collect(n_step=10 ** 18) + except StopIteration: + envs.tracker.validate() + + +def test_finite_subproc_vector_env(): + dataset = DummyDataset(100) + envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) + policy = AnyPolicy() + test_collector = Collector(policy, envs, exploration_noise=True) + + for epoch in range(3): + envs.tracker = MetricTracker() + try: + test_collector.collect(n_step=10 ** 18) + except StopIteration: + envs.tracker.validate() + + +if __name__ == '__main__': + test_finite_dummy_vector_env() + test_finite_subproc_vector_env() From 580967cbd79ea9f1d7187f1e3f955abfa48f9b94 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 25 Mar 2021 21:51:35 +0800 Subject: [PATCH 2/4] Fight with linter --- test/base/test_env_finite.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index cc52b34eb..dafa01b54 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -1,10 +1,8 @@ import copy -import random from collections import Counter from typing import Callable, Optional, Union, List import numpy as np -import torch from gym import Env from tianshou.data import Collector, Batch from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv @@ -30,7 +28,10 @@ def __init__(self, dataset, num_replicas, rank): self.dataset = dataset self.num_replicas = num_replicas self.rank = rank - self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None) + self.loader = DataLoader(dataset, + sampler=DistributedSampler( + dataset, num_replicas, rank), + batch_size=None) self.iterator = None def reset(self): @@ -47,7 +48,8 @@ def reset(self): def step(self, action): self.current_step += 1 assert self.current_step <= self.step_count - return 0, 1.0, self.current_step >= self.step_count, {'sample': self.current_sample, 'action': action, 'metric': 2.0} + return 0, 1.0, self.current_step >= self.step_count, \ + {'sample': self.current_sample, 'action': action, 'metric': 2.0} class FiniteVectorEnv(BaseVectorEnv): @@ -67,7 +69,7 @@ def _reset_alive_envs(self): # starting or running out self._alive_env_ids = set(range(self.env_num)) - ### to workaround with tianshou's buffer and batch ### + # to workaround with tianshou's buffer and batch def _set_default_obs(self, obs): if obs is not None and self._default_obs is None: self._default_obs = copy.deepcopy(obs) @@ -81,7 +83,7 @@ def _get_default_obs(self): def _get_default_info(self): return copy.deepcopy(self._default_info) - ### END ### + # END def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None): id = self._wrap_id(id) @@ -111,7 +113,9 @@ def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None): return np.stack(obs) - def step(self, action: np.ndarray, id: Optional[Union[int, List[int], np.ndarray]] = None): + def step(self, + action: np.ndarray, + id: Optional[Union[int, List[int], np.ndarray]] = None): id = self._wrap_id(id) id2idx = {i: k for k, i in enumerate(id)} request_id = list(filter(lambda i: i in self._alive_env_ids, id)) @@ -181,11 +185,12 @@ def validate(self): def test_finite_dummy_vector_env(): dataset = DummyDataset(100) - envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteSubprocVectorEnv([ + _finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) - for epoch in range(3): + for _ in range(3): envs.tracker = MetricTracker() try: test_collector.collect(n_step=10 ** 18) @@ -195,11 +200,12 @@ def test_finite_dummy_vector_env(): def test_finite_subproc_vector_env(): dataset = DummyDataset(100) - envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteSubprocVectorEnv([ + _finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) - for epoch in range(3): + for _ in range(3): envs.tracker = MetricTracker() try: test_collector.collect(n_step=10 ** 18) From 5383e424481ff937e1a7d89bc68c32023f056d68 Mon Sep 17 00:00:00 2001 From: Yuge Zhang Date: Thu, 25 Mar 2021 22:12:47 +0800 Subject: [PATCH 3/4] Try to fix init type error --- test/base/test_env_finite.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index dafa01b54..a2fc8e46b 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -53,13 +53,8 @@ def step(self, action): class FiniteVectorEnv(BaseVectorEnv): - def __init__( - self, - env_fns: List[Callable[[], Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None, - ): - super().__init__(env_fns, wait_num, timeout) + def __init__(self, env_fns: List[Callable[[], Env]], **kwargs): + super().__init__(env_fns, **kwargs) self._alive_env_ids = set() self._reset_alive_envs() self._default_obs = self._default_info = None From aa795a19569e2bb8890c37d8d7b4ff3e4a814be9 Mon Sep 17 00:00:00 2001 From: Trinkle23897 Date: Thu, 25 Mar 2021 22:43:20 +0800 Subject: [PATCH 4/4] minor fix --- test/base/test_env_finite.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index a2fc8e46b..b670d65e9 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -1,13 +1,14 @@ +# see issue #322 for detail + +import gym import copy +import numpy as np from collections import Counter -from typing import Callable, Optional, Union, List +from torch.utils.data import Dataset, DataLoader, DistributedSampler -import numpy as np -from gym import Env +from tianshou.policy import BasePolicy from tianshou.data import Collector, Batch from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import BasePolicy -from torch.utils.data import Dataset, DataLoader, DistributedSampler class DummyDataset(Dataset): @@ -23,15 +24,15 @@ def __len__(self): return self.length -class FiniteEnv(Env): +class FiniteEnv(gym.Env): def __init__(self, dataset, num_replicas, rank): self.dataset = dataset self.num_replicas = num_replicas self.rank = rank - self.loader = DataLoader(dataset, - sampler=DistributedSampler( - dataset, num_replicas, rank), - batch_size=None) + self.loader = DataLoader( + dataset, + sampler=DistributedSampler(dataset, num_replicas, rank), + batch_size=None) self.iterator = None def reset(self): @@ -53,7 +54,7 @@ def step(self, action): class FiniteVectorEnv(BaseVectorEnv): - def __init__(self, env_fns: List[Callable[[], Env]], **kwargs): + def __init__(self, env_fns, **kwargs): super().__init__(env_fns, **kwargs) self._alive_env_ids = set() self._reset_alive_envs() @@ -80,7 +81,7 @@ def _get_default_info(self): return copy.deepcopy(self._default_info) # END - def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None): + def reset(self, id=None): id = self._wrap_id(id) self._reset_alive_envs() @@ -108,9 +109,7 @@ def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None): return np.stack(obs) - def step(self, - action: np.ndarray, - id: Optional[Union[int, List[int], np.ndarray]] = None): + def step(self, action, id=None): id = self._wrap_id(id) id2idx = {i: k for k, i in enumerate(id)} request_id = list(filter(lambda i: i in self._alive_env_ids, id))