diff --git a/setup.py b/setup.py index 83ac36580..ae0b936de 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def get_extras_require() -> str: "doc8", "scipy", "pillow", + "supersuit", # for padding pettingzoo "pettingzoo>=1.17", "pygame>=2.1.0", # pettingzoo test cases pistonball "pymunk>=6.2.1", # pettingzoo test cases pistonball diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index c406872dc..956b3a4c0 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Tuple import gym.spaces +import supersuit as ss from pettingzoo.utils.env import AECEnv from pettingzoo.utils.wrappers import BaseWrapper @@ -24,8 +25,18 @@ class PettingZooEnv(AECEnv, ABC): Further usage can be found at :ref:`marl_example`. """ - def __init__(self, env: BaseWrapper): + def __init__( + self, + env: BaseWrapper, + pad_observation_space: bool = False, + pad_action_space: bool = False, + ): super().__init__() + if pad_observation_space: + env = ss.pad_observations_v0(env) + if pad_action_space: + env = ss.pad_action_space_v0(env) + self.env = env # agent idx list self.agents = self.env.possible_agents @@ -35,65 +46,85 @@ def __init__(self, env: BaseWrapper): self.rewards = [0] * len(self.agents) + # Get first observation space, assuming all agents have equal space + self.state_space: Any = self.env.state_space if hasattr( + self.env, "state_space" + ) else None + # Get first observation space, assuming all agents have equal space self.observation_space: Any = self.env.observation_space(self.agents[0]) # Get first action space, assuming all agents have equal space self.action_space: Any = self.env.action_space(self.agents[0]) - assert all(self.env.observation_space(agent) == self.observation_space - for agent in self.agents), \ - "Observation spaces for all agents must be identical. Perhaps " \ - "SuperSuit's pad_observations wrapper can help (useage: " \ + assert all( + self.env.observation_space(agent) == self.observation_space + for agent in self.agents + ), ( + "Observation spaces for all agents must be identical. Perhaps " + "SuperSuit's pad_observations wrapper can help (useage: " "`supersuit.aec_wrappers.pad_observations(env)`" + ) - assert all(self.env.action_space(agent) == self.action_space - for agent in self.agents), \ - "Action spaces for all agents must be identical. Perhaps " \ - "SuperSuit's pad_action_space wrapper can help (useage: " \ + assert all( + self.env.action_space(agent) == self.action_space for agent in self.agents + ), ( + "Action spaces for all agents must be identical. Perhaps " + "SuperSuit's pad_action_space wrapper can help (useage: " "`supersuit.aec_wrappers.pad_action_space(env)`" + ) self.reset() def reset(self, *args: Any, **kwargs: Any) -> dict: self.env.reset(*args, **kwargs) observation = self.env.observe(self.env.agent_selection) - if isinstance(observation, dict) and 'action_mask' in observation: - return { - 'agent_id': self.env.agent_selection, - 'obs': observation['observation'], - 'mask': - [True if obm == 1 else False for obm in observation['action_mask']] + if isinstance(observation, dict) and "action_mask" in observation: + ret = { + "agent_id": + self.env.agent_selection, + "obs": + observation["observation"], + "mask": + [True if obm == 1 else False for obm in observation["action_mask"]], } else: if isinstance(self.action_space, gym.spaces.Discrete): - return { - 'agent_id': self.env.agent_selection, - 'obs': observation, - 'mask': [True] * self.env.action_space(self.env.agent_selection).n + ret = { + "agent_id": self.env.agent_selection, + "obs": observation, + "mask": [True] * self.env.action_space(self.env.agent_selection).n, } else: - return {'agent_id': self.env.agent_selection, 'obs': observation} + ret = {"agent_id": self.env.agent_selection, "obs": observation} + if hasattr(self.env, "state"): + ret["state"] = self.env.state() + return ret def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: self.env.step(action) observation, rew, done, info = self.env.last() - if isinstance(observation, dict) and 'action_mask' in observation: + if isinstance(observation, dict) and "action_mask" in observation: obs = { - 'agent_id': self.env.agent_selection, - 'obs': observation['observation'], - 'mask': - [True if obm == 1 else False for obm in observation['action_mask']] + "agent_id": + self.env.agent_selection, + "obs": + observation["observation"], + "mask": + [True if obm == 1 else False for obm in observation["action_mask"]], } else: if isinstance(self.action_space, gym.spaces.Discrete): obs = { - 'agent_id': self.env.agent_selection, - 'obs': observation, - 'mask': [True] * self.env.action_space(self.env.agent_selection).n + "agent_id": self.env.agent_selection, + "obs": observation, + "mask": [True] * self.env.action_space(self.env.agent_selection).n, } else: - obs = {'agent_id': self.env.agent_selection, 'obs': observation} + obs = {"agent_id": self.env.agent_selection, "obs": observation} + + if hasattr(self.env, "state"): + obs["state"] = self.env.state() for agent_id, reward in self.env.rewards.items(): self.rewards[self.agent_idx[agent_id]] = reward diff --git a/tianshou/env/pettingzoo_env_ma_wrapper.py b/tianshou/env/pettingzoo_env_ma_wrapper.py new file mode 100644 index 000000000..93cfc649b --- /dev/null +++ b/tianshou/env/pettingzoo_env_ma_wrapper.py @@ -0,0 +1,117 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np + +from tianshou.env import BaseVectorEnv, PettingZooEnv + + +class MAEnvWrapper(PettingZooEnv): + """wrap pettingzoo env to act as dummy env""" + + def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: + """ + :param Any action: + :return Tuple[Dict, List[int], bool, Dict] + + Append env_id to the returned info. + """ + obs, rew, done, info = super().step(action) + info["env_id"] = self.agent_idx[obs["agent_id"]] + + return obs, rew, done, info + + def __len__(self) -> int: + return self.num_agents + + +def ma_venv_init( + self: BaseVectorEnv, p_cls: Type[BaseVectorEnv], + env_fns: List[Callable[[], gym.Env]], **kwargs: Any +) -> None: + """add agents relevant attrs + + :param BaseVectorEnv self + :param Type[BaseVectorEnv] p_cls + :param List[Callable[[], gym.Env]] env_fns + """ + p_cls.__init__(self, env_fns, **kwargs) + + setattr(self, "p_lcs", p_cls) + + agents = self.get_env_attr("agents", [0])[0] + agent_idx = self.get_env_attr("agent_idx", [0])[0] + + setattr(self, "agents", agents) + setattr(self, "agent_idx", agent_idx) + setattr(self, "agent_num", len(agent_idx)) + + +def ma_venv_len(self: BaseVectorEnv) -> int: + """ + :param BaseVectorEnv self + :return int: num_agent * env_num + """ + return sum(self.get_env_attr("num_agents")) + + +def ma_venv_step( + self: BaseVectorEnv, + action: np.ndarray, + id: Optional[Union[int, List[int], np.ndarray]] = None, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + + :param BaseVectorEnv self: + :param np.ndarray action: + :param Optional[Union[int, List[int], np.ndarray]] id: , defaults to None + :return Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + + ma_env_id is set to true env_id when taking step. + + and (agent_id, env_id) is set back to ma_env_id in returned info + """ + if id is not None: + if isinstance(id, int): + id = [id] + id = np.array(id) + for _i, _id in enumerate(id): + id[_i] = _id % self.env_num + obs_stack, rew_stack, done_stack, info_stack = self.p_cls.step(self, action, id) + for obs, info in zip(obs_stack, info_stack): + # self.env_num is the number of environments, + # while the env_num in collector is + # `the number of agents` * `the number of environments` + info["env_id"] = ( + self.agent_idx[obs["agent_id"]] * self.env_num + info["env_id"] + ) + return obs_stack, rew_stack, done_stack, info_stack + + +def get_MA_VectorEnv_cls(p_cls: Type[BaseVectorEnv]) -> Type[BaseVectorEnv]: + """ + Get the class of Multi-Agent VectorEnv. + MAVectorEnv has the layout [(agent0, env0), (agent0, env1), ..., + (agent1, env0), (agent1, env1), ...] + """ + + def init_func( + self: BaseVectorEnv, env_fns: List[Callable[[], gym.Env]], **kwargs: Any + ) -> None: + ma_venv_init(self, p_cls, env_fns, **kwargs) + + name = "MA" + p_cls.__name__ + + attr_dict = {"__init__": init_func, "__len__": ma_venv_len, "step": ma_venv_step} + + return type(name, (p_cls, ), attr_dict) + + +def get_MA_VectorEnv( + p_cls: Type[BaseVectorEnv], env_fns: List[Callable[[], gym.Env]], **kwargs: Any +) -> BaseVectorEnv: + """ + Get an instance of Multi-Agent VectorEnv. + """ + + return get_MA_VectorEnv_cls(p_cls)(env_fns, **kwargs)