这是indexloc提供的服务,不要输入任何密码
Skip to content

add cross-platform test and release 0.4.1 #331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
import tianshou, torch, numpy, sys
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
```
4 changes: 2 additions & 2 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ Less important but also useful:
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, sys
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
import tianshou, torch, numpy, sys
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
```
27 changes: 27 additions & 0 deletions .github/workflows/extra_sys.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Unittest

on: [push, pull_request]

jobs:
build:
runs-on: ${{ matrix.os }}
if: "!contains(github.event.head_commit.message, 'ci skip')"
strategy:
matrix:
os: [macos-latest, windows-latest]
python-version: [3.6, 3.7, 3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
- name: Install dependencies
run: |
python -m pip install ".[dev]" --upgrade
- name: Test with pytest
run: |
pytest test/base test/continuous --ignore-glob "*env.py" --cov=tianshou --durations=0 -v
6 changes: 0 additions & 6 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
# - uses: actions/cache@v2
# with:
# path: /opt/hostedtoolcache/Python/
# key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
# restore-keys: |
# ${{ runner.os }}-${{ matrix.python-version }}-
- name: Upgrade pip
run: |
python -m pip install --upgrade pip setuptools wheel
Expand Down
4 changes: 2 additions & 2 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def _get_state(self):
elif self.recurse_state:
return {'index': np.array([self.index], dtype=np.float32),
'dict': {"tuple": (np.array([1],
dtype=np.int64), self.rng.rand(2)),
dtype=int), self.rng.rand(2)),
"rand": self.rng.rand(1, 2)}}
elif self.array_state:
img = np.zeros([4, 84, 84], np.int)
img = np.zeros([4, 84, 84], int)
img[3, np.arange(84), np.arange(84)] = self.index
img[2, np.arange(84)] = self.index
img[1, :, np.arange(84)] = self.index
Expand Down
11 changes: 9 additions & 2 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import copy
import torch
import pickle
Expand Down Expand Up @@ -373,7 +374,10 @@ def test_batch_over_batch_to_torch():
assert batch.a.dtype == torch.float64
assert batch.b.c.dtype == torch.float32
assert batch.b.d.dtype == torch.float64
assert batch.b.e.dtype == torch.int64
if sys.platform in ["win32", "cygwin"]: # windows
assert batch.b.e.dtype == torch.int32
else:
assert batch.b.e.dtype == torch.int64
batch.to_torch(dtype=torch.float32)
assert batch.a.dtype == torch.float32
assert batch.b.c.dtype == torch.float32
Expand Down Expand Up @@ -439,7 +443,10 @@ def test_utils_to_torch_numpy():
assert to_numpy(to_numpy).item() == to_numpy
# additional test for to_torch, for code-coverage
assert isinstance(to_torch(1), torch.Tensor)
assert to_torch(1).dtype == torch.int64
if sys.platform in ["win32", "cygwin"]: # windows
assert to_torch(1).dtype == torch.int32
else:
assert to_torch(1).dtype == torch.int64
assert to_torch(1.).dtype == torch.float64
assert isinstance(to_torch({'a': [1]})['a'], torch.Tensor)
with pytest.raises(TypeError):
Expand Down
6 changes: 3 additions & 3 deletions test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@

def has_ray():
try:
import ray
return hasattr(ray, 'init') # avoid PEP8 F401 Error
import ray # noqa: F401
return True
except ImportError:
return False


def recurse_comp(a, b):
try:
if isinstance(a, np.ndarray):
if a.dtype == np.object:
if a.dtype == object:
return np.array(
[recurse_comp(m, n) for m, n in zip(a, b)]).all()
else:
Expand Down
2 changes: 1 addition & 1 deletion test/base/test_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def target_q_fn_multidim(buffer, indice):


def compute_nstep_return_base(nstep, gamma, buffer, indice):
returns = np.zeros_like(indice, dtype=np.float)
returns = np.zeros_like(indice, dtype=float)
buf_len = len(buffer)
for i in range(len(indice)):
flag, r = False, 0.
Expand Down
2 changes: 1 addition & 1 deletion tianshou/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration


__version__ = "0.4.0"
__version__ = "0.4.1"

__all__ = [
"env",
Expand Down