这是indexloc提供的服务,不要输入任何密码
Skip to content

set policy.eval() before collector.collect #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,8 @@ $ pip install tianshou
You can also install with the newest version through GitHub:

```bash
# latest release
# latest version
$ pip install git+https://github.com/thu-ml/tianshou.git@master
# develop version
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
```

If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
Expand Down
8 changes: 1 addition & 7 deletions docs/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ To install Tianshou in an "editable" mode, run

.. code-block:: bash

$ git checkout dev
$ pip install -e ".[dev]"

in the main directory. This installation is removable by
Expand Down Expand Up @@ -70,9 +69,4 @@ To compile documentation into webpages, run

under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.

Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.

Pull Request
------------

All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
6 changes: 2 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ You can also install with the newest version through GitHub:

.. code-block:: bash

# latest release
# latest version
$ pip install git+https://github.com/thu-ml/tianshou.git@master
# develop version
$ pip install git+https://github.com/thu-ml/tianshou.git@dev

If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:

Expand All @@ -70,7 +68,7 @@ After installation, open your python console and type

If no error occurs, you have successfully installed Tianshou.

Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.

.. toctree::
:maxdepth: 1
Expand Down
8 changes: 7 additions & 1 deletion tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def collect(self,
n_episode: Optional[Union[int, List[int]]] = None,
random: bool = False,
render: Optional[float] = None,
no_grad: bool = True,
) -> Dict[str, float]:
"""Collect a specified number of step or episode.

Expand All @@ -185,6 +186,8 @@ def collect(self,
defaults to ``False``.
:param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward,
defaults to ``True`` (no gradient retaining).

.. note::

Expand Down Expand Up @@ -252,7 +255,10 @@ def collect(self,
result = Batch(
act=[spaces[i].sample() for i in self._ready_env_ids])
else:
with torch.no_grad():
if no_grad:
with torch.no_grad(): # faster than retain_grad version
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)

state = result.get('state', Batch())
Expand Down
11 changes: 6 additions & 5 deletions tianshou/trainer/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def offpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
Expand All @@ -99,9 +99,10 @@ def offpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += collect_per_step
Expand Down
11 changes: 6 additions & 5 deletions tianshou/trainer/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def onpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
Expand All @@ -99,9 +99,10 @@ def onpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
losses = policy.update(
0, train_collector.buffer, batch_size, repeat_per_collect)
train_collector.reset_buffer()
Expand Down