这是indexloc提供的服务,不要输入任何密码
Skip to content

pettingzoo wrapper #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def get_version() -> str:
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"pettingzoo>=1.11.0",
],
extras_require={
"dev": [
Expand All @@ -73,6 +74,7 @@ def get_version() -> str:
"pydocstyle",
"doc8",
"scipy",
"pillow",
],
"atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"],
Expand Down
33 changes: 33 additions & 0 deletions test/pettingzoo/pistonball_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pettingzoo.butterfly import pistonball_v4
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.env import DummyVectorEnv
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import RandomPolicy, MultiAgentPolicyManager
import numpy as np
import torch


def get_env():
return PettingZooEnv(pistonball_v4.env(continuous=False))


train_envs = DummyVectorEnv([get_env for _ in range(10)])
# test_envs = DummyVectorEnv([get_env for _ in range(100)])

# seed
np.random.seed(1626)
torch.manual_seed(1626)
train_envs.seed(1626)
# test_envs.seed(1626)

policy = MultiAgentPolicyManager([RandomPolicy()
for _ in range(len(get_env().agents))],
get_env())

# collector
train_collector = Collector(policy, train_envs,
VectorReplayBuffer(6, len(train_envs)),
exploration_noise=True)
# test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=640, render=0.0001)
83 changes: 83 additions & 0 deletions tianshou/env/pettingzoo_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from pettingzoo.utils.env import AECEnv


class PettingZooEnv(AECEnv):

def __init__(self, env):
self.env = env
# agent idx list
self.agents = self.env.possible_agents
self.agent_idx = {}
for i, agent_id in enumerate(self.agents):
self.agent_idx[agent_id] = i
# Get dictionaries of obs_spaces and act_spaces
self.observation_spaces = self.env.observation_spaces
self.action_spaces = self.env.action_spaces

self.rewards = [0] * len(self.agents)

# Get first observation space, assuming all agents have equal space
self.observation_space = self.observation_spaces[self.agents[0]]

# Get first action space, assuming all agents have equal space
self.action_space = self.action_spaces[self.agents[0]]

assert all(obs_space == self.observation_space
for obs_space
in self.env.observation_spaces.values()), \
"Observation spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_observations wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_observations(env)`"

assert all(act_space == self.action_space
for act_space in self.env.action_spaces.values()), \
"Action spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_action_space wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_action_space(env)`"

self.reset()

def reset(self):
self.env.reset()
observation = self.env.observe(self.env.agent_selection)
if isinstance(observation, dict) and 'action_mask' in observation:
return {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask': observation['action_mask']
}
else:
return {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.action_spaces[self.env.agent_selection].n
}

def step(self, action):
self.env.step(action)
observation, rew, done, info = self.env.last()
if isinstance(observation, dict) and 'action_mask' in observation:
obs = {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask': observation['action_mask']
}
else:
obs = {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.action_spaces[self.env.agent_selection].n
}

for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
return obs, self.rewards, done, info

def close(self):
self.env.close()

def seed(self, seed=None):
self.env.seed(seed)

def render(self, mode="human"):
return self.env.render(mode)
94 changes: 45 additions & 49 deletions tianshou/policy/multiagent/mapolicy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from pettingzoo.utils.env import AECEnv
from typing import Any, Dict, List, Tuple, Union, Optional

from tianshou.data import Batch, ReplayBuffer
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer


class MultiAgentPolicyManager(BasePolicy):
Expand All @@ -16,21 +16,25 @@ class MultiAgentPolicyManager(BasePolicy):
:ref:`marl_example` can help you better understand this procedure.
"""

def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None:
def __init__(self, policies: List[BasePolicy], env: AECEnv, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.policies = policies
assert (len(policies) == len(env.agents)), "One policy must be \
assigned for each agent."
self.agent_idx = env.agent_idx
for i, policy in enumerate(policies):
# agent_id 0 is reserved for the environment proxy
# (this MultiAgentPolicyManager)
policy.set_agent_id(i + 1)
policy.set_agent_id(env.agents[i])

self.policies = dict(zip(env.agents, policies))

def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
"""Replace the "agent_id"th policy in this manager."""
self.policies[agent_id - 1] = policy
policy.set_agent_id(agent_id)
self.policies[agent_id] = policy

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Dispatch batch data from obs.agent_id to every policy's process_fn.

Expand All @@ -45,32 +49,31 @@ def process_fn(
# Since we do not override buffer.__setattr__, here we use _meta to
# change buffer.rew, otherwise buffer.rew = Batch() has no effect.
save_rew, buffer._meta.rew = buffer.rew, Batch()
for policy in self.policies:
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
for agent_id, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
results[f"agent_{policy.agent_id}"] = Batch()
results[agent_id] = Batch()
continue
tmp_batch, tmp_indices = batch[agent_index], indices[agent_index]
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
if has_rew:
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
results[f"agent_{policy.agent_id}"] = policy.process_fn(
tmp_batch, buffer, tmp_indices
)
results[policy.agent_id] = policy.process_fn(
tmp_batch, buffer, tmp_indice)
if has_rew: # restore from save_rew
buffer._meta.rew = save_rew
return Batch(results)

def exploration_noise(self, act: Union[np.ndarray, Batch],
batch: Batch) -> Union[np.ndarray, Batch]:
def exploration_noise(
self, act: Union[np.ndarray, Batch], batch: Batch
) -> Union[np.ndarray, Batch]:
"""Add exploration noise from sub-policy onto act."""
for policy in self.policies:
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
for agent_id, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
continue
act[agent_index] = policy.exploration_noise(
act[agent_index], batch[agent_index]
)
act[agent_index], batch[agent_index])
return act

def forward( # type: ignore
Expand Down Expand Up @@ -102,57 +105,50 @@ def forward( # type: ignore
"agent_n": xxx}
}
"""
results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
Batch]] = []
for policy in self.policies:
results: List[Tuple[bool, np.ndarray, Batch,
Union[np.ndarray, Batch], Batch]] = []
for agent_id, policy in self.policies.items():
# This part of code is difficult to understand.
# Let's follow an example with two agents
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
# each agent plays for three transitions
# agent_index for agent 1 is [0, 2, 4]
# agent_index for agent 2 is [1, 3, 5]
# we separate the transition of each agent according to agent_id
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
# (has_data, agent_index, out, act, state)
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
continue
tmp_batch = batch[agent_index]
if isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
out = policy(
batch=tmp_batch,
state=None if state is None else state["agent_" +
str(policy.agent_id)],
**kwargs
)
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
out = policy(batch=tmp_batch, state=None if state is None
else state[agent_id],
**kwargs)
act = out.act
each_state = out.state \
if (hasattr(out, "state") and out.state is not None) \
else Batch()
results.append((True, agent_index, out, act, each_state))
holder = Batch.cat(
[
{
"act": act
} for (has_data, agent_index, out, act, each_state) in results
if has_data
]
)
holder = Batch.cat([{"act": act} for
(has_data, agent_index, out, act, each_state)
in results if has_data])
state_dict, out_dict = {}, {}
for policy, (has_data, agent_index, out, act,
state) in zip(self.policies, results):
for (agent_id, _), (has_data, agent_index, out, act, state) in zip(
self.policies.items(), results):
if has_data:
holder.act[agent_index] = act
state_dict["agent_" + str(policy.agent_id)] = state
out_dict["agent_" + str(policy.agent_id)] = out
state_dict[agent_id] = state
out_dict[agent_id] = out
holder["out"] = out_dict
holder["state"] = state_dict
return holder

def learn(self, batch: Batch,
**kwargs: Any) -> Dict[str, Union[float, List[float]]]:
def learn(
self, batch: Batch, **kwargs: Any
) -> Dict[str, Union[float, List[float]]]:
"""Dispatch the data to all policies for learning.

:return: a dict with the following contents:
Expand All @@ -168,10 +164,10 @@ def learn(self, batch: Batch,
}
"""
results = {}
for policy in self.policies:
data = batch[f"agent_{policy.agent_id}"]
for agent_id, policy in self.policies.items():
data = batch[agent_id]
if not data.is_empty():
out = policy.learn(batch=data, **kwargs)
for k, v in out.items():
results["agent_" + str(policy.agent_id) + "/" + k] = v
results[agent_id + "/" + k] = v
return results