diff --git a/setup.py b/setup.py index 00af4fefa..2e8a5bca3 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ def get_version() -> str: "torch>=1.4.0", "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements + "pettingzoo>=1.12,<=1.13", ], extras_require={ "dev": [ @@ -74,6 +75,9 @@ def get_version() -> str: "pydocstyle", "doc8", "scipy", + "pillow", + "pygame>=2.1.0", # pettingzoo test cases pistonball + "pymunk>=6.2.1", # pettingzoo test cases pistonball ], "atari": ["atari_py", "opencv-python"], "mujoco": ["mujoco_py"], diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py deleted file mode 100644 index 91d3cbc63..000000000 --- a/test/multiagent/Gomoku.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import pprint -from copy import deepcopy - -import numpy as np -from tic_tac_toe import get_agents, get_parser, train_agent, watch -from tic_tac_toe_env import TicTacToeEnv -from torch.utils.tensorboard import SummaryWriter - -from tianshou.data import Collector -from tianshou.env import DummyVectorEnv -from tianshou.policy import RandomPolicy -from tianshou.utils import TensorboardLogger - - -def get_args(): - parser = get_parser() - parser.add_argument('--self_play_round', type=int, default=20) - args = parser.parse_known_args()[0] - return args - - -def gomoku(args=get_args()): - Collector._default_rew_metric = lambda x: x[args.agent_id - 1] - if args.watch: - watch(args) - return - - policy, optim = get_agents(args) - agent_learn = policy.policies[args.agent_id - 1] - agent_opponent = policy.policies[2 - args.agent_id] - - # log - log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') - writer = SummaryWriter(log_path) - args.logger = TensorboardLogger(writer) - - opponent_pool = [agent_opponent] - - def env_func(): - return TicTacToeEnv(args.board_size, args.win_size) - - test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) - for round in range(args.self_play_round): - rews = [] - agent_learn.set_eps(0.0) - # compute the reward over previous learner - for opponent in opponent_pool: - policy.replace_policy(opponent, 3 - args.agent_id) - test_collector = Collector(policy, test_envs) - results = test_collector.collect(n_episode=100) - rews.append(results['rews'].mean()) - rews = np.array(rews) - # weight opponent by their difficulty level - rews = np.exp(-rews * 10.0) - rews /= np.sum(rews) - total_epoch = args.epoch - args.epoch = 1 - for epoch in range(total_epoch): - # sample one opponent - opp_id = np.random.choice(len(opponent_pool), size=1, p=rews) - print(f'selection probability {rews.tolist()}') - print(f'selected opponent {opp_id}') - opponent = opponent_pool[opp_id.item(0)] - agent = RandomPolicy() - # previous learner can only be used for forward - agent.forward = opponent.forward - args.model_save_path = os.path.join( - args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth' - ) - result, agent_learn = train_agent( - args, agent_learn=agent_learn, agent_opponent=agent, optim=optim - ) - print(f'round_{round}_epoch_{epoch}') - pprint.pprint(result) - learnt_agent = deepcopy(agent_learn) - learnt_agent.set_eps(0.0) - opponent_pool.append(learnt_agent) - args.epoch = total_epoch - if __name__ == '__main__': - # Let's watch its performance! - opponent = opponent_pool[-2] - watch(args, agent_learn, opponent) - - -if __name__ == '__main__': - gomoku(get_args()) diff --git a/test/multiagent/tic_tac_toe_env.py b/test/multiagent/tic_tac_toe_env.py deleted file mode 100644 index 2c79d303d..000000000 --- a/test/multiagent/tic_tac_toe_env.py +++ /dev/null @@ -1,148 +0,0 @@ -from functools import partial -from typing import Optional, Tuple - -import gym -import numpy as np - -from tianshou.env import MultiAgentEnv - - -class TicTacToeEnv(MultiAgentEnv): - """This is a simple implementation of the Tic-Tac-Toe game, where two - agents play against each other. - - The implementation is intended to show how to wrap an environment to - satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`. - - :param size: the size of the board (square board) - :param win_size: how many units in a row is considered to win - """ - - def __init__(self, size: int = 3, win_size: int = 3): - super().__init__() - assert size > 0, f'board size should be positive, but got {size}' - self.size = size - assert win_size > 0, f'win-size should be positive, but got {win_size}' - self.win_size = win_size - assert win_size <= size, f'win-size {win_size} should not ' \ - f'be larger than board size {size}' - self.convolve_kernel = np.ones(win_size) - self.observation_space = gym.spaces.Box( - low=-1.0, high=1.0, shape=(size, size), dtype=np.float32 - ) - self.action_space = gym.spaces.Discrete(size * size) - self.current_board = None - self.current_agent = None - self._last_move = None - self.step_num = None - - def reset(self) -> dict: - self.current_board = np.zeros((self.size, self.size), dtype=np.int32) - self.current_agent = 1 - self._last_move = (-1, -1) - self.step_num = 0 - return { - 'agent_id': self.current_agent, - 'obs': np.array(self.current_board), - 'mask': self.current_board.flatten() == 0 - } - - def step(self, action: [int, - np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]: - if self.current_agent is None: - raise ValueError("calling step() of unreset environment is prohibited!") - assert 0 <= action < self.size * self.size - assert self.current_board.item(action) == 0 - _current_agent = self.current_agent - self._move(action) - mask = self.current_board.flatten() == 0 - is_win, is_opponent_win = False, False - is_win = self._test_win() - # the game is over when one wins or there is only one empty place - done = is_win - if sum(mask) == 1: - done = True - self._move(np.where(mask)[0][0]) - is_opponent_win = self._test_win() - if is_win: - reward = 1 - elif is_opponent_win: - reward = -1 - else: - reward = 0 - obs = { - 'agent_id': self.current_agent, - 'obs': np.array(self.current_board), - 'mask': mask - } - rew_agent_1 = reward if _current_agent == 1 else (-reward) - rew_agent_2 = reward if _current_agent == 2 else (-reward) - vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32) - if done: - self.current_agent = None - return obs, vec_rew, np.array(done), {} - - def _move(self, action): - row, col = action // self.size, action % self.size - if self.current_agent == 1: - self.current_board[row, col] = 1 - else: - self.current_board[row, col] = -1 - self.current_agent = 3 - self.current_agent - self._last_move = (row, col) - self.step_num += 1 - - def _test_win(self): - """test if someone wins by checking the situation around last move""" - row, col = self._last_move - rboard = self.current_board[row, :] - cboard = self.current_board[:, col] - current = self.current_board[row, col] - rightup = [ - self.current_board[row - i, col + i] for i in range(1, self.size - col) - if row - i >= 0 - ] - leftdown = [ - self.current_board[row + i, col - i] for i in range(1, col + 1) - if row + i < self.size - ] - rdiag = np.array(leftdown[::-1] + [current] + rightup) - rightdown = [ - self.current_board[row + i, col + i] for i in range(1, self.size - col) - if row + i < self.size - ] - leftup = [ - self.current_board[row - i, col - i] for i in range(1, col + 1) - if row - i >= 0 - ] - diag = np.array(leftup[::-1] + [current] + rightdown) - results = [ - np.convolve(k, self.convolve_kernel, mode='valid') - for k in (rboard, cboard, rdiag, diag) - ] - return any([(np.abs(x) == self.win_size).any() for x in results]) - - def seed(self, seed: Optional[int] = None) -> int: - pass - - def render(self, **kwargs) -> None: - print(f'board (step {self.step_num}):') - pad = '===' - top = pad + '=' * (2 * self.size - 1) + pad - print(top) - - def f(i, data): - j, number = data - last_move = i == self._last_move[0] and j == self._last_move[1] - if number == 1: - return 'X' if last_move else 'x' - if number == -1: - return 'O' if last_move else 'o' - return '_' - - for i, row in enumerate(self.current_board): - print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad) - print(top) - - def close(self) -> None: - pass diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py new file mode 100644 index 000000000..2658b677a --- /dev/null +++ b/test/pettingzoo/pistonball.py @@ -0,0 +1,185 @@ +import argparse +import os +from typing import List, Optional, Tuple + +import gym +import numpy as np +import pettingzoo.butterfly.pistonball_v4 as pistonball_v4 +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=2000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument( + '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' + ) + parser.add_argument( + '--n-pistons', + type=int, + default=3, + help='Number of pistons(agents) in the env' + ) + parser.add_argument('--n-step', type=int, default=100) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=3) + parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=100) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.0) + + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='no training, ' + 'watch the play of pre-trained models' + ) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + return parser + + +def get_args() -> argparse.Namespace: + parser = get_parser() + return parser.parse_known_args()[0] + + +def get_env(args: argparse.Namespace = get_args()): + return PettingZooEnv(pistonball_v4.env(continuous=False, n_pistons=args.n_pistons)) + + +def get_agents( + args: argparse.Namespace = get_args(), + agents: Optional[List[BasePolicy]] = None, + optims: Optional[List[torch.optim.Optimizer]] = None, +) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]: + env = get_env() + observation_space = env.observation_space['observation'] if isinstance( + env.observation_space, gym.spaces.Dict + ) else env.observation_space + args.state_shape = observation_space.shape or observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + if agents is None: + agents = [] + optims = [] + for _ in range(args.n_pistons): + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + agent = DQNPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) + agents.append(agent) + optims.append(optim) + + policy = MultiAgentPolicyManager(agents, env) + return policy, optims, env.agents + + +def train_agent( + args: argparse.Namespace = get_args(), + agents: Optional[List[BasePolicy]] = None, + optims: Optional[List[torch.optim.Optimizer]] = None, +) -> Tuple[dict, BasePolicy]: + train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + policy, optim, agents = get_agents(args, agents=agents, optims=optims) + + # collector + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True + ) + test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, 'pistonball', 'dqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + pass + + def stop_fn(mean_rewards): + return False + + def train_fn(epoch, env_step): + [agent.set_eps(args.eps_train) for agent in policy.policies.values()] + + def test_fn(epoch, env_step): + [agent.set_eps(args.eps_test) for agent in policy.policies.values()] + + def reward_metric(rews): + return rews[:, 0] + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric + ) + + return result, policy + + +def watch( + args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None +) -> None: + env = get_env() + policy.eval() + [agent.set_eps(args.eps_test) for agent in policy.policies.values()] + collector = Collector(policy, env, exploration_noise=True) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py new file mode 100644 index 000000000..a7a2983c9 --- /dev/null +++ b/test/pettingzoo/pistonball_continuous.py @@ -0,0 +1,276 @@ +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import gym +import numpy as np +import pettingzoo.butterfly.pistonball_v4 as pistonball_v4 +import torch +import torch.nn as nn +from torch.distributions import Independent, Normal +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.continuous import ActorProb, Critic + + +class DQN(nn.Module): + """Reference: Human-level control through deep reinforcement learning. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + c: int, + h: int, + w: int, + device: Union[str, int, torch.device] = "cpu", + ) -> None: + super().__init__() + self.device = device + self.c = c + self.h = h + self.w = w + self.net = nn.Sequential( + nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True), + nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True), + nn.Flatten() + ) + with torch.no_grad(): + self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:]) + + def forward( + self, + x: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: + r"""Mapping: x -> Q(x, \*).""" + x = torch.as_tensor(x, device=self.device, dtype=torch.float32) + return self.net(x.reshape(-1, self.c, self.w, self.h)), state + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=2000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument( + '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' + ) + parser.add_argument( + '--n-pistons', + type=int, + default=3, + help='Number of pistons(agents) in the env' + ) + parser.add_argument('--n-step', type=int, default=100) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--episode-per-collect', type=int, default=16) + parser.add_argument('--repeat-per-collect', type=int, default=2) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=1000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) + parser.add_argument('--training-num', type=int, default=1000) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='no training, ' + 'watch the play of pre-trained models' + ) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + # ppo special + parser.add_argument('--vf-coef', type=float, default=0.25) + parser.add_argument('--ent-coef', type=float, default=0.0) + parser.add_argument('--eps-clip', type=float, default=0.2) + parser.add_argument('--max-grad-norm', type=float, default=0.5) + parser.add_argument('--gae-lambda', type=float, default=0.95) + parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--dual-clip', type=float, default=None) + parser.add_argument('--value-clip', type=int, default=1) + parser.add_argument('--norm-adv', type=int, default=1) + parser.add_argument('--recompute-adv', type=int, default=0) + parser.add_argument('--resume', action="store_true") + parser.add_argument("--save-interval", type=int, default=4) + parser.add_argument('--render', type=float, default=0.0) + + return parser + + +def get_args() -> argparse.Namespace: + parser = get_parser() + return parser.parse_known_args()[0] + + +def get_env(args: argparse.Namespace = get_args()): + return PettingZooEnv(pistonball_v4.env(continuous=True, n_pistons=args.n_pistons)) + + +def get_agents( + args: argparse.Namespace = get_args(), + agents: Optional[List[BasePolicy]] = None, + optims: Optional[List[torch.optim.Optimizer]] = None, +) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]: + env = get_env() + observation_space = env.observation_space['observation'] if isinstance( + env.observation_space, gym.spaces.Dict + ) else env.observation_space + args.state_shape = observation_space.shape or observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + + if agents is None: + agents = [] + optims = [] + for _ in range(args.n_pistons): + # model + net = DQN( + observation_space.shape[2], + observation_space.shape[1], + observation_space.shape[0], + device=args.device + ).to(args.device) + + actor = ActorProb( + net, args.action_shape, max_action=args.max_action, device=args.device + ).to(args.device) + net2 = DQN( + observation_space.shape[2], + observation_space.shape[1], + observation_space.shape[0], + device=args.device + ).to(args.device) + critic = Critic(net2, device=args.device).to(args.device) + for m in set(actor.modules()).union(critic.modules()): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + optim = torch.optim.Adam( + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) + + def dist(*logits): + return Independent(Normal(*logits), 1) + + agent = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + max_grad_norm=args.max_grad_norm, + eps_clip=args.eps_clip, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + reward_normalization=args.rew_norm, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv, + # dual_clip=args.dual_clip, + # dual clip cause monotonically increasing log_std :) + value_clip=args.value_clip, + gae_lambda=args.gae_lambda, + action_space=env.action_space + ) + + agents.append(agent) + optims.append(optim) + + policy = MultiAgentPolicyManager( + agents, env, action_scaling=True, action_bound_method='clip' + ) + return policy, optims, env.agents + + +def train_agent( + args: argparse.Namespace = get_args(), + agents: Optional[List[BasePolicy]] = None, + optims: Optional[List[torch.optim.Optimizer]] = None, +) -> Tuple[dict, BasePolicy]: + train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + policy, optim, agents = get_agents(args, agents=agents, optims=optims) + + # collector + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=False # True + ) + test_collector = Collector(policy, test_envs) + # train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, 'pistonball', 'dqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + pass + + def stop_fn(mean_rewards): + return False + + def train_fn(epoch, env_step): + [agent.set_eps(args.eps_train) for agent in policy.policies.values()] + + def test_fn(epoch, env_step): + [agent.set_eps(args.eps_test) for agent in policy.policies.values()] + + def reward_metric(rews): + return rews[:, 0] + + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger, + resume_from_log=args.resume + ) + + return result, policy + + +def watch( + args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None +) -> None: + env = get_env() + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py new file mode 100644 index 000000000..19b3cff93 --- /dev/null +++ b/test/pettingzoo/test_pistonball.py @@ -0,0 +1,21 @@ +import pprint + +from pistonball import get_args, train_agent, watch + + +def test_piston_ball(args=get_args()): + if args.watch: + watch(args) + return + + result, agent = train_agent(args) + # assert result["best_reward"] >= args.win_rate + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + watch(args, agent) + + +if __name__ == '__main__': + test_piston_ball(get_args()) diff --git a/test/pettingzoo/test_pistonball_continuous.py b/test/pettingzoo/test_pistonball_continuous.py new file mode 100644 index 000000000..e62b07521 --- /dev/null +++ b/test/pettingzoo/test_pistonball_continuous.py @@ -0,0 +1,21 @@ +import pprint + +from pistonball_continuous import get_args, train_agent, watch + + +def test_piston_ball_continuous(args=get_args()): + if args.watch: + watch(args) + return + + result, agent = train_agent(args) + assert result["best_reward"] >= 30.0 + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + watch(args, agent) + + +if __name__ == '__main__': + test_piston_ball_continuous(get_args()) diff --git a/test/multiagent/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py similarity index 95% rename from test/multiagent/test_tic_tac_toe.py rename to test/pettingzoo/test_tic_tac_toe.py index aeb4644e1..e4b517abf 100644 --- a/test/multiagent/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -1,21 +1,21 @@ -import pprint - -from tic_tac_toe import get_args, train_agent, watch - - -def test_tic_tac_toe(args=get_args()): - if args.watch: - watch(args) - return - - result, agent = train_agent(args) - assert result["best_reward"] >= args.win_rate - - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - watch(args, agent) - - -if __name__ == '__main__': - test_tic_tac_toe(get_args()) +import pprint + +from tic_tac_toe import get_args, train_agent, watch + + +def test_tic_tac_toe(args=get_args()): + if args.watch: + watch(args) + return + + result, agent = train_agent(args) + assert result["best_reward"] >= args.win_rate + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + watch(args, agent) + + +if __name__ == '__main__': + test_tic_tac_toe(get_args()) diff --git a/test/multiagent/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py similarity index 78% rename from test/multiagent/tic_tac_toe.py rename to test/pettingzoo/tic_tac_toe.py index 02fd47cd7..aed65a3cf 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -1,232 +1,241 @@ -import argparse -import os -from copy import deepcopy -from typing import Optional, Tuple - -import numpy as np -import torch -from tic_tac_toe_env import TicTacToeEnv -from torch.utils.tensorboard import SummaryWriter - -from tianshou.data import Collector, VectorReplayBuffer -from tianshou.env import DummyVectorEnv -from tianshou.policy import ( - BasePolicy, - DQNPolicy, - MultiAgentPolicyManager, - RandomPolicy, -) -from tianshou.trainer import offpolicy_trainer -from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net - - -def get_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--eps-test', type=float, default=0.05) - parser.add_argument('--eps-train', type=float, default=0.1) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument( - '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' - ) - parser.add_argument('--n-step', type=int, default=3) - parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=5000) - parser.add_argument('--step-per-collect', type=int, default=10) - parser.add_argument('--update-per-step', type=float, default=0.1) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument( - '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] - ) - parser.add_argument('--training-num', type=int, default=10) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.1) - parser.add_argument('--board-size', type=int, default=6) - parser.add_argument('--win-size', type=int, default=4) - parser.add_argument( - '--win-rate', type=float, default=0.9, help='the expected winning rate' - ) - parser.add_argument( - '--watch', - default=False, - action='store_true', - help='no training, ' - 'watch the play of pre-trained models' - ) - parser.add_argument( - '--agent-id', - type=int, - default=2, - help='the learned agent plays as the' - ' agent_id-th player. Choices are 1 and 2.' - ) - parser.add_argument( - '--resume-path', - type=str, - default='', - help='the path of agent pth file ' - 'for resuming from a pre-trained agent' - ) - parser.add_argument( - '--opponent-path', - type=str, - default='', - help='the path of opponent agent pth file ' - 'for resuming from a pre-trained agent' - ) - parser.add_argument( - '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' - ) - return parser - - -def get_args() -> argparse.Namespace: - parser = get_parser() - return parser.parse_known_args()[0] - - -def get_agents( - args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, -) -> Tuple[BasePolicy, torch.optim.Optimizer]: - env = TicTacToeEnv(args.board_size, args.win_size) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - if agent_learn is None: - # model - net = Net( - args.state_shape, - args.action_shape, - hidden_sizes=args.hidden_sizes, - device=args.device - ).to(args.device) - if optim is None: - optim = torch.optim.Adam(net.parameters(), lr=args.lr) - agent_learn = DQNPolicy( - net, - optim, - args.gamma, - args.n_step, - target_update_freq=args.target_update_freq - ) - if args.resume_path: - agent_learn.load_state_dict(torch.load(args.resume_path)) - - if agent_opponent is None: - if args.opponent_path: - agent_opponent = deepcopy(agent_learn) - agent_opponent.load_state_dict(torch.load(args.opponent_path)) - else: - agent_opponent = RandomPolicy() - - if args.agent_id == 1: - agents = [agent_learn, agent_opponent] - else: - agents = [agent_opponent, agent_learn] - policy = MultiAgentPolicyManager(agents) - return policy, optim - - -def train_agent( - args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, - optim: Optional[torch.optim.Optimizer] = None, -) -> Tuple[dict, BasePolicy]: - - def env_func(): - return TicTacToeEnv(args.board_size, args.win_size) - - train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)]) - test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - - policy, optim = get_agents( - args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim - ) - - # collector - train_collector = Collector( - policy, - train_envs, - VectorReplayBuffer(args.buffer_size, len(train_envs)), - exploration_noise=True - ) - test_collector = Collector(policy, test_envs, exploration_noise=True) - # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) - # log - log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - logger = TensorboardLogger(writer) - - def save_fn(policy): - if hasattr(args, 'model_save_path'): - model_save_path = args.model_save_path - else: - model_save_path = os.path.join( - args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth' - ) - torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path) - - def stop_fn(mean_rewards): - return mean_rewards >= args.win_rate - - def train_fn(epoch, env_step): - policy.policies[args.agent_id - 1].set_eps(args.eps_train) - - def test_fn(epoch, env_step): - policy.policies[args.agent_id - 1].set_eps(args.eps_test) - - def reward_metric(rews): - return rews[:, args.agent_id - 1] - - # trainer - result = offpolicy_trainer( - policy, - train_collector, - test_collector, - args.epoch, - args.step_per_epoch, - args.step_per_collect, - args.test_num, - args.batch_size, - train_fn=train_fn, - test_fn=test_fn, - stop_fn=stop_fn, - save_fn=save_fn, - update_per_step=args.update_per_step, - logger=logger, - test_in_train=False, - reward_metric=reward_metric - ) - - return result, policy.policies[args.agent_id - 1] - - -def watch( - args: argparse.Namespace = get_args(), - agent_learn: Optional[BasePolicy] = None, - agent_opponent: Optional[BasePolicy] = None, -) -> None: - env = TicTacToeEnv(args.board_size, args.win_size) - policy, optim = get_agents( - args, agent_learn=agent_learn, agent_opponent=agent_opponent - ) - policy.eval() - policy.policies[args.agent_id - 1].set_eps(args.eps_test) - collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") +import argparse +import os +from copy import deepcopy +from typing import Optional, Tuple + +import gym +import numpy as np +import torch +from pettingzoo.classic import tictactoe_v3 +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.env.pettingzoo_env import PettingZooEnv +from tianshou.policy import ( + BasePolicy, + DQNPolicy, + MultiAgentPolicyManager, + RandomPolicy, +) +from tianshou.trainer import offpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net + + +def get_env(): + return PettingZooEnv(tictactoe_v3.env()) + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument( + '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' + ) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=50) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument( + '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] + ) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.1) + parser.add_argument( + '--win-rate', + type=float, + default=0.6, + help='the expected winning rate: Optimal policy can get 0.7' + ) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='no training, ' + 'watch the play of pre-trained models' + ) + parser.add_argument( + '--agent-id', + type=int, + default=2, + help='the learned agent plays as the' + ' agent_id-th player. Choices are 1 and 2.' + ) + parser.add_argument( + '--resume-path', + type=str, + default='', + help='the path of agent pth file ' + 'for resuming from a pre-trained agent' + ) + parser.add_argument( + '--opponent-path', + type=str, + default='', + help='the path of opponent agent pth file ' + 'for resuming from a pre-trained agent' + ) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + return parser + + +def get_args() -> argparse.Namespace: + parser = get_parser() + return parser.parse_known_args()[0] + + +def get_agents( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[BasePolicy, torch.optim.Optimizer, list]: + env = get_env() + observation_space = env.observation_space['observation'] if isinstance( + env.observation_space, gym.spaces.Dict + ) else env.observation_space + args.state_shape = observation_space.shape or observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + if agent_learn is None: + # model + net = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device + ).to(args.device) + if optim is None: + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + agent_learn = DQNPolicy( + net, + optim, + args.gamma, + args.n_step, + target_update_freq=args.target_update_freq + ) + if args.resume_path: + agent_learn.load_state_dict(torch.load(args.resume_path)) + + if agent_opponent is None: + if args.opponent_path: + agent_opponent = deepcopy(agent_learn) + agent_opponent.load_state_dict(torch.load(args.opponent_path)) + else: + agent_opponent = RandomPolicy() + + if args.agent_id == 1: + agents = [agent_learn, agent_opponent] + else: + agents = [agent_opponent, agent_learn] + policy = MultiAgentPolicyManager(agents, env) + return policy, optim, env.agents + + +def train_agent( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, +) -> Tuple[dict, BasePolicy]: + + train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim + ) + + # collector + train_collector = Collector( + policy, + train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True + ) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + if hasattr(args, 'model_save_path'): + model_save_path = args.model_save_path + else: + model_save_path = os.path.join( + args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth' + ) + torch.save( + policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path + ) + + def stop_fn(mean_rewards): + return mean_rewards >= args.win_rate + + def train_fn(epoch, env_step): + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train) + + def test_fn(epoch, env_step): + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + + def reward_metric(rews): + return rews[:, args.agent_id - 1] + + # trainer + result = offpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.step_per_collect, + args.test_num, + args.batch_size, + train_fn=train_fn, + test_fn=test_fn, + stop_fn=stop_fn, + save_fn=save_fn, + update_per_step=args.update_per_step, + logger=logger, + test_in_train=False, + reward_metric=reward_metric + ) + + return result, policy.policies[agents[args.agent_id - 1]] + + +def watch( + args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, +) -> None: + env = get_env() + policy, optim, agents = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent + ) + policy.eval() + policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) + collector = Collector(policy, env, exploration_noise=True) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index c77c30c3f..f32e0cff0 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,6 +1,6 @@ """Env package.""" -from tianshou.env.maenv import MultiAgentEnv +from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.env.venvs import ( BaseVectorEnv, DummyVectorEnv, @@ -15,5 +15,5 @@ "SubprocVectorEnv", "ShmemVectorEnv", "RayVectorEnv", - "MultiAgentEnv", + "PettingZooEnv", ] diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py deleted file mode 100644 index 6fa1ad2f8..000000000 --- a/tianshou/env/maenv.py +++ /dev/null @@ -1,65 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple - -import gym -import numpy as np - - -class MultiAgentEnv(ABC, gym.Env): - """The interface for multi-agent environments. - - Multi-agent environments must be wrapped as - :class:`~tianshou.env.MultiAgentEnv`. Here is the usage: - :: - - env = MultiAgentEnv(...) - # obs is a dict containing obs, agent_id, and mask - obs = env.reset() - act = policy(obs) - obs, rew, done, info = env.step(act) - env.close() - - The available action's mask is set to 1, otherwise it is set to 0. Further - usage can be found at :ref:`marl_example`. - """ - - def __init__(self) -> None: - pass - - @abstractmethod - def reset(self) -> dict: - """Reset the state. - - Return the initial state, first agent_id, and the initial action set, - for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``. - """ - pass - - @abstractmethod - def step( - self, action: np.ndarray - ) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]: - """Run one timestep of the environment’s dynamics. - - When the end of episode is reached, you are responsible for calling - reset() to reset the environment’s state. - - Accept action and return a tuple (obs, rew, done, info). - - :param numpy.ndarray action: action provided by a agent. - - :return: A tuple including four items: - - * ``obs`` a dict containing obs, agent_id, and mask, which means \ - that it is the ``agent_id`` player's turn to play with ``obs``\ - observation and ``mask``. - * ``rew`` a numpy.ndarray, the amount of rewards returned after \ - previous actions. Depending on the specific environment, this \ - can be either a scalar reward for current agent or a vector \ - reward for all the agents. - * ``done`` a numpy.ndarray, whether the episode has ended, in \ - which case further step() calls will return undefined results - * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ - information (helpful for debugging, and sometimes learning) - """ - pass diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py new file mode 100644 index 000000000..de752516a --- /dev/null +++ b/tianshou/env/pettingzoo_env.py @@ -0,0 +1,112 @@ +from abc import ABC +from typing import Any, Dict, List, Tuple + +import gym.spaces +from pettingzoo.utils.env import AECEnv +from pettingzoo.utils.wrappers import BaseWrapper + + +class PettingZooEnv(AECEnv, gym.Env, ABC): + """The interface for petting zoo environments. + + Multi-agent environments must be wrapped as + :class:`~tianshou.env.PettingZooEnv`. Here is the usage: + :: + + env = PettingZooEnv(...) + # obs is a dict containing obs, agent_id, and mask + obs = env.reset() + action = policy(obs) + obs, rew, done, info = env.step(action) + env.close() + + The available action's mask is set to True, otherwise it is set to False. + Further usage can be found at :ref:`marl_example`. + """ + + def __init__(self, env: BaseWrapper): + super().__init__() + self.env = env + # agent idx list + self.agents = self.env.possible_agents + self.agent_idx = {} + for i, agent_id in enumerate(self.agents): + self.agent_idx[agent_id] = i + # Get dictionaries of obs_spaces and act_spaces + self.observation_spaces = self.env.observation_spaces + self.action_spaces = self.env.action_spaces + + self.rewards = [0] * len(self.agents) + + # Get first observation space, assuming all agents have equal space + self.observation_space: Any = self.observation_space(self.agents[0]) + + # Get first action space, assuming all agents have equal space + self.action_space: Any = self.action_space(self.agents[0]) + + assert all(self.env.observation_space(agent) == self.observation_space + for agent in self.agents), \ + "Observation spaces for all agents must be identical. Perhaps " \ + "SuperSuit's pad_observations wrapper can help (useage: " \ + "`supersuit.aec_wrappers.pad_observations(env)`" + + assert all(self.env.action_space(agent) == self.action_space + for agent in self.agents), \ + "Action spaces for all agents must be identical. Perhaps " \ + "SuperSuit's pad_action_space wrapper can help (useage: " \ + "`supersuit.aec_wrappers.pad_action_space(env)`" + + self.reset() + + def reset(self) -> dict: + self.env.reset() + observation = self.env.observe(self.env.agent_selection) + if isinstance(observation, dict) and 'action_mask' in observation: + return { + 'agent_id': self.env.agent_selection, + 'obs': observation['observation'], + 'mask': + [True if obm == 1 else False for obm in observation['action_mask']] + } + else: + if isinstance(self.action_space, gym.spaces.Discrete): + return { + 'agent_id': self.env.agent_selection, + 'obs': observation, + 'mask': [True] * self.env.action_space(self.env.agent_selection).n + } + else: + return {'agent_id': self.env.agent_selection, 'obs': observation} + + def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: + self.env.step(action) + observation, rew, done, info = self.env.last() + if isinstance(observation, dict) and 'action_mask' in observation: + obs = { + 'agent_id': self.env.agent_selection, + 'obs': observation['observation'], + 'mask': + [True if obm == 1 else False for obm in observation['action_mask']] + } + else: + if isinstance(self.action_space, gym.spaces.Discrete): + obs = { + 'agent_id': self.env.agent_selection, + 'obs': observation, + 'mask': [True] * self.env.action_space(self.env.agent_selection).n + } + else: + obs = {'agent_id': self.env.agent_selection, 'obs': observation} + + for agent_id, reward in self.env.rewards.items(): + self.rewards[self.agent_idx[agent_id]] = reward + return obs, self.rewards, done, info + + def close(self) -> None: + self.env.close() + + def seed(self, seed: Any = None) -> None: + self.env.seed(seed) + + def render(self, mode: str = "human") -> Any: + return self.env.render(mode) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index c668109b6..02b50b183 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -2,6 +2,7 @@ import gym import numpy as np +import pettingzoo from tianshou.env.worker import ( DummyEnvWorker, @@ -364,7 +365,10 @@ class DummyVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + def __init__( + self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]], + **kwargs: Any + ) -> None: super().__init__(env_fns, DummyEnvWorker, **kwargs) diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 75705f4a3..6c69a8d36 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -3,6 +3,7 @@ import numpy as np from tianshou.data import Batch, ReplayBuffer +from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy @@ -16,21 +17,29 @@ class MultiAgentPolicyManager(BasePolicy): :ref:`marl_example` can help you better understand this procedure. """ - def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None: - super().__init__(**kwargs) - self.policies = policies + def __init__( + self, policies: List[BasePolicy], env: PettingZooEnv, **kwargs: Any + ) -> None: + super().__init__(action_space=env.action_space, **kwargs) + assert ( + len(policies) == len(env.agents) + ), "One policy must be assigned for each agent." + + self.agent_idx = env.agent_idx for i, policy in enumerate(policies): # agent_id 0 is reserved for the environment proxy # (this MultiAgentPolicyManager) - policy.set_agent_id(i + 1) + policy.set_agent_id(env.agents[i]) + + self.policies = dict(zip(env.agents, policies)) def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: """Replace the "agent_id"th policy in this manager.""" - self.policies[agent_id - 1] = policy policy.set_agent_id(agent_id) + self.policies[agent_id] = policy def process_fn( - self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's process_fn. @@ -45,18 +54,21 @@ def process_fn( # Since we do not override buffer.__setattr__, here we use _meta to # change buffer.rew, otherwise buffer.rew = Batch() has no effect. save_rew, buffer._meta.rew = buffer.rew, Batch() - for policy in self.policies: - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + for agent, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent)[0] if len(agent_index) == 0: - results[f"agent_{policy.agent_id}"] = Batch() + results[agent] = Batch() continue - tmp_batch, tmp_indices = batch[agent_index], indices[agent_index] + tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] if has_rew: - tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] - buffer._meta.rew = save_rew[:, policy.agent_id - 1] - results[f"agent_{policy.agent_id}"] = policy.process_fn( - tmp_batch, buffer, tmp_indices - ) + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]] + buffer._meta.rew = save_rew[:, self.agent_idx[agent]] + if not hasattr(tmp_batch.obs, "mask"): + if hasattr(tmp_batch.obs, 'obs'): + tmp_batch.obs = tmp_batch.obs.obs + if hasattr(tmp_batch.obs_next, 'obs'): + tmp_batch.obs_next = tmp_batch.obs_next.obs + results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return Batch(results) @@ -64,8 +76,8 @@ def process_fn( def exploration_noise(self, act: Union[np.ndarray, Batch], batch: Batch) -> Union[np.ndarray, Batch]: """Add exploration noise from sub-policy onto act.""" - for policy in self.policies: - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + for agent_id, policy in self.policies.items(): + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: continue act[agent_index] = policy.exploration_noise( @@ -104,7 +116,7 @@ def forward( # type: ignore """ results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch], Batch]] = [] - for policy in self.policies: + for agent_id, policy in self.policies.items(): # This part of code is difficult to understand. # Let's follow an example with two agents # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) @@ -112,7 +124,7 @@ def forward( # type: ignore # agent_index for agent 1 is [0, 2, 4] # agent_index for agent 2 is [1, 3, 5] # we separate the transition of each agent according to agent_id - agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: # (has_data, agent_index, out, act, state) results.append((False, np.array([-1]), Batch(), Batch(), Batch())) @@ -120,11 +132,15 @@ def forward( # type: ignore tmp_batch = batch[agent_index] if isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. - tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] + tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] + if not hasattr(tmp_batch.obs, "mask"): + if hasattr(tmp_batch.obs, 'obs'): + tmp_batch.obs = tmp_batch.obs.obs + if hasattr(tmp_batch.obs_next, 'obs'): + tmp_batch.obs_next = tmp_batch.obs_next.obs out = policy( batch=tmp_batch, - state=None if state is None else state["agent_" + - str(policy.agent_id)], + state=None if state is None else state[agent_id], **kwargs ) act = out.act @@ -141,12 +157,12 @@ def forward( # type: ignore ] ) state_dict, out_dict = {}, {} - for policy, (has_data, agent_index, out, act, - state) in zip(self.policies, results): + for (agent_id, _), (has_data, agent_index, out, act, + state) in zip(self.policies.items(), results): if has_data: holder.act[agent_index] = act - state_dict["agent_" + str(policy.agent_id)] = state - out_dict["agent_" + str(policy.agent_id)] = out + state_dict[agent_id] = state + out_dict[agent_id] = out holder["out"] = out_dict holder["state"] = state_dict return holder @@ -168,10 +184,10 @@ def learn(self, batch: Batch, } """ results = {} - for policy in self.policies: - data = batch[f"agent_{policy.agent_id}"] + for agent_id, policy in self.policies.items(): + data = batch[agent_id] if not data.is_empty(): out = policy.learn(batch=data, **kwargs) for k, v in out.items(): - results["agent_" + str(policy.agent_id) + "/" + k] = v + results[agent_id + "/" + k] = v return results