diff --git a/.gitignore b/.gitignore index 42152539b..0ecb650d3 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,5 @@ MUJOCO_LOG.TXT *.pth .vscode/ .DS_Store +*.zip +*.pstats diff --git a/README.md b/README.md index fbc8bbbcd..5ba7c8159 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ Here is Tianshou's other features: - Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) - Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process) - Support n-step returns estimation for all Q-learning based algorithms +- Support multi-agent RL easily [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. diff --git a/docs/_static/images/marl.png b/docs/_static/images/marl.png new file mode 100644 index 000000000..cf368d5ef Binary files /dev/null and b/docs/_static/images/marl.png differ diff --git a/docs/_static/images/tic-tac-toe.png b/docs/_static/images/tic-tac-toe.png new file mode 100644 index 000000000..071fa9f5c Binary files /dev/null and b/docs/_static/images/tic-tac-toe.png differ diff --git a/docs/contributor.rst b/docs/contributor.rst index 044ea23c3..62a69277e 100644 --- a/docs/contributor.rst +++ b/docs/contributor.rst @@ -6,3 +6,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom * Jiayi Weng (`Trinkle23897 `_) * Minghao Zhang (`Mehooz `_) * Alexis Duburcq (`duburcqa `_) +* Kaichao You (`youkaichao `_) \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index ef2342d86..27ef950a1 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,6 +28,7 @@ Here is Tianshou's other features: * Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env` * Support customized training process: :ref:`customize_training` * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms +* Support multi-agent RL easily (a tutorial is available at :doc:`/tutorials/tictactoe`) 中文文档位于 https://tianshou.readthedocs.io/zh/latest/ @@ -71,6 +72,7 @@ Tianshou is still under development, you can also check out the documents in sta tutorials/dqn tutorials/concepts tutorials/batch + tutorials/tictactoe tutorials/trick tutorials/cheatsheet diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index d193ae3c9..bae20b19c 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -244,3 +244,46 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y def step(a): ... return copy.deepcopy(self.graph), reward, done, {} + +.. _marl_example: + +Multi-Agent Reinforcement Learning +---------------------------------- + +This is related to `Issue 121 `_. The discussion is still goes on. + +With the flexible core APIs, Tianshou can support multi-agent reinforcement learning with minimal efforts. + +Currently, we support three types of multi-agent reinforcement learning paradigms: + +1. Simultaneous move: at each timestep, all the agents take their actions (example: moba games) + +2. Cyclic move: players take action in turn (example: Go game) + +3. Conditional move, at each timestep, the environment conditionally selects an agent to take action. (example: `Pig Game `_) + +We mainly address these multi-agent RL problems by converting them into traditional RL formulations. + +For simultaneous move, the solution is simple: we can just add a ``num_agent`` dimension to state, action, and reward. Nothing else is going to change. + +For 2 & 3 (cyclic move and conditional move), they can be unified into a single framework: at each timestep, the environment selects an agent with id ``agent_id`` to play. Since multi-agents are usually wrapped into one object (which we call "abstract agent"), we can pass the ``agent_id`` to the "abstract agent", leaving it to further call the specific agent. + +In addition, legal actions in multi-agent RL often vary with timestep (just like Go games), so the environment should also passes the legal action mask to the "abstract agent", where the mask is a boolean array that "True" for available actions and "False" for illegal actions at the current step. Below is a figure that explains the abstract agent. + +.. image:: /_static/images/marl.png + :align: center + :height: 300 + +The above description gives rise to the following formulation of multi-agent RL: +:: + + action = policy(state, agent_id, mask) + (next_state, next_agent_id, next_mask), reward = env.step(action) + +By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL: +:: + + action = policy(state_) + next_state_, reward = env.step(action) + +Following this idea, we write a tiny example of playing `Tic Tac Toe `_ against a random player by using a Q-lerning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`. diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index d981a1cb3..2da0ed738 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -88,7 +88,7 @@ We use the defined ``net`` and ``optim``, with extra policy hyper-parameters, to policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, - use_target_network=True, target_update_freq=320) + target_update_freq=320) Setup Collector diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst new file mode 100644 index 000000000..eb76bd540 --- /dev/null +++ b/docs/tutorials/tictactoe.rst @@ -0,0 +1,660 @@ +Multi-Agent RL +============== + +In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe `_ (see the image below) against a random opponent. + +.. image:: ../_static/images/tic-tac-toe.png + :align: center + +Tic-Tac-Toe Environment +----------------------- + +The scripts are located at ``test/multiagent/``. We have implemented a Tic-Tac-Toe environment inherit the :class:`~tianshou.env.MultiAgentEnv` that supports Tic-Tac-Toe of any scale. Let's first explore the environment. The 3x3 Tic-Tac-Toe is too easy, so we will focus on 6x6 Tic-Tac-Toe where 4 same signs in a row are considered to win. +:: + + >>> from tic_tac_toe_env import TicTacToeEnv # the module tic_tac_toe_env is in test/multiagent/ + >>> board_size = 6 # the size of board size + >>> win_size = 4 # how many signs in a row are considered to win + >>> + >>> # This board has 6 rows and 6 cols (36 places in total) + >>> # Players place 'x' and 'o' in turn on the board + >>> # The player who first gets 4 consecutive 'x's or 'o's wins + >>> + >>> env = TicTacToeEnv(size=board_size, win_size=win_size) + >>> obs = env.reset() + >>> env.render() # render the empty board + board (step 0): + ================= + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ================= + >>> print(obs) # let's see the shape of the observation + {'agent_id': 1, + 'obs': array([[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], dtype=int32), + 'mask': array([ True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True])} + +The observation variable ``obs`` returned from the environment is a ``dict``, with three keys ``agent_id``, ``obs``, ``mask``. This is a general structure in multi-agent RL where agents take turns. The meaning of these keys are: + +- ``agent_id``: the id of the current acting agent, where agent_id :math:`\in [1, N]`, N is the number of agents. In our Tic-Tac-Toe case, N is 2. The agent_id starts at 1 because we reserve 0 for the environment itself. Sometimes the developer may want to control the behavior of the environment, for example, to determine how to dispatch cards in Poker. + +- ``obs``: the actual observation of the environment. In the Tic-Tac-Toe game above, the observation variable ``obs`` is a ``np.ndarray`` with the shape of (6, 6). The values can be "0/1/-1": 0 for empty, 1 for ``x``, -1 for ``o``. Agent 1 places ``x`` on the board, while agent 2 places ``o`` on the board. + +- ``mask``: the action mask in the current timestep. In board games or card games, the legal action set varies with time. The mask is a boolean array. For Tic-Tac-Toe, index ``i`` means the place of ``i/N`` th row and ``i%N`` th column. If ``mask[i] == True``, the player can place an ``x`` or ``o`` at that position. Now the board is empty, so the mask is all the true, contains all the positions on the board. + +.. note:: + + There is no special formulation of ``mask`` either in discrete action space or in continuous action space. You can also use some action spaces like ``gym.spaces.Discrete`` or ``gym.spaces.Box`` to represent the available action space. Currently, we use a boolean array. + +Let's play two steps to have an intuitive understanding of the environment. + +:: + + >>> import numpy as np + >>> action = 0 # action is either an integer, or an np.ndarray with one element + >>> obs, reward, done, info = env.step(action) # the env.step follows the api of OpenAI Gym + >>> print(obs) # notice the change in the observation + {'agent_id': 2, + 'obs': array([[1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], dtype=int32), + 'mask': array([False, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True, + True, True, True, True, True, True, True, True, True])}} + >>> # reward has two items, one for each player: 1 for win, -1 for lose, and 0 otherwise + >>> print(reward) + [0. 0.] + >>> print(done) # done indicates whether the game is over + False + >>> # info is always an empty dict in Tic-Tac-Toe, but may contain some useful information in environments other than Tic-Tac-Toe. + >>> print(info) + {} + +One worth-noting case is that the game is over when there is only one empty position, rather than when there is no position. This is because the player just has one choice (literally no choice) in this game. +:: + + >>> # omitted actions: 6, 1, 7, 2, 8 + >>> obs, reward, done, info = env.step(3) # player 1 wins + >>> print((reward, done)) + (array([ 1., -1.], dtype=float32), array(True)) + >>> env.render() # 'X' and 'O' indicate the last action + board (step 7): + ================= + ===x x x X _ _=== + ===o o o _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ================= + +After being familiar with the environment, let's try to play with random agents first! + +Two Random Agent +---------------- + +.. sidebar:: The relationship between MultiAgentPolicyManager (Manager) and BasePolicy (Agent) + + .. Figure:: ../_static/images/marl.png + +Tianshou already provides some builtin classes for multi-agent learning. You can check out the API documentation for details. Here we use :class:`~tianshou.policy.RandomPolicy` and :class:`~tianshou.policy.MultiAgentPolicyManager`. The figure on the right gives an intuitive explanation. + +:: + + >>> from tianshou.data import Collector + >>> from tianshou.policy import RandomPolicy, MultiAgentPolicyManager + >>> + >>> # agents should be wrapped into one policy, + >>> # which is responsible for calling the acting agent correctly + >>> # here we use two random agents + >>> policy = MultiAgentPolicyManager([RandomPolicy(), RandomPolicy()]) + >>> + >>> # use collectors to collect a episode of trajectories + >>> # the reward is a vector, so we need a scalar metric to monitor the training + >>> collector = Collector(policy, env, reward_metric=lambda x: x[0]) + >>> + >>> # you will see a long trajectory showing the board status at each timestep + >>> result = collector.collect(n_episode=1, render=.1) + (only show the last 3 steps) + board (step 20): + ================= + ===o x _ o o o=== + ===_ _ x _ _ x=== + ===x _ o o x _=== + ===O _ o o x _=== + ===x _ o _ _ _=== + ===x _ _ _ x x=== + ================= + board (step 21): + ================= + ===o x _ o o o=== + ===_ _ x _ _ x=== + ===x _ o o x _=== + ===o _ o o x _=== + ===x _ o X _ _=== + ===x _ _ _ x x=== + ================= + board (step 22): + ================= + ===o x _ o o o=== + ===_ O x _ _ x=== + ===x _ o o x _=== + ===o _ o o x _=== + ===x _ o x _ _=== + ===x _ _ _ x x=== + ================= + >>> collector.close() + +Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly. + +Train a MARL Agent +------------------ + +So let's start to train our Tic-Tac-Toe agent! First, import some required modules. +:: + + import os + import torch + import argparse + import numpy as np + from copy import deepcopy + from torch.utils.tensorboard import SummaryWriter + + from tianshou.env import VectorEnv + from tianshou.utils.net.common import Net + from tianshou.trainer import offpolicy_trainer + from tianshou.data import Collector, ReplayBuffer + from tianshou.policy import BasePolicy, RandomPolicy, DQNPolicy, MultiAgentPolicyManager + + from tic_tac_toe_env import TicTacToeEnv + +The explanation of each Tianshou class/function will be deferred to their first usages. Here we define some arguments and hyperparameters of the experiment. The meaning of arguments is clear by just looking at their names. +:: + + def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.1, + help='a smaller gamma favors earlier win') + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=3) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.1) + parser.add_argument('--board_size', type=int, default=6) + parser.add_argument('--win_size', type=int, default=4) + parser.add_argument('--win-rate', type=float, default=np.float32(0.9), + help='the expected winning rate') + parser.add_argument('--watch', default=False, action='store_true', + help='no training, watch the play of pre-trained models') + parser.add_argument('--agent_id', type=int, default=2, + help='the learned agent plays as the agent_id-th player. choices are 1 and 2.') + parser.add_argument('--resume_path', type=str, default='', + help='the path of agent pth file for resuming from a pre-trained agent') + parser.add_argument('--opponent_path', type=str, default='', + help='the path of opponent agent pth file for resuming from a pre-trained agent') + parser.add_argument('--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + +.. sidebar:: The relationship between MultiAgentPolicyManager (Manager) and BasePolicy (Agent) + + .. Figure:: ../_static/images/marl.png + +The following ``get_agents`` function returns agents and their optimizers from either constructing a new policy, or loading from disk, or using the pass-in arguments. For the models: + +- The action model we use is an instance of :class:`~tianshou.utils.net.common.Net`, essentially a multi-layer perceptron with the ReLU activation function; +- The network model is passed to a :class:`~tianshou.policy.DQNPolicy`, where actions are selected according to both the action mask and their Q-values; +- The opponent can be either a random agent :class:`~tianshou.policy.RandomPolicy` that randomly chooses an action from legal actions, or it can be a pre-trained :class:`~tianshou.policy.DQNPolicy` allowing learned agents to play with themselves. + +Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, which is responsible to call the correct agent according to the ``agent_id`` in the observation. :class:`~tianshou.policy.MultiAgentPolicyManager` also dispatches data to each agent according to ``agent_id``, so that each agent seems to play with a virtual single-agent environment. + +Here it is: +:: + + def get_agents(args=get_args(), + agent_learn=None, # BasePolicy + agent_opponent=None, # BasePolicy + optim=None, # torch.optim.Optimizer + ): # return a tuple of (BasePolicy, torch.optim.Optimizer) + env = TicTacToeEnv(args.board_size, args.win_size) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + + if agent_learn is None: + net = Net(args.layer_num, args.state_shape, args.action_shape, args.device).to(args.device) + if optim is None: + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + agent_learn = DQNPolicy( + net, optim, args.gamma, args.n_step, + target_update_freq=args.target_update_freq) + if args.resume_path: + agent_learn.load_state_dict(torch.load(args.resume_path)) + + if agent_opponent is None: + if args.opponent_path: + agent_opponent = deepcopy(agent_learn) + agent_opponent.load_state_dict(torch.load(args.opponent_path)) + else: + agent_opponent = RandomPolicy() + + if args.agent_id == 1: + agents = [agent_learn, agent_opponent] + else: + agents = [agent_opponent, agent_learn] + policy = MultiAgentPolicyManager(agents) + return policy, optim + +With the above preparation, we are close to the first learned agent. The following code is almost the same as the code in the DQN tutorial. + +:: + + args = get_args() + # the reward is a vector, we need a scalar metric to monitor the training. + # we choose the reward of the learning agent + Collector._default_rew_metric = lambda x: x[args.agent_id - 1] + + # ======== a test function that tests a pre-trained agent and exit ====== + def watch(args=get_args(), + agent_learn=None, # BasePolicy + agent_opponent=None): # BasePolicy + env = TicTacToeEnv(args.board_size, args.win_size) + policy, optim = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + if args.watch: + watch(args) + exit(0) + + # ======== environment setup ========= + env_func = lambda: TicTacToeEnv(args.board_size, args.win_size) + train_envs = VectorEnv([env_func for _ in range(args.training_num)]) + test_envs = VectorEnv([env_func for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + # ======== agent setup ========= + policy, optim = get_agents() + + # ======== collector setup ========= + train_collector = Collector(policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.batch_size) + + # ======== tensorboard logging setup ========= + if not hasattr(args, 'writer'): + log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') + writer = SummaryWriter(log_path) + else: + writer = args.writer + + # ======== callback functions used during training ========= + + def save_fn(policy): + if hasattr(args, 'model_save_path'): + model_save_path = args.model_save_path + else: + model_save_path = os.path.join( + args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth') + torch.save( + policy.policies[args.agent_id - 1].state_dict(), + model_save_path) + + def stop_fn(x): + return x >= args.win_rate # 95% winning rate by default + # the default args.win_rate is 0.9, but the reward is [-1, 1] + # instead of [0, 1], so args.win_rate == 0.9 is equal to 95% win rate. + + def train_fn(x): + policy.policies[args.agent_id - 1].set_eps(args.eps_train) + + def test_fn(x): + policy.policies[args.agent_id - 1].set_eps(args.eps_test) + + # start training, this may require about three minutes + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + test_in_train=False) + + train_collector.close() + test_collector.close() + + agent = policy.policies[args.agent_id - 1] + # let's watch the match! + watch(args, agent) + +That's it. By executing the code, you will see a progress bar indicating the progress of training. After about less than 1 minute, the agent has finished training, and you can see how it plays against the random agent. Here is an example: + +.. raw:: html + +
+ Play with random agent + +:: + + board (step 1): + ================= + ===_ _ _ X _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ================= + board (step 2): + ================= + ===_ _ _ x _ _=== + ===_ _ _ _ _ _=== + ===_ _ O _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ================= + board (step 3): + ================= + ===_ _ _ x _ _=== + ===_ _ _ _ _ _=== + ===_ _ o _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ X _ _=== + ===_ _ _ _ _ _=== + ================= + board (step 4): + ================= + ===_ _ _ x _ _=== + ===_ _ _ _ _ _=== + ===_ _ o _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ x _ _=== + ===_ _ O _ _ _=== + ================= + board (step 5): + ================= + ===_ _ _ x _ _=== + ===_ _ _ _ X _=== + ===_ _ o _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ x _ _=== + ===_ _ o _ _ _=== + ================= + board (step 6): + ================= + ===_ _ _ x _ _=== + ===_ _ _ _ x _=== + ===_ _ o _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ O x _ _=== + ===_ _ o _ _ _=== + ================= + board (step 7): + ================= + ===_ _ _ x _ X=== + ===_ _ _ _ x _=== + ===_ _ o _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ o x _ _=== + ===_ _ o _ _ _=== + ================= + board (step 8): + ================= + ===_ _ _ x _ x=== + ===_ _ _ _ x _=== + ===_ _ o _ _ _=== + ===_ _ _ _ O _=== + ===_ _ o x _ _=== + ===_ _ o _ _ _=== + ================= + board (step 9): + ================= + ===_ _ _ x _ x=== + ===_ _ _ _ x _=== + ===_ _ o _ _ _=== + ===_ _ _ _ o _=== + ===X _ o x _ _=== + ===_ _ o _ _ _=== + ================= + board (step 10): + ================= + ===_ _ _ x _ x=== + ===_ _ _ _ x _=== + ===_ _ o _ _ _=== + ===_ _ O _ o _=== + ===x _ o x _ _=== + ===_ _ o _ _ _=== + ================= + Final reward: 1.0, length: 10.0 + +.. raw:: html + +

+ +Notice that, our learned agent plays the role of agent 2, placing ``o`` on the board. The agent performs pretty well against the random opponent! It learns the rule of the game by trial and error, and learns that four consecutive ``o`` means winning, so it does! + +The above code can be executed in a python shell or can be saved as a script file (we have saved it in ``test/multiagent/test_tic_tac_toe.py``). In the latter case, you can train an agent by + +.. code-block:: console + + $ python test_tic_tac_toe.py + +By default, the trained agent is stored in ``log/tic_tac_toe/dqn/policy.pth``. You can also make the trained agent play against itself, by + +.. code-block:: console + + $ python test_tic_tac_toe.py --watch --resume_path=log/tic_tac_toe/dqn/policy.pth --opponent_path=log/tic_tac_toe/dqn/policy.pth + +Here is our output: + +.. raw:: html + +
+ The trained agent play against itself + +:: + + board (step 1): + ================= + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ X _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ================= + board (step 2): + ================= + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ x _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ O _ _ _=== + ================= + board (step 3): + ================= + ===_ _ _ _ _ _=== + ===_ _ X _ _ _=== + ===_ _ x _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ o _ _ _=== + ================= + board (step 4): + ================= + ===_ _ _ _ _ _=== + ===_ _ x _ _ _=== + ===_ _ x _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ o O _ _=== + ================= + board (step 5): + ================= + ===_ _ _ _ _ _=== + ===_ _ x _ _ _=== + ===_ _ x _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ X _ _=== + ===_ _ o o _ _=== + ================= + board (step 6): + ================= + ===_ _ _ _ _ _=== + ===_ _ x _ _ _=== + ===_ _ x _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ x _ _=== + ===_ _ o o O _=== + ================= + board (step 7): + ================= + ===_ _ _ _ _ _=== + ===_ _ x _ X _=== + ===_ _ x _ _ _=== + ===_ _ _ _ _ _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 8): + ================= + ===_ _ _ _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ _=== + ===O _ _ _ _ _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 9): + ================= + ===_ _ _ _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ _=== + ===o _ _ X _ _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 10): + ================= + ===_ O _ _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ _=== + ===o _ _ x _ _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 11): + ================= + ===_ o _ _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ X=== + ===o _ _ x _ _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 12): + ================= + ===_ o O _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ x=== + ===o _ _ x _ _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 13): + ================= + ===_ o o _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ x=== + ===o _ _ x X _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 14): + ================= + ===O o o _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ x=== + ===o _ _ x x _=== + ===_ _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 15): + ================= + ===o o o _ _ _=== + ===_ _ x _ x _=== + ===_ _ x _ _ x=== + ===o _ _ x x _=== + ===X _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 16): + ================= + ===o o o _ _ _=== + ===_ O x _ x _=== + ===_ _ x _ _ x=== + ===o _ _ x x _=== + ===x _ _ x _ _=== + ===_ _ o o o _=== + ================= + board (step 17): + ================= + ===o o o _ _ _=== + ===_ o x _ x _=== + ===_ _ x _ _ x=== + ===o _ _ x x _=== + ===x _ X x _ _=== + ===_ _ o o o _=== + ================= + board (step 18): + ================= + ===o o o _ _ _=== + ===_ o x _ x _=== + ===_ _ x _ _ x=== + ===o _ _ x x _=== + ===x _ x x _ _=== + ===_ O o o o _=== + ================= + +.. raw:: html + +

+ +Well, although the learned agent plays well against the random agent, it is far away from intelligence. + +Next, maybe you can try to build more intelligent agents by letting the agent learn from self-play, just like AlphaZero! + +In this tutorial, we show an example of how to use Tianshou for multi-agent RL. Tianshou is a flexible and easy to use RL library. Make the best of Tianshou by yourself! diff --git a/examples/pong_dqn.py b/examples/pong_dqn.py index e99c04edc..d46d1af96 100644 --- a/examples/pong_dqn.py +++ b/examples/pong_dqn.py @@ -65,7 +65,6 @@ def test_dqn(args=get_args()): optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, - use_target_network=args.target_update_freq > 0, target_update_freq=args.target_update_freq) # collector train_collector = Collector( diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 96ddb70e2..0a3a5b067 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -62,7 +62,6 @@ def test_dqn(args=get_args()): optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, - use_target_network=args.target_update_freq > 0, target_update_freq=args.target_update_freq) # collector train_collector = Collector( diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 42ee53495..48573e736 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -63,7 +63,6 @@ def test_drqn(args=get_args()): optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, - use_target_network=args.target_update_freq > 0, target_update_freq=args.target_update_freq) # collector train_collector = Collector( diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py index 22fa34764..b614f248a 100644 --- a/test/discrete/test_pdqn.py +++ b/test/discrete/test_pdqn.py @@ -65,7 +65,6 @@ def test_pdqn(args=get_args()): optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, - use_target_network=args.target_update_freq > 0, target_update_freq=args.target_update_freq) # collector if args.prioritized_replay > 0: diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py new file mode 100644 index 000000000..23793914d --- /dev/null +++ b/test/multiagent/Gomoku.py @@ -0,0 +1,84 @@ +import os +import pprint +import numpy as np +from copy import deepcopy +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import VectorEnv +from tianshou.data import Collector +from tianshou.policy import RandomPolicy + +from tic_tac_toe_env import TicTacToeEnv +from tic_tac_toe import get_parser, get_agents, train_agent, watch + + +def get_args(): + parser = get_parser() + parser.add_argument('--self_play_round', type=int, default=20) + args = parser.parse_known_args()[0] + return args + + +def gomoku(args=get_args()): + Collector._default_rew_metric = lambda x: x[args.agent_id - 1] + if args.watch: + watch(args) + return + + policy, optim = get_agents(args) + agent_learn = policy.policies[args.agent_id - 1] + agent_opponent = policy.policies[2 - args.agent_id] + + # log + log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') + args.writer = SummaryWriter(log_path) + + opponent_pool = [agent_opponent] + + def env_func(): + return TicTacToeEnv(args.board_size, args.win_size) + test_envs = VectorEnv([env_func for _ in range(args.test_num)]) + for r in range(args.self_play_round): + rews = [] + agent_learn.set_eps(0.0) + # compute the reward over previous learner + for opponent in opponent_pool: + policy.replace_policy(opponent, 3 - args.agent_id) + test_collector = Collector(policy, test_envs) + results = test_collector.collect(n_episode=100) + rews.append(results['rew']) + rews = np.array(rews) + # weight opponent by their difficulty level + rews = np.exp(-rews * 10.0) + rews /= np.sum(rews) + total_epoch = args.epoch + args.epoch = 1 + for epoch in range(total_epoch): + # sample one opponent + opp_id = np.random.choice(len(opponent_pool), size=1, p=rews) + print(f'selection probability {rews.tolist()}') + print(f'selected opponent {opp_id}') + opponent = opponent_pool[opp_id.item(0)] + agent = RandomPolicy() + # previous learner can only be used for forward + agent.forward = opponent.forward + args.model_save_path = os.path.join( + args.logdir, 'Gomoku', 'dqn', + f'policy_round_{r}_epoch_{epoch}.pth') + result, agent_learn = train_agent( + args, agent_learn=agent_learn, + agent_opponent=agent, optim=optim) + print(f'round_{r}_epoch_{epoch}') + pprint.pprint(result) + learnt_agent = deepcopy(agent_learn) + learnt_agent.set_eps(0.0) + opponent_pool.append(learnt_agent) + args.epoch = total_epoch + if __name__ == '__main__': + # Let's watch its performance! + opponent = opponent_pool[-2] + watch(args, agent_learn, opponent) + + +if __name__ == '__main__': + gomoku(get_args()) diff --git a/test/multiagent/test_tic_tac_toe.py b/test/multiagent/test_tic_tac_toe.py new file mode 100644 index 000000000..92ecb97c6 --- /dev/null +++ b/test/multiagent/test_tic_tac_toe.py @@ -0,0 +1,22 @@ +import pprint +from tianshou.data import Collector +from tic_tac_toe import get_args, train_agent, watch + + +def test_tic_tac_toe(args=get_args()): + Collector._default_rew_metric = lambda x: x[args.agent_id - 1] + if args.watch: + watch(args) + return + + result, agent = train_agent(args) + assert result["best_reward"] >= args.win_rate + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + watch(args, agent) + + +if __name__ == '__main__': + test_tic_tac_toe(get_args()) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py new file mode 100644 index 000000000..5b718ce6b --- /dev/null +++ b/test/multiagent/tic_tac_toe.py @@ -0,0 +1,178 @@ +import os +import torch +import argparse +import numpy as np +from copy import deepcopy +from typing import Optional, Tuple +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import VectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \ + MultiAgentPolicyManager + +from tic_tac_toe_env import TicTacToeEnv + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.1, + help='a smaller gamma favors earlier win') + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=500) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=3) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.1) + parser.add_argument('--board_size', type=int, default=6) + parser.add_argument('--win_size', type=int, default=4) + parser.add_argument('--win-rate', type=float, default=0.8, + help='the expected winning rate') + parser.add_argument('--watch', default=False, action='store_true', + help='no training, ' + 'watch the play of pre-trained models') + parser.add_argument('--agent_id', type=int, default=2, + help='the learned agent plays as the' + ' agent_id-th player. choices are 1 and 2.') + parser.add_argument('--resume_path', type=str, default='', + help='the path of agent pth file ' + 'for resuming from a pre-trained agent') + parser.add_argument('--opponent_path', type=str, default='', + help='the path of opponent agent pth file ' + 'for resuming from a pre-trained agent') + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + return parser + + +def get_args() -> argparse.Namespace: + parser = get_parser() + args = parser.parse_known_args()[0] + return args + + +def get_agents(args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> Tuple[BasePolicy, torch.optim.Optimizer]: + env = TicTacToeEnv(args.board_size, args.win_size) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + if agent_learn is None: + # model + net = Net(args.layer_num, args.state_shape, args.action_shape, + args.device).to(args.device) + if optim is None: + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + agent_learn = DQNPolicy( + net, optim, args.gamma, args.n_step, + target_update_freq=args.target_update_freq) + if args.resume_path: + agent_learn.load_state_dict(torch.load(args.resume_path)) + + if agent_opponent is None: + if args.opponent_path: + agent_opponent = deepcopy(agent_learn) + agent_opponent.load_state_dict(torch.load(args.opponent_path)) + else: + agent_opponent = RandomPolicy() + + if args.agent_id == 1: + agents = [agent_learn, agent_opponent] + else: + agents = [agent_opponent, agent_learn] + policy = MultiAgentPolicyManager(agents) + return policy, optim + + +def train_agent(args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> Tuple[dict, BasePolicy]: + def env_func(): + return TicTacToeEnv(args.board_size, args.win_size) + train_envs = VectorEnv([env_func for _ in range(args.training_num)]) + test_envs = VectorEnv([env_func for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + policy, optim = get_agents( + args, agent_learn=agent_learn, + agent_opponent=agent_opponent, optim=optim) + + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size) + # log + if not hasattr(args, 'writer'): + log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') + writer = SummaryWriter(log_path) + args.writer = writer + else: + writer = args.writer + + def save_fn(policy): + if hasattr(args, 'model_save_path'): + model_save_path = args.model_save_path + else: + model_save_path = os.path.join( + args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth') + torch.save( + policy.policies[args.agent_id - 1].state_dict(), + model_save_path) + + def stop_fn(x): + return x >= args.win_rate + + def train_fn(x): + policy.policies[args.agent_id - 1].set_eps(args.eps_train) + + def test_fn(x): + policy.policies[args.agent_id - 1].set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + test_in_train=False) + + train_collector.close() + test_collector.close() + + return result, policy.policies[args.agent_id - 1] + + +def watch(args: argparse.Namespace = get_args(), + agent_learn: Optional[BasePolicy] = None, + agent_opponent: Optional[BasePolicy] = None, + ) -> None: + env = TicTacToeEnv(args.board_size, args.win_size) + policy, optim = get_agents( + args, agent_learn=agent_learn, agent_opponent=agent_opponent) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() diff --git a/test/multiagent/tic_tac_toe_env.py b/test/multiagent/tic_tac_toe_env.py new file mode 100644 index 000000000..2fc045afa --- /dev/null +++ b/test/multiagent/tic_tac_toe_env.py @@ -0,0 +1,136 @@ +import gym +import numpy as np +from functools import partial +from typing import Tuple, Optional + +from tianshou.env import MultiAgentEnv + + +class TicTacToeEnv(MultiAgentEnv): + """This is a simple implementation of the Tic-Tac-Toe game, where two + agents play against each other. + + The implementation is intended to show how to wrap an environment to + satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`. + + :param size: the size of the board (square board) + :param win_size: how many units in a row is considered to win + """ + + def __init__(self, size: int = 3, win_size: int = 3): + super().__init__() + assert size > 0, f'board size should be positive, but got {size}' + self.size = size + assert win_size > 0, f'win-size should be positive, but got {win_size}' + self.win_size = win_size + assert win_size <= size, f'win-size {win_size} should not ' \ + f'be larger than board size {size}' + self.convolve_kernel = np.ones(win_size) + self.observation_space = gym.spaces.Box( + low=-1.0, high=1.0, shape=(size, size), dtype=np.float32) + self.action_space = gym.spaces.Discrete(size * size) + self.current_board = None + self.current_agent = None + self._last_move = None + self.step_num = None + + def reset(self) -> dict: + self.current_board = np.zeros((self.size, self.size), dtype=np.int32) + self.current_agent = 1 + self._last_move = (-1, -1) + self.step_num = 0 + return { + 'agent_id': self.current_agent, + 'obs': np.array(self.current_board), + 'mask': self.current_board.flatten() == 0 + } + + def step(self, action: [int, np.ndarray] + ) -> Tuple[dict, np.ndarray, np.ndarray, dict]: + if self.current_agent is None: + raise ValueError( + "calling step() of unreset environment is prohibited!") + assert 0 <= action < self.size * self.size + assert self.current_board.item(action) == 0 + _current_agent = self.current_agent + self._move(action) + mask = self.current_board.flatten() == 0 + is_win, is_opponent_win = False, False + is_win = self._test_win() + # the game is over when one wins or there is only one empty place + done = is_win + if sum(mask) == 1: + done = True + self._move(np.where(mask)[0][0]) + is_opponent_win = self._test_win() + if is_win: + reward = 1 + elif is_opponent_win: + reward = -1 + else: + reward = 0 + obs = { + 'agent_id': self.current_agent, + 'obs': np.array(self.current_board), + 'mask': mask + } + rew_agent_1 = reward if _current_agent == 1 else (-reward) + rew_agent_2 = reward if _current_agent == 2 else (-reward) + vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32) + if done: + self.current_agent = None + return obs, vec_rew, np.array(done), {} + + def _move(self, action): + row, col = action // self.size, action % self.size + if self.current_agent == 1: + self.current_board[row, col] = 1 + else: + self.current_board[row, col] = -1 + self.current_agent = 3 - self.current_agent + self._last_move = (row, col) + self.step_num += 1 + + def _test_win(self): + """test if someone wins by checking the situation around last move""" + row, col = self._last_move + rboard = self.current_board[row, :] + cboard = self.current_board[:, col] + current = self.current_board[row, col] + rightup = [self.current_board[row - i, col + i] + for i in range(1, self.size - col) if row - i >= 0] + leftdown = [self.current_board[row + i, col - i] + for i in range(1, col + 1) if row + i < self.size] + rdiag = np.array(leftdown[::-1] + [current] + rightup) + rightdown = [self.current_board[row + i, col + i] + for i in range(1, self.size - col) if row + i < self.size] + leftup = [self.current_board[row - i, col - i] + for i in range(1, col + 1) if row - i >= 0] + diag = np.array(leftup[::-1] + [current] + rightdown) + results = [np.convolve(k, self.convolve_kernel, mode='valid') + for k in (rboard, cboard, rdiag, diag)] + return any([(np.abs(x) == self.win_size).any() for x in results]) + + def seed(self, seed: Optional[int] = None) -> int: + pass + + def render(self, **kwargs) -> None: + print(f'board (step {self.step_num}):') + pad = '===' + top = pad + '=' * (2 * self.size - 1) + pad + print(top) + + def f(i, data): + j, number = data + last_move = i == self._last_move[0] and j == self._last_move[1] + if number == 1: + return 'X' if last_move else 'x' + if number == -1: + return 'O' if last_move else 'o' + return '_' + for i, row in enumerate(self.current_board): + print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad) + print(top) + + def close(self) -> None: + pass diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7e602bba8..f147ca326 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -548,6 +548,7 @@ def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': (2, 4, 5) .. note:: + If there are keys that are not shared across all batches, ``stack`` with ``axis != 0`` is undefined, and will cause an exception. """ diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 98a09a4dc..43bca7fdc 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,9 +1,11 @@ -from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \ - SubprocVectorEnv, RayVectorEnv +from tianshou.env.basevecenv import BaseVectorEnv +from tianshou.env.vecenv import VectorEnv, SubprocVectorEnv, RayVectorEnv +from tianshou.env.maenv import MultiAgentEnv __all__ = [ 'BaseVectorEnv', 'VectorEnv', 'SubprocVectorEnv', 'RayVectorEnv', + 'MultiAgentEnv', ] diff --git a/tianshou/env/basevecenv.py b/tianshou/env/basevecenv.py new file mode 100644 index 000000000..60394e3de --- /dev/null +++ b/tianshou/env/basevecenv.py @@ -0,0 +1,116 @@ +import gym +import numpy as np +from abc import ABC, abstractmethod +from typing import List, Tuple, Union, Optional, Callable + + +class BaseVectorEnv(ABC, gym.Env): + """Base class for vectorized environments wrapper. Usage: + :: + + env_num = 8 + envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)]) + assert len(envs) == env_num + + It accepts a list of environment generators. In other words, an environment + generator ``efn`` of a specific task means that ``efn()`` returns the + environment of the given task, for example, ``gym.make(task)``. + + All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`. + Here are some other usages: + :: + + envs.seed(2) # which is equal to the next line + envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env + obs = envs.reset() # reset all environments + obs = envs.reset([0, 5, 7]) # reset 3 specific environments + obs, rew, done, info = envs.step([1] * 8) # step synchronously + envs.render() # render all environments + envs.close() # close all environments + """ + + def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: + self._env_fns = env_fns + self.env_num = len(env_fns) + + def __len__(self) -> int: + """Return len(self), which is the number of environments.""" + return self.env_num + + def __getattribute__(self, key: str): + """Switch between the default attribute getter or one + looking at wrapped environment level depending on the key.""" + if key not in ('observation_space', 'action_space'): + return super().__getattribute__(key) + else: + return self.__getattr__(key) + + @abstractmethod + def __getattr__(self, key: str): + """Try to retrieve an attribute from each individual wrapped + environment, if it does not belong to the wrapping vector + environment class.""" + pass + + @abstractmethod + def reset(self, id: Optional[Union[int, List[int]]] = None): + """Reset the state of all the environments and return initial + observations if id is ``None``, otherwise reset the specific + environments with given id, either an int or a list. + """ + pass + + @abstractmethod + def step(self, + action: np.ndarray, + id: Optional[Union[int, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Run one timestep of all the environments’ dynamics if id is + ``None``, otherwise run one timestep for some environments + with given id, either an int or a list. When the end of + episode is reached, you are responsible for calling reset(id) + to reset this environment’s state. + + Accept a batch of action and return a tuple (obs, rew, done, info). + + :param numpy.ndarray action: a batch of action provided by the agent. + + :return: A tuple including four items: + + * ``obs`` a numpy.ndarray, the agent's observation of current \ + environments + * ``rew`` a numpy.ndarray, the amount of rewards returned after \ + previous actions + * ``done`` a numpy.ndarray, whether these episodes have ended, in \ + which case further step() calls will return undefined results + * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ + information (helpful for debugging, and sometimes learning) + """ + pass + + @abstractmethod + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: + """Set the seed for all environments. + + Accept ``None``, an int (which will extend ``i`` to + ``[i, i + 1, i + 2, ...]``) or a list. + + :return: The list of seeds used in this env's random number \ + generators. The first value in the list should be the "main" seed, or \ + the value which a reproducer pass to "seed". + """ + pass + + @abstractmethod + def render(self, **kwargs) -> None: + """Render all of the environments.""" + pass + + @abstractmethod + def close(self) -> None: + """Close all of the environments. + + Environments will automatically close() themselves when garbage + collected or when the program exits. + """ + pass diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py new file mode 100644 index 000000000..ea9284ea7 --- /dev/null +++ b/tianshou/env/maenv.py @@ -0,0 +1,59 @@ +import gym +import numpy as np +from typing import Tuple +from abc import ABC, abstractmethod + + +class MultiAgentEnv(ABC, gym.Env): + """The interface for multi-agent environments. Multi-agent environments + must be wrapped as :class:`~tianshou.env.MultiAgentEnv`. Here is the usage: + :: + + env = MultiAgentEnv(...) + # obs is a dict containing obs, agent_id, and mask + obs = env.reset() + action = policy(obs) + obs, rew, done, info = env.step(action) + env.close() + + The available action's mask is set to 1, otherwise it is set to 0. Further + usage can be found at :ref:`marl_example`. + """ + + def __init__(self, **kwargs) -> None: + pass + + @abstractmethod + def reset(self) -> dict: + """Reset the state. Return the initial state, first agent_id, and the + initial action set, for example, + ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}`` + """ + pass + + @abstractmethod + def step(self, action: np.ndarray + ) -> Tuple[dict, np.ndarray, np.ndarray, np.ndarray]: + """Run one timestep of the environment’s dynamics. When the end of + episode is reached, you are responsible for calling reset() to reset + the environment’s state. + + Accept action and return a tuple (obs, rew, done, info). + + :param numpy.ndarray action: action provided by a agent. + + :return: A tuple including four items: + + * ``obs`` a dict containing obs, agent_id, and mask, which means \ + that it is the ``agent_id`` player's turn to play with ``obs``\ + observation and ``mask``. + * ``rew`` a numpy.ndarray, the amount of rewards returned after \ + previous actions. Depending on the specific environment, this \ + can be either a scalar reward for current agent or a vector \ + reward for all the agents. + * ``done`` a numpy.ndarray, whether the episode has ended, in \ + which case further step() calls will return undefined results + * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ + information (helpful for debugging, and sometimes learning) + """ + pass diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index 93c388209..2f07ebb06 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -1,6 +1,5 @@ import gym import numpy as np -from abc import ABC, abstractmethod from multiprocessing import Process, Pipe from typing import List, Tuple, Union, Optional, Callable, Any @@ -9,121 +8,10 @@ except ImportError: pass +from tianshou.env import BaseVectorEnv from tianshou.env.utils import CloudpickleWrapper -class BaseVectorEnv(ABC, gym.Env): - """Base class for vectorized environments wrapper. Usage: - :: - - env_num = 8 - envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)]) - assert len(envs) == env_num - - It accepts a list of environment generators. In other words, an environment - generator ``efn`` of a specific task means that ``efn()`` returns the - environment of the given task, for example, ``gym.make(task)``. - - All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`. - Here are some other usages: - :: - - envs.seed(2) # which is equal to the next line - envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env - obs = envs.reset() # reset all environments - obs = envs.reset([0, 5, 7]) # reset 3 specific environments - obs, rew, done, info = envs.step([1] * 8) # step synchronously - envs.render() # render all environments - envs.close() # close all environments - """ - - def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: - self._env_fns = env_fns - self.env_num = len(env_fns) - - def __len__(self) -> int: - """Return len(self), which is the number of environments.""" - return self.env_num - - def __getattribute__(self, key: str): - """Switch between the default attribute getter or one - looking at wrapped environment level depending on the key.""" - if key not in ('observation_space', 'action_space'): - return super().__getattribute__(key) - else: - return self.__getattr__(key) - - @abstractmethod - def __getattr__(self, key: str): - """Try to retrieve an attribute from each individual wrapped - environment, if it does not belong to the wrapping vector - environment class.""" - pass - - @abstractmethod - def reset(self, id: Optional[Union[int, List[int]]] = None): - """Reset the state of all the environments and return initial - observations if id is ``None``, otherwise reset the specific - environments with given id, either an int or a list. - """ - pass - - @abstractmethod - def step(self, - action: np.ndarray, - id: Optional[Union[int, List[int]]] = None - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Run one timestep of all the environments’ dynamics if id is - ``None``, otherwise run one timestep for some environments - with given id, either an int or a list. When the end of - episode is reached, you are responsible for calling reset(id) - to reset this environment’s state. - - Accept a batch of action and return a tuple (obs, rew, done, info). - - :param numpy.ndarray action: a batch of action provided by the agent. - - :return: A tuple including four items: - - * ``obs`` a numpy.ndarray, the agent's observation of current \ - environments - * ``rew`` a numpy.ndarray, the amount of rewards returned after \ - previous actions - * ``done`` a numpy.ndarray, whether these episodes have ended, in \ - which case further step() calls will return undefined results - * ``info`` a numpy.ndarray, contains auxiliary diagnostic \ - information (helpful for debugging, and sometimes learning) - """ - pass - - @abstractmethod - def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: - """Set the seed for all environments. - - Accept ``None``, an int (which will extend ``i`` to - ``[i, i + 1, i + 2, ...]``) or a list. - - :return: The list of seeds used in this env's random number \ - generators. The first value in the list should be the "main" seed, or \ - the value which a reproducer pass to "seed". - """ - pass - - @abstractmethod - def render(self, **kwargs) -> None: - """Render all of the environments.""" - pass - - @abstractmethod - def close(self) -> None: - """Close all of the environments. - - Environments will automatically close() themselves when garbage - collected or when the program exits. - """ - pass - - class VectorEnv(BaseVectorEnv): """Dummy vectorized environment wrapper, implemented in for-loop. diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 37f11e92b..95b7f0eeb 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,4 +1,5 @@ from tianshou.policy.base import BasePolicy +from tianshou.policy.random import RandomPolicy from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.pg import PGPolicy @@ -7,9 +8,12 @@ from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy +from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager + __all__ = [ 'BasePolicy', + 'RandomPolicy', 'ImitationPolicy', 'DQNPolicy', 'PGPolicy', @@ -18,4 +22,5 @@ 'PPOPolicy', 'TD3Policy', 'SACPolicy', + 'MultiAgentPolicyManager', ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e75374228..c9763621c 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -53,6 +53,11 @@ def __init__(self, **kwargs) -> None: super().__init__() self.observation_space = kwargs.get('observation_space') self.action_space = kwargs.get('action_space') + self.agent_id = 0 + + def set_agent_id(self, agent_id: int) -> None: + """set self.agent_id = agent_id, for MARL.""" + self.agent_id = agent_id def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index eb6f29878..214a453ee 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -85,17 +85,8 @@ def _target_q(self, buffer: ReplayBuffer, def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: - r"""Compute the n-step return for Q-learning targets: - - .. math:: - G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + - \gamma^n (1 - d_{t + n}) \max_a Q_{old}(s_{t + n}, \arg\max_a - (Q_{new}(s_{t + n}, a))) - - , where :math:`\gamma` is the discount factor, - :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step - :math:`t`. If there is no target network, the :math:`Q_{old}` is equal - to :math:`Q_{new}`. + """Compute the n-step return for Q-learning targets. More details can + be found at :meth:`~tianshou.policy.BasePolicy.compute_nstep_return`. """ batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, @@ -111,7 +102,20 @@ def forward(self, batch: Batch, input: str = 'obs', eps: Optional[float] = None, **kwargs) -> Batch: - """Compute action over the given batch data. + """Compute action over the given batch data. If you need to mask the + action, please add a "mask" into batch.obs, for example, if we have an + environment that has "0/1/2" three actions: + :: + + batch == Batch( + obs=Batch( + obs="original obs, with batch_size=1 for demonstration", + mask=np.array([[False, True, False]]), + # action 1 is available + # action 0 and 2 are unavailable + ), + ... + ) :param float eps: in [0, 1], for epsilon-greedy exploration method. @@ -128,15 +132,25 @@ def forward(self, batch: Batch, """ model = getattr(self, model) obs = getattr(batch, input) - q, h = model(obs, state=state, info=batch.info) + obs_ = obs.obs if hasattr(obs, 'obs') else obs + q, h = model(obs_, state=state, info=batch.info) act = to_numpy(q.max(dim=1)[1]) + has_mask = hasattr(obs, 'mask') + if has_mask: + # some of actions are masked, they cannot be selected + q_ = to_numpy(q) + q_[~obs.mask] = -np.inf + act = q_.argmax(axis=1) # add eps to act if eps is None: eps = self.eps if not np.isclose(eps, 0): for i in range(len(q)): if np.random.rand() < eps: - act[i] = np.random.randint(q.shape[1]) + q_ = np.random.rand(*q[i].shape) + if has_mask: + q_[~obs.mask[i]] = -np.inf + act[i] = q_.argmax() return Batch(logits=q, act=act, state=h) def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: diff --git a/tianshou/policy/multiagent/__init__.py b/tianshou/policy/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py new file mode 100644 index 000000000..f6329888d --- /dev/null +++ b/tianshou/policy/multiagent/mapolicy.py @@ -0,0 +1,140 @@ +import numpy as np +from typing import Union, Optional, Dict, List + +from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer + + +class MultiAgentPolicyManager(BasePolicy): + """This multi-agent policy manager accepts a list of + :class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each + of these policies when the "forward" is called. The same as "process_fn" + and "learn": it splits the data and feeds them to each policy. A figure in + :ref:`marl_example` can help you better understand this procedure. + """ + + def __init__(self, policies: List[BasePolicy]): + super().__init__() + self.policies = policies + for i, policy in enumerate(policies): + # agent_id 0 is reserved for the environment proxy + # (this MultiAgentPolicyManager) + policy.set_agent_id(i + 1) + + def replace_policy(self, policy, agent_id): + """Replace the "agent_id"th policy in this manager.""" + self.policies[agent_id - 1] = policy + policy.set_agent_id(agent_id) + + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: + """Save original multi-dimensional rew in "save_rew", set rew to the + reward of each agent during their ``process_fn``, and restore the + original reward afterwards. + """ + results = {} + # reward can be empty Batch (after initial reset) or nparray. + has_rew = isinstance(buffer.rew, np.ndarray) + if has_rew: # save the original reward in save_rew + save_rew, buffer.rew = buffer.rew, Batch() + for policy in self.policies: + agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + if len(agent_index) == 0: + results[f'agent_{policy.agent_id}'] = Batch() + continue + tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] + if has_rew: + tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] + buffer.rew = save_rew[:, policy.agent_id - 1] + results[f'agent_{policy.agent_id}'] = \ + policy.process_fn(tmp_batch, buffer, tmp_indice) + if has_rew: # restore from save_rew + buffer.rew = save_rew + return Batch(results) + + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch]] = None, + **kwargs) -> Batch: + """:param state: if None, it means all agents have no state. If not + None, it should contain keys of "agent_1", "agent_2", ... + + :return: a Batch with the following contents: + + :: + + { + "act": actions corresponding to the input + "state":{ + "agent_1": output state of agent_1's policy for the state + "agent_2": xxx + ... + "agent_n": xxx} + "out":{ + "agent_1": output of agent_1's policy for the input + "agent_2": xxx + ... + "agent_n": xxx} + } + """ + results = [] + for policy in self.policies: + # This part of code is difficult to understand. + # Let's follow an example with two agents + # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) + # each agent plays for three transitions + # agent_index for agent 1 is [0, 2, 4] + # agent_index for agent 2 is [1, 3, 5] + # we separate the transition of each agent according to agent_id + agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + if len(agent_index) == 0: + # (has_data, agent_index, out, act, state) + results.append((False, None, Batch(), None, Batch())) + continue + tmp_batch = batch[agent_index] + if isinstance(tmp_batch.rew, np.ndarray): + # reward can be empty Batch (after initial reset) or nparray. + tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] + out = policy(batch=tmp_batch, state=None if state is None + else state["agent_" + str(policy.agent_id)], + **kwargs) + act = out.act + each_state = out.state \ + if (hasattr(out, 'state') and out.state is not None) \ + else Batch() + results.append((True, agent_index, out, act, each_state)) + holder = Batch.cat([{'act': act} for + (has_data, agent_index, out, act, each_state) + in results if has_data]) + state_dict, out_dict = {}, {} + for policy, (has_data, agent_index, out, act, state) in \ + zip(self.policies, results): + if has_data: + holder.act[agent_index] = act + state_dict["agent_" + str(policy.agent_id)] = state + out_dict["agent_" + str(policy.agent_id)] = out + holder["out"] = out_dict + holder["state"] = state_dict + return holder + + def learn(self, batch: Batch, **kwargs + ) -> Dict[str, Union[float, List[float]]]: + """:return: a dict with the following contents: + + :: + + { + "agent_1/item1": item 1 of agent_1's policy.learn output + "agent_1/item2": item 2 of agent_1's policy.learn output + "agent_2/xxx": xxx + ... + "agent_n/xxx": xxx + } + """ + results = {} + for policy in self.policies: + data = batch[f'agent_{policy.agent_id}'] + if not data.is_empty(): + out = policy.learn(batch=data, **kwargs) + for k, v in out.items(): + results["agent_" + str(policy.agent_id) + '/' + k] = v + return results diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py new file mode 100644 index 000000000..a300e8c92 --- /dev/null +++ b/tianshou/policy/random.py @@ -0,0 +1,40 @@ +import numpy as np +from typing import Union, Optional, Dict, List + +from tianshou.data import Batch +from tianshou.policy import BasePolicy + + +class RandomPolicy(BasePolicy): + """A random agent used in multi-agent learning. It randomly chooses an + action from the legal action. + """ + + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs) -> Batch: + """Compute the random action over the given batch data. The input + should contain a mask in batch.obs, with "True" to be available and + "False" to be unavailable. + For example, ``batch.obs.mask == np.array([[False, True, False]])`` + means with batch size 1, action "1" is available but action "0" and + "2" are unavailable. + + :return: A :class:`~tianshou.data.Batch` with "act" key, containing + the random action. + + .. seealso:: + + Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for + more detailed explanation. + """ + mask = batch.obs.mask + logits = np.random.rand(*mask.shape) + logits[~mask] = -np.inf + return Batch(act=logits.argmax(axis=-1)) + + def learn(self, batch: Batch, **kwargs + ) -> Dict[str, Union[float, List[float]]]: + """No need of a learn function for a random agent, so it returns an + empty dict.""" + return {} diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 5cecd0570..edcf0fdb8 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -27,6 +27,7 @@ def offpolicy_trainer( writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, + test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. @@ -65,6 +66,7 @@ def offpolicy_trainer( SummaryWriter. :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. + :param bool test_in_train: whether to test in the training phase. :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -72,7 +74,7 @@ def offpolicy_trainer( best_epoch, best_reward = -1, -1 stat = {} start_time = time.time() - test_in_train = train_collector.policy == policy + test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train policy.train() diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 5f7ae7694..b0d68ff2a 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -27,6 +27,7 @@ def onpolicy_trainer( writer: Optional[SummaryWriter] = None, log_interval: int = 1, verbose: bool = True, + test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. @@ -66,6 +67,7 @@ def onpolicy_trainer( SummaryWriter. :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. + :param bool test_in_train: whether to test in the training phase. :return: See :func:`~tianshou.trainer.gather_info`. """ @@ -73,7 +75,7 @@ def onpolicy_trainer( best_epoch, best_reward = -1, -1 stat = {} start_time = time.time() - test_in_train = train_collector.policy == policy + test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): # train policy.train()