-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Enable pivoted QR on GPU via MAGMA. #25955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
47dc68d
to
ac4aeb0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I've approved so that we can run all the tests, but also left some small inline comments.
The TPU failures are unrelated, but can you rebase onto the current |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work on this. I think we're there! One tiny last comment. Then can you rebase onto main
(there's a merge conflict with the CHANGELOG)? Then, I can merge. Thanks again!
MAGMA 2.9.0 was released yesterday and now supports a workspace query for Would you prefer I update this pull to make a workspace query and require |
Let's leave this as is! So, we're getting quite a few JVP failures in |
I can reproduce these errors. I'm pretty sure I've implemented the pivot inversion in |
@dfm does this change rectify the internal test failures for you? |
Hmm... I'm still seeing issues. Let me try and debug a bit and report back! |
@dfm did you have chance to look into this? I'm happy to look further if you could share details of the test failures. |
Whoops! I dropped the ball on this last week. Could you rebase onto main and fix the merge conflicts? I'll get to the bottom of this today! |
Looks like I made a mistake in the merge sorry! I've just fixed it and simplified the gpu lowering to use the new |
Thanks for the updates here! Here are the failures that I'm consistently finding:
And they all fail when checking the JVP with large discrepancies that don't seem to be just caused by numerical noise. In all these tests we're not using MAGMA, so the issue seems to be with the host kernel. Can you take a look at these and see if you can reproduce them? |
I cannot reproduce these test failures on a CUDA device. Do the tests pass if you manually transfer the data to the host, before calling As a sanity check, can you check that Key information about the test machine that may be useful: Host Device Microarchitecture: Haswell
CUDA Device Microarchitecture: Turing
CUDA Driver Version: 535.161.08
CUDA Version: 12.2
Python Version: 3.13.2
Scipy Version: 1.15.1
BLAS/LAPACK: scipy-openblas 0.3.28
Custom arguments to `build.py` (required for successful compilation on the machine):
--bazel_options='--define xnn_enable_avxvnni=false --define xnn_enable_avx512amx=false --define xnn_enable_avx512fp16=false' |
@@ -1684,9 +1684,11 @@ def osp_fun(A): | |||
pivoting=[False, True] | |||
) | |||
def testScipyQrModes(self, shape, dtype, mode, pivoting): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured it out! We need to add:
@jax.default_matmul_precision("float32")
Otherwise we lose precision on some hardware!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one! I've added this in and fixed the merge conflict.
Can you add that matmul precision decorator then fix the merge conflict? I'm confident that we'll be good to go after this! |
Originally noted in jax-ml#20282, this commit provides a GPU compatible implementation of `geqp3` via MAGMA.
All green now - should be merged soon. Thanks for all your work on this! |
Originally noted in #20282, this commit provides a GPU compatible implementation of
geqp3
via MAGMA.MAGMA implementation is based on @dfm's implementation of
eig
in ccb3317.Maybe closes #12897?
To reduce code duplication I've moved
AllocateWorkspace
fromsolver_kernels_ffi.cc
intoffi_helpers.h
.