diff --git a/setup.py b/setup.py index 80209edcb..7054619e6 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ def get_version() -> str: "torch>=1.4.0", "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements + "pettingzoo>=1.11.0", ], extras_require={ "dev": [ @@ -73,6 +74,7 @@ def get_version() -> str: "pydocstyle", "doc8", "scipy", + "pillow", ], "atari": ["atari_py", "opencv-python"], "mujoco": ["mujoco_py"], diff --git a/test/pettingzoo/pistonball_train.py b/test/pettingzoo/pistonball_train.py new file mode 100644 index 000000000..54149c691 --- /dev/null +++ b/test/pettingzoo/pistonball_train.py @@ -0,0 +1,33 @@ +from pettingzoo.butterfly import pistonball_v4 +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.env import DummyVectorEnv +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.policy import RandomPolicy, MultiAgentPolicyManager +import numpy as np +import torch + + +def get_env(): + return PettingZooEnv(pistonball_v4.env(continuous=False)) + + +train_envs = DummyVectorEnv([get_env for _ in range(10)]) +# test_envs = DummyVectorEnv([get_env for _ in range(100)]) + +# seed +np.random.seed(1626) +torch.manual_seed(1626) +train_envs.seed(1626) +# test_envs.seed(1626) + +policy = MultiAgentPolicyManager([RandomPolicy() + for _ in range(len(get_env().agents))], + get_env()) + +# collector +train_collector = Collector(policy, train_envs, + VectorReplayBuffer(6, len(train_envs)), + exploration_noise=True) +# test_collector = Collector(policy, test_envs, exploration_noise=True) +# policy.set_eps(1) +train_collector.collect(n_step=640, render=0.0001) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py new file mode 100644 index 000000000..6924b0a4d --- /dev/null +++ b/tianshou/env/pettingzoo_env.py @@ -0,0 +1,83 @@ +from pettingzoo.utils.env import AECEnv + + +class PettingZooEnv(AECEnv): + + def __init__(self, env): + self.env = env + # agent idx list + self.agents = self.env.possible_agents + self.agent_idx = {} + for i, agent_id in enumerate(self.agents): + self.agent_idx[agent_id] = i + # Get dictionaries of obs_spaces and act_spaces + self.observation_spaces = self.env.observation_spaces + self.action_spaces = self.env.action_spaces + + self.rewards = [0] * len(self.agents) + + # Get first observation space, assuming all agents have equal space + self.observation_space = self.observation_spaces[self.agents[0]] + + # Get first action space, assuming all agents have equal space + self.action_space = self.action_spaces[self.agents[0]] + + assert all(obs_space == self.observation_space + for obs_space + in self.env.observation_spaces.values()), \ + "Observation spaces for all agents must be identical. Perhaps " \ + "SuperSuit's pad_observations wrapper can help (useage: " \ + "`supersuit.aec_wrappers.pad_observations(env)`" + + assert all(act_space == self.action_space + for act_space in self.env.action_spaces.values()), \ + "Action spaces for all agents must be identical. Perhaps " \ + "SuperSuit's pad_action_space wrapper can help (useage: " \ + "`supersuit.aec_wrappers.pad_action_space(env)`" + + self.reset() + + def reset(self): + self.env.reset() + observation = self.env.observe(self.env.agent_selection) + if isinstance(observation, dict) and 'action_mask' in observation: + return { + 'agent_id': self.env.agent_selection, + 'obs': observation['observation'], + 'mask': observation['action_mask'] + } + else: + return { + 'agent_id': self.env.agent_selection, + 'obs': observation, + 'mask': [True] * self.action_spaces[self.env.agent_selection].n + } + + def step(self, action): + self.env.step(action) + observation, rew, done, info = self.env.last() + if isinstance(observation, dict) and 'action_mask' in observation: + obs = { + 'agent_id': self.env.agent_selection, + 'obs': observation['observation'], + 'mask': observation['action_mask'] + } + else: + obs = { + 'agent_id': self.env.agent_selection, + 'obs': observation, + 'mask': [True] * self.action_spaces[self.env.agent_selection].n + } + + for agent_id, reward in self.env.rewards.items(): + self.rewards[self.agent_idx[agent_id]] = reward + return obs, self.rewards, done, info + + def close(self): + self.env.close() + + def seed(self, seed=None): + self.env.seed(seed) + + def render(self, mode="human"): + return self.env.render(mode) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 75705f4a3..dfd96acd2 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,9 +1,9 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - import numpy as np +from pettingzoo.utils.env import AECEnv +from typing import Any, Dict, List, Tuple, Union, Optional -from tianshou.data import Batch, ReplayBuffer from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer class MultiAgentPolicyManager(BasePolicy): @@ -16,21 +16,25 @@ class MultiAgentPolicyManager(BasePolicy): :ref:`marl_example` can help you better understand this procedure. """ - def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None: + def __init__(self, policies: List[BasePolicy], env: AECEnv, **kwargs: Any) -> None: super().__init__(**kwargs) - self.policies = policies + assert (len(policies) == len(env.agents)), "One policy must be \ + assigned for each agent." + self.agent_idx = env.agent_idx for i, policy in enumerate(policies): # agent_id 0 is reserved for the environment proxy # (this MultiAgentPolicyManager) - policy.set_agent_id(i + 1) + policy.set_agent_id(env.agents[i]) + + self.policies = dict(zip(env.agents, policies)) def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: """Replace the "agent_id"th policy in this manager.""" - self.policies[agent_id - 1] = policy policy.set_agent_id(agent_id) + self.policies[agent_id] = policy def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's process_fn. @@ -45,32 +49,31 @@ def process_fn( # Since we do not override buffer.__setattr__, here we use _meta to # change buffer.rew, otherwise buffer.rew = Batch() has no effect. save_rew, buffer._meta.rew = buffer.rew, Batch() - for policy in self.policies: - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + for agent_id, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: - results[f"agent_{policy.agent_id}"] = Batch() + results[agent_id] = Batch() continue - tmp_batch, tmp_indices = batch[agent_index], indices[agent_index] + tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] if has_rew: tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] buffer._meta.rew = save_rew[:, policy.agent_id - 1] - results[f"agent_{policy.agent_id}"] = policy.process_fn( - tmp_batch, buffer, tmp_indices - ) + results[policy.agent_id] = policy.process_fn( + tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return Batch(results) - def exploration_noise(self, act: Union[np.ndarray, Batch], - batch: Batch) -> Union[np.ndarray, Batch]: + def exploration_noise( + self, act: Union[np.ndarray, Batch], batch: Batch + ) -> Union[np.ndarray, Batch]: """Add exploration noise from sub-policy onto act.""" - for policy in self.policies: - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + for agent_id, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: continue act[agent_index] = policy.exploration_noise( - act[agent_index], batch[agent_index] - ) + act[agent_index], batch[agent_index]) return act def forward( # type: ignore @@ -102,9 +105,9 @@ def forward( # type: ignore "agent_n": xxx} } """ - results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch], - Batch]] = [] - for policy in self.policies: + results: List[Tuple[bool, np.ndarray, Batch, + Union[np.ndarray, Batch], Batch]] = [] + for agent_id, policy in self.policies.items(): # This part of code is difficult to understand. # Let's follow an example with two agents # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) @@ -112,7 +115,7 @@ def forward( # type: ignore # agent_index for agent 1 is [0, 2, 4] # agent_index for agent 2 is [1, 3, 5] # we separate the transition of each agent according to agent_id - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: # (has_data, agent_index, out, act, state) results.append((False, np.array([-1]), Batch(), Batch(), Batch())) @@ -120,39 +123,32 @@ def forward( # type: ignore tmp_batch = batch[agent_index] if isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. - tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] - out = policy( - batch=tmp_batch, - state=None if state is None else state["agent_" + - str(policy.agent_id)], - **kwargs - ) + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] + out = policy(batch=tmp_batch, state=None if state is None + else state[agent_id], + **kwargs) act = out.act each_state = out.state \ if (hasattr(out, "state") and out.state is not None) \ else Batch() results.append((True, agent_index, out, act, each_state)) - holder = Batch.cat( - [ - { - "act": act - } for (has_data, agent_index, out, act, each_state) in results - if has_data - ] - ) + holder = Batch.cat([{"act": act} for + (has_data, agent_index, out, act, each_state) + in results if has_data]) state_dict, out_dict = {}, {} - for policy, (has_data, agent_index, out, act, - state) in zip(self.policies, results): + for (agent_id, _), (has_data, agent_index, out, act, state) in zip( + self.policies.items(), results): if has_data: holder.act[agent_index] = act - state_dict["agent_" + str(policy.agent_id)] = state - out_dict["agent_" + str(policy.agent_id)] = out + state_dict[agent_id] = state + out_dict[agent_id] = out holder["out"] = out_dict holder["state"] = state_dict return holder - def learn(self, batch: Batch, - **kwargs: Any) -> Dict[str, Union[float, List[float]]]: + def learn( + self, batch: Batch, **kwargs: Any + ) -> Dict[str, Union[float, List[float]]]: """Dispatch the data to all policies for learning. :return: a dict with the following contents: @@ -168,10 +164,10 @@ def learn(self, batch: Batch, } """ results = {} - for policy in self.policies: - data = batch[f"agent_{policy.agent_id}"] + for agent_id, policy in self.policies.items(): + data = batch[agent_id] if not data.is_empty(): out = policy.learn(batch=data, **kwargs) for k, v in out.items(): - results["agent_" + str(policy.agent_id) + "/" + k] = v + results[agent_id + "/" + k] = v return results