Adding a hybrid GPU kernel for the Schur decomposition #30288
Unanswered
sjvenditto
asked this question in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I recently ran into the issue outlined here when trying to run the Schur decomposition on the GPU: #28927
I noticed that other linear algebra methods have a hybrid GPU kernel that calls the corresponding LAPACK routine on the host CPU (e.g.
jax.lax.linalg.eig()
), so I went ahead and added a similarly structured kernel for Schur to avoid doing the exchange on the python scripting side of things. I didn't add the MAGMA implementation because it doesn't look like they've implemented the gees routine yet, but I've structured the kernel similarly toeig
such that it can be added in the future.All the changes I've made are in my local fork here: https://github.com/sjvenditto/jax/tree/schur_hybrid_gpu
I updated the existing Schur tests (and some tests that call
schur
, likeFunm
andsqrtm
) to run on the GPU, and they are all passing locally (using CUDA12). Is this something that would be useful as a PR?Thanks!
Beta Was this translation helpful? Give feedback.
All reactions