这是indexloc提供的服务,不要输入任何密码
Skip to content

Add Bazel TPU tests in JAX presubmit and postsubmit. #2163

Add Bazel TPU tests in JAX presubmit and postsubmit.

Add Bazel TPU tests in JAX presubmit and postsubmit. #2163

# CI - Wheel Tests (Continuous)
#
# This workflow builds JAX artifacts and runs CPU/CUDA tests.
#
# It orchestrates the following:
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and
# uploads it to a GCS bucket.
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel
# that was built in the previous step and runs CPU tests.
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and
# uploads them to a GCS bucket.
# 4. run-bazel-test-cpu-py-import: Calls the `bazel_cpu_py_import_rbe.yml` workflow which
# runs Bazel CPU tests with py_import on RBE.
# 5. run-bazel-test-cuda-py-import: Calls the `bazel_cuda_non_rbe_py_import.yml` workflow which
# runs Bazel CUDA tests with py_import on non-RBE.
# 6. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA
# artifacts that were built in the previous steps and runs the CUDA tests.
# 7. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib
# and CUDA artifacts that were built in the previous steps and runs the
# CUDA tests using Bazel.
name: CI - Wheel Tests (Continuous)
# permissions:
# contents: read
# on:
# schedule:
# - cron: "0 */3 * * *" # Run once every 3 hours
# workflow_dispatch: # allows triggering the workflow run manually
# concurrency:
# group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
# cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }}
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
permissions: {}
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
build-jax-artifact:
uses: ./.github/workflows/build_artifacts.yml
name: "Build jax artifact"
with:
# Note that since jax is a pure python package, the runner OS and Python values do not
# matter. In addition, cloning main XLA also has no effect.
runner: "linux-x86-n2-16"
artifact: "jax"
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
build-jaxlib-artifact:
uses: ./.github/workflows/build_artifacts.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Runner OS and Python values need to match the matrix stategy in the CPU tests job
runner: ["linux-x86-n2-16"]
artifact: ["jaxlib"]
python: ["3.11"]
# Note: For reasons unknown, Github actions groups jobs with the same top-level name in the
# dashboard only if we use an expression in the "name" field. Otherwise, it appends the matrix
# values to the name and creates a separate entry for each matrix combination.
name: "Build ${{ format('{0}', 'jaxlib') }} artifacts"
with:
runner: ${{ matrix.runner }}
artifact: ${{ matrix.artifact }}
python: ${{ matrix.python }}
clone_main_xla: 1
upload_artifacts_to_gcs: true
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
# build-cuda-artifacts:
# uses: ./.github/workflows/build_artifacts.yml
# strategy:
# fail-fast: false # don't cancel all jobs on failure
# matrix:
# # Python values need to match the matrix stategy in the CUDA tests job below
# runner: ["linux-x86-n2-16"]
# artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
# python: ["3.11",]
# name: "Build ${{ format('{0}', 'CUDA') }} artifacts"
# with:
# runner: ${{ matrix.runner }}
# artifact: ${{ matrix.artifact }}
# python: ${{ matrix.python }}
# clone_main_xla: 1
# upload_artifacts_to_gcs: true
# gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
# run-pytest-cpu:
# # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# # still want to run the tests for other platforms.
# if: ${{ !cancelled() }}
# needs: [build-jax-artifact, build-jaxlib-artifact]
# uses: ./.github/workflows/pytest_cpu.yml
# strategy:
# fail-fast: false # don't cancel all jobs on failure
# matrix:
# # Runner OS and Python values need to match the matrix stategy in the
# # build_jaxlib_artifact job above
# runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"]
# python: ["3.11",]
# enable-x64: [1, 0]
# name: "Pytest CPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
# with:
# runner: ${{ matrix.runner }}
# python: ${{ matrix.python }}
# enable-x64: ${{ matrix.enable-x64 }}
# gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
# run-pytest-cuda:
# # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# # still want to run the tests for other platforms.
# if: ${{ !cancelled() }}
# needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
# uses: ./.github/workflows/pytest_cuda.yml
# strategy:
# fail-fast: false # don't cancel all jobs on failure
# matrix:
# # Python values need to match the matrix stategy in the artifact build jobs above
# # See exlusions for what is fully tested
# runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu", "linux-x86-a4-224-b200-1gpu"]
# python: ["3.11",]
# cuda: [
# {version: "12.1", use-nvidia-pip-wheels: false},
# {version: "12.8", use-nvidia-pip-wheels: true},
# ]
# enable-x64: [1, 0]
# exclude:
# # H100 runs only a single config, CUDA 12.8 Enable x64 1
# - runner: "linux-x86-a3-8g-h100-8gpu"
# cuda:
# version: "12.1"
# - runner: "linux-x86-a3-8g-h100-8gpu"
# enable-x64: "0"
# # B200 runs only a single config, CUDA 12.8 Enable x64 1
# - runner: "linux-x86-a4-224-b200-1gpu"
# cuda:
# version: "12.1"
# - runner: "linux-x86-a4-224-b200-1gpu"
# enable-x64: "0"
# name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }}, CUDA Pip packages = ${{ matrix.cuda.use-nvidia-pip-wheels }})"
# with:
# runner: ${{ matrix.runner }}
# python: ${{ matrix.python }}
# cuda-version: ${{ matrix.cuda.version }}
# use-nvidia-pip-wheels: ${{ matrix.cuda.use-nvidia-pip-wheels }}
# enable-x64: ${{ matrix.enable-x64 }}
# # GCS upload URI is the same for both artifact build jobs
# gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
# run-bazel-test-cpu-py-import:
# uses: ./.github/workflows/bazel_cpu_py_import_rbe.yml
# strategy:
# fail-fast: false # don't cancel all jobs on failure
# matrix:
# runner: ["linux-x86-n2-16", "linux-arm64-t2a-48"]
# python: ["3.11",]
# enable-x64: [1, 0]
# name: "Bazel CPU tests with ${{ format('{0}', 'py_import') }}"
# with:
# runner: ${{ matrix.runner }}
# python: ${{ matrix.python }}
# enable-x64: ${{ matrix.enable-x64 }}
# run-bazel-test-cuda:
# # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# # still want to run the tests for other platforms.
# if: ${{ !cancelled() }}
# needs: [build-jax-artifact, build-jaxlib-artifact, build-cuda-artifacts]
# uses: ./.github/workflows/bazel_cuda_non_rbe.yml
# strategy:
# fail-fast: false # don't cancel all jobs on failure
# matrix:
# # Python values need to match the matrix stategy in the build artifacts job above
# runner: ["linux-x86-g2-48-l4-4gpu",]
# python: ["3.11",]
# jaxlib-version: ["head", "pypi_latest"]
# enable-x64: [1, 0]
# name: "Bazel CUDA Non-RBE (jax version = ${{ format('{0}', 'head') }})"
# with:
# runner: ${{ matrix.runner }}
# python: ${{ matrix.python }}
# enable-x64: ${{ matrix.enable-x64 }}
# jaxlib-version: ${{ matrix.jaxlib-version }}
# # GCS upload URI is the same for both artifact build jobs
# gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
# run-bazel-test-cuda-py-import:
# # Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# # build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# # still want to run the tests for other platforms.
# if: ${{ !cancelled() }}
# uses: ./.github/workflows/bazel_cuda_non_rbe_py_import.yml
# strategy:
# fail-fast: false # don't cancel all jobs on failure
# matrix:
# # Python values need to match the matrix stategy in the build artifacts job above
# runner: ["linux-x86-g2-48-l4-4gpu",]
# python: ["3.11"]
# enable-x64: [1]
# name: "Bazel CUDA Non-RBE with ${{ format('{0}', 'py_import') }}"
# with:
# runner: ${{ matrix.runner }}
# python: ${{ matrix.python }}
# enable-x64: ${{ matrix.enable-x64 }}
run-pytest-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/pytest_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.11",]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}
]
libtpu-version-type: ["nightly", "oldest_supported_libtpu"]
exclude:
# Run a single config for oldest_supported_libtpu
- libtpu-version-type: "oldest_supported_libtpu"
tpu-specs:
type: "v4-8"
- libtpu-version-type: "oldest_supported_libtpu"
tpu-specs:
type: "v6e-8"
name: "Pytest TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
run-bazel-test-tpu:
# Run test jobs even if the build job fails. Avoids losing test coverage if a single unrelated
# build job fails. E.g Windows build job fails but everything else succeeds. In this case, we
# still want to run the tests for other platforms.
if: ${{ !cancelled() }}
needs: [build-jax-artifact, build-jaxlib-artifact]
uses: ./.github/workflows/bazel_test_tpu.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
python: ["3.11",]
tpu-specs: [
# {type: "v3-8", cores: "4"}, # Enable when we have the v3 type available
{type: "v4-8", cores: "4", runner: "linux-x86-ct4p-240-4tpu"},
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"},
{type: "v6e-8", cores: "8", runner: "linux-x86-ct6e-180-8tpu"}
]
libtpu-version-type: ["nightly", "oldest_supported_libtpu"]
exclude:
# Run a single config for oldest_supported_libtpu
- libtpu-version-type: "oldest_supported_libtpu"
tpu-specs:
type: "v4-8"
- libtpu-version-type: "oldest_supported_libtpu"
tpu-specs:
type: "v6e-8"
name: "Bazel tests TPU (JAX artifacts version = ${{ format('{0}', 'head') }})"
with:
runner: ${{ matrix.tpu-specs.runner }}
cores: ${{ matrix.tpu-specs.cores }}
tpu-type: ${{ matrix.tpu-specs.type }}
python: ${{ matrix.python }}
run-full-tpu-test-suite: "1"
libtpu-version-type: ${{ matrix.libtpu-version-type }}
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}