diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index d7ead224e..c07877324 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -33,7 +33,6 @@ def get_args(): parser.add_argument('--step-per-epoch', type=int, default=40000) parser.add_argument('--step-per-collect', type=int, default=4) parser.add_argument('--update-per-step', type=float, default=0.25) - parser.add_argument('--update-per-step', type=int, default=1) parser.add_argument('--pre-collect-step', type=int, default=10000) parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--hidden-sizes', type=int, diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 05f5ecef9..542d47452 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -454,7 +454,7 @@ class ReplayBufferManager(ReplayBuffer): def __init__(self, buffer_list: List[ReplayBuffer]) -> None: self.buffer_num = len(buffer_list) - self.buffers = np.array(buffer_list) + self.buffers = np.array(buffer_list, dtype=np.object) offset, size = [], 0 buffer_type = type(self.buffers[0]) kwargs = self.buffers[0].options diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index ddf7acaf2..36c613e5f 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -43,8 +43,7 @@ def seed(self, seed): Otherwise, the outputs of these envs may be the same with each other. - :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith - env. + :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env. :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a worker which contains the i-th env. :param int wait_num: use in asynchronous simulation if the time cost of @@ -75,13 +74,11 @@ def __init__( self.env_num = len(env_fns) self.wait_num = wait_num or len(env_fns) - assert ( - 1 <= self.wait_num <= len(env_fns) - ), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" + assert 1 <= self.wait_num <= len(env_fns), \ + f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" self.timeout = timeout - assert ( - self.timeout is None or self.timeout > 0 - ), f"timeout is {timeout}, it should be positive if provided!" + assert self.timeout is None or self.timeout > 0, \ + f"timeout is {timeout}, it should be positive if provided!" self.is_async = self.wait_num != len(env_fns) or timeout is not None self.waiting_conn: List[EnvWorker] = [] # environments in self.ready_id is actually ready @@ -94,9 +91,8 @@ def __init__( self.is_closed = False def _assert_is_not_closed(self) -> None: - assert not self.is_closed, ( - f"Methods of {self.__class__.__name__} cannot be called after " - "close.") + assert not self.is_closed, \ + f"Methods of {self.__class__.__name__} cannot be called after close." def __len__(self) -> int: """Return len(self), which is the number of environments.""" @@ -106,9 +102,8 @@ def __getattribute__(self, key: str) -> Any: """Switch the attribute getter depending on the key. Any class who inherits ``gym.Env`` will inherit some attributes, like - ``action_space``. However, we would like the attribute lookup to go - straight into the worker (in fact, this vector env's action_space is - always None). + ``action_space``. However, we would like the attribute lookup to go straight + into the worker (in fact, this vector env's action_space is always None). """ if key in ['metadata', 'reward_range', 'spec', 'action_space', 'observation_space']: # reserved keys in gym.Env @@ -119,9 +114,8 @@ def __getattribute__(self, key: str) -> Any: def __getattr__(self, key: str) -> List[Any]: """Fetch a list of env attributes. - This function tries to retrieve an attribute from each individual - wrapped environment, if it does not belong to the wrapping vector - environment class. + This function tries to retrieve an attribute from each individual wrapped + environment, if it does not belong to the wrapping vector environment class. """ return [getattr(worker, key) for worker in self.workers] @@ -136,12 +130,10 @@ def _wrap_id( def _assert_id(self, id: List[int]) -> None: for i in id: - assert ( - i not in self.waiting_id - ), f"Cannot interact with environment {i} which is stepping now." - assert ( - i in self.ready_id - ), f"Can only interact with ready environments {self.ready_id}." + assert i not in self.waiting_id, \ + f"Cannot interact with environment {i} which is stepping now." + assert i in self.ready_id, \ + f"Can only interact with ready environments {self.ready_id}." def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None @@ -178,8 +170,7 @@ def step( :return: A tuple including four items: - * ``obs`` a numpy.ndarray, the agent's observation of current \ - environments + * ``obs`` a numpy.ndarray, the agent's observation of current environments * ``rew`` a numpy.ndarray, the amount of rewards returned after \ previous actions * ``done`` a numpy.ndarray, whether these episodes have ended, in \ @@ -294,8 +285,7 @@ def __init__( wait_num: Optional[int] = None, timeout: Optional[float] = None, ) -> None: - super().__init__( - env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) class SubprocVectorEnv(BaseVectorEnv): @@ -316,8 +306,7 @@ def __init__( def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) - super().__init__( - env_fns, worker_fn, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) class ShmemVectorEnv(BaseVectorEnv): @@ -340,8 +329,7 @@ def __init__( def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) - super().__init__( - env_fns, worker_fn, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) class RayVectorEnv(BaseVectorEnv): @@ -369,5 +357,4 @@ def __init__( ) from e if not ray.is_initialized(): ray.init() - super().__init__( - env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 02acc21fb..226c314ff 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -31,10 +31,7 @@ class ShArray: """Wrapper of multiprocessing Array.""" def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: - self.arr = Array( - _NP_TO_CT[dtype.type], - int(np.prod(shape)), - ) + self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) self.dtype = dtype self.shape = shape @@ -143,10 +140,14 @@ def __init__( self.process = Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() + self._seed = None def __getattr__(self, key: str) -> Any: self.parent_remote.send(["getattr", key]) - return self.parent_remote.recv() + result = self.parent_remote.recv() + if key == "action_space": # issue #299 + result.seed(self._seed) + return result def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: def decode_obs( @@ -185,11 +186,9 @@ def wait( # type: ignore if remain_time <= 0: break # connection.wait hangs if the list is empty - new_ready_conns = connection.wait( - remain_conns, timeout=remain_time) + new_ready_conns = connection.wait(remain_conns, timeout=remain_time) ready_conns.extend(new_ready_conns) # type: ignore - remain_conns = [ - conn for conn in remain_conns if conn not in ready_conns] + remain_conns = [conn for conn in remain_conns if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] def send_action(self, action: np.ndarray) -> None: @@ -205,7 +204,9 @@ def get_result( def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: self.parent_remote.send(["seed", seed]) - return self.parent_remote.recv() + result = self.parent_remote.recv() + self._seed = result[0] if result is not None else seed + return result def render(self, **kwargs: Any) -> Any: self.parent_remote.send(["render", kwargs])