-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Improve PER #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve PER #159
Conversation
Trinkle23897
commented
Jul 23, 2020
- use segment tree to rewrite the previous PrioReplayBuffer code, add the test
- enable all Q-learning algorithms to use PER
Codecov Report
@@ Coverage Diff @@
## dev #159 +/- ##
==========================================
+ Coverage 88.63% 89.50% +0.87%
==========================================
Files 38 38
Lines 2226 2278 +52
==========================================
+ Hits 1973 2039 +66
+ Misses 253 239 -14
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
This is only the constraint for |
# profile | ||
if __name__ == '__main__': | ||
size = 100000 | ||
bsz = 64 | ||
naive = np.random.rand(size) | ||
tree = SegmentTree(size) | ||
tree[np.arange(size)] = naive | ||
|
||
def sample_npbuf(): | ||
return np.random.choice(size, bsz, p=naive / naive.sum()) | ||
|
||
def sample_tree(): | ||
scalar = np.random.rand(bsz) * tree.reduce() | ||
return tree.get_prefix_sum_idx(scalar) | ||
|
||
print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000)) | ||
print('tree', timeit(sample_tree, setup=sample_tree, number=1000)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to make it a separate function?
# prio buffer update | ||
if isinstance(buffer, PrioritizedReplayBuffer): | ||
batch.update_weight = buffer.update_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mounting a function buffer.update_weight
to be a field of batch objects is a hack and should be avoided. Since it has been here for a while and this PR is already very large, I will open a new PR to deal with it.
The solution may be something like the following: add a function BasePolicy .update
and a function BasePolicy.post_process_fn
. The update of weight into the buffer can be done in BasePolicy.post_process_fn
. Trainer functions just have to call BasePolicy.update
.
def update(self, buffer, batch_size):
batch, indices = buffer.sample(batch_size)
self.process_fn(batch, buffer, indices)
self.learn(batch)
self.post_process_fn(batch, buffer, indices)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A nice PR to improve the efficiency of prioritized buffer!
@duburcqa should be okay now and please have a check. |
- use segment tree to rewrite the previous PrioReplayBuffer code, add the test - enable all Q-learning algorithms to use PER