Replies: 1 comment 1 reply
-
A simple suggestion based on what you already have: You can try using the |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi! First, thanks for this great package!
I use JAX heavily for my research, for which I implemented ray tracing utilities.
One critical operation among many is the ray-triangle intersection test. We usually have millions of rays and millions of triangles, and we would like to test, for each ray, if it intersects with any triangle. A common approach to check for this is to use the Möller-Trumbore algorithm, which I implemented using JAX.
However, I am struggling to improve the performance of this function, and I am reaching out to see if anyone has any suggestions. I am considering moving to Mitsuba (or using DrJit directly) or implementing a C++ (or Rust) version, but I hope I can find a way to improve the code without leaving JAX.
After reading some related discussions and issues (https://stackoverflow.com/questions/77527847/jax-vmap-limit-memory, #11319, https://stackoverflow.com/questions/77659069/how-to-understand-and-debug-memory-usage-with-jax), I realized that JIT compilation could optimize away the need for a large allocation.
Here is the code to reproduce (
uv run file.py
works):FYI, I have an NVIDIA 3070 with 8192 MiB of memory, which is not enormous but should do the job as it should not need to allocate a matrix that is too large to fit this.
N.B.: I am aware that I could use a hierarchical representation of the scene to accelerate the intersection-check, but I couldn't find any JAX-compatible library for that, so I would prefer avoiding this solution :'-)
Thanks in advance for your help!
Beta Was this translation helpful? Give feedback.
All reactions