这是indexloc提供的服务,不要输入任何密码
Skip to content

Improve PER #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Aug 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9e636ed
first segtree without test
Trinkle23897 Jul 23, 2020
9208895
test some code
Trinkle23897 Jul 24, 2020
addc7a9
test prefix-sum-idx
Trinkle23897 Jul 24, 2020
6f7e0d1
Merge branch 'dev' into prio-buffer
Trinkle23897 Jul 27, 2020
f580910
Merge branch 'dev' into prio-buffer
Trinkle23897 Jul 29, 2020
c0c1290
finish test segtree
Trinkle23897 Aug 1, 2020
322a520
fix test
Trinkle23897 Aug 1, 2020
e15803f
change prio-buffer
Trinkle23897 Aug 1, 2020
a3e037a
align PER
Trinkle23897 Aug 1, 2020
0f0c58b
fix test
Trinkle23897 Aug 1, 2020
bbf60cd
fix a bug
Trinkle23897 Aug 1, 2020
31c25c3
remove mintree
Trinkle23897 Aug 1, 2020
c3443ef
DQN and DDPG
Trinkle23897 Aug 2, 2020
4f72573
TD3 and SAC
Trinkle23897 Aug 2, 2020
15c53be
docs
Trinkle23897 Aug 2, 2020
0b8a316
revert gitignore
Trinkle23897 Aug 2, 2020
8301c74
Merge branch 'dev' into prio-buffer
Trinkle23897 Aug 2, 2020
c3c5a64
merge pdqn test to dqn
Trinkle23897 Aug 2, 2020
81ba208
Merge branch 'prio-buffer' of github.com:Trinkle23897/tianshou into p…
Trinkle23897 Aug 2, 2020
d9cd78c
rm <<>>
Trinkle23897 Aug 2, 2020
036e54c
fix test
Trinkle23897 Aug 2, 2020
27d51ce
Merge branch 'dev' into prio-buffer
Trinkle23897 Aug 2, 2020
3c0cb2e
change op
Trinkle23897 Aug 2, 2020
2bd7fac
Merge branch 'prio-buffer' of github.com:Trinkle23897/tianshou into p…
Trinkle23897 Aug 2, 2020
c772009
size assert
Trinkle23897 Aug 2, 2020
1f98d01
minor fix
Trinkle23897 Aug 2, 2020
a7472ac
minor fix
Trinkle23897 Aug 2, 2020
b6e0651
fix corner case
Trinkle23897 Aug 2, 2020
a6b2e2d
fix
Trinkle23897 Aug 2, 2020
6f5c4f6
fix test
Trinkle23897 Aug 2, 2020
1226e2f
fix numba part
Trinkle23897 Aug 2, 2020
370802a
add to profile test
Trinkle23897 Aug 2, 2020
d6765d7
doc polish and remove intricate xor operators
youkaichao Aug 2, 2020
0ced37a
code refactor for _get_prefix_sum_idx
youkaichao Aug 2, 2020
cebcc2d
leave todo and doc fix
youkaichao Aug 3, 2020
30a6619
small fix for torch ones like
youkaichao Aug 3, 2020
61ea9f0
minor fix
Trinkle23897 Aug 4, 2020
687ccbb
minor fix
Trinkle23897 Aug 4, 2020
316974f
doc improve for sample
youkaichao Aug 4, 2020
87bb133
fix dqn local test
Trinkle23897 Aug 4, 2020
1377a0d
Merge branch 'dev' into prio-buffer
Trinkle23897 Aug 4, 2020
fd2ee87
Merge branch 'dev' into prio-buffer
Trinkle23897 Aug 5, 2020
eb307eb
fix weight update in buffer.add
Trinkle23897 Aug 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Here is Tianshou's other features:
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation for all Q-learning based algorithms
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)

In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Here is Tianshou's other features:
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
* Support customized training process: :ref:`customize_training`
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms
* Support multi-agent RL: :doc:`/tutorials/tictactoe`

中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
Expand Down
106 changes: 105 additions & 1 deletion test/base/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest
import numpy as np
from timeit import timeit

from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
from tianshou.data import Batch, PrioritizedReplayBuffer, \
ReplayBuffer, SegmentTree

if __name__ == '__main__':
from env import MyTestEnv
Expand Down Expand Up @@ -112,9 +115,110 @@ def test_update():
assert (buf2[-1].obs == buf1[0].obs).all()


def test_segtree():
for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
realop = getattr(np, op)
# small test
actual_len = 8
tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes
assert np.all([tree[i] == init for i in range(actual_len)])
with pytest.raises(IndexError):
tree[actual_len]
naive = np.full([actual_len], init)
for _ in range(1000):
# random choose a place to perform single update
index = np.random.randint(actual_len)
value = np.random.rand()
naive[index] = value
tree[index] = value
for i in range(actual_len):
for j in range(i + 1, actual_len):
ref = realop(naive[i:j])
out = tree.reduce(i, j)
assert np.allclose(ref, out)
# batch setitem
for _ in range(1000):
index = np.random.choice(actual_len, size=4)
value = np.random.rand(4)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# large test
actual_len = 16384
tree = SegmentTree(actual_len, op)
naive = np.full([actual_len], init)
for _ in range(1000):
index = np.random.choice(actual_len, size=64)
value = np.random.rand(64)
naive[index] = value
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
left = np.random.randint(actual_len)
right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))

# test prefix-sum-idx
actual_len = 8
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, np.int)
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
tree = SegmentTree(10)
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
assert np.allclose(tree.get_prefix_sum_idx(
np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
with pytest.raises(AssertionError):
tree.get_prefix_sum_idx(.2)
# test large prefix-sum-idx
actual_len = 16384
tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()

# profile
if __name__ == '__main__':
size = 100000
bsz = 64
naive = np.random.rand(size)
tree = SegmentTree(size)
tree[np.arange(size)] = naive

def sample_npbuf():
return np.random.choice(size, bsz, p=naive / naive.sum())

def sample_tree():
scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)

print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
Comment on lines +198 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to make it a separate function?



if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_segtree()
test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()
26 changes: 20 additions & 6 deletions test/discrete/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer


def get_args():
Expand All @@ -33,6 +33,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', type=int, default=0)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -58,15 +61,20 @@ def test_dqn(args=get_args()):
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device,
dueling=(2, 2)).to(args.device)
args.action_shape, args.device, # dueling=(1, 1)
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
# buffer
if args.prioritized_replay > 0:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.beta)
else:
buf = ReplayBuffer(args.buffer_size)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
train_collector = Collector(policy, train_envs, buf)
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
Expand Down Expand Up @@ -114,5 +122,11 @@ def test_fn(x):
collector.close()


def test_pdqn(args=get_args()):
args.prioritized_replay = 1
args.gamma = .95
test_dqn(args)


if __name__ == '__main__':
test_dqn(get_args())
118 changes: 0 additions & 118 deletions test/discrete/test_pdqn.py

This file was deleted.

16 changes: 13 additions & 3 deletions test/throughput/test_buffer_profile.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest
import numpy as np

from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer,
ReplayBuffer)
ReplayBuffer, SegmentTree)


@pytest.fixture(scope="module")
Expand All @@ -21,7 +21,7 @@ def data():
'buffer': buffer,
'buffer2': buffer2,
'slice': slice(-3000, -1000, 2),
'indexes': indexes
'indexes': indexes,
}


Expand Down Expand Up @@ -77,5 +77,15 @@ def test_sample(data):
buffer.sample(int(1e2))


def test_segtree(data):
size = 100000
tree = SegmentTree(size)
tree[np.arange(size)] = np.random.rand(size)

for i in np.arange(1e5):
scalar = np.random.rand(64) * tree.reduce()
tree.get_prefix_sum_idx(scalar)


if __name__ == '__main__':
pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"])
6 changes: 4 additions & 2 deletions tianshou/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tianshou.data.batch import Batch
from tianshou.data.utils import to_numpy, to_torch, \
from tianshou.data.utils.converter import to_numpy, to_torch, \
to_torch_as
from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
Expand All @@ -10,8 +11,9 @@
'to_numpy',
'to_torch',
'to_torch_as',
'SegmentTree',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
'Collector'
'Collector',
]
Loading