这是indexloc提供的服务,不要输入任何密码
Skip to content

Enable pivoted QR on GPU via MAGMA. #25955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025
Merged

Conversation

tttc3
Copy link
Contributor

@tttc3 tttc3 commented Jan 17, 2025

Originally noted in #20282, this commit provides a GPU compatible implementation of geqp3 via MAGMA.
MAGMA implementation is based on @dfm's implementation of eig in ccb3317.

Maybe closes #12897?

To reduce code duplication I've moved AllocateWorkspace from solver_kernels_ffi.cc into ffi_helpers.h.

@dfm dfm self-assigned this Jan 17, 2025
@tttc3 tttc3 force-pushed the magma_qr branch 2 times, most recently from 47dc68d to ac4aeb0 Compare January 18, 2025 10:07
Copy link
Contributor

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've approved so that we can run all the tests, but also left some small inline comments.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 27, 2025
@dfm
Copy link
Contributor

dfm commented Jan 27, 2025

The TPU failures are unrelated, but can you rebase onto the current main branch after making your edits to fix them? Thanks!

Copy link
Contributor

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on this. I think we're there! One tiny last comment. Then can you rebase onto main (there's a merge conflict with the CHANGELOG)? Then, I can merge. Thanks again!

@tttc3
Copy link
Contributor Author

tttc3 commented Jan 29, 2025

MAGMA 2.9.0 was released yesterday and now supports a workspace query for lwork (in addition to an expert interface to geqp3_gpu). The workspace requirements appear unchanged (marked backward compatible) so we can continue to use the manual calculation implemented here (with the added benefit that it will work for MAGMA versions below 2.9.0).

Would you prefer I update this pull to make a workspace query and require magma>=2.9.0 or are you happy to leave it as it stands?

@dfm
Copy link
Contributor

dfm commented Jan 29, 2025

Would you prefer I update this pull to make a workspace query and require magma>=2.9.0 or are you happy to leave it as it stands?

Let's leave this as is!

So, we're getting quite a few JVP failures in testScipyQrModes with complex64 dtypes in our internal CI, and I can reproduce those when I run myself. The errors seem large enough that they don't seem to just be numerics. Can you try running those tests yourself to see if you can reproduce and debug?

@tttc3
Copy link
Contributor Author

tttc3 commented Jan 29, 2025

So, we're getting quite a few JVP failures in testScipyQrModes with complex64 dtypes in our internal CI, and I can reproduce those when I run myself. The errors seem large enough that they don't seem to just be numerics. Can you try running those tests yourself to see if you can reproduce and debug?

I can reproduce these errors. I'm pretty sure I've implemented the pivot inversion in qr_and_mul incorrectly, if you set inverted_pivots = jnp.argsort(p[0]) instead of inverted_pivots = p[0][p[0]], it should fix any JVP specific issues (I verified this by checking the output of qr_and_mul is indeed the identity, which it turns out it wasn't before).

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 3, 2025

@dfm does this change rectify the internal test failures for you?

@dfm
Copy link
Contributor

dfm commented Feb 3, 2025

Hmm... I'm still seeing issues. Let me try and debug a bit and report back!

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 10, 2025

@dfm did you have chance to look into this? I'm happy to look further if you could share details of the test failures.

@dfm
Copy link
Contributor

dfm commented Feb 11, 2025

Whoops! I dropped the ball on this last week. Could you rebase onto main and fix the merge conflicts? I'll get to the bottom of this today!

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 11, 2025

Looks like I made a mistake in the merge sorry! I've just fixed it and simplified the gpu lowering to use the new _lapack_ffi_lowering function.

@dfm
Copy link
Contributor

dfm commented Feb 11, 2025

Thanks for the updates here! Here are the failures that I'm consistently finding:

ScipyLinalgTest.testScipyQrModes57 (shape=(4, 3), dtype=<class 'numpy.complex64'>, mode='economic', pivoting=True)
ScipyLinalgTest.testScipyQrModes61 (shape=(3, 3), dtype=<class 'numpy.complex64'>, mode='full', pivoting=True)
ScipyLinalgTest.testScipyQrModes62 (shape=(3, 3), dtype=<class 'numpy.complex64'>, mode='economic', pivoting=True)

And they all fail when checking the JVP with large discrepancies that don't seem to be just caused by numerical noise. In all these tests we're not using MAGMA, so the issue seems to be with the host kernel.

Can you take a look at these and see if you can reproduce them?

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 12, 2025

I cannot reproduce these test failures on a CUDA device. Do the tests pass if you manually transfer the data to the host, before calling jtu.check_jvp? This will at least tell us if the issue is with the entire host implementation or just the GPU handoff to and from the host (although I don't see any errors when doing this either).

As a sanity check, can you check that qr_and_mul uses the corrected pivoting inversion in the internal tests? It might also be worth seeing if you can reproduce the test failures when using the custom build arguments given below?

Key information about the test machine that may be useful:

Host Device Microarchitecture: Haswell
CUDA Device Microarchitecture: Turing
CUDA Driver Version: 535.161.08
CUDA Version: 12.2 

Python Version: 3.13.2
Scipy Version: 1.15.1
BLAS/LAPACK: scipy-openblas 0.3.28

Custom arguments to `build.py` (required for successful compilation on the machine):
--bazel_options='--define xnn_enable_avxvnni=false --define xnn_enable_avx512amx=false --define xnn_enable_avx512fp16=false'

@@ -1684,9 +1684,11 @@ def osp_fun(A):
pivoting=[False, True]
)
def testScipyQrModes(self, shape, dtype, mode, pivoting):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured it out! We need to add:

@jax.default_matmul_precision("float32")

Otherwise we lose precision on some hardware!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one! I've added this in and fixed the merge conflict.

@dfm
Copy link
Contributor

dfm commented Feb 12, 2025

Can you add that matmul precision decorator then fix the merge conflict? I'm confident that we'll be good to go after this!

Originally noted in jax-ml#20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
@dfm
Copy link
Contributor

dfm commented Feb 12, 2025

All green now - should be merged soon. Thanks for all your work on this!

@copybara-service copybara-service bot merged commit f7e2901 into jax-ml:main Feb 12, 2025
21 checks passed
@tttc3 tttc3 deleted the magma_qr branch June 16, 2025 17:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Column-Pivoted QR Decomposition
3 participants