这是indexloc提供的服务,不要输入任何密码
Skip to content

Improve PER #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Aug 6, 2020
Merged

Improve PER #159

merged 43 commits into from
Aug 6, 2020

Conversation

Trinkle23897
Copy link
Collaborator

  1. use segment tree to rewrite the previous PrioReplayBuffer code, add the test
  2. enable all Q-learning algorithms to use PER

@Trinkle23897 Trinkle23897 changed the title WIP: Prio Experience Replay WIP: Improve PER Jul 23, 2020
@codecov-commenter
Copy link

codecov-commenter commented Aug 1, 2020

Codecov Report

Merging #159 into dev will increase coverage by 0.87%.
The diff coverage is 94.59%.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #159      +/-   ##
==========================================
+ Coverage   88.63%   89.50%   +0.87%     
==========================================
  Files          38       38              
  Lines        2226     2278      +52     
==========================================
+ Hits         1973     2039      +66     
+ Misses        253      239      -14     
Flag Coverage Δ
#unittests 89.50% <94.59%> (+0.87%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
tianshou/data/utils/converter.py 87.23% <ø> (ø)
tianshou/policy/modelfree/ddpg.py 97.46% <75.00%> (-1.24%) ⬇️
tianshou/policy/modelfree/sac.py 85.41% <83.33%> (-0.61%) ⬇️
tianshou/policy/modelfree/td3.py 98.59% <83.33%> (-1.41%) ⬇️
tianshou/data/utils/segtree.py 94.82% <94.82%> (ø)
tianshou/data/__init__.py 100.00% <100.00%> (ø)
tianshou/data/buffer.py 96.62% <100.00%> (+4.03%) ⬆️
tianshou/policy/base.py 95.38% <100.00%> (+0.38%) ⬆️
tianshou/policy/modelfree/dqn.py 97.50% <100.00%> (-0.18%) ⬇️
tianshou/utils/__init__.py
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 99a1d40...0b8a316. Read the comment docs.

@Trinkle23897 Trinkle23897 changed the title WIP: Improve PER Improve PER Aug 2, 2020
duburcqa
duburcqa previously approved these changes Aug 2, 2020
@Trinkle23897
Copy link
Collaborator Author

Trinkle23897 commented Aug 2, 2020

Discussion: is it necessary to support the segment tree with min/max operator? By definition, segment tree supports any binary commutative operators op(a, op(b, c)) = op(op(a, b), c). But the sum tree we use is not an exact segment tree: we require elements in the tree that are non-negative.

This is only the constraint for get_prefix_sum_idx, but for the normal usage of Segment Tree (op(value[l:r])), that doesn't exist.

Comment on lines +204 to +220
# profile
if __name__ == '__main__':
size = 100000
bsz = 64
naive = np.random.rand(size)
tree = SegmentTree(size)
tree[np.arange(size)] = naive

def sample_npbuf():
return np.random.choice(size, bsz, p=naive / naive.sum())

def sample_tree():
scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)

print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to make it a separate function?

Comment on lines +217 to +219
# prio buffer update
if isinstance(buffer, PrioritizedReplayBuffer):
batch.update_weight = buffer.update_weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mounting a function buffer.update_weight to be a field of batch objects is a hack and should be avoided. Since it has been here for a while and this PR is already very large, I will open a new PR to deal with it.

The solution may be something like the following: add a function BasePolicy .update and a function BasePolicy.post_process_fn. The update of weight into the buffer can be done in BasePolicy.post_process_fn. Trainer functions just have to call BasePolicy.update.

def update(self, buffer, batch_size):
    batch, indices = buffer.sample(batch_size)
    self.process_fn(batch, buffer, indices)
    self.learn(batch)
    self.post_process_fn(batch, buffer, indices)

youkaichao
youkaichao previously approved these changes Aug 4, 2020
Copy link
Collaborator

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nice PR to improve the efficiency of prioritized buffer!

@Trinkle23897 Trinkle23897 requested a review from duburcqa August 4, 2020 04:50
youkaichao
youkaichao previously approved these changes Aug 4, 2020
@Trinkle23897
Copy link
Collaborator Author

@duburcqa should be okay now and please have a check.

@Trinkle23897 Trinkle23897 merged commit 140b1c2 into thu-ml:dev Aug 6, 2020
@Trinkle23897 Trinkle23897 deleted the prio-buffer branch August 6, 2020 02:27
BFAnas pushed a commit to BFAnas/tianshou that referenced this pull request May 5, 2024
- use segment tree to rewrite the previous PrioReplayBuffer code, add the test

- enable all Q-learning algorithms to use PER
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants