diff --git a/README.md b/README.md index 7226d7c52..5037a7868 100644 --- a/README.md +++ b/README.md @@ -158,13 +158,14 @@ Currently, the overall code of Tianshou platform is less than 2500 lines. Most o ```python result = collector.collect(n_step=n) ``` - -If you have 3 environments in total and want to collect 1 episode in the first environment, 3 for the third environment: +If you have 3 environments in total and want to collect 4 episodes: ```python -result = collector.collect(n_episode=[1, 0, 3]) +result = collector.collect(n_episode=4) ``` +Collector will collect exactly 4 episodes without any bias of episode length despite we only have 3 parallel environments. + If you want to train the given policy with a sampled batch: ```python @@ -190,12 +191,13 @@ Define some hyper-parameters: ```python task = 'CartPole-v0' lr, epoch, batch_size = 1e-3, 10, 64 -train_num, test_num = 8, 100 +train_num, test_num = 10, 100 gamma, n_step, target_freq = 0.9, 3, 320 buffer_size = 20000 eps_train, eps_test = 0.1, 0.05 -step_per_epoch, collect_per_step = 1000, 10 +step_per_epoch, step_per_collect = 10000, 10 writer = SummaryWriter('log/dqn') # tensorboard is also supported! +logger = ts.utils.BasicLogger(writer) ``` Make environments: @@ -223,20 +225,20 @@ Setup policy and collectors: ```python policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq) -train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size)) -test_collector = ts.data.Collector(policy, test_envs) +train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True) +test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method ``` Let's train it: ```python result = ts.trainer.offpolicy_trainer( - policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, - test_num, batch_size, + policy, train_collector, test_collector, epoch, step_per_epoch, step_per_collect, + test_num, batch_size, update_per_step=1 / step_per_collect, train_fn=lambda epoch, env_step: policy.set_eps(eps_train), test_fn=lambda epoch, env_step: policy.set_eps(eps_test), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, - writer=writer) + logger=logger) print(f'Finished training! Use {result["duration"]}') ``` @@ -252,7 +254,7 @@ Watch the performance with 35 FPS: ```python policy.eval() policy.set_eps(eps_test) -collector = ts.data.Collector(policy, env) +collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) ``` diff --git a/docs/api/tianshou.data.rst b/docs/api/tianshou.data.rst index fa1d5c738..eea262a76 100644 --- a/docs/api/tianshou.data.rst +++ b/docs/api/tianshou.data.rst @@ -1,7 +1,90 @@ tianshou.data ============= -.. automodule:: tianshou.data + +Batch +----- + +.. autoclass:: tianshou.data.Batch + :members: + :undoc-members: + :show-inheritance: + + +Buffer +------ + +ReplayBuffer +~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.ReplayBuffer + :members: + :undoc-members: + :show-inheritance: + +PrioritizedReplayBuffer +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.PrioritizedReplayBuffer + :members: + :undoc-members: + :show-inheritance: + +ReplayBufferManager +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.ReplayBufferManager + :members: + :undoc-members: + :show-inheritance: + +PrioritizedReplayBufferManager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.PrioritizedReplayBufferManager + :members: + :undoc-members: + :show-inheritance: + +VectorReplayBuffer +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.VectorReplayBuffer + :members: + :undoc-members: + :show-inheritance: + +PrioritizedVectorReplayBuffer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.PrioritizedVectorReplayBuffer + :members: + :undoc-members: + :show-inheritance: + +CachedReplayBuffer +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.CachedReplayBuffer + :members: + :undoc-members: + :show-inheritance: + +Collector +--------- + +Collector +~~~~~~~~~ + +.. autoclass:: tianshou.data.Collector + :members: + :undoc-members: + :show-inheritance: + +AsyncCollector +~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.data.AsyncCollector :members: :undoc-members: :show-inheritance: diff --git a/docs/api/tianshou.env.rst b/docs/api/tianshou.env.rst index 7201bae46..04848a778 100644 --- a/docs/api/tianshou.env.rst +++ b/docs/api/tianshou.env.rst @@ -1,12 +1,82 @@ tianshou.env ============ -.. automodule:: tianshou.env + +VectorEnv +--------- + +BaseVectorEnv +~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.BaseVectorEnv :members: :undoc-members: :show-inheritance: -.. automodule:: tianshou.env.worker +DummyVectorEnv +~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.DummyVectorEnv + :members: + :undoc-members: + :show-inheritance: + +SubprocVectorEnv +~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.SubprocVectorEnv + :members: + :undoc-members: + :show-inheritance: + +ShmemVectorEnv +~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.ShmemVectorEnv + :members: + :undoc-members: + :show-inheritance: + +RayVectorEnv +~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.RayVectorEnv + :members: + :undoc-members: + :show-inheritance: + + +Worker +------ + +EnvWorker +~~~~~~~~~ + +.. autoclass:: tianshou.env.worker.EnvWorker + :members: + :undoc-members: + :show-inheritance: + +DummyEnvWorker +~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.worker.DummyEnvWorker + :members: + :undoc-members: + :show-inheritance: + +SubprocEnvWorker +~~~~~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.worker.SubprocEnvWorker + :members: + :undoc-members: + :show-inheritance: + +RayEnvWorker +~~~~~~~~~~~~ + +.. autoclass:: tianshou.env.worker.RayEnvWorker :members: :undoc-members: :show-inheritance: diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index f492953ac..818253775 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -1,7 +1,106 @@ tianshou.policy =============== -.. automodule:: tianshou.policy +Base +---- + +.. autoclass:: tianshou.policy.BasePolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.RandomPolicy + :members: + :undoc-members: + :show-inheritance: + +Model-free +---------- + +DQN Family +~~~~~~~~~~ + +.. autoclass:: tianshou.policy.DQNPolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.C51Policy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.QRDQNPolicy + :members: + :undoc-members: + :show-inheritance: + +On-policy +~~~~~~~~~ + +.. autoclass:: tianshou.policy.PGPolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.A2CPolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.PPOPolicy + :members: + :undoc-members: + :show-inheritance: + +Off-policy +~~~~~~~~~~ + +.. autoclass:: tianshou.policy.DDPGPolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.TD3Policy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.SACPolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.DiscreteSACPolicy + :members: + :undoc-members: + :show-inheritance: + +Imitation +--------- + +.. autoclass:: tianshou.policy.ImitationPolicy + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: tianshou.policy.DiscreteBCQPolicy + :members: + :undoc-members: + :show-inheritance: + +Model-based +----------- + +.. autoclass:: tianshou.policy.PSRLPolicy + :members: + :undoc-members: + :show-inheritance: + +Multi-agent +----------- + +.. autoclass:: tianshou.policy.MultiAgentPolicyManager :members: :undoc-members: :show-inheritance: diff --git a/docs/api/tianshou.utils.rst b/docs/api/tianshou.utils.rst index 3a293b1c1..38a8e7ca6 100644 --- a/docs/api/tianshou.utils.rst +++ b/docs/api/tianshou.utils.rst @@ -6,16 +6,29 @@ tianshou.utils :undoc-members: :show-inheritance: + +Pre-defined Networks +-------------------- + +Common +~~~~~~ + .. automodule:: tianshou.utils.net.common :members: :undoc-members: :show-inheritance: +Discrete +~~~~~~~~ + .. automodule:: tianshou.utils.net.discrete :members: :undoc-members: :show-inheritance: +Continuous +~~~~~~~~~~ + .. automodule:: tianshou.utils.net.continuous :members: :undoc-members: diff --git a/docs/conf.py b/docs/conf.py index f7bcc562d..b981eb4d4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -70,6 +70,7 @@ ] ) } +autodoc_member_order = "bysource" bibtex_bibfiles = ['refs.bib'] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/contributor.rst b/docs/contributor.rst index b48d7ffb5..c594b2c0d 100644 --- a/docs/contributor.rst +++ b/docs/contributor.rst @@ -7,3 +7,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom * Minghao Zhang (`Mehooz `_) * Alexis Duburcq (`duburcqa `_) * Kaichao You (`youkaichao `_) +* Huayu Chen (`ChenDRAG `_) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 07155f356..fefb30934 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -144,7 +144,7 @@ And finally, :: test_processor = MyProcessor(size=100) - collector = Collector(policy, env, buffer, test_processor.preprocess_fn) + collector = Collector(policy, env, buffer, preprocess_fn=test_processor.preprocess_fn) Some examples are in `test/base/test_collector.py `_. @@ -156,7 +156,7 @@ RNN-style Training This is related to `Issue 19 `_. -First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`: +First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :class:`~tianshou.data.VectorReplayBuffer`, or other types of buffer you are using, like: :: buf = ReplayBuffer(size=size, stack_num=stack_num) @@ -206,14 +206,13 @@ The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1 It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as: :: - >>> from tianshou.data import ReplayBuffer + >>> from tianshou.data import Batch, ReplayBuffer >>> b = ReplayBuffer(size=3) - >>> b.add(obs=e.reset(), act=0, rew=0, done=0) + >>> b.add(Batch(obs=e.reset(), act=0, rew=0, done=0)) >>> print(b) ReplayBuffer( act: array([0, 0, 0]), - done: array([0, 0, 0]), - info: Batch(), + done: array([False, False, False]), obs: Batch( achieved_goal: array([[1.34183265, 0.74910039, 0.53472272], [0. , 0. , 0. ], @@ -234,7 +233,6 @@ It shows that the state is a dictionary which has 3 keys. It will stored in :cla 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]), ), - policy: Batch(), rew: array([0, 0, 0]), ) >>> print(b.obs.achieved_goal) @@ -278,7 +276,7 @@ For self-defined class, the replay buffer will store the reference into a ``nump >>> import networkx as nx >>> b = ReplayBuffer(size=3) - >>> b.add(obs=nx.Graph(), act=0, rew=0, done=0) + >>> b.add(Batch(obs=nx.Graph(), act=0, rew=0, done=0)) >>> print(b) ReplayBuffer( act: array([0, 0, 0]), @@ -299,6 +297,10 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y ... return copy.deepcopy(self.graph), reward, done, {} +.. note :: + + Please make sure this variable is numpy-compatible, e.g., np.array([variable]) will not result in an empty array. Otherwise, ReplayBuffer cannot create an numpy array to store it. + .. _marl_example: diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 8ea2d272e..888e7bfc3 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -53,11 +53,163 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair Buffer ------ -.. automodule:: tianshou.data.ReplayBuffer - :members: - :noindex: +:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of :class:`~tianshou.data.Batch`. It stores all the data in a batch with circular-queue style. -Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. +The current implementation of Tianshou typically use 7 reserved keys in +:class:`~tianshou.data.Batch`: + +* ``obs`` the observation of step :math:`t` ; +* ``act`` the action of step :math:`t` ; +* ``rew`` the reward of step :math:`t` ; +* ``done`` the done flag of step :math:`t` ; +* ``obs_next`` the observation of step :math:`t+1` ; +* ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` function returns 4 arguments, and the last one is ``info``); +* ``policy`` the data computed by policy in step :math:`t`; + +The following code snippet illustrates its usage, including: + +- the basic data storage: ``add()``; +- get attribute, get slicing data, ...; +- sample from buffer: ``sample_index(batch_size)`` and ``sample(batch_size)``; +- get previous/next transition index within episodes: ``prev(index)`` and ``next(index)``; +- save/load data from buffer: pickle and HDF5; + +:: + + >>> import pickle, numpy as np + >>> from tianshou.data import ReplayBuffer + >>> buf = ReplayBuffer(size=20) + >>> for i in range(3): + ... buf.add(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}) + + >>> buf.obs + # since we set size = 20, len(buf.obs) == 20. + array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + >>> # but there are only three valid items, so len(buf) == 3. + >>> len(buf) + 3 + >>> # save to file "buf.pkl" + >>> pickle.dump(buf, open('buf.pkl', 'wb')) + >>> # save to HDF5 file + >>> buf.save_hdf5('buf.hdf5') + + >>> buf2 = ReplayBuffer(size=10) + >>> for i in range(15): + ... done = i % 4 == 0 + ... buf2.add(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}) + >>> len(buf2) + 10 + >>> buf2.obs + # since its size = 10, it only stores the last 10 steps' result. + array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9]) + + >>> # move buf2's result into buf (meanwhile keep it chronologically) + >>> buf.update(buf2) + >>> buf.obs + array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0, + 0, 0, 0, 0]) + + >>> # get all available index by using batch_size = 0 + >>> indice = buf.sample_index(0) + >>> indice + array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + >>> # get one step previous/next transition + >>> buf.prev(indice) + array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11]) + >>> buf.next(indice) + array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12]) + + >>> # get a random sample from buffer + >>> # the batch_data is equal to buf[indice]. + >>> batch_data, indice = buf.sample(batch_size=4) + >>> batch_data.obs == buf[indice].obs + array([ True, True, True, True]) + >>> len(buf) + 13 + + >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" + >>> len(buf) + 3 + >>> # load complete buffer from HDF5 file + >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') + >>> len(buf) + 3 + +:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling (typically for RNN usage, see issue#19), ignoring storing the next observation (save memory in Atari tasks), and multi-modal observation (see issue#38): + +.. raw:: html + +
+ Advance usage of ReplayBuffer + +.. code-block:: python + + >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) + >>> for i in range(16): + ... done = i % 5 == 0 + ... ep_len, ep_rew = buf.add(obs={'id': i}, act=i, rew=i, + ... done=done, obs_next={'id': i + 1}) + ... print(i, ep_len, ep_rew) + 0 1 0.0 + 1 0 0.0 + 2 0 0.0 + 3 0 0.0 + 4 0 0.0 + 5 5 15.0 + 6 0 0.0 + 7 0 0.0 + 8 0 0.0 + 9 0 0.0 + 10 5 40.0 + 11 0 0.0 + 12 0 0.0 + 13 0 0.0 + 14 0 0.0 + 15 5 65.0 + >>> print(buf) # you can see obs_next is not saved in buf + ReplayBuffer( + obs: Batch( + id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), + ), + act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]), + rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), + done: array([False, True, False, False, False, False, True, False, + False]), + info: Batch(), + policy: Batch(), + ) + >>> index = np.arange(len(buf)) + >>> print(buf.get(index, 'obs').id) + [[ 7 7 8 9] + [ 7 8 9 10] + [11 11 11 11] + [11 11 11 12] + [11 11 12 13] + [11 12 13 14] + [12 13 14 15] + [ 7 7 7 7] + [ 7 7 7 8]] + >>> # here is another way to get the stacked data + >>> # (stack only for obs and obs_next) + >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() + 0 + >>> # we can get obs_next through __getitem__, even if it doesn't exist + >>> print(buf[:].obs_next.id) + [[ 7 8 9 10] + [ 7 8 9 10] + [11 11 11 12] + [11 11 12 13] + [11 12 13 14] + [12 13 14 15] + [12 13 14 15] + [ 7 7 7 8] + [ 7 7 8 9]] + +.. raw:: html + +

+ +Tianshou provides other type of data buffer such as :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``) and :class:`~tianshou.data.VectorReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail. Policy @@ -132,7 +284,7 @@ policy.process_fn The ``process_fn`` function computes some variables that depends on time-series. For example, compute the N-step or GAE returns. -Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: +Take 2-step return DQN as an example. The 2-step return DQN compute each transition's return as: .. math:: @@ -187,13 +339,34 @@ Collector The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. -:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer. +:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer, then return the statistics of the collected data such as episode's total reward. + +The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. Here are some example usages: +:: + + policy = PGPolicy(...) # or other policies if you wish + env = gym.make("CartPole-v0") + + replay_buffer = ReplayBuffer(size=10000) + + # here we set up a collector with a single environment + collector = Collector(policy, env, buffer=replay_buffer) -Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. + # the collector supports vectorized environments as well + vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3) + # buffer_num should be equal to (suggested) or larger than #envs + envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)]) + collector = Collector(policy, envs, buffer=vec_buffer) -The proposed solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number. + # collect 3 episodes + collector.collect(n_episode=3) + # collect at least 2 steps + collector.collect(n_step=2) + # collect episodes with visual rendering ("render" is the sleep time between + # rendering consecutive frames) + collector.collect(n_episode=1, render=0.03) -The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation. +There is also another type of collector :class:`~tianshou.data.AsyncCollector` which supports asynchronous environment setting (for those taking a long time to step). However, AsyncCollector only supports **at least** ``n_step`` or ``n_episode`` collection due to the property of asynchronous environments. Trainer diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index faea6a869..5c5d547d5 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -35,10 +35,10 @@ If you want to use the original ``gym.Env``: Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`) :: - train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)]) + train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)]) -Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``. +Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``. For the demonstration, here we use the second code-block. @@ -87,7 +87,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers. Yet, of cour net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -It is also possible to use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: +You can also use pre-defined MLP networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. 2. Output: some ``logits``, the next hidden state ``state``. The logits could be a tuple instead of a ``torch.Tensor``, or some other useful variables or results during the policy forwarding procedure. It depends on how the policy class process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. @@ -113,8 +113,8 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer. :: - train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000)) - test_collector = ts.data.Collector(policy, test_envs) + train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True) + test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) Train Policy with a Trainer @@ -125,33 +125,35 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t result = ts.trainer.offpolicy_trainer( policy, train_collector, test_collector, - max_epoch=10, step_per_epoch=1000, collect_per_step=10, - episode_per_test=100, batch_size=64, + max_epoch=10, step_per_epoch=10000, step_per_collect=10, + update_per_step=0.1, episode_per_test=100, batch_size=64, train_fn=lambda epoch, env_step: policy.set_eps(0.1), test_fn=lambda epoch, env_step: policy.set_eps(0.05), stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, - writer=None) + logger=None) print(f'Finished training! Use {result["duration"]}') The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; -* ``step_per_epoch``: The number of step for updating policy network in one epoch; -* ``collect_per_step``: The number of frames the collector would collect before the network update. For example, the code above means "collect 10 frames and do one policy network update"; +* ``step_per_epoch``: The number of environment step (a.k.a. transition) collected per epoch; +* ``step_per_collect``: The number of transition the collector would collect before the network update. For example, the code above means "collect 10 transitions and do one policy network update"; * ``episode_per_test``: The number of episodes for one policy evaluation. * ``batch_size``: The batch size of sample data, which is going to feed in the policy network. * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". * ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". * ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. -* ``writer``: See below. +* ``logger``: See below. The trainer supports `TensorBoard `_ for logging. It can be used as: :: from torch.utils.tensorboard import SummaryWriter + from tianshou.utils import BasicLogger writer = SummaryWriter('log/dqn') + logger = BasicLogger(writer) -Pass the writer into the trainer, and the training result will be recorded into the TensorBoard. +Pass the logger into the trainer, and the training result will be recorded into the TensorBoard. The returned result is a dictionary as follows: :: @@ -191,7 +193,7 @@ Watch the Agent's Performance policy.eval() policy.set_eps(0.05) - collector = ts.data.Collector(policy, env) + collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=1, render=1 / 35) @@ -205,9 +207,8 @@ Train a Policy with Customized Codes Tianshou supports user-defined training code. Here is the code snippet: :: - # pre-collect at least 5000 frames with random action before training - policy.set_eps(1) - train_collector.collect(n_step=5000) + # pre-collect at least 5000 transitions with random action before training + train_collector.collect(n_step=5000, random=True) policy.set_eps(0.1) for i in range(int(1e6)): # total step @@ -215,11 +216,11 @@ Tianshou supports user-defined training code. Here is the code snippet: # once if the collected episodes' mean returns reach the threshold, # or every 1000 steps, we test it on test_collector - if collect_result['rew'] >= env.spec.reward_threshold or i % 1000 == 0: + if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0: policy.set_eps(0.05) result = test_collector.collect(n_episode=100) - if result['rew'] >= env.spec.reward_threshold: - print(f'Finished training! Test mean returns: {result["rew"]}') + if result['rews'].mean() >= env.spec.reward_threshold: + print(f'Finished training! Test mean returns: {result["rews"].mean()}') break else: # back to training eps diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index ac9adc118..64b0dfd70 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -128,7 +128,7 @@ Tianshou already provides some builtin classes for multi-agent learning. You can >>> >>> # use collectors to collect a episode of trajectories >>> # the reward is a vector, so we need a scalar metric to monitor the training - >>> collector = Collector(policy, env, reward_metric=lambda x: x[0]) + >>> collector = Collector(policy, env) >>> >>> # you will see a long trajectory showing the board status at each timestep >>> result = collector.collect(n_episode=1, render=.1) @@ -176,11 +176,12 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul import numpy as np from copy import deepcopy from torch.utils.tensorboard import SummaryWriter + from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer - from tianshou.data import Collector, ReplayBuffer + from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import BasePolicy, RandomPolicy, DQNPolicy, MultiAgentPolicyManager from tic_tac_toe_env import TicTacToeEnv @@ -199,27 +200,28 @@ The explanation of each Tianshou class/function will be deferred to their first help='a smaller gamma favors earlier win') parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) - parser.add_argument('--board_size', type=int, default=6) - parser.add_argument('--win_size', type=int, default=4) - parser.add_argument('--win-rate', type=float, default=np.float32(0.9), + parser.add_argument('--board-size', type=int, default=6) + parser.add_argument('--win-size', type=int, default=4) + parser.add_argument('--win-rate', type=float, default=0.9, help='the expected winning rate') parser.add_argument('--watch', default=False, action='store_true', help='no training, watch the play of pre-trained models') - parser.add_argument('--agent_id', type=int, default=2, + parser.add_argument('--agent-id', type=int, default=2, help='the learned agent plays as the agent_id-th player. Choices are 1 and 2.') - parser.add_argument('--resume_path', type=str, default='', + parser.add_argument('--resume-path', type=str, default='', help='the path of agent pth file for resuming from a pre-trained agent') - parser.add_argument('--opponent_path', type=str, default='', + parser.add_argument('--opponent-path', type=str, default='', help='the path of opponent agent pth file for resuming from a pre-trained agent') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -240,11 +242,13 @@ Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, whi Here it is: :: - def get_agents(args=get_args(), - agent_learn=None, # BasePolicy - agent_opponent=None, # BasePolicy - optim=None, # torch.optim.Optimizer - ): # return a tuple of (BasePolicy, torch.optim.Optimizer) + def get_agents( + args=get_args(), + agent_learn=None, # BasePolicy + agent_opponent=None, # BasePolicy + optim=None, # torch.optim.Optimizer + ): # return a tuple of (BasePolicy, torch.optim.Optimizer) + env = TicTacToeEnv(args.board_size, args.win_size) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n @@ -279,9 +283,6 @@ With the above preparation, we are close to the first learned agent. The followi :: args = get_args() - # the reward is a vector, we need a scalar metric to monitor the training. - # we choose the reward of the learning agent - Collector._default_rew_metric = lambda x: x[args.agent_id - 1] # ======== a test function that tests a pre-trained agent and exit ====== def watch(args=get_args(), @@ -294,7 +295,7 @@ With the above preparation, we are close to the first learned agent. The followi policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + print(f'Final reward: {result["rews"][:, args.agent_id - 1].mean()}, length: {result["lens"].mean()}') if args.watch: watch(args) exit(0) @@ -313,16 +314,16 @@ With the above preparation, we are close to the first learned agent. The followi policy, optim = get_agents() # ======== collector setup ========= - train_collector = Collector(policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - train_collector.collect(n_step=args.batch_size) + buffer = VectorReplayBuffer(args.buffer_size, args.training_num) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector.collect(n_step=args.batch_size * args.training_num) # ======== tensorboard logging setup ========= - if not hasattr(args, 'writer'): - log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') - writer = SummaryWriter(log_path) - else: - writer = args.writer + log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) # ======== callback functions used during training ========= @@ -347,13 +348,18 @@ With the above preparation, we are close to the first learned agent. The followi def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) + # the reward is a vector, we need a scalar metric to monitor the training. + # we choose the reward of the learning agent + def reward_metric(rews): + return rews[:, args.agent_id - 1] + # start training, this may require about three minutes result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + logger=logger, test_in_train=False, reward_metric=reward_metric) agent = policy.policies[args.agent_id - 1] # let's watch the match! @@ -476,7 +482,7 @@ By default, the trained agent is stored in ``log/tic_tac_toe/dqn/policy.pth``. Y .. code-block:: console - $ python test_tic_tac_toe.py --watch --resume_path=log/tic_tac_toe/dqn/policy.pth --opponent_path=log/tic_tac_toe/dqn/policy.pth + $ python test_tic_tac_toe.py --watch --resume-path log/tic_tac_toe/dqn/policy.pth --opponent-path log/tic_tac_toe/dqn/policy.pth Here is our output: diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py index e9edb8310..05633bc4a 100644 --- a/examples/atari/atari_bcq.py +++ b/examples/atari/atari_bcq.py @@ -2,10 +2,12 @@ import torch import pickle import pprint +import datetime import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offline_trainer from tianshou.utils.net.discrete import Actor @@ -28,7 +30,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=100) - parser.add_argument("--step-per-epoch", type=int, default=10000) + parser.add_argument("--update-per-epoch", type=int, default=10000) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) @@ -39,7 +41,7 @@ def get_args(): parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--watch", default=False, action="store_true", help="watch the play of pre-trained policy only") - parser.add_argument("--log-interval", type=int, default=1000) + parser.add_argument("--log-interval", type=int, default=100) parser.add_argument( "--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5", @@ -111,10 +113,15 @@ def test_discrete_bcq(args=get_args()): exit(0) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) - log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') + # log + log_path = os.path.join( + args.logdir, args.task, 'bcq', + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer, update_interval=args.log_interval) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -130,7 +137,7 @@ def watch(): test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) @@ -140,8 +147,8 @@ def watch(): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, log_interval=args.log_interval, ) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index c74da8f13..0956d691c 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -6,9 +6,10 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from atari_network import C51 from atari_wrapper import wrap_deepmind @@ -30,10 +31,11 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -84,20 +86,21 @@ def test_c51(args=get_args()): ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device - )) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) # collector - train_collector = Collector(policy, train_envs, buffer) - test_collector = Collector(policy, test_envs) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -118,7 +121,7 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=env_step) + logger.write('train/eps', env_step, eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) @@ -130,8 +133,7 @@ def watch(): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) if args.watch: @@ -139,13 +141,14 @@ def watch(): exit(0) # test train_collector and start filling replay buffer - train_collector.collect(n_step=args.batch_size * 4) + train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 4f4c4f2df..b3f36c893 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -6,9 +6,10 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from atari_network import DQN from atari_wrapper import wrap_deepmind @@ -27,10 +28,11 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -80,20 +82,21 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device - )) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) # collector - train_collector = Collector(policy, train_envs, buffer) - test_collector = Collector(policy, test_envs) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -114,7 +117,7 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=env_step) + logger.write('train/eps', env_step, eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) @@ -127,9 +130,10 @@ def watch(): test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") - buffer = ReplayBuffer( - args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack) collector = Collector(policy, test_envs, buffer) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") @@ -138,7 +142,7 @@ def watch(): else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) @@ -147,13 +151,14 @@ def watch(): exit(0) # test train_collector and start filling replay buffer - train_collector.collect(n_step=args.batch_size * 4) + train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 08c34733c..ae2a26f4f 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -5,10 +5,11 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.policy import QRDQNPolicy from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from atari_network import QRDQN from atari_wrapper import wrap_deepmind @@ -28,10 +29,11 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=32) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=10) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -82,20 +84,21 @@ def test_qrdqn(args=get_args()): ).to(args.device) # load a previous policy if args.resume_path: - policy.load_state_dict(torch.load( - args.resume_path, map_location=args.device - )) + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) print("Loaded agent from: ", args.resume_path) # replay buffer: `save_last_obs` and `stack_num` can be removed together # when you have enough RAM - buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True, - save_only_last_obs=True, stack_num=args.frames_stack) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) # collector - train_collector = Collector(policy, train_envs, buffer) - test_collector = Collector(policy, test_envs) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -116,7 +119,7 @@ def train_fn(epoch, env_step): else: eps = args.eps_train_final policy.set_eps(eps) - writer.add_scalar('train/eps', eps, global_step=env_step) + logger.write('train/eps', env_step, eps) def test_fn(epoch, env_step): policy.set_eps(args.eps_test) @@ -128,8 +131,7 @@ def watch(): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) if args.watch: @@ -137,13 +139,14 @@ def watch(): exit(0) # test train_collector and start filling replay buffer - train_collector.collect(n_step=args.batch_size * 4) + train_collector.collect(n_step=args.batch_size * args.training_num) # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) pprint.pprint(result) watch() diff --git a/examples/atari/runnable/atari.py b/examples/atari/runnable/atari.py deleted file mode 100644 index 8e2ea5168..000000000 --- a/examples/atari/runnable/atari.py +++ /dev/null @@ -1,133 +0,0 @@ -import cv2 -import gym -import numpy as np -from gym.spaces.box import Box -from tianshou.data import Batch - -SIZE = 84 -FRAME = 4 - - -def create_atari_environment(name=None, sticky_actions=True, - max_episode_steps=2000): - game_version = 'v0' if sticky_actions else 'v4' - name = '{}NoFrameskip-{}'.format(name, game_version) - env = gym.make(name) - env = env.env - env = preprocessing(env, max_episode_steps=max_episode_steps) - return env - - -def preprocess_fn(obs=None, act=None, rew=None, done=None, - obs_next=None, info=None, policy=None, **kwargs): - if obs_next is not None: - obs_next = np.reshape(obs_next, (-1, *obs_next.shape[2:])) - obs_next = np.moveaxis(obs_next, 0, -1) - obs_next = cv2.resize(obs_next, (SIZE, SIZE)) - obs_next = np.asanyarray(obs_next, dtype=np.uint8) - obs_next = np.reshape(obs_next, (-1, FRAME, SIZE, SIZE)) - obs_next = np.moveaxis(obs_next, 1, -1) - elif obs is not None: - obs = np.reshape(obs, (-1, *obs.shape[2:])) - obs = np.moveaxis(obs, 0, -1) - obs = cv2.resize(obs, (SIZE, SIZE)) - obs = np.asanyarray(obs, dtype=np.uint8) - obs = np.reshape(obs, (-1, FRAME, SIZE, SIZE)) - obs = np.moveaxis(obs, 1, -1) - - return Batch(obs=obs, act=act, rew=rew, done=done, - obs_next=obs_next, info=info) - - -class preprocessing(object): - def __init__(self, env, frame_skip=4, terminal_on_life_loss=False, - size=84, max_episode_steps=2000): - self.max_episode_steps = max_episode_steps - self.env = env - self.terminal_on_life_loss = terminal_on_life_loss - self.frame_skip = frame_skip - self.size = size - self.count = 0 - obs_dims = self.env.observation_space - - self.screen_buffer = [ - np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8), - np.empty((obs_dims.shape[0], obs_dims.shape[1]), dtype=np.uint8) - ] - - self.game_over = False - self.lives = 0 - - @property - def observation_space(self): - return Box(low=0, high=255, - shape=(self.size, self.size, self.frame_skip), - dtype=np.uint8) - - def action_space(self): - return self.env.action_space - - def reward_range(self): - return self.env.reward_range - - def metadata(self): - return self.env.metadata - - def close(self): - return self.env.close() - - def reset(self): - self.count = 0 - self.env.reset() - self.lives = self.env.ale.lives() - self._grayscale_obs(self.screen_buffer[0]) - self.screen_buffer[1].fill(0) - - return np.array([self._pool_and_resize() - for _ in range(self.frame_skip)]) - - def render(self, mode='human'): - return self.env.render(mode) - - def step(self, action): - total_reward = 0. - observation = [] - for t in range(self.frame_skip): - self.count += 1 - _, reward, terminal, info = self.env.step(action) - total_reward += reward - - if self.terminal_on_life_loss: - lives = self.env.ale.lives() - is_terminal = terminal or lives < self.lives - self.lives = lives - else: - is_terminal = terminal - - if is_terminal: - break - elif t >= self.frame_skip - 2: - t_ = t - (self.frame_skip - 2) - self._grayscale_obs(self.screen_buffer[t_]) - - observation.append(self._pool_and_resize()) - if len(observation) == 0: - observation = [self._pool_and_resize() - for _ in range(self.frame_skip)] - while len(observation) > 0 and \ - len(observation) < self.frame_skip: - observation.append(observation[-1]) - terminal = self.count >= self.max_episode_steps - return np.array(observation), total_reward, \ - (terminal or is_terminal), info - - def _grayscale_obs(self, output): - self.env.ale.getScreenGrayscale(output) - return output - - def _pool_and_resize(self): - if self.frame_skip > 1: - np.maximum(self.screen_buffer[0], self.screen_buffer[1], - out=self.screen_buffer[0]) - - return self.screen_buffer[0] diff --git a/examples/atari/runnable/pong_a2c.py b/examples/atari/runnable/pong_a2c.py deleted file mode 100644 index ffed1694d..000000000 --- a/examples/atari/runnable/pong_a2c.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import A2CPolicy -from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import Actor, Critic - -from atari import create_atari_environment, preprocess_fn - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pong') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=3e-4) - parser.add_argument('--gamma', type=float, default=0.9) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--repeat-per-collect', type=int, default=1) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=8) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - # a2c special - parser.add_argument('--vf-coef', type=float, default=0.5) - parser.add_argument('--ent-coef', type=float, default=0.001) - parser.add_argument('--max-grad-norm', type=float, default=None) - parser.add_argument('--max-episode-steps', type=int, default=2000) - return parser.parse_args() - - -def test_a2c(args=get_args()): - env = create_atari_environment(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.env.action_space.shape or env.env.action_space.n - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: create_atari_environment(args.task) - for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: create_atari_environment(args.task) - for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) - dist = torch.distributions.Categorical - policy = A2CPolicy( - actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef, - ent_coef=args.ent_coef, max_grad_norm=args.max_grad_norm) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), - preprocess_fn=preprocess_fn) - test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) - # log - writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c')) - - def stop_fn(mean_rewards): - if env.env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold - else: - return False - - # trainer - result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = create_atari_environment(args.task) - collector = Collector(policy, env, preprocess_fn=preprocess_fn) - result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_a2c() diff --git a/examples/atari/runnable/pong_ppo.py b/examples/atari/runnable/pong_ppo.py deleted file mode 100644 index 35ed0e749..000000000 --- a/examples/atari/runnable/pong_ppo.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import PPOPolicy -from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.discrete import Actor, Critic - -from atari import create_atari_environment, preprocess_fn - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Pong') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--repeat-per-collect', type=int, default=2) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=8) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - # ppo special - parser.add_argument('--vf-coef', type=float, default=0.5) - parser.add_argument('--ent-coef', type=float, default=0.0) - parser.add_argument('--eps-clip', type=float, default=0.2) - parser.add_argument('--max-grad-norm', type=float, default=0.5) - parser.add_argument('--max-episode-steps', type=int, default=2000) - return parser.parse_args() - - -def test_ppo(args=get_args()): - env = create_atari_environment(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space().shape or env.action_space().n - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv([ - lambda: create_atari_environment(args.task) - for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, device=args.device).to(args.device) - critic = Critic(net, device=args.device).to(args.device) - optim = torch.optim.Adam(set( - actor.parameters()).union(critic.parameters()), lr=args.lr) - dist = torch.distributions.Categorical - policy = PPOPolicy( - actor, critic, optim, dist, args.gamma, - max_grad_norm=args.max_grad_norm, - eps_clip=args.eps_clip, - vf_coef=args.vf_coef, - ent_coef=args.ent_coef, - action_range=None) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size), - preprocess_fn=preprocess_fn) - test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) - # log - writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo')) - - def stop_fn(mean_rewards): - if env.env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold - else: - return False - - # trainer - result = onpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - env = create_atari_environment(args.task) - collector = Collector(policy, env, preprocess_fn=preprocess_fn) - result = collector.collect(n_step=2000, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_ppo() diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 071990069..58f1a3783 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -7,10 +7,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -25,16 +26,16 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=100) + parser.add_argument('--update-per-step', type=float, default=0.01) parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128]) parser.add_argument('--dueling-q-hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -72,13 +73,16 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -102,9 +106,9 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -114,9 +118,9 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 58eb98ec9..d5a8f0577 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -7,10 +7,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy +from tianshou.utils import BasicLogger from tianshou.utils.net.common import Net from tianshou.env import SubprocVectorEnv from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -27,12 +28,13 @@ def get_args(): parser.add_argument('--auto-alpha', type=int, default=1) parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -125,12 +127,15 @@ def test_sac_bipedal(args=get_args()): # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -141,9 +146,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, test_in_train=False, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) if __name__ == '__main__': pprint.pprint(result) @@ -151,9 +156,10 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, + result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 5c98fb779..f5a9d3bdf 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -7,9 +7,10 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv @@ -26,8 +27,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=4) parser.add_argument('--target-update-freq', type=int, default=500) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=5000) - parser.add_argument('--collect-per-step', type=int, default=16) + parser.add_argument('--step-per-epoch', type=int, default=80000) + parser.add_argument('--step-per-collect', type=int, default=16) + parser.add_argument('--update-per-step', type=float, default=0.0625) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) @@ -35,7 +37,7 @@ def get_args(): nargs='*', default=[128, 128]) parser.add_argument('--dueling-v-hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -73,13 +75,16 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -97,10 +102,9 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, + test_fn=test_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -110,9 +114,9 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 47ca4e25c..f22c7846e 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -7,11 +7,12 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -29,12 +30,13 @@ def get_args(): parser.add_argument('--auto_alpha', type=int, default=1) parser.add_argument('--alpha', type=float, default=0.2) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=12000) + parser.add_argument('--step-per-collect', type=int, default=5) + parser.add_argument('--update-per-step', type=float, default=0.2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=5) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -90,16 +92,19 @@ def test_sac(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=args.rew_norm, ignore_done=True, + reward_normalization=args.rew_norm, exploration_noise=OUNoise(0.0, args.noise_std)) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -110,8 +115,10 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, logger=logger) + assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -119,9 +126,9 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 0b09c81a5..c07877324 100644 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -2,15 +2,17 @@ import gym import torch import pprint +import datetime import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import SACPolicy +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic @@ -28,14 +30,14 @@ def get_args(): parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--n-step', type=int, default=2) parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=10000) - parser.add_argument('--collect-per-step', type=int, default=4) - parser.add_argument('--update-per-step', type=int, default=1) + parser.add_argument('--step-per-epoch', type=int, default=40000) + parser.add_argument('--step-per-collect', type=int, default=4) + parser.add_argument('--update-per-step', type=float, default=0.25) parser.add_argument('--pre-collect-step', type=int, default=10000) parser.add_argument('--batch-size', type=int, default=256) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--training-num', type=int, default=4) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -108,11 +110,16 @@ def test_sac(args=get_args()): # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log - log_path = os.path.join(args.logdir, args.task, 'sac') + log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str( + args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S')) writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer, train_interval=args.log_interval) def watch(): # watch agent's performance @@ -120,8 +127,7 @@ def watch(): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render) pprint.pprint(result) def save_fn(policy): @@ -138,10 +144,9 @@ def stop_fn(mean_rewards): train_collector.collect(n_step=args.pre_collect_step, random=True) result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, args.update_per_step, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - log_interval=args.log_interval) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step) pprint.pprint(result) watch() diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py deleted file mode 100644 index 13192dbc9..000000000 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ /dev/null @@ -1,103 +0,0 @@ -import gym -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import DDPGPolicy -from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.continuous import Actor, Critic - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Ant-v2') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--actor-lr', type=float, default=1e-4) - parser.add_argument('--critic-lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--tau', type=float, default=0.005) - parser.add_argument('--exploration-noise', type=float, default=0.1) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) - parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - return parser.parse_args() - - -def test_ddpg(args=get_args()): - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) - critic = Critic(net, device=args.device).to(args.device) - critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) - policy = DDPGPolicy( - actor, actor_optim, critic, critic_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], - tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), - reward_normalization=True, ignore_done=True) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - # log - writer = SummaryWriter(args.logdir + '/' + 'ddpg') - - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold - - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) - assert stop_fn(result['best_reward']) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_ddpg() diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py deleted file mode 100644 index 5ed45506a..000000000 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ /dev/null @@ -1,112 +0,0 @@ -import gym -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import TD3Policy -from tianshou.env import SubprocVectorEnv -from tianshou.utils.net.common import Net -from tianshou.exploration import GaussianNoise -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.continuous import Actor, Critic - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='Ant-v2') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--actor-lr', type=float, default=3e-4) - parser.add_argument('--critic-lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--tau', type=float, default=0.005) - parser.add_argument('--exploration-noise', type=float, default=0.1) - parser.add_argument('--policy-noise', type=float, default=0.2) - parser.add_argument('--noise-clip', type=float, default=0.5) - parser.add_argument('--update-actor-freq', type=int, default=2) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - return parser.parse_args() - - -def test_td3(args=get_args()): - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) - critic1 = Critic(net, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic(net, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy = TD3Policy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], - tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), - policy_noise=args.policy_noise, - update_actor_freq=args.update_actor_freq, - noise_clip=args.noise_clip, - reward_normalization=True, ignore_done=True) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - # train_collector.collect(n_step=args.buffer_size) - # log - writer = SummaryWriter(args.logdir + '/' + 'td3') - - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold - - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) - assert stop_fn(result['best_reward']) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_td3() diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py deleted file mode 100644 index b669d264a..000000000 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -import gym -import torch -import pprint -import argparse -import numpy as np -import pybullet_envs -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import SACPolicy -from tianshou.utils.net.common import Net -from tianshou.env import SubprocVectorEnv -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.continuous import ActorProb, Critic - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='HalfCheetahBulletEnv-v0') - parser.add_argument('--run-id', type=str, default='test') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--actor-lr', type=float, default=3e-4) - parser.add_argument('--critic-lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--tau', type=float, default=0.005) - parser.add_argument('--alpha', type=float, default=0.2) - parser.add_argument('--epoch', type=int, default=200) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=4) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--log-interval', type=int, default=100) - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - return parser.parse_args() - - -def test_sac(args=get_args()): - torch.set_num_threads(1) - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - # you can also use tianshou.env.SubprocVectorEnv - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, - device=args.device, unbounded=True).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) - critic1 = Critic(net, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) - critic2 = Critic(net, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy = SACPolicy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], - tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=True, ignore_done=True) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - # train_collector.collect(n_step=args.buffer_size) - # log - log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) - writer = SummaryWriter(log_path) - - def stop_fn(mean_rewards): - return mean_rewards >= env.spec.reward_threshold - - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, - writer=writer, log_interval=args.log_interval) - assert stop_fn(result['best_reward']) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - __all__ = ('pybullet_envs',) # Avoid F401 error :) - test_sac() diff --git a/examples/mujoco/runnable/mujoco/assets/point.xml b/examples/mujoco/runnable/mujoco/assets/point.xml deleted file mode 100644 index 38cc64407..000000000 --- a/examples/mujoco/runnable/mujoco/assets/point.xml +++ /dev/null @@ -1,34 +0,0 @@ - - - diff --git a/examples/mujoco/runnable/mujoco/maze_env_utils.py b/examples/mujoco/runnable/mujoco/maze_env_utils.py deleted file mode 100644 index dafce77f5..000000000 --- a/examples/mujoco/runnable/mujoco/maze_env_utils.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Adapted from rllab maze_env_utils.py.""" -import math - - -class Move(object): - X = 11 - Y = 12 - Z = 13 - XY = 14 - XZ = 15 - YZ = 16 - XYZ = 17 - SpinXY = 18 - - -def can_move_x(movable): - return movable in [Move.X, Move.XY, Move.XZ, Move.XYZ, - Move.SpinXY] - - -def can_move_y(movable): - return movable in [Move.Y, Move.XY, Move.YZ, Move.XYZ, - Move.SpinXY] - - -def can_move_z(movable): - return movable in [Move.Z, Move.XZ, Move.YZ, Move.XYZ] - - -def can_spin(movable): - return movable in [Move.SpinXY] - - -def can_move(movable): - return can_move_x(movable) or can_move_y(movable) or can_move_z(movable) - - -def construct_maze(maze_id='Maze'): - if maze_id == 'Maze': - structure = [ - [1, 1, 1, 1, 1], - [1, 'r', 0, 0, 1], - [1, 1, 1, 0, 1], - [1, 'g', 0, 0, 1], - [1, 1, 1, 1, 1], - ] - elif maze_id == 'Maze1': - structure = [ - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 'r', 1, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 1, 1, 0, 1], - [1, 1, 1, 1, 1, 0, 0, 1], - [1, 0, 0, 0, 1, 0, 1, 1], - [1, 0, 0, 0, 1, 0, 1, 1], - [1, 0, 1, 0, 0, 0, 0, 1], - [1, 1, 1, 1, 1, 1, 1, 1], - ] - elif maze_id == 'Maze2': - structure = [ - [0, 0, 0, 0, 0], - [0, 'r', 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - ] - # transfer maze - elif maze_id == 'Maze3': - structure = [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], - [1, 0, 'r', 0, 0, 1, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1], - [1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1], - [1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1], - [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 'g', 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - elif maze_id == 'Maze4': - structure = [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], - [1, 0, 'r', 0, 0, 1, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], - [1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 0, 0, 0, 'g', 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - elif maze_id == 'Push': - structure = [ - [1, 1, 1, 1, 1], - [1, 0, 'r', 1, 1], - [1, 0, Move.XY, 0, 1], - [1, 1, 0, 1, 1], - [1, 1, 1, 1, 1], - ] - elif maze_id == 'Fall': - structure = [ - [1, 1, 1, 1], - [1, 'r', 0, 1], - [1, 0, Move.YZ, 1], - [1, -1, -1, 1], - [1, 0, 0, 1], - [1, 1, 1, 1], - ] - elif maze_id == 'Block': - structure = [ - [1, 1, 1, 1, 1], - [1, 'r', 0, 0, 1], - [1, 0, 0, 0, 1], - [1, 0, 0, 0, 1], - [1, 1, 1, 1, 1], - ] - elif maze_id == 'BlockMaze': - structure = [ - [1, 1, 1, 1], - [1, 'r', 0, 1], - [1, 1, 0, 1], - [1, 0, 0, 1], - [1, 1, 1, 1], - ] - else: - raise NotImplementedError( - 'The provided MazeId %s is not recognized' % maze_id) - - return structure - - -def line_intersect(pt1, pt2, ptA, ptB): - """ - Taken from https://www.cs.hmc.edu/ACM/lectures/intersections.html - this returns the intersection of Line(pt1,pt2) and Line(ptA,ptB) - """ - - DET_TOLERANCE = 0.00000001 - - # the first line is pt1 + r*(pt2-pt1) - # in component form: - x1, y1 = pt1 - x2, y2 = pt2 - dx1 = x2 - x1 - dy1 = y2 - y1 - - # the second line is ptA + s*(ptB-ptA) - x, y = ptA - xB, yB = ptB - dx = xB - x - dy = yB - y - - DET = (-dx1 * dy + dy1 * dx) - - if math.fabs(DET) < DET_TOLERANCE: - return (0, 0, 0, 0, 0) - - # now, the determinant should be OK - DETinv = 1.0 / DET - - # find the scalar amount along the "self" segment - r = DETinv * (-dy * (x - x1) + dx * (y - y1)) - - # find the scalar amount along the input line - s = DETinv * (-dy1 * (x - x1) + dx1 * (y - y1)) - - # return the average of the two descriptions - xi = (x1 + r * dx1 + x + s * dx) / 2.0 - yi = (y1 + r * dy1 + y + s * dy) / 2.0 - return (xi, yi, 1, r, s) - - -def ray_segment_intersect(ray, segment): - """ - Check if the ray originated from (x, y) with direction theta - intersects the line segment (x1, y1) -- (x2, y2), and return - the intersection point if there is one - """ - (x, y), theta = ray - # (x1, y1), (x2, y2) = segment - pt1 = (x, y) - len = 1 - pt2 = (x + len * math.cos(theta), y + len * math.sin(theta)) - xo, yo, valid, r, s = line_intersect(pt1, pt2, *segment) - if valid and r >= 0 and 0 <= s <= 1: - return (xo, yo) - return None - - -def point_distance(p1, p2): - x1, y1 = p1 - x2, y2 = p2 - return ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5 diff --git a/examples/mujoco/runnable/mujoco/point.py b/examples/mujoco/runnable/mujoco/point.py deleted file mode 100644 index 2a6a08d41..000000000 --- a/examples/mujoco/runnable/mujoco/point.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Wrapper for creating the ant environment in gym_mujoco.""" - -import math -import numpy as np -from gym import utils -from gym.envs.mujoco import mujoco_env - - -class PointEnv(mujoco_env.MujocoEnv, utils.EzPickle): - FILE = "point.xml" - ORI_IND = 2 - - def __init__(self, file_path=None, expose_all_qpos=True, noisy_init=False): - self._expose_all_qpos = expose_all_qpos - self.noisy_init = noisy_init - mujoco_env.MujocoEnv.__init__(self, file_path, 1) - utils.EzPickle.__init__(self) - - @property - def physics(self): - return self.model - - def _step(self, a): - return self.step(a) - - def step(self, action): - # action[0] is velocity, action[1] is direction - action[0] = 0.2 * action[0] - qpos = np.copy(self.data.qpos) - qpos[2] += action[1] - ori = qpos[2] - # compute increment in each direction - dx = math.cos(ori) * action[0] - dy = math.sin(ori) * action[0] - # ensure that the robot is within reasonable range - qpos[0] = np.clip(qpos[0] + dx, -100, 100) - qpos[1] = np.clip(qpos[1] + dy, -100, 100) - qvel = np.squeeze(self.data.qvel) - self.set_state(qpos, qvel) - for _ in range(0, self.frame_skip): - # self.physics.step() - self.sim.step() - next_obs = self._get_obs() - reward = 0 - done = False - info = {} - return next_obs, reward, done, info - - def _get_obs(self): - if self._expose_all_qpos: - return np.concatenate([ - self.data.qpos.flat[:3], # Only point-relevant coords. - self.data.qvel.flat[:3]]) - return np.concatenate([ - self.data.qpos.flat[2:3], - self.data.qvel.flat[:3]]) - - def reset_model(self): - if self.noisy_init: - qpos = self.init_qpos + self.np_random.uniform( - size=self.model.nq, low=-.1, high=.1) - qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 - - else: - qpos = self.init_qpos - qvel = self.init_qvel - - # Set everything other than point to original position and 0 velocity. - qpos[3:] = self.init_qpos[3:] - qvel[3:] = 0. - self.set_state(qpos, qvel) - return self._get_obs() - - def get_ori(self): - return self.data.qpos[self.__class__.ORI_IND] - - def set_xy(self, xy): - qpos = np.copy(self.data.qpos) - qpos[0] = xy[0] - qpos[1] = xy[1] - - qvel = self.data.qvel - self.set_state(qpos, qvel) - - def get_xy(self): - qpos = np.copy(self.data.qpos) - return qpos[:2] - - def viewer_setup(self): - - self.viewer.cam.trackbodyid = -1 - self.viewer.cam.distance = 80 - self.viewer.cam.elevation = -90 diff --git a/examples/mujoco/runnable/mujoco/point_maze_env.py b/examples/mujoco/runnable/mujoco/point_maze_env.py deleted file mode 100644 index c8e8ef84b..000000000 --- a/examples/mujoco/runnable/mujoco/point_maze_env.py +++ /dev/null @@ -1,568 +0,0 @@ -"""Adapted from rllab maze_env.py.""" - -import os -import tempfile -import xml.etree.ElementTree as ET -import math -import numpy as np -import gym -from . import maze_env_utils -from .point import PointEnv -from gym.utils import seeding - -# Directory that contains mujoco xml files. -MODEL_DIR = os.path.join(os.path.dirname(__file__), 'assets') - - -class PointMazeEnv(gym.Env): - MODEL_CLASS = PointEnv - - MAZE_HEIGHT = None - MAZE_SIZE_SCALING = None - - def __init__( - self, - maze_id=None, - maze_height=0.5, - maze_size_scaling=8, - n_bins=0, - sensor_range=3., - sensor_span=2 * math.pi, - observe_blocks=False, - put_spin_near_agent=False, - top_down_view=False, - manual_collision=False, - goal=None, - EPS=0.25, - max_episode_steps=2000, - *args, - **kwargs): - self._maze_id = maze_id - - model_cls = self.__class__.MODEL_CLASS - if model_cls is None: - raise "MODEL_CLASS unspecified!" - xml_path = os.path.join(MODEL_DIR, model_cls.FILE) - self.tree = tree = ET.parse(xml_path) - self.worldbody = worldbody = tree.find(".//worldbody") - self.visualize_goal = False - self.max_episode_steps = max_episode_steps - self.t = 0 - self.MAZE_HEIGHT = height = maze_height - self.MAZE_SIZE_SCALING = size_scaling = maze_size_scaling - self._n_bins = n_bins - self._sensor_range = sensor_range * size_scaling - self._sensor_span = sensor_span - self._observe_blocks = observe_blocks - self._put_spin_near_agent = put_spin_near_agent - self._top_down_view = top_down_view - self._manual_collision = manual_collision - - self.MAZE_STRUCTURE = structure = maze_env_utils.construct_maze( - maze_id=self._maze_id) - # Elevate the maze to allow for falling. - self.elevated = any(-1 in row for row in structure) - self.blocks = any( - any(maze_env_utils.can_move(r) for r in row) - for row in structure) # Are there any movable blocks? - - torso_x, torso_y = self._find_robot() # x, y coordinates - self._init_torso_x = torso_x - self._init_torso_y = torso_y - self._init_positions = [ - (x - torso_x, y - torso_y) - for x, y in self._find_all_robots()] - - self._view = np.zeros([5, 5, 3]) - - height_offset = 0. - if self.elevated: - height_offset = height * size_scaling - torso = tree.find(".//body[@name='torso']") - torso.set('pos', '0 0 %.2f' % (0.75 + height_offset)) - if self.blocks: - default = tree.find(".//default") - default.find('.//geom').set('solimp', '.995 .995 .01') - - self.movable_blocks = [] - for i in range(len(structure)): - for j in range(len(structure[0])): - struct = structure[i][j] - if struct == 'r' and self._put_spin_near_agent: - struct = maze_env_utils.Move.SpinXY - if self.elevated and struct not in [-1]: - # Create elevated platform. - ET.SubElement( - worldbody, "geom", - name="elevated_%d_%d" % (i, j), - pos="%f %f %f" % (j * size_scaling - torso_x, - i * size_scaling - torso_y, - height / 2 * size_scaling), - size="%f %f %f" % (0.5 * size_scaling, - 0.5 * size_scaling, - height / 2 * size_scaling), - type="box", - material="", - contype="1", - conaffinity="1", - rgba="0.9 0.9 0.9 1", - ) - if struct == 1: # Unmovable block. - # Offset all coordinates so that robot starts at the origin - ET.SubElement( - worldbody, "geom", - name="block_%d_%d" % (i, j), - pos="%f %f %f" % (j * size_scaling - torso_x, - i * size_scaling - torso_y, - height_offset + - height / 2 * size_scaling), - size="%f %f %f" % (0.5 * size_scaling, - 0.5 * size_scaling, - height / 2 * size_scaling), - type="box", - material="", - contype="1", - conaffinity="1", - rgba="0.4 0.4 0.4 1", - ) - elif maze_env_utils.can_move(struct): - name = "movable_%d_%d" % (i, j) - self.movable_blocks.append((name, struct)) - falling = maze_env_utils.can_move_z(struct) - spinning = maze_env_utils.can_spin(struct) - x_offset = 0.25 * size_scaling if spinning else 0.0 - y_offset = 0.0 - shrink = 0.1 if spinning else 0.99 if falling else 1.0 - height_shrink = 0.1 if spinning else 1.0 - _x = j * size_scaling - torso_x + x_offset - _y = i * size_scaling - torso_y + y_offset - _z = height / 2 * size_scaling * height_shrink - movable_body = ET.SubElement( - worldbody, "body", - name=name, - pos="%f %f %f" % (_x, _y, height_offset + _z), - ) - ET.SubElement( - movable_body, "geom", - name="block_%d_%d" % (i, j), - pos="0 0 0", - size="%f %f %f" % (0.5 * size_scaling * shrink, - 0.5 * size_scaling * shrink, - _z), - type="box", - material="", - mass="0.001" if falling else "0.0002", - contype="1", - conaffinity="1", - rgba="0.9 0.1 0.1 1" - ) - if maze_env_utils.can_move_x(struct): - ET.SubElement( - movable_body, "joint", - armature="0", - axis="1 0 0", - damping="0.0", - limited="true" if falling else "false", - range="%f %f" % (-size_scaling, size_scaling), - margin="0.01", - name="movable_x_%d_%d" % (i, j), - pos="0 0 0", - type="slide" - ) - if maze_env_utils.can_move_y(struct): - ET.SubElement( - movable_body, "joint", - armature="0", - axis="0 1 0", - damping="0.0", - limited="true" if falling else "false", - range="%f %f" % (-size_scaling, size_scaling), - margin="0.01", - name="movable_y_%d_%d" % (i, j), - pos="0 0 0", - type="slide" - ) - if maze_env_utils.can_move_z(struct): - ET.SubElement( - movable_body, "joint", - armature="0", - axis="0 0 1", - damping="0.0", - limited="true", - range="%f 0" % (-height_offset), - margin="0.01", - name="movable_z_%d_%d" % (i, j), - pos="0 0 0", - type="slide" - ) - if maze_env_utils.can_spin(struct): - ET.SubElement( - movable_body, "joint", - armature="0", - axis="0 0 1", - damping="0.0", - limited="false", - name="spinable_%d_%d" % (i, j), - pos="0 0 0", - type="ball" - ) - - torso = tree.find(".//body[@name='torso']") - geoms = torso.findall(".//geom") - for geom in geoms: - if 'name' not in geom.attrib: - raise Exception("Every geom of the torso must have a name " - "defined") - - _, file_path = tempfile.mkstemp(text=True, suffix='.xml') - tree.write(file_path) - - self.wrapped_env = model_cls(*args, file_path=file_path, **kwargs) - self.args = args - self.kwargs = kwargs - self.GOAL = goal - if self.GOAL is not None: - self.GOAL = self.unwrapped._rowcol_to_xy(*self.GOAL) - self.EPS = EPS - - def get_ori(self): - return self.wrapped_env.get_ori() - - def get_top_down_view(self): - self._view = np.zeros_like(self._view) - - def valid(row, col): - return self._view.shape[0] > row >= 0 \ - and self._view.shape[1] > col >= 0 - - def update_view(x, y, d, row=None, col=None): - if row is None or col is None: - x = x - self._robot_x - y = y - self._robot_y - - row, col = self._xy_to_rowcol(x, y) - update_view(x, y, d, row=row, col=col) - return - - row, row_frac, col, col_frac = int(row), row % 1, int(col), col % 1 - if row_frac < 0: - row_frac += 1 - if col_frac < 0: - col_frac += 1 - - if valid(row, col): - self._view[row, col, d] += ( - (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) * - (min(1., col_frac + 0.5) - max(0., col_frac - 0.5))) - if valid(row - 1, col): - self._view[row - 1, col, d] += ( - (max(0., 0.5 - row_frac)) * - (min(1., col_frac + 0.5) - max(0., col_frac - 0.5))) - if valid(row + 1, col): - self._view[row + 1, col, d] += ( - (max(0., row_frac - 0.5)) * - (min(1., col_frac + 0.5) - max(0., col_frac - 0.5))) - if valid(row, col - 1): - self._view[row, col - 1, d] += ( - (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) * - (max(0., 0.5 - col_frac))) - if valid(row, col + 1): - self._view[row, col + 1, d] += ( - (min(1., row_frac + 0.5) - max(0., row_frac - 0.5)) * - (max(0., col_frac - 0.5))) - if valid(row - 1, col - 1): - self._view[row - 1, col - 1, d] += ( - (max(0., 0.5 - row_frac)) * max(0., 0.5 - col_frac)) - if valid(row - 1, col + 1): - self._view[row - 1, col + 1, d] += ( - (max(0., 0.5 - row_frac)) * max(0., col_frac - 0.5)) - if valid(row + 1, col + 1): - self._view[row + 1, col + 1, d] += ( - (max(0., row_frac - 0.5)) * max(0., col_frac - 0.5)) - if valid(row + 1, col - 1): - self._view[row + 1, col - 1, d] += ( - (max(0., row_frac - 0.5)) * max(0., 0.5 - col_frac)) - - # Draw ant. - robot_x, robot_y = self.wrapped_env.get_body_com("torso")[:2] - self._robot_x = robot_x - self._robot_y = robot_y - self._robot_ori = self.get_ori() - - structure = self.MAZE_STRUCTURE - size_scaling = self.MAZE_SIZE_SCALING - - # Draw immovable blocks and chasms. - for i in range(len(structure)): - for j in range(len(structure[0])): - if structure[i][j] == 1: # Wall. - update_view(j * size_scaling - self._init_torso_x, - i * size_scaling - self._init_torso_y, - 0) - if structure[i][j] == -1: # Chasm. - update_view(j * size_scaling - self._init_torso_x, - i * size_scaling - self._init_torso_y, - 1) - - # Draw movable blocks. - for block_name, block_type in self.movable_blocks: - block_x, block_y = self.wrapped_env.get_body_com(block_name)[:2] - update_view(block_x, block_y, 2) - - import cv2 - cv2.imshow('x.jpg', cv2.resize( - np.uint8(self._view * 255), (512, 512), - interpolation=cv2.INTER_CUBIC)) - cv2.waitKey(0) - - return self._view - - def get_range_sensor_obs(self): - """Returns egocentric range sensor observations of maze.""" - robot_x, robot_y, robot_z = self.wrapped_env.get_body_com("torso")[:3] - ori = self.get_ori() - - structure = self.MAZE_STRUCTURE - size_scaling = self.MAZE_SIZE_SCALING - height = self.MAZE_HEIGHT - - segments = [] - # Get line segments (corresponding to outer boundary) of each immovable - # block or drop-off. - for i in range(len(structure)): - for j in range(len(structure[0])): - if structure[i][j] in [1, -1]: # There's a wall or drop-off. - cx = j * size_scaling - self._init_torso_x - cy = i * size_scaling - self._init_torso_y - x1 = cx - 0.5 * size_scaling - x2 = cx + 0.5 * size_scaling - y1 = cy - 0.5 * size_scaling - y2 = cy + 0.5 * size_scaling - struct_segments = [ - ((x1, y1), (x2, y1)), - ((x2, y1), (x2, y2)), - ((x2, y2), (x1, y2)), - ((x1, y2), (x1, y1)), - ] - for seg in struct_segments: - segments.append(dict( - segment=seg, - type=structure[i][j], - )) - - for block_name, block_type in self.movable_blocks: - block_x, block_y, block_z = \ - self.wrapped_env.get_body_com(block_name)[:3] - if (block_z + height * size_scaling / 2 >= robot_z and - robot_z >= block_z - height * size_scaling / 2): - # Block in view. - x1 = block_x - 0.5 * size_scaling - x2 = block_x + 0.5 * size_scaling - y1 = block_y - 0.5 * size_scaling - y2 = block_y + 0.5 * size_scaling - struct_segments = [ - ((x1, y1), (x2, y1)), - ((x2, y1), (x2, y2)), - ((x2, y2), (x1, y2)), - ((x1, y2), (x1, y1)), - ] - for seg in struct_segments: - segments.append(dict( - segment=seg, - type=block_type, - )) - - # 3 for wall, drop-off, block - sensor_readings = np.zeros((self._n_bins, 3)) - for ray_idx in range(self._n_bins): - ray_ori = (ori - self._sensor_span * 0.5 + ( - 2 * ray_idx + 1.0) / - (2 * self._n_bins) * self._sensor_span) - ray_segments = [] - # Get all segments that intersect with ray. - for seg in segments: - p = maze_env_utils.ray_segment_intersect( - ray=((robot_x, robot_y), ray_ori), - segment=seg["segment"]) - if p is not None: - ray_segments.append(dict( - segment=seg["segment"], - type=seg["type"], - ray_ori=ray_ori, - distance=maze_env_utils.point_distance( - p, (robot_x, robot_y)), - )) - if len(ray_segments) > 0: - # Find out which segment is intersected first. - first_seg = sorted( - ray_segments, key=lambda x: x["distance"])[0] - seg_type = first_seg["type"] - idx = (0 if seg_type == 1 else # Wall. - 1 if seg_type == -1 else # Drop-off. - 2 if maze_env_utils.can_move(seg_type) else # Block. - None) - if first_seg["distance"] <= self._sensor_range: - sensor_readings[ray_idx][idx] = \ - (self._sensor_range - first_seg[ - "distance"]) / self._sensor_range - return sensor_readings - - def _get_obs(self): - wrapped_obs = self.wrapped_env._get_obs() - if self._top_down_view: - self.get_top_down_view() - - if self._observe_blocks: - additional_obs = [] - for block_name, block_type in self.movable_blocks: - additional_obs.append( - self.wrapped_env.get_body_com(block_name)) - wrapped_obs = np.concatenate([wrapped_obs[:3]] + additional_obs + - [wrapped_obs[3:]]) - - self.get_range_sensor_obs() - return wrapped_obs - - def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - return [seed] - - def reset(self, goal=None): - self.goal = goal - - if self.visualize_goal: # remove the prev goal and add a new goal - goal_x, goal_y = goal[0], goal[1] - size_scaling = self.MAZE_SIZE_SCALING - # remove the original goal - try: - self.worldbody.remove(self.goal_element) - except AttributeError: - pass - # offset all coordinates so that robot starts at the origin - self.goal_element = \ - ET.SubElement( - self.worldbody, "geom", - name="goal_%d_%d" % (goal_x, goal_y), - pos="%f %f %f" % (goal_x, - goal_y, - self.MAZE_HEIGHT / 2 * size_scaling), - # smaller than the block to prevent collision - size="%f %f %f" % (0.1 * size_scaling, - 0.1 * size_scaling, - self.MAZE_HEIGHT / 2 * size_scaling), - type="box", - material="", - contype="1", - conaffinity="1", - rgba="1.0 0.0 0.0 0.5" - ) - # Note: running the lines below will make the robot position wrong! - # (because the graph is rebuilt) - torso = self.tree.find(".//body[@name='torso']") - geoms = torso.findall(".//geom") - for geom in geoms: - if 'name' not in geom.attrib: - raise Exception("Every geom of the torso must have a name " - "defined") - _, file_path = tempfile.mkstemp(text=True, suffix='.xml') - self.tree.write(file_path) - # here we write a temporal file with the robot specifications. - # Why not the original one?? - - model_cls = self.__class__.MODEL_CLASS - # file to the robot specifications; model_cls is AntEnv - self.wrapped_env = model_cls( - *self.args, file_path=file_path, **self.kwargs) - - self.t = 0 - self.trajectory = [] - self.wrapped_env.reset() - if len(self._init_positions) > 1: - xy = self._init_positions[self.np_random.randint( - len(self._init_positions))] - self.wrapped_env.set_xy(xy) - return self._get_obs() - - @property - def viewer(self): - return self.wrapped_env.viewer - - def render(self, *args, **kwargs): - return self.wrapped_env.render(*args, **kwargs) - - @property - def observation_space(self): - shape = self._get_obs().shape - high = np.inf * np.ones(shape) - low = -high - return gym.spaces.Box(low, high) - - @property - def action_space(self): - return self.wrapped_env.action_space - - def _find_robot(self): - structure = self.MAZE_STRUCTURE - size_scaling = self.MAZE_SIZE_SCALING - for i in range(len(structure)): - for j in range(len(structure[0])): - if structure[i][j] == 'r': - return j * size_scaling, i * size_scaling - assert False, 'No robot in maze specification.' - - def _find_all_robots(self): - structure = self.MAZE_STRUCTURE - size_scaling = self.MAZE_SIZE_SCALING - coords = [] - for i in range(len(structure)): - for j in range(len(structure[0])): - if structure[i][j] == 'r': - coords.append((j * size_scaling, i * size_scaling)) - return coords - - def _is_in_collision(self, pos): - x, y = pos - structure = self.MAZE_STRUCTURE - scale = self.MAZE_SIZE_SCALING - for i in range(len(structure)): - for j in range(len(structure[0])): - if structure[i][j] == 1: - minx = j * scale - scale * 0.5 - self._init_torso_x - maxx = j * scale + scale * 0.5 - self._init_torso_x - miny = i * scale - scale * 0.5 - self._init_torso_y - maxy = i * scale + scale * 0.5 - self._init_torso_y - if minx <= x <= maxx and miny <= y <= maxy: - return True - return False - - def _rowcol_to_xy(self, j, i): - scale = self.MAZE_SIZE_SCALING - minx = j * scale - scale * 0.5 - self._init_torso_x - maxx = j * scale + scale * 0.5 - self._init_torso_x - miny = i * scale - scale * 0.5 - self._init_torso_y - maxy = i * scale + scale * 0.5 - self._init_torso_y - return (minx + maxx) / 2, (miny + maxy) / 2 - - def step(self, action): - self.t += 1 - if self._manual_collision: - old_pos = self.wrapped_env.get_xy() - inner_next_obs, inner_reward, inner_done, info = \ - self.wrapped_env.step(action) - new_pos = self.wrapped_env.get_xy() - if self._is_in_collision(new_pos): - self.wrapped_env.set_xy(old_pos) - else: - inner_next_obs, inner_reward, inner_done, info = \ - self.wrapped_env.step(action) - next_obs = self._get_obs() - done = False - if self.goal is not None: - done = bool(((next_obs[:2] - self.goal[:2]) ** 2).sum() < self.EPS) - - new_pos = self.wrapped_env.get_xy() - if self._is_in_collision(new_pos) or inner_done: - done = True - if self.t >= self.max_episode_steps: - done = True - return next_obs, inner_reward, done, info diff --git a/examples/mujoco/runnable/mujoco/register.py b/examples/mujoco/runnable/mujoco/register.py deleted file mode 100644 index 82acac2af..000000000 --- a/examples/mujoco/runnable/mujoco/register.py +++ /dev/null @@ -1,27 +0,0 @@ -from gym.envs.registration import register - - -def reg(): - register( - id='PointMaze-v0', - entry_point='mujoco.point_maze_env:PointMazeEnv', - kwargs={ - "maze_size_scaling": 4, - "maze_id": "Maze2", - "maze_height": 0.5, - "manual_collision": True, - "goal": (1, 3), - } - ) - - register( - id='PointMaze-v1', - entry_point='mujoco.point_maze_env:PointMazeEnv', - kwargs={ - "maze_size_scaling": 2, - "maze_id": "Maze2", - "maze_height": 0.5, - "manual_collision": True, - "goal": (1, 3), - } - ) diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py deleted file mode 100644 index dbb612fc8..000000000 --- a/examples/mujoco/runnable/point_maze_td3.py +++ /dev/null @@ -1,120 +0,0 @@ -import gym -import torch -import pprint -import argparse -import numpy as np -from torch.utils.tensorboard import SummaryWriter - -from tianshou.policy import TD3Policy -from tianshou.utils.net.common import Net -from tianshou.env import SubprocVectorEnv -from tianshou.exploration import GaussianNoise -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer -from tianshou.utils.net.continuous import Actor, Critic - -from mujoco.register import reg - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--task', type=str, default='PointMaze-v1') - parser.add_argument('--seed', type=int, default=1626) - parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--actor-lr', type=float, default=3e-5) - parser.add_argument('--critic-lr', type=float, default=1e-4) - parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--tau', type=float, default=0.005) - parser.add_argument('--exploration-noise', type=float, default=0.1) - parser.add_argument('--policy-noise', type=float, default=0.2) - parser.add_argument('--noise-clip', type=float, default=0.5) - parser.add_argument('--update-actor-freq', type=int, default=2) - parser.add_argument('--epoch', type=int, default=100) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) - parser.add_argument('--test-num', type=int, default=100) - parser.add_argument('--logdir', type=str, default='log') - parser.add_argument('--render', type=float, default=0.) - parser.add_argument( - '--device', type=str, - default='cuda' if torch.cuda.is_available() else 'cpu') - return parser.parse_args() - - -def test_td3(args=get_args()): - reg() - env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] - # train_envs = gym.make(args.task) - train_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.training_num)]) - # test_envs = gym.make(args.task) - test_envs = SubprocVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) - # seed - np.random.seed(args.seed) - torch.manual_seed(args.seed) - train_envs.seed(args.seed) - test_envs.seed(args.seed) - # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, - device=args.device) - actor = Actor(net, args.action_shape, max_action=args.max_action, - device=args.device).to(args.device) - actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) - critic1 = Critic(net, device=args.device).to(args.device) - critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - net = Net(args.state_shape, args.action_shape, - hidden_sizes=args.hidden_sizes, concat=True, device=args.device) - critic2 = Critic(net, device=args.device).to(args.device) - critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) - policy = TD3Policy( - actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - action_range=[env.action_space.low[0], env.action_space.high[0]], - tau=args.tau, gamma=args.gamma, - exploration_noise=GaussianNoise(sigma=args.exploration_noise), - policy_noise=args.policy_noise, - update_actor_freq=args.update_actor_freq, - noise_clip=args.noise_clip, - reward_normalization=True, ignore_done=True) - # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector(policy, test_envs) - # train_collector.collect(n_step=args.buffer_size) - # log - writer = SummaryWriter(args.logdir + '/' + 'td3') - - def stop_fn(mean_rewards): - if env.spec.reward_threshold: - return mean_rewards >= env.spec.reward_threshold - else: - return False - - # trainer - result = offpolicy_trainer( - policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer) - assert stop_fn(result['best_reward']) - if __name__ == '__main__': - pprint.pprint(result) - # Let's watch its performance! - policy.eval() - test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') - - -if __name__ == '__main__': - test_td3() diff --git a/setup.cfg b/setup.cfg index 0a4742891..d485e6d06 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,14 @@ +[flake8] +exclude = + .git + log + __pycache__ + docs + build + dist + *.egg-info +max-line-length = 87 + [mypy] files = tianshou/**/*.py allow_redefinition = True diff --git a/test/base/env.py b/test/base/env.py index f0907f8e1..e71957a7e 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -10,15 +10,16 @@ class MyTestEnv(gym.Env): """ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, - ma_rew=0, multidiscrete_action=False, random_sleep=False): - assert not ( - dict_state and recurse_state), \ - "dict_state and recurse_state cannot both be true" + ma_rew=0, multidiscrete_action=False, random_sleep=False, + array_state=False): + assert dict_state + recurse_state + array_state <= 1, \ + "dict_state / recurse_state / array_state can be only one true" self.size = size self.sleep = sleep self.random_sleep = random_sleep self.dict_state = dict_state self.recurse_state = recurse_state + self.array_state = array_state self.ma_rew = ma_rew self._md_action = multidiscrete_action # how many steps this env has stepped @@ -36,6 +37,8 @@ def __init__(self, size, sleep=0, dict_state=False, recurse_state=False, "rand": Box(shape=(1, 2), low=0, high=1, dtype=np.float64)}) }) + elif array_state: + self.observation_space = Box(shape=(4, 84, 84), low=0, high=255) else: self.observation_space = Box(shape=(1, ), low=0, high=size - 1) if multidiscrete_action: @@ -72,6 +75,13 @@ def _get_state(self): 'dict': {"tuple": (np.array([1], dtype=np.int64), self.rng.rand(2)), "rand": self.rng.rand(1, 2)}} + elif self.array_state: + img = np.zeros([4, 84, 84], np.int) + img[3, np.arange(84), np.arange(84)] = self.index + img[2, np.arange(84)] = self.index + img[1, :, np.arange(84)] = self.index + img[0] = self.index + return img else: return np.array([self.index], dtype=np.float32) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 4553edff7..0898e154c 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -21,6 +21,8 @@ def test_batch(): assert not Batch(a=[1, 2, 3]).is_empty() b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) assert b.c.dtype == np.object + b = Batch(d=[None], e=[starmap], f=Batch) + assert b.d.dtype == b.e.dtype == np.object and b.f == Batch b = Batch() b.update() assert b.is_empty() diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 1695195b4..225375d0a 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1,15 +1,18 @@ import os +import h5py import torch import pickle import pytest import tempfile -import h5py import numpy as np from timeit import timeit -from tianshou.data import Batch, SegmentTree, \ - ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data.utils.converter import to_hdf5 +from tianshou.data import Batch, SegmentTree, ReplayBuffer +from tianshou.data import PrioritizedReplayBuffer +from tianshou.data import VectorReplayBuffer, CachedReplayBuffer +from tianshou.data import PrioritizedVectorReplayBuffer + if __name__ == '__main__': from env import MyTestEnv @@ -26,46 +29,77 @@ def test_replaybuffer(size=10, bufsize=20): action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(obs, [a], rew, done, obs_next, info) + buf.add(Batch(obs=obs, act=[a], rew=rew, + done=done, obs_next=obs_next, info=info)) obs = obs_next assert len(buf) == min(bufsize, i + 1) - with pytest.raises(ValueError): - buf._add_to_buffer('rew', np.array([1, 2, 3])) - assert buf.act.dtype == np.object - assert isinstance(buf.act[0], list) + assert buf.act.dtype == np.int + assert buf.act.shape == (bufsize, 1) data, indice = buf.sample(bufsize * 2) assert (indice < len(buf)).all() assert (data.obs < size).all() assert (0 <= data.done).all() and (data.done <= 1).all() b = ReplayBuffer(size=10) - b.add(1, 1, 1, 'str', 1, {'a': 3, 'b': {'c': 5.0}}) + # neg bsz should return empty index + assert b.sample_index(-1).tolist() == [] + ptr, ep_rew, ep_len, ep_idx = b.add( + Batch(obs=1, act=1, rew=1, done=1, obs_next='str', + info={'a': 3, 'b': {'c': 5.0}})) assert b.obs[0] == 1 - assert b.done[0] == 'str' + assert b.done[0] + assert b.obs_next[0] == 'str' assert np.all(b.obs[1:] == 0) - assert np.all(b.done[1:] == np.array(None)) + assert np.all(b.obs_next[1:] == np.array(None)) assert b.info.a[0] == 3 and b.info.a.dtype == np.integer assert np.all(b.info.a[1:] == 0) assert b.info.b.c[0] == 5.0 and b.info.b.c.dtype == np.inexact assert np.all(b.info.b.c[1:] == 0.0) + assert ptr.shape == (1,) and ptr[0] == 0 + assert ep_rew.shape == (1,) and ep_rew[0] == 1 + assert ep_len.shape == (1,) and ep_len[0] == 1 + assert ep_idx.shape == (1,) and ep_idx[0] == 0 + # test extra keys pop up, the buffer should handle it dynamically + batch = Batch(obs=2, act=2, rew=2, done=0, obs_next="str2", + info={"a": 4, "d": {"e": -np.inf}}) + b.add(batch) + info_keys = ["a", "b", "d"] + assert set(b.info.keys()) == set(info_keys) + assert b.info.a[1] == 4 and b.info.b.c[1] == 0 + assert b.info.d.e[1] == -np.inf + # test batch-style adding method, where len(batch) == 1 + batch.done = [1] + batch.info.e = np.zeros([1, 4]) + batch = Batch.stack([batch]) + ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) + assert ptr.shape == (1,) and ptr[0] == 2 + assert ep_rew.shape == (1,) and ep_rew[0] == 4 + assert ep_len.shape == (1,) and ep_len[0] == 2 + assert ep_idx.shape == (1,) and ep_idx[0] == 1 + assert set(b.info.keys()) == set(info_keys + ["e"]) + assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): b[22] - b = ListReplayBuffer() - with pytest.raises(NotImplementedError): - b.sample(0) + # test prev / next + assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) + assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) + batch.done = [0] + b.add(batch, buffer_ids=[0]) + assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) + assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) def test_ignore_obs_next(size=10): # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): - buf.add(obs={'mask1': np.array([i, 1, 1, 0, 0]), - 'mask2': np.array([i + 4, 0, 1, 0, 0]), - 'mask': i}, - act={'act_id': i, - 'position_id': i + 3}, - rew=i, - done=i % 3 == 0, - info={'if': i}) + buf.add(Batch(obs={'mask1': np.array([i, 1, 1, 0, 0]), + 'mask2': np.array([i + 4, 0, 1, 0, 0]), + 'mask': i}, + act={'act_id': i, + 'position_id': i + 3}, + rew=i, + done=i % 3 == 0, + info={'if': i})) indice = np.arange(len(buf)) orig = np.arange(len(buf)) data = buf[indice] @@ -91,7 +125,7 @@ def test_ignore_obs_next(size=10): assert data.obs_next -def test_stack(size=5, bufsize=9, stack_num=4): +def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3): env = MyTestEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) @@ -99,9 +133,10 @@ def test_stack(size=5, bufsize=9, stack_num=4): obs = env.reset(1) for i in range(16): obs_next, rew, done, info = env.step(1) - buf.add(obs, 1, rew, done, None, info) - buf2.add(obs, 1, rew, done, None, info) - buf3.add([None, None, obs], 1, rew, done, [None, obs], info) + buf.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) + buf2.add(Batch(obs=obs, act=1, rew=rew, done=done, info=info)) + buf3.add(Batch(obs=[obs, obs, obs], act=1, rew=rew, + done=done, obs_next=[obs, obs], info=info)) obs = obs_next if done: obs = env.reset(1) @@ -115,7 +150,9 @@ def test_stack(size=5, bufsize=9, stack_num=4): _, indice = buf2.sample(0) assert indice.tolist() == [2, 6] _, indice = buf2.sample(1) - assert indice in [2, 6] + assert indice[0] in [2, 6] + batch, indice = buf2.sample(-1) # neg bsz -> no data + assert indice.tolist() == [] and len(batch) == 0 with pytest.raises(IndexError): buf[bufsize * 2] @@ -123,11 +160,16 @@ def test_stack(size=5, bufsize=9, stack_num=4): def test_priortized_replaybuffer(size=32, bufsize=15): env = MyTestEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) + buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) obs = env.reset() action_list = [1] * 5 + [0] * 10 + [1] * 10 for i, a in enumerate(action_list): obs_next, rew, done, info = env.step(a) - buf.add(obs, a, rew, done, obs_next, info, np.random.randn() - 0.5) + batch = Batch(obs=obs, act=a, rew=rew, done=done, obs_next=obs_next, + info=info, policy=np.random.randn() - 0.5) + batch_stack = Batch.stack([batch, batch, batch]) + buf.add(Batch.stack([batch]), buffer_ids=[0]) + buf2.add(batch_stack, buffer_ids=[0, 1, 2]) obs = obs_next data, indice = buf.sample(len(buf) // 2) if len(buf) // 2 == 0: @@ -135,23 +177,39 @@ def test_priortized_replaybuffer(size=32, bufsize=15): else: assert len(data) == len(buf) // 2 assert len(buf) == min(bufsize, i + 1) + assert len(buf2) == min(bufsize, 3 * (i + 1)) + # check single buffer's data + assert buf.info.key.shape == (buf.maxsize,) + assert buf.rew.dtype == np.float + assert buf.done.dtype == np.bool_ data, indice = buf.sample(len(buf) // 2) buf.update_weight(indice, -data.weight / 2) - assert np.allclose( - buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) + assert np.allclose(buf.weight[indice], np.abs(-data.weight / 2) ** buf._alpha) + # check multi buffer's data + assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) + batch, indice = buf2.sample(10) + buf2.update_weight(indice, batch.weight * 0) + weight = buf2[np.arange(buf2.maxsize)].weight + mask = np.isin(np.arange(buf2.maxsize), indice) + assert np.all(weight[mask] == weight[mask][0]) + assert np.all(weight[~mask] == weight[~mask][0]) + assert weight[~mask][0] < weight[mask][0] and weight[mask][0] < 1 def test_update(): buf1 = ReplayBuffer(4, stack_num=2) buf2 = ReplayBuffer(4, stack_num=2) for i in range(5): - buf1.add(obs=np.array([i]), act=float(i), rew=i * i, - done=i % 2 == 0, info={'incident': 'found'}) + buf1.add(Batch(obs=np.array([i]), act=float(i), rew=i * i, + done=i % 2 == 0, info={'incident': 'found'})) assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) assert (buf2[0].obs == buf1[1].obs).all() assert (buf2[-1].obs == buf1[0].obs).all() + b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) + with pytest.raises(NotImplementedError): + b.update(b) def test_segtree(): @@ -258,23 +316,17 @@ def sample_tree(): def test_pickle(): size = 100 vbuf = ReplayBuffer(size, stack_num=2) - lbuf = ListReplayBuffer() pbuf = PrioritizedReplayBuffer(size, 0.6, 0.4) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - rew = torch.tensor([1.]).to(device) + rew = np.array([1, 1]) for i in range(4): - vbuf.add(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0) - for i in range(3): - lbuf.add(obs=Batch(index=np.array([i])), act=1, rew=rew, done=0) + vbuf.add(Batch(obs=Batch(index=np.array([i])), act=0, rew=rew, done=0)) for i in range(5): - pbuf.add(obs=Batch(index=np.array([i])), - act=2, rew=rew, done=0, weight=np.random.rand()) + pbuf.add(Batch(obs=Batch(index=np.array([i])), + act=2, rew=rew, done=0, info=np.random.rand())) # save & load _vbuf = pickle.loads(pickle.dumps(vbuf)) - _lbuf = pickle.loads(pickle.dumps(lbuf)) _pbuf = pickle.loads(pickle.dumps(pbuf)) assert len(_vbuf) == len(vbuf) and np.allclose(_vbuf.act, vbuf.act) - assert len(_lbuf) == len(lbuf) and np.allclose(_lbuf.act, lbuf.act) assert len(_pbuf) == len(pbuf) and np.allclose(_pbuf.act, pbuf.act) # make sure the meta var is identical assert _vbuf.stack_num == vbuf.stack_num @@ -286,23 +338,21 @@ def test_hdf5(): size = 100 buffers = { "array": ReplayBuffer(size, stack_num=2), - "list": ListReplayBuffer(), - "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4) + "prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4), } buffer_types = {k: b.__class__ for k, b in buffers.items()} device = 'cuda' if torch.cuda.is_available() else 'cpu' - rew = torch.tensor([1.]).to(device) + info_t = torch.tensor([1.]).to(device) for i in range(4): kwargs = { 'obs': Batch(index=np.array([i])), 'act': i, - 'rew': rew, - 'done': 0, - 'info': {"number": {"n": i}, 'extra': None}, + 'rew': np.array([1, 2]), + 'done': i % 3 == 2, + 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, } - buffers["array"].add(**kwargs) - buffers["list"].add(**kwargs) - buffers["prioritized"].add(weight=np.random.rand(), **kwargs) + buffers["array"].add(Batch(kwargs)) + buffers["prioritized"].add(Batch(kwargs)) # save paths = {} @@ -320,10 +370,10 @@ def test_hdf5(): assert len(_buffers[k]) == len(buffers[k]) assert np.allclose(_buffers[k].act, buffers[k].act) assert _buffers[k].stack_num == buffers[k].stack_num - assert _buffers[k]._maxsize == buffers[k]._maxsize - assert _buffers[k]._index == buffers[k]._index + assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: + assert _buffers[k]._index == buffers[k]._index assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: @@ -332,28 +382,353 @@ def test_hdf5(): assert np.all( buffers[k][:].info.extra == _buffers[k][:].info.extra) - for path in paths.values(): - os.remove(path) - # raise exception when value cannot be pickled - data = {"not_supported": lambda x: x*x} + data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): to_hdf5(data, grp) # ndarray with data type not supported by HDF5 that cannot be pickled - data = {"not_supported": np.array(lambda x: x*x)} + data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): to_hdf5(data, grp) +def test_replaybuffermanager(): + buf = VectorReplayBuffer(20, 4) + batch = Batch(obs=[1, 2, 3], act=[1, 2, 3], rew=[1, 2, 3], done=[0, 0, 1]) + ptr, ep_rew, ep_len, ep_idx = buf.add(batch, buffer_ids=[0, 1, 2]) + assert np.all(ep_len == [0, 0, 1]) and np.all(ep_rew == [0, 0, 3]) + assert np.all(ptr == [0, 5, 10]) and np.all(ep_idx == [0, 5, 10]) + with pytest.raises(NotImplementedError): + # ReplayBufferManager cannot be updated + buf.update(buf) + # sample index / prev / next / unfinished_index + indice = buf.sample_index(11000) + assert np.bincount(indice)[[0, 5, 10]].min() >= 3000 # uniform sample + batch, indice = buf.sample(0) + assert np.allclose(indice, [0, 5, 10]) + indice_prev = buf.prev(indice) + assert np.allclose(indice_prev, indice), indice_prev + indice_next = buf.next(indice) + assert np.allclose(indice_next, indice), indice_next + assert np.allclose(buf.unfinished_index(), [0, 5]) + buf.add(Batch(obs=[4], act=[4], rew=[4], done=[1]), buffer_ids=[3]) + assert np.allclose(buf.unfinished_index(), [0, 5]) + batch, indice = buf.sample(10) + batch, indice = buf.sample(0) + assert np.allclose(indice, [0, 5, 10, 15]) + indice_prev = buf.prev(indice) + assert np.allclose(indice_prev, indice), indice_prev + indice_next = buf.next(indice) + assert np.allclose(indice_next, indice), indice_next + data = np.array([0, 0, 0, 0]) + buf.add(Batch(obs=data, act=data, rew=data, done=data), + buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=1 - data), + buffer_ids=[0, 1, 2, 3]) + assert len(buf) == 12 + buf.add(Batch(obs=data, act=data, rew=data, done=data), + buffer_ids=[0, 1, 2, 3]) + buf.add(Batch(obs=data, act=data, rew=data, done=[0, 1, 0, 1]), + buffer_ids=[0, 1, 2, 3]) + assert len(buf) == 20 + indice = buf.sample_index(120000) + assert np.bincount(indice).min() >= 5000 + batch, indice = buf.sample(10) + indice = buf.sample_index(0) + assert np.allclose(indice, np.arange(len(buf))) + # check the actual data stored in buf._meta + assert np.allclose(buf.done, [ + 0, 0, 1, 0, 0, + 0, 0, 1, 0, 1, + 1, 0, 1, 0, 0, + 1, 0, 1, 0, 1, + ]) + assert np.allclose(buf.prev(indice), [ + 0, 0, 1, 3, 3, + 5, 5, 6, 8, 8, + 10, 11, 11, 13, 13, + 15, 16, 16, 18, 18, + ]) + assert np.allclose(buf.next(indice), [ + 1, 2, 2, 4, 4, + 6, 7, 7, 9, 9, + 10, 12, 12, 14, 14, + 15, 17, 17, 19, 19, + ]) + assert np.allclose(buf.unfinished_index(), [4, 14]) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[1], act=[1], rew=[1], done=[1]), buffer_ids=[2]) + assert np.all(ep_len == [3]) and np.all(ep_rew == [1]) + assert np.all(ptr == [10]) and np.all(ep_idx == [13]) + assert np.allclose(buf.unfinished_index(), [4]) + indice = list(sorted(buf.sample_index(0))) + assert np.allclose(indice, np.arange(len(buf))) + assert np.allclose(buf.prev(indice), [ + 0, 0, 1, 3, 3, + 5, 5, 6, 8, 8, + 14, 11, 11, 13, 13, + 15, 16, 16, 18, 18, + ]) + assert np.allclose(buf.next(indice), [ + 1, 2, 2, 4, 4, + 6, 7, 7, 9, 9, + 10, 12, 12, 14, 10, + 15, 17, 17, 19, 19, + ]) + # corner case: list, int and -1 + assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] + assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] + batch = buf._meta + batch.info = np.ones(buf.maxsize) + buf.set_batch(batch) + assert np.allclose(buf.buffers[-1].info, [1] * 5) + assert buf.sample_index(-1).tolist() == [] + assert np.array([ReplayBuffer(0, ignore_obs_next=True)]).dtype == np.object + + +def test_cachedbuffer(): + buf = CachedReplayBuffer(ReplayBuffer(10), 4, 5) + assert buf.sample_index(0).tolist() == [] + # check the normal function/usage/storage in CachedReplayBuffer + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[1], act=[1], rew=[1], done=[0]), buffer_ids=[1]) + obs = np.zeros(buf.maxsize) + obs[15] = 1 + indice = buf.sample_index(0) + assert np.allclose(indice, [15]) + assert np.allclose(buf.prev(indice), [15]) + assert np.allclose(buf.next(indice), [15]) + assert np.allclose(buf.obs, obs) + assert np.all(ep_len == [0]) and np.all(ep_rew == [0.0]) + assert np.all(ptr == [15]) and np.all(ep_idx == [15]) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[2], act=[2], rew=[2], done=[1]), buffer_ids=[3]) + obs[[0, 25]] = 2 + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 15]) + assert np.allclose(buf.prev(indice), [0, 15]) + assert np.allclose(buf.next(indice), [0, 15]) + assert np.allclose(buf.obs, obs) + assert np.all(ep_len == [1]) and np.all(ep_rew == [2.0]) + assert np.all(ptr == [0]) and np.all(ep_idx == [0]) + assert np.allclose(buf.unfinished_index(), [15]) + assert np.allclose(buf.sample_index(0), [0, 15]) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=[3, 4], act=[3, 4], rew=[3, 4], done=[0, 1]), + buffer_ids=[3, 1]) + assert np.all(ep_len == [0, 2]) and np.all(ep_rew == [0, 5.0]) + assert np.all(ptr == [25, 2]) and np.all(ep_idx == [25, 1]) + obs[[0, 1, 2, 15, 16, 25]] = [2, 1, 4, 1, 4, 3] + assert np.allclose(buf.obs, obs) + assert np.allclose(buf.unfinished_index(), [25]) + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 1, 2, 25]) + assert np.allclose(buf.done[indice], [1, 0, 1, 0]) + assert np.allclose(buf.prev(indice), [0, 1, 1, 25]) + assert np.allclose(buf.next(indice), [0, 2, 2, 25]) + indice = buf.sample_index(10000) + assert np.bincount(indice)[[0, 1, 2, 25]].min() > 2000 # uniform sample + # cached buffer with main_buffer size == 0 (no update) + # used in test_collector + buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) + data = np.zeros(4) + rew = np.ones([4, 4]) + buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 1, 1])) + buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) + buf.add(Batch(obs=data, act=data, rew=rew, done=[1, 1, 1, 1])) + buf.add(Batch(obs=data, act=data, rew=rew, done=[0, 0, 0, 0])) + ptr, ep_rew, ep_len, ep_idx = buf.add( + Batch(obs=data, act=data, rew=rew, done=[0, 1, 0, 1])) + assert np.all(ptr == [1, -1, 11, -1]) and np.all(ep_idx == [0, -1, 10, -1]) + assert np.all(ep_len == [0, 2, 0, 2]) + assert np.all(ep_rew == [data, data + 2, data, data + 2]) + assert np.allclose(buf.done, [ + 0, 0, 1, 0, 0, + 0, 1, 1, 0, 0, + 0, 0, 0, 0, 0, + 0, 1, 0, 0, 0, + ]) + indice = buf.sample_index(0) + assert np.allclose(indice, [0, 1, 10, 11]) + assert np.allclose(buf.prev(indice), [0, 0, 10, 10]) + assert np.allclose(buf.next(indice), [1, 1, 11, 11]) + + +def test_multibuf_stack(): + size = 5 + bufsize = 9 + stack_num = 4 + cached_num = 3 + env = MyTestEnv(size) + # test if CachedReplayBuffer can handle stack_num + ignore_obs_next + buf4 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), + cached_num, size) + # test if CachedReplayBuffer can handle corner case: + # buffer + stack_num + ignore_obs_next + sample_avail + buf5 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, + ignore_obs_next=True, sample_avail=True), + cached_num, size) + obs = env.reset(1) + for i in range(18): + obs_next, rew, done, info = env.step(1) + obs_list = np.array([obs + size * i for i in range(cached_num)]) + act_list = [1] * cached_num + rew_list = [rew] * cached_num + done_list = [done] * cached_num + obs_next_list = -obs_list + info_list = [info] * cached_num + batch = Batch(obs=obs_list, act=act_list, rew=rew_list, + done=done_list, obs_next=obs_next_list, info=info_list) + buf5.add(batch) + buf4.add(batch) + assert np.all(buf4.obs == buf5.obs) + assert np.all(buf4.done == buf5.done) + obs = obs_next + if done: + obs = env.reset(1) + # check the `add` order is correct + assert np.allclose(buf4.obs.reshape(-1), [ + 12, 13, 14, 4, 6, 7, 8, 9, 11, # main_buffer + 1, 2, 3, 4, 0, # cached_buffer[0] + 6, 7, 8, 9, 0, # cached_buffer[1] + 11, 12, 13, 14, 0, # cached_buffer[2] + ]), buf4.obs + assert np.allclose(buf4.done, [ + 0, 0, 1, 1, 0, 0, 0, 1, 0, # main_buffer + 0, 0, 0, 1, 0, # cached_buffer[0] + 0, 0, 0, 1, 0, # cached_buffer[1] + 0, 0, 0, 1, 0, # cached_buffer[2] + ]), buf4.done + assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) + indice = sorted(buf4.sample_index(0)) + assert np.allclose(indice, list(range(bufsize)) + [9, 10, 14, 15, 19, 20]) + assert np.allclose(buf4[indice].obs[..., 0], [ + [11, 11, 11, 12], [11, 11, 12, 13], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 6], [6, 6, 6, 7], + [6, 6, 7, 8], [6, 7, 8, 9], [11, 11, 11, 11], + [1, 1, 1, 1], [1, 1, 1, 2], [6, 6, 6, 6], + [6, 6, 6, 7], [11, 11, 11, 11], [11, 11, 11, 12], + ]) + assert np.allclose(buf4[indice].obs_next[..., 0], [ + [11, 11, 12, 13], [11, 12, 13, 14], [11, 12, 13, 14], + [4, 4, 4, 4], [6, 6, 6, 7], [6, 6, 7, 8], + [6, 7, 8, 9], [6, 7, 8, 9], [11, 11, 11, 12], + [1, 1, 1, 2], [1, 1, 1, 2], [6, 6, 6, 7], + [6, 6, 6, 7], [11, 11, 11, 12], [11, 11, 11, 12], + ]) + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [2, 7]) + assert np.all(np.isin(buf5.sample_index(100), indice)) + # manually change the stack num + buf5.stack_num = 2 + for buf in buf5.buffers: + buf.stack_num = 2 + indice = buf5.sample_index(0) + assert np.allclose(sorted(indice), [0, 1, 2, 5, 6, 7, 10, 15, 20]) + batch, _ = buf5.sample(0) + # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next + buf6 = CachedReplayBuffer( + ReplayBuffer(bufsize, stack_num=stack_num, + save_only_last_obs=True, ignore_obs_next=True), + cached_num, size) + obs = np.random.rand(size, 4, 84, 84) + buf6.add(Batch(obs=[obs[2], obs[0]], act=[1, 1], rew=[0, 0], done=[0, 1], + obs_next=[obs[3], obs[1]]), buffer_ids=[1, 2]) + assert buf6.obs.shape == (buf6.maxsize, 84, 84) + assert np.allclose(buf6.obs[0], obs[0, -1]) + assert np.allclose(buf6.obs[14], obs[2, -1]) + assert np.allclose(buf6.obs[19], obs[0, -1]) + assert buf6[0].obs.shape == (4, 84, 84) + + +def test_multibuf_hdf5(): + size = 100 + buffers = { + "vector": VectorReplayBuffer(size * 4, 4), + "cached": CachedReplayBuffer(ReplayBuffer(size), 4, size) + } + buffer_types = {k: b.__class__ for k, b in buffers.items()} + device = 'cuda' if torch.cuda.is_available() else 'cpu' + info_t = torch.tensor([1.]).to(device) + for i in range(4): + kwargs = { + 'obs': Batch(index=np.array([i])), + 'act': i, + 'rew': np.array([1, 2]), + 'done': i % 3 == 2, + 'info': {"number": {"n": i, "t": info_t}, 'extra': None}, + } + buffers["vector"].add(Batch.stack([kwargs, kwargs, kwargs]), + buffer_ids=[0, 1, 2]) + buffers["cached"].add(Batch.stack([kwargs, kwargs, kwargs]), + buffer_ids=[0, 1, 2]) + + # save + paths = {} + for k, buf in buffers.items(): + f, path = tempfile.mkstemp(suffix='.hdf5') + os.close(f) + buf.save_hdf5(path) + paths[k] = path + + # load replay buffer + _buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()} + + # compare + for k in buffers.keys(): + assert len(_buffers[k]) == len(buffers[k]) + assert np.allclose(_buffers[k].act, buffers[k].act) + assert _buffers[k].stack_num == buffers[k].stack_num + assert _buffers[k].maxsize == buffers[k].maxsize + assert np.all(_buffers[k]._indices == buffers[k]._indices) + # check shallow copy in VectorReplayBuffer + for k in ["vector", "cached"]: + buffers[k].info.number.n[0] = -100 + assert buffers[k].buffers[0].info.number.n[0] == -100 + # check if still behave normally + for k in ["vector", "cached"]: + kwargs = { + 'obs': Batch(index=np.array([5])), + 'act': 5, + 'rew': np.array([2, 1]), + 'done': False, + 'info': {"number": {"n": i}, 'Timelimit.truncate': True}, + } + buffers[k].add(Batch.stack([kwargs, kwargs, kwargs, kwargs])) + act = np.zeros(buffers[k].maxsize) + if k == "vector": + act[np.arange(5)] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size] = np.array([0, 1, 2, 3, 5]) + act[np.arange(5) + size * 2] = np.array([0, 1, 2, 3, 5]) + act[size * 3] = 5 + elif k == "cached": + act[np.arange(9)] = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) + act[np.arange(3) + size] = np.array([3, 5, 2]) + act[np.arange(3) + size * 2] = np.array([3, 5, 2]) + act[np.arange(3) + size * 3] = np.array([3, 5, 2]) + act[size * 4] = 5 + assert np.allclose(buffers[k].act, act) + info_keys = ["number", "extra", "Timelimit.truncate"] + assert set(buffers[k].info.keys()) == set(info_keys) + + for path in paths.values(): + os.remove(path) + + if __name__ == '__main__': - test_hdf5() test_replaybuffer() test_ignore_obs_next() test_stack() - test_pickle() test_segtree() test_priortized_replaybuffer() - test_priortized_replaybuffer(233333, 200000) test_update() + test_pickle() + test_hdf5() + test_replaybuffermanager() + test_cachedbuffer() + test_multibuf_stack() + test_multibuf_hdf5() diff --git a/test/base/test_collector.py b/test/base/test_collector.py index e7e11759f..b9d789193 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,9 +1,17 @@ +import tqdm +import pytest import numpy as np from torch.utils.tensorboard import SummaryWriter from tianshou.policy import BasePolicy from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.data import Collector, Batch, ReplayBuffer +from tianshou.data import Batch, Collector, AsyncCollector +from tianshou.data import ( + ReplayBuffer, + PrioritizedReplayBuffer, + VectorReplayBuffer, + CachedReplayBuffer, +) if __name__ == '__main__': from env import MyTestEnv @@ -12,7 +20,7 @@ class MyPolicy(BasePolicy): - def __init__(self, dict_state: bool = False, need_state: bool = True): + def __init__(self, dict_state=False, need_state=True): """ :param bool dict_state: if the observation of the environment is a dict :param bool need_state: if the policy needs the hidden state (for RNN) @@ -43,15 +51,13 @@ def __init__(self, writer): def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb # if only obs exist -> reset - # if obs/act/rew/done/... exist -> normal step + # if obs_next/rew/done/info exist -> normal step if 'rew' in kwargs: - n = len(kwargs['obs']) info = kwargs['info'] - for i in range(n): - info[i].update(rew=kwargs['rew'][i]) + info.rew = kwargs['rew'] if 'key' in info.keys(): - self.writer.add_scalar('key', np.mean( - info['key']), global_step=self.cnt) + self.writer.add_scalar( + 'key', np.mean(info.key), global_step=self.cnt) self.cnt += 1 return Batch(info=info) else: @@ -61,10 +67,8 @@ def preprocess_fn(self, **kwargs): def single_preprocess_fn(**kwargs): # same as above, without tfb if 'rew' in kwargs: - n = len(kwargs['obs']) info = kwargs['info'] - for i in range(n): - info[i].update(rew=kwargs['rew'][i]) + info.rew = kwargs['rew'] return Batch(info=info) else: return Batch() @@ -79,110 +83,105 @@ def test_collector(): dum = DummyVectorEnv(env_fns) policy = MyPolicy() env = env_fns[0]() - c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), - logger.preprocess_fn) + c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn) c0.collect(n_step=3) - assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 1]) - assert np.allclose(c0.buffer[:4].obs_next[..., 0], [1, 2, 1, 2]) + assert len(c0.buffer) == 3 + assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) + assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) c0.collect(n_episode=3) - assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]) - assert np.allclose(c0.buffer[:10].obs_next[..., 0], - [1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) + assert len(c0.buffer) == 8 + assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) + assert np.allclose(c0.buffer[:].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 1, 2]) c0.collect(n_step=3, random=True) - c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), - logger.preprocess_fn) - c1.collect(n_step=6) - assert np.allclose(c1.buffer.obs[:11, 0], - [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3]) - assert np.allclose(c1.buffer[:11].obs_next[..., 0], - [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4]) - c1.collect(n_episode=2) - assert np.allclose(c1.buffer.obs[11:21, 0], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) - assert np.allclose(c1.buffer[11:21].obs_next[..., 0], - [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) - c1.collect(n_episode=3, random=True) - c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), - logger.preprocess_fn) - c2.collect(n_episode=[1, 2, 2, 2]) - assert np.allclose(c2.buffer.obs_next[:26, 0], [ - 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) - c2.reset_env() - c2.collect(n_episode=[2, 2, 2, 2]) - assert np.allclose(c2.buffer.obs_next[26:54, 0], [ - 1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, - 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) - c2.collect(n_episode=[1, 1, 1, 1], random=True) - + c1 = Collector( + policy, venv, + VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn) + c1.collect(n_step=8) + obs = np.zeros(100) + obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] -def test_collector_with_exact_episodes(): - env_lens = [2, 6, 3, 10] - writer = SummaryWriter('log/exact_collector') - logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True) - for i in env_lens] + assert np.allclose(c1.buffer.obs[:, 0], obs) + assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + c1.collect(n_episode=4) + assert len(c1.buffer) == 16 + obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] + assert np.allclose(c1.buffer.obs[:, 0], obs) + assert np.allclose(c1.buffer[:].obs_next[..., 0], + [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) + c1.collect(n_episode=4, random=True) + c2 = Collector( + policy, dum, + VectorReplayBuffer(total_size=100, buffer_num=4), + logger.preprocess_fn) + c2.collect(n_episode=7) + obs1 = obs.copy() + obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] + obs2 = obs.copy() + obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] + c2obs = c2.buffer.obs[:, 0] + assert np.all(c2obs == obs1) or np.all(c2obs == obs2) + c2.reset_env() + c2.reset_buffer() + assert c2.collect(n_episode=8)['n/ep'] == 8 + obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3] + assert np.all(c2.buffer.obs[:, 0] == obs) + c2.collect(n_episode=4, random=True) - venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) - policy = MyPolicy() - c1 = Collector(policy, venv, - ReplayBuffer(size=1000, ignore_obs_next=False), - logger.preprocess_fn) - n_episode1 = [2, 0, 5, 1] - n_episode2 = [1, 3, 2, 0] - c1.collect(n_episode=n_episode1) - expected_steps = sum([a * b for a, b in zip(env_lens, n_episode1)]) - actual_steps = sum(venv.steps) - assert expected_steps == actual_steps - c1.collect(n_episode=n_episode2) - expected_steps = sum( - [a * (b + c) for a, b, c in zip(env_lens, n_episode1, n_episode2)]) - actual_steps = sum(venv.steps) - assert expected_steps == actual_steps + # test corner case + with pytest.raises(TypeError): + Collector(policy, dum, ReplayBuffer(10)) + with pytest.raises(TypeError): + Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5)) + with pytest.raises(TypeError): + c2.collect() def test_collector_with_async(): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True) + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MyPolicy() - c1 = Collector(policy, venv, - ReplayBuffer(size=1000, ignore_obs_next=False), - logger.preprocess_fn) - c1.collect(n_episode=10) - # check if the data in the buffer is chronological - # i.e. data in the buffer are full episodes, and each episode is - # returned by the same environment - env_id = c1.buffer.info['env_id'] - size = len(c1.buffer) - obs = c1.buffer.obs[:size] - done = c1.buffer.done[:size] - obs_ground_truth = [] - i = 0 - while i < size: - # i is the start of an episode - if done[i]: - # this episode has one transition - assert env_lens[env_id[i]] == 1 - i += 1 - continue - j = i - while True: - j += 1 - # in one episode, the environment id is the same - assert env_id[j] == env_id[i] - if done[j]: - break - j = j + 1 # j is the start of the next episode - assert j - i == env_lens[env_id[i]] - obs_ground_truth += list(range(j - i)) - i = j - obs_ground_truth = np.expand_dims( - np.array(obs_ground_truth), axis=-1) - assert np.allclose(obs, obs_ground_truth) + bufsize = 60 + c1 = AsyncCollector( + policy, venv, + VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), + logger.preprocess_fn) + ptr = [0, 0, 0, 0] + for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode) + assert result["n/ep"] >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate( + np.bincount(result["lens"], minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape( + count, env_len) == seq + 1) + # test async n_step, for now the buffer should be full of data + for n_step in tqdm.trange(1, 15, desc="test async n_step"): + result = c1.collect(n_step=n_step) + assert result["n/st"] >= n_step + for i in range(4): + env_len = i + 2 + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id == i) + assert np.all(buf.obs.reshape(-1, env_len) == seq) + assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) + with pytest.raises(TypeError): + c1.collect() def test_collector_with_dict_state(): @@ -192,72 +191,227 @@ def test_collector_with_dict_state(): Logger.single_preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=2) + assert len(c0.buffer) == 10 env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) envs.seed(666) obs = envs.reset() assert not np.isclose(obs[0]['rand'], obs[1]['rand']) - c1 = Collector(policy, envs, ReplayBuffer(size=100), - Logger.single_preprocess_fn) - c1.collect(n_step=10) - c1.collect(n_episode=[2, 1, 1, 2]) + c1 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn) + c1.collect(n_step=12) + result = c1.collect(n_episode=8) + assert result['n/ep'] == 8 + lens = np.bincount(result['lens']) + assert result['n/st'] == 21 and np.all(lens == [0, 0, 2, 2, 2, 2]) or \ + result['n/st'] == 20 and np.all(lens == [0, 0, 3, 1, 2, 2]) batch, _ = c1.buffer.sample(10) - print(batch) c0.buffer.update(c1.buffer) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index[..., 0], [ - 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., - 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) - c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - Logger.single_preprocess_fn) - c2.collect(n_episode=[0, 0, 0, 10]) + assert len(c0.buffer) in [42, 43] + if len(c0.buffer) == 42: + assert np.all(c0.buffer[:].obs.index[..., 0] == [ + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 2, 0, 1, 2, + 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + ]), c0.buffer[:].obs.index[..., 0] + else: + assert np.all(c0.buffer[:].obs.index[..., 0] == [ + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + 0, 1, 0, 1, 0, 1, + 0, 1, 2, 0, 1, 2, 0, 1, 2, + 0, 1, 2, 3, 0, 1, 2, 3, + 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, + ]), c0.buffer[:].obs.index[..., 0] + c2 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn) + c2.collect(n_episode=10) batch, _ = c2.buffer.sample(10) def test_collector_with_ma(): - def reward_metric(x): - return x.sum() env = MyTestEnv(size=5, sleep=0, ma_rew=4) policy = MyPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100), - Logger.single_preprocess_fn, reward_metric=reward_metric) + Logger.single_preprocess_fn) # n_step=3 will collect a full episode - r = c0.collect(n_step=3)['rew'] - assert np.asanyarray(r).size == 1 and r == 4. - r = c0.collect(n_episode=2)['rew'] - assert np.asanyarray(r).size == 1 and r == 4. + r = c0.collect(n_step=3)['rews'] + assert len(r) == 0 + r = c0.collect(n_episode=2)['rews'] + assert r.shape == (2, 4) and np.all(r == 1) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) - c1 = Collector(policy, envs, ReplayBuffer(size=100), - Logger.single_preprocess_fn, reward_metric=reward_metric) - r = c1.collect(n_step=10)['rew'] - assert np.asanyarray(r).size == 1 and r == 4. - r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] - assert np.asanyarray(r).size == 1 and r == 4. + c1 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4), + Logger.single_preprocess_fn) + r = c1.collect(n_step=12)['rews'] + assert r.shape == (2, 4) and np.all(r == 1), r + r = c1.collect(n_episode=8)['rews'] + assert r.shape == (8, 4) and np.all(r == 1) batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) - assert np.allclose(c0.buffer[:len(c0.buffer)].obs[..., 0], [ - 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., - 0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0., - 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]) - rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, - 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, - 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1] - assert np.allclose(c0.buffer[:len(c0.buffer)].rew, - [[x] * 4 for x in rew]) - c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), - Logger.single_preprocess_fn, reward_metric=reward_metric) - r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] - assert np.asanyarray(r).size == 1 and r == 4. + assert len(c0.buffer) in [42, 43] + if len(c0.buffer) == 42: + rew = [ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 0, 0, 1, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 1, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + ] + else: + rew = [ + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + 0, 1, 0, 1, 0, 1, + 0, 0, 1, 0, 0, 1, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 1, + 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, + ] + assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) + assert np.all(c0.buffer[:].done == rew) + c2 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), + Logger.single_preprocess_fn) + r = c2.collect(n_episode=10)['rews'] + assert r.shape == (10, 4) and np.all(r == 1) batch, _ = c2.buffer.sample(10) +def test_collector_with_atari_setting(): + reference_obs = np.zeros([6, 4, 84, 84]) + for i in range(6): + reference_obs[i, 3, np.arange(84), np.arange(84)] = i + reference_obs[i, 2, np.arange(84)] = i + reference_obs[i, 1, :, np.arange(84)] = i + reference_obs[i, 0] = i + + # atari single buffer + env = MyTestEnv(size=5, sleep=0, array_state=True) + policy = MyPolicy() + c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0.collect(n_step=6) + c0.collect(n_episode=2) + assert c0.buffer.obs.shape == (100, 4, 84, 84) + assert c0.buffer.obs_next.shape == (100, 4, 84, 84) + assert len(c0.buffer) == 15 + obs = np.zeros_like(c0.buffer.obs) + obs[np.arange(15)] = reference_obs[np.arange(15) % 5] + assert np.all(obs == c0.buffer.obs) + + c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) + c1.collect(n_episode=3) + assert np.allclose(c0.buffer.obs, c1.buffer.obs) + with pytest.raises(AttributeError): + c1.buffer.obs_next + assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next) + + c2 = Collector( + policy, env, + ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True)) + c2.collect(n_step=8) + assert c2.buffer.obs.shape == (100, 84, 84) + obs = np.zeros_like(c2.buffer.obs) + obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] + assert np.all(c2.buffer.obs == obs) + assert np.allclose(c2.buffer[:].obs_next, + reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) + + # atari multi buffer + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) + for i in [2, 3, 4, 5]] + envs = DummyVectorEnv(env_fns) + c3 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4)) + c3.collect(n_step=12) + result = c3.collect(n_episode=9) + assert result["n/ep"] == 9 and result["n/st"] == 23 + assert c3.buffer.obs.shape == (100, 4, 84, 84) + obs = np.zeros_like(c3.buffer.obs) + obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]] + obs[np.arange(25, 34)] = reference_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] + obs[np.arange(50, 58)] = reference_obs[[0, 1, 2, 3, 0, 1, 2, 3]] + obs[np.arange(75, 85)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] + assert np.all(obs == c3.buffer.obs) + obs_next = np.zeros_like(c3.buffer.obs_next) + obs_next[np.arange(8)] = reference_obs[[1, 2, 1, 2, 1, 2, 1, 2]] + obs_next[np.arange(25, 34)] = reference_obs[[1, 2, 3, 1, 2, 3, 1, 2, 3]] + obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]] + obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]] + assert np.all(obs_next == c3.buffer.obs_next) + c4 = Collector( + policy, envs, + VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4, + ignore_obs_next=True, save_only_last_obs=True)) + c4.collect(n_step=12) + result = c4.collect(n_episode=9) + assert result["n/ep"] == 9 and result["n/st"] == 23 + assert c4.buffer.obs.shape == (100, 84, 84) + obs = np.zeros_like(c4.buffer.obs) + slice_obs = reference_obs[:, -1] + obs[np.arange(8)] = slice_obs[[0, 1, 0, 1, 0, 1, 0, 1]] + obs[np.arange(25, 34)] = slice_obs[[0, 1, 2, 0, 1, 2, 0, 1, 2]] + obs[np.arange(50, 58)] = slice_obs[[0, 1, 2, 3, 0, 1, 2, 3]] + obs[np.arange(75, 85)] = slice_obs[[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]] + assert np.all(c4.buffer.obs == obs) + obs_next = np.zeros([len(c4.buffer), 4, 84, 84]) + ref_index = np.array([ + 1, 1, 1, 1, 1, 1, 1, 1, + 1, 2, 2, 1, 2, 2, 1, 2, 2, + 1, 2, 3, 3, 1, 2, 3, 3, + 1, 2, 3, 4, 4, 1, 2, 3, 4, 4, + ]) + obs_next[:, -1] = slice_obs[ref_index] + ref_index -= 1 + ref_index[ref_index < 0] = 0 + obs_next[:, -2] = slice_obs[ref_index] + ref_index -= 1 + ref_index[ref_index < 0] = 0 + obs_next[:, -3] = slice_obs[ref_index] + ref_index -= 1 + ref_index[ref_index < 0] = 0 + obs_next[:, -4] = slice_obs[ref_index] + assert np.all(obs_next == c4.buffer[:].obs_next) + + buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, + save_only_last_obs=True) + c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) + result_ = c5.collect(n_step=12) + assert len(buf) == 5 and len(c5.buffer) == 12 + result = c5.collect(n_episode=9) + assert result["n/ep"] == 9 and result["n/st"] == 23 + assert len(buf) == 35 + assert np.all(buf.obs[:len(buf)] == slice_obs[[ + 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 3, 4, + 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4]]) + assert np.all(buf[:].obs_next[:, -1] == slice_obs[[ + 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 3, 4, 4, + 1, 1, 1, 2, 2, 1, 1, 1, 2, 3, 3, 1, 2, 2, 1, 2, 3, 4, 4]]) + assert len(buf) == len(c5.buffer) + + # test buffer=None + c6 = Collector(policy, envs) + result1 = c6.collect(n_step=12) + for key in ["n/ep", "n/st", "rews", "lens"]: + assert np.allclose(result1[key], result_[key]) + result2 = c6.collect(n_episode=9) + for key in ["n/ep", "n/st", "rews", "lens"]: + assert np.allclose(result2[key], result[key]) + + if __name__ == '__main__': test_collector() test_collector_with_dict_state() test_collector_with_ma() + test_collector_with_atari_setting() test_collector_with_async() - test_collector_with_exact_episodes() diff --git a/test/base/test_returns.py b/test/base/test_returns.py index a8bdc7c9d..e8d70de5c 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -20,50 +20,91 @@ def compute_episodic_return_base(batch, gamma): def test_episodic_returns(size=2560): fn = BasePolicy.compute_episodic_return + buf = ReplayBuffer(20) batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), + info=Batch({'TimeLimit.truncated': + np.array([False, False, False, False, False, True, False, False])}) ) - batch = fn(batch, None, gamma=.1, gae_lambda=1) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) assert np.allclose(batch.returns, ans) + buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) - batch = fn(batch, None, gamma=.1, gae_lambda=1) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) assert np.allclose(batch.returns, ans) + buf.reset() batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), rew=np.array([7, 6, 1, 2, 3, 4, 5.]), ) - batch = fn(batch, None, gamma=.1, gae_lambda=1) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + batch = fn(batch, buf, buf.sample_index(0), None, gamma=.1, gae_lambda=1) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) assert np.allclose(batch.returns, ans) + buf.reset() batch = Batch( done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), - rew=np.array([ - 101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]) + rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), ) + for b in batch: + b.obs = b.act = 1 + buf.add(b) v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) - ret = fn(batch, v, gamma=0.99, gae_lambda=0.95) + ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) returns = np.array([ 454.8344, 376.1143, 291.298, 200., 464.5610, 383.1085, 295.387, 201., 474.2876, 390.1027, 299.476, 202.]) assert np.allclose(ret.returns, returns) + buf.reset() + batch = Batch( + done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), + rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), + info=Batch({'TimeLimit.truncated': + np.array([False, False, False, True, False, False, + False, True, False, False, False, False])}) + ) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) + ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) + returns = np.array([ + 454.0109, 375.2386, 290.3669, 199.01, + 462.9138, 381.3571, 293.5248, 199.02, + 474.2876, 390.1027, 299.476, 202.]) + assert np.allclose(ret.returns, returns) + if __name__ == '__main__': + buf = ReplayBuffer(size) batch = Batch( done=np.random.randint(100, size=size) == 0, rew=np.random.random(size), ) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + indice = buf.sample_index(0) def vanilla(): return compute_episodic_return_base(batch, gamma=.1) def optimized(): - return fn(batch, gamma=.1) + return fn(batch, buf, indice, gamma=.1, gae_lambda=1.0) cnt = 3000 print('GAE vanilla', timeit(vanilla, setup=vanilla, number=cnt)) @@ -72,7 +113,7 @@ def optimized(): def target_q_fn(buffer, indice): # return the next reward - indice = (indice + 1 - buffer.done[indice]) % len(buffer) + indice = buffer.next(indice) return torch.tensor(-buffer.rew[indice], dtype=torch.float32) @@ -85,15 +126,19 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): buf_len = len(buffer) for i in range(len(indice)): flag, r = False, 0. + real_step_n = nstep for n in range(nstep): idx = (indice[i] + n) % buf_len r += buffer.rew[idx] * gamma ** n if buffer.done[idx]: - flag = True + if not (hasattr(buffer, 'info') and + buffer.info['TimeLimit.truncated'][idx]): + flag = True + real_step_n = n + 1 break if not flag: - idx = (indice[i] + nstep - 1) % buf_len - r += to_numpy(target_q_fn(buffer, idx)) * gamma ** nstep + idx = (indice[i] + real_step_n - 1) % buf_len + r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n returns[i] = r return returns @@ -101,14 +146,15 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): def test_nstep_returns(size=10000): buf = ReplayBuffer(10) for i in range(12): - buf.add(obs=0, act=0, rew=i + 1, done=i % 4 == 3) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3)) batch, indice = buf.sample(0) assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) - # rew: [10, 11, 2, 3, 4, 5, 6, 7, 8, 9] + # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] # test nstep = 1 returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=1).pop('returns')) + batch, buf, indice, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1)) assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) r_ = compute_nstep_return_base(1, .1, buf, indice) assert np.allclose(returns, r_), (r_, returns) @@ -118,9 +164,53 @@ def test_nstep_returns(size=10000): assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 2 returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns')) - assert np.allclose(returns, [ - 3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + batch, buf, indice, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + r_ = compute_nstep_return_base(2, .1, buf, indice) + assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + # test nstep = 10 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + r_ = compute_nstep_return_base(10, .1, buf, indice) + assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + + +def test_nstep_returns_with_timelimit(size=10000): + buf = ReplayBuffer(10) + for i in range(12): + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, + info={"TimeLimit.truncated": i == 3})) + batch, indice = buf.sample(0) + assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) + # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] + # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] + # test nstep = 1 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) + r_ = compute_nstep_return_base(1, .1, buf, indice) + assert np.allclose(returns, r_), (r_, returns) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + # test nstep = 2 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) r_ = compute_nstep_return_base(2, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( @@ -129,9 +219,10 @@ def test_nstep_returns(size=10000): assert np.allclose(returns_multidim, returns[:, np.newaxis]) # test nstep = 10 returns = to_numpy(BasePolicy.compute_nstep_return( - batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns')) - assert np.allclose(returns, [ - 3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12]) + batch, buf, indice, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, + 7.8, 8, 10.122, 11.22, 12.2, 12]) r_ = compute_nstep_return_base(10, .1, buf, indice) assert np.allclose(returns, r_) returns_multidim = to_numpy(BasePolicy.compute_nstep_return( @@ -142,7 +233,8 @@ def test_nstep_returns(size=10000): if __name__ == '__main__': buf = ReplayBuffer(size) for i in range(int(size * 1.5)): - buf.add(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0, + info={"TimeLimit.truncated": i % 33 == 0})) batch, indice = buf.sample(256) def vanilla(): @@ -159,4 +251,5 @@ def optimized(): if __name__ == '__main__': test_nstep_returns() + test_nstep_returns_with_timelimit() test_episodic_returns() diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 93d2a48ca..aa72272a9 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2,7 +2,6 @@ import numpy as np from tianshou.utils import MovAvg -from tianshou.utils import SummaryWriter from tianshou.utils.net.common import MLP, Net from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -77,25 +76,7 @@ def test_net(): assert list(net(data, act).shape) == [bsz, 1] -def test_summary_writer(): - # get first instance by key of `default` or your own key - writer1 = SummaryWriter.get_instance( - key="first", log_dir="log/test_sw/first") - assert writer1.log_dir == "log/test_sw/first" - writer2 = SummaryWriter.get_instance() - assert writer1 is writer2 - # create new instance by specify a new key - writer3 = SummaryWriter.get_instance( - key="second", log_dir="log/test_sw/second") - assert writer3.log_dir == "log/test_sw/second" - writer4 = SummaryWriter.get_instance(key="second") - assert writer3 is writer4 - assert writer1 is not writer3 - assert writer1.log_dir != writer4.log_dir - - if __name__ == '__main__': test_noise() test_moving_average() test_net() - test_summary_writer() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 52257342d..093aed196 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -7,11 +7,12 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic @@ -25,20 +26,18 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--exploration-noise', type=float, default=0.1) - parser.add_argument('--test-noise', type=float, default=0.1) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=4) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=20000) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=int, default=1) - parser.add_argument('--ignore-done', type=int, default=1) - parser.add_argument('--n-step', type=int, default=1) + parser.add_argument('--rew-norm', action="store_true", default=False) + parser.add_argument('--n-step', type=int, default=3) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -82,16 +81,17 @@ def test_ddpg(args=get_args()): tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector( - policy, test_envs, action_noise=GaussianNoise(sigma=args.test_noise)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) + test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ddpg') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -102,8 +102,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -112,7 +113,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index d0987cdbf..762c58838 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -8,23 +8,24 @@ from torch.distributions import Independent, Normal from tianshou.policy import PPOPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=150000) + parser.add_argument('--episode-per-collect', type=int, default=16) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, @@ -104,11 +105,14 @@ def dist(*logits): gae_lambda=args.gae_lambda) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -119,9 +123,9 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -130,7 +134,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index f35e18497..ac533fcf4 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -6,10 +6,11 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import SACPolicy, ImitationPolicy from tianshou.utils.net.continuous import Actor, ActorProb, Critic @@ -25,20 +26,21 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.2) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=24000) + parser.add_argument('--il-step-per-epoch', type=int, default=500) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--imitation-hidden-sizes', type=int, nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=int, default=1) - parser.add_argument('--ignore-done', type=int, default=1) + parser.add_argument('--rew-norm', action="store_true", default=False) parser.add_argument('--n-step', type=int, default=4) parser.add_argument( '--device', type=str, @@ -88,16 +90,18 @@ def test_sac_with_il(args=get_args()): action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -108,8 +112,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -118,7 +123,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") # here we define an imitation collector with a trivial policy policy.eval() @@ -139,8 +145,8 @@ def stop_fn(mean_rewards): train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch // 5, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.il_step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -149,7 +155,8 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 32c0efd43..86331e993 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -7,20 +7,21 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import TD3Policy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer from tianshou.exploration import GaussianNoise -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.continuous import Actor, Critic def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='Pendulum-v0') - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--actor-lr', type=float, default=3e-4) + parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) @@ -28,19 +29,18 @@ def get_args(): parser.add_argument('--policy-noise', type=float, default=0.2) parser.add_argument('--noise-clip', type=float, default=0.5) parser.add_argument('--update-actor-freq', type=int, default=2) - parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=20000) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=128) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) - parser.add_argument('--rew-norm', type=int, default=1) - parser.add_argument('--ignore-done', type=int, default=1) - parser.add_argument('--n-step', type=int, default=1) + parser.add_argument('--rew-norm', action="store_true", default=False) + parser.add_argument('--n-step', type=int, default=3) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -93,16 +93,18 @@ def test_td3(args=get_args()): update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'td3') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -113,8 +115,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, + update_per_step=args.update_per_step, stop_fn=stop_fn, + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -123,7 +126,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 08759f92e..882cb440a 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -6,9 +6,10 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.trainer import onpolicy_trainer, offpolicy_trainer @@ -23,8 +24,11 @@ def get_args(): parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--il-step-per-epoch', type=int, default=1000) + parser.add_argument('--episode-per-collect', type=int, default=8) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -43,7 +47,7 @@ def get_args(): parser.add_argument('--ent-coef', type=float, default=0.0) parser.add_argument('--max-grad-norm', type=float, default=None) parser.add_argument('--gae-lambda', type=float, default=1.) - parser.add_argument('--rew-norm', type=bool, default=False) + parser.add_argument('--rew-norm', action="store_true", default=False) args = parser.parse_known_args()[0] return args @@ -79,11 +83,14 @@ def test_a2c_with_il(args=get_args()): max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'a2c') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -94,9 +101,9 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -105,7 +112,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") policy.eval() # here we define an imitation collector with a trivial policy @@ -118,14 +126,13 @@ def stop_fn(mean_rewards): il_policy = ImitationPolicy(net, optim, mode='discrete') il_test_collector = Collector( il_policy, - DummyVectorEnv( - [lambda: gym.make(args.task) for _ in range(args.test_num)]) + DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) ) train_collector.reset() result = offpolicy_trainer( il_policy, train_collector, il_test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.il_step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -134,7 +141,8 @@ def stop_fn(mean_rewards): il_policy.eval() collector = Collector(il_policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 32a41d0df..1d0c4cc0a 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -7,10 +7,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import C51Policy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -28,8 +29,9 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=8000) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) @@ -75,18 +77,20 @@ def test_c51(args=get_args()): ).to(args.device) # buffer if args.prioritized_replay: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, beta=args.beta) + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + alpha=args.alpha, beta=args.beta) else: - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf) - test_collector = Collector(policy, test_envs) + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'c51') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -111,9 +115,9 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -124,7 +128,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") def test_pc51(args=get_args()): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 1e9f08984..c59910e84 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -8,10 +8,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): @@ -25,13 +26,14 @@ def get_args(): parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -77,18 +79,20 @@ def test_dqn(args=get_args()): target_update_freq=args.target_update_freq) # buffer if args.prioritized_replay: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, beta=args.beta) + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + alpha=args.alpha, beta=args.beta) else: - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf) - test_collector = Collector(policy, test_envs) + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'dqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -113,9 +117,9 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) @@ -127,10 +131,11 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") # save buffer in pickle format, for imitation learning unittest - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) collector = Collector(policy, test_envs, buf) collector.collect(n_step=args.buffer_size) pickle.dump(buf, open(args.save_buffer_name, "wb")) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index f3f00e69f..39bef8dbc 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -7,10 +7,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.common import Recurrent -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -23,14 +24,15 @@ def get_args(): parser.add_argument('--stack-num', type=int, default=4) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.95) - parser.add_argument('--n-step', type=int, default=4) + parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) - parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--layer-num', type=int, default=3) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--epoch', type=int, default=5) + parser.add_argument('--step-per-epoch', type=int, default=20000) + parser.add_argument('--update-per-step', type=float, default=1 / 16) + parser.add_argument('--step-per-collect', type=int, default=16) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--layer-num', type=int, default=2) + parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -65,16 +67,18 @@ def test_drqn(args=get_args()): net, optim, args.gamma, args.n_step, target_update_freq=args.target_update_freq) # collector - train_collector = Collector( - policy, train_envs, ReplayBuffer( - args.buffer_size, stack_num=args.stack_num, ignore_obs_next=True)) + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + stack_num=args.stack_num, ignore_obs_next=True) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) # the stack_num is for RNN training: sample framestack obs - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'drqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -91,9 +95,10 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, + train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, + save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -103,7 +108,8 @@ def test_fn(epoch, env_step): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py index e9e857ef1..996f7d599 100644 --- a/test/discrete/test_il_bcq.py +++ b/test/discrete/test_il_bcq.py @@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offline_trainer @@ -26,7 +27,7 @@ def get_args(): parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) parser.add_argument("--epoch", type=int, default=5) - parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--update-per-epoch", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128]) @@ -78,10 +79,11 @@ def test_discrete_bcq(args=get_args()): buffer = pickle.load(open(args.load_buffer_name, "rb")) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -91,8 +93,8 @@ def stop_fn(mean_rewards): result = offline_trainer( policy, buffer, test_collector, - args.epoch, args.step_per_epoch, args.test_num, args.batch_size, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) @@ -104,7 +106,8 @@ def stop_fn(mean_rewards): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == "__main__": diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index b26213c59..6ebeb2686 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,10 +7,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PGPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -21,8 +22,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=40000) + parser.add_argument('--episode-per-collect', type=int, default=8) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -65,11 +66,14 @@ def test_pg(args=get_args()): reward_normalization=args.rew_norm) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'pg') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -80,9 +84,9 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -91,7 +95,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index e2e671c99..5821e7be8 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -7,10 +7,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PPOPolicy +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.utils.net.discrete import Actor, Critic @@ -22,8 +23,8 @@ def get_args(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=2000) - parser.add_argument('--collect-per-step', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=50000) + parser.add_argument('--episode-per-collect', type=int, default=20) parser.add_argument('--repeat-per-collect', type=int, default=2) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, @@ -91,11 +92,14 @@ def test_ppo(args=get_args()): value_clip=args.value_clip) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, 'ppo') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -106,9 +110,9 @@ def stop_fn(mean_rewards): # trainer result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.repeat_per_collect, - args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, - writer=writer) + args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, + logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -117,7 +121,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 7020df275..2268b63de 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -6,17 +6,18 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.policy import QRDQNPolicy from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) @@ -26,12 +27,13 @@ def get_args(): parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=10) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) @@ -73,18 +75,20 @@ def test_qrdqn(args=get_args()): ).to(args.device) # buffer if args.prioritized_replay: - buf = PrioritizedReplayBuffer( - args.buffer_size, alpha=args.alpha, beta=args.beta) + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + alpha=args.alpha, beta=args.beta) else: - buf = ReplayBuffer(args.buffer_size) + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf) - test_collector = Collector(policy, test_envs) + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log log_path = os.path.join(args.logdir, args.task, 'qrdqn') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -109,9 +113,10 @@ def test_fn(epoch, env_step): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer) + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step) assert stop_fn(result['best_reward']) if __name__ == '__main__': @@ -122,7 +127,8 @@ def test_fn(epoch, env_step): policy.set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") def test_pqrdqn(args=get_args()): diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 3d3df6f2c..ad594dbfc 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -6,12 +6,13 @@ import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import SubprocVectorEnv from tianshou.utils.net.common import Net -from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer from tianshou.policy import DiscreteSACPolicy +from tianshou.trainer import offpolicy_trainer from tianshou.utils.net.discrete import Actor, Critic +from tianshou.data import Collector, VectorReplayBuffer def get_args(): @@ -19,25 +20,25 @@ def get_args(): parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--actor-lr', type=float, default=3e-4) + parser.add_argument('--actor-lr', type=float, default=1e-4) parser.add_argument('--critic-lr', type=float, default=1e-3) parser.add_argument('--alpha-lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.95) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--alpha', type=float, default=0.05) - parser.add_argument('--auto_alpha', type=int, default=0) + parser.add_argument('--auto-alpha', action="store_true", default=False) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=1000) - parser.add_argument('--collect-per-step', type=int, default=5) - parser.add_argument('--batch-size', type=int, default=64) - parser.add_argument('--hidden-sizes', type=int, - nargs='*', default=[128, 128]) - parser.add_argument('--training-num', type=int, default=16) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.0) - parser.add_argument('--rew-norm', type=int, default=0) - parser.add_argument('--ignore-done', type=int, default=0) + parser.add_argument('--rew-norm', action="store_true", default=False) + parser.add_argument('--n-step', type=int, default=3) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -85,17 +86,19 @@ def test_discrete_sac(args=get_args()): policy = DiscreteSACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.alpha, - reward_normalization=args.rew_norm, - ignore_done=args.ignore_done) + args.tau, args.gamma, args.alpha, estimation_step=args.n_step, + reward_normalization=args.rew_norm) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, 'discrete_sac') writer = SummaryWriter(log_path) + logger = BasicLogger(writer) def save_fn(policy): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) @@ -106,9 +109,9 @@ def stop_fn(mean_rewards): # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) assert stop_fn(result['best_reward']) if __name__ == '__main__': pprint.pprint(result) @@ -117,7 +120,8 @@ def stop_fn(mean_rewards): policy.eval() collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") if __name__ == '__main__': diff --git a/test/modelbase/test_psrl.py b/test/modelbase/test_psrl.py index 29dfb6b8c..d89a7f4bc 100644 --- a/test/modelbase/test_psrl.py +++ b/test/modelbase/test_psrl.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -6,19 +7,20 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PSRLPolicy +# from tianshou.utils import BasicLogger from tianshou.trainer import onpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='NChain-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=50000) parser.add_argument('--epoch', type=int, default=5) - parser.add_argument('--step-per-epoch', type=int, default=5) - parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--episode-per-collect', type=int, default=1) parser.add_argument('--training-num', type=int, default=1) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') @@ -27,7 +29,7 @@ def get_args(): parser.add_argument('--rew-std-prior', type=float, default=1.0) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--eps', type=float, default=0.01) - parser.add_argument('--add-done-loop', action='store_true') + parser.add_argument('--add-done-loop', action="store_true", default=False) return parser.parse_known_args()[0] @@ -61,10 +63,15 @@ def test_psrl(args=get_args()): args.add_done_loop) # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # log - writer = SummaryWriter(args.logdir + '/' + args.task) + log_path = os.path.join(args.logdir, args.task, 'psrl') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + # logger = BasicLogger(writer) def stop_fn(mean_rewards): if env.spec.reward_threshold: @@ -73,11 +80,12 @@ def stop_fn(mean_rewards): return False train_collector.collect(n_step=args.buffer_size, random=True) - # trainer + # trainer, test it without logger result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, 1, - args.test_num, 0, stop_fn=stop_fn, writer=writer, + args.step_per_epoch, 1, args.test_num, 0, + episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, + # logger=logger, test_in_train=False) if __name__ == '__main__': @@ -86,9 +94,9 @@ def stop_fn(mean_rewards): policy.eval() test_envs.seed(args.seed) test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + result = test_collector.collect(n_episode=args.test_num, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") elif env.spec.reward_threshold: assert result["best_reward"] >= env.spec.reward_threshold diff --git a/test/multiagent/Gomoku.py b/test/multiagent/Gomoku.py index 53652a2e4..4c88656cb 100644 --- a/test/multiagent/Gomoku.py +++ b/test/multiagent/Gomoku.py @@ -7,6 +7,7 @@ from tianshou.env import DummyVectorEnv from tianshou.data import Collector from tianshou.policy import RandomPolicy +from tianshou.utils import BasicLogger from tic_tac_toe_env import TicTacToeEnv from tic_tac_toe import get_parser, get_agents, train_agent, watch @@ -31,7 +32,8 @@ def gomoku(args=get_args()): # log log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') - args.writer = SummaryWriter(log_path) + writer = SummaryWriter(log_path) + args.logger = BasicLogger(writer) opponent_pool = [agent_opponent] @@ -46,7 +48,7 @@ def env_func(): policy.replace_policy(opponent, 3 - args.agent_id) test_collector = Collector(policy, test_envs) results = test_collector.collect(n_episode=100) - rews.append(results['rew']) + rews.append(results['rews'].mean()) rews = np.array(rews) # weight opponent by their difficulty level rews = np.exp(-rews * 10.0) diff --git a/test/multiagent/test_tic_tac_toe.py b/test/multiagent/test_tic_tac_toe.py index 92ecb97c6..1cc06d374 100644 --- a/test/multiagent/test_tic_tac_toe.py +++ b/test/multiagent/test_tic_tac_toe.py @@ -1,10 +1,8 @@ import pprint -from tianshou.data import Collector from tic_tac_toe import get_args, train_agent, watch def test_tic_tac_toe(args=get_args()): - Collector._default_rew_metric = lambda x: x[args.agent_id - 1] if args.watch: watch(args) return diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index f9d6af104..3e92838ab 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -6,10 +6,11 @@ from typing import Optional, Tuple from torch.utils.tensorboard import SummaryWriter +from tianshou.utils import BasicLogger from tianshou.env import DummyVectorEnv from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer -from tianshou.data import Collector, ReplayBuffer +from tianshou.data import Collector, VectorReplayBuffer from tianshou.policy import BasePolicy, DQNPolicy, RandomPolicy, \ MultiAgentPolicyManager @@ -28,29 +29,30 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--epoch', type=int, default=20) - parser.add_argument('--step-per-epoch', type=int, default=500) - parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.1) - parser.add_argument('--board_size', type=int, default=6) - parser.add_argument('--win_size', type=int, default=4) - parser.add_argument('--win_rate', type=float, default=0.9, + parser.add_argument('--board-size', type=int, default=6) + parser.add_argument('--win-size', type=int, default=4) + parser.add_argument('--win-rate', type=float, default=0.9, help='the expected winning rate') parser.add_argument('--watch', default=False, action='store_true', help='no training, ' 'watch the play of pre-trained models') - parser.add_argument('--agent_id', type=int, default=2, + parser.add_argument('--agent-id', type=int, default=2, help='the learned agent plays as the' - ' agent_id-th player. choices are 1 and 2.') - parser.add_argument('--resume_path', type=str, default='', + ' agent_id-th player. Choices are 1 and 2.') + parser.add_argument('--resume-path', type=str, default='', help='the path of agent pth file ' 'for resuming from a pre-trained agent') - parser.add_argument('--opponent_path', type=str, default='', + parser.add_argument('--opponent-path', type=str, default='', help='the path of opponent agent pth file ' 'for resuming from a pre-trained agent') parser.add_argument( @@ -61,8 +63,7 @@ def get_parser() -> argparse.ArgumentParser: def get_args() -> argparse.Namespace: parser = get_parser() - args = parser.parse_known_args()[0] - return args + return parser.parse_known_args()[0] def get_agents( @@ -124,17 +125,17 @@ def env_func(): # collector train_collector = Collector( - policy, train_envs, ReplayBuffer(args.buffer_size)) + policy, train_envs, + VectorReplayBuffer(args.buffer_size, len(train_envs)), + exploration_noise=True) test_collector = Collector(policy, test_envs) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size * args.training_num) # log - if not hasattr(args, 'writer'): - log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') - writer = SummaryWriter(log_path) - args.writer = writer - else: - writer = args.writer + log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) def save_fn(policy): if hasattr(args, 'model_save_path'): @@ -155,13 +156,16 @@ def train_fn(epoch, env_step): def test_fn(epoch, env_step): policy.policies[args.agent_id - 1].set_eps(args.eps_test) + def reward_metric(rews): + return rews[:, args.agent_id - 1] + # trainer result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, - args.step_per_epoch, args.collect_per_step, args.test_num, + args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, writer=writer, - test_in_train=False) + stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, + logger=logger, test_in_train=False, reward_metric=reward_metric) return result, policy.policies[args.agent_id - 1] @@ -178,4 +182,5 @@ def watch( policy.policies[args.agent_id - 1].set_eps(args.eps_test) collector = Collector(policy, env) result = collector.collect(n_episode=1, render=args.render) - print(f'Final reward: {result["rew"]}, length: {result["len"]}') + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") diff --git a/test/throughput/env.py b/test/throughput/env.py new file mode 120000 index 000000000..9a57534db --- /dev/null +++ b/test/throughput/env.py @@ -0,0 +1 @@ +../base/env.py \ No newline at end of file diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py index 3134004f1..40ce68889 100644 --- a/test/throughput/test_buffer_profile.py +++ b/test/throughput/test_buffer_profile.py @@ -1,89 +1,61 @@ -import pytest +import sys +import gym +import time +import tqdm import numpy as np - -from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer, - ReplayBuffer, SegmentTree) - - -@pytest.fixture(scope="module") -def data(): - np.random.seed(0) - obs = {'observable': np.random.rand(100, 100), - 'hidden': np.random.randint(1000, size=200)} - info = {'policy': "dqn", 'base': np.arange(10)} - add_data = {'obs': obs, 'rew': 1., 'act': np.random.rand(30), - 'done': False, 'obs_next': obs, 'info': info} - buffer = ReplayBuffer(int(1e3), stack_num=100) - buffer2 = ReplayBuffer(int(1e4), stack_num=100) - indexes = np.random.choice(int(1e3), size=3, replace=False) - return { - 'add_data': add_data, - 'buffer': buffer, - 'buffer2': buffer2, - 'slice': slice(-3000, -1000, 2), - 'indexes': indexes, - } - - -def test_init(): - for _ in np.arange(1e5): - _ = ReplayBuffer(1e5) - _ = PrioritizedReplayBuffer(size=int(1e5), alpha=0.5, beta=0.5) - _ = ListReplayBuffer() - - -def test_add(data): - buffer = data['buffer'] - for _ in np.arange(1e5): - buffer.add(**data['add_data']) - - -def test_update(data): - buffer = data['buffer'] - buffer2 = data['buffer2'] - for _ in np.arange(1e2): - buffer2.update(buffer) - - -def test_getitem_slice(data): - Slice = data['slice'] - buffer = data['buffer'] - for _ in np.arange(1e3): - _ = buffer[Slice] - - -def test_getitem_indexes(data): - indexes = data['indexes'] - buffer = data['buffer'] - for _ in np.arange(1e2): - _ = buffer[indexes] - - -def test_get(data): - indexes = data['indexes'] - buffer = data['buffer'] - for _ in np.arange(3e2): - buffer.get(indexes, 'obs') - buffer.get(indexes, 'rew') - buffer.get(indexes, 'done') - buffer.get(indexes, 'info') - - -def test_sample(data): - buffer = data['buffer'] - for _ in np.arange(1e1): - buffer.sample(int(1e2)) - - -def test_segtree(data): - size = 100000 - tree = SegmentTree(size) - tree[np.arange(size)] = np.random.rand(size) - - for i in np.arange(1e5): - scalar = np.random.rand(64) * tree.reduce() - tree.get_prefix_sum_idx(scalar) +from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer + + +def test_replaybuffer(task="Pendulum-v0"): + total_count = 5 + for _ in tqdm.trange(total_count, desc="ReplayBuffer"): + env = gym.make(task) + buf = ReplayBuffer(10000) + obs = env.reset() + for i in range(100000): + act = env.action_space.sample() + obs_next, rew, done, info = env.step(act) + batch = Batch( + obs=np.array([obs]), + act=np.array([act]), + rew=np.array([rew]), + done=np.array([done]), + obs_next=np.array([obs_next]), + info=np.array([info]), + ) + buf.add(batch, buffer_ids=[0]) + obs = obs_next + if done: + obs = env.reset() + + +def test_vectorbuffer(task="Pendulum-v0"): + total_count = 5 + for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"): + env = gym.make(task) + buf = VectorReplayBuffer(total_size=10000, buffer_num=1) + obs = env.reset() + for i in range(100000): + act = env.action_space.sample() + obs_next, rew, done, info = env.step(act) + batch = Batch( + obs=np.array([obs]), + act=np.array([act]), + rew=np.array([rew]), + done=np.array([done]), + obs_next=np.array([obs_next]), + info=np.array([info]), + ) + buf.add(batch) + obs = obs_next + if done: + obs = env.reset() if __name__ == '__main__': - pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"]) + t0 = time.time() + test_replaybuffer(sys.argv[-1]) + print("test replaybuffer: ", time.time() - t0) + t0 = time.time() + test_vectorbuffer(sys.argv[-1]) + print("test vectorbuffer: ", time.time() - t0) diff --git a/test/throughput/test_collector_profile.py b/test/throughput/test_collector_profile.py index 4036472f7..6242e694b 100644 --- a/test/throughput/test_collector_profile.py +++ b/test/throughput/test_collector_profile.py @@ -1,144 +1,109 @@ -import gym +import tqdm import numpy as np -import pytest -from gym.spaces.discrete import Discrete -from gym.utils import seeding -from tianshou.data import Batch, Collector, ReplayBuffer -from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy +from tianshou.env import DummyVectorEnv, SubprocVectorEnv +from tianshou.data import Batch, Collector, AsyncCollector, VectorReplayBuffer - -class SimpleEnv(gym.Env): - """A simplest example of self-defined env, used to minimize - data collect time and profile collector.""" - - def __init__(self): - self.action_space = Discrete(200) - self._fake_data = np.ones((10, 10, 1)) - self.seed(0) - self.reset() - - def reset(self): - self._index = 0 - self.done = np.random.randint(3, high=200) - return {'observable': np.zeros((10, 10, 1)), 'hidden': self._index} - - def step(self, action): - if self._index == self.done: - raise ValueError('step after done !!!') - self._index += 1 - return {'observable': self._fake_data, 'hidden': self._index}, -1, \ - self._index == self.done, {} - - def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - return [seed] - - -class SimplePolicy(BasePolicy): - """A simplest example of self-defined policy, used - to minimize data collect time.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def learn(self, batch, **kwargs): - return super().learn(batch, **kwargs) - - def forward(self, batch, state=None, **kwargs): - return Batch(act=np.array([30] * len(batch)), state=None, logits=None) - - -@pytest.fixture(scope="module") -def data(): - np.random.seed(0) - env = SimpleEnv() - env.seed(0) - env_vec = DummyVectorEnv([lambda: SimpleEnv() for _ in range(100)]) - env_vec.seed(np.random.randint(1000, size=100).tolist()) - env_subproc = SubprocVectorEnv([lambda: SimpleEnv() for _ in range(8)]) - env_subproc.seed(np.random.randint(1000, size=100).tolist()) - env_subproc_init = SubprocVectorEnv( - [lambda: SimpleEnv() for _ in range(8)]) - env_subproc_init.seed(np.random.randint(1000, size=100).tolist()) - buffer = ReplayBuffer(50000) - policy = SimplePolicy() - collector = Collector(policy, env, ReplayBuffer(50000)) - collector_vec = Collector(policy, env_vec, ReplayBuffer(50000)) - collector_subproc = Collector(policy, env_subproc, ReplayBuffer(50000)) - return { - "env": env, - "env_vec": env_vec, - "env_subproc": env_subproc, - "env_subproc_init": env_subproc_init, - "policy": policy, - "buffer": buffer, - "collector": collector, - "collector_vec": collector_vec, - "collector_subproc": collector_subproc, - } - - -def test_init(data): - for _ in range(5000): - Collector(data["policy"], data["env"], data["buffer"]) - - -def test_reset(data): - for _ in range(5000): - data["collector"].reset() - - -def test_collect_st(data): - for _ in range(50): - data["collector"].collect(n_step=1000) - - -def test_collect_ep(data): - for _ in range(50): - data["collector"].collect(n_episode=10) - - -def test_init_vec_env(data): - for _ in range(5000): - Collector(data["policy"], data["env_vec"], data["buffer"]) - - -def test_reset_vec_env(data): - for _ in range(5000): - data["collector_vec"].reset() - - -def test_collect_vec_env_st(data): - for _ in range(50): - data["collector_vec"].collect(n_step=1000) - - -def test_collect_vec_env_ep(data): - for _ in range(50): - data["collector_vec"].collect(n_episode=10) - - -def test_init_subproc_env(data): - for _ in range(5000): - Collector(data["policy"], data["env_subproc_init"], data["buffer"]) - - -def test_reset_subproc_env(data): - for _ in range(5000): - data["collector_subproc"].reset() - - -def test_collect_subproc_env_st(data): - for _ in range(50): - data["collector_subproc"].collect(n_step=1000) - - -def test_collect_subproc_env_ep(data): - for _ in range(50): - data["collector_subproc"].collect(n_episode=10) +if __name__ == '__main__': + from env import MyTestEnv +else: # pytest + from test.base.env import MyTestEnv + + +class MyPolicy(BasePolicy): + def __init__(self, dict_state=False, need_state=True): + """ + :param bool dict_state: if the observation of the environment is a dict + :param bool need_state: if the policy needs the hidden state (for RNN) + """ + super().__init__() + self.dict_state = dict_state + self.need_state = need_state + + def forward(self, batch, state=None): + if self.need_state: + if state is None: + state = np.zeros((len(batch.obs), 2)) + else: + state += 1 + if self.dict_state: + return Batch(act=np.ones(len(batch.obs['index'])), state=state) + return Batch(act=np.ones(len(batch.obs)), state=state) + + def learn(self): + pass + + +def test_collector_nstep(): + policy = MyPolicy() + env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] + dum = DummyVectorEnv(env_fns) + num = len(env_fns) + c3 = Collector(policy, dum, + VectorReplayBuffer(total_size=40000, buffer_num=num)) + for i in tqdm.trange(1, 400, desc="test step collector n_step"): + c3.reset() + result = c3.collect(n_step=i * len(env_fns)) + assert result['n/st'] >= i + + +def test_collector_nepisode(): + policy = MyPolicy() + env_fns = [lambda x=i: MyTestEnv(size=x) for i in np.arange(2, 11)] + dum = DummyVectorEnv(env_fns) + num = len(env_fns) + c3 = Collector(policy, dum, + VectorReplayBuffer(total_size=40000, buffer_num=num)) + for i in tqdm.trange(1, 400, desc="test step collector n_episode"): + c3.reset() + result = c3.collect(n_episode=i) + assert result['n/ep'] == i + assert result['n/st'] == len(c3.buffer) + + +def test_asynccollector(): + env_lens = [2, 3, 4, 5] + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) + for i in env_lens] + + venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) + policy = MyPolicy() + bufsize = 300 + c1 = AsyncCollector( + policy, venv, + VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4)) + ptr = [0, 0, 0, 0] + for n_episode in tqdm.trange(1, 100, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode) + assert result["n/ep"] >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate( + np.bincount(result["lens"], minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape( + count, env_len) == seq + 1) + # test async n_step, for now the buffer should be full of data + for n_step in tqdm.trange(1, 150, desc="test async n_step"): + result = c1.collect(n_step=n_step) + assert result["n/st"] >= n_step + for i in range(4): + env_len = i + 2 + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id == i) + assert np.all(buf.obs.reshape(-1, env_len) == seq) + assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) if __name__ == '__main__': - pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"]) + test_collector_nstep() + test_collector_nepisode() + test_asynccollector() diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 380167b2c..689f50edf 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.3.2" +__version__ = "0.4.0" __all__ = [ "env", diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index e51b8d161..75e02a940 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,9 +1,14 @@ from tianshou.data.batch import Batch from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree -from tianshou.data.buffer import ReplayBuffer, \ - ListReplayBuffer, PrioritizedReplayBuffer -from tianshou.data.collector import Collector +from tianshou.data.buffer.base import ReplayBuffer +from tianshou.data.buffer.prio import PrioritizedReplayBuffer +from tianshou.data.buffer.manager import ReplayBufferManager +from tianshou.data.buffer.manager import PrioritizedReplayBufferManager +from tianshou.data.buffer.vecbuf import VectorReplayBuffer +from tianshou.data.buffer.vecbuf import PrioritizedVectorReplayBuffer +from tianshou.data.buffer.cached import CachedReplayBuffer +from tianshou.data.collector import Collector, AsyncCollector __all__ = [ "Batch", @@ -12,7 +17,12 @@ "to_torch_as", "SegmentTree", "ReplayBuffer", - "ListReplayBuffer", "PrioritizedReplayBuffer", + "ReplayBufferManager", + "PrioritizedReplayBufferManager", + "VectorReplayBuffer", + "PrioritizedVectorReplayBuffer", + "CachedReplayBuffer", "Collector", + "AsyncCollector", ] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4f15622ab..a07ad67ed 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -5,8 +5,7 @@ from copy import deepcopy from numbers import Number from collections.abc import Collection -from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ - Sequence +from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, Sequence def _is_batch_set(data: Any) -> bool: @@ -35,8 +34,8 @@ def _is_scalar(value: Any) -> bool: if isinstance(value, torch.Tensor): return value.numel() == 1 and not value.shape else: - value = np.asanyarray(value) - return value.size == 1 and not value.shape + # np.asanyarray will cause dead loop in some cases + return np.isscalar(value) def _is_number(value: Any) -> bool: @@ -48,10 +47,8 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: - if isinstance(v, np.ndarray) and issubclass( - v.dtype.type, (np.bool_, np.number) - ): # most often case - return v + if isinstance(v, np.ndarray) and issubclass(v.dtype.type, (np.bool_, np.number)): + return v # most often case # convert the value to np.ndarray # convert to np.object data type if neither bool nor number # raises an exception if array's elements are tensors themself @@ -66,9 +63,7 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: # array([{}, array({}, dtype=object)], dtype=object) if not v.shape: v = v.item(0) - elif any( - isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1) - ): + elif any(isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1)): raise ValueError("Numpy arrays of tensors are not supported yet.") return v @@ -78,20 +73,16 @@ def _create_value( ) -> Union["Batch", np.ndarray, torch.Tensor]: """Create empty place-holders accroding to inst's shape. - :param bool stack: whether to stack or to concatenate. E.g. if inst has - shape of (3, 5), size = 10, stack=True returns an np.ndarry with shape - of (10, 3, 5), otherwise (10, 5) + :param bool stack: whether to stack or to concatenate. E.g. if inst has shape of + (3, 5), size = 10, stack=True returns an np.ndarry with shape of (10, 3, 5), + otherwise (10, 5) """ has_shape = isinstance(inst, (np.ndarray, torch.Tensor)) is_scalar = _is_scalar(inst) if not stack and is_scalar: - # _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here - if isinstance(inst, Batch) and inst.is_empty(recurse=True): - return inst - # should never hit since it has already checked in Batch.cat_ - # here we do not consider scalar types, following the behavior of numpy - # which does not support concatenation of zero-dimensional arrays - # (scalars) + # should never hit since it has already checked in Batch.cat_ , here we do not + # consider scalar types, following the behavior of numpy which does not support + # concatenation of zero-dimensional arrays (scalars) raise TypeError(f"cannot concatenate with {inst} which is scalar") if has_shape: shape = (size, *inst.shape) if stack else (size, *inst.shape[1:]) @@ -106,9 +97,7 @@ def _create_value( dtype=target_type ) elif isinstance(inst, torch.Tensor): - return torch.full( - shape, fill_value=0, device=inst.device, dtype=inst.dtype - ) + return torch.full(shape, fill_value=0, device=inst.device, dtype=inst.dtype) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): @@ -121,9 +110,8 @@ def _create_value( def _assert_type_keys(keys: Iterable[str]) -> None: - assert all( - isinstance(e, str) for e in keys - ), f"keys should all be string, but got {keys}" + assert all(isinstance(e, str) for e in keys), \ + f"keys should all be string, but got {keys}" def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: @@ -158,6 +146,19 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: return v +def _alloc_by_keys_diff( + meta: "Batch", batch: "Batch", size: int, stack: bool = True +) -> None: + for key in batch.keys(): + if key in meta.keys(): + if isinstance(meta[key], Batch) and isinstance(batch[key], Batch): + _alloc_by_keys_diff(meta[key], batch[key], size, stack) + elif isinstance(meta[key], Batch) and meta[key].is_empty(): + meta[key] = _create_value(batch[key], size, stack) + else: + meta[key] = _create_value(batch[key], size, stack) + + class Batch: """The internal data structure in Tianshou. @@ -440,9 +441,7 @@ def __cat( val, sum_lens[-1], stack=False) self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val - def cat_( - self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]] - ) -> None: + def cat_(self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" if isinstance(batches, Batch): batches = [batches] @@ -498,9 +497,7 @@ def cat(batches: Sequence[Union[dict, "Batch"]]) -> "Batch": batch.cat_(batches) return batch - def stack_( - self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0 - ) -> None: + def stack_(self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> None: """Stack a list of Batch object into current batch.""" # check input format batch_list = [] @@ -564,9 +561,7 @@ def stack_( self.__dict__[k][i] = val @staticmethod - def stack( - batches: Sequence[Union[dict, "Batch"]], axis: int = 0 - ) -> "Batch": + def stack(batches: Sequence[Union[dict, "Batch"]], axis: int = 0) -> "Batch": """Stack a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not @@ -593,10 +588,7 @@ def stack( return batch def empty_( - self, - index: Union[ - str, slice, int, np.integer, np.ndarray, List[int] - ] = None, + self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None ) -> "Batch": """Return an empty Batch object with 0 or None filled. @@ -646,9 +638,7 @@ def empty_( @staticmethod def empty( batch: "Batch", - index: Union[ - str, slice, int, np.integer, np.ndarray, List[int] - ] = None, + index: Union[str, slice, int, np.integer, np.ndarray, List[int]] = None, ) -> "Batch": """Return an empty Batch object with 0 or None filled. @@ -674,9 +664,7 @@ def __len__(self) -> int: for v in self.__dict__.values(): if isinstance(v, Batch) and v.is_empty(recurse=True): continue - elif hasattr(v, "__len__") and ( - isinstance(v, Batch) or v.ndim > 0 - ): + elif hasattr(v, "__len__") and (isinstance(v, Batch) or v.ndim > 0): r.append(len(v)) else: raise TypeError(f"Object {v} in {self} has no len()") diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py deleted file mode 100644 index 1d36ce13a..000000000 --- a/tianshou/data/buffer.py +++ /dev/null @@ -1,534 +0,0 @@ -import h5py -import torch -import warnings -import numpy as np -from numbers import Number -from typing import Any, Dict, List, Tuple, Union, Optional - -from tianshou.data.batch import _create_value -from tianshou.data import Batch, SegmentTree, to_numpy -from tianshou.data.utils.converter import to_hdf5, from_hdf5 - - -class ReplayBuffer: - """:class:`~tianshou.data.ReplayBuffer` stores data generated from \ - interaction between the policy and environment. - - The current implementation of Tianshou typically use 7 reserved keys in - :class:`~tianshou.data.Batch`: - - * ``obs`` the observation of step :math:`t` ; - * ``act`` the action of step :math:`t` ; - * ``rew`` the reward of step :math:`t` ; - * ``done`` the done flag of step :math:`t` ; - * ``obs_next`` the observation of step :math:`t+1` ; - * ``info`` the info of step :math:`t` (in ``gym.Env``, the ``env.step()`` \ - function returns 4 arguments, and the last one is ``info``); - * ``policy`` the data computed by policy in step :math:`t`; - - The following code snippet illustrates its usage: - :: - - >>> import pickle, numpy as np - >>> from tianshou.data import ReplayBuffer - >>> buf = ReplayBuffer(size=20) - >>> for i in range(3): - ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) - >>> buf.obs - # since we set size = 20, len(buf.obs) == 20. - array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0.]) - >>> # but there are only three valid items, so len(buf) == 3. - >>> len(buf) - 3 - >>> # save to file "buf.pkl" - >>> pickle.dump(buf, open('buf.pkl', 'wb')) - >>> # save to HDF5 file - >>> buf.save_hdf5('buf.hdf5') - >>> buf2 = ReplayBuffer(size=10) - >>> for i in range(15): - ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={}) - >>> len(buf2) - 10 - >>> buf2.obs - # since its size = 10, it only stores the last 10 steps' result. - array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.]) - - >>> # move buf2's result into buf (meanwhile keep it chronologically) - >>> buf.update(buf2) - array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., - 0., 0., 0., 0., 0., 0., 0.]) - - >>> # get a random sample from buffer - >>> # the batch_data is equal to buf[indice]. - >>> batch_data, indice = buf.sample(batch_size=4) - >>> batch_data.obs == buf[indice].obs - array([ True, True, True, True]) - >>> len(buf) - 13 - >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl" - >>> len(buf) - 3 - >>> # load complete buffer from HDF5 file - >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') - >>> len(buf) - 3 - >>> # load contents of HDF5 file into existing buffer - >>> # (only possible if size of buffer and data in file match) - >>> buf.load_contents_hdf5('buf.hdf5') - >>> len(buf) - 3 - - :class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling - (typically for RNN usage, see issue#19), ignoring storing the next - observation (save memory in atari tasks), and multi-modal observation (see - issue#38): - :: - - >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True) - >>> for i in range(16): - ... done = i % 5 == 0 - ... buf.add(obs={'id': i}, act=i, rew=i, done=done, - ... obs_next={'id': i + 1}) - >>> print(buf) # you can see obs_next is not saved in buf - ReplayBuffer( - act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]), - info: Batch(), - obs: Batch( - id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - ), - policy: Batch(), - rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]), - ) - >>> index = np.arange(len(buf)) - >>> print(buf.get(index, 'obs').id) - [[ 7. 7. 8. 9.] - [ 7. 8. 9. 10.] - [11. 11. 11. 11.] - [11. 11. 11. 12.] - [11. 11. 12. 13.] - [11. 12. 13. 14.] - [12. 13. 14. 15.] - [ 7. 7. 7. 7.] - [ 7. 7. 7. 8.]] - >>> # here is another way to get the stacked data - >>> # (stack only for obs and obs_next) - >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum() - 0.0 - >>> # we can get obs_next through __getitem__, even if it doesn't exist - >>> print(buf[:].obs_next.id) - [[ 7. 8. 9. 10.] - [ 7. 8. 9. 10.] - [11. 11. 11. 12.] - [11. 11. 12. 13.] - [11. 12. 13. 14.] - [12. 13. 14. 15.] - [12. 13. 14. 15.] - [ 7. 7. 7. 8.] - [ 7. 7. 8. 9.]] - - :param int size: the size of replay buffer. - :param int stack_num: the frame-stack sampling argument, should be greater - than or equal to 1, defaults to 1 (no stacking). - :param bool ignore_obs_next: whether to store obs_next, defaults to False. - :param bool save_only_last_obs: only save the last obs/obs_next when it has - a shape of (timestep, ...) because of temporal stacking, defaults to - False. - :param bool sample_avail: the parameter indicating sampling only available - index when using frame-stack sampling method, defaults to False. - This feature is not supported in Prioritized Replay Buffer currently. - """ - - def __init__( - self, - size: int, - stack_num: int = 1, - ignore_obs_next: bool = False, - save_only_last_obs: bool = False, - sample_avail: bool = False, - ) -> None: - super().__init__() - self._maxsize = size - self._indices = np.arange(size) - self.stack_num = stack_num - self._avail = sample_avail and stack_num > 1 - self._avail_index: List[int] = [] - self._save_s_ = not ignore_obs_next - self._last_obs = save_only_last_obs - self._index = 0 - self._size = 0 - self._meta: Batch = Batch() - self.reset() - - def __len__(self) -> int: - """Return len(self).""" - return self._size - - def __repr__(self) -> str: - """Return str(self).""" - return self.__class__.__name__ + self._meta.__repr__()[5:] - - def __getattr__(self, key: str) -> Any: - """Return self.key.""" - try: - return self._meta[key] - except KeyError as e: - raise AttributeError from e - - def __setstate__(self, state: Dict[str, Any]) -> None: - """Unpickling interface. - - We need it because pickling buffer does not work out-of-the-box - ("buffer.__getattr__" is customized). - """ - self._indices = np.arange(state["_maxsize"]) - self.__dict__.update(state) - - def __getstate__(self) -> dict: - exclude = {"_indices"} - state = {k: v for k, v in self.__dict__.items() if k not in exclude} - return state - - def _add_to_buffer(self, name: str, inst: Any) -> None: - try: - value = self._meta.__dict__[name] - except KeyError: - self._meta.__dict__[name] = _create_value(inst, self._maxsize) - value = self._meta.__dict__[name] - if isinstance(inst, (torch.Tensor, np.ndarray)): - if inst.shape != value.shape[1:]: - raise ValueError( - "Cannot add data to a buffer with different shape with key" - f" {name}, expect {value.shape[1:]}, given {inst.shape}." - ) - try: - value[self._index] = inst - except ValueError: - for key in set(inst.keys()).difference(value.__dict__.keys()): - value.__dict__[key] = _create_value(inst[key], self._maxsize) - value[self._index] = inst - - @property - def stack_num(self) -> int: - return self._stack - - @stack_num.setter - def stack_num(self, num: int) -> None: - assert num > 0, "stack_num should greater than 0" - self._stack = num - - def update(self, buffer: "ReplayBuffer") -> None: - """Move the data from the given buffer to self.""" - if len(buffer) == 0: - return - i = begin = buffer._index % len(buffer) - stack_num_orig = buffer.stack_num - buffer.stack_num = 1 - while True: - self.add(**buffer[i]) # type: ignore - i = (i + 1) % len(buffer) - if i == begin: - break - buffer.stack_num = stack_num_orig - - def add( - self, - obs: Any, - act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - **kwargs: Any, - ) -> None: - """Add a batch of data into replay buffer.""" - assert isinstance( - info, (dict, Batch) - ), "You should return a dict in the last argument of env.step()." - if self._last_obs: - obs = obs[-1] - self._add_to_buffer("obs", obs) - self._add_to_buffer("act", act) - # make sure the reward is a float instead of an int - self._add_to_buffer("rew", rew * 1.0) # type: ignore - self._add_to_buffer("done", done) - if self._save_s_: - if obs_next is None: - obs_next = Batch() - elif self._last_obs: - obs_next = obs_next[-1] - self._add_to_buffer("obs_next", obs_next) - self._add_to_buffer("info", info) - self._add_to_buffer("policy", policy) - - # maintain available index for frame-stack sampling - if self._avail: - # update current frame - avail = sum(self.done[i] for i in range( - self._index - self.stack_num + 1, self._index)) == 0 - if self._size < self.stack_num - 1: - avail = False - if avail and self._index not in self._avail_index: - self._avail_index.append(self._index) - elif not avail and self._index in self._avail_index: - self._avail_index.remove(self._index) - # remove the later available frame because of broken storage - t = (self._index + self.stack_num - 1) % self._maxsize - if t in self._avail_index: - self._avail_index.remove(t) - - if self._maxsize > 0: - self._size = min(self._size + 1, self._maxsize) - self._index = (self._index + 1) % self._maxsize - else: - self._size = self._index = self._index + 1 - - def reset(self) -> None: - """Clear all the data in replay buffer.""" - self._index = 0 - self._size = 0 - self._avail_index = [] - - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with size equal to batch_size. - - Return all the data in the buffer if batch_size is 0. - - :return: Sample data and its corresponding index inside the buffer. - """ - if batch_size > 0: - _all = self._avail_index if self._avail else self._size - indice = np.random.choice(_all, batch_size) - else: - if self._avail: - indice = np.array(self._avail_index) - else: - indice = np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) - assert len(indice) > 0, "No available indice can be sampled." - return self[indice], indice - - def get( - self, - indice: Union[slice, int, np.integer, np.ndarray], - key: str, - stack_num: Optional[int] = None, - ) -> Union[Batch, np.ndarray]: - """Return the stacked result. - - E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the - indice. The stack_num (here equals to 4) is given from buffer - initialization procedure. - """ - if stack_num is None: - stack_num = self.stack_num - if stack_num == 1: # the most often case - if key != "obs_next" or self._save_s_: - val = self._meta.__dict__[key] - try: - return val[indice] - except IndexError as e: - if not (isinstance(val, Batch) and val.is_empty()): - raise e # val != Batch() - return Batch() - indice = self._indices[:self._size][indice] - done = self._meta.__dict__["done"] - if key == "obs_next" and not self._save_s_: - indice += 1 - done[indice].astype(np.int) - indice[indice == self._size] = 0 - key = "obs" - val = self._meta.__dict__[key] - try: - if stack_num == 1: - return val[indice] - stack: List[Any] = [] - for _ in range(stack_num): - stack = [val[indice]] + stack - pre_indice = np.asarray(indice - 1) - pre_indice[pre_indice == -1] = self._size - 1 - indice = np.asarray( - pre_indice + done[pre_indice].astype(np.int)) - indice[indice == self._size] = 0 - if isinstance(val, Batch): - return Batch.stack(stack, axis=indice.ndim) - else: - return np.stack(stack, axis=indice.ndim) - except IndexError as e: - if not (isinstance(val, Batch) and val.is_empty()): - raise e # val != Batch() - return Batch() - - def __getitem__( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> Batch: - """Return a data batch: self[index]. - - If stack_num is larger than 1, return the stacked obs and obs_next with - shape (batch, len, ...). - """ - return Batch( - obs=self.get(index, "obs"), - act=self.act[index], - rew=self.rew[index], - done=self.done[index], - obs_next=self.get(index, "obs_next"), - info=self.get(index, "info"), - policy=self.get(index, "policy"), - ) - - def save_hdf5(self, path: str) -> None: - """Save replay buffer to HDF5 file.""" - with h5py.File(path, "w") as f: - to_hdf5(self.__getstate__(), f) - - @classmethod - def load_hdf5( - cls, path: str, device: Optional[str] = None - ) -> "ReplayBuffer": - """Load replay buffer from HDF5 file.""" - with h5py.File(path, "r") as f: - buf = cls.__new__(cls) - buf.__setstate__(from_hdf5(f, device=device)) - return buf - - -class ListReplayBuffer(ReplayBuffer): - """List-based replay buffer. - - The function of :class:`~tianshou.data.ListReplayBuffer` is almost the same - as :class:`~tianshou.data.ReplayBuffer`. The only difference is that - :class:`~tianshou.data.ListReplayBuffer` is based on list. Therefore, - it does not support advanced indexing, which means you cannot sample a - batch of data out of it. It is typically used for storing data. - - .. seealso:: - - Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed - explanation. - """ - - def __init__(self, **kwargs: Any) -> None: - super().__init__(size=0, ignore_obs_next=False, **kwargs) - warnings.warn("ListReplayBuffer will be removed in version 0.4.0.") - - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - raise NotImplementedError("ListReplayBuffer cannot be sampled!") - - def _add_to_buffer( - self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool] - ) -> None: - if self._meta.__dict__.get(name) is None: - self._meta.__dict__[name] = [] - self._meta.__dict__[name].append(inst) - - def reset(self) -> None: - self._index = self._size = 0 - for k in list(self._meta.__dict__.keys()): - if isinstance(self._meta.__dict__[k], list): - self._meta.__dict__[k] = [] - - -class PrioritizedReplayBuffer(ReplayBuffer): - """Implementation of Prioritized Experience Replay. arXiv:1511.05952. - - :param float alpha: the prioritization exponent. - :param float beta: the importance sample soft coefficient. - - .. seealso:: - - Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed - explanation. - """ - - def __init__( - self, size: int, alpha: float, beta: float, **kwargs: Any - ) -> None: - super().__init__(size, **kwargs) - assert alpha > 0.0 and beta >= 0.0 - self._alpha, self._beta = alpha, beta - self._max_prio = self._min_prio = 1.0 - # save weight directly in this class instead of self._meta - self.weight = SegmentTree(size) - self.__eps = np.finfo(np.float32).eps.item() - - def add( - self, - obs: Any, - act: Any, - rew: Union[Number, np.number, np.ndarray], - done: Union[Number, np.number, np.bool_], - obs_next: Any = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - weight: Optional[Union[Number, np.number]] = None, - **kwargs: Any, - ) -> None: - """Add a batch of data into replay buffer.""" - if weight is None: - weight = self._max_prio - else: - weight = np.abs(weight) - self._max_prio = max(self._max_prio, weight) - self._min_prio = min(self._min_prio, weight) - self.weight[self._index] = weight ** self._alpha - super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) - - def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with priority probability. - - Return all the data in the buffer if batch_size is 0. - - :return: Sample data and its corresponding index inside the buffer. - - The "weight" in the returned Batch is the weight on loss function - to de-bias the sampling process (some transition tuples are sampled - more often so their losses are weighted less). - """ - assert self._size > 0, "Cannot sample a buffer with 0 size!" - if batch_size == 0: - indice = np.concatenate([ - np.arange(self._index, self._size), - np.arange(0, self._index), - ]) - else: - scalar = np.random.rand(batch_size) * self.weight.reduce() - indice = self.weight.get_prefix_sum_idx(scalar) - batch = self[indice] - # important sampling weight calculation - # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) - # simplified formula: (p_j/p_min)**(-beta) - batch.weight = (batch.weight / self._min_prio) ** (-self._beta) - return batch, indice - - def update_weight( - self, - indice: Union[np.ndarray], - new_weight: Union[np.ndarray, torch.Tensor] - ) -> None: - """Update priority weight by indice in this buffer. - - :param np.ndarray indice: indice you want to update weight. - :param np.ndarray new_weight: new priority weight you want to update. - """ - weight = np.abs(to_numpy(new_weight)) + self.__eps - self.weight[indice] = weight ** self._alpha - self._max_prio = max(self._max_prio, weight.max()) - self._min_prio = min(self._min_prio, weight.min()) - - def __getitem__( - self, index: Union[slice, int, np.integer, np.ndarray] - ) -> Batch: - return Batch( - obs=self.get(index, "obs"), - act=self.act[index], - rew=self.rew[index], - done=self.done[index], - obs_next=self.get(index, "obs_next"), - info=self.get(index, "info"), - policy=self.get(index, "policy"), - weight=self.weight[index], - ) diff --git a/examples/mujoco/runnable/mujoco/__init__.py b/tianshou/data/buffer/__init__.py similarity index 100% rename from examples/mujoco/runnable/mujoco/__init__.py rename to tianshou/data/buffer/__init__.py diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py new file mode 100644 index 000000000..2ee931ec8 --- /dev/null +++ b/tianshou/data/buffer/base.py @@ -0,0 +1,344 @@ +import h5py +import numpy as np +from typing import Any, Dict, List, Tuple, Union, Optional + +from tianshou.data import Batch +from tianshou.data.utils.converter import to_hdf5, from_hdf5 +from tianshou.data.batch import _create_value, _alloc_by_keys_diff + + +class ReplayBuffer: + """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction \ + between the policy and environment. + + ReplayBuffer can be considered as a specialized form (or management) of Batch. It + stores all the data in a batch with circular-queue style. + + For the example usage of ReplayBuffer, please check out Section Buffer in + :doc:`/tutorials/concepts`. + + :param int size: the maximum size of replay buffer. + :param int stack_num: the frame-stack sampling argument, should be greater than or + equal to 1. Default to 1 (no stacking). + :param bool ignore_obs_next: whether to store obs_next. Default to False. + :param bool save_only_last_obs: only save the last obs/obs_next when it has a shape + of (timestep, ...) because of temporal stacking. Default to False. + :param bool sample_avail: the parameter indicating sampling only available index + when using frame-stack sampling method. Default to False. + """ + + _reserved_keys = ("obs", "act", "rew", "done", "obs_next", "info", "policy") + + def __init__( + self, + size: int, + stack_num: int = 1, + ignore_obs_next: bool = False, + save_only_last_obs: bool = False, + sample_avail: bool = False, + **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError + ) -> None: + self.options: Dict[str, Any] = { + "stack_num": stack_num, + "ignore_obs_next": ignore_obs_next, + "save_only_last_obs": save_only_last_obs, + "sample_avail": sample_avail, + } + super().__init__() + self.maxsize = size + assert stack_num > 0, "stack_num should be greater than 0" + self.stack_num = stack_num + self._indices = np.arange(size) + self._save_obs_next = not ignore_obs_next + self._save_only_last_obs = save_only_last_obs + self._sample_avail = sample_avail + self._meta: Batch = Batch() + self.reset() + + def __len__(self) -> int: + """Return len(self).""" + return self._size + + def __repr__(self) -> str: + """Return str(self).""" + return self.__class__.__name__ + self._meta.__repr__()[5:] + + def __getattr__(self, key: str) -> Any: + """Return self.key.""" + try: + return self._meta[key] + except KeyError as e: + raise AttributeError from e + + def __setstate__(self, state: Dict[str, Any]) -> None: + """Unpickling interface. + + We need it because pickling buffer does not work out-of-the-box + ("buffer.__getattr__" is customized). + """ + self.__dict__.update(state) + + def __setattr__(self, key: str, value: Any) -> None: + """Set self.key = value.""" + assert ( + key not in self._reserved_keys + ), "key '{}' is reserved and cannot be assigned".format(key) + super().__setattr__(key, value) + + def save_hdf5(self, path: str) -> None: + """Save replay buffer to HDF5 file.""" + with h5py.File(path, "w") as f: + to_hdf5(self.__dict__, f) + + @classmethod + def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer": + """Load replay buffer from HDF5 file.""" + with h5py.File(path, "r") as f: + buf = cls.__new__(cls) + buf.__setstate__(from_hdf5(f, device=device)) + return buf + + def reset(self) -> None: + """Clear all the data in replay buffer and episode statistics.""" + self.last_index = np.array([0]) + self._index = self._size = 0 + self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 + + def set_batch(self, batch: Batch) -> None: + """Manually choose the batch you want the ReplayBuffer to manage.""" + assert len(batch) == self.maxsize and set(batch.keys()).issubset( + self._reserved_keys + ), "Input batch doesn't meet ReplayBuffer's data form requirement." + self._meta = batch + + def unfinished_index(self) -> np.ndarray: + """Return the index of unfinished episode.""" + last = (self._index - 1) % self._size if self._size else 0 + return np.array([last] if not self.done[last] and self._size else [], np.int) + + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + """Return the index of previous transition. + + The index won't be modified if it is the beginning of an episode. + """ + index = (index - 1) % self._size + end_flag = self.done[index] | (index == self.last_index[0]) + return (index + end_flag) % self._size + + def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + """Return the index of next transition. + + The index won't be modified if it is the end of an episode. + """ + end_flag = self.done[index] | (index == self.last_index[0]) + return (index + (1 - end_flag)) % self._size + + def update(self, buffer: "ReplayBuffer") -> np.ndarray: + """Move the data from the given buffer to current buffer. + + Return the updated indices. If update fails, return an empty array. + """ + if len(buffer) == 0 or self.maxsize == 0: + return np.array([], np.int) + stack_num, buffer.stack_num = buffer.stack_num, 1 + from_indices = buffer.sample_index(0) # get all available indices + buffer.stack_num = stack_num + if len(from_indices) == 0: + return np.array([], np.int) + to_indices = [] + for _ in range(len(from_indices)): + to_indices.append(self._index) + self.last_index[0] = self._index + self._index = (self._index + 1) % self.maxsize + self._size = min(self._size + 1, self.maxsize) + to_indices = np.array(to_indices) + if self._meta.is_empty(): + self._meta = _create_value( # type: ignore + buffer._meta, self.maxsize, stack=False) + self._meta[to_indices] = buffer._meta[from_indices] + return to_indices + + def _add_index( + self, rew: Union[float, np.ndarray], done: bool + ) -> Tuple[int, Union[float, np.ndarray], int, int]: + """Maintain the buffer's state after adding one data batch. + + Return (index_to_be_modified, episode_reward, episode_length, + episode_start_index). + """ + self.last_index[0] = ptr = self._index + self._size = min(self._size + 1, self.maxsize) + self._index = (self._index + 1) % self.maxsize + + self._ep_rew += rew + self._ep_len += 1 + + if done: + result = ptr, self._ep_rew, self._ep_len, self._ep_idx + self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index + return result + else: + return ptr, self._ep_rew * 0.0, 0, self._ep_idx + + def add( + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Add a batch of data into replay buffer. + + :param Batch batch: the input data batch. Its keys must belong to the 7 + reserved keys, and "obs", "act", "rew", "done" is required. + :param buffer_ids: to make consistent with other buffer's add function; if it + is not None, we assume the input batch's first dimension is always 1. + + Return (current_index, episode_reward, episode_length, episode_start_index). If + the episode is not finished, the return value of episode_length and + episode_reward is 0. + """ + # preprocess batch + b = Batch() + for key in set(self._reserved_keys).intersection(batch.keys()): + b.__dict__[key] = batch[key] + batch = b + assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) + stacked_batch = buffer_ids is not None + if stacked_batch: + assert len(batch) == 1 + if self._save_only_last_obs: + batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] + if not self._save_obs_next: + batch.pop("obs_next", None) + elif self._save_only_last_obs: + batch.obs_next = ( + batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] + ) + # get ptr + if stacked_batch: + rew, done = batch.rew[0], batch.done[0] + else: + rew, done = batch.rew, batch.done + ptr, ep_rew, ep_len, ep_idx = list( + map(lambda x: np.array([x]), self._add_index(rew, done)) + ) + try: + self._meta[ptr] = batch + except ValueError: + stack = not stacked_batch + batch.rew = batch.rew.astype(np.float) + batch.done = batch.done.astype(np.bool_) + if self._meta.is_empty(): + self._meta = _create_value( # type: ignore + batch, self.maxsize, stack) + else: # dynamic key pops up in batch + _alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) + self._meta[ptr] = batch + return ptr, ep_rew, ep_len, ep_idx + + def sample_index(self, batch_size: int) -> np.ndarray: + """Get a random sample of index with size = batch_size. + + Return all available indices in the buffer if batch_size is 0; return an empty + numpy array if batch_size < 0 or no available index can be sampled. + """ + if self.stack_num == 1 or not self._sample_avail: # most often case + if batch_size > 0: + return np.random.choice(self._size, batch_size) + elif batch_size == 0: # construct current available indices + return np.concatenate( + [np.arange(self._index, self._size), np.arange(self._index)] + ) + else: + return np.array([], np.int) + else: + if batch_size < 0: + return np.array([], np.int) + all_indices = prev_indices = np.concatenate( + [np.arange(self._index, self._size), np.arange(self._index)] + ) + for _ in range(self.stack_num - 2): + prev_indices = self.prev(prev_indices) + all_indices = all_indices[prev_indices != self.prev(prev_indices)] + if batch_size > 0: + return np.random.choice(all_indices, batch_size) + else: + return all_indices + + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: + """Get a random sample from buffer with size = batch_size. + + Return all the data in the buffer if batch_size is 0. + + :return: Sample data and its corresponding index inside the buffer. + """ + indices = self.sample_index(batch_size) + return self[indices], indices + + def get( + self, + index: Union[int, np.integer, np.ndarray], + key: str, + default_value: Optional[Any] = None, + stack_num: Optional[int] = None, + ) -> Union[Batch, np.ndarray]: + """Return the stacked result. + + E.g., if you set ``key = "obs", stack_num = 4, index = t``, it returns the + stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. + + :param index: the index for getting stacked data. + :param str key: the key to get, should be one of the reserved_keys. + :param default_value: if the given key's data is not found and default_value is + set, return this default_value. + :param int stack_num: Default to self.stack_num. + """ + if key not in self._meta and default_value is not None: + return default_value + val = self._meta[key] + if stack_num is None: + stack_num = self.stack_num + try: + if stack_num == 1: # the most often case + return val[index] + stack: List[Any] = [] + if isinstance(index, list): + indice = np.array(index) + else: + indice = index + for _ in range(stack_num): + stack = [val[indice]] + stack + indice = self.prev(indice) + if isinstance(val, Batch): + return Batch.stack(stack, axis=indice.ndim) + else: + return np.stack(stack, axis=indice.ndim) + except IndexError as e: + if not (isinstance(val, Batch) and val.is_empty()): + raise e # val != Batch() + return Batch() + + def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: + """Return a data batch: self[index]. + + If stack_num is larger than 1, return the stacked obs and obs_next with shape + (batch, len, ...). + """ + if isinstance(index, slice): # change slice to np array + if index == slice(None): # buffer[:] will get all available data + index = self.sample_index(0) + else: + index = self._indices[:len(self)][index] + # raise KeyError first instead of AttributeError, + # to support np.array([ReplayBuffer()]) + obs = self.get(index, "obs") + if self._save_obs_next: + obs_next = self.get(index, "obs_next", Batch()) + else: + obs_next = self.get(self.next(index), "obs", Batch()) + return Batch( + obs=obs, + act=self.act[index], + rew=self.rew[index], + done=self.done[index], + obs_next=obs_next, + info=self.get(index, "info", Batch()), + policy=self.get(index, "policy", Batch()), + ) diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py new file mode 100644 index 000000000..acbae6f9a --- /dev/null +++ b/tianshou/data/buffer/cached.py @@ -0,0 +1,81 @@ +import numpy as np +from typing import List, Tuple, Union, Optional + +from tianshou.data import Batch, ReplayBuffer, ReplayBufferManager + + +class CachedReplayBuffer(ReplayBufferManager): + """CachedReplayBuffer contains a given main buffer and n cached buffers, \ + ``cached_buffer_num * ReplayBuffer(size=max_episode_length)``. + + The memory layout is: ``| main_buffer | cached_buffers[0] | cached_buffers[1] | ... + | cached_buffers[cached_buffer_num - 1] |``. + + The data is first stored in cached buffers. When an episode is terminated, the data + will move to the main buffer and the corresponding cached buffer will be reset. + + :param ReplayBuffer main_buffer: the main buffer whose ``.update()`` function + behaves normally. + :param int cached_buffer_num: number of ReplayBuffer needs to be created for cached + buffer. + :param int max_episode_length: the maximum length of one episode, used in each + cached buffer's maxsize. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__( + self, + main_buffer: ReplayBuffer, + cached_buffer_num: int, + max_episode_length: int, + ) -> None: + assert cached_buffer_num > 0 and max_episode_length > 0 + assert type(main_buffer) == ReplayBuffer + kwargs = main_buffer.options + buffers = [main_buffer] + [ + ReplayBuffer(max_episode_length, **kwargs) + for _ in range(cached_buffer_num) + ] + super().__init__(buffer_list=buffers) + self.main_buffer = self.buffers[0] + self.cached_buffers = self.buffers[1:] + self.cached_buffer_num = cached_buffer_num + + def add( + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Add a batch of data into CachedReplayBuffer. + + Each of the data's length (first dimension) must equal to the length of + buffer_ids. By default the buffer_ids is [0, 1, ..., cached_buffer_num - 1]. + + Return (current_index, episode_reward, episode_length, episode_start_index) + with each of the shape (len(buffer_ids), ...), where (current_index[i], + episode_reward[i], episode_length[i], episode_start_index[i]) refers to the + cached_buffer_ids[i]th cached buffer's corresponding episode result. + """ + if buffer_ids is None: + buffer_ids = np.arange(1, 1 + self.cached_buffer_num) + else: # make sure it is np.ndarray + buffer_ids = np.asarray(buffer_ids) + 1 + ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buffer_ids) + # find the terminated episode, move data from cached buf to main buf + updated_ptr, updated_ep_idx = [], [] + done = batch.done.astype(np.bool_) + for buffer_idx in buffer_ids[done]: + index = self.main_buffer.update(self.buffers[buffer_idx]) + if len(index) == 0: # unsuccessful move, replace with -1 + index = [-1] + updated_ep_idx.append(index[0]) + updated_ptr.append(index[-1]) + self.buffers[buffer_idx].reset() + self._lengths[0] = len(self.main_buffer) + self._lengths[buffer_idx] = 0 + self.last_index[0] = index[-1] + self.last_index[buffer_idx] = self._offset[buffer_idx] + ptr[done] = updated_ptr + ep_idx[done] = updated_ep_idx + return ptr, ep_rew, ep_len, ep_idx diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py new file mode 100644 index 000000000..ccd03eb98 --- /dev/null +++ b/tianshou/data/buffer/manager.py @@ -0,0 +1,232 @@ +import numpy as np +from numba import njit +from typing import List, Tuple, Union, Sequence, Optional + +from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data.batch import _create_value, _alloc_by_keys_diff + + +class ReplayBufferManager(ReplayBuffer): + """ReplayBufferManager contains a list of ReplayBuffer with exactly the same \ + configuration. + + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + + :param buffer_list: a list of ReplayBuffer needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, buffer_list: List[ReplayBuffer]) -> None: + self.buffer_num = len(buffer_list) + self.buffers = np.array(buffer_list, dtype=np.object) + offset, size = [], 0 + buffer_type = type(self.buffers[0]) + kwargs = self.buffers[0].options + for buf in self.buffers: + assert buf._meta.is_empty() + assert isinstance(buf, buffer_type) and buf.options == kwargs + offset.append(size) + size += buf.maxsize + self._offset = np.array(offset) + self._extend_offset = np.array(offset + [size]) + self._lengths = np.zeros_like(offset) + super().__init__(size=size, **kwargs) + self._compile() + self._meta: Batch + + def _compile(self) -> None: + lens = last = index = np.array([0]) + offset = np.array([0, 1]) + done = np.array([False, False]) + _prev_index(index, offset, done, last, lens) + _next_index(index, offset, done, last, lens) + + def __len__(self) -> int: + return self._lengths.sum() + + def reset(self) -> None: + self.last_index = self._offset.copy() + self._lengths = np.zeros_like(self._offset) + for buf in self.buffers: + buf.reset() + + def _set_batch_for_children(self) -> None: + for offset, buf in zip(self._offset, self.buffers): + buf.set_batch(self._meta[offset:offset + buf.maxsize]) + + def set_batch(self, batch: Batch) -> None: + super().set_batch(batch) + self._set_batch_for_children() + + def unfinished_index(self) -> np.ndarray: + return np.concatenate([ + buf.unfinished_index() + offset + for offset, buf in zip(self._offset, self.buffers) + ]) + + def prev(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + if isinstance(index, (list, np.ndarray)): + return _prev_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) + else: + return _prev_index(np.array([index]), self._extend_offset, + self.done, self.last_index, self._lengths)[0] + + def next(self, index: Union[int, np.integer, np.ndarray]) -> np.ndarray: + if isinstance(index, (list, np.ndarray)): + return _next_index(np.asarray(index), self._extend_offset, + self.done, self.last_index, self._lengths) + else: + return _next_index(np.array([index]), self._extend_offset, + self.done, self.last_index, self._lengths)[0] + + def update(self, buffer: ReplayBuffer) -> np.ndarray: + """The ReplayBufferManager cannot be updated by any buffer.""" + raise NotImplementedError + + def add( + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Add a batch of data into ReplayBufferManager. + + Each of the data's length (first dimension) must equal to the length of + buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. + + Return (current_index, episode_reward, episode_length, episode_start_index). If + the episode is not finished, the return value of episode_length and + episode_reward is 0. + """ + # preprocess batch + b = Batch() + for key in set(self._reserved_keys).intersection(batch.keys()): + b.__dict__[key] = batch[key] + batch = b + assert set(["obs", "act", "rew", "done"]).issubset(batch.keys()) + if self._save_only_last_obs: + batch.obs = batch.obs[:, -1] + if not self._save_obs_next: + batch.pop("obs_next", None) + elif self._save_only_last_obs: + batch.obs_next = batch.obs_next[:, -1] + # get index + if buffer_ids is None: + buffer_ids = np.arange(self.buffer_num) + ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] + for batch_idx, buffer_id in enumerate(buffer_ids): + ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( + batch.rew[batch_idx], batch.done[batch_idx] + ) + ptrs.append(ptr + self._offset[buffer_id]) + ep_lens.append(ep_len) + ep_rews.append(ep_rew) + ep_idxs.append(ep_idx + self._offset[buffer_id]) + self.last_index[buffer_id] = ptr + self._offset[buffer_id] + self._lengths[buffer_id] = len(self.buffers[buffer_id]) + ptrs = np.array(ptrs) + try: + self._meta[ptrs] = batch + except ValueError: + batch.rew = batch.rew.astype(np.float) + batch.done = batch.done.astype(np.bool_) + if self._meta.is_empty(): + self._meta = _create_value( # type: ignore + batch, self.maxsize, stack=False) + else: # dynamic key pops up in batch + _alloc_by_keys_diff(self._meta, batch, self.maxsize, False) + self._set_batch_for_children() + self._meta[ptrs] = batch + return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) + + def sample_index(self, batch_size: int) -> np.ndarray: + if batch_size < 0: + return np.array([], np.int) + if self._sample_avail and self.stack_num > 1: + all_indices = np.concatenate([ + buf.sample_index(0) + offset + for offset, buf in zip(self._offset, self.buffers) + ]) + if batch_size == 0: + return all_indices + else: + return np.random.choice(all_indices, batch_size) + if batch_size == 0: # get all available indices + sample_num = np.zeros(self.buffer_num, np.int) + else: + buffer_idx = np.random.choice( + self.buffer_num, batch_size, p=self._lengths / self._lengths.sum() + ) + sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) + # avoid batch_size > 0 and sample_num == 0 -> get child's all data + sample_num[sample_num == 0] = -1 + + return np.concatenate([ + buf.sample_index(bsz) + offset + for offset, buf, bsz in zip(self._offset, self.buffers, sample_num) + ]) + + +class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): + """PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with \ + exactly the same configuration. + + These replay buffers have contiguous memory layout, and the storage space each + buffer has is a shallow copy of the topmost memory. + + :param buffer_list: a list of PrioritizedReplayBuffer needed to be handled. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: + ReplayBufferManager.__init__(self, buffer_list) # type: ignore + kwargs = buffer_list[0].options + for buf in buffer_list: + del buf.weight + PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) + + +@njit +def _prev_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + prev_index = np.zeros_like(index) + for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (start <= index) & (index < end) + cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] + subind = (subind - start - 1) % cur_len + end_flag = done[subind + start] | (subind + start == last) + prev_index[mask] = (subind + end_flag) % cur_len + start + return prev_index + + +@njit +def _next_index( + index: np.ndarray, + offset: np.ndarray, + done: np.ndarray, + last_index: np.ndarray, + lengths: np.ndarray, +) -> np.ndarray: + index = index % offset[-1] + next_index = np.zeros_like(index) + for start, end, cur_len, last in zip(offset[:-1], offset[1:], lengths, last_index): + mask = (start <= index) & (index < end) + cur_len = max(1, cur_len) + if np.sum(mask) > 0: + subind = index[mask] + end_flag = done[subind] | (subind == last) + next_index[mask] = (subind - start + 1 - end_flag) % cur_len + start + return next_index diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py new file mode 100644 index 000000000..46c0be5e4 --- /dev/null +++ b/tianshou/data/buffer/prio.py @@ -0,0 +1,82 @@ +import torch +import numpy as np +from typing import Any, List, Tuple, Union, Optional + +from tianshou.data import Batch, SegmentTree, to_numpy, ReplayBuffer + + +class PrioritizedReplayBuffer(ReplayBuffer): + """Implementation of Prioritized Experience Replay. arXiv:1511.05952. + + :param float alpha: the prioritization exponent. + :param float beta: the importance sample soft coefficient. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: + # will raise KeyError in PrioritizedVectorReplayBuffer + # super().__init__(size, **kwargs) + ReplayBuffer.__init__(self, size, **kwargs) + assert alpha > 0.0 and beta >= 0.0 + self._alpha, self._beta = alpha, beta + self._max_prio = self._min_prio = 1.0 + # save weight directly in this class instead of self._meta + self.weight = SegmentTree(size) + self.__eps = np.finfo(np.float32).eps.item() + self.options.update(alpha=alpha, beta=beta) + + def init_weight(self, index: Union[int, np.ndarray]) -> None: + self.weight[index] = self._max_prio ** self._alpha + + def update(self, buffer: ReplayBuffer) -> np.ndarray: + indices = super().update(buffer) + self.init_weight(indices) + + def add( + self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) + self.init_weight(ptr) + return ptr, ep_rew, ep_len, ep_idx + + def sample_index(self, batch_size: int) -> np.ndarray: + if batch_size > 0 and len(self) > 0: + scalar = np.random.rand(batch_size) * self.weight.reduce() + return self.weight.get_prefix_sum_idx(scalar) + else: + return super().sample_index(batch_size) + + def get_weight( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> np.ndarray: + """Get the importance sampling weight. + + The "weight" in the returned Batch is the weight on loss function to de-bias + the sampling process (some transition tuples are sampled more often so their + losses are weighted less). + """ + # important sampling weight calculation + # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) + # simplified formula: (p_j/p_min)**(-beta) + return (self.weight[index] / self._min_prio) ** (-self._beta) + + def update_weight( + self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor] + ) -> None: + """Update priority weight by index in this buffer. + + :param np.ndarray index: index you want to update weight. + :param np.ndarray new_weight: new priority weight you want to update. + """ + weight = np.abs(to_numpy(new_weight)) + self.__eps + self.weight[index] = weight ** self._alpha + self._max_prio = max(self._max_prio, weight.max()) + self._min_prio = min(self._min_prio, weight.min()) + + def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: + batch = super().__getitem__(index) + batch.weight = self.get_weight(index) + return batch diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py new file mode 100644 index 000000000..1cfeae9d6 --- /dev/null +++ b/tianshou/data/buffer/vecbuf.py @@ -0,0 +1,57 @@ +import numpy as np +from typing import Any + +from tianshou.data import ReplayBuffer, ReplayBufferManager +from tianshou.data import PrioritizedReplayBuffer, PrioritizedReplayBufferManager + + +class VectorReplayBuffer(ReplayBufferManager): + """VectorReplayBuffer contains n ReplayBuffer with the same size. + + It is used for storing transition from different environments yet keeping the order + of time. + + :param int total_size: the total size of VectorReplayBuffer. + :param int buffer_num: the number of ReplayBuffer it uses, which are under the same + configuration. + + Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail) + are the same as :class:`~tianshou.data.ReplayBuffer`. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) + buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)] + super().__init__(buffer_list) + + +class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): + """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size. + + It is used for storing transition from different environments yet keeping the order + of time. + + :param int total_size: the total size of PrioritizedVectorReplayBuffer. + :param int buffer_num: the number of PrioritizedReplayBuffer it uses, which are + under the same configuration. + + Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/ + sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. + """ + + def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None: + assert buffer_num > 0 + size = int(np.ceil(total_size / buffer_num)) + buffer_list = [ + PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num) + ] + super().__init__(buffer_list) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 25f268dc0..3a1b05d26 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -3,77 +3,47 @@ import torch import warnings import numpy as np -from copy import deepcopy -from numbers import Number -from typing import Dict, List, Union, Optional, Callable +from typing import Any, Dict, List, Union, Optional, Callable from tianshou.policy import BasePolicy -from tianshou.exploration import BaseNoise -from tianshou.data.batch import _create_value +from tianshou.data.batch import _alloc_by_keys_diff from tianshou.env import BaseVectorEnv, DummyVectorEnv -from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy +from tianshou.data import ( + Batch, + ReplayBuffer, + ReplayBufferManager, + VectorReplayBuffer, + CachedReplayBuffer, + to_numpy, +) class Collector(object): - """Collector enables the policy to interact with different types of envs. + """Collector enables the policy to interact with different types of envs with \ + exact number of steps or episodes. - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` - class. If set to ``None`` (testing phase), it will not store the data. - :param function preprocess_fn: a function called before the data has been - added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults - to None. - :param BaseNoise action_noise: add a noise to continuous action. Normally - a policy already has a noise param for exploration in training phase, - so this is recommended to use in test collector for some purpose. - :param function reward_metric: to be used in multi-agent RL. The reward to - report is of shape [agent_num], but we need to return a single scalar - to monitor training. This function specifies what is the desired - metric, e.g., the reward of agent 1 or the average reward over all - agents. By default, the behavior is to select the reward of agent 1. - - The ``preprocess_fn`` is a function called before the data has been added - to the buffer with batch format, which receives up to 7 keys as listed in - :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the - collector resets the environment. It returns either a dict or a - :class:`~tianshou.data.Batch` with the modified keys and values. Examples - are in "test/base/test_collector.py". - - Here is the example: - :: - - policy = PGPolicy(...) # or other policies if you wish - env = gym.make('CartPole-v0') - replay_buffer = ReplayBuffer(size=10000) - # here we set up a collector with a single environment - collector = Collector(policy, env, buffer=replay_buffer) - - # the collector supports vectorized environments as well - envs = DummyVectorEnv([lambda: gym.make('CartPole-v0') - for _ in range(3)]) - collector = Collector(policy, envs, buffer=replay_buffer) - - # collect 3 episodes - collector.collect(n_episode=3) - # collect 1 episode for the first env, 3 for the third env - collector.collect(n_episode=[1, 0, 3]) - # collect at least 2 steps - collector.collect(n_step=2) - # collect episodes with visual rendering (the render argument is the - # sleep time between rendering consecutive frames) - collector.collect(n_episode=1, render=0.03) - - Collected data always consist of full episodes. So if only ``n_step`` - argument is give, the collector may return the data more than the - ``n_step`` limitation. Same as ``n_episode`` for the multiple environment - case. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, it will not store the data. Default to None. + :param function preprocess_fn: a function called before the data has been added to + the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None. + :param bool exploration_noise: determine whether the action needs to be modified + with corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + + The "preprocess_fn" is a function called before the data has been added to the + buffer with batch format. It will receive with only "obs" when the collector resets + the environment, and will receive four keys "obs_next", "rew", "done", "info" in a + normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with + the modified keys and values. Examples are in "test/base/test_collector.py". .. note:: - Please make sure the given environment has a time limitation. + Please make sure the given environment has a time limitation if using n_episode + collect option. """ def __init__( @@ -82,322 +52,449 @@ def __init__( env: Union[gym.Env, BaseVectorEnv], buffer: Optional[ReplayBuffer] = None, preprocess_fn: Optional[Callable[..., Batch]] = None, - action_noise: Optional[BaseNoise] = None, - reward_metric: Optional[Callable[[np.ndarray], float]] = None, + exploration_noise: bool = False, ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) self.env = env self.env_num = len(env) - # environments that are available in step() - # this means all environments in synchronous simulation - # but only a subset of environments in asynchronous simulation - self._ready_env_ids = np.arange(self.env_num) - # self.async is a flag to indicate whether this collector works - # with asynchronous simulation - self.is_async = env.is_async - # need cache buffers before storing in the main buffer - self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)] - self.buffer = buffer + self.exploration_noise = exploration_noise + self._assign_buffer(buffer) self.policy = policy self.preprocess_fn = preprocess_fn - self.process_fn = policy.process_fn self._action_space = env.action_space - self._action_noise = action_noise - self._rew_metric = reward_metric or Collector._default_rew_metric # avoid creating attribute outside __init__ self.reset() - @staticmethod - def _default_rew_metric( - x: Union[Number, np.number] - ) -> Union[Number, np.number]: - # this internal function is designed for single-agent RL - # for multi-agent RL, a reward_metric must be provided - assert np.asanyarray(x).size == 1, ( - "Please specify the reward_metric " - "since the reward is not a scalar." - ) - return x + def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: + """Check if the buffer matches the constraint.""" + if buffer is None: + buffer = VectorReplayBuffer(self.env_num, self.env_num) + elif isinstance(buffer, ReplayBufferManager): + assert buffer.buffer_num >= self.env_num + if isinstance(buffer, CachedReplayBuffer): + assert buffer.cached_buffer_num >= self.env_num + else: # ReplayBuffer or PrioritizedReplayBuffer + assert buffer.maxsize > 0 + if self.env_num > 1: + if type(buffer) == ReplayBuffer: + buffer_type = "ReplayBuffer" + vector_type = "VectorReplayBuffer" + else: + buffer_type = "PrioritizedReplayBuffer" + vector_type = "PrioritizedVectorReplayBuffer" + raise TypeError( + f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " + f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" + f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead." + ) + self.buffer = buffer def reset(self) -> None: """Reset all related variables in the collector.""" - # use empty Batch for ``state`` so that ``self.data`` supports slicing + # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy - self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={}, - obs_next={}, policy={}) + self.data = Batch(obs={}, act={}, rew={}, done={}, + obs_next={}, info={}, policy={}) self.reset_env() self.reset_buffer() self.reset_stat() - if self._action_noise is not None: - self._action_noise.reset() def reset_stat(self) -> None: """Reset the statistic variables.""" - self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 def reset_buffer(self) -> None: - """Reset the main data buffer.""" - if self.buffer is not None: - self.buffer.reset() - - def get_env_num(self) -> int: - """Return the number of environments the collector have.""" - return self.env_num + """Reset the data buffer.""" + self.buffer.reset() def reset_env(self) -> None: - """Reset all of the environment(s)' states and the cache buffers.""" - self._ready_env_ids = np.arange(self.env_num) + """Reset all of the environments.""" obs = self.env.reset() if self.preprocess_fn: obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs - for b in self._cached_buf: - b.reset() def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset the hidden state: self.data.state[id].""" - state = self.data.state # it is a reference - if isinstance(state, torch.Tensor): - state[id].zero_() - elif isinstance(state, np.ndarray): - state[id] = None if state.dtype == np.object else 0 - elif isinstance(state, Batch): - state.empty_(id) + if hasattr(self.data.policy, "hidden_state"): + state = self.data.policy.hidden_state # it is a reference + if isinstance(state, torch.Tensor): + state[id].zero_() + elif isinstance(state, np.ndarray): + state[id] = None if state.dtype == np.object else 0 + elif isinstance(state, Batch): + state.empty_(id) def collect( self, n_step: Optional[int] = None, - n_episode: Optional[Union[int, List[int]]] = None, + n_episode: Optional[int] = None, random: bool = False, render: Optional[float] = None, no_grad: bool = True, - ) -> Dict[str, float]: + ) -> Dict[str, Any]: """Collect a specified number of step or episode. + To ensure unbiased sampling result with n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + :param int n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. If it is an - int, it means to collect at lease ``n_episode`` episodes; if it is - a list, it means to collect exactly ``n_episode[i]`` episodes in - the i-th environment - :param bool random: whether to use random policy for collecting data, - defaults to False. - :param float render: the sleep time between rendering consecutive - frames, defaults to None (no rendering). - :param bool no_grad: whether to retain gradient in policy.forward, - defaults to True (no gradient retaining). + :param int n_episode: how many episodes you want to collect. + :param bool random: whether to use random policy for collecting data. Default + to False. + :param float render: the sleep time between rendering consecutive frames. + Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward(). Default to + True (no gradient retaining). .. note:: - One and only one collection number specification is permitted, - either ``n_step`` or ``n_episode``. + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. :return: A dict including the following keys - * ``n/ep`` the collected number of episodes. - * ``n/st`` the collected number of steps. - * ``v/st`` the speed of steps per second. - * ``v/ep`` the speed of episode per second. - * ``rew`` the mean reward over collected episodes. - * ``len`` the mean length over collected episodes. + * ``n/ep`` collected number of episodes. + * ``n/st`` collected number of steps. + * ``rews`` array of episode reward over collected episodes. + * ``lens`` array of episode length over collected episodes. + * ``idxs`` array of episode start index in buffer over collected episodes. """ - assert (n_step is not None and n_episode is None and n_step > 0) or ( - n_step is None and n_episode is not None and np.sum(n_episode) > 0 - ), "Only one of n_step or n_episode is allowed in Collector.collect, " - f"got n_step = {n_step}, n_episode = {n_episode}." + assert not self.env.is_async, "Please use AsyncCollector if using async venv." + if n_step is not None: + assert n_episode is None, ( + f"Only one of n_step or n_episode is allowed in Collector." + f"collect, got n_step={n_step}, n_episode={n_episode}." + ) + assert n_step > 0 + if not n_step % self.env_num == 0: + warnings.warn( + f"n_step={n_step} is not a multiple of #env ({self.env_num}), " + "which may cause extra transitions collected into the buffer." + ) + ready_env_ids = np.arange(self.env_num) + elif n_episode is not None: + assert n_episode > 0 + ready_env_ids = np.arange(min(self.env_num, n_episode)) + self.data = self.data[:min(self.env_num, n_episode)] + else: + raise TypeError("Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect().") + start_time = time.time() + step_count = 0 - # episode of each environment - episode_count = np.zeros(self.env_num) - # If n_episode is a list, and some envs have collected the required - # number of episodes, these envs will be recorded in this list, and - # they will not be stepped. - finished_env_ids = [] - rewards = [] - whole_data = Batch() - if isinstance(n_episode, list): - assert len(n_episode) == self.get_env_num() - finished_env_ids = [ - i for i in self._ready_env_ids if n_episode[i] <= 0] - self._ready_env_ids = np.array( - [x for x in self._ready_env_ids if x not in finished_env_ids]) + episode_count = 0 + episode_rews = [] + episode_lens = [] + episode_start_indices = [] + while True: - if step_count >= 100000 and episode_count.sum() == 0: - warnings.warn( - "There are already many steps in an episode. " - "You should add a time limitation to your environment!", - Warning) - - is_async = self.is_async or len(finished_env_ids) > 0 - if is_async: - # self.data are the data for all environments in async - # simulation or some envs have finished, - # **only a subset of data are disposed**, - # so we store the whole data in ``whole_data``, let self.data - # to be the data available in ready environments, and finally - # set these back into all the data - whole_data = self.data - self.data = self.data[self._ready_env_ids] - - # restore the state and the input data - last_state = self.data.state - if isinstance(last_state, Batch) and last_state.is_empty(): - last_state = None - self.data.update(state=Batch(), obs_next=Batch(), policy=Batch()) - - # calculate the next action + assert len(self.data) == len(ready_env_ids) + # restore the state: if the last state is None, it won't store + last_state = self.data.policy.pop("hidden_state", None) + + # get the next action if random: - spaces = self._action_space - result = Batch( - act=[spaces[i].sample() for i in self._ready_env_ids]) + self.data.update( + act=[self._action_space[i].sample() for i in ready_env_ids]) else: if no_grad: with torch.no_grad(): # faster than retain_grad version + # self.data.obs will be used by agent to get result result = self.policy(self.data, last_state) else: result = self.policy(self.data, last_state) - - state = result.get("state", Batch()) - # convert None to Batch(), since None is reserved for 0-init - if state is None: - state = Batch() - self.data.update(state=state, policy=result.get("policy", Batch())) - # save hidden state to policy._state, in order to save into buffer - if not (isinstance(state, Batch) and state.is_empty()): - self.data.policy._state = self.data.state - - self.data.act = to_numpy(result.act) - if self._action_noise is not None: - assert isinstance(self.data.act, np.ndarray) - self.data.act += self._action_noise(self.data.act.shape) + # update state / act / policy into self.data + policy = result.get("policy", Batch()) + assert isinstance(policy, Batch) + state = result.get("state", None) + if state is not None: + policy.hidden_state = state # save state into buffer + act = to_numpy(result.act) + if self.exploration_noise: + act = self.policy.exploration_noise(act, self.data) + self.data.update(policy=policy, act=act) # step in env - if not is_async: - obs_next, rew, done, info = self.env.step(self.data.act) - else: - # store computed actions, states, etc - _batch_set_item( - whole_data, self._ready_env_ids, self.data, self.env_num) - # fetch finished data - obs_next, rew, done, info = self.env.step( - self.data.act, id=self._ready_env_ids) - self._ready_env_ids = np.array([i["env_id"] for i in info]) - # get the stepped data - self.data = whole_data[self._ready_env_ids] - # move data to self.data + obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) + self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) + if self.preprocess_fn: + self.data.update(self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + )) if render: self.env.render() - time.sleep(render) + if render > 0 and not np.isclose(render, 0): + time.sleep(render) # add data into the buffer - if self.preprocess_fn: - result = self.preprocess_fn(**self.data) # type: ignore - self.data.update(result) - - for j, i in enumerate(self._ready_env_ids): - # j is the index in current ready_env_ids - # i is the index in all environments - if self.buffer is None: - # users do not want to store data, so we store - # small fake data here to make the code clean - self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0) - else: - self._cached_buf[i].add(**self.data[j]) - - if done[j]: - if not (isinstance(n_episode, list) - and episode_count[i] >= n_episode[i]): - episode_count[i] += 1 - rewards.append(self._rew_metric( - np.sum(self._cached_buf[i].rew, axis=0))) - step_count += len(self._cached_buf[i]) - if self.buffer is not None: - self.buffer.update(self._cached_buf[i]) - if isinstance(n_episode, list) and \ - episode_count[i] >= n_episode[i]: - # env i has collected enough data, it has finished - finished_env_ids.append(i) - self._cached_buf[i].reset() - self._reset_state(j) - obs_next = self.data.obs_next - if sum(done): + ptr, ep_rew, ep_len, ep_idx = self.buffer.add( + self.data, buffer_ids=ready_env_ids) + + # collect statistics + step_count += len(ready_env_ids) + + if np.any(done): env_ind_local = np.where(done)[0] - env_ind_global = self._ready_env_ids[env_ind_local] + env_ind_global = ready_env_ids[env_ind_local] + episode_count += len(env_ind_local) + episode_lens.append(ep_len[env_ind_local]) + episode_rews.append(ep_rew[env_ind_local]) + episode_start_indices.append(ep_idx[env_ind_local]) + # now we copy obs_next to obs, but since there might be + # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: - obs_reset = self.preprocess_fn( - obs=obs_reset).get("obs", obs_reset) - obs_next[env_ind_local] = obs_reset - self.data.obs = obs_next - if is_async: - # set data back - whole_data = deepcopy(whole_data) # avoid reference in ListBuf - _batch_set_item( - whole_data, self._ready_env_ids, self.data, self.env_num) - # let self.data be the data in all environments again - self.data = whole_data - self._ready_env_ids = np.array( - [x for x in self._ready_env_ids if x not in finished_env_ids]) - if n_step: - if step_count >= n_step: - break - else: - if isinstance(n_episode, int) and \ - episode_count.sum() >= n_episode: - break - if isinstance(n_episode, list) and \ - (episode_count >= n_episode).all(): - break - - # finished envs are ready, and can be used for the next collection - self._ready_env_ids = np.array( - self._ready_env_ids.tolist() + finished_env_ids) - - # generate the statistics - episode_count = sum(episode_count) - duration = max(time.time() - start_time, 1e-9) + obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) + self.data.obs_next[env_ind_local] = obs_reset + for i in env_ind_local: + self._reset_state(i) + + # remove surplus env id from ready_env_ids + # to avoid bias in selecting environments + if n_episode: + surplus_env_num = len(ready_env_ids) - (n_episode - episode_count) + if surplus_env_num > 0: + mask = np.ones_like(ready_env_ids, np.bool) + mask[env_ind_local[:surplus_env_num]] = False + ready_env_ids = ready_env_ids[mask] + self.data = self.data[mask] + + self.data.obs = self.data.obs_next + + if (n_step and step_count >= n_step) or \ + (n_episode and episode_count >= n_episode): + break + + # generate statistics self.collect_step += step_count self.collect_episode += episode_count - self.collect_time += duration + self.collect_time += max(time.time() - start_time, 1e-9) + + if n_episode: + self.data = Batch(obs={}, act={}, rew={}, done={}, + obs_next={}, info={}, policy={}) + self.reset_env() + + if episode_count > 0: + rews, lens, idxs = list(map( + np.concatenate, [episode_rews, episode_lens, episode_start_indices])) + else: + rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int) + return { "n/ep": episode_count, "n/st": step_count, - "v/st": step_count / duration, - "v/ep": episode_count / duration, - "rew": np.mean(rewards), - "rew_std": np.std(rewards), - "len": step_count / episode_count, + "rews": rews, + "lens": lens, + "idxs": idxs, } -def _batch_set_item( - source: Batch, indices: np.ndarray, target: Batch, size: int -) -> None: - # for any key chain k, there are four cases - # 1. source[k] is non-reserved, but target[k] does not exist or is reserved - # 2. source[k] does not exist or is reserved, but target[k] is non-reserved - # 3. both source[k] and target[k] are non-reserved - # 4. both source[k] and target[k] do not exist or are reserved, do nothing. - # A special case in case 4, if target[k] is reserved but source[k] does - # not exist, make source[k] reserved, too. - for k, vt in target.items(): - if not isinstance(vt, Batch) or not vt.is_empty(): - # target[k] is non-reserved - vs = source.get(k, Batch()) - if isinstance(vs, Batch): - if vs.is_empty(): - # case 2, use __dict__ to avoid many type checks - source.__dict__[k] = _create_value(vt[0], size) +class AsyncCollector(Collector): + """Async Collector handles async vector environment. + + The arguments are exactly the same as :class:`~tianshou.data.Collector`, please + refer to :class:`~tianshou.data.Collector` for more detailed explanation. + """ + + def __init__( + self, + policy: BasePolicy, + env: BaseVectorEnv, + buffer: Optional[ReplayBuffer] = None, + preprocess_fn: Optional[Callable[..., Batch]] = None, + exploration_noise: bool = False, + ) -> None: + assert env.is_async + super().__init__(policy, env, buffer, preprocess_fn, exploration_noise) + + def reset_env(self) -> None: + super().reset_env() + self._ready_env_ids = np.arange(self.env_num) + + def collect( + self, + n_step: Optional[int] = None, + n_episode: Optional[int] = None, + random: bool = False, + render: Optional[float] = None, + no_grad: bool = True, + ) -> Dict[str, Any]: + """Collect a specified number of step or episode with async env setting. + + This function doesn't collect exactly n_step or n_episode number of + transitions. Instead, in order to support async setting, it may collect more + than given n_step or n_episode transitions and save into buffer. + + :param int n_step: how many steps you want to collect. + :param int n_episode: how many episodes you want to collect. + :param bool random: whether to use random policy for collecting data. Default + to False. + :param float render: the sleep time between rendering consecutive frames. + Default to None (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward(). Default to + True (no gradient retaining). + + .. note:: + + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. + + :return: A dict including the following keys + + * ``n/ep`` collected number of episodes. + * ``n/st`` collected number of steps. + * ``rews`` array of episode reward over collected episodes. + * ``lens`` array of episode length over collected episodes. + * ``idxs`` array of episode start index in buffer over collected episodes. + """ + # collect at least n_step or n_episode + if n_step is not None: + assert n_episode is None, ( + "Only one of n_step or n_episode is allowed in Collector." + f"collect, got n_step={n_step}, n_episode={n_episode}." + ) + assert n_step > 0 + elif n_episode is not None: + assert n_episode > 0 + else: + raise TypeError("Please specify at least one (either n_step or n_episode) " + "in AsyncCollector.collect().") + warnings.warn("Using async setting may collect extra transitions into buffer.") + + ready_env_ids = self._ready_env_ids + + start_time = time.time() + + step_count = 0 + episode_count = 0 + episode_rews = [] + episode_lens = [] + episode_start_indices = [] + + while True: + whole_data = self.data + self.data = self.data[ready_env_ids] + assert len(whole_data) == self.env_num # major difference + # restore the state: if the last state is None, it won't store + last_state = self.data.policy.pop("hidden_state", None) + + # get the next action + if random: + self.data.update( + act=[self._action_space[i].sample() for i in ready_env_ids]) + else: + if no_grad: + with torch.no_grad(): # faster than retain_grad version + # self.data.obs will be used by agent to get result + result = self.policy(self.data, last_state) else: - assert isinstance(vt, Batch) - _batch_set_item(source.__dict__[k], indices, vt, size) + result = self.policy(self.data, last_state) + # update state / act / policy into self.data + policy = result.get("policy", Batch()) + assert isinstance(policy, Batch) + state = result.get("state", None) + if state is not None: + policy.hidden_state = state # save state into buffer + act = to_numpy(result.act) + if self.exploration_noise: + act = self.policy.exploration_noise(act, self.data) + self.data.update(policy=policy, act=act) + + # save act/policy before env.step + try: + whole_data.act[ready_env_ids] = self.data.act + whole_data.policy[ready_env_ids] = self.data.policy + except ValueError: + _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) + whole_data[ready_env_ids] = self.data # lots of overhead + + # step in env + obs_next, rew, done, info = self.env.step(self.data.act, id=ready_env_ids) + + # change self.data here because ready_env_ids has changed + ready_env_ids = np.array([i["env_id"] for i in info]) + self.data = whole_data[ready_env_ids] + + self.data.update(obs_next=obs_next, rew=rew, done=done, info=info) + if self.preprocess_fn: + self.data.update(self.preprocess_fn( + obs_next=self.data.obs_next, + rew=self.data.rew, + done=self.data.done, + info=self.data.info, + )) + + if render: + self.env.render() + if render > 0 and not np.isclose(render, 0): + time.sleep(render) + + # add data into the buffer + ptr, ep_rew, ep_len, ep_idx = self.buffer.add( + self.data, buffer_ids=ready_env_ids) + + # collect statistics + step_count += len(ready_env_ids) + + if np.any(done): + env_ind_local = np.where(done)[0] + env_ind_global = ready_env_ids[env_ind_local] + episode_count += len(env_ind_local) + episode_lens.append(ep_len[env_ind_local]) + episode_rews.append(ep_rew[env_ind_local]) + episode_start_indices.append(ep_idx[env_ind_local]) + # now we copy obs_next to obs, but since there might be + # finished episodes, we have to reset finished envs first. + obs_reset = self.env.reset(env_ind_global) + if self.preprocess_fn: + obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) + self.data.obs_next[env_ind_local] = obs_reset + for i in env_ind_local: + self._reset_state(i) + + try: + whole_data.obs[ready_env_ids] = self.data.obs_next + whole_data.rew[ready_env_ids] = self.data.rew + whole_data.done[ready_env_ids] = self.data.done + whole_data.info[ready_env_ids] = self.data.info + except ValueError: + _alloc_by_keys_diff(whole_data, self.data, self.env_num, False) + self.data.obs = self.data.obs_next + whole_data[ready_env_ids] = self.data # lots of overhead + self.data = whole_data + + if (n_step and step_count >= n_step) or \ + (n_episode and episode_count >= n_episode): + break + + self._ready_env_ids = ready_env_ids + + # generate statistics + self.collect_step += step_count + self.collect_episode += episode_count + self.collect_time += max(time.time() - start_time, 1e-9) + + if episode_count > 0: + rews, lens, idxs = list(map( + np.concatenate, [episode_rews, episode_lens, episode_start_indices])) else: - # target[k] is reserved - # case 1 or special case of case 4 - if k not in source.__dict__: - source.__dict__[k] = Batch() - continue - source.__dict__[k][indices] = vt + rews, lens, idxs = np.array([]), np.array([], np.int), np.array([], np.int) + + return { + "n/ep": episode_count, + "n/st": step_count, + "rews": rews, + "lens": lens, + "idxs": idxs, + } diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index ddf7acaf2..36c613e5f 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -43,8 +43,7 @@ def seed(self, seed): Otherwise, the outputs of these envs may be the same with each other. - :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith - env. + :param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith env. :param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a worker which contains the i-th env. :param int wait_num: use in asynchronous simulation if the time cost of @@ -75,13 +74,11 @@ def __init__( self.env_num = len(env_fns) self.wait_num = wait_num or len(env_fns) - assert ( - 1 <= self.wait_num <= len(env_fns) - ), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" + assert 1 <= self.wait_num <= len(env_fns), \ + f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" self.timeout = timeout - assert ( - self.timeout is None or self.timeout > 0 - ), f"timeout is {timeout}, it should be positive if provided!" + assert self.timeout is None or self.timeout > 0, \ + f"timeout is {timeout}, it should be positive if provided!" self.is_async = self.wait_num != len(env_fns) or timeout is not None self.waiting_conn: List[EnvWorker] = [] # environments in self.ready_id is actually ready @@ -94,9 +91,8 @@ def __init__( self.is_closed = False def _assert_is_not_closed(self) -> None: - assert not self.is_closed, ( - f"Methods of {self.__class__.__name__} cannot be called after " - "close.") + assert not self.is_closed, \ + f"Methods of {self.__class__.__name__} cannot be called after close." def __len__(self) -> int: """Return len(self), which is the number of environments.""" @@ -106,9 +102,8 @@ def __getattribute__(self, key: str) -> Any: """Switch the attribute getter depending on the key. Any class who inherits ``gym.Env`` will inherit some attributes, like - ``action_space``. However, we would like the attribute lookup to go - straight into the worker (in fact, this vector env's action_space is - always None). + ``action_space``. However, we would like the attribute lookup to go straight + into the worker (in fact, this vector env's action_space is always None). """ if key in ['metadata', 'reward_range', 'spec', 'action_space', 'observation_space']: # reserved keys in gym.Env @@ -119,9 +114,8 @@ def __getattribute__(self, key: str) -> Any: def __getattr__(self, key: str) -> List[Any]: """Fetch a list of env attributes. - This function tries to retrieve an attribute from each individual - wrapped environment, if it does not belong to the wrapping vector - environment class. + This function tries to retrieve an attribute from each individual wrapped + environment, if it does not belong to the wrapping vector environment class. """ return [getattr(worker, key) for worker in self.workers] @@ -136,12 +130,10 @@ def _wrap_id( def _assert_id(self, id: List[int]) -> None: for i in id: - assert ( - i not in self.waiting_id - ), f"Cannot interact with environment {i} which is stepping now." - assert ( - i in self.ready_id - ), f"Can only interact with ready environments {self.ready_id}." + assert i not in self.waiting_id, \ + f"Cannot interact with environment {i} which is stepping now." + assert i in self.ready_id, \ + f"Can only interact with ready environments {self.ready_id}." def reset( self, id: Optional[Union[int, List[int], np.ndarray]] = None @@ -178,8 +170,7 @@ def step( :return: A tuple including four items: - * ``obs`` a numpy.ndarray, the agent's observation of current \ - environments + * ``obs`` a numpy.ndarray, the agent's observation of current environments * ``rew`` a numpy.ndarray, the amount of rewards returned after \ previous actions * ``done`` a numpy.ndarray, whether these episodes have ended, in \ @@ -294,8 +285,7 @@ def __init__( wait_num: Optional[int] = None, timeout: Optional[float] = None, ) -> None: - super().__init__( - env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) class SubprocVectorEnv(BaseVectorEnv): @@ -316,8 +306,7 @@ def __init__( def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) - super().__init__( - env_fns, worker_fn, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) class ShmemVectorEnv(BaseVectorEnv): @@ -340,8 +329,7 @@ def __init__( def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) - super().__init__( - env_fns, worker_fn, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) class RayVectorEnv(BaseVectorEnv): @@ -369,5 +357,4 @@ def __init__( ) from e if not ray.is_initialized(): ray.init() - super().__init__( - env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index f13fe37d1..d22d60b62 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -11,6 +11,7 @@ def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] + self.action_space = getattr(self, "action_space") @abstractmethod def __getattr__(self, key: str) -> Any: @@ -51,9 +52,8 @@ def wait( """Given a list of workers, return those ready ones.""" raise NotImplementedError - @abstractmethod def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: - pass + return self.action_space.seed(seed) # issue 299 @abstractmethod def render(self, **kwargs: Any) -> Any: diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index 9e0c3c5c7..eafa690b1 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -9,8 +9,8 @@ class DummyEnvWorker(EnvWorker): """Dummy worker used in sequential vector environments.""" def __init__(self, env_fn: Callable[[], gym.Env]) -> None: - super().__init__(env_fn) self.env = env_fn() + super().__init__(env_fn) def __getattr__(self, key: str) -> Any: return getattr(self.env, key) @@ -30,13 +30,12 @@ def wait( # type: ignore def send_action(self, action: np.ndarray) -> None: self.result = self.env.step(action) - def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: - return self.env.seed(seed) if hasattr(self.env, "seed") else None + def seed(self, seed: Optional[int] = None) -> List[int]: + super().seed(seed) + return self.env.seed(seed) def render(self, **kwargs: Any) -> Any: - return ( - self.env.render(**kwargs) if hasattr(self.env, "render") else None - ) + return self.env.render(**kwargs) def close_env(self) -> None: self.env.close() diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 165e104a3..8139ed9d5 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -14,8 +14,8 @@ class RayEnvWorker(EnvWorker): """Ray worker used in RayVectorEnv.""" def __init__(self, env_fn: Callable[[], gym.Env]) -> None: - super().__init__(env_fn) self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn()) + super().__init__(env_fn) def __getattr__(self, key: str) -> Any: return ray.get(self.env.__getattr__.remote(key)) @@ -30,28 +30,22 @@ def wait( # type: ignore timeout: Optional[float] = None, ) -> List["RayEnvWorker"]: results = [x.result for x in workers] - ready_results, _ = ray.wait( - results, num_returns=wait_num, timeout=timeout - ) + ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) return [workers[results.index(result)] for result in ready_results] def send_action(self, action: np.ndarray) -> None: # self.action is actually a handle self.result = self.env.step.remote(action) - def get_result( - self, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: return ray.get(self.result) - def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: - if hasattr(self.env, "seed"): - return ray.get(self.env.seed.remote(seed)) - return None + def seed(self, seed: Optional[int] = None) -> List[int]: + super().seed(seed) + return ray.get(self.env.seed.remote(seed)) def render(self, **kwargs: Any) -> Any: - if hasattr(self.env, "render"): - return ray.get(self.env.render.remote(**kwargs)) + return ray.get(self.env.render.remote(**kwargs)) def close_env(self) -> None: ray.get(self.env.close.remote()) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 02acc21fb..822d65ccf 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -31,10 +31,7 @@ class ShArray: """Wrapper of multiprocessing Array.""" def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: - self.arr = Array( - _NP_TO_CT[dtype.type], - int(np.prod(shape)), - ) + self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) self.dtype = dtype self.shape = shape @@ -124,7 +121,6 @@ class SubprocEnvWorker(EnvWorker): def __init__( self, env_fn: Callable[[], gym.Env], share_memory: bool = False ) -> None: - super().__init__(env_fn) self.parent_remote, self.child_remote = Pipe() self.share_memory = share_memory self.buffer: Optional[Union[dict, tuple, ShArray]] = None @@ -143,6 +139,7 @@ def __init__( self.process = Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() + super().__init__(env_fn) def __getattr__(self, key: str) -> Any: self.parent_remote.send(["getattr", key]) @@ -185,25 +182,22 @@ def wait( # type: ignore if remain_time <= 0: break # connection.wait hangs if the list is empty - new_ready_conns = connection.wait( - remain_conns, timeout=remain_time) + new_ready_conns = connection.wait(remain_conns, timeout=remain_time) ready_conns.extend(new_ready_conns) # type: ignore - remain_conns = [ - conn for conn in remain_conns if conn not in ready_conns] + remain_conns = [conn for conn in remain_conns if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] def send_action(self, action: np.ndarray) -> None: self.parent_remote.send(["step", action]) - def get_result( - self, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def get_result(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: obs, rew, done, info = self.parent_remote.recv() if self.share_memory: obs = self._decode_obs() return obs, rew, done, info def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: + super().seed(seed) self.parent_remote.send(["seed", seed]) return self.parent_remote.recv() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 01b7019af..a3625ca3f 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,6 +1,5 @@ from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy -from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy @@ -11,6 +10,7 @@ from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -19,7 +19,6 @@ __all__ = [ "BasePolicy", "RandomPolicy", - "ImitationPolicy", "DQNPolicy", "C51Policy", "QRDQNPolicy", @@ -30,6 +29,7 @@ "TD3Policy", "SACPolicy", "DiscreteSACPolicy", + "ImitationPolicy", "DiscreteBCQPolicy", "PSRLPolicy", "MultiAgentPolicyManager", diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 99e16544b..730ee28b0 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,7 @@ from torch import nn from numba import njit from abc import ABC, abstractmethod -from typing import Any, List, Union, Mapping, Optional, Callable +from typing import Any, Dict, Union, Optional, Callable from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -67,6 +67,20 @@ def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id + def exploration_noise( + self, act: Union[np.ndarray, Batch], batch: Batch + ) -> Union[np.ndarray, Batch]: + """Modify the action from policy.forward with exploration noise. + + :param act: a data batch or numpy.ndarray which is the action taken by + policy.forward. + :param batch: the input batch for policy.forward, kept for advanced usage. + + :return: action in the same form of input "act" but with added exploration + noise. + """ + return act + @abstractmethod def forward( self, @@ -76,8 +90,7 @@ def forward( ) -> Batch: """Compute action over the given batch data. - :return: A :class:`~tianshou.data.Batch` which MUST have the following\ - keys: + :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: * ``act`` an numpy.ndarray or a torch.Tensor, the action over \ given batch data. @@ -106,18 +119,15 @@ def process_fn( ) -> Batch: """Pre-process the data from the provided replay buffer. - Used in :meth:`update`. Check out :ref:`process_fn` for more - information. + Used in :meth:`update`. Check out :ref:`process_fn` for more information. """ return batch @abstractmethod - def learn( - self, batch: Batch, **kwargs: Any - ) -> Mapping[str, Union[float, List[float]]]: + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]: """Update policy with a given batch of data. - :return: A dict which includes loss and its corresponding label. + :return: A dict, including the data needed to be logged (e.g., loss). .. note:: @@ -150,18 +160,20 @@ def post_process_fn( def update( self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any - ) -> Mapping[str, Union[float, List[float]]]: + ) -> Dict[str, Any]: """Update the policy network and replay buffer. - It includes 3 function steps: process_fn, learn, and post_process_fn. - In addition, this function will change the value of ``self.updating``: - it will be False before this function and will be True when executing - :meth:`update`. Please refer to :ref:`policy_state` for more detailed - explanation. + It includes 3 function steps: process_fn, learn, and post_process_fn. In + addition, this function will change the value of ``self.updating``: it will be + False before this function and will be True when executing :meth:`update`. + Please refer to :ref:`policy_state` for more detailed explanation. - :param int sample_size: 0 means it will extract all the data from the - buffer, otherwise it will sample a batch with given sample_size. + :param int sample_size: 0 means it will extract all the data from the buffer, + otherwise it will sample a batch with given sample_size. :param ReplayBuffer buffer: the corresponding replay buffer. + + :return: A dict, including the data needed to be logged (e.g., loss) from + ``policy.learn()``. """ if buffer is None: return {} @@ -173,36 +185,71 @@ def update( self.updating = False return result + @staticmethod + def value_mask(buffer: ReplayBuffer, indice: np.ndarray) -> np.ndarray: + """Value mask determines whether the obs_next of buffer[indice] is valid. + + For instance, usually "obs_next" after "done" flag is considered to be invalid, + and its q/advantage value can provide meaningless (even misleading) + information, and should be set to 0 by hand. But if "done" flag is generated + because timelimit of game length (info["TimeLimit.truncated"] is set to True in + gym's settings), "obs_next" will instead be valid. Value mask is typically used + for assisting in calculating the correct q/advantage value. + + :param ReplayBuffer buffer: the corresponding replay buffer. + :param numpy.ndarray indice: indices of replay buffer whose "obs_next" will be + judged. + + :return: A bool type numpy.ndarray in the same shape with indice. "True" means + "obs_next" of that buffer[indice] is valid. + """ + mask = ~buffer.done[indice].astype(np.bool) + # info['TimeLimit.truncated'] will be set to True if 'done' flag is generated + # because of timelimit of environments. Checkout gym.wrappers.TimeLimit. + if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info: + mask = mask | buffer.info['TimeLimit.truncated'][indice] + return mask + @staticmethod def compute_episodic_return( batch: Batch, + buffer: ReplayBuffer, + indice: np.ndarray, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, rew_norm: bool = False, ) -> Batch: - """Compute returns over given full-length episodes. - - Implementation of Generalized Advantage Estimator (arXiv:1506.02438). - - :param batch: a data batch which contains several full-episode data - chronologically. - :type batch: :class:`~tianshou.data.Batch` - :param v_s_: the value function of all next states :math:`V(s')`. - :type v_s_: numpy.ndarray - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param float gae_lambda: the parameter for Generalized Advantage - Estimation, should be in [0, 1], defaults to 0.95. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + """Compute returns over given batch. + + Use Implementation of Generalized Advantage Estimator (arXiv:1506.02438) + to calculate q function/reward to go of given batch. + + :param Batch batch: a data batch which contains several episodes of data in + sequential order. Mind that the end of each finished episode of batch + should be marked by done flag, unfinished (or collecting) episodes will be + recongized by buffer.unfinished_index(). + :param numpy.ndarray indice: tell batch's location in buffer, batch is equal to + buffer[indice]. + :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param float gae_lambda: the parameter for Generalized Advantage Estimation, + should be in [0, 1]. Default to 0.95. + :param bool rew_norm: normalize the reward to Normal(0, 1). Default to False. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). """ rew = batch.rew - v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_.flatten()) - returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda) + if v_s_ is None: + assert np.isclose(gae_lambda, 1.0) + v_s_ = np.zeros_like(rew) + else: + v_s_ = to_numpy(v_s_.flatten()) * BasePolicy.value_mask(buffer, indice) + + end_flag = batch.done.copy() + end_flag[np.isin(indice, buffer.unfinished_index())] = True + returns = _episodic_return(v_s_, rew, end_flag, gamma, gae_lambda) if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() batch.returns = returns @@ -224,45 +271,40 @@ def compute_nstep_return( G_t = \sum_{i = t}^{t + n - 1} \gamma^{i - t}(1 - d_i)r_i + \gamma^n (1 - d_{t + n}) Q_{\mathrm{target}}(s_{t + n}) - where :math:`\gamma` is the discount factor, - :math:`\gamma \in [0, 1]`, :math:`d_t` is the done flag of step - :math:`t`. - - :param batch: a data batch, which is equal to buffer[indice]. - :type batch: :class:`~tianshou.data.Batch` - :param buffer: a data buffer which contains several full-episode data - chronologically. - :type buffer: :class:`~tianshou.data.ReplayBuffer` - :param indice: sampled timestep. - :type indice: numpy.ndarray - :param function target_q_fn: a function receives :math:`t+n-1` step's - data and compute target Q value. - :param float gamma: the discount factor, should be in [0, 1], defaults - to 0.99. - :param int n_step: the number of estimation step, should be an int - greater than 0, defaults to 1. - :param bool rew_norm: normalize the reward to Normal(0, 1), defaults - to False. + where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`, + :math:`d_t` is the done flag of step :math:`t`. + + :param Batch batch: a data batch, which is equal to buffer[indice]. + :param ReplayBuffer buffer: the data buffer. + :param function target_q_fn: a function which compute target Q value + of "obs_next" given data buffer and wanted indices. + :param float gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param int n_step: the number of estimation step, should be an int greater + than 0. Default to 1. + :param bool rew_norm: normalize the reward to Normal(0, 1), Default to False. :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor. """ + assert not rew_norm, \ + "Reward normalization in computing n-step returns is unsupported now." rew = buffer.rew - if rew_norm: - bfr = rew[:min(len(buffer), 1000)] # avoid large buffer - mean, std = bfr.mean(), bfr.std() - if np.isclose(std, 0, 1e-2): - mean, std = 0.0, 1.0 - else: - mean, std = 0.0, 1.0 - buf_len = len(buffer) - terminal = (indice + n_step - 1) % buf_len + bsz = len(indice) + indices = [indice] + for _ in range(n_step - 1): + indices.append(buffer.next(indices[-1])) + indices = np.stack(indices) + # terminal indicates buffer indexes nstep after 'indice', + # and are truncated at the end of each episode + terminal = indices[-1] with torch.no_grad(): target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) - target_q = to_numpy(target_q_torch) + target_q = to_numpy(target_q_torch.reshape(bsz, -1)) + target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1) + end_flag = buffer.done.copy() + end_flag[buffer.unfinished_index()] = True + target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step) - target_q = _nstep_return(rew, buffer.done, target_q, indice, - gamma, n_step, len(buffer), mean, std) batch.returns = to_torch_as(target_q, target_q_torch) if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) @@ -272,57 +314,68 @@ def _compile(self) -> None: f64 = np.array([0, 1], dtype=np.float64) f32 = np.array([0, 1], dtype=np.float32) b = np.array([False, True], dtype=np.bool_) - i64 = np.array([0, 1], dtype=np.int64) + i64 = np.array([[0, 1]], dtype=np.int64) + _gae_return(f64, f64, f64, b, 0.1, 0.1) + _gae_return(f32, f32, f64, b, 0.1, 0.1) _episodic_return(f64, f64, b, 0.1, 0.1) _episodic_return(f32, f64, b, 0.1, 0.1) - _nstep_return(f64, b, f32, i64, 0.1, 1, 4, 0.0, 1.0) + _nstep_return(f64, b, f32.reshape(-1, 1), i64, 0.1, 1) @njit -def _episodic_return( +def _gae_return( + v_s: np.ndarray, v_s_: np.ndarray, rew: np.ndarray, - done: np.ndarray, + end_flag: np.ndarray, gamma: float, gae_lambda: float, ) -> np.ndarray: - """Numba speedup: 4.1s -> 0.057s.""" - returns = np.roll(v_s_, 1) - m = (1.0 - done) * gamma - delta = rew + v_s_ * m - returns - m *= gae_lambda + returns = np.zeros(rew.shape) + delta = rew + v_s_ * gamma - v_s + m = (1.0 - end_flag) * (gamma * gae_lambda) gae = 0.0 for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae - returns[i] += gae + returns[i] = gae return returns +@njit +def _episodic_return( + v_s_: np.ndarray, + rew: np.ndarray, + end_flag: np.ndarray, + gamma: float, + gae_lambda: float, +) -> np.ndarray: + """Numba speedup: 4.1s -> 0.057s.""" + v_s = np.roll(v_s_, 1) + return _gae_return(v_s, v_s_, rew, end_flag, gamma, gae_lambda) + v_s + + @njit def _nstep_return( rew: np.ndarray, - done: np.ndarray, + end_flag: np.ndarray, target_q: np.ndarray, - indice: np.ndarray, + indices: np.ndarray, gamma: float, n_step: int, - buf_len: int, - mean: float, - std: float, ) -> np.ndarray: - """Numba speedup: 0.3s -> 0.15s.""" + gamma_buffer = np.ones(n_step + 1) + for i in range(1, n_step + 1): + gamma_buffer[i] = gamma_buffer[i - 1] * gamma target_shape = target_q.shape bsz = target_shape[0] # change target_q to 2d array target_q = target_q.reshape(bsz, -1) returns = np.zeros(target_q.shape) - gammas = np.full(indice.shape, n_step) + gammas = np.full(indices[0].shape, n_step) for n in range(n_step - 1, -1, -1): - now = (indice + n) % buf_len - gammas[done[now] > 0] = n - returns[done[now] > 0] = 0.0 - returns = (rew[now].reshape(-1, 1) - mean) / std + gamma * returns - target_q[gammas != n_step] = 0.0 - gammas = gammas.reshape(-1, 1) - target_q = target_q * (gamma ** gammas) + returns + now = indices[n] + gammas[end_flag[now] > 0] = n + 1 + returns[end_flag[now] > 0] = 0.0 + returns = rew[now].reshape(bsz, 1) + gamma * returns + target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns return target_q.reshape(target_shape) diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 954bc81f6..a618dd480 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -14,7 +14,7 @@ class ImitationPolicy(BasePolicy): :class:`~tianshou.policy.BasePolicy`. (s -> a) :param torch.optim.Optimizer optim: for optimizing the model. :param str mode: indicate the imitation type ("continuous" or "discrete" - action space), defaults to "continuous". + action space). Default to "continuous". .. seealso:: @@ -32,9 +32,8 @@ def __init__( super().__init__(**kwargs) self.model = model self.optim = optim - assert ( - mode in ["continuous", "discrete"] - ), f"Mode {mode} is not in ['continuous', 'discrete']." + assert mode in ["continuous", "discrete"], \ + f"Mode {mode} is not in ['continuous', 'discrete']." self.mode = mode def forward( diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 688a9901d..5d7082243 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -17,16 +17,15 @@ class DiscreteBCQPolicy(DQNPolicy): :class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float discount_factor: in [0, 1]. - :param int estimation_step: greater than 1, the number of steps to look - ahead. + :param int estimation_step: the number of steps to look ahead. Default to 1. :param int target_update_freq: the target network update frequency. :param float eval_eps: the epsilon-greedy noise added in evaluation. :param float unlikely_action_threshold: the threshold (tau) for unlikely - actions, as shown in Equ. (17) in the paper, defaults to 0.3. + actions, as shown in Equ. (17) in the paper. Default to 0.3. :param float imitation_logits_penalty: reguralization weight for imitation - logits, defaults to 1e-2. - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. + logits. Default to 1e-2. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. .. seealso:: @@ -52,9 +51,8 @@ def __init__( target_update_freq, reward_normalization, **kwargs) assert target_update_freq > 0, "BCQ needs target network setting." self.imitator = imitator - assert ( - 0.0 <= unlikely_action_threshold < 1.0 - ), "unlikely_action_threshold should be in [0, 1)" + assert 0.0 <= unlikely_action_threshold < 1.0, \ + "unlikely_action_threshold should be in [0, 1)" if unlikely_action_threshold > 0: self._log_tau = math.log(unlikely_action_threshold) else: @@ -69,12 +67,10 @@ def train(self, mode: bool = True) -> "DiscreteBCQPolicy": self.imitator.train(mode) return self - def _target_q( - self, buffer: ReplayBuffer, indice: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - act = self(batch, input="obs_next", eps=0.0).act + act = self(batch, input="obs_next").act target_q, _ = self.model_old(batch.obs_next) target_q = target_q[np.arange(len(act)), act] return target_q @@ -84,39 +80,38 @@ def forward( # type: ignore batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, input: str = "obs", - eps: Optional[float] = None, **kwargs: Any, ) -> Batch: - if eps is None: - eps = self._eps obs = batch[input] q_value, state = self.model(obs, state=state, info=batch.info) + if not hasattr(self, "max_action_num"): + self.max_action_num = q_value.shape[1] imitation_logits, _ = self.imitator(obs, state=state, info=batch.info) # mask actions for argmax - ratio = imitation_logits - imitation_logits.max( - dim=-1, keepdim=True).values + ratio = imitation_logits - imitation_logits.max(dim=-1, keepdim=True).values mask = (ratio < self._log_tau).float() action = (q_value - np.inf * mask).argmax(dim=-1) - # add eps to act - if not np.isclose(eps, 0.0): - bsz, action_num = q_value.shape - mask = np.random.rand(bsz) < eps - action_rand = torch.randint( - action_num, size=[bsz], device=action.device) - action[mask] = action_rand[mask] - return Batch(act=action, state=state, q_value=q_value, imitation_logits=imitation_logits) + def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: + # add eps to act + if not np.isclose(self._eps, 0.0): + bsz = len(act) + mask = np.random.rand(bsz) < self._eps + act_rand = np.random.randint(self.max_action_num, size=[bsz]) + act[mask] = act_rand[mask] + return act + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._iter % self._freq == 0: self.sync_weight() self._iter += 1 target_q = batch.returns.flatten() - result = self(batch, eps=0.0) + result = self(batch) imitation_logits = result.imitation_logits current_q = result.q_value[np.arange(len(target_q)), batch.act] act = to_torch(batch.act, dtype=torch.long, device=target_q.device) diff --git a/tianshou/policy/modelbase/psrl.py b/tianshou/policy/modelbase/psrl.py index dcf6a5d05..4a565976f 100644 --- a/tianshou/policy/modelbase/psrl.py +++ b/tianshou/policy/modelbase/psrl.py @@ -149,7 +149,7 @@ class PSRLPolicy(BasePolicy): :param float discount_factor: in [0, 1]. :param float epsilon: for precision control in value iteration. :param bool add_done_loop: whether to add an extra self-loop for the - terminal state in MDP, defaults to False. + terminal state in MDP. Default to False. .. seealso:: diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index e215c6bbd..f79682789 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -2,7 +2,7 @@ import numpy as np from torch import nn import torch.nn.functional as F -from typing import Any, Dict, List, Union, Optional, Callable +from typing import Any, Dict, List, Type, Union, Optional from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -17,20 +17,20 @@ class A2CPolicy(PGPolicy): :param torch.optim.Optimizer optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. - :type dist_fn: Callable[[], torch.distributions.Distribution] - :param float discount_factor: in [0, 1], defaults to 0.99. - :param float vf_coef: weight for value loss, defaults to 0.5. - :param float ent_coef: weight for entropy loss, defaults to 0.01. - :param float max_grad_norm: clipping gradients in back propagation, - defaults to None. + :type dist_fn: Type[torch.distributions.Distribution] + :param float discount_factor: in [0, 1]. Default to 0.99. + :param float vf_coef: weight for value loss. Default to 0.5. + :param float ent_coef: weight for entropy loss. Default to 0.01. + :param float max_grad_norm: clipping gradients in back propagation. + Default to None. :param float gae_lambda: in [0, 1], param for Generalized Advantage - Estimation, defaults to 0.95. - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. + Estimation. Default to 0.95. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the - model; should be as large as possible within the memory constraint; - defaults to 256. + model; should be as large as possible within the memory constraint. + Default to 256. .. seealso:: @@ -43,7 +43,7 @@ def __init__( actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: Callable[[], torch.distributions.Distribution], + dist_fn: Type[torch.distributions.Distribution], discount_factor: float = 0.99, vf_coef: float = 0.5, ent_coef: float = 0.01, @@ -69,15 +69,16 @@ def process_fn( ) -> Batch: if self._lambda in [0.0, 1.0]: return self.compute_episodic_return( - batch, None, gamma=self._gamma, gae_lambda=self._lambda) + batch, buffer, indice, + None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda, - rew_norm=self._rew_norm) + batch, buffer, indice, v_, + gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) def forward( self, @@ -103,7 +104,7 @@ def forward( if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: - dist = self.dist_fn(logits) # type: ignore + dist = self.dist_fn(logits) act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) @@ -122,13 +123,11 @@ def learn( # type: ignore a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) # type: ignore ent_loss = dist.entropy().mean() - loss = a_loss + self._weight_vf * vf_loss - \ - self._weight_ent * ent_loss + loss = a_loss + self._weight_vf * vf_loss - self._weight_ent * ent_loss loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_( - list(self.actor.parameters()) - + list(self.critic.parameters()), + list(self.actor.parameters()) + list(self.critic.parameters()), max_norm=self._grad_norm, ) self.optim.step() diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index b0e94c616..20ef89c1a 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -14,17 +14,16 @@ class C51Policy(DQNPolicy): :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float discount_factor: in [0, 1]. :param int num_atoms: the number of atoms in the support set of the - value distribution, defaults to 51. - :param float v_min: the value of the smallest atom in the support set, - defaults to -10.0. - :param float v_max: the value of the largest atom in the support set, - defaults to 10.0. - :param int estimation_step: greater than 1, the number of steps to look - ahead. + value distribution. Default to 51. + :param float v_min: the value of the smallest atom in the support set. + Default to -10.0. + :param float v_max: the value of the largest atom in the support set. + Default to 10.0. + :param int estimation_step: the number of steps to look ahead. Default to 1. :param int target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. + you do not use the target network). Default to 0. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. .. seealso:: @@ -70,9 +69,7 @@ def compute_q_value(self, logits: torch.Tensor) -> torch.Tensor: def _target_dist(self, batch: Batch) -> torch.Tensor: if self._target: a = self(batch, input="obs_next").act - next_dist = self( - batch, model="model_old", input="obs_next" - ).logits + next_dist = self(batch, model="model_old", input="obs_next").logits else: next_b = self(batch, input="obs_next") a = next_b.act diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 1ac293c88..9a4dad062 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -5,7 +5,7 @@ from tianshou.policy import BasePolicy from tianshou.exploration import BaseNoise, GaussianNoise -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer class DDPGPolicy(BasePolicy): @@ -15,21 +15,16 @@ class DDPGPolicy(BasePolicy): :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer actor_optim: the optimizer for actor network. :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a)) - :param torch.optim.Optimizer critic_optim: the optimizer for critic - network. + :param torch.optim.Optimizer critic_optim: the optimizer for critic network. :param action_range: the action range (minimum, maximum). :type action_range: Tuple[float, float] - :param float tau: param for soft update of the target network, defaults to - 0.005. - :param float gamma: discount factor, in [0, 1], defaults to 0.99. + :param float tau: param for soft update of the target network. Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param BaseNoise exploration_noise: the exploration noise, - add to the action, defaults to ``GaussianNoise(sigma=0.1)``. + add to the action. Default to ``GaussianNoise(sigma=0.1)``. :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. - :param int estimation_step: greater than 1, the number of steps to look - ahead. + Default to False. + :param int estimation_step: the number of steps to look ahead. Default to 1. .. seealso:: @@ -48,7 +43,6 @@ def __init__( gamma: float = 0.99, exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: @@ -73,9 +67,7 @@ def __init__( self._action_scale = (action_range[1] - action_range[0]) / 2.0 # it is only a little difference to use GaussianNoise # self.noise = OUNoise() - self._rm_done = ignore_done self._rew_norm = reward_normalization - assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step def set_exp_noise(self, noise: Optional[BaseNoise]) -> None: @@ -93,9 +85,7 @@ def sync_weight(self) -> None: """Soft-update the weight for the target network.""" for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip( - self.critic_old.parameters(), self.critic.parameters() - ): + for o, n in zip(self.critic_old.parameters(), self.critic.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) def _target_q( @@ -110,8 +100,6 @@ def _target_q( def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - if self._rm_done: - batch.done = batch.done * 0.0 batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) @@ -141,9 +129,6 @@ def forward( obs = batch[input] actions, h = model(obs, state=state, info=batch.info) actions += self._action_bias - if self._noise and not self.updating: - actions += to_torch_as(self._noise(actions.shape), actions) - actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -166,3 +151,9 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: "loss/actor": actor_loss.item(), "loss/critic": critic_loss.item(), } + + def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: + if self._noise: + act = act + self._noise(act.shape) + act = act.clip(self._range[0], self._range[1]) + return act diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 4db2dc9dc..a53bbbbf8 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -19,17 +19,14 @@ class DiscreteSACPolicy(SACPolicy): :param torch.nn.Module critic2: the second critic network. (s -> Q(s)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. - :param float tau: param for soft update of the target network, defaults to - 0.005. - :param float gamma: discount factor, in [0, 1], defaults to 0.99. + :param float tau: param for soft update of the target network. Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy - regularization coefficient, default to 0.2. - If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then + regularization coefficient. Default to 0.2. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, the alpha is automatatically tuned. - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to ``False``. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to ``False``. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. .. seealso:: @@ -47,17 +44,14 @@ def __init__( critic2_optim: torch.optim.Optimizer, tau: float = 0.005, gamma: float = 0.99, - alpha: Union[ - float, Tuple[float, torch.Tensor, torch.optim.Optimizer] - ] = 0.2, + alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: super().__init__(actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, (-np.inf, np.inf), tau, gamma, alpha, - reward_normalization, ignore_done, estimation_step, + reward_normalization, estimation_step, **kwargs) self._alpha: Union[float, torch.Tensor] @@ -119,8 +113,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: current_q1a = self.critic1(batch.obs) current_q2a = self.critic2(batch.obs) q = torch.min(current_q1a, current_q2a) - actor_loss = -(self._alpha * entropy - + (dist.probs * q).sum(dim=-1)).mean() + actor_loss = -(self._alpha * entropy + (dist.probs * q).sum(dim=-1)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() @@ -145,3 +138,8 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: result["alpha"] = self._alpha.item() # type: ignore return result + + def exploration_noise( + self, act: Union[np.ndarray, Batch], batch: Batch + ) -> Union[np.ndarray, Batch]: + return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 54397915c..bd1fea14a 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -19,12 +19,11 @@ class DQNPolicy(BasePolicy): :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float discount_factor: in [0, 1]. - :param int estimation_step: greater than 1, the number of steps to look - ahead. + :param int estimation_step: the number of steps to look ahead. Default to 1. :param int target_update_freq: the target network update frequency (0 if - you do not use the target network). - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. + you do not use the target network). Default to 0. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. .. seealso:: @@ -46,9 +45,7 @@ def __init__( self.model = model self.optim = optim self.eps = 0.0 - assert ( - 0.0 <= discount_factor <= 1.0 - ), "discount factor should be in [0, 1]" + assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self._gamma = discount_factor assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step @@ -74,16 +71,12 @@ def sync_weight(self) -> None: """Synchronize the weight for the target network.""" self.model_old.load_state_dict(self.model.state_dict()) - def _target_q( - self, buffer: ReplayBuffer, indice: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} # target_Q = Q_old(s_, argmax(Q_new(s_, *))) if self._target: a = self(batch, input="obs_next").act - target_q = self( - batch, model="model_old", input="obs_next" - ).logits + target_q = self(batch, model="model_old", input="obs_next").logits target_q = target_q[np.arange(len(a)), a] else: target_q = self(batch, input="obs_next").logits.max(dim=1)[0] @@ -148,20 +141,14 @@ def forward( obs_ = obs.obs if hasattr(obs, "obs") else obs logits, h = model(obs_, state=state, info=batch.info) q = self.compute_q_value(logits) + if not hasattr(self, "max_action_num"): + self.max_action_num = q.shape[1] act: np.ndarray = to_numpy(q.max(dim=1)[1]) if hasattr(obs, "mask"): # some of actions are masked, they cannot be selected q_: np.ndarray = to_numpy(q) q_[~obs.mask] = -np.inf act = q_.argmax(axis=1) - # add eps to act in training or testing phase - if not self.updating and not np.isclose(self.eps, 0.0): - for i in range(len(q)): - if np.random.rand() < self.eps: - q_ = np.random.rand(*q[i].shape) - if hasattr(obs, "mask"): - q_[~obs.mask[i]] = -np.inf - act[i] = q_.argmax() return Batch(logits=logits, act=act, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: @@ -179,3 +166,13 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.step() self._iter += 1 return {"loss": loss.item()} + + def exploration_noise(self, act: np.ndarray, batch: Batch) -> np.ndarray: + if not np.isclose(self.eps, 0.0): + for i in range(len(act)): + if np.random.rand() < self.eps: + q_ = np.random.rand(self.max_action_num) + if hasattr(batch["obs"], "mask"): + q_[~batch["obs"].mask[i]] = -np.inf + act[i] = q_.argmax() + return act diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 2f3304658..080ba70a2 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import Any, Dict, List, Union, Optional, Callable +from typing import Any, Dict, List, Type, Union, Optional from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as @@ -13,8 +13,8 @@ class PGPolicy(BasePolicy): :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param dist_fn: distribution class for computing the action. - :type dist_fn: Callable[[], torch.distributions.Distribution] - :param float discount_factor: in [0, 1]. + :type dist_fn: Type[torch.distributions.Distribution] + :param float discount_factor: in [0, 1]. Default to 0.99. .. seealso:: @@ -26,7 +26,7 @@ def __init__( self, model: Optional[torch.nn.Module], optim: torch.optim.Optimizer, - dist_fn: Callable[[], torch.distributions.Distribution], + dist_fn: Type[torch.distributions.Distribution], discount_factor: float = 0.99, reward_normalization: bool = False, **kwargs: Any, @@ -36,16 +36,14 @@ def __init__( self.model: torch.nn.Module = model self.optim = optim self.dist_fn = dist_fn - assert ( - 0.0 <= discount_factor <= 1.0 - ), "discount factor should be in [0, 1]" + assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self._gamma = discount_factor self._rew_norm = reward_normalization def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - r"""Compute the discounted returns for each frame. + r"""Compute the discounted returns for each transition. .. math:: G_t = \sum_{i=t}^T \gamma^{i-t}r_i @@ -56,7 +54,8 @@ def process_fn( # batch.returns = self._vanilla_returns(batch) # batch.returns = self._vectorized_returns(batch) return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm) + batch, buffer, indice, gamma=self._gamma, + gae_lambda=1.0, rew_norm=self._rew_norm) def forward( self, @@ -82,7 +81,7 @@ def forward( if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: - dist = self.dist_fn(logits) # type: ignore + dist = self.dist_fn(logits) act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 4cf9f9054..953829195 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,7 +1,7 @@ import torch import numpy as np from torch import nn -from typing import Any, Dict, List, Tuple, Union, Optional, Callable +from typing import Any, Dict, List, Type, Tuple, Union, Optional from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as @@ -13,32 +13,31 @@ class PPOPolicy(PGPolicy): :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.nn.Module critic: the critic network. (s -> V(s)) - :param torch.optim.Optimizer optim: the optimizer for actor and critic - network. + :param torch.optim.Optimizer optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. - :type dist_fn: Callable[[], torch.distributions.Distribution] - :param float discount_factor: in [0, 1], defaults to 0.99. - :param float max_grad_norm: clipping gradients in back propagation, - defaults to None. + :type dist_fn: Type[torch.distributions.Distribution] + :param float discount_factor: in [0, 1]. Default to 0.99. + :param float max_grad_norm: clipping gradients in back propagation. + Default to None. :param float eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original - paper, defaults to 0.2. - :param float vf_coef: weight for value loss, defaults to 0.5. - :param float ent_coef: weight for entropy loss, defaults to 0.01. + paper. Default to 0.2. + :param float vf_coef: weight for value loss. Default to 0.5. + :param float ent_coef: weight for entropy loss. Default to 0.01. :param action_range: the action range (minimum, maximum). :type action_range: (float, float) :param float gae_lambda: in [0, 1], param for Generalized Advantage - Estimation, defaults to 0.95. + Estimation. Default to 0.95. :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, - where c > 1 is a constant indicating the lower bound, - defaults to 5.0 (set ``None`` if you do not want to use it). - :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1, - defaults to True. - :param bool reward_normalization: normalize the returns to Normal(0, 1), - defaults to True. + where c > 1 is a constant indicating the lower bound. + Default to 5.0 (set None if you do not want to use it). + :param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1. + Default to True. + :param bool reward_normalization: normalize the returns to Normal(0, 1). + Default to True. :param int max_batchsize: the maximum size of the batch when computing GAE, depends on the size of available memory and the memory cost of the - model; should be as large as possible within the memory constraint; - defaults to 256. + model; should be as large as possible within the memory constraint. + Default to 256. .. seealso:: @@ -51,7 +50,7 @@ def __init__( actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: Callable[[], torch.distributions.Distribution], + dist_fn: Type[torch.distributions.Distribution], discount_factor: float = 0.99, max_grad_norm: Optional[float] = None, eps_clip: float = 0.2, @@ -76,9 +75,8 @@ def __init__( self._batch = max_batchsize assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda - assert ( - dual_clip is None or dual_clip > 1.0 - ), "Dual-clip PPO parameter should greater than 1.0." + assert dual_clip is None or dual_clip > 1.0, \ + "Dual-clip PPO parameter should greater than 1.0." self._dual_clip = dual_clip self._value_clip = value_clip self._rew_norm = reward_normalization @@ -95,13 +93,11 @@ def process_fn( for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(self.critic(b.obs_next)) v.append(self.critic(b.obs)) - old_log_prob.append( - self(b).dist.log_prob(to_torch_as(b.act, v[0])) - ) + old_log_prob.append(self(b).dist.log_prob(to_torch_as(b.act, v[0]))) v_ = to_numpy(torch.cat(v_, dim=0)) batch = self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda, - rew_norm=self._rew_norm) + batch, buffer, indice, v_, gamma=self._gamma, + gae_lambda=self._lambda, rew_norm=self._rew_norm) batch.v = torch.cat(v, dim=0).flatten() # old value batch.act = to_torch_as(batch.act, v[0]) batch.logp_old = torch.cat(old_log_prob, dim=0) @@ -137,7 +133,7 @@ def forward( if isinstance(logits, tuple): dist = self.dist_fn(*logits) else: - dist = self.dist_fn(logits) # type: ignore + dist = self.dist_fn(logits) act = dist.sample() if self._range: act = act.clamp(self._range[0], self._range[1]) @@ -154,8 +150,7 @@ def learn( # type: ignore ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv - surr2 = ratio.clamp(1.0 - self._eps_clip, - 1.0 + self._eps_clip) * b.adv + surr2 = ratio.clamp(1.0 - self._eps_clip, 1.0 + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max( torch.min(surr1, surr2), self._dual_clip * b.adv @@ -164,8 +159,7 @@ def learn( # type: ignore clip_loss = -torch.min(surr1, surr2).mean() clip_losses.append(clip_loss.item()) if self._value_clip: - v_clip = b.v + (value - b.v).clamp( - -self._eps_clip, self._eps_clip) + v_clip = b.v + (value - b.v).clamp(-self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) vf_loss = 0.5 * torch.max(vf1, vf2).mean() @@ -174,15 +168,14 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) - loss = clip_loss + self._weight_vf * vf_loss - \ - self._weight_ent * e_loss + loss = clip_loss + self._weight_vf * vf_loss \ + - self._weight_ent * e_loss losses.append(loss.item()) self.optim.zero_grad() loss.backward() if self._max_grad_norm: nn.utils.clip_grad_norm_( - list(self.actor.parameters()) - + list(self.critic.parameters()), + list(self.actor.parameters()) + list(self.critic.parameters()), self._max_grad_norm) self.optim.step() return { diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index 754d9acce..7e154e7f7 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -16,13 +16,12 @@ class QRDQNPolicy(DQNPolicy): :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float discount_factor: in [0, 1]. :param int num_quantiles: the number of quantile midpoints in the inverse - cumulative distribution function of the value, defaults to 200. - :param int estimation_step: greater than 1, the number of steps to look - ahead. + cumulative distribution function of the value. Default to 200. + :param int estimation_step: the number of steps to look ahead. Default to 1. :param int target_update_freq: the target network update frequency (0 if you do not use the target network). - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. .. seealso:: @@ -50,15 +49,11 @@ def __init__( ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1), requires_grad=False) warnings.filterwarnings("ignore", message="Using a target size") - def _target_q( - self, buffer: ReplayBuffer, indice: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} if self._target: a = self(batch, input="obs_next").act - next_dist = self( - batch, model="model_old", input="obs_next" - ).logits + next_dist = self(batch, model="model_old", input="obs_next").logits else: next_b = self(batch, input="obs_next") a = next_b.act diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index fbdd12297..68bef3971 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -6,7 +6,7 @@ from tianshou.policy import DDPGPolicy from tianshou.exploration import BaseNoise -from tianshou.data import Batch, ReplayBuffer, to_torch_as +from tianshou.data import Batch, ReplayBuffer class SACPolicy(DDPGPolicy): @@ -15,32 +15,27 @@ class SACPolicy(DDPGPolicy): :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer actor_optim: the optimizer for actor network. - :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, - a)) + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic1_optim: the optimizer for the first critic network. - :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, - a)) + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. :param action_range: the action range (minimum, maximum). :type action_range: Tuple[float, float] - :param float tau: param for soft update of the target network, defaults to - 0.005. - :param float gamma: discount factor, in [0, 1], defaults to 0.99. + :param float tau: param for soft update of the target network. Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy - regularization coefficient, default to 0.2. + regularization coefficient. Default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then alpha is automatatically tuned. - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. - :param BaseNoise exploration_noise: add a noise to action for exploration, - defaults to None. This is useful when solving hard-exploration problem. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + :param BaseNoise exploration_noise: add a noise to action for exploration. + Default to None. This is useful when solving hard-exploration problem. :param bool deterministic_eval: whether to use deterministic action (mean - of Gaussian policy) instead of stochastic action sampled by the policy, - defaults to True. + of Gaussian policy) instead of stochastic action sampled by the policy. + Default to True. .. seealso:: @@ -59,18 +54,15 @@ def __init__( action_range: Tuple[float, float], tau: float = 0.005, gamma: float = 0.99, - alpha: Union[ - float, Tuple[float, torch.Tensor, torch.optim.Optimizer] - ] = 0.2, + alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, exploration_noise: Optional[BaseNoise] = None, deterministic_eval: bool = True, **kwargs: Any, ) -> None: super().__init__(None, None, None, None, action_range, tau, gamma, - exploration_noise, reward_normalization, ignore_done, + exploration_noise, reward_normalization, estimation_step, **kwargs) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) @@ -101,13 +93,9 @@ def train(self, mode: bool = True) -> "SACPolicy": return self def sync_weight(self) -> None: - for o, n in zip( - self.critic1_old.parameters(), self.critic1.parameters() - ): + for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip( - self.critic2_old.parameters(), self.critic2.parameters() - ): + for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) def forward( # type: ignore @@ -130,11 +118,8 @@ def forward( # type: ignore y = self._action_scale * (1 - y.pow(2)) + self.__eps log_prob = dist.log_prob(x).unsqueeze(-1) log_prob = log_prob - torch.log(y).sum(-1, keepdim=True) - if self._noise is not None and self.training and not self.updating: - act += to_torch_as(self._noise(act.shape), act) - act = act.clamp(self._range[0], self._range[1]) - return Batch( - logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) + + return Batch(logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) def _target_q( self, buffer: ReplayBuffer, indice: np.ndarray diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index f79c2a0d5..bd6572205 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -14,31 +14,26 @@ class TD3Policy(DDPGPolicy): :param torch.nn.Module actor: the actor network following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer actor_optim: the optimizer for actor network. - :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, - a)) + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic1_optim: the optimizer for the first critic network. - :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, - a)) + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. :param action_range: the action range (minimum, maximum). :type action_range: Tuple[float, float] - :param float tau: param for soft update of the target network, defaults to - 0.005. - :param float gamma: discount factor, in [0, 1], defaults to 0.99. - :param float exploration_noise: the exploration noise, add to the action, - defaults to ``GaussianNoise(sigma=0.1)`` - :param float policy_noise: the noise used in updating policy network, - default to 0.2. - :param int update_actor_freq: the update frequency of actor network, - default to 2. - :param float noise_clip: the clipping range used in updating policy - network, default to 0.5. - :param bool reward_normalization: normalize the reward to Normal(0, 1), - defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. + :param float tau: param for soft update of the target network. Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. + :param float exploration_noise: the exploration noise, add to the action. + Default to ``GaussianNoise(sigma=0.1)`` + :param float policy_noise: the noise used in updating policy network. + Default to 0.2. + :param int update_actor_freq: the update frequency of actor network. + Default to 2. + :param float noise_clip: the clipping range used in updating policy network. + Default to 0.5. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. .. seealso:: @@ -62,13 +57,12 @@ def __init__( update_actor_freq: int = 2, noise_clip: float = 0.5, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: - super().__init__(actor, actor_optim, None, None, action_range, - tau, gamma, exploration_noise, reward_normalization, - ignore_done, estimation_step, **kwargs) + super().__init__(actor, actor_optim, None, None, action_range, tau, gamma, + exploration_noise, reward_normalization, + estimation_step, **kwargs) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim @@ -91,18 +85,12 @@ def train(self, mode: bool = True) -> "TD3Policy": def sync_weight(self) -> None: for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip( - self.critic1_old.parameters(), self.critic1.parameters() - ): + for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - for o, n in zip( - self.critic2_old.parameters(), self.critic2.parameters() - ): + for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - def _target_q( - self, buffer: ReplayBuffer, indice: np.ndarray - ) -> torch.Tensor: + def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} a_ = self(batch, model="actor_old", input="obs_next").act dev = a_.device @@ -137,8 +125,7 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.critic2_optim.step() batch.weight = (td1 + td2) / 2.0 # prio-buffer if self._cnt % self._freq == 0: - actor_loss = -self.critic1( - batch.obs, self(batch, eps=0.0).act).mean() + actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 5bfd5e9b2..7aa1f661c 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -59,6 +59,18 @@ def process_fn( buffer._meta.rew = save_rew return Batch(results) + def exploration_noise( + self, act: Union[np.ndarray, Batch], batch: Batch + ) -> Union[np.ndarray, Batch]: + """Add exploration noise from sub-policy onto act.""" + for policy in self.policies: + agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] + if len(agent_index) == 0: + continue + act[agent_index] = policy.exploration_noise( + act[agent_index], batch[agent_index]) + return act + def forward( self, batch: Batch, diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index 13f9159f8..9c7f132af 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -38,5 +38,5 @@ def forward( return Batch(act=logits.argmax(axis=-1)) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - """Since a random agent learn nothing, it returns an empty dict.""" + """Since a random agent learns nothing, it returns an empty dict.""" return {} diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 22fc1eea1..9fa88fbc3 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -4,9 +4,9 @@ from tianshou.trainer.offline import offline_trainer __all__ = [ - "gather_info", - "test_episode", - "onpolicy_trainer", "offpolicy_trainer", + "onpolicy_trainer", "offline_trainer", + "test_episode", + "gather_info", ] diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py index e69364135..13f96faeb 100644 --- a/tianshou/trainer/offline.py +++ b/tianshou/trainer/offline.py @@ -1,11 +1,11 @@ import time import tqdm +import numpy as np from collections import defaultdict -from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, Union, Callable, Optional from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.data import Collector, ReplayBuffer from tianshou.trainer import test_episode, gather_info @@ -15,57 +15,62 @@ def offline_trainer( buffer: ReplayBuffer, test_collector: Collector, max_epoch: int, - step_per_epoch: int, - episode_per_test: Union[int, List[int]], + update_per_epoch: int, + episode_per_test: int, batch_size: int, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), verbose: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for offline trainer procedure. - The "step" in trainer means a policy network update. + The "step" in offline trainer means a gradient step. - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int update_per_epoch: the number of policy network updates, so-called gradient steps, per epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. + :param int batch_size: the batch size of sample data, which is going to feed in + the policy network. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during updating/testing. + Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ gradient_step = 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() test_collector.reset_stat() - + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + logger, gradient_step, reward_metric) + best_epoch = 0 + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1, 1 + max_epoch): policy.train() with tqdm.trange( - step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config ) as t: for i in t: gradient_step += 1 @@ -73,25 +78,23 @@ def offline_trainer( data = {"gradient_step": str(gradient_step)} for k in losses.keys(): stat[k].add(losses[k]) - data[k] = f"{stat[k].get():.6f}" - if writer and gradient_step % log_interval == 0: - writer.add_scalar( - "train/" + k, stat[k].get(), - global_step=gradient_step) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.3f}" + logger.log_update_data(losses, gradient_step) t.set_postfix(**data) # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, gradient_step) - if best_epoch == -1 or best_reward < result["rew"]: - best_reward, best_reward_std = result["rew"], result["rew_std"] + test_result = test_episode( + policy, test_collector, test_fn, epoch, episode_per_test, + logger, gradient_step, reward_metric) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch == -1 or best_reward < rew: + best_reward, best_reward_std = rew, rew_std best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " - f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " - f"{best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break - return gather_info(start_time, None, test_collector, - best_reward, best_reward_std) + return gather_info(start_time, None, test_collector, best_reward, best_reward_std) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index f34f5b281..72a243d9a 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,12 +1,12 @@ import time import tqdm +import numpy as np from collections import defaultdict -from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.trainer import test_episode, gather_info @@ -16,70 +16,76 @@ def offpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, - episode_per_test: Union[int, List[int]], + step_per_collect: int, + episode_per_test: int, batch_size: int, - update_per_step: int = 1, + update_per_step: Union[int, float] = 1, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment step (a.k.a. transition). - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of frames the collector would - collect before the network update. In other words, collect some frames - and do some policy network update. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. :param episode_per_test: the number of episodes for one policy evaluation. - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param int update_per_step: the number of times the policy network would - be updated after frames are collected, for example, set it to 256 means - it updates policy 256 times once after ``collect_per_step`` frames are - collected. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int/float update_per_step: the number of times the policy network would be + updated per transition after (step_per_collect) transitions are collected, + e.g., if update_per_step set to 0.3, and step_per_collect is 256, policy will + be updated round(256 * 0.3 = 76.8) = 77 times after 256 transitions are + collected by the collector. Default to 1. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy:BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 + last_rew, last_len = 0.0, 0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + logger, env_step, reward_metric) + best_epoch = 0 + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -89,62 +95,58 @@ def offpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_step=collect_per_step) + result = train_collector.collect(n_step=step_per_collect) + if result["n/ep"] > 0 and reward_metric: + result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) + logger.log_train_data(result, env_step) + last_rew = result['rew'] if 'rew' in result else last_rew + last_len = result['len'] if 'len' in result else last_len data = { "env_step": str(env_step), - "rew": f"{result['rew']:.2f}", - "len": str(int(result["len"])), + "rew": f"{last_rew:.2f}", + "len": str(int(last_len)), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), - "v/ep": f"{result['v/ep']:.2f}", - "v/st": f"{result['v/st']:.2f}", } - if writer and env_step % log_interval == 0: - for k in result.keys(): - writer.add_scalar( - "train/" + k, result[k], global_step=env_step) - if test_in_train and stop_fn and stop_fn(result["rew"]): - test_result = test_episode( - policy, test_collector, test_fn, - epoch, episode_per_test, writer, env_step) - if stop_fn(test_result["rew"]): - if save_fn: - save_fn(policy) - for k in result.keys(): - data[k] = f"{result[k]:.2f}" - t.set_postfix(**data) - return gather_info( - start_time, train_collector, test_collector, - test_result["rew"], test_result["rew_std"]) - else: - policy.train() - for i in range(update_per_step * min( - result["n/st"] // collect_per_step, t.total - t.n)): + if result["n/ep"] > 0: + if test_in_train and stop_fn and stop_fn(result["rew"]): + test_result = test_episode( + policy, test_collector, test_fn, + epoch, episode_per_test, logger, env_step) + if stop_fn(test_result["rew"]): + if save_fn: + save_fn(policy) + t.set_postfix(**data) + return gather_info( + start_time, train_collector, test_collector, + test_result["rew"], test_result["rew_std"]) + else: + policy.train() + for i in range(round(update_per_step * result["n/st"])): gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): stat[k].add(losses[k]) - data[k] = f"{stat[k].get():.6f}" - if writer and gradient_step % log_interval == 0: - writer.add_scalar( - k, stat[k].get(), global_step=gradient_step) - t.update(1) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.3f}" + logger.log_update_data(losses, gradient_step) t.set_postfix(**data) if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < result["rew"]: - best_reward, best_reward_std = result["rew"], result["rew_std"] + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, logger, env_step, reward_metric) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch == -1 or best_reward < rew: + best_reward, best_reward_std = rew, rew_std best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " - f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " - f"{best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, train_collector, test_collector, diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index f094ddd7d..dae20a741 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,12 +1,12 @@ import time import tqdm +import numpy as np from collections import defaultdict -from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy -from tianshou.utils import tqdm_config, MovAvg +from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger from tianshou.trainer import test_episode, gather_info @@ -16,70 +16,82 @@ def onpolicy_trainer( test_collector: Collector, max_epoch: int, step_per_epoch: int, - collect_per_step: int, repeat_per_collect: int, - episode_per_test: Union[int, List[int]], + episode_per_test: int, batch_size: int, + step_per_collect: Optional[int] = None, + episode_per_collect: Optional[int] = None, train_fn: Optional[Callable[[int, int], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None, stop_fn: Optional[Callable[[float], bool]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, + logger: BaseLogger = LazyLogger(), verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. - The "step" in trainer means a policy network update. + The "step" in trainer means an environment step (a.k.a. transition). - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` - class. - :param train_collector: the collector used for training. - :type train_collector: :class:`~tianshou.data.Collector` - :param test_collector: the collector used for testing. - :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum number of epochs for training. The - training process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of policy network updates, so-called - gradient steps, per epoch. - :param int collect_per_step: the number of episodes the collector would - collect before the network update. In other words, collect some - episodes and do one policy network update. - :param int repeat_per_collect: the number of repeat time for policy - learning, for example, set it to 2 means the policy needs to learn each - given batch data twice. - :param episode_per_test: the number of episodes for one policy evaluation. - :type episode_per_test: int or list of ints - :param int batch_size: the batch size of sample data, which is going to - feed in the policy network. - :param function train_fn: a hook called at the beginning of training in - each epoch. It can be used to perform custom additional operations, - with the signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function test_fn: a hook called at the beginning of testing in each - epoch. It can be used to perform custom additional operations, with the - signature ``f(num_epoch: int, step_idx: int) -> None``. - :param function save_fn: a hook called when the undiscounted average mean - reward in evaluation phase gets better, with the signature ``f(policy: - BasePolicy) -> None``. - :param function stop_fn: a function with signature ``f(mean_rewards: float) - -> bool``, receives the average undiscounted returns of the testing - result, returns a boolean which indicates whether reaching the goal. - :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter; if None is given, it will not write logs to TensorBoard. - :param int log_interval: the log interval of the writer. - :param bool verbose: whether to print the information. - :param bool test_in_train: whether to test in the training phase. + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param Collector train_collector: the collector used for training. + :param Collector test_collector: the collector used for testing. + :param int max_epoch: the maximum number of epochs for training. The training + process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set. + :param int step_per_epoch: the number of transitions collected per epoch. + :param int repeat_per_collect: the number of repeat time for policy learning, for + example, set it to 2 means the policy needs to learn each given batch data + twice. + :param int episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to feed in the + policy network. + :param int step_per_collect: the number of transitions the collector would collect + before the network update, i.e., trainer will collect "step_per_collect" + transitions and do some policy network update repeatly in each epoch. + :param int episode_per_collect: the number of episodes the collector would collect + before the network update, i.e., trainer will collect "episode_per_collect" + episodes and do some policy network update repeatly in each epoch. + :param function train_fn: a hook called at the beginning of training in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each epoch. + It can be used to perform custom additional operations, with the signature ``f( + num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean reward in + evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> + None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) -> + bool``, receives the average undiscounted returns of the testing result, + returns a boolean which indicates whether reaching the goal. + :param function reward_metric: a function with signature ``f(rewards: np.ndarray + with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``, + used in multi-agent RL. We need to return a single scalar for each episode's + result to monitor training in the multi-agent RL setting. This function + specifies what is the desired metric, e.g., the reward of agent 1 or the + average reward over all agents. + :param BaseLogger logger: A logger that logs statistics during + training/testing/updating. Default to a logger that doesn't log anything. + :param bool verbose: whether to print the information. Default to True. + :param bool test_in_train: whether to test in the training phase. Default to True. :return: See :func:`~tianshou.trainer.gather_info`. + + .. note:: + + Only either one of step_per_collect and episode_per_collect can be specified. """ env_step, gradient_step = 0, 0 - best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 + last_rew, last_len = 0.0, 0 stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() test_in_train = test_in_train and train_collector.policy == policy + test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, + logger, env_step, reward_metric) + best_epoch = 0 + best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] for epoch in range(1, 1 + max_epoch): # train policy.train() @@ -89,30 +101,29 @@ def onpolicy_trainer( while t.n < t.total: if train_fn: train_fn(epoch, env_step) - result = train_collector.collect(n_episode=collect_per_step) + result = train_collector.collect(n_step=step_per_collect, + n_episode=episode_per_collect) + if reward_metric: + result["rews"] = reward_metric(result["rews"]) env_step += int(result["n/st"]) + t.update(result["n/st"]) + logger.log_train_data(result, env_step) + last_rew = result['rew'] if 'rew' in result else last_rew + last_len = result['len'] if 'len' in result else last_len data = { "env_step": str(env_step), - "rew": f"{result['rew']:.2f}", - "len": str(int(result["len"])), + "rew": f"{last_rew:.2f}", + "len": str(int(last_len)), "n/ep": str(int(result["n/ep"])), "n/st": str(int(result["n/st"])), - "v/ep": f"{result['v/ep']:.2f}", - "v/st": f"{result['v/st']:.2f}", } - if writer and env_step % log_interval == 0: - for k in result.keys(): - writer.add_scalar( - "train/" + k, result[k], global_step=env_step) if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( policy, test_collector, test_fn, - epoch, episode_per_test, writer, env_step) + epoch, episode_per_test, logger, env_step) if stop_fn(test_result["rew"]): if save_fn: save_fn(policy) - for k in result.keys(): - data[k] = f"{result[k]:.2f}" t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, @@ -128,26 +139,24 @@ def onpolicy_trainer( gradient_step += step for k in losses.keys(): stat[k].add(losses[k]) - data[k] = f"{stat[k].get():.6f}" - if writer and gradient_step % log_interval == 0: - writer.add_scalar( - k, stat[k].get(), global_step=gradient_step) - t.update(step) + losses[k] = stat[k].get() + data[k] = f"{losses[k]:.3f}" + logger.log_update_data(losses, gradient_step) t.set_postfix(**data) if t.n <= t.total: t.update() # test - result = test_episode(policy, test_collector, test_fn, epoch, - episode_per_test, writer, env_step) - if best_epoch == -1 or best_reward < result["rew"]: - best_reward, best_reward_std = result["rew"], result["rew_std"] + test_result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, logger, env_step) + rew, rew_std = test_result["rew"], test_result["rew_std"] + if best_epoch == -1 or best_reward < rew: + best_reward, best_reward_std = rew, rew_std best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " - f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " - f"{best_reward_std:.6f} in #{best_epoch}") + print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew" + f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info(start_time, train_collector, test_collector, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index dfffd71a4..2e729feeb 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,10 +1,10 @@ import time import numpy as np -from torch.utils.tensorboard import SummaryWriter -from typing import Dict, List, Union, Callable, Optional +from typing import Any, Dict, Union, Callable, Optional from tianshou.data import Collector from tianshou.policy import BasePolicy +from tianshou.utils import BaseLogger def test_episode( @@ -12,25 +12,22 @@ def test_episode( collector: Collector, test_fn: Optional[Callable[[int, Optional[int]], None]], epoch: int, - n_episode: Union[int, List[int]], - writer: Optional[SummaryWriter] = None, + n_episode: int, + logger: Optional[BaseLogger] = None, global_step: Optional[int] = None, -) -> Dict[str, float]: + reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, +) -> Dict[str, Any]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() policy.eval() if test_fn: test_fn(epoch, global_step) - if collector.get_env_num() > 1 and isinstance(n_episode, int): - n = collector.get_env_num() - n_ = np.zeros(n) + n_episode // n - n_[:n_episode % n] += 1 - n_episode = list(n_) result = collector.collect(n_episode=n_episode) - if writer is not None and global_step is not None: - for k in result.keys(): - writer.add_scalar("test/" + k, result[k], global_step=global_step) + if reward_metric: + result["rews"] = reward_metric(result["rews"]) + if logger and global_step is not None: + logger.log_test_data(result, global_step) return result @@ -47,14 +44,14 @@ def gather_info( * ``train_step`` the total collected step of training collector; * ``train_episode`` the total collected episode of training collector; - * ``train_time/collector`` the time for collecting frames in the \ + * ``train_time/collector`` the time for collecting transitions in the \ training collector; * ``train_time/model`` the time for training models; - * ``train_speed`` the speed of training (frames per second); + * ``train_speed`` the speed of training (env_step per second); * ``test_step`` the total collected step of test collector; * ``test_episode`` the total collected episode of test collector; * ``test_time`` the time for testing; - * ``test_speed`` the speed of testing (frames per second); + * ``test_speed`` the speed of testing (env_step per second); * ``best_reward`` the best reward over the test results; * ``duration`` the total elapsed time. """ diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index d3a371577..b8cfa2315 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,9 +1,11 @@ from tianshou.utils.config import tqdm_config from tianshou.utils.moving_average import MovAvg -from tianshou.utils.log_tools import SummaryWriter +from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger __all__ = [ "MovAvg", "tqdm_config", - "SummaryWriter", + "BaseLogger", + "BasicLogger", + "LazyLogger", ] diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py index bbbd82e67..c50c8ebae 100644 --- a/tianshou/utils/log_tools.py +++ b/tianshou/utils/log_tools.py @@ -1,47 +1,159 @@ -import threading -from torch.utils import tensorboard -from typing import Any, Dict, Optional - - -class SummaryWriter(tensorboard.SummaryWriter): - """A more convenient Summary Writer(`tensorboard.SummaryWriter`). - - You can get the same instance of summary writer everywhere after you - created one. - :: - - >>> writer1 = SummaryWriter.get_instance( - key="first", log_dir="log/test_sw/first") - >>> writer2 = SummaryWriter.get_instance() - >>> writer1 is writer2 - True - >>> writer4 = SummaryWriter.get_instance( - key="second", log_dir="log/test_sw/second") - >>> writer5 = SummaryWriter.get_instance(key="second") - >>> writer1 is not writer4 - True - >>> writer4 is writer5 - True +import numpy as np +from numbers import Number +from typing import Any, Union +from abc import ABC, abstractmethod +from torch.utils.tensorboard import SummaryWriter + + +class BaseLogger(ABC): + """The base class for any logger which is compatible with trainer.""" + + def __init__(self, writer: Any) -> None: + super().__init__() + self.writer = writer + + @abstractmethod + def write( + self, + key: str, + x: Union[Number, np.number, np.ndarray], + y: Union[Number, np.number, np.ndarray], + **kwargs: Any, + ) -> None: + """Specify how the writer is used to log data. + + :param key: namespace which the input data tuple belongs to. + :param x: stands for the ordinate of the input data tuple. + :param y: stands for the abscissa of the input data tuple. + """ + pass + + def log_train_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during training. + + :param collect_result: a dict containing information of data collected in + training stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + def log_update_data(self, update_result: dict, step: int) -> None: + """Use writer to log statistics generated during updating. + + :param update_result: a dict containing information of data collected in + updating stage, i.e., returns of policy.update(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + def log_test_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during evaluating. + + :param collect_result: a dict containing information of data collected in + evaluating stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + """ + pass + + +class BasicLogger(BaseLogger): + """A loggger that relies on tensorboard SummaryWriter by default to visualize \ + and log statistics. + + You can also rewrite write() func to use your own writer. + + :param SummaryWriter writer: the writer to log data. + :param int train_interval: the log interval in log_train_data(). Default to 1. + :param int test_interval: the log interval in log_test_data(). Default to 1. + :param int update_interval: the log interval in log_update_data(). Default to 1000. """ - _mutex_lock = threading.Lock() - _default_key: str - _instance: Optional[Dict[str, "SummaryWriter"]] = None + def __init__( + self, + writer: SummaryWriter, + train_interval: int = 1, + test_interval: int = 1, + update_interval: int = 1000, + ) -> None: + super().__init__(writer) + self.train_interval = train_interval + self.test_interval = test_interval + self.update_interval = update_interval + self.last_log_train_step = -1 + self.last_log_test_step = -1 + self.last_log_update_step = -1 + + def write( + self, + key: str, + x: Union[Number, np.number, np.ndarray], + y: Union[Number, np.number, np.ndarray], + **kwargs: Any, + ) -> None: + self.writer.add_scalar(key, y, global_step=x) + + def log_train_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during training. + + :param collect_result: a dict containing information of data collected in + training stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + + .. note:: + + ``collect_result`` will be modified in-place with "rew" and "len" keys. + """ + if collect_result["n/ep"] > 0: + collect_result["rew"] = collect_result["rews"].mean() + collect_result["len"] = collect_result["lens"].mean() + if step - self.last_log_train_step >= self.train_interval: + self.write("train/n/ep", step, collect_result["n/ep"]) + self.write("train/rew", step, collect_result["rew"]) + self.write("train/len", step, collect_result["len"]) + self.last_log_train_step = step + + def log_test_data(self, collect_result: dict, step: int) -> None: + """Use writer to log statistics generated during evaluating. + + :param collect_result: a dict containing information of data collected in + evaluating stage, i.e., returns of collector.collect(). + :param int step: stands for the timestep the collect_result being logged. + + .. note:: + + ``collect_result`` will be modified in-place with "rew", "rew_std", "len", + and "len_std" keys. + """ + assert collect_result["n/ep"] > 0 + rews, lens = collect_result["rews"], collect_result["lens"] + rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std() + collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std) + if step - self.last_log_test_step >= self.test_interval: + self.write("test/rew", step, rew) + self.write("test/len", step, len_) + self.write("test/rew_std", step, rew_std) + self.write("test/len_std", step, len_std) + self.last_log_test_step = step + + def log_update_data(self, update_result: dict, step: int) -> None: + if step - self.last_log_update_step >= self.update_interval: + for k, v in update_result.items(): + self.write(k, step, v) + self.last_log_update_step = step + + +class LazyLogger(BasicLogger): + """A loggger that does nothing. Used as the placeholder in trainer.""" + + def __init__(self) -> None: + super().__init__(None) # type: ignore - @classmethod - def get_instance( - cls, - key: Optional[str] = None, - *args: Any, + def write( + self, + key: str, + x: Union[Number, np.number, np.ndarray], + y: Union[Number, np.number, np.ndarray], **kwargs: Any, - ) -> "SummaryWriter": - """Get instance of torch.utils.tensorboard.SummaryWriter by key.""" - with SummaryWriter._mutex_lock: - if key is None: - key = SummaryWriter._default_key - if SummaryWriter._instance is None: - SummaryWriter._instance = {} - SummaryWriter._default_key = key - if key not in SummaryWriter._instance.keys(): - SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs) - return SummaryWriter._instance[key] + ) -> None: + """The LazyLogger writes nothing.""" + pass diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 1da33650b..b41346e9a 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -34,7 +34,7 @@ class MLP(nn.Module): :param hidden_sizes: shape of MLP passed in as a list, not incluing input_dim and output_dim. :param norm_layer: use which normalization before activation, e.g., - ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to no normalization. + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. You can also pass a list of normalization modules with the same length of hidden_sizes, to use different normalization module in different layers. Default to no normalization. @@ -103,7 +103,7 @@ class Net(nn.Module): :param action_shape: int or a sequence of int of the shape of action. :param hidden_sizes: shape of MLP passed in as a list. :param norm_layer: use which normalization before activation, e.g., - ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to no normalization. + ``nn.LayerNorm`` and ``nn.BatchNorm1d``. Default to no normalization. You can also pass a list of normalization modules with the same length of hidden_sizes, to use different normalization module in different layers. Default to no normalization. @@ -118,13 +118,13 @@ class Net(nn.Module): :param bool concat: whether the input shape is concatenated by state_shape and action_shape. If it is True, ``action_shape`` is not the output shape, but affects the input shape only. - :param int num_atoms: in order to expand to the net of distributional RL, - defaults to 1 (not use). + :param int num_atoms: in order to expand to the net of distributional RL. + Default to 1 (not use). :param bool dueling_param: whether to use dueling network to calculate Q values (for Dueling DQN). If you want to use dueling option, you should pass a tuple of two dict (first for Q and second for V) stating self-defined arguments as stated in - class:`~tianshou.utils.net.common.MLP`. Defaults to None. + class:`~tianshou.utils.net.common.MLP`. Default to None. .. seealso::