这是indexloc提供的服务,不要输入任何密码
Skip to content

type check in unit test #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/lint_and_docs.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: PEP8 and Docs Check
name: PEP8, Types and Docs Check

on: [push, pull_request]

Expand All @@ -20,6 +20,9 @@ jobs:
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Type check
run: |
mypy
- name: Documentation test
run: |
pydocstyle tianshou
Expand Down
10 changes: 10 additions & 0 deletions docs/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ We follow PEP8 python code style. To check, in the main directory, run:
$ flake8 . --count --show-source --statistics


Type Check
----------

We use `mypy <https://github.com/python/mypy/>`_ to check the type annotations. To check, in the main directory, run:

.. code-block:: bash

$ mypy


Test Locally
------------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Here is Tianshou's other features:
* Support :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe`
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking

中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_

Expand Down
23 changes: 23 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
[mypy]
files = tianshou/**/*.py
allow_redefinition = True
check_untyped_defs = True
disallow_incomplete_defs = True
disallow_untyped_defs = True
ignore_missing_imports = True
no_implicit_optional = True
pretty = True
show_error_codes = True
show_error_context = True
show_traceback = True
strict_equality = True
strict_optional = True
warn_no_return = True
warn_redundant_casts = True
warn_unreachable = True
warn_unused_configs = True
warn_unused_ignores = True

[mypy-tianshou.utils.net.*]
ignore_errors = True

[pydocstyle]
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402

Expand Down
15 changes: 0 additions & 15 deletions test/throughput/test_collector_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ def test_collect_ep(data):
data["collector"].collect(n_episode=10)


def test_sample(data):
for _ in range(5000):
data["collector"].sample(256)


def test_init_vec_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_vec"], data["buffer"])
Expand All @@ -125,11 +120,6 @@ def test_collect_vec_env_ep(data):
data["collector_vec"].collect(n_episode=10)


def test_sample_vec_env(data):
for _ in range(5000):
data["collector_vec"].sample(256)


def test_init_subproc_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
Expand All @@ -150,10 +140,5 @@ def test_collect_subproc_env_ep(data):
data["collector_subproc"].collect(n_episode=10)


def test_sample_subproc_env(data):
for _ in range(5000):
data["collector_subproc"].sample(256)


if __name__ == '__main__':
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])
2 changes: 1 addition & 1 deletion tianshou/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration


__version__ = "0.2.7"
__version__ = "0.3.0rc0"

__all__ = [
"env",
Expand Down
51 changes: 17 additions & 34 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numbers import Number
from collections.abc import Collection
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
Sequence, KeysView, ValuesView, ItemsView
Sequence

# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
Expand Down Expand Up @@ -144,7 +144,7 @@ def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
try:
return torch.stack(v)
return torch.stack(v) # type: ignore
except RuntimeError as e:
raise TypeError("Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.") from e
Expand Down Expand Up @@ -191,12 +191,20 @@ def __init__(
elif _is_batch_set(batch_dict):
self.stack_(batch_dict)
if len(kwargs) > 0:
self.__init__(kwargs, copy=copy)
self.__init__(kwargs, copy=copy) # type: ignore

def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
self.__dict__[key] = _parse_value(value)

def __getattr__(self, key: str) -> Any:
"""Return self.key. The "Any" return type is needed for mypy."""
return getattr(self.__dict__, key)

def __contains__(self, key: str) -> bool:
"""Return key in self."""
return key in self.__dict__

def __getstate__(self) -> Dict[str, Any]:
"""Pickling interface.

Expand All @@ -215,11 +223,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
At this point, self is an empty Batch instance that has not been
initialized, so it can safely be initialized by the pickle state.
"""
self.__init__(**state)
self.__init__(**state) # type: ignore

def __getitem__(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
) -> Union["Batch", np.ndarray, torch.Tensor]:
) -> Any:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
Expand All @@ -245,7 +253,7 @@ def __setitem__(
if isinstance(index, str):
self.__dict__[index] = value
return
if isinstance(value, (np.ndarray, torch.Tensor)):
if not isinstance(value, Batch):
raise ValueError("Batch does not supported tensor assignment. "
"Use a compatible Batch or dict instead.")
if not set(value.keys()).issubset(self.__dict__.keys()):
Expand Down Expand Up @@ -330,30 +338,6 @@ def __repr__(self) -> str:
s = self.__class__.__name__ + "()"
return s

def __contains__(self, key: str) -> bool:
"""Return key in self."""
return key in self.__dict__

def keys(self) -> KeysView[str]:
"""Return self.keys()."""
return self.__dict__.keys()

def values(self) -> ValuesView[Any]:
"""Return self.values()."""
return self.__dict__.values()

def items(self) -> ItemsView[str, Any]:
"""Return self.items()."""
return self.__dict__.items()

def get(self, k: str, d: Optional[Any] = None) -> Any:
"""Return self[k] if k in self else d. d defaults to None."""
return self.__dict__.get(k, d)

def pop(self, k: str, d: Optional[Any] = None) -> Any:
"""Return & remove self[k] if k in self else d. d defaults to None."""
return self.__dict__.pop(k, d)

def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
for k, v in self.items():
Expand All @@ -375,7 +359,6 @@ def to_torch(
if isinstance(v, torch.Tensor):
if dtype is not None and v.dtype != dtype or \
v.device.type != device.type or \
device.index is not None and \
device.index != v.device.index:
if dtype is not None:
v = v.type(dtype)
Expand Down Expand Up @@ -517,7 +500,7 @@ def stack_(
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
if not self.is_empty():
batches = [self] + list(batches)
batches = [self] + batches
# collect non-empty keys
keys_map = [
set(k for k, v in batch.items()
Expand Down Expand Up @@ -672,8 +655,8 @@ def __len__(self) -> int:
for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True):
continue
elif hasattr(v, "__len__") and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0
elif hasattr(v, "__len__") and (
isinstance(v, Batch) or v.ndim > 0
):
r.append(len(v))
else:
Expand Down
20 changes: 10 additions & 10 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import numpy as np
from numbers import Number
from typing import Any, Dict, Tuple, Union, Optional
from typing import Any, Dict, List, Tuple, Union, Optional

from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
self._indices = np.arange(size)
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._avail_index: List[int] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we really using variable typing in Tianshou ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the purpose of this PR. Remove this will cause mypy type-check error:

tianshou/data/buffer.py: note: In member "__init__" of class "ReplayBuffer":
tianshou/data/buffer.py:140: error: Need type annotation for '_avail_index' (hint: "_avail_index: List[<type>] = ...")  [var-annotated]
            self._avail_index = []
            ^

self._save_s_ = not ignore_obs_next
self._last_obs = save_only_last_obs
self._index = 0
Expand Down Expand Up @@ -175,12 +175,12 @@ def _add_to_buffer(self, name: str, inst: Any) -> None:
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
if isinstance(inst, (torch.Tensor, np.ndarray)) \
and inst.shape != value.shape[1:]:
raise ValueError(
"Cannot add data to a buffer with different shape, with key "
f"{name}, expect {value.shape[1:]}, given {inst.shape}."
)
if isinstance(inst, (torch.Tensor, np.ndarray)):
if inst.shape != value.shape[1:]:
raise ValueError(
"Cannot add data to a buffer with different shape with key"
f" {name}, expect {value.shape[1:]}, given {inst.shape}."
)
try:
value[self._index] = inst
except KeyError:
Expand All @@ -205,7 +205,7 @@ def update(self, buffer: "ReplayBuffer") -> None:
stack_num_orig = buffer.stack_num
buffer.stack_num = 1
while True:
self.add(**buffer[i])
self.add(**buffer[i]) # type: ignore
i = (i + 1) % len(buffer)
if i == begin:
break
Expand Down Expand Up @@ -323,7 +323,7 @@ def get(
try:
if stack_num == 1:
return val[indice]
stack = []
stack: List[Any] = []
for _ in range(stack_num):
stack = [val[indice]] + stack
pre_indice = np.asarray(indice - 1)
Expand Down
46 changes: 10 additions & 36 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,8 @@ def collect(
finished_env_ids = []
reward_total = 0.0
whole_data = Batch()
list_n_episode = False
if n_episode is not None and not np.isscalar(n_episode):
if isinstance(n_episode, list):
assert len(n_episode) == self.get_env_num()
list_n_episode = True
finished_env_ids = [
i for i in self._ready_env_ids if n_episode[i] <= 0]
self._ready_env_ids = np.array(
Expand Down Expand Up @@ -266,7 +264,8 @@ def collect(
self.data.policy._state = self.data.state

self.data.act = to_numpy(result.act)
if self._action_noise is not None: # noqa
if self._action_noise is not None:
assert isinstance(self.data.act, np.ndarray)
self.data.act += self._action_noise(self.data.act.shape)

# step in env
Expand All @@ -291,7 +290,7 @@ def collect(

# add data into the buffer
if self.preprocess_fn:
result = self.preprocess_fn(**self.data)
result = self.preprocess_fn(**self.data) # type: ignore
self.data.update(result)

for j, i in enumerate(self._ready_env_ids):
Expand All @@ -305,14 +304,14 @@ def collect(
self._cached_buf[i].add(**self.data[j])

if done[j]:
if not (list_n_episode and
episode_count[i] >= n_episode[i]):
if not (isinstance(n_episode, list)
and episode_count[i] >= n_episode[i]):
episode_count[i] += 1
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
step_count += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
if list_n_episode and \
if isinstance(n_episode, list) and \
episode_count[i] >= n_episode[i]:
# env i has collected enough data, it has finished
finished_env_ids.append(i)
Expand All @@ -324,10 +323,9 @@ def collect(
env_ind_global = self._ready_env_ids[env_ind_local]
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_next[env_ind_local] = self.preprocess_fn(
obs_reset = self.preprocess_fn(
obs=obs_reset).get("obs", obs_reset)
else:
obs_next[env_ind_local] = obs_reset
obs_next[env_ind_local] = obs_reset
self.data.obs = obs_next
if is_async:
# set data back
Expand Down Expand Up @@ -362,7 +360,7 @@ def collect(
# average reward across the number of episodes
reward_avg = reward_total / episode_count
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
reward_avg = self._rew_metric(reward_avg)
reward_avg = self._rew_metric(reward_avg) # type: ignore
return {
"n/ep": episode_count,
"n/st": step_count,
Expand All @@ -372,30 +370,6 @@ def collect(
"len": step_count / episode_count,
}

def sample(self, batch_size: int) -> Batch:
"""Sample a data batch from the internal replay buffer.

It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before
returning the final batch data.

:param int batch_size: ``0`` means it will extract all the data from
the buffer, otherwise it will extract the data with the given
batch_size.
"""
warnings.warn(
"Collector.sample is deprecated and will cause error if you use "
"prioritized experience replay! Collector.sample will be removed "
"upon version 0.3. Use policy.update instead!", Warning)
assert self.buffer is not None, "Cannot get sample from empty buffer!"
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data

def close(self) -> None:
warnings.warn(
"Collector.close is deprecated and will be removed upon version "
"0.3.", Warning)


def _batch_set_item(
source: Batch, indices: np.ndarray, target: Batch, size: int
Expand Down
Loading