这是indexloc提供的服务,不要输入任何密码
Skip to content

Step collector implementation #280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 122 commits into from
Feb 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
8229b47
add first version of cached replay buffer(baseline), add standard api…
ChenDRAG Jan 19, 2021
483404f
add cached buffer, vec buffer
ChenDRAG Jan 20, 2021
942c2a3
simple pep8 fix
ChenDRAG Jan 20, 2021
ec4b246
init
ChenDRAG Jan 20, 2021
e7f631e
Merge branch 'master(net/utils change #275)' into cached
ChenDRAG Jan 20, 2021
da564be
Merge branch 'master' into cached
Trinkle23897 Jan 20, 2021
36e799e
some change
ChenDRAG Jan 22, 2021
50b20a0
update ReplayBuffer
Trinkle23897 Jan 22, 2021
3e487dc
Merge branch 'cached' of github.com:ChenDRAG/tianshou into cached
ChenDRAG Jan 22, 2021
0ac97af
refactor ReplayBuffer
ChenDRAG Jan 22, 2021
bd90b97
refactor vec/cached buffer
ChenDRAG Jan 22, 2021
2b7c227
pep8 fix
ChenDRAG Jan 22, 2021
3afb2cb
update VectorReplayBuffer and add test
Trinkle23897 Jan 23, 2021
443969d
update cached
Trinkle23897 Jan 24, 2021
ee51e64
order change, small fix
ChenDRAG Jan 25, 2021
a5bc4ad
try unittest
Trinkle23897 Jan 25, 2021
17e3612
add more test and fix bugs
Trinkle23897 Jan 25, 2021
7eba23d
fix a bug and add some corner-case tests
Trinkle23897 Jan 25, 2021
b9f4f2a
re-implement sample_avail function and add test for CachedReplayBuffe…
Trinkle23897 Jan 26, 2021
8fe85f8
improve documents
Trinkle23897 Jan 26, 2021
f59a530
ReplayBuffers._offset
Trinkle23897 Jan 26, 2021
425c2bd
fix atari-style update; support CachedBuffer with main_buffer==PrioBu…
Trinkle23897 Jan 27, 2021
2361755
assert _meta.is_empty() in ReplayBuffers init
Trinkle23897 Jan 27, 2021
0160a7f
Merge branch 'master' into cached
Trinkle23897 Jan 28, 2021
75d581b
small fix
Trinkle23897 Jan 28, 2021
b5d93f3
small fix
Trinkle23897 Jan 28, 2021
c8f27c9
improve coverage
Trinkle23897 Jan 28, 2021
9c879f2
small buffer change
ChenDRAG Jan 28, 2021
679fe27
draft of step_collector, not finished yet
ChenDRAG Jan 28, 2021
26bb74c
pep8 fix
ChenDRAG Jan 28, 2021
74df1c5
fix ci
Trinkle23897 Jan 28, 2021
39463ba
recover speed to 2000+
Trinkle23897 Jan 28, 2021
31f0c94
improve documents
Trinkle23897 Jan 28, 2021
720da29
ReplayBuffers -> ReplayBufferManager; alloc_fn -> _buffer_allocator; …
Trinkle23897 Jan 28, 2021
b9385d8
re-organize test_buffer.py
Trinkle23897 Jan 28, 2021
16bb42e
improve test
Trinkle23897 Jan 29, 2021
cada7cc
test if can be faster
Trinkle23897 Jan 29, 2021
ea50ed1
Merge branch 'cached' into step_collector
ChenDRAG Jan 29, 2021
13683a0
update tianshou/trainer
ChenDRAG Jan 29, 2021
2242089
Merge branch 'master' into step_collector
Trinkle23897 Jan 29, 2021
1fed4a6
small change
ChenDRAG Jan 29, 2021
f44287d
Merge branch 'step_collector' of github.com:ChenDRAG/tianshou into st…
ChenDRAG Jan 29, 2021
cf8a738
collector's buffer type check
ChenDRAG Jan 29, 2021
174f037
fix syntax err
Trinkle23897 Jan 30, 2021
8810c23
change collector API in all files
ChenDRAG Jan 30, 2021
a0c880b
Merge branch 'step_collector' of github.com:chenDRAG/tianshou into st…
Trinkle23897 Jan 30, 2021
7b62804
rewrite multibuf.add and buffer.update
Trinkle23897 Jan 30, 2021
3f62101
buffer update (draft)
Trinkle23897 Feb 1, 2021
036cf14
buffer update: fix test
Trinkle23897 Feb 2, 2021
4b86274
VectorBuffer
Trinkle23897 Feb 2, 2021
4ffdb82
vectorbuf
Trinkle23897 Feb 2, 2021
5ce4684
merge master
Trinkle23897 Feb 2, 2021
13a41fa
ReplayBuffer now support add batch-style data when buffer_ids is not …
Trinkle23897 Feb 3, 2021
20f2549
finish step collector, time to write unittest
Trinkle23897 Feb 4, 2021
03a1d2f
add test
Trinkle23897 Feb 4, 2021
43e69a7
finish step_collector test
Trinkle23897 Feb 4, 2021
311fba6
add more test about atari-style buffer setting and CachedReplayBuffer
Trinkle23897 Feb 4, 2021
a99a613
fix Collector(buffer=None)
Trinkle23897 Feb 4, 2021
220854f
collector enhance
ChenDRAG Feb 4, 2021
59abc05
Merge branch 'step_collector' of github.com:ChenDRAG/tianshou into st…
ChenDRAG Feb 4, 2021
1b8275e
add expl nosie in collector
ChenDRAG Feb 7, 2021
509e417
refactor basepolicy to coordinate with step collector
ChenDRAG Feb 7, 2021
a1c9851
fix bug, test_collector passed now
ChenDRAG Feb 8, 2021
6c6d503
change buffer setup in all files, switch buffer to vec buffer
ChenDRAG Feb 8, 2021
e683c08
change all api to coordinate with new collector and vec buffer
ChenDRAG Feb 8, 2021
8e12504
small update
Trinkle23897 Feb 8, 2021
05140aa
fix some bug
ChenDRAG Feb 8, 2021
59a63da
Merge branch 'step_collector' of github.com:ChenDRAG/tianshou into st…
ChenDRAG Feb 8, 2021
f96f81b
coordinate test files and fix on bug on summary writer
ChenDRAG Feb 8, 2021
1f105d5
AsyncCollector, still need more test
Trinkle23897 Feb 8, 2021
d30bfce
Merge branch 'step_collector' of github.com:chenDRAG/tianshou into st…
Trinkle23897 Feb 8, 2021
0a0dd26
AsyncCollector test
Trinkle23897 Feb 8, 2021
3dbb65f
flake8 maxlen 119
Trinkle23897 Feb 9, 2021
6155844
black format
Trinkle23897 Feb 9, 2021
c6431cf
fix test and docs for collector
Trinkle23897 Feb 9, 2021
581e5c8
some other bugs fix
ChenDRAG Feb 9, 2021
e1e9ba2
Merge branch 'step_collector' of github.com:ChenDRAG/tianshou into st…
ChenDRAG Feb 9, 2021
26dc2a1
88
Trinkle23897 Feb 9, 2021
d3dd8ea
small update
Trinkle23897 Feb 9, 2021
c37bf24
fix cuda bug
ChenDRAG Feb 9, 2021
9656a1e
another fix
ChenDRAG Feb 9, 2021
15bd6e7
Apply suggestions from code review
Trinkle23897 Feb 9, 2021
38496cf
last fix
ChenDRAG Feb 9, 2021
0575136
Merge branch 'cudafix' of github.com:ChenDRAG/tianshou into cudafix
ChenDRAG Feb 9, 2021
5d60979
Merge branch 'cudafix' into step_collector
ChenDRAG Feb 9, 2021
a40700d
pep8 fix
ChenDRAG Feb 9, 2021
5c46415
fix test_ppo
Trinkle23897 Feb 9, 2021
4df7bd1
Merge branch 'step_collector' of github.com:chenDRAG/tianshou into st…
Trinkle23897 Feb 9, 2021
b91e44f
fix 4096
Trinkle23897 Feb 9, 2021
d75f5b5
fix a2c and pg
Trinkle23897 Feb 9, 2021
f16c5ff
fix print
Trinkle23897 Feb 9, 2021
924fde6
fix some mypy
Trinkle23897 Feb 9, 2021
7f5b51c
fix multidim target_q nstep error
Trinkle23897 Feb 9, 2021
4ab067f
fix throughput
Trinkle23897 Feb 9, 2021
4f6a983
open nstep njit
Trinkle23897 Feb 9, 2021
cbb0fda
fix assert False
Trinkle23897 Feb 9, 2021
6c0b52c
fix a strange bug in buffer init
ChenDRAG Feb 10, 2021
e8a71a7
some bugs in nstep
Trinkle23897 Feb 10, 2021
ea6b3c3
fix n_episode test
Trinkle23897 Feb 10, 2021
3365eec
value mask bug fix
ChenDRAG Feb 10, 2021
d408243
merge0.3.2
ChenDRAG Feb 16, 2021
8a61dd9
update doc
ChenDRAG Feb 17, 2021
dc3e150
pep8 fix
ChenDRAG Feb 17, 2021
f8e7310
drqn
ChenDRAG Feb 17, 2021
2bc9663
Merge branch 'dev' into step_collector
Trinkle23897 Feb 18, 2021
4ffbf58
Apply suggestions from code review
Trinkle23897 Feb 18, 2021
51f06b8
add reward_metric in trainer
Trinkle23897 Feb 18, 2021
d0357c1
add exploration_noise in mapolicy
Trinkle23897 Feb 18, 2021
64fe3ce
it works!
Trinkle23897 Feb 18, 2021
64e04df
fix test
Trinkle23897 Feb 18, 2021
2275efb
fix dqn family eps-test
Trinkle23897 Feb 18, 2021
1145359
fix dead loop in creating new Batch (drqn _is_scalar replace np.asany…
Trinkle23897 Feb 18, 2021
c7a624f
split exploration_noise in bcq
Trinkle23897 Feb 18, 2021
d5fa008
add test for priovecbuf
Trinkle23897 Feb 18, 2021
c3b35d4
improve coverage
Trinkle23897 Feb 18, 2021
0cae28c
fix several bugs of documentation
Trinkle23897 Feb 18, 2021
b7efc68
fix test
Trinkle23897 Feb 18, 2021
b2098ee
add a test of batch
Trinkle23897 Feb 18, 2021
4f181cc
fix test
Trinkle23897 Feb 18, 2021
8be5b2c
add a note
Trinkle23897 Feb 18, 2021
cb4dbda
remove redundant code
Trinkle23897 Feb 19, 2021
9b46678
polish docs organization
Trinkle23897 Feb 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,14 @@ Currently, the overall code of Tianshou platform is less than 2500 lines. Most o
```python
result = collector.collect(n_step=n)
```

If you have 3 environments in total and want to collect 1 episode in the first environment, 3 for the third environment:
If you have 3 environments in total and want to collect 4 episodes:

```python
result = collector.collect(n_episode=[1, 0, 3])
result = collector.collect(n_episode=4)
```

Collector will collect exactly 4 episodes without any bias of episode length despite we only have 3 parallel environments.

If you want to train the given policy with a sampled batch:

```python
Expand Down Expand Up @@ -194,7 +195,7 @@ train_num, test_num = 8, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, collect_per_step = 1000, 10
step_per_epoch, collect_per_step = 1000, 8
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
```

Expand Down Expand Up @@ -223,8 +224,8 @@ Setup policy and collectors:

```python
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method
```

Let's train it:
Expand Down Expand Up @@ -252,7 +253,7 @@ Watch the performance with 35 FPS:
```python
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)
```

Expand Down
24 changes: 23 additions & 1 deletion docs/api/tianshou.data.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
tianshou.data
=============

.. automodule:: tianshou.data

Batch
-----

.. automodule:: tianshou.data.batch
:members:
:undoc-members:
:show-inheritance:


Buffer
------

.. automodule:: tianshou.data.buffer
:members:
:undoc-members:
:show-inheritance:


Collector
---------

.. automodule:: tianshou.data.collector
:members:
:undoc-members:
:show-inheritance:
8 changes: 8 additions & 0 deletions docs/api/tianshou.env.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
tianshou.env
============


VectorEnv
---------

.. automodule:: tianshou.env
:members:
:undoc-members:
:show-inheritance:


Worker
------

.. automodule:: tianshou.env.worker
:members:
:undoc-members:
Expand Down
4 changes: 4 additions & 0 deletions docs/api/tianshou.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ tianshou.utils
:undoc-members:
:show-inheritance:


Pre-defined Networks
--------------------

.. automodule:: tianshou.utils.net.common
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
]
)
}
autodoc_member_order = "bysource"
bibtex_bibfiles = ['refs.bib']

# -- Options for HTML output -------------------------------------------------
Expand Down
18 changes: 10 additions & 8 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ And finally,
::

test_processor = MyProcessor(size=100)
collector = Collector(policy, env, buffer, test_processor.preprocess_fn)
collector = Collector(policy, env, buffer, preprocess_fn=test_processor.preprocess_fn)

Some examples are in `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_.

Expand All @@ -156,7 +156,7 @@ RNN-style Training

This is related to `Issue 19 <https://github.com/thu-ml/tianshou/issues/19>`_.

First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`:
First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :class:`~tianshou.data.VectorReplayBuffer`, or other types of buffer you are using, like:
::

buf = ReplayBuffer(size=size, stack_num=stack_num)
Expand Down Expand Up @@ -206,14 +206,13 @@ The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1
It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as:
::

>>> from tianshou.data import ReplayBuffer
>>> from tianshou.data import Batch, ReplayBuffer
>>> b = ReplayBuffer(size=3)
>>> b.add(obs=e.reset(), act=0, rew=0, done=0)
>>> b.add(Batch(obs=e.reset(), act=0, rew=0, done=0))
>>> print(b)
ReplayBuffer(
act: array([0, 0, 0]),
done: array([0, 0, 0]),
info: Batch(),
done: array([False, False, False]),
obs: Batch(
achieved_goal: array([[1.34183265, 0.74910039, 0.53472272],
[0. , 0. , 0. ],
Expand All @@ -234,7 +233,6 @@ It shows that the state is a dictionary which has 3 keys. It will stored in :cla
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]]),
),
policy: Batch(),
rew: array([0, 0, 0]),
)
>>> print(b.obs.achieved_goal)
Expand Down Expand Up @@ -278,7 +276,7 @@ For self-defined class, the replay buffer will store the reference into a ``nump

>>> import networkx as nx
>>> b = ReplayBuffer(size=3)
>>> b.add(obs=nx.Graph(), act=0, rew=0, done=0)
>>> b.add(Batch(obs=nx.Graph(), act=0, rew=0, done=0))
>>> print(b)
ReplayBuffer(
act: array([0, 0, 0]),
Expand All @@ -299,6 +297,10 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y
...
return copy.deepcopy(self.graph), reward, done, {}

.. note ::

Please make sure this variable is numpy-compatible, e.g., np.array([variable]) will not result in an empty array. Otherwise, ReplayBuffer cannot create an numpy array to store it.


.. _marl_example:

Expand Down
12 changes: 5 additions & 7 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
Buffer
------

:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of :class:`~tianshou.data.Batch`. It stores all the data in a batch with circular-queue style.

The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
Expand Down Expand Up @@ -209,7 +209,7 @@ The following code snippet illustrates its usage, including:

</details><br>

Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
Tianshou provides other type of data buffer such as :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``) and :class:`~tianshou.data.VectorReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.


Policy
Expand Down Expand Up @@ -339,14 +339,12 @@ Collector

The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.

:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.

Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.

The proposed solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer, then return the statistics of the collected data such as episode's total reward.

The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.

There is also another type of collector :class:`~tianshou.data.AsyncCollector` which supports asynchronous environment setting (for those taking a long time to step). However, AsyncCollector only supports **at least** ``n_step`` or ``n_episode`` collection due to the property of asynchronous environments.


Trainer
-------
Expand Down
15 changes: 7 additions & 8 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
::

train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
test_collector = ts.data.Collector(policy, test_envs)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 8), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)


Train Policy with a Trainer
Expand Down Expand Up @@ -191,7 +191,7 @@ Watch the Agent's Performance

policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, env)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)


Expand All @@ -206,20 +206,19 @@ Tianshou supports user-defined training code. Here is the code snippet:
::

# pre-collect at least 5000 frames with random action before training
policy.set_eps(1)
train_collector.collect(n_step=5000)
train_collector.collect(n_step=5000, random=True)

policy.set_eps(0.1)
for i in range(int(1e6)): # total step
collect_result = train_collector.collect(n_step=10)

# once if the collected episodes' mean returns reach the threshold,
# or every 1000 steps, we test it on test_collector
if collect_result['rew'] >= env.spec.reward_threshold or i % 1000 == 0:
if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0:
policy.set_eps(0.05)
result = test_collector.collect(n_episode=100)
if result['rew'] >= env.spec.reward_threshold:
print(f'Finished training! Test mean returns: {result["rew"]}')
if result['rews'].mean() >= env.spec.reward_threshold:
print(f'Finished training! Test mean returns: {result["rews"].mean()}')
break
else:
# back to training eps
Expand Down
57 changes: 31 additions & 26 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
>>>
>>> # use collectors to collect a episode of trajectories
>>> # the reward is a vector, so we need a scalar metric to monitor the training
>>> collector = Collector(policy, env, reward_metric=lambda x: x[0])
>>> collector = Collector(policy, env)
>>>
>>> # you will see a long trajectory showing the board status at each timestep
>>> result = collector.collect(n_episode=1, render=.1)
Expand Down Expand Up @@ -180,7 +180,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import BasePolicy, RandomPolicy, DQNPolicy, MultiAgentPolicyManager

from tic_tac_toe_env import TicTacToeEnv
Expand All @@ -199,27 +199,27 @@ The explanation of each Tianshou class/function will be deferred to their first
help='a smaller gamma favors earlier win')
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.1)
parser.add_argument('--board_size', type=int, default=6)
parser.add_argument('--win_size', type=int, default=4)
parser.add_argument('--win-rate', type=float, default=np.float32(0.9),
parser.add_argument('--board-size', type=int, default=6)
parser.add_argument('--win-size', type=int, default=4)
parser.add_argument('--win-rate', type=float, default=0.9,
help='the expected winning rate')
parser.add_argument('--watch', default=False, action='store_true',
help='no training, watch the play of pre-trained models')
parser.add_argument('--agent_id', type=int, default=2,
parser.add_argument('--agent-id', type=int, default=2,
help='the learned agent plays as the agent_id-th player. Choices are 1 and 2.')
parser.add_argument('--resume_path', type=str, default='',
parser.add_argument('--resume-path', type=str, default='',
help='the path of agent pth file for resuming from a pre-trained agent')
parser.add_argument('--opponent_path', type=str, default='',
parser.add_argument('--opponent-path', type=str, default='',
help='the path of opponent agent pth file for resuming from a pre-trained agent')
parser.add_argument('--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -240,11 +240,13 @@ Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, whi
Here it is:
::

def get_agents(args=get_args(),
agent_learn=None, # BasePolicy
agent_opponent=None, # BasePolicy
optim=None, # torch.optim.Optimizer
): # return a tuple of (BasePolicy, torch.optim.Optimizer)
def get_agents(
args=get_args(),
agent_learn=None, # BasePolicy
agent_opponent=None, # BasePolicy
optim=None, # torch.optim.Optimizer
): # return a tuple of (BasePolicy, torch.optim.Optimizer)

env = TicTacToeEnv(args.board_size, args.win_size)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
Expand Down Expand Up @@ -279,9 +281,6 @@ With the above preparation, we are close to the first learned agent. The followi
::

args = get_args()
# the reward is a vector, we need a scalar metric to monitor the training.
# we choose the reward of the learning agent
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]

# ======== a test function that tests a pre-trained agent and exit ======
def watch(args=get_args(),
Expand All @@ -294,7 +293,7 @@ With the above preparation, we are close to the first learned agent. The followi
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}')
if args.watch:
watch(args)
exit(0)
Expand All @@ -313,9 +312,10 @@ With the above preparation, we are close to the first learned agent. The followi
policy, optim = get_agents()

# ======== collector setup =========
train_collector = Collector(policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.batch_size)
buffer = VectorReplayBuffer(args.buffer_size, args.training_num)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num)

# ======== tensorboard logging setup =========
if not hasattr(args, 'writer'):
Expand Down Expand Up @@ -347,13 +347,18 @@ With the above preparation, we are close to the first learned agent. The followi
def test_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_test)

# the reward is a vector, we need a scalar metric to monitor the training.
# we choose the reward of the learning agent
def reward_metric(rews):
return rews[:, args.agent_id - 1]

# start training, this may require about three minutes
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
test_in_train=False)
stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric,
writer=writer, test_in_train=False)

agent = policy.policies[args.agent_id - 1]
# let's watch the match!
Expand Down Expand Up @@ -476,7 +481,7 @@ By default, the trained agent is stored in ``log/tic_tac_toe/dqn/policy.pth``. Y

.. code-block:: console

$ python test_tic_tac_toe.py --watch --resume_path=log/tic_tac_toe/dqn/policy.pth --opponent_path=log/tic_tac_toe/dqn/policy.pth
$ python test_tic_tac_toe.py --watch --resume-path log/tic_tac_toe/dqn/policy.pth --opponent-path log/tic_tac_toe/dqn/policy.pth

Here is our output:

Expand Down
Loading