-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add profile workflow #143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add profile workflow #143
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
e976d74
Improve Batch (#126)
youkaichao d1a2037
Improve Batch (#128)
youkaichao a55ad33
Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)
youkaichao 885fbc1
Improve collector (#125)
youkaichao cee8088
Vector env enable select worker (#132)
duburcqa f8ad6df
Standardized behavior of Batch.cat and misc code refactor (#137)
youkaichao 1b6210b
add profile code for batch&comment fix
ChenDRAG 2f182a6
add profile code for batch&comment fix
ChenDRAG ce0a189
enhance batch profile
ChenDRAG c486a07
change workflow
ChenDRAG 4a5669b
update with #125
ChenDRAG c78a030
python args bug fix
ChenDRAG 9117426
Merge branch 'dev' into addtest
ChenDRAG f6e6e8a
*split all batch profile test. *delete print()
ChenDRAG c0d5137
Merge branch 'addtest' of github.com:ChenDRAG/tianshou into addtest
ChenDRAG 6417241
pep8 fix
ChenDRAG f783199
WIP, adding buffer_profile
fa542f8
write tutorials to specify the standard of Batch (#142)
youkaichao 3d04559
Merge branch 'dev' into addtest
Trinkle23897 7f9a914
Merge branch 'dev' into addtest
Trinkle23897 c835261
buffer update bug fix
ChenDRAG 0403780
some fix in buffer update
ChenDRAG 7e9c952
Merge branch 'buffer_bug_fix' into addtest
ChenDRAG 80f8d85
add test buffer profile
ChenDRAG 6e771db
Merge branch 'addtest' of github.com:ChenDRAG/tianshou into addtest
ChenDRAG 16ede77
Merge branch 'dev' into addtest
Trinkle23897 00a7bdb
merge with dev
ChenDRAG f9caa48
pep8 fix
ChenDRAG 74bddd6
revert
Trinkle23897 c28f368
Merge branch 'dev' into addtest
Trinkle23897 27b2869
Merge branch 'dev' into addtest
Trinkle23897 073f330
Merge branch 'dev' into addtest
Trinkle23897 a23b930
Merge branch 'dev' into addtest
Trinkle23897 e173f2e
add collector profile
ChenDRAG c546764
Merge branch 'dev' into addtest
ChenDRAG fa7dfe2
some fix
ChenDRAG f9e5548
Merge branch 'addtest' of github.com:ChenDRAG/tianshou into addtest
ChenDRAG 659d409
yet another fix
ChenDRAG dcd1b9d
add subproc, some fix
ChenDRAG 98bb7ad
add subprocenv test
ChenDRAG bded637
add subproc_env profile, fix batch stack_ profile
ChenDRAG 6334893
Merge branch 'dev' into addtest
ChenDRAG 9d32e99
Merge branch 'dev' into addtest
Trinkle23897 dcef91a
add subproc test, stack_fix
ChenDRAG d9981fb
Merge branch 'dev' into addtest
Trinkle23897 29a510e
style fix
ChenDRAG d1ec533
fix
ChenDRAG ee95324
Merge branch 'dev' into addtest
Trinkle23897 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
name: Data Profile | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python 3.8 | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: 3.8 | ||
- name: Upgrade pip | ||
run: | | ||
python -m pip install --upgrade pip setuptools wheel | ||
- name: Install dependencies | ||
run: | | ||
pip install ".[dev]" --upgrade | ||
- name: Test with pytest | ||
run: | | ||
pytest test/throughput --durations=0 -v |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import copy | ||
import pickle | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from tianshou.data import Batch | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def data(): | ||
print("Initialising data...") | ||
np.random.seed(0) | ||
batch_set = [Batch(a=[j for j in np.arange(1e3)], | ||
b={'b1': (3.14, 3.14), 'b2': np.arange(1e3)}, | ||
c=i) for i in np.arange(int(1e4))] | ||
batch0 = Batch( | ||
a=np.ones((3, 4), dtype=np.float64), | ||
b=Batch( | ||
c=np.ones((1,), dtype=np.float64), | ||
d=torch.ones((3, 3, 3), dtype=torch.float32), | ||
e=list(range(3)) | ||
) | ||
) | ||
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] | ||
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] | ||
batch_len = int(1e4) | ||
batch3 = Batch(obs=[np.arange(20) for _ in np.arange(batch_len)], | ||
reward=np.arange(batch_len)) | ||
indexs = np.random.choice(batch_len, | ||
size=batch_len//10, replace=False) | ||
slice_dict = {'obs': [np.arange(20) | ||
for _ in np.arange(batch_len//10)], | ||
'reward': np.arange(batch_len//10)} | ||
dict_set = [{'obs': np.arange(20), 'info': "this is info", 'reward': 0} | ||
for _ in np.arange(1e2)] | ||
batch4 = Batch( | ||
a=np.ones((10000, 4), dtype=np.float64), | ||
b=Batch( | ||
c=np.ones((1,), dtype=np.float64), | ||
d=torch.ones((1000, 1000), dtype=torch.float32), | ||
e=np.arange(1000) | ||
) | ||
) | ||
|
||
print("Initialised") | ||
return {'batch_set': batch_set, | ||
'batch0': batch0, | ||
'batchs1': batchs1, | ||
'batchs2': batchs2, | ||
'batch3': batch3, | ||
'indexs': indexs, | ||
'dict_set': dict_set, | ||
'slice_dict': slice_dict, | ||
'batch4': batch4 | ||
} | ||
|
||
|
||
def test_init(data): | ||
"""Test Batch __init__().""" | ||
for _ in np.arange(10): | ||
_ = Batch(data['batch_set']) | ||
|
||
|
||
def test_get_item(data): | ||
"""Test get with item.""" | ||
for _ in np.arange(1e5): | ||
_ = data['batch3'][data['indexs']] | ||
|
||
|
||
def test_get_attr(data): | ||
"""Test get with attr.""" | ||
for _ in np.arange(1e6): | ||
data['batch3'].get('obs') | ||
data['batch3'].get('reward') | ||
_, _ = data['batch3'].obs, data['batch3'].reward | ||
|
||
|
||
def test_set_item(data): | ||
"""Test set with item.""" | ||
for _ in np.arange(1e4): | ||
data['batch3'][data['indexs']] = data['slice_dict'] | ||
|
||
|
||
def test_set_attr(data): | ||
"""Test set with attr.""" | ||
for _ in np.arange(1e4): | ||
data['batch3'].c = np.arange(1e3) | ||
data['batch3'].obs = data['dict_set'] | ||
|
||
|
||
def test_numpy_torch_convert(data): | ||
"""Test conversion between numpy and torch.""" | ||
for _ in np.arange(1e5): | ||
data['batch4'].to_torch() | ||
data['batch4'].to_numpy() | ||
|
||
|
||
def test_pickle(data): | ||
for _ in np.arange(1e4): | ||
pickle.loads(pickle.dumps(data['batch4'])) | ||
|
||
|
||
def test_cat(data): | ||
"""Test cat""" | ||
for i in range(10000): | ||
Batch.cat((data['batch0'], data['batch0'])) | ||
data['batchs1'][i].cat_(data['batch0']) | ||
|
||
|
||
def test_stack(data): | ||
"""Test stack""" | ||
for i in range(10000): | ||
Batch.stack((data['batch0'], data['batch0'])) | ||
data['batchs2'][i].stack_([data['batch0']]) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main(["-s", "-k batch_profile", "--durations=0", "-v"]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer, | ||
ReplayBuffer) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def data(): | ||
np.random.seed(0) | ||
obs = {'observable': np.random.rand( | ||
100, 100), 'hidden': np.random.randint(1000, size=200)} | ||
info = {'policy': "dqn", 'base': np.arange(10)} | ||
add_data = {'obs': obs, 'rew': 1., 'act': np.random.rand(30), | ||
'done': False, 'obs_next': obs, 'info': info} | ||
buffer = ReplayBuffer(int(1e3), stack_num=100) | ||
buffer2 = ReplayBuffer(int(1e4), stack_num=100) | ||
indexes = np.random.choice(int(1e3), size=3, replace=False) | ||
return{ | ||
'add_data': add_data, | ||
'buffer': buffer, | ||
'buffer2': buffer2, | ||
'slice': slice(-3000, -1000, 2), | ||
'indexes': indexes | ||
} | ||
|
||
ChenDRAG marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def test_init(): | ||
for _ in np.arange(1e5): | ||
_ = ReplayBuffer(1e5) | ||
_ = PrioritizedReplayBuffer( | ||
size=int(1e5), alpha=0.5, | ||
beta=0.5, repeat_sample=True) | ||
_ = ListReplayBuffer() | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_add(data): | ||
buffer = data['buffer'] | ||
for _ in np.arange(1e5): | ||
buffer.add(**data['add_data']) | ||
|
||
|
||
def test_update(data): | ||
buffer = data['buffer'] | ||
buffer2 = data['buffer2'] | ||
for _ in np.arange(1e2): | ||
buffer2.update(buffer) | ||
|
||
|
||
def test_getitem_slice(data): | ||
Slice = data['slice'] | ||
buffer = data['buffer'] | ||
for _ in np.arange(1e3): | ||
_ = buffer[Slice] | ||
|
||
|
||
def test_getitem_indexes(data): | ||
indexes = data['indexes'] | ||
buffer = data['buffer'] | ||
for _ in np.arange(1e2): | ||
_ = buffer[indexes] | ||
|
||
|
||
def test_get(data): | ||
indexes = data['indexes'] | ||
buffer = data['buffer'] | ||
for _ in np.arange(3e2): | ||
buffer.get(indexes, 'obs') | ||
buffer.get(indexes, 'rew') | ||
buffer.get(indexes, 'done') | ||
buffer.get(indexes, 'info') | ||
|
||
|
||
def test_sample(data): | ||
buffer = data['buffer'] | ||
for _ in np.arange(1e1): | ||
buffer.sample(int(1e2)) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.