这是indexloc提供的服务,不要输入任何密码
Skip to content

fix vecenv action_space randomness #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/mujoco/mujoco_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=40000)
parser.add_argument('--step-per-collect', type=int, default=4)
parser.add_argument('--update-per-step', type=float, default=0.25)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--pre-collect-step', type=int, default=10000)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-sizes', type=int,
Expand Down
2 changes: 1 addition & 1 deletion tianshou/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class ReplayBufferManager(ReplayBuffer):

def __init__(self, buffer_list: List[ReplayBuffer]) -> None:
self.buffer_num = len(buffer_list)
self.buffers = np.array(buffer_list)
self.buffers = np.array(buffer_list, dtype=np.object)
offset, size = [], 0
buffer_type = type(self.buffers[0])
kwargs = self.buffers[0].options
Expand Down
53 changes: 20 additions & 33 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def seed(self, seed):

Otherwise, the outputs of these envs may be the same with each other.

:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith
env.
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env.
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
worker which contains the i-th env.
:param int wait_num: use in asynchronous simulation if the time cost of
Expand Down Expand Up @@ -75,13 +74,11 @@ def __init__(

self.env_num = len(env_fns)
self.wait_num = wait_num or len(env_fns)
assert (
1 <= self.wait_num <= len(env_fns)
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
assert 1 <= self.wait_num <= len(env_fns), \
f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
self.timeout = timeout
assert (
self.timeout is None or self.timeout > 0
), f"timeout is {timeout}, it should be positive if provided!"
assert self.timeout is None or self.timeout > 0, \
f"timeout is {timeout}, it should be positive if provided!"
self.is_async = self.wait_num != len(env_fns) or timeout is not None
self.waiting_conn: List[EnvWorker] = []
# environments in self.ready_id is actually ready
Expand All @@ -94,9 +91,8 @@ def __init__(
self.is_closed = False

def _assert_is_not_closed(self) -> None:
assert not self.is_closed, (
f"Methods of {self.__class__.__name__} cannot be called after "
"close.")
assert not self.is_closed, \
f"Methods of {self.__class__.__name__} cannot be called after close."

def __len__(self) -> int:
"""Return len(self), which is the number of environments."""
Expand All @@ -106,9 +102,8 @@ def __getattribute__(self, key: str) -> Any:
"""Switch the attribute getter depending on the key.

Any class who inherits ``gym.Env`` will inherit some attributes, like
``action_space``. However, we would like the attribute lookup to go
straight into the worker (in fact, this vector env's action_space is
always None).
``action_space``. However, we would like the attribute lookup to go straight
into the worker (in fact, this vector env's action_space is always None).
"""
if key in ['metadata', 'reward_range', 'spec', 'action_space',
'observation_space']: # reserved keys in gym.Env
Expand All @@ -119,9 +114,8 @@ def __getattribute__(self, key: str) -> Any:
def __getattr__(self, key: str) -> List[Any]:
"""Fetch a list of env attributes.

This function tries to retrieve an attribute from each individual
wrapped environment, if it does not belong to the wrapping vector
environment class.
This function tries to retrieve an attribute from each individual wrapped
environment, if it does not belong to the wrapping vector environment class.
"""
return [getattr(worker, key) for worker in self.workers]

Expand All @@ -136,12 +130,10 @@ def _wrap_id(

def _assert_id(self, id: List[int]) -> None:
for i in id:
assert (
i not in self.waiting_id
), f"Cannot interact with environment {i} which is stepping now."
assert (
i in self.ready_id
), f"Can only interact with ready environments {self.ready_id}."
assert i not in self.waiting_id, \
f"Cannot interact with environment {i} which is stepping now."
assert i in self.ready_id, \
f"Can only interact with ready environments {self.ready_id}."

def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None
Expand Down Expand Up @@ -178,8 +170,7 @@ def step(

:return: A tuple including four items:

* ``obs`` a numpy.ndarray, the agent's observation of current \
environments
* ``obs`` a numpy.ndarray, the agent's observation of current environments
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
previous actions
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
Expand Down Expand Up @@ -294,8 +285,7 @@ def __init__(
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
super().__init__(
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)


class SubprocVectorEnv(BaseVectorEnv):
Expand All @@ -316,8 +306,7 @@ def __init__(
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)

super().__init__(
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)


class ShmemVectorEnv(BaseVectorEnv):
Expand All @@ -340,8 +329,7 @@ def __init__(
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)

super().__init__(
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)


class RayVectorEnv(BaseVectorEnv):
Expand Down Expand Up @@ -369,5 +357,4 @@ def __init__(
) from e
if not ray.is_initialized():
ray.init()
super().__init__(
env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
21 changes: 11 additions & 10 deletions tianshou/env/worker/subproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ class ShArray:
"""Wrapper of multiprocessing Array."""

def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(
_NP_TO_CT[dtype.type],
int(np.prod(shape)),
)
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape

Expand Down Expand Up @@ -143,10 +140,14 @@ def __init__(
self.process = Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
self._seed = None

def __getattr__(self, key: str) -> Any:
self.parent_remote.send(["getattr", key])
return self.parent_remote.recv()
result = self.parent_remote.recv()
if key == "action_space": # issue #299
result.seed(self._seed)
return result

def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
def decode_obs(
Expand Down Expand Up @@ -185,11 +186,9 @@ def wait( # type: ignore
if remain_time <= 0:
break
# connection.wait hangs if the list is empty
new_ready_conns = connection.wait(
remain_conns, timeout=remain_time)
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [
conn for conn in remain_conns if conn not in ready_conns]
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]

def send_action(self, action: np.ndarray) -> None:
Expand All @@ -205,7 +204,9 @@ def get_result(

def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
self.parent_remote.send(["seed", seed])
return self.parent_remote.recv()
result = self.parent_remote.recv()
self._seed = result[0] if result is not None else seed
return result

def render(self, **kwargs: Any) -> Any:
self.parent_remote.send(["render", kwargs])
Expand Down