这是indexloc提供的服务,不要输入任何密码
Skip to content

Parametrize build system on CUDA major version #28968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2025

Conversation

olupton
Copy link
Contributor

@olupton olupton commented May 23, 2025

No description provided.

@olupton olupton force-pushed the cuda-13-build-system branch from 3c921ab to ba7f731 Compare May 28, 2025 07:57
@hawkinsp
Copy link
Collaborator

hawkinsp commented Jun 4, 2025

Try rebasing on head as well, please. Sorry for the slow review.

olupton added a commit to olupton/jax that referenced this pull request Jun 6, 2025
This is a cherry-pick from jax-ml#28968. [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip] [ci skip]
@olupton olupton force-pushed the cuda-13-build-system branch from ba7f731 to 60c975d Compare June 6, 2025 14:01
@olupton olupton force-pushed the cuda-13-build-system branch from 60c975d to 647ff8c Compare June 16, 2025 14:17
@olupton
Copy link
Contributor Author

olupton commented Jun 16, 2025

Rebased again to resolve new conflicts.

Copy link
Contributor

@ybaturina ybaturina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a recent change that has an impact on this implementation too:

def _load_nvidia_libraries():
"""Attempts to load NVIDIA's libraries.
We prefer the Python packages, if present. If not, we fall back to loading
them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will
find these copies."""
_load("cuda_runtime", ["libcudart.so.12"])
# cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it
# and at least in CUDA 12.9 has RUNPATHs misconfigured to refer to
# nvidia/nvrtc instead of nvidia/cuda_nvrtc.
_load("cuda_nvrtc", ["libnvrtc.so.12"])
_load("cublas", ["libcublas.so.12", "libcublasLt.so.12"])
_load("nccl", ["libnccl.so.2"])
_load("cuda_cupti", ["libcupti.so.12"])
_load("cusparse", ["libcusparse.so.12"])
_load("cusolver", ["libcusolver.so.11"])
_load("cufft", ["libcufft.so.11"])
_load("nvshmem", ["libnvshmem_host.so.3"])
_load("cudnn", ["libcudnn.so.9"])

I think you need something similar with try-except to load either CUDA12 or CUDA13 wheels.

@olupton olupton force-pushed the cuda-13-build-system branch from 647ff8c to 2817d38 Compare June 25, 2025 10:38
@olupton
Copy link
Contributor Author

olupton commented Jun 25, 2025

There was a recent change that has an impact on this implementation too:

def _load_nvidia_libraries():
"""Attempts to load NVIDIA's libraries.
We prefer the Python packages, if present. If not, we fall back to loading
them from LD_LIBRARY_PATH. By loading the libraries here, later lookups will
find these copies."""
_load("cuda_runtime", ["libcudart.so.12"])
# cuda_nvrtc isn't directly a dependency of JAX, but CUDNN appears to need it
# and at least in CUDA 12.9 has RUNPATHs misconfigured to refer to
# nvidia/nvrtc instead of nvidia/cuda_nvrtc.
_load("cuda_nvrtc", ["libnvrtc.so.12"])
_load("cublas", ["libcublas.so.12", "libcublasLt.so.12"])
_load("nccl", ["libnccl.so.2"])
_load("cuda_cupti", ["libcupti.so.12"])
_load("cusparse", ["libcusparse.so.12"])
_load("cusolver", ["libcusolver.so.11"])
_load("cufft", ["libcufft.so.11"])
_load("nvshmem", ["libnvshmem_host.so.3"])
_load("cudnn", ["libcudnn.so.9"])

I think you need something similar with try-except to load either CUDA12 or CUDA13 wheels.

Thanks. I think this one can be addressed in a follow-up if that's OK.

@olupton olupton force-pushed the cuda-13-build-system branch from 2817d38 to a5fecb4 Compare July 11, 2025 11:40
olupton added a commit to olupton/jax that referenced this pull request Jul 16, 2025
@MichaelHudgins MichaelHudgins self-requested a review July 16, 2025 19:22
@olupton olupton force-pushed the cuda-13-build-system branch from a5fecb4 to a147e50 Compare July 17, 2025 15:07
Copy link
Contributor

@ybaturina ybaturina left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (I don't have approval permissions).

@MichaelHudgins
Copy link
Collaborator

@olupton reran ci with latest changes, looks like there is still an issue in the cuda test and linter

@ybaturina
Copy link
Contributor

@olupton reran ci with latest changes, looks like there is still an issue in the cuda test and linter

sorry, forgot to mention two more targets that need deps updated: jax_cuda_plugin_wheel_size_test and jax_cuda_pjrt_wheel_size_test.

@olupton olupton force-pushed the cuda-13-build-system branch from a147e50 to 8c76df5 Compare July 18, 2025 08:27
@olupton olupton force-pushed the cuda-13-build-system branch from 8c76df5 to a224b30 Compare July 18, 2025 16:30
@olupton
Copy link
Contributor Author

olupton commented Jul 18, 2025

Done, sorry about that. Tried to fix the mypy linter failure too.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jul 18, 2025
@MichaelHudgins
Copy link
Collaborator

Getting one additional linter error internally on jax.bzl. I'm going to try and patch it myself

@copybara-service copybara-service bot merged commit 0b2caf9 into jax-ml:main Jul 21, 2025
23 checks passed
@olupton olupton deleted the cuda-13-build-system branch July 21, 2025 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants