这是indexloc提供的服务,不要输入任何密码
Skip to content

Support Basic MARL #619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_extras_require() -> str:
"doc8",
"scipy",
"pillow",
"supersuit", # for padding pettingzoo
"pettingzoo>=1.17",
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
Expand Down
89 changes: 60 additions & 29 deletions tianshou/env/pettingzoo_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Tuple

import gym.spaces
import supersuit as ss
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper

Expand All @@ -24,8 +25,18 @@ class PettingZooEnv(AECEnv, ABC):
Further usage can be found at :ref:`marl_example`.
"""

def __init__(self, env: BaseWrapper):
def __init__(
self,
env: BaseWrapper,
pad_observation_space: bool = False,
pad_action_space: bool = False,
):
super().__init__()
if pad_observation_space:
env = ss.pad_observations_v0(env)
if pad_action_space:
env = ss.pad_action_space_v0(env)

self.env = env
# agent idx list
self.agents = self.env.possible_agents
Expand All @@ -35,65 +46,85 @@ def __init__(self, env: BaseWrapper):

self.rewards = [0] * len(self.agents)

# Get first observation space, assuming all agents have equal space
self.state_space: Any = self.env.state_space if hasattr(
self.env, "state_space"
) else None

# Get first observation space, assuming all agents have equal space
self.observation_space: Any = self.env.observation_space(self.agents[0])

# Get first action space, assuming all agents have equal space
self.action_space: Any = self.env.action_space(self.agents[0])

assert all(self.env.observation_space(agent) == self.observation_space
for agent in self.agents), \
"Observation spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_observations wrapper can help (useage: " \
assert all(
self.env.observation_space(agent) == self.observation_space
for agent in self.agents
), (
"Observation spaces for all agents must be identical. Perhaps "
"SuperSuit's pad_observations wrapper can help (useage: "
"`supersuit.aec_wrappers.pad_observations(env)`"
)

assert all(self.env.action_space(agent) == self.action_space
for agent in self.agents), \
"Action spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_action_space wrapper can help (useage: " \
assert all(
self.env.action_space(agent) == self.action_space for agent in self.agents
), (
"Action spaces for all agents must be identical. Perhaps "
"SuperSuit's pad_action_space wrapper can help (useage: "
"`supersuit.aec_wrappers.pad_action_space(env)`"
)

self.reset()

def reset(self, *args: Any, **kwargs: Any) -> dict:
self.env.reset(*args, **kwargs)
observation = self.env.observe(self.env.agent_selection)
if isinstance(observation, dict) and 'action_mask' in observation:
return {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask':
[True if obm == 1 else False for obm in observation['action_mask']]
if isinstance(observation, dict) and "action_mask" in observation:
ret = {
"agent_id":
self.env.agent_selection,
"obs":
observation["observation"],
"mask":
[True if obm == 1 else False for obm in observation["action_mask"]],
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
return {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n
ret = {
"agent_id": self.env.agent_selection,
"obs": observation,
"mask": [True] * self.env.action_space(self.env.agent_selection).n,
}
else:
return {'agent_id': self.env.agent_selection, 'obs': observation}
ret = {"agent_id": self.env.agent_selection, "obs": observation}
if hasattr(self.env, "state"):
ret["state"] = self.env.state()
return ret

def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
self.env.step(action)
observation, rew, done, info = self.env.last()
if isinstance(observation, dict) and 'action_mask' in observation:
if isinstance(observation, dict) and "action_mask" in observation:
obs = {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask':
[True if obm == 1 else False for obm in observation['action_mask']]
"agent_id":
self.env.agent_selection,
"obs":
observation["observation"],
"mask":
[True if obm == 1 else False for obm in observation["action_mask"]],
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
obs = {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n
"agent_id": self.env.agent_selection,
"obs": observation,
"mask": [True] * self.env.action_space(self.env.agent_selection).n,
}
else:
obs = {'agent_id': self.env.agent_selection, 'obs': observation}
obs = {"agent_id": self.env.agent_selection, "obs": observation}

if hasattr(self.env, "state"):
obs["state"] = self.env.state()

for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
Expand Down
117 changes: 117 additions & 0 deletions tianshou/env/pettingzoo_env_ma_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import gym
import numpy as np

from tianshou.env import BaseVectorEnv, PettingZooEnv


class MAEnvWrapper(PettingZooEnv):
"""wrap pettingzoo env to act as dummy env"""

def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
"""
:param Any action:
:return Tuple[Dict, List[int], bool, Dict]

Append env_id to the returned info.
"""
obs, rew, done, info = super().step(action)
info["env_id"] = self.agent_idx[obs["agent_id"]]

return obs, rew, done, info

def __len__(self) -> int:
return self.num_agents


def ma_venv_init(
self: BaseVectorEnv, p_cls: Type[BaseVectorEnv],
env_fns: List[Callable[[], gym.Env]], **kwargs: Any
) -> None:
"""add agents relevant attrs

:param BaseVectorEnv self
:param Type[BaseVectorEnv] p_cls
:param List[Callable[[], gym.Env]] env_fns
"""
p_cls.__init__(self, env_fns, **kwargs)

setattr(self, "p_lcs", p_cls)

agents = self.get_env_attr("agents", [0])[0]
agent_idx = self.get_env_attr("agent_idx", [0])[0]

setattr(self, "agents", agents)
setattr(self, "agent_idx", agent_idx)
setattr(self, "agent_num", len(agent_idx))


def ma_venv_len(self: BaseVectorEnv) -> int:
"""
:param BaseVectorEnv self
:return int: num_agent * env_num
"""
return sum(self.get_env_attr("num_agents"))


def ma_venv_step(
self: BaseVectorEnv,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""

:param BaseVectorEnv self:
:param np.ndarray action:
:param Optional[Union[int, List[int], np.ndarray]] id: , defaults to None
:return Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

ma_env_id is set to true env_id when taking step.

and (agent_id, env_id) is set back to ma_env_id in returned info
"""
if id is not None:
if isinstance(id, int):
id = [id]
id = np.array(id)
for _i, _id in enumerate(id):
id[_i] = _id % self.env_num
obs_stack, rew_stack, done_stack, info_stack = self.p_cls.step(self, action, id)
for obs, info in zip(obs_stack, info_stack):
# self.env_num is the number of environments,
# while the env_num in collector is
# `the number of agents` * `the number of environments`
info["env_id"] = (
self.agent_idx[obs["agent_id"]] * self.env_num + info["env_id"]
)
return obs_stack, rew_stack, done_stack, info_stack


def get_MA_VectorEnv_cls(p_cls: Type[BaseVectorEnv]) -> Type[BaseVectorEnv]:
"""
Get the class of Multi-Agent VectorEnv.
MAVectorEnv has the layout [(agent0, env0), (agent0, env1), ...,
(agent1, env0), (agent1, env1), ...]
"""

def init_func(
self: BaseVectorEnv, env_fns: List[Callable[[], gym.Env]], **kwargs: Any
) -> None:
ma_venv_init(self, p_cls, env_fns, **kwargs)

name = "MA" + p_cls.__name__

attr_dict = {"__init__": init_func, "__len__": ma_venv_len, "step": ma_venv_step}

return type(name, (p_cls, ), attr_dict)


def get_MA_VectorEnv(
p_cls: Type[BaseVectorEnv], env_fns: List[Callable[[], gym.Env]], **kwargs: Any
) -> BaseVectorEnv:
"""
Get an instance of Multi-Agent VectorEnv.
"""

return get_MA_VectorEnv_cls(p_cls)(env_fns, **kwargs)