-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Parametrize build system on CUDA major version #28968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3c921ab
to
ba7f731
Compare
Try rebasing on head as well, please. Sorry for the slow review. |
This is a cherry-pick from jax-ml#28968. [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip]
ba7f731
to
60c975d
Compare
60c975d
to
647ff8c
Compare
Rebased again to resolve new conflicts. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a recent change that has an impact on this implementation too:
jax/jax_plugins/cuda/__init__.py
Lines 123 to 141 in 09d903f
def _load_nvidia_libraries(): | |
"""Attempts to load NVIDIA's libraries. | |
We prefer the Python packages, if present. If not, we fall back to loading | |
them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will | |
find these copies.""" | |
_load("cuda_runtime", ["libcudart.so.12"]) | |
# cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it | |
# and at least in CUDA 12.9 has RUNPATHs misconfigured to refer to | |
# nvidia/nvrtc instead of nvidia/cuda_nvrtc. | |
_load("cuda_nvrtc", ["libnvrtc.so.12"]) | |
_load("cublas", ["libcublas.so.12", "libcublasLt.so.12"]) | |
_load("nccl", ["libnccl.so.2"]) | |
_load("cuda_cupti", ["libcupti.so.12"]) | |
_load("cusparse", ["libcusparse.so.12"]) | |
_load("cusolver", ["libcusolver.so.11"]) | |
_load("cufft", ["libcufft.so.11"]) | |
_load("nvshmem", ["libnvshmem_host.so.3"]) | |
_load("cudnn", ["libcudnn.so.9"]) |
I think you need something similar with try-except to load either CUDA12 or CUDA13 wheels.
647ff8c
to
2817d38
Compare
Thanks. I think this one can be addressed in a follow-up if that's OK. |
2817d38
to
a5fecb4
Compare
This is a cherry-pick from jax-ml#28968. [ci skip]
a5fecb4
to
a147e50
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (I don't have approval permissions).
@olupton reran ci with latest changes, looks like there is still an issue in the cuda test and linter |
sorry, forgot to mention two more targets that need deps updated: |
a147e50
to
8c76df5
Compare
8c76df5
to
a224b30
Compare
Done, sorry about that. Tried to fix the |
Getting one additional linter error internally on jax.bzl. I'm going to try and patch it myself |
No description provided.