From 9e636ed5359677924eea516e132111b86d9927da Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Thu, 23 Jul 2020 21:30:02 +0800
Subject: [PATCH 01/35] first segtree without test
---
tianshou/data/__init__.py | 4 +-
tianshou/data/buffer.py | 6 +-
tianshou/data/utils/__init__.py | 0
.../data/{utils.py => utils/converter.py} | 0
tianshou/data/utils/segtree.py | 87 +++++++++++++++++++
5 files changed, 91 insertions(+), 6 deletions(-)
create mode 100644 tianshou/data/utils/__init__.py
rename tianshou/data/{utils.py => utils/converter.py} (100%)
create mode 100644 tianshou/data/utils/segtree.py
diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py
index 5d097a03b..67e1b0ca1 100644
--- a/tianshou/data/__init__.py
+++ b/tianshou/data/__init__.py
@@ -1,5 +1,5 @@
from tianshou.data.batch import Batch
-from tianshou.data.utils import to_numpy, to_torch, \
+from tianshou.data.utils.converter import to_numpy, to_torch, \
to_torch_as
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
@@ -13,5 +13,5 @@
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
- 'Collector'
+ 'Collector',
]
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index b7ddcff38..1441e395c 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -372,7 +372,6 @@ def __init__(self, size: int, alpha: float, beta: float,
self._alpha = alpha
self._beta = beta
self._weight_sum = 0.0
- self._amortization_freq = 50
self._replace = replace
self._meta.weight = np.zeros(size, dtype=np.float64)
@@ -413,8 +412,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
# sampling weight
p = (self.weight / self.weight.sum())[:self._size]
indice = np.random.choice(
- self._size, batch_size, p=p,
- replace=self._replace)
+ self._size, batch_size, p=p, replace=self._replace)
p = p[indice] # weight of each sample
elif batch_size == 0:
p = np.full(shape=self._size, fill_value=1.0 / self._size)
@@ -427,7 +425,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
f"batch_size should be less than {len(self)}, \
or set replace=True")
batch = self[indice]
- batch["impt_weight"] = (self._size * p) ** (-self._beta)
+ batch.impt_weight = (self._size * p) ** (-self._beta)
return batch, indice
def update_weight(self, indice: Union[slice, np.ndarray],
diff --git a/tianshou/data/utils/__init__.py b/tianshou/data/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tianshou/data/utils.py b/tianshou/data/utils/converter.py
similarity index 100%
rename from tianshou/data/utils.py
rename to tianshou/data/utils/converter.py
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
new file mode 100644
index 000000000..792b8a1da
--- /dev/null
+++ b/tianshou/data/utils/segtree.py
@@ -0,0 +1,87 @@
+import numpy as np
+from typing import Union, Optional
+
+
+class SegmentTree:
+ """Implementation of Segment Tree. The procedure is as follows:
+
+ 1. Find out the smallest n which safisfies ``size <= 2^n``, and let \
+ ``bound = 2^n``. This is to ensure that all leaf nodes are in the same \
+ depth inside the segment tree.
+ 2. Store the original value to leaf nodes in ``[bound:bound * 2]``, and \
+ the union of elementary to internal nodes in ``[1:bound]``. The internal \
+ node follows the rule: \
+ ``value[i] = operation(value[i * 2], value[i * 2 + 1])``.
+ 3. Update a node takes O(log(bound)) time complexity.
+ 4. Query an interval [l, r] with the default operation takes O(log(bound))
+
+ :param int size: the size of segment tree.
+ :param operation: the operation of segment tree. Choose one of "sum", "min"
+ and "max", defaults to "sum".
+ """
+
+ def __init__(self, size: int,
+ operation: Union[sum, min, max] = sum) -> None:
+ bound = 1
+ while bound < size:
+ bound <<= 1
+ self._bound = bound
+ assert operation in [sum, min, max], f"Unknown operation {operation}."
+ self._op = operation
+ self._init_value = {sum: 0, min: np.inf, max: -np.inf}[self._op]
+ self._value = np.zeros([bound << 1]) + self._init_value
+
+ def __getitem__(self, index: int) -> float:
+ """Return self[index]"""
+ assert isinstance(index, int) and 0 <= index < self._bound
+ return self._value[index + self._bound]
+
+ def __setitem__(self, index: Union[int, np.ndarray],
+ value: Union[float, np.ndarray]) -> None:
+ """Insert or overwrite a (or some) value in this segment tree."""
+ if isinstance(index, int) and isinstance(value, float):
+ index, value = np.array([index]), np.array([value])
+ assert isinstance(index, np.ndarray) and isinstance(value, np.ndarray)
+ assert ((0 <= index) & (index < self._bound) & (value >= 0.)).all()
+ index += self._bound
+ self._value[index] = value
+ while index > 1:
+ index >>= 1
+ self._value[index] = self._op(self._value[index << 1],
+ self._value[index << 1 | 1])
+
+ def reduce(self, start: Optional[int] = 0,
+ end: Optional[int] = 0) -> float:
+ """Return operation(value[start:end])."""
+ if start == end == 0:
+ return self._value[1]
+ if end <= 0:
+ end += self._bound
+ start, end = start + self._bound - 1, end + self._bound
+ result = self._init_value
+ while start ^ end ^ 1 != 0:
+ if start % 2 == 0:
+ result = self._op(result, self._value[start ^ 1])
+ if end % 2 == 1:
+ result = self._op(result, self._value[end ^ 1])
+ start, end = start >> 1, end >> 1
+ return result
+
+ def get_prefix_sum_idx(
+ self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
+ """Return the index ``i`` which satisfies
+ ``sum(value[:i]) <= value < sum(value[:i + 1])``.
+ """
+ assert self._op == sum
+ single = False
+ if not isinstance(value, np.ndarray):
+ value = np.array([value])
+ single = True
+ assert (value <= self._value[1]).all()
+ index = np.ones(value.shape, dtype=np.int)
+ while index[0] < self._bound:
+ index <<= 1
+ direct = self._value[index] < value
+ value -= self._value[index] * direct
+ index += direct
+ return index.item() if single else index
From 92088957f78929e16f4c4a3348a7072e003310a4 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Fri, 24 Jul 2020 11:43:54 +0800
Subject: [PATCH 02/35] test some code
---
test/base/test_buffer.py | 64 +++++++++++++++++++++++++++++++++-
tianshou/data/__init__.py | 2 ++
tianshou/data/utils/segtree.py | 40 ++++++++++++---------
3 files changed, 88 insertions(+), 18 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 6178a3299..ddbcbb5ec 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -1,6 +1,8 @@
+import pytest
import numpy as np
-from tianshou.data import Batch, PrioritizedReplayBuffer, ReplayBuffer
+from tianshou.data import Batch, PrioritizedReplayBuffer, \
+ ReplayBuffer, SegmentTree
if __name__ == '__main__':
from env import MyTestEnv
@@ -112,9 +114,69 @@ def test_update():
assert (buf2[-1].obs == buf1[0].obs).all()
+def test_segtree():
+ for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
+ realop = getattr(np, op)
+ # small test
+ tree = SegmentTree(6, op) # 1-15. 8-15 are leaf nodes
+ actual_len = 8
+ assert np.all([tree[i] == init for i in range(actual_len)])
+ with pytest.raises(AssertionError):
+ tree[-1]
+ with pytest.raises(AssertionError):
+ tree[actual_len]
+ naive = np.zeros([actual_len]) + init
+ for _ in range(1000):
+ # random choose a place to perform single update
+ index = np.random.randint(actual_len)
+ value = np.random.rand()
+ naive[index] = value
+ tree[index] = value
+ for i in range(actual_len):
+ for j in range(i, actual_len):
+ try:
+ ref = realop(naive[i:j])
+ except ValueError:
+ continue
+ out = tree.reduce(i, j)
+ assert np.allclose(ref, out), (i, j, ref, out)
+ # batch setitem
+ for _ in range(1000):
+ index = np.random.choice(actual_len, size=4)
+ value = np.random.rand(4)
+ naive[index] = value
+ tree[index] = value
+ assert np.allclose(realop(naive), tree.reduce())
+ for i in range(10):
+ left = right = 0
+ while left >= right:
+ left = np.random.randint(actual_len)
+ right = np.random.randint(actual_len)
+ assert np.allclose(realop(naive[left:right]),
+ tree.reduce(left, right))
+ # large test
+ tree = SegmentTree(10000, op)
+ actual_len = 16384
+ naive = np.zeros([actual_len]) + init
+ for _ in range(1000):
+ index = np.random.choice(actual_len, size=64)
+ value = np.random.rand(64)
+ naive[index] = value
+ tree[index] = value
+ assert np.allclose(realop(naive), tree.reduce())
+ for i in range(10):
+ left = right = 0
+ while left >= right:
+ left = np.random.randint(actual_len)
+ right = np.random.randint(actual_len)
+ assert np.allclose(realop(naive[left:right]),
+ tree.reduce(left, right))
+
+
if __name__ == '__main__':
test_replaybuffer()
test_ignore_obs_next()
test_stack()
+ test_segtree()
test_priortized_replaybuffer(233333, 200000)
test_update()
diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py
index 67e1b0ca1..f5f68e9e0 100644
--- a/tianshou/data/__init__.py
+++ b/tianshou/data/__init__.py
@@ -1,6 +1,7 @@
from tianshou.data.batch import Batch
from tianshou.data.utils.converter import to_numpy, to_torch, \
to_torch_as
+from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
@@ -10,6 +11,7 @@
'to_numpy',
'to_torch',
'to_torch_as',
+ 'SegmentTree',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 792b8a1da..cc8f17833 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -16,19 +16,23 @@ class SegmentTree:
4. Query an interval [l, r] with the default operation takes O(log(bound))
:param int size: the size of segment tree.
- :param operation: the operation of segment tree. Choose one of "sum", "min"
- and "max", defaults to "sum".
+ :param str operation: the operation of segment tree. Choose one of "sum",
+ "min" and "max", defaults to "sum".
"""
def __init__(self, size: int,
- operation: Union[sum, min, max] = sum) -> None:
+ operation: str = "sum") -> None:
bound = 1
while bound < size:
bound <<= 1
self._bound = bound
- assert operation in [sum, min, max], f"Unknown operation {operation}."
- self._op = operation
- self._init_value = {sum: 0, min: np.inf, max: -np.inf}[self._op]
+ assert operation in ["sum", "min", "max"], \
+ f"Unknown operation {operation}."
+ (self._op, self._init_value) = {
+ "sum": (np.sum, 0.),
+ "min": (np.min, np.inf),
+ "max": (np.max, -np.inf),
+ }[operation]
self._value = np.zeros([bound << 1]) + self._init_value
def __getitem__(self, index: int) -> float:
@@ -38,32 +42,34 @@ def __getitem__(self, index: int) -> float:
def __setitem__(self, index: Union[int, np.ndarray],
value: Union[float, np.ndarray]) -> None:
- """Insert or overwrite a (or some) value in this segment tree."""
+ """Insert or overwrite a (or some) value(s) in this segment tree."""
if isinstance(index, int) and isinstance(value, float):
index, value = np.array([index]), np.array([value])
assert isinstance(index, np.ndarray) and isinstance(value, np.ndarray)
- assert ((0 <= index) & (index < self._bound) & (value >= 0.)).all()
+ assert ((0 <= index) & (index < self._bound)).all()
index += self._bound
self._value[index] = value
- while index > 1:
+ while index[0] > 1:
index >>= 1
- self._value[index] = self._op(self._value[index << 1],
- self._value[index << 1 | 1])
+ self._value[index] = self._op(
+ [self._value[index << 1], self._value[index << 1 | 1]], axis=0)
def reduce(self, start: Optional[int] = 0,
- end: Optional[int] = 0) -> float:
+ end: Optional[int] = None) -> float:
"""Return operation(value[start:end])."""
- if start == end == 0:
+ if start == 0 and end is None:
return self._value[1]
- if end <= 0:
+ if end is None:
+ end = self._bound
+ if end < 0:
end += self._bound
start, end = start + self._bound - 1, end + self._bound
result = self._init_value
while start ^ end ^ 1 != 0:
if start % 2 == 0:
- result = self._op(result, self._value[start ^ 1])
+ result = self._op([result, self._value[start ^ 1]])
if end % 2 == 1:
- result = self._op(result, self._value[end ^ 1])
+ result = self._op([result, self._value[end ^ 1]])
start, end = start >> 1, end >> 1
return result
@@ -72,7 +78,7 @@ def get_prefix_sum_idx(
"""Return the index ``i`` which satisfies
``sum(value[:i]) <= value < sum(value[:i + 1])``.
"""
- assert self._op == sum
+ assert self._op == np.sum
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
From addc7a9fc3eac069c6fa03b542d43b789495afb3 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Fri, 24 Jul 2020 17:20:16 +0800
Subject: [PATCH 03/35] test prefix-sum-idx
---
test/base/test_buffer.py | 30 ++++++++++++++++++++++++++++++
tianshou/data/utils/segtree.py | 5 +++--
2 files changed, 33 insertions(+), 2 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index ddbcbb5ec..7a668744c 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -1,5 +1,6 @@
import pytest
import numpy as np
+from timeit import timeit
from tianshou.data import Batch, PrioritizedReplayBuffer, \
ReplayBuffer, SegmentTree
@@ -172,6 +173,35 @@ def test_segtree():
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
+ # test prefix-sum-idx
+ tree = SegmentTree(6)
+ actual_len = 8
+ naive = np.random.rand(actual_len)
+ tree[np.arange(actual_len)] = naive
+ for _ in range(1000):
+ scalar = np.random.rand() * naive.sum()
+ index = tree.get_prefix_sum_idx(scalar)
+ assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
+ # corner case here
+ naive = np.ones(actual_len, np.int)
+ tree[np.arange(actual_len)] = naive
+ for scalar in range(actual_len):
+ index = tree.get_prefix_sum_idx(scalar * 1.)
+ assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
+ # test large prefix-sum-idx
+ tree = SegmentTree(10000)
+ actual_len = 16384
+ naive = np.random.rand(actual_len)
+ tree[np.arange(actual_len)] = naive
+ for _ in range(1000):
+ scalar = np.random.rand() * naive.sum()
+ index = tree.get_prefix_sum_idx(scalar)
+ assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
+ # profile
+ if __name__ == '__main__':
+ naive = np.zeros(10000)
+ tree = SegmentTree(10000)
+
if __name__ == '__main__':
test_replaybuffer()
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index cc8f17833..834e06cf4 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -78,7 +78,7 @@ def get_prefix_sum_idx(
"""Return the index ``i`` which satisfies
``sum(value[:i]) <= value < sum(value[:i + 1])``.
"""
- assert self._op == np.sum
+ assert self._op is np.sum
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
@@ -87,7 +87,8 @@ def get_prefix_sum_idx(
index = np.ones(value.shape, dtype=np.int)
while index[0] < self._bound:
index <<= 1
- direct = self._value[index] < value
+ direct = self._value[index] <= value
value -= self._value[index] * direct
index += direct
+ index -= self._bound
return index.item() if single else index
From c0c1290a8b840a3d896157a71def83583714d2bb Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sat, 1 Aug 2020 09:11:22 +0800
Subject: [PATCH 04/35] finish test segtree
---
.gitignore | 1 +
test/base/test_buffer.py | 17 +++++++++++++++--
tianshou/data/utils/segtree.py | 23 +++++++++++++++++++++++
3 files changed, 39 insertions(+), 2 deletions(-)
diff --git a/.gitignore b/.gitignore
index 0ecb650d3..2aa0c739b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -145,3 +145,4 @@ MUJOCO_LOG.TXT
.DS_Store
*.zip
*.pstats
+output.png
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 7a668744c..bea78e6c2 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -199,8 +199,21 @@ def test_segtree():
assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
# profile
if __name__ == '__main__':
- naive = np.zeros(10000)
- tree = SegmentTree(10000)
+ size = 100000
+ bsz = 64
+ naive = np.random.rand(size)
+ tree = SegmentTree(size)
+ tree[np.arange(size)] = naive
+
+ def sample_npbuf():
+ return np.random.choice(size, bsz, p=naive / naive.sum())
+
+ def sample_tree():
+ scalar = np.random.rand() * tree.reduce()
+ return tree.get_prefix_sum_idx(scalar)
+
+ print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
+ print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
if __name__ == '__main__':
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 834e06cf4..3f5744e41 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -1,5 +1,6 @@
import numpy as np
from typing import Union, Optional
+# from numba import njit
class SegmentTree:
@@ -90,5 +91,27 @@ def get_prefix_sum_idx(
direct = self._value[index] <= value
value -= self._value[index] * direct
index += direct
+ # index = self.__class__._get_prefix_sum_idx(
+ # index, value, self._bound, self._value)
index -= self._bound
return index.item() if single else index
+
+ # numba version, 10x speed up
+ # @njit
+ # def _get_prefix_sum_idx(index, scalar, bound, weight):
+ # # while index[0] < bound:
+ # # index <<= 1
+ # # direct = weight[index] <= scalar
+ # # scalar -= weight[index] * direct
+ # # index += direct
+ # for _, s in enumerate(scalar):
+ # i = 1
+ # while i < bound:
+ # l = i * 2
+ # if weight[l] > s:
+ # i = l
+ # else:
+ # s = s - weight[l]
+ # i = l + 1
+ # index[_] = i
+ # return index
From 322a520893b1290a84cc8013f674ff028a9b4364 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sat, 1 Aug 2020 09:25:23 +0800
Subject: [PATCH 05/35] fix test
---
test/base/test_buffer.py | 2 +-
tianshou/data/utils/segtree.py | 49 ++++++++++++++++------------------
2 files changed, 24 insertions(+), 27 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index bea78e6c2..f9547e421 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -209,7 +209,7 @@ def sample_npbuf():
return np.random.choice(size, bsz, p=naive / naive.sum())
def sample_tree():
- scalar = np.random.rand() * tree.reduce()
+ scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 3f5744e41..9c4faa056 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -86,32 +86,29 @@ def get_prefix_sum_idx(
single = True
assert (value <= self._value[1]).all()
index = np.ones(value.shape, dtype=np.int)
- while index[0] < self._bound:
- index <<= 1
- direct = self._value[index] <= value
- value -= self._value[index] * direct
- index += direct
- # index = self.__class__._get_prefix_sum_idx(
- # index, value, self._bound, self._value)
- index -= self._bound
+ index = self.__class__._get_prefix_sum_idx(
+ index, value, self._bound, self._value)
return index.item() if single else index
- # numba version, 10x speed up
+ # numba version, 5x speed up
+ # with size=100000 and bsz=64
+ # first block (vectorized np): 0.0923 (now) -> 0.0251
+ # second block (for-loop): 0.2914 -> 0.0192 (future)
# @njit
- # def _get_prefix_sum_idx(index, scalar, bound, weight):
- # # while index[0] < bound:
- # # index <<= 1
- # # direct = weight[index] <= scalar
- # # scalar -= weight[index] * direct
- # # index += direct
- # for _, s in enumerate(scalar):
- # i = 1
- # while i < bound:
- # l = i * 2
- # if weight[l] > s:
- # i = l
- # else:
- # s = s - weight[l]
- # i = l + 1
- # index[_] = i
- # return index
+ def _get_prefix_sum_idx(index, scalar, bound, weight):
+ while index[0] < bound:
+ index <<= 1
+ direct = weight[index] <= scalar
+ scalar -= weight[index] * direct
+ index += direct
+ # for _, s in enumerate(scalar):
+ # i = 1
+ # while i < bound:
+ # l = i * 2
+ # if weight[l] > s:
+ # i = l
+ # else:
+ # s = s - weight[l]
+ # i = l + 1
+ # index[_] = i
+ return index - bound
From e15803f8eaedf77c5f1170497675e8821b0a69b5 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sat, 1 Aug 2020 11:29:55 +0800
Subject: [PATCH 06/35] change prio-buffer
---
tianshou/data/buffer.py | 68 +++++++++-------------------------
tianshou/data/utils/segtree.py | 7 +++-
2 files changed, 22 insertions(+), 53 deletions(-)
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 6443b3e10..0d6dab64c 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -1,7 +1,8 @@
import numpy as np
from typing import Any, Tuple, Union, Optional
-from tianshou.data.batch import Batch, _create_value
+from tianshou.data import Batch, SegmentTree
+from tianshou.data.batch import _create_value
class ReplayBuffer:
@@ -353,12 +354,10 @@ def reset(self) -> None:
class PrioritizedReplayBuffer(ReplayBuffer):
- """Prioritized replay buffer implementation.
+ """Prioritized replay buffer implementation, using segment tree.
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
- :param str mode: defaults to ``weight``.
- :param bool replace: whether to sample with replacement
.. seealso::
@@ -366,17 +365,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
detailed explanation.
"""
- def __init__(self, size: int, alpha: float, beta: float,
- mode: str = 'weight',
- replace: bool = False, **kwargs) -> None:
- if mode != 'weight':
- raise NotImplementedError
+ def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
super().__init__(size, **kwargs)
- self._alpha = alpha
- self._beta = beta
- self._weight_sum = 0.0
- self._replace = replace
- self._meta.weight = np.zeros(size, dtype=np.float64)
+ self._alpha, self._beta = alpha, beta
+ # bypass the check
+ self._meta.__dict__['weight'] = SegmentTree(size, 'sum')
def add(self,
obs: Union[dict, np.ndarray],
@@ -389,64 +382,37 @@ def add(self,
weight: float = 1.0,
**kwargs) -> None:
"""Add a batch of data into replay buffer."""
- # we have to sacrifice some convenience for speed
- self._weight_sum += np.abs(weight) ** self._alpha - \
- self._meta.weight[self._index]
- self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
+ weight = np.abs(weight) ** self._alpha
+ self.weight[self._index] = weight
super().add(obs, act, rew, done, obs_next, info, policy)
- @property
- def replace(self):
- return self._replace
-
- @replace.setter
- def replace(self, v: bool):
- self._replace = v
-
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with priority probability. \
Return all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
"""
- assert self._size > 0, 'cannot sample a buffer with size == 0 !'
- p = None
- if batch_size > 0 and (self._replace or batch_size <= self._size):
- # sampling weight
- p = (self.weight / self.weight.sum())[:self._size]
- indice = np.random.choice(
- self._size, batch_size, p=p, replace=self._replace)
- p = p[indice] # weight of each sample
- elif batch_size == 0:
- p = np.full(shape=self._size, fill_value=1.0 / self._size)
+ assert self._size > 0, 'Cannot sample a buffer with 0 size!'
+ if batch_size == 0:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
else:
- raise ValueError(
- f"batch_size should be less than {len(self)}, \
- or set replace=True")
+ scalar = np.random.rand(batch_size) * self.weight.reduce()
+ indice = self.weight.get_prefix_sum_idx(scalar)
batch = self[indice]
- batch.impt_weight = (self._size * p) ** (-self._beta)
+ batch.impt_weight = (self._size * batch.weight) ** (-self._beta)
return batch, indice
- def update_weight(self, indice: Union[slice, np.ndarray],
+ def update_weight(self, indice: Union[np.ndarray],
new_weight: np.ndarray) -> None:
"""Update priority weight by indice in this buffer.
:param np.ndarray indice: indice you want to update weight
:param np.ndarray new_weight: new priority weight you want to update
"""
- if self._replace:
- if isinstance(indice, slice):
- # convert slice to ndarray
- indice = np.arange(indice.stop)[indice]
- # remove the same values in indice
- indice, unique_indice = np.unique(
- indice, return_index=True)
- new_weight = new_weight[unique_indice]
- self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
+ self.weight[indice] = np.abs(new_weight) ** self._alpha
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
@@ -457,6 +423,6 @@ def __getitem__(self, index: Union[
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
- weight=self.weight[index],
policy=self.get(index, 'policy'),
+ weight=self.weight[index],
)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 9c4faa056..470a90db3 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -36,9 +36,12 @@ def __init__(self, size: int,
}[operation]
self._value = np.zeros([bound << 1]) + self._init_value
- def __getitem__(self, index: int) -> float:
+ def __len__(self):
+ return self._bound
+
+ def __getitem__(self, index: Union[int, np.ndarray]
+ ) -> Union[float, np.ndarray]:
"""Return self[index]"""
- assert isinstance(index, int) and 0 <= index < self._bound
return self._value[index + self._bound]
def __setitem__(self, index: Union[int, np.ndarray],
From a3e037a922637a10529092cdb6fbdd9549dfe422 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sat, 1 Aug 2020 16:55:16 +0800
Subject: [PATCH 07/35] align PER
---
tianshou/data/buffer.py | 21 ++++++++++++++++-----
1 file changed, 16 insertions(+), 5 deletions(-)
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 0d6dab64c..54a2caeb9 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -354,7 +354,7 @@ def reset(self) -> None:
class PrioritizedReplayBuffer(ReplayBuffer):
- """Prioritized replay buffer implementation, using segment tree.
+ """Implementation of Prioritized Experience Replay. arXiv:1511.05952
:param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient.
@@ -367,9 +367,12 @@ class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
super().__init__(size, **kwargs)
+ assert alpha > 0. and beta >= 0.
self._alpha, self._beta = alpha, beta
+ self._max_prio = 1.
# bypass the check
self._meta.__dict__['weight'] = SegmentTree(size, 'sum')
+ self._meta.__dict__['weight_min'] = SegmentTree(size, 'min')
def add(self,
obs: Union[dict, np.ndarray],
@@ -379,11 +382,13 @@ def add(self,
obs_next: Optional[Union[dict, np.ndarray]] = None,
info: dict = {},
policy: Optional[Union[dict, Batch]] = {},
- weight: float = 1.0,
+ weight: float = None,
**kwargs) -> None:
"""Add a batch of data into replay buffer."""
+ if weight is None:
+ weight = self._max_prio
weight = np.abs(weight) ** self._alpha
- self.weight[self._index] = weight
+ self.weight[self._index] = self.weight_min[self._index] = weight
super().add(obs, act, rew, done, obs_next, info, policy)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
@@ -402,7 +407,11 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
scalar = np.random.rand(batch_size) * self.weight.reduce()
indice = self.weight.get_prefix_sum_idx(scalar)
batch = self[indice]
- batch.impt_weight = (self._size * batch.weight) ** (-self._beta)
+ # impt_weight
+ # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
+ # simplified formula: (p_j/p_min)**(-beta)
+ batch.impt_weight = (
+ batch.weight / self.weight_min.reduce()) ** (-self._beta)
return batch, indice
def update_weight(self, indice: Union[np.ndarray],
@@ -412,7 +421,9 @@ def update_weight(self, indice: Union[np.ndarray],
:param np.ndarray indice: indice you want to update weight
:param np.ndarray new_weight: new priority weight you want to update
"""
- self.weight[indice] = np.abs(new_weight) ** self._alpha
+ self.weight[indice] = self.weight_min[indice] = np.abs(
+ new_weight) ** self._alpha
+ self._max_prio = max(self._max_prio, np.abs(new_weight).max())
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
From 0f0c58bab690e18bf4798b20f05bcc8a733c2c57 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sat, 1 Aug 2020 17:02:24 +0800
Subject: [PATCH 08/35] fix test
---
test/base/test_buffer.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index f9547e421..d6c269d33 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -122,9 +122,7 @@ def test_segtree():
tree = SegmentTree(6, op) # 1-15. 8-15 are leaf nodes
actual_len = 8
assert np.all([tree[i] == init for i in range(actual_len)])
- with pytest.raises(AssertionError):
- tree[-1]
- with pytest.raises(AssertionError):
+ with pytest.raises(IndexError):
tree[actual_len]
naive = np.zeros([actual_len]) + init
for _ in range(1000):
From bbf60cd26893394ce955b333433b6fc04c76fb55 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sat, 1 Aug 2020 22:17:28 +0800
Subject: [PATCH 09/35] fix a bug
---
test/base/test_buffer.py | 1 +
tianshou/data/buffer.py | 13 +++++++------
2 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index d6c269d33..74cb30d0e 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -215,6 +215,7 @@ def sample_tree():
if __name__ == '__main__':
+ test_priortized_replaybuffer()
test_replaybuffer()
test_ignore_obs_next()
test_stack()
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 54a2caeb9..8b13551b5 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -392,8 +392,8 @@ def add(self,
super().add(obs, act, rew, done, obs_next, info, policy)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
- """Get a random sample from buffer with priority probability. \
- Return all the data in the buffer if batch_size is ``0``.
+ """Get a random sample from buffer with priority probability. Return
+ all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
"""
@@ -418,11 +418,12 @@ def update_weight(self, indice: Union[np.ndarray],
new_weight: np.ndarray) -> None:
"""Update priority weight by indice in this buffer.
- :param np.ndarray indice: indice you want to update weight
- :param np.ndarray new_weight: new priority weight you want to update
+ :param np.ndarray indice: indice you want to update weight.
+ :param np.ndarray new_weight: new priority weight you want to update.
"""
- self.weight[indice] = self.weight_min[indice] = np.abs(
- new_weight) ** self._alpha
+ weight = np.abs(new_weight) ** self._alpha
+ self.weight[indice] = weight
+ self.weight_min[indice] = weight
self._max_prio = max(self._max_prio, np.abs(new_weight).max())
def __getitem__(self, index: Union[
From 31c25c362282d28ea316512cba9870fd7d1d93d4 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sat, 1 Aug 2020 22:37:50 +0800
Subject: [PATCH 10/35] remove mintree
---
tianshou/data/buffer.py | 17 ++++++++---------
tianshou/data/utils/segtree.py | 2 +-
2 files changed, 9 insertions(+), 10 deletions(-)
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 8b13551b5..501f99317 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -327,8 +327,8 @@ class ListReplayBuffer(ReplayBuffer):
.. seealso::
- Please refer to :class:`~tianshou.data.ReplayBuffer` for more
- detailed explanation.
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
+ explanation.
"""
def __init__(self, **kwargs) -> None:
@@ -361,8 +361,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
.. seealso::
- Please refer to :class:`~tianshou.data.ReplayBuffer` for more
- detailed explanation.
+ Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
+ explanation.
"""
def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
@@ -370,9 +370,9 @@ def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
assert alpha > 0. and beta >= 0.
self._alpha, self._beta = alpha, beta
self._max_prio = 1.
+ self._min_prio = 1.
# bypass the check
self._meta.__dict__['weight'] = SegmentTree(size, 'sum')
- self._meta.__dict__['weight_min'] = SegmentTree(size, 'min')
def add(self,
obs: Union[dict, np.ndarray],
@@ -388,7 +388,7 @@ def add(self,
if weight is None:
weight = self._max_prio
weight = np.abs(weight) ** self._alpha
- self.weight[self._index] = self.weight_min[self._index] = weight
+ self.weight[self._index] = weight
super().add(obs, act, rew, done, obs_next, info, policy)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
@@ -410,8 +410,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
# impt_weight
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
- batch.impt_weight = (
- batch.weight / self.weight_min.reduce()) ** (-self._beta)
+ batch.impt_weight = (batch.weight / self._min_prio) ** (-self._beta)
return batch, indice
def update_weight(self, indice: Union[np.ndarray],
@@ -423,8 +422,8 @@ def update_weight(self, indice: Union[np.ndarray],
"""
weight = np.abs(new_weight) ** self._alpha
self.weight[indice] = weight
- self.weight_min[indice] = weight
self._max_prio = max(self._max_prio, np.abs(new_weight).max())
+ self._min_prio = min(self._min_prio, np.abs(new_weight).min())
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 470a90db3..710a714ee 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -51,7 +51,7 @@ def __setitem__(self, index: Union[int, np.ndarray],
index, value = np.array([index]), np.array([value])
assert isinstance(index, np.ndarray) and isinstance(value, np.ndarray)
assert ((0 <= index) & (index < self._bound)).all()
- index += self._bound
+ index = index + self._bound
self._value[index] = value
while index[0] > 1:
index >>= 1
From c3443ef5142ef606a3c5624a9e580fa05b473959 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 11:06:36 +0800
Subject: [PATCH 11/35] DQN and DDPG
---
tianshou/data/buffer.py | 9 +++++----
tianshou/policy/base.py | 9 ++++++++-
tianshou/policy/modelfree/ddpg.py | 7 ++++++-
tianshou/policy/modelfree/dqn.py | 14 +++++---------
4 files changed, 24 insertions(+), 15 deletions(-)
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 501f99317..f0c8622f4 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -314,7 +314,7 @@ def __getitem__(self, index: Union[
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
- policy=self.get(index, 'policy')
+ policy=self.get(index, 'policy'),
)
@@ -373,6 +373,7 @@ def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
self._min_prio = 1.
# bypass the check
self._meta.__dict__['weight'] = SegmentTree(size, 'sum')
+ self.__eps = np.finfo(np.float32).eps.item()
def add(self,
obs: Union[dict, np.ndarray],
@@ -410,7 +411,7 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
# impt_weight
# original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
# simplified formula: (p_j/p_min)**(-beta)
- batch.impt_weight = (batch.weight / self._min_prio) ** (-self._beta)
+ batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
return batch, indice
def update_weight(self, indice: Union[np.ndarray],
@@ -420,8 +421,8 @@ def update_weight(self, indice: Union[np.ndarray],
:param np.ndarray indice: indice you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
- weight = np.abs(new_weight) ** self._alpha
- self.weight[indice] = weight
+ weight = np.abs(new_weight) + self.__eps
+ self.weight[indice] = weight ** self._alpha
self._max_prio = max(self._max_prio, np.abs(new_weight).max())
self._min_prio = min(self._min_prio, np.abs(new_weight).min())
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index 9d7711c20..65ae496fb 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -4,7 +4,8 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional, Callable
-from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
+from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
+ to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module):
@@ -213,4 +214,10 @@ def compute_nstep_return(
returns = to_torch_as(returns, target_q)
gammas = to_torch_as(gamma ** gammas, target_q)
batch.returns = target_q * gammas + returns
+ # prio buffer update
+ if isinstance(buffer, PrioritizedReplayBuffer):
+ batch.update_weight = buffer.update_weight
+ batch.indice = indice
+ else:
+ batch.weight = 1.
return batch
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
index da4833a69..4d139cfa5 100644
--- a/tianshou/policy/modelfree/ddpg.py
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -144,7 +144,12 @@ def forward(self, batch: Batch,
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
- critic_loss = F.mse_loss(current_q, target_q)
+ td = current_q - target_q
+ if hasattr(batch, 'update_weight'): # prio-buffer
+ batch.update_weight(batch.indice, to_numpy(td))
+ weight = to_torch_as(batch.weight, target_q)
+ critic_loss = (td.pow(2) * weight).mean()
+ # critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
index 9bf60a6c9..1c4754064 100644
--- a/tianshou/policy/modelfree/dqn.py
+++ b/tianshou/policy/modelfree/dqn.py
@@ -95,9 +95,6 @@ def process_fn(self, batch: Batch, buffer: ReplayBuffer,
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
- if isinstance(buffer, PrioritizedReplayBuffer):
- batch.update_weight = buffer.update_weight
- batch.indice = indice
return batch
def forward(self, batch: Batch,
@@ -164,13 +161,12 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
q = self(batch, eps=0.).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
- if hasattr(batch, 'update_weight'):
- td = r - q
+ td = r - q
+ if hasattr(batch, 'update_weight'): # prio-buffer
batch.update_weight(batch.indice, to_numpy(td))
- impt_weight = to_torch_as(batch.impt_weight, q)
- loss = (td.pow(2) * impt_weight).mean()
- else:
- loss = F.mse_loss(q, r)
+ weight = to_torch_as(batch.weight, q)
+ loss = (td.pow(2) * weight).mean()
+ # loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()
self._cnt += 1
From 4f725732089e763b66fc7390d2cbb22de6bb6a03 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 12:12:57 +0800
Subject: [PATCH 12/35] TD3 and SAC
---
README.md | 2 +-
tianshou/data/buffer.py | 11 ++++++-----
tianshou/policy/base.py | 3 ++-
tianshou/policy/modelfree/ddpg.py | 6 ++----
tianshou/policy/modelfree/dqn.py | 9 +++------
tianshou/policy/modelfree/sac.py | 12 +++++++++---
tianshou/policy/modelfree/td3.py | 11 ++++++++---
7 files changed, 31 insertions(+), 23 deletions(-)
diff --git a/README.md b/README.md
index 414e288c3..2c027d506 100644
--- a/README.md
+++ b/README.md
@@ -38,7 +38,7 @@ Here is Tianshou's other features:
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
-- Support n-step returns estimation for all Q-learning based algorithms
+- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index f0c8622f4..e295f5da7 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -1,7 +1,8 @@
+import torch
import numpy as np
from typing import Any, Tuple, Union, Optional
-from tianshou.data import Batch, SegmentTree
+from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
@@ -415,16 +416,16 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
return batch, indice
def update_weight(self, indice: Union[np.ndarray],
- new_weight: np.ndarray) -> None:
+ new_weight: Union[np.ndarray, torch.Tensor]) -> None:
"""Update priority weight by indice in this buffer.
:param np.ndarray indice: indice you want to update weight.
:param np.ndarray new_weight: new priority weight you want to update.
"""
- weight = np.abs(new_weight) + self.__eps
+ weight = np.abs(to_numpy(new_weight)) + self.__eps
self.weight[indice] = weight ** self._alpha
- self._max_prio = max(self._max_prio, np.abs(new_weight).max())
- self._min_prio = min(self._min_prio, np.abs(new_weight).min())
+ self._max_prio = max(self._max_prio, weight.max())
+ self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index 65ae496fb..eef01e337 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -218,6 +218,7 @@ def compute_nstep_return(
if isinstance(buffer, PrioritizedReplayBuffer):
batch.update_weight = buffer.update_weight
batch.indice = indice
+ batch.weight = to_torch_as(batch.weight, target_q)
else:
- batch.weight = 1.
+ batch.weight = to_torch_as(1., target_q)
return batch
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
index 4d139cfa5..2205102f2 100644
--- a/tianshou/policy/modelfree/ddpg.py
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -1,7 +1,6 @@
import torch
import numpy as np
from copy import deepcopy
-import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy
@@ -146,9 +145,8 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
target_q = batch.returns.flatten()
td = current_q - target_q
if hasattr(batch, 'update_weight'): # prio-buffer
- batch.update_weight(batch.indice, to_numpy(td))
- weight = to_torch_as(batch.weight, target_q)
- critic_loss = (td.pow(2) * weight).mean()
+ batch.update_weight(batch.indice, td)
+ critic_loss = (td.pow(2) * batch.weight).mean()
# critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
index 1c4754064..c37dac515 100644
--- a/tianshou/policy/modelfree/dqn.py
+++ b/tianshou/policy/modelfree/dqn.py
@@ -1,12 +1,10 @@
import torch
import numpy as np
from copy import deepcopy
-import torch.nn.functional as F
from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy
-from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
- to_torch_as, to_numpy
+from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class DQNPolicy(BasePolicy):
@@ -163,9 +161,8 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
r = to_torch_as(batch.returns, q).flatten()
td = r - q
if hasattr(batch, 'update_weight'): # prio-buffer
- batch.update_weight(batch.indice, to_numpy(td))
- weight = to_torch_as(batch.weight, q)
- loss = (td.pow(2) * weight).mean()
+ batch.update_weight(batch.indice, td)
+ loss = (td.pow(2) * batch.weight).mean()
# loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()
diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py
index b67a95b90..ce4a5baf0 100644
--- a/tianshou/policy/modelfree/sac.py
+++ b/tianshou/policy/modelfree/sac.py
@@ -1,7 +1,6 @@
import torch
import numpy as np
from copy import deepcopy
-import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.policy import DDPGPolicy
@@ -141,16 +140,23 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
- critic1_loss = F.mse_loss(current_q1, target_q)
+ td1 = current_q1 - target_q
+ critic1_loss = (td1.pow(2) * batch.weight).mean()
+ # critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act).flatten()
- critic2_loss = F.mse_loss(current_q2, target_q)
+ td2 = current_q2 - target_q
+ critic2_loss = (td2.pow(2) * batch.weight).mean()
+ # critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
+ # prio-buffer
+ if hasattr(batch, 'update_weight'):
+ batch.update_weight(batch.indice, (td1 + td2) / 2.)
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act
diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py
index d90f51087..698145f1d 100644
--- a/tianshou/policy/modelfree/td3.py
+++ b/tianshou/policy/modelfree/td3.py
@@ -1,7 +1,6 @@
import torch
import numpy as np
from copy import deepcopy
-import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from tianshou.policy import DDPGPolicy
@@ -119,16 +118,22 @@ def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
- critic1_loss = F.mse_loss(current_q1, target_q)
+ td1 = current_q1 - target_q
+ critic1_loss = (td1.pow(2) * batch.weight).mean()
+ # critic1_loss = F.mse_loss(current_q1, target_q)
self.critic1_optim.zero_grad()
critic1_loss.backward()
self.critic1_optim.step()
# critic 2
current_q2 = self.critic2(batch.obs, batch.act).flatten()
- critic2_loss = F.mse_loss(current_q2, target_q)
+ td2 = current_q2 - target_q
+ critic2_loss = (td2.pow(2) * batch.weight).mean()
+ # critic2_loss = F.mse_loss(current_q2, target_q)
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
+ if hasattr(batch, 'update_weight'): # prio-buffer
+ batch.update_weight(batch.indice, (td1 + td2) / 2.)
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(
batch.obs, self(batch, eps=0).act).mean()
From 15c53be915597f510ad92fcc69c168d4ce14ce01 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 12:15:22 +0800
Subject: [PATCH 13/35] docs
---
docs/index.rst | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/index.rst b/docs/index.rst
index 25f041085..9ef598a81 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -28,7 +28,7 @@ Here is Tianshou's other features:
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
* Support customized training process: :ref:`customize_training`
-* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` for all Q-learning based algorithms
+* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ `_
From 0b8a316bb043be11e82509a8efe96ebb29cd335a Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 12:26:34 +0800
Subject: [PATCH 14/35] revert gitignore
---
.gitignore | 1 -
1 file changed, 1 deletion(-)
diff --git a/.gitignore b/.gitignore
index 2aa0c739b..0ecb650d3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -145,4 +145,3 @@ MUJOCO_LOG.TXT
.DS_Store
*.zip
*.pstats
-output.png
From c3c5a6446aa245e7c39af2ca05c0120d845ec94a Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 16:05:52 +0800
Subject: [PATCH 15/35] merge pdqn test to dqn
---
test/discrete/test_dqn.py | 24 ++++++--
test/discrete/test_pdqn.py | 118 -------------------------------------
2 files changed, 18 insertions(+), 124 deletions(-)
delete mode 100644 test/discrete/test_pdqn.py
diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py
index 0455f7059..9356fb9f3 100644
--- a/test/discrete/test_dqn.py
+++ b/test/discrete/test_dqn.py
@@ -8,9 +8,9 @@
from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net
+from tianshou.trainer import offpolicy_trainer
+from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
def get_args():
@@ -33,6 +33,9 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
+ parser.add_argument('--prioritized-replay', type=int, default=0)
+ parser.add_argument('--alpha', type=float, default=0.6)
+ parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@@ -58,15 +61,19 @@ def test_dqn(args=get_args()):
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
- args.action_shape, args.device,
- dueling=(2, 2)).to(args.device)
+ args.action_shape, args.device, dueling=(1, 1)).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
+ # buffer
+ if args.prioritized_replay > 0:
+ buf = PrioritizedReplayBuffer(
+ args.buffer_size, alpha=args.alpha, beta=args.beta)
+ else:
+ buf = ReplayBuffer(args.buffer_size)
# collector
- train_collector = Collector(
- policy, train_envs, ReplayBuffer(args.buffer_size))
+ train_collector = Collector(policy, train_envs, buf)
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
@@ -114,5 +121,10 @@ def test_fn(x):
collector.close()
+def test_pdqn(args=get_args()):
+ args.prioritized_replay = 1
+ test_dqn(args)
+
+
if __name__ == '__main__':
test_dqn(get_args())
diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py
deleted file mode 100644
index b614f248a..000000000
--- a/test/discrete/test_pdqn.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import os
-import gym
-import torch
-import pprint
-import argparse
-import numpy as np
-from torch.utils.tensorboard import SummaryWriter
-
-from tianshou.utils.net.common import Net
-from tianshou.env import VectorEnv
-from tianshou.policy import DQNPolicy
-from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
-
-
-def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('--task', type=str, default='CartPole-v0')
- parser.add_argument('--seed', type=int, default=1626)
- parser.add_argument('--eps-test', type=float, default=0.05)
- parser.add_argument('--eps-train', type=float, default=0.1)
- parser.add_argument('--buffer-size', type=int, default=20000)
- parser.add_argument('--lr', type=float, default=1e-3)
- parser.add_argument('--gamma', type=float, default=0.9)
- parser.add_argument('--n-step', type=int, default=3)
- parser.add_argument('--target-update-freq', type=int, default=320)
- parser.add_argument('--epoch', type=int, default=10)
- parser.add_argument('--step-per-epoch', type=int, default=1000)
- parser.add_argument('--collect-per-step', type=int, default=10)
- parser.add_argument('--batch-size', type=int, default=64)
- parser.add_argument('--layer-num', type=int, default=3)
- parser.add_argument('--training-num', type=int, default=8)
- parser.add_argument('--test-num', type=int, default=100)
- parser.add_argument('--logdir', type=str, default='log')
- parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--prioritized-replay', type=int, default=1)
- parser.add_argument('--alpha', type=float, default=0.5)
- parser.add_argument('--beta', type=float, default=0.5)
- parser.add_argument(
- '--device', type=str,
- default='cuda' if torch.cuda.is_available() else 'cpu')
- args = parser.parse_known_args()[0]
- return args
-
-
-def test_pdqn(args=get_args()):
- env = gym.make(args.task)
- args.state_shape = env.observation_space.shape or env.observation_space.n
- args.action_shape = env.action_space.shape or env.action_space.n
- # train_envs = gym.make(args.task)
- # you can also use tianshou.env.SubprocVectorEnv
- train_envs = VectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
- # test_envs = gym.make(args.task)
- test_envs = VectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.test_num)])
- # seed
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
- train_envs.seed(args.seed)
- test_envs.seed(args.seed)
- # model
- net = Net(args.layer_num, args.state_shape,
- args.action_shape, args.device).to(args.device)
- optim = torch.optim.Adam(net.parameters(), lr=args.lr)
- policy = DQNPolicy(
- net, optim, args.gamma, args.n_step,
- target_update_freq=args.target_update_freq)
- # collector
- if args.prioritized_replay > 0:
- buf = PrioritizedReplayBuffer(
- args.buffer_size, alpha=args.alpha,
- beta=args.alpha, repeat_sample=True)
- else:
- buf = ReplayBuffer(args.buffer_size)
- train_collector = Collector(
- policy, train_envs, buf)
- test_collector = Collector(policy, test_envs)
- # policy.set_eps(1)
- train_collector.collect(n_step=args.batch_size)
- # log
- log_path = os.path.join(args.logdir, args.task, 'dqn')
- writer = SummaryWriter(log_path)
-
- def save_fn(policy):
- torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
-
- def stop_fn(x):
- return x >= env.spec.reward_threshold
-
- def train_fn(x):
- policy.set_eps(args.eps_train)
-
- def test_fn(x):
- policy.set_eps(args.eps_test)
-
- # trainer
- result = offpolicy_trainer(
- policy, train_collector, test_collector, args.epoch,
- args.step_per_epoch, args.collect_per_step, args.test_num,
- args.batch_size, train_fn=train_fn, test_fn=test_fn,
- stop_fn=stop_fn, save_fn=save_fn, writer=writer)
-
- assert stop_fn(result['best_reward'])
- train_collector.close()
- test_collector.close()
- if __name__ == '__main__':
- pprint.pprint(result)
- # Let's watch its performance!
- env = gym.make(args.task)
- collector = Collector(policy, env)
- result = collector.collect(n_episode=1, render=args.render)
- print(f'Final reward: {result["rew"]}, length: {result["len"]}')
- collector.close()
-
-
-if __name__ == '__main__':
- test_pdqn(get_args())
From d9cd78c1d347aa7fcc98cda6d6dffdc30dbe4401 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 16:20:57 +0800
Subject: [PATCH 16/35] rm <<>>
---
tianshou/data/utils/segtree.py | 30 +++++++++++++++---------------
1 file changed, 15 insertions(+), 15 deletions(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 710a714ee..8de3e884b 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -13,7 +13,7 @@ class SegmentTree:
the union of elementary to internal nodes in ``[1:bound]``. The internal \
node follows the rule: \
``value[i] = operation(value[i * 2], value[i * 2 + 1])``.
- 3. Update a node takes O(log(bound)) time complexity.
+ 3. Update a (or some) node(s) takes O(log(bound)) time complexity.
4. Query an interval [l, r] with the default operation takes O(log(bound))
:param int size: the size of segment tree.
@@ -25,19 +25,17 @@ def __init__(self, size: int,
operation: str = "sum") -> None:
bound = 1
while bound < size:
- bound <<= 1
+ bound *= 2
+ self._size = size
self._bound = bound
assert operation in ["sum", "min", "max"], \
f"Unknown operation {operation}."
- (self._op, self._init_value) = {
- "sum": (np.sum, 0.),
- "min": (np.min, np.inf),
- "max": (np.max, -np.inf),
- }[operation]
- self._value = np.zeros([bound << 1]) + self._init_value
+ self._op = getattr(np, operation)
+ self._init_value = {'sum': 0, 'min': np.inf, 'max': -np.inf}[operation]
+ self._value = np.full([bound * 2], self._init_value, dtype=np.float64)
def __len__(self):
- return self._bound
+ return self._size
def __getitem__(self, index: Union[int, np.ndarray]
) -> Union[float, np.ndarray]:
@@ -46,17 +44,19 @@ def __getitem__(self, index: Union[int, np.ndarray]
def __setitem__(self, index: Union[int, np.ndarray],
value: Union[float, np.ndarray]) -> None:
- """Insert or overwrite a (or some) value(s) in this segment tree."""
+ """Insert or overwrite a (or some) value(s) in this segment tree. The
+ duplicate values are handled as numpy array, in other words, we only
+ keep the last value and ignore the previous same value.
+ """
if isinstance(index, int) and isinstance(value, float):
index, value = np.array([index]), np.array([value])
- assert isinstance(index, np.ndarray) and isinstance(value, np.ndarray)
assert ((0 <= index) & (index < self._bound)).all()
index = index + self._bound
self._value[index] = value
while index[0] > 1:
- index >>= 1
+ index //= 2
self._value[index] = self._op(
- [self._value[index << 1], self._value[index << 1 | 1]], axis=0)
+ [self._value[index * 2], self._value[index * 2 + 1]], axis=0)
def reduce(self, start: Optional[int] = 0,
end: Optional[int] = None) -> float:
@@ -74,7 +74,7 @@ def reduce(self, start: Optional[int] = 0,
result = self._op([result, self._value[start ^ 1]])
if end % 2 == 1:
result = self._op([result, self._value[end ^ 1]])
- start, end = start >> 1, end >> 1
+ start, end = start // 2, end // 2
return result
def get_prefix_sum_idx(
@@ -100,7 +100,7 @@ def get_prefix_sum_idx(
# @njit
def _get_prefix_sum_idx(index, scalar, bound, weight):
while index[0] < bound:
- index <<= 1
+ index *= 2
direct = weight[index] <= scalar
scalar -= weight[index] * direct
index += direct
From 036e54c7349eba032f4337a2f7a61ed231434f38 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 17:03:19 +0800
Subject: [PATCH 17/35] fix test
---
test/discrete/test_dqn.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py
index 9356fb9f3..2487c81d5 100644
--- a/test/discrete/test_dqn.py
+++ b/test/discrete/test_dqn.py
@@ -61,7 +61,8 @@ def test_dqn(args=get_args()):
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
- args.action_shape, args.device, dueling=(1, 1)).to(args.device)
+ args.action_shape, args.device, # dueling=(1, 1)
+ ).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
@@ -123,8 +124,9 @@ def test_fn(x):
def test_pdqn(args=get_args()):
args.prioritized_replay = 1
+ args.gamma = .95
test_dqn(args)
if __name__ == '__main__':
- test_dqn(get_args())
+ test_pdqn(get_args())
From 3c0cb2e820ba4ed98be70c905b208168903bd097 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 18:46:35 +0800
Subject: [PATCH 18/35] change op
---
tianshou/data/utils/segtree.py | 25 +++++++++++++++----------
1 file changed, 15 insertions(+), 10 deletions(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 8de3e884b..78ca5aa31 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -22,17 +22,22 @@ class SegmentTree:
"""
def __init__(self, size: int,
- operation: str = "sum") -> None:
+ operation: str = 'sum') -> None:
bound = 1
while bound < size:
bound *= 2
self._size = size
self._bound = bound
- assert operation in ["sum", "min", "max"], \
- f"Unknown operation {operation}."
- self._op = getattr(np, operation)
- self._init_value = {'sum': 0, 'min': np.inf, 'max': -np.inf}[operation]
- self._value = np.full([bound * 2], self._init_value, dtype=np.float64)
+ assert operation in ['sum', 'min', 'max'], \
+ f'Unknown operation {operation}.'
+ if operation == 'sum':
+ self._op, self._init_value = np.add, 0.
+ elif operation == 'min':
+ self._op, self._init_value = np.minimum, np.inf
+ else:
+ self._op, self._init_value = np.maximum, -np.inf
+ # assert isinstance(self._op, np.ufunc)
+ self._value = np.full([bound * 2], self._init_value)
def __len__(self):
return self._size
@@ -56,7 +61,7 @@ def __setitem__(self, index: Union[int, np.ndarray],
while index[0] > 1:
index //= 2
self._value[index] = self._op(
- [self._value[index * 2], self._value[index * 2 + 1]], axis=0)
+ self._value[index * 2], self._value[index * 2 + 1])
def reduce(self, start: Optional[int] = 0,
end: Optional[int] = None) -> float:
@@ -71,9 +76,9 @@ def reduce(self, start: Optional[int] = 0,
result = self._init_value
while start ^ end ^ 1 != 0:
if start % 2 == 0:
- result = self._op([result, self._value[start ^ 1]])
+ result = self._op(result, self._value[start ^ 1])
if end % 2 == 1:
- result = self._op([result, self._value[end ^ 1]])
+ result = self._op(result, self._value[end ^ 1])
start, end = start // 2, end // 2
return result
@@ -82,7 +87,7 @@ def get_prefix_sum_idx(
"""Return the index ``i`` which satisfies
``sum(value[:i]) <= value < sum(value[:i + 1])``.
"""
- assert self._op is np.sum
+ assert self._op is np.add
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
From c772009e3d993dbb97fd96251c72effb6c38571c Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 18:51:47 +0800
Subject: [PATCH 19/35] size assert
---
test/base/test_buffer.py | 8 ++++----
tianshou/data/utils/segtree.py | 2 +-
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 74cb30d0e..1568ef3b0 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -119,8 +119,8 @@ def test_segtree():
for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
realop = getattr(np, op)
# small test
- tree = SegmentTree(6, op) # 1-15. 8-15 are leaf nodes
actual_len = 8
+ tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes
assert np.all([tree[i] == init for i in range(actual_len)])
with pytest.raises(IndexError):
tree[actual_len]
@@ -154,8 +154,8 @@ def test_segtree():
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# large test
- tree = SegmentTree(10000, op)
actual_len = 16384
+ tree = SegmentTree(actual_len, op)
naive = np.zeros([actual_len]) + init
for _ in range(1000):
index = np.random.choice(actual_len, size=64)
@@ -172,8 +172,8 @@ def test_segtree():
tree.reduce(left, right))
# test prefix-sum-idx
- tree = SegmentTree(6)
actual_len = 8
+ tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
@@ -187,8 +187,8 @@ def test_segtree():
index = tree.get_prefix_sum_idx(scalar * 1.)
assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
# test large prefix-sum-idx
- tree = SegmentTree(10000)
actual_len = 16384
+ tree = SegmentTree(actual_len)
naive = np.random.rand(actual_len)
tree[np.arange(actual_len)] = naive
for _ in range(1000):
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 78ca5aa31..200b3abba 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -55,7 +55,7 @@ def __setitem__(self, index: Union[int, np.ndarray],
"""
if isinstance(index, int) and isinstance(value, float):
index, value = np.array([index]), np.array([value])
- assert ((0 <= index) & (index < self._bound)).all()
+ assert ((0 <= index) & (index < self._size)).all()
index = index + self._bound
self._value[index] = value
while index[0] > 1:
From 1f98d01e5b50a6fe05a91bc12139d72d2e8a67a4 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 18:55:46 +0800
Subject: [PATCH 20/35] minor fix
---
tianshou/data/utils/segtree.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 200b3abba..0da437a96 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -51,10 +51,10 @@ def __setitem__(self, index: Union[int, np.ndarray],
value: Union[float, np.ndarray]) -> None:
"""Insert or overwrite a (or some) value(s) in this segment tree. The
duplicate values are handled as numpy array, in other words, we only
- keep the last value and ignore the previous same value.
+ keep the last index-value pair and ignore the previous same indexes.
"""
- if isinstance(index, int) and isinstance(value, float):
- index, value = np.array([index]), np.array([value])
+ if isinstance(index, int):
+ index = np.array([index])
assert ((0 <= index) & (index < self._size)).all()
index = index + self._bound
self._value[index] = value
From a7472ac13621054927154555c4c39dd77bf13663 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 18:57:34 +0800
Subject: [PATCH 21/35] minor fix
---
tianshou/data/utils/segtree.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 0da437a96..570b0278f 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -4,7 +4,11 @@
class SegmentTree:
- """Implementation of Segment Tree. The procedure is as follows:
+ """Implementation of Segment Tree: store an array ``arr`` with size ``n``
+ in a segment tree, support value update and fast query of ``min/max/sum``
+ ``arr[left:right]`` in O(log n) time.
+
+ The detailed procedure is as follows:
1. Find out the smallest n which safisfies ``size <= 2^n``, and let \
``bound = 2^n``. This is to ensure that all leaf nodes are in the same \
From b6e0651f544565a674e4f602de801fa500d27a56 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 19:16:16 +0800
Subject: [PATCH 22/35] fix corner case
---
test/base/test_buffer.py | 4 ++++
tianshou/data/utils/segtree.py | 9 +++++++--
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 1568ef3b0..be20d916e 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -186,6 +186,10 @@ def test_segtree():
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
+ tree = SegmentTree(10)
+ tree[np.arange(3)] = np.array([0.1, 0, 0.1])
+ assert np.allclose(tree.get_prefix_sum_idx(
+ np.array([0, .1, .2])), [0, 2, 9])
# test large prefix-sum-idx
actual_len = 16384
tree = SegmentTree(actual_len)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 570b0278f..8bffe54b2 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -59,7 +59,9 @@ def __setitem__(self, index: Union[int, np.ndarray],
"""
if isinstance(index, int):
index = np.array([index])
- assert ((0 <= index) & (index < self._size)).all()
+ assert np.all(0 <= index) and np.all(index < self._size)
+ if self._op is np.add:
+ assert np.all(0 <= value)
index = index + self._bound
self._value[index] = value
while index[0] > 1:
@@ -89,9 +91,11 @@ def reduce(self, start: Optional[int] = 0,
def get_prefix_sum_idx(
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
"""Return the index ``i`` which satisfies
- ``sum(value[:i]) <= value < sum(value[:i + 1])``.
+ ``sum(value[:i]) <= value < sum(value[:i + 1])``. If multiple indexes
+ meet this condition, return the biggest one.
"""
assert self._op is np.add
+ assert np.all(value >= 0)
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
@@ -100,6 +104,7 @@ def get_prefix_sum_idx(
index = np.ones(value.shape, dtype=np.int)
index = self.__class__._get_prefix_sum_idx(
index, value, self._bound, self._value)
+ index[index >= self._size] = self._size - 1
return index.item() if single else index
# numba version, 5x speed up
From a6b2e2df28d20d7deab94a3b58993d57c47c5a35 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 19:36:03 +0800
Subject: [PATCH 23/35] fix
---
test/base/test_buffer.py | 6 ++++--
tianshou/data/utils/segtree.py | 12 +++++-------
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index be20d916e..e5d02a300 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -185,11 +185,13 @@ def test_segtree():
tree[np.arange(actual_len)] = naive
for scalar in range(actual_len):
index = tree.get_prefix_sum_idx(scalar * 1.)
- assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
+ assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
tree = SegmentTree(10)
tree[np.arange(3)] = np.array([0.1, 0, 0.1])
assert np.allclose(tree.get_prefix_sum_idx(
- np.array([0, .1, .2])), [0, 2, 9])
+ np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
+ with pytest.raises(AssertionError):
+ tree.get_prefix_sum_idx(.2)
# test large prefix-sum-idx
actual_len = 16384
tree = SegmentTree(actual_len)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 8bffe54b2..b2f93167b 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -61,7 +61,7 @@ def __setitem__(self, index: Union[int, np.ndarray],
index = np.array([index])
assert np.all(0 <= index) and np.all(index < self._size)
if self._op is np.add:
- assert np.all(0 <= value)
+ assert np.all(0. <= value)
index = index + self._bound
self._value[index] = value
while index[0] > 1:
@@ -90,12 +90,11 @@ def reduce(self, start: Optional[int] = 0,
def get_prefix_sum_idx(
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
- """Return the index ``i`` which satisfies
- ``sum(value[:i]) <= value < sum(value[:i + 1])``. If multiple indexes
- meet this condition, return the biggest one.
+ """Return the minimum index ``i`` which satisfies
+ ``sum(value[:i]) <= value <= sum(value[:i + 1])``.
"""
assert self._op is np.add
- assert np.all(value >= 0)
+ assert np.all(value >= 0.) and np.all(value < self._value[1])
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
@@ -104,7 +103,6 @@ def get_prefix_sum_idx(
index = np.ones(value.shape, dtype=np.int)
index = self.__class__._get_prefix_sum_idx(
index, value, self._bound, self._value)
- index[index >= self._size] = self._size - 1
return index.item() if single else index
# numba version, 5x speed up
@@ -115,7 +113,7 @@ def get_prefix_sum_idx(
def _get_prefix_sum_idx(index, scalar, bound, weight):
while index[0] < bound:
index *= 2
- direct = weight[index] <= scalar
+ direct = weight[index] < scalar
scalar -= weight[index] * direct
index += direct
# for _, s in enumerate(scalar):
From 6f5c4f61f0453338dc708196d3a05a2020dada93 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 19:48:51 +0800
Subject: [PATCH 24/35] fix test
---
test/base/test_buffer.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index e5d02a300..ea2130b3a 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -179,7 +179,7 @@ def test_segtree():
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
- assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
+ assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# corner case here
naive = np.ones(actual_len, np.int)
tree[np.arange(actual_len)] = naive
@@ -200,7 +200,7 @@ def test_segtree():
for _ in range(1000):
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
- assert naive[:index].sum() <= scalar < naive[:index + 1].sum()
+ assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
# profile
if __name__ == '__main__':
size = 100000
From 1226e2feb82831b87d943955a23cafcb545f068f Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 19:52:52 +0800
Subject: [PATCH 25/35] fix numba part
---
tianshou/data/utils/segtree.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index b2f93167b..970d6c662 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -120,7 +120,7 @@ def _get_prefix_sum_idx(index, scalar, bound, weight):
# i = 1
# while i < bound:
# l = i * 2
- # if weight[l] > s:
+ # if weight[l] >= s:
# i = l
# else:
# s = s - weight[l]
From 370802aa91f489b6fcdcabbccdb32b7c1f6bd1d8 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Sun, 2 Aug 2020 21:09:26 +0800
Subject: [PATCH 26/35] add to profile test
---
test/throughput/test_buffer_profile.py | 16 +++++++++++++---
1 file changed, 13 insertions(+), 3 deletions(-)
diff --git a/test/throughput/test_buffer_profile.py b/test/throughput/test_buffer_profile.py
index aec32682a..88abdcb64 100644
--- a/test/throughput/test_buffer_profile.py
+++ b/test/throughput/test_buffer_profile.py
@@ -1,8 +1,8 @@
-import numpy as np
import pytest
+import numpy as np
from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer,
- ReplayBuffer)
+ ReplayBuffer, SegmentTree)
@pytest.fixture(scope="module")
@@ -21,7 +21,7 @@ def data():
'buffer': buffer,
'buffer2': buffer2,
'slice': slice(-3000, -1000, 2),
- 'indexes': indexes
+ 'indexes': indexes,
}
@@ -77,5 +77,15 @@ def test_sample(data):
buffer.sample(int(1e2))
+def test_segtree(data):
+ size = 100000
+ tree = SegmentTree(size)
+ tree[np.arange(size)] = np.random.rand(size)
+
+ for i in np.arange(1e5):
+ scalar = np.random.rand(64) * tree.reduce()
+ tree.get_prefix_sum_idx(scalar)
+
+
if __name__ == '__main__':
pytest.main(["-s", "-k buffer_profile", "--durations=0", "-v"])
From d6765d7891754a3d3ec656011de95d7b413bc4e0 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Mon, 3 Aug 2020 01:03:00 +0800
Subject: [PATCH 27/35] doc polish and remove intricate xor operators
---
tianshou/data/utils/segtree.py | 91 ++++++++++++++++++----------------
1 file changed, 47 insertions(+), 44 deletions(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 970d6c662..d7d04eba7 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -3,26 +3,44 @@
# from numba import njit
+# numba version, 5x speed up
+# with size=100000 and bsz=64
+# first block (vectorized np): 0.0923 (now) -> 0.0251
+# second block (for-loop): 0.2914 -> 0.0192 (future)
+# @njit
+def _get_prefix_sum_idx(index, scalar, bound, weight):
+ while index[0] < bound:
+ index *= 2
+ direct = weight[index] < scalar
+ scalar -= weight[index] * direct
+ index += direct
+ # for _, s in enumerate(scalar):
+ # i = 1
+ # while i < bound:
+ # l = i * 2
+ # if weight[l] >= s:
+ # i = l
+ # else:
+ # s = s - weight[l]
+ # i = l + 1
+ # index[_] = i
+ return index - bound
+
+
class SegmentTree:
"""Implementation of Segment Tree: store an array ``arr`` with size ``n``
in a segment tree, support value update and fast query of ``min/max/sum``
- ``arr[left:right]`` in O(log n) time.
+ for the interval ``[left, right)`` in O(log n) time.
The detailed procedure is as follows:
- 1. Find out the smallest n which safisfies ``size <= 2^n``, and let \
- ``bound = 2^n``. This is to ensure that all leaf nodes are in the same \
- depth inside the segment tree.
- 2. Store the original value to leaf nodes in ``[bound:bound * 2]``, and \
- the union of elementary to internal nodes in ``[1:bound]``. The internal \
- node follows the rule: \
- ``value[i] = operation(value[i * 2], value[i * 2 + 1])``.
- 3. Update a (or some) node(s) takes O(log(bound)) time complexity.
- 4. Query an interval [l, r] with the default operation takes O(log(bound))
+ 1. Pad the array to have length of power of 2, so that leaf nodes in the\
+ segment tree have the same depth.
+ 2. Store the segment tree in a binary heap.
:param int size: the size of segment tree.
- :param str operation: the operation of segment tree. Choose one of "sum",
- "min" and "max", defaults to "sum".
+ :param str operation: the operation of segment tree. Choices are "sum",
+ "min" and "max". Default: "sum".
"""
def __init__(self, size: int,
@@ -53,9 +71,16 @@ def __getitem__(self, index: Union[int, np.ndarray]
def __setitem__(self, index: Union[int, np.ndarray],
value: Union[float, np.ndarray]) -> None:
- """Insert or overwrite a (or some) value(s) in this segment tree. The
- duplicate values are handled as numpy array, in other words, we only
- keep the last index-value pair and ignore the previous same indexes.
+ """Duplicate values in ``index`` are handled by numpy: later index
+ overwrites previous ones.
+
+ ::
+
+ >>> a = np.array([1, 2, 3, 4])
+ >>> a[[0, 1, 0, 1]] = [4, 5, 6, 7]
+ >>> print(a)
+ [6 7 3 4]
+
"""
if isinstance(index, int):
index = np.array([index])
@@ -75,16 +100,17 @@ def reduce(self, start: Optional[int] = 0,
if start == 0 and end is None:
return self._value[1]
if end is None:
- end = self._bound
+ end = self._size
if end < 0:
- end += self._bound
+ end += self._size
+ # nodes in (start, end) should be aggregated
start, end = start + self._bound - 1, end + self._bound
result = self._init_value
- while start ^ end ^ 1 != 0:
+ while end - start > 1: # (start, end) interval is not empty
if start % 2 == 0:
- result = self._op(result, self._value[start ^ 1])
+ result = self._op(result, self._value[start + 1])
if end % 2 == 1:
- result = self._op(result, self._value[end ^ 1])
+ result = self._op(result, self._value[end - 1])
start, end = start // 2, end // 2
return result
@@ -101,29 +127,6 @@ def get_prefix_sum_idx(
single = True
assert (value <= self._value[1]).all()
index = np.ones(value.shape, dtype=np.int)
- index = self.__class__._get_prefix_sum_idx(
+ index = _get_prefix_sum_idx(
index, value, self._bound, self._value)
return index.item() if single else index
-
- # numba version, 5x speed up
- # with size=100000 and bsz=64
- # first block (vectorized np): 0.0923 (now) -> 0.0251
- # second block (for-loop): 0.2914 -> 0.0192 (future)
- # @njit
- def _get_prefix_sum_idx(index, scalar, bound, weight):
- while index[0] < bound:
- index *= 2
- direct = weight[index] < scalar
- scalar -= weight[index] * direct
- index += direct
- # for _, s in enumerate(scalar):
- # i = 1
- # while i < bound:
- # l = i * 2
- # if weight[l] >= s:
- # i = l
- # else:
- # s = s - weight[l]
- # i = l + 1
- # index[_] = i
- return index - bound
From 0ced37a9fa7bb50f2127d7726f362329286d7ea2 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Mon, 3 Aug 2020 01:13:03 +0800
Subject: [PATCH 28/35] code refactor for _get_prefix_sum_idx
---
tianshou/data/utils/segtree.py | 24 +++++++++++-------------
1 file changed, 11 insertions(+), 13 deletions(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index d7d04eba7..70228488a 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -8,23 +8,25 @@
# first block (vectorized np): 0.0923 (now) -> 0.0251
# second block (for-loop): 0.2914 -> 0.0192 (future)
# @njit
-def _get_prefix_sum_idx(index, scalar, bound, weight):
+def _get_prefix_sum_idx(value, bound, sums):
+ index = np.ones(value.shape, dtype=np.int64)
while index[0] < bound:
index *= 2
- direct = weight[index] < scalar
- scalar -= weight[index] * direct
+ direct = sums[index] < value
+ value -= sums[index] * direct
index += direct
- # for _, s in enumerate(scalar):
+ # for _, s in enumerate(value):
# i = 1
# while i < bound:
# l = i * 2
- # if weight[l] >= s:
+ # if sums[l] >= s:
# i = l
# else:
- # s = s - weight[l]
+ # s = s - sums[l]
# i = l + 1
# index[_] = i
- return index - bound
+ index -= bound
+ return index
class SegmentTree:
@@ -116,8 +118,7 @@ def reduce(self, start: Optional[int] = 0,
def get_prefix_sum_idx(
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
- """Return the minimum index ``i`` which satisfies
- ``sum(value[:i]) <= value <= sum(value[:i + 1])``.
+ """Return the minimum index for each value so that ``value <= sum[i]``
"""
assert self._op is np.add
assert np.all(value >= 0.) and np.all(value < self._value[1])
@@ -125,8 +126,5 @@ def get_prefix_sum_idx(
if not isinstance(value, np.ndarray):
value = np.array([value])
single = True
- assert (value <= self._value[1]).all()
- index = np.ones(value.shape, dtype=np.int)
- index = _get_prefix_sum_idx(
- index, value, self._bound, self._value)
+ index = _get_prefix_sum_idx(value, self._bound, self._value)
return index.item() if single else index
From cebcc2d604935041f4f8038170cbeaa20def7ce4 Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Mon, 3 Aug 2020 10:37:28 +0800
Subject: [PATCH 29/35] leave todo and doc fix
---
tianshou/data/utils/segtree.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py
index 70228488a..60a60dd50 100644
--- a/tianshou/data/utils/segtree.py
+++ b/tianshou/data/utils/segtree.py
@@ -84,6 +84,7 @@ def __setitem__(self, index: Union[int, np.ndarray],
[6 7 3 4]
"""
+ # TODO numba njit version
if isinstance(index, int):
index = np.array([index])
assert np.all(0 <= index) and np.all(index < self._size)
@@ -99,6 +100,7 @@ def __setitem__(self, index: Union[int, np.ndarray],
def reduce(self, start: Optional[int] = 0,
end: Optional[int] = None) -> float:
"""Return operation(value[start:end])."""
+ # TODO numba njit version
if start == 0 and end is None:
return self._value[1]
if end is None:
@@ -118,7 +120,8 @@ def reduce(self, start: Optional[int] = 0,
def get_prefix_sum_idx(
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
- """Return the minimum index for each value so that ``value <= sum[i]``
+ """Return the minimum index for each ``v`` in ``value`` so that
+ ``v <= sums[i]``, where sums[i] = \\sum_{j=0}^{i} arr[j].
"""
assert self._op is np.add
assert np.all(value >= 0.) and np.all(value < self._value[1])
From 30a66190e06972b82795168b101cc5e021b51bdd Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Tue, 4 Aug 2020 00:17:17 +0800
Subject: [PATCH 30/35] small fix for torch ones like
---
tianshou/policy/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py
index eef01e337..cc1a59326 100644
--- a/tianshou/policy/base.py
+++ b/tianshou/policy/base.py
@@ -220,5 +220,5 @@ def compute_nstep_return(
batch.indice = indice
batch.weight = to_torch_as(batch.weight, target_q)
else:
- batch.weight = to_torch_as(1., target_q)
+ batch.weight = torch.ones_like(target_q)
return batch
From 61ea9f01a7300be2fc0571522fc8b397be99ed1f Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Tue, 4 Aug 2020 08:17:28 +0800
Subject: [PATCH 31/35] minor fix
---
test/base/test_buffer.py | 14 ++++++--------
tianshou/data/buffer.py | 8 +++++++-
2 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index ea2130b3a..9cfce29b3 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -124,7 +124,7 @@ def test_segtree():
assert np.all([tree[i] == init for i in range(actual_len)])
with pytest.raises(IndexError):
tree[actual_len]
- naive = np.zeros([actual_len]) + init
+ naive = np.full([actual_len], init)
for _ in range(1000):
# random choose a place to perform single update
index = np.random.randint(actual_len)
@@ -132,13 +132,10 @@ def test_segtree():
naive[index] = value
tree[index] = value
for i in range(actual_len):
- for j in range(i, actual_len):
- try:
- ref = realop(naive[i:j])
- except ValueError:
- continue
+ for j in range(i + 1, actual_len):
+ ref = realop(naive[i:j])
out = tree.reduce(i, j)
- assert np.allclose(ref, out), (i, j, ref, out)
+ assert np.allclose(ref, out)
# batch setitem
for _ in range(1000):
index = np.random.choice(actual_len, size=4)
@@ -156,7 +153,7 @@ def test_segtree():
# large test
actual_len = 16384
tree = SegmentTree(actual_len, op)
- naive = np.zeros([actual_len]) + init
+ naive = np.full([actual_len], init)
for _ in range(1000):
index = np.random.choice(actual_len, size=64)
value = np.random.rand(64)
@@ -201,6 +198,7 @@ def test_segtree():
scalar = np.random.rand() * naive.sum()
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
+
# profile
if __name__ == '__main__':
size = 100000
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index e295f5da7..cb4a0a7fa 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -373,9 +373,15 @@ def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
self._max_prio = 1.
self._min_prio = 1.
# bypass the check
- self._meta.__dict__['weight'] = SegmentTree(size, 'sum')
+ self._weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
+ def __getattr__(self, key: str) -> Union['Batch', Any]:
+ """Return self.key"""
+ if key == 'weight':
+ return self._weight
+ return self._meta.__dict__[key]
+
def add(self,
obs: Union[dict, np.ndarray],
act: Union[np.ndarray, float],
From 687ccbb607e87befb7cd21f7b5d1bc566283f7a5 Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Tue, 4 Aug 2020 09:48:53 +0800
Subject: [PATCH 32/35] minor fix
---
test/base/test_buffer.py | 14 +++++---------
1 file changed, 5 insertions(+), 9 deletions(-)
diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py
index 9cfce29b3..3f732d872 100644
--- a/test/base/test_buffer.py
+++ b/test/base/test_buffer.py
@@ -144,10 +144,8 @@ def test_segtree():
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
- left = right = 0
- while left >= right:
- left = np.random.randint(actual_len)
- right = np.random.randint(actual_len)
+ left = np.random.randint(actual_len)
+ right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
# large test
@@ -161,10 +159,8 @@ def test_segtree():
tree[index] = value
assert np.allclose(realop(naive), tree.reduce())
for i in range(10):
- left = right = 0
- while left >= right:
- left = np.random.randint(actual_len)
- right = np.random.randint(actual_len)
+ left = np.random.randint(actual_len)
+ right = np.random.randint(left + 1, actual_len + 1)
assert np.allclose(realop(naive[left:right]),
tree.reduce(left, right))
@@ -219,10 +215,10 @@ def sample_tree():
if __name__ == '__main__':
- test_priortized_replaybuffer()
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_segtree()
+ test_priortized_replaybuffer()
test_priortized_replaybuffer(233333, 200000)
test_update()
From 316974fae6555e36b0cdf40b2f1e8c9d532a362f Mon Sep 17 00:00:00 2001
From: youkaichao
Date: Tue, 4 Aug 2020 12:38:26 +0800
Subject: [PATCH 33/35] doc improve for sample
---
tianshou/data/buffer.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index cb4a0a7fa..5c7a58ab6 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -404,6 +404,10 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
all the data in the buffer if batch_size is ``0``.
:return: Sample data and its corresponding index inside the buffer.
+
+ The ``weight`` in the returned Batch is the weight on loss function
+ to de-bias the sampling process (some transition tuples are sampled
+ more often so their losses are weighted less).
"""
assert self._size > 0, 'Cannot sample a buffer with 0 size!'
if batch_size == 0:
From 87bb1332d5bedf6806af157364eb1f03c926dcdc Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Tue, 4 Aug 2020 12:52:24 +0800
Subject: [PATCH 34/35] fix dqn local test
---
test/discrete/test_dqn.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py
index 2487c81d5..ae4c4ce0c 100644
--- a/test/discrete/test_dqn.py
+++ b/test/discrete/test_dqn.py
@@ -129,4 +129,4 @@ def test_pdqn(args=get_args()):
if __name__ == '__main__':
- test_pdqn(get_args())
+ test_dqn(get_args())
From eb307eb40676dd663ea737f9bbce74f86b65631e Mon Sep 17 00:00:00 2001
From: Trinkle23897 <463003665@qq.com>
Date: Wed, 5 Aug 2020 15:25:26 +0800
Subject: [PATCH 35/35] fix weight update in buffer.add
---
tianshou/data/buffer.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py
index 5c7a58ab6..1d7a80f3c 100644
--- a/tianshou/data/buffer.py
+++ b/tianshou/data/buffer.py
@@ -395,8 +395,11 @@ def add(self,
"""Add a batch of data into replay buffer."""
if weight is None:
weight = self._max_prio
- weight = np.abs(weight) ** self._alpha
- self.weight[self._index] = weight
+ else:
+ weight = np.abs(weight)
+ self._max_prio = max(self._max_prio, weight)
+ self._min_prio = min(self._min_prio, weight)
+ self.weight[self._index] = weight ** self._alpha
super().add(obs, act, rew, done, obs_next, info, policy)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: