From d40e6c6234dae1e0846dc34b5c11cc554a819bd4 Mon Sep 17 00:00:00 2001 From: Rodrigo de Lazcano Date: Sat, 11 Sep 2021 15:10:21 -0400 Subject: [PATCH 1/5] pettingzoo wrapper --- test/pettingzoo/pistonball_train.py | 29 ++++++++ tianshou/env/pettingzoo_env.py | 83 +++++++++++++++++++++++ tianshou/policy/multiagent/mapolicy.py | 94 ++++++++++++-------------- 3 files changed, 157 insertions(+), 49 deletions(-) create mode 100644 test/pettingzoo/pistonball_train.py create mode 100644 tianshou/env/pettingzoo_env.py diff --git a/test/pettingzoo/pistonball_train.py b/test/pettingzoo/pistonball_train.py new file mode 100644 index 000000000..43d6fdb28 --- /dev/null +++ b/test/pettingzoo/pistonball_train.py @@ -0,0 +1,29 @@ +from pettingzoo.butterfly import pistonball_v4 +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.env import DummyVectorEnv +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.policy import RandomPolicy, MultiAgentPolicyManager +import numpy as np +import torch + +def get_env(): + return PettingZooEnv(pistonball_v4.env(continuous=False)) + +train_envs = DummyVectorEnv([get_env for _ in range(10)]) +# test_envs = DummyVectorEnv([get_env for _ in range(100)]) + +# seed +np.random.seed(1626) +torch.manual_seed(1626) +train_envs.seed(1626) +# test_envs.seed(1626) + +policy = MultiAgentPolicyManager([RandomPolicy() for _ in range(len(get_env().agents))], get_env()) + +# collector +train_collector = Collector(policy, train_envs, + VectorReplayBuffer(6, len(train_envs)), + exploration_noise=True) +# test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) +train_collector.collect(n_step=640, render=0.0001) \ No newline at end of file diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py new file mode 100644 index 000000000..2ef06ba9b --- /dev/null +++ b/tianshou/env/pettingzoo_env.py @@ -0,0 +1,83 @@ +from pettingzoo.utils.env import AECEnv + +class PettingZooEnv(AECEnv): + + def __init__(self, env): + self.env = env + # agent idx list + self.agents = self.env.possible_agents + self.agent_idx = {} + for i, agent_id in enumerate(self.agents): + self.agent_idx[agent_id] = i + # Get dictionaries of obs_spaces and act_spaces + self.observation_spaces = self.env.observation_spaces + self.action_spaces = self.env.action_spaces + + self.rewards = [0] * len(self.agents) + + # Get first observation space, assuming all agents have equal space + self.observation_space = self.observation_spaces[self.agents[0]] + + # Get first action space, assuming all agents have equal space + self.action_space = self.action_spaces[self.agents[0]] + + assert all(obs_space == self.observation_space + for obs_space + in self.env.observation_spaces.values()), \ + "Observation spaces for all agents must be identical. Perhaps " \ + "SuperSuit's pad_observations wrapper can help (useage: " \ + "`supersuit.aec_wrappers.pad_observations(env)`" + + assert all(act_space == self.action_space + for act_space in self.env.action_spaces.values()), \ + "Action spaces for all agents must be identical. Perhaps " \ + "SuperSuit's pad_action_space wrapper can help (useage: " \ + "`supersuit.aec_wrappers.pad_action_space(env)`" + + self.reset() + + def reset(self): + self.env.reset() + observation = self.env.observe(self.env.agent_selection) + if isinstance(observation, dict) and 'action_mask' in observation: + return { + 'agent_id': self.env.agent_selection, + 'obs': observation['observation'], + 'mask': observation['action_mask'] + } + else: + return { + 'agent_id': self.env.agent_selection, + 'obs': observation, + 'mask': [True] * self.action_spaces[self.env.agent_selection].n + } + + def step(self, action): + self.env.step(action) + observation, rew, done, info = self.env.last() + if isinstance(observation, dict) and 'action_mask' in observation: + obs = { + 'agent_id': self.env.agent_selection, + 'obs': observation['observation'], + 'mask': observation['action_mask'] + } + else: + obs = { + 'agent_id': self.env.agent_selection, + 'obs': observation, + 'mask': [True] * self.action_spaces[self.env.agent_selection].n + } + + for agent_id, reward in self.env.rewards.items(): + self.rewards[self.agent_idx[agent_id]] = reward + return obs, self.rewards, done, info + + def close(self): + self.env.close() + + def seed(self, seed=None): + self.env.seed(seed) + + def render(self, mode="human"): + return self.env.render(mode) + diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 75705f4a3..dc9297244 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - +import gym import numpy as np +from pettingzoo.utils.env import AECEnv +from typing import Any, Dict, List, Tuple, Union, Optional -from tianshou.data import Batch, ReplayBuffer from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer class MultiAgentPolicyManager(BasePolicy): @@ -16,21 +17,24 @@ class MultiAgentPolicyManager(BasePolicy): :ref:`marl_example` can help you better understand this procedure. """ - def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None: + def __init__(self, policies: List[BasePolicy], env: AECEnv, **kwargs: Any) -> None: super().__init__(**kwargs) - self.policies = policies + assert (len(policies) == len(env.agents)), "One policy must be assigned for each agent." + self.agent_idx = env.agent_idx for i, policy in enumerate(policies): # agent_id 0 is reserved for the environment proxy # (this MultiAgentPolicyManager) - policy.set_agent_id(i + 1) + policy.set_agent_id(env.agents[i]) + + self.policies = dict(zip(env.agents, policies)) def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: """Replace the "agent_id"th policy in this manager.""" - self.policies[agent_id - 1] = policy policy.set_agent_id(agent_id) + self.policies[agent_id] = policy def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's process_fn. @@ -45,32 +49,31 @@ def process_fn( # Since we do not override buffer.__setattr__, here we use _meta to # change buffer.rew, otherwise buffer.rew = Batch() has no effect. save_rew, buffer._meta.rew = buffer.rew, Batch() - for policy in self.policies: - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + for agent_id, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: - results[f"agent_{policy.agent_id}"] = Batch() + results[agent_id] = Batch() continue - tmp_batch, tmp_indices = batch[agent_index], indices[agent_index] + tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] if has_rew: tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] buffer._meta.rew = save_rew[:, policy.agent_id - 1] - results[f"agent_{policy.agent_id}"] = policy.process_fn( - tmp_batch, buffer, tmp_indices - ) + results[policy.agent_id] = policy.process_fn( + tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return Batch(results) - def exploration_noise(self, act: Union[np.ndarray, Batch], - batch: Batch) -> Union[np.ndarray, Batch]: + def exploration_noise( + self, act: Union[np.ndarray, Batch], batch: Batch + ) -> Union[np.ndarray, Batch]: """Add exploration noise from sub-policy onto act.""" - for policy in self.policies: - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + for agent_id, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: continue act[agent_index] = policy.exploration_noise( - act[agent_index], batch[agent_index] - ) + act[agent_index], batch[agent_index]) return act def forward( # type: ignore @@ -102,9 +105,9 @@ def forward( # type: ignore "agent_n": xxx} } """ - results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch], - Batch]] = [] - for policy in self.policies: + results: List[Tuple[bool, np.ndarray, Batch, + Union[np.ndarray, Batch], Batch]] = [] + for agent_id, policy in self.policies.items(): # This part of code is difficult to understand. # Let's follow an example with two agents # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) @@ -112,7 +115,7 @@ def forward( # type: ignore # agent_index for agent 1 is [0, 2, 4] # agent_index for agent 2 is [1, 3, 5] # we separate the transition of each agent according to agent_id - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: # (has_data, agent_index, out, act, state) results.append((False, np.array([-1]), Batch(), Batch(), Batch())) @@ -120,39 +123,32 @@ def forward( # type: ignore tmp_batch = batch[agent_index] if isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. - tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] - out = policy( - batch=tmp_batch, - state=None if state is None else state["agent_" + - str(policy.agent_id)], - **kwargs - ) + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] + out = policy(batch=tmp_batch, state=None if state is None + else state[agent_id], + **kwargs) act = out.act each_state = out.state \ if (hasattr(out, "state") and out.state is not None) \ else Batch() results.append((True, agent_index, out, act, each_state)) - holder = Batch.cat( - [ - { - "act": act - } for (has_data, agent_index, out, act, each_state) in results - if has_data - ] - ) + holder = Batch.cat([{"act": act} for + (has_data, agent_index, out, act, each_state) + in results if has_data]) state_dict, out_dict = {}, {} - for policy, (has_data, agent_index, out, act, - state) in zip(self.policies, results): + for (agent_id, policy), (has_data, agent_index, out, act, state) in zip( + self.policies.items(), results): if has_data: holder.act[agent_index] = act - state_dict["agent_" + str(policy.agent_id)] = state - out_dict["agent_" + str(policy.agent_id)] = out + state_dict[agent_id] = state + out_dict[agent_id] = out holder["out"] = out_dict holder["state"] = state_dict return holder - def learn(self, batch: Batch, - **kwargs: Any) -> Dict[str, Union[float, List[float]]]: + def learn( + self, batch: Batch, **kwargs: Any + ) -> Dict[str, Union[float, List[float]]]: """Dispatch the data to all policies for learning. :return: a dict with the following contents: @@ -168,10 +164,10 @@ def learn(self, batch: Batch, } """ results = {} - for policy in self.policies: - data = batch[f"agent_{policy.agent_id}"] + for agent_id, policy in self.policies.items(): + data = batch[agent_id] if not data.is_empty(): out = policy.learn(batch=data, **kwargs) for k, v in out.items(): - results["agent_" + str(policy.agent_id) + "/" + k] = v + results[agent_id + "/" + k] = v return results From 952a6ae13bc0b8f27496bcf4ef8ecd444eeb2874 Mon Sep 17 00:00:00 2001 From: J K Terry Date: Thu, 23 Sep 2021 11:56:17 -0400 Subject: [PATCH 2/5] try fixing flake8 --- test/pettingzoo/pistonball_train.py | 6 ++++-- tianshou/env/pettingzoo_env.py | 2 +- tianshou/policy/multiagent/mapolicy.py | 1 - 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/pettingzoo/pistonball_train.py b/test/pettingzoo/pistonball_train.py index 43d6fdb28..dcc2e34cb 100644 --- a/test/pettingzoo/pistonball_train.py +++ b/test/pettingzoo/pistonball_train.py @@ -6,9 +6,11 @@ import numpy as np import torch + def get_env(): return PettingZooEnv(pistonball_v4.env(continuous=False)) + train_envs = DummyVectorEnv([get_env for _ in range(10)]) # test_envs = DummyVectorEnv([get_env for _ in range(100)]) @@ -25,5 +27,5 @@ def get_env(): VectorReplayBuffer(6, len(train_envs)), exploration_noise=True) # test_collector = Collector(policy, test_envs, exploration_noise=True) - # policy.set_eps(1) -train_collector.collect(n_step=640, render=0.0001) \ No newline at end of file +# policy.set_eps(1) +train_collector.collect(n_step=640, render=0.0001) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 2ef06ba9b..8eac23777 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,5 +1,6 @@ from pettingzoo.utils.env import AECEnv + class PettingZooEnv(AECEnv): def __init__(self, env): @@ -80,4 +81,3 @@ def seed(self, seed=None): def render(self, mode="human"): return self.env.render(mode) - diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index dc9297244..4407944cb 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,4 +1,3 @@ -import gym import numpy as np from pettingzoo.utils.env import AECEnv from typing import Any, Dict, List, Tuple, Union, Optional From 67af2848f0f96d6f5a02a662b8412723b099b9ec Mon Sep 17 00:00:00 2001 From: J K Terry Date: Thu, 23 Sep 2021 12:07:42 -0400 Subject: [PATCH 3/5] fix pettingzoo dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index bf48020e5..6fb93728c 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ def get_version() -> str: "torch>=1.4.0", "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements + "pettingzoo>=1.11.0", ], extras_require={ "dev": [ From 135a259a8a57c8c73a6d421c73cd350c52efff00 Mon Sep 17 00:00:00 2001 From: J K Terry Date: Thu, 23 Sep 2021 19:53:54 -0400 Subject: [PATCH 4/5] try more fixes --- setup.py | 1 + test/pettingzoo/pistonball_train.py | 8 +++++--- tianshou/env/pettingzoo_env.py | 2 +- tianshou/policy/multiagent/mapolicy.py | 5 +++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 6fb93728c..0ebf735fb 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ def get_version() -> str: "pydocstyle", "doc8", "scipy", + "pillow", ], "atari": ["atari_py", "opencv-python"], "mujoco": ["mujoco_py"], diff --git a/test/pettingzoo/pistonball_train.py b/test/pettingzoo/pistonball_train.py index dcc2e34cb..424dc7fcc 100644 --- a/test/pettingzoo/pistonball_train.py +++ b/test/pettingzoo/pistonball_train.py @@ -20,12 +20,14 @@ def get_env(): train_envs.seed(1626) # test_envs.seed(1626) -policy = MultiAgentPolicyManager([RandomPolicy() for _ in range(len(get_env().agents))], get_env()) +policy = MultiAgentPolicyManager([RandomPolicy() + for _ in range(len(get_env().agents))], + get_env()) # collector train_collector = Collector(policy, train_envs, - VectorReplayBuffer(6, len(train_envs)), - exploration_noise=True) + VectorReplayBuffer(6, len(train_envs)), + exploration_noise=True) # test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=640, render=0.0001) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 8eac23777..6924b0a4d 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -68,7 +68,7 @@ def step(self, action): 'obs': observation, 'mask': [True] * self.action_spaces[self.env.agent_selection].n } - + for agent_id, reward in self.env.rewards.items(): self.rewards[self.agent_idx[agent_id]] = reward return obs, self.rewards, done, info diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 4407944cb..3a5691524 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -18,7 +18,8 @@ class MultiAgentPolicyManager(BasePolicy): def __init__(self, policies: List[BasePolicy], env: AECEnv, **kwargs: Any) -> None: super().__init__(**kwargs) - assert (len(policies) == len(env.agents)), "One policy must be assigned for each agent." + assert (len(policies) == len(env.agents)), "One policy must be \ + assigned for each agent." self.agent_idx = env.agent_idx for i, policy in enumerate(policies): # agent_id 0 is reserved for the environment proxy @@ -105,7 +106,7 @@ def forward( # type: ignore } """ results: List[Tuple[bool, np.ndarray, Batch, - Union[np.ndarray, Batch], Batch]] = [] + Union[np.ndarray, Batch], Batch]] = [] for agent_id, policy in self.policies.items(): # This part of code is difficult to understand. # Let's follow an example with two agents From f556ceb73a4d194733e3527af3eedc4f8bfbc90d Mon Sep 17 00:00:00 2001 From: J K Terry Date: Fri, 24 Sep 2021 11:38:19 -0400 Subject: [PATCH 5/5] fix flake8 --- test/pettingzoo/pistonball_train.py | 2 +- tianshou/policy/multiagent/mapolicy.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/pettingzoo/pistonball_train.py b/test/pettingzoo/pistonball_train.py index 424dc7fcc..54149c691 100644 --- a/test/pettingzoo/pistonball_train.py +++ b/test/pettingzoo/pistonball_train.py @@ -20,7 +20,7 @@ def get_env(): train_envs.seed(1626) # test_envs.seed(1626) -policy = MultiAgentPolicyManager([RandomPolicy() +policy = MultiAgentPolicyManager([RandomPolicy() for _ in range(len(get_env().agents))], get_env()) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 3a5691524..dfd96acd2 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -136,7 +136,7 @@ def forward( # type: ignore (has_data, agent_index, out, act, each_state) in results if has_data]) state_dict, out_dict = {}, {} - for (agent_id, policy), (has_data, agent_index, out, act, state) in zip( + for (agent_id, _), (has_data, agent_index, out, act, state) in zip( self.policies.items(), results): if has_data: holder.act[agent_index] = act