这是indexloc提供的服务,不要输入任何密码
Skip to content

Enable pivoted QR on GPU via MAGMA. #25955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Added low-level reduction APIs in {mod}`jax.lax`: {func}`jax.lax.reduce_sum`,
{func}`jax.lax.reduce_prod`, {func}`jax.lax.reduce_max`, {func}`jax.lax.reduce_min`,
{func}`jax.lax.reduce_and`, {func}`jax.lax.reduce_or`, and {func}`jax.lax.reduce_xor`.
* {func}`jax.lax.linalg.qr`, and {func}`jax.scipy.linalg.qr`, now support
column-pivoting on CPU and GPU. See {jax-issue}`#20282` and
{jax-issue}`#25955` for more details.

* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
Expand Down
74 changes: 55 additions & 19 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,22 @@ def lu(x: ArrayLike) -> tuple[Array, Array, Array]:

@overload
def qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True,
) -> tuple[Array, Array]:
use_magma: bool | None = None) -> tuple[Array, Array]:
...

@overload
def qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True,
) -> tuple[Array, Array, Array]:
use_magma: bool | None = None) -> tuple[Array, Array, Array]:
...

@overload
def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
use_magma: bool | None = None
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
...

def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
use_magma: bool | None = None
) -> tuple[Array, Array] | tuple[Array, Array, Array]:
"""QR decomposition.

Expand All @@ -341,9 +343,14 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
pivoting: Allows the QR decomposition to be rank-revealing. If ``True``,
compute the column pivoted decomposition ``A[:, P] = Q @ R``, where ``P``
is chosen such that the diagonal of ``R`` is non-increasing. Currently
supported on CPU backends only.
supported on CPU and GPU backends only.
full_matrices: Determines if full or reduced matrices are returned; see
below.
use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
pivoted `qr` factorization is computed using MAGMA. If ``False``, the
computation is done using LAPACK on the host CPU. If ``None`` (default),
the behavior is controlled by the ``jax_use_magma`` flag. This argument is
only used on GPU.

Returns:
A pair of arrays ``(q, r)``, if ``pivoting=False``, otherwise ``(q, r, p)``.
Expand All @@ -357,8 +364,16 @@ def qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True,
``full_matrices=False``.

Array ``p`` is an index vector with shape [..., n]

Notes:
- `MAGMA <https://icl.utk.edu/magma/>`_ support is experimental - see
:func:`jax.lax.linalg.eig` for further assumptions and limitations.
- If ``jax_use_magma`` is set to ``"auto"``, the MAGMA implementation will
be used if the library can be found, and the input matrix is sufficiently
large (has at least 2048 columns).
"""
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
use_magma=use_magma)
if pivoting:
return q, r, p[0]
return q, r
Expand Down Expand Up @@ -1854,22 +1869,28 @@ def _geqrf_cpu_gpu_lowering(ctx, a, *, target_name_prefix: str):
platform='rocm')


def geqp3(a: ArrayLike, jpvt: ArrayLike) -> tuple[Array, Array, Array]:
def geqp3(a: ArrayLike, jpvt: ArrayLike, *,
use_magma: bool | None = None) -> tuple[Array, Array, Array]:
"""Computes the column-pivoted QR decomposition of a matrix.

Args:
a: a ``[..., m, n]`` batch of matrices, with floating-point or complex type.
jpvt: a ``[..., n]`` batch of column-pivot index vectors with integer type,
use_magma: Locally override the ``jax_use_magma`` flag. If ``True``, the
`geqp3` is computed using MAGMA. If ``False``, the computation is done using
LAPACK on to the host CPU. If ``None`` (default), the behavior is controlled
by the ``jax_use_magma`` flag. This argument is only used on GPU.
Returns:
A ``(a, jpvt, taus)`` triple, where ``r`` is in the upper triangle of ``a``,
``q`` is represented in the lower triangle of ``a`` and in ``taus`` as
elementary Householder reflectors, and ``jpvt`` is the column-pivot indices
such that ``a[:, jpvt] = q @ r``.
"""
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt)
a_out, jpvt_out, taus = geqp3_p.bind(a, jpvt, use_magma=use_magma)
return a_out, jpvt_out, taus

def _geqp3_abstract_eval(a, jpvt):
def _geqp3_abstract_eval(a, jpvt, *, use_magma):
del use_magma
if not isinstance(a, ShapedArray) or not isinstance(jpvt, ShapedArray):
raise NotImplementedError("Unsupported aval in geqp3_abstract_eval: "
f"{a.aval}, {jpvt.aval}")
Expand All @@ -1882,25 +1903,37 @@ def _geqp3_abstract_eval(a, jpvt):
taus = a.update(shape=(*batch_dims, core.min_dim(m, n)))
return a, jpvt, taus

def _geqp3_batching_rule(batched_args, batch_dims):
def _geqp3_batching_rule(batched_args, batch_dims, *, use_magma):
a, jpvt = batched_args
b_a, b_jpvt = batch_dims
a = batching.moveaxis(a, b_a, 0)
jpvt = batching.moveaxis(jpvt, b_jpvt, 0)
return geqp3(a, jpvt), (0, 0, 0)
return geqp3(a, jpvt, use_magma=use_magma), (0, 0, 0)

def _geqp3_cpu_lowering(ctx, a, jpvt):
def _geqp3_cpu_lowering(ctx, a, jpvt, *, use_magma):
del use_magma
a_aval, _ = ctx.avals_in
target_name = lapack.prepare_lapack_call("geqp3_ffi", a_aval.dtype)
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
return rule(ctx, a, jpvt)

def _geqp3_gpu_lowering(target_name_prefix, ctx, a, jpvt, *, use_magma):
gpu_solver.initialize_hybrid_kernels()
magma = config.gpu_use_magma.value
target_name = f"{target_name_prefix}hybrid_geqp3"
if use_magma is not None:
magma = "on" if use_magma else "off"
rule = _linalg_ffi_lowering(target_name, operand_output_aliases={0: 0, 1: 1})
return rule(ctx, a, jpvt, magma=magma)

geqp3_p = Primitive('geqp3')
geqp3_p.multiple_results = True
geqp3_p.def_impl(partial(dispatch.apply_primitive, geqp3_p))
geqp3_p.def_abstract_eval(_geqp3_abstract_eval)
batching.primitive_batchers[geqp3_p] = _geqp3_batching_rule
mlir.register_lowering(geqp3_p, _geqp3_cpu_lowering, platform="cpu")
mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'cu'), platform="cuda")
mlir.register_lowering(geqp3_p, partial(_geqp3_gpu_lowering, 'hip'), platform="rocm")

# householder_product: product of elementary Householder reflectors

Expand Down Expand Up @@ -1988,12 +2021,13 @@ def _householder_product_cpu_gpu_lowering(ctx, a, taus, *,
platform='rocm')


def _qr_impl(operand, *, pivoting, full_matrices):
def _qr_impl(operand, *, pivoting, full_matrices, use_magma):
q, r, *p = dispatch.apply_primitive(qr_p, operand, pivoting=pivoting,
full_matrices=full_matrices)
full_matrices=full_matrices, use_magma=use_magma)
return (q, r, p[0]) if pivoting else (q, r)

def _qr_abstract_eval(operand, *, pivoting, full_matrices):
def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
del use_magma
if isinstance(operand, ShapedArray):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
Expand All @@ -2018,11 +2052,11 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
q, r, p = operand, operand, operand
return (q, r, p) if pivoting else (q, r)

def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices, use_magma):
# See j-towns.github.io/papers/qr-derivative.pdf for a terse derivation.
x, = primals
dx, = tangents
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False)
q, r, *p = qr_p.bind(x, pivoting=pivoting, full_matrices=False, use_magma=use_magma)
*_, m, n = x.shape
if m < n or (full_matrices and m != n):
raise NotImplementedError(
Expand All @@ -2043,14 +2077,16 @@ def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
return (q, r, p[0]), (dq, dr, dp)
return (q, r), (dq, dr)

def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices):
def _qr_batching_rule(batched_args, batch_dims, *, pivoting, full_matrices,
use_magma):
x, = batched_args
bd, = batch_dims
x = batching.moveaxis(x, bd, 0)
out_axes = (0, 0, 0) if pivoting else (0, 0)
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices), out_axes
return qr_p.bind(x, pivoting=pivoting, full_matrices=full_matrices,
use_magma=use_magma), out_axes

def _qr_lowering(a, *, pivoting, full_matrices):
def _qr_lowering(a, *, pivoting, full_matrices, use_magma):
*batch_dims, m, n = a.shape
if m == 0 or n == 0:
k = m if full_matrices else core.min_dim(m, n)
Expand All @@ -2065,7 +2101,7 @@ def _qr_lowering(a, *, pivoting, full_matrices):

if pivoting:
jpvt = lax.full((*batch_dims, n), 0, dtype=np.dtype(np.int32))
r, p, taus = geqp3(a, jpvt)
r, p, taus = geqp3(a, jpvt, use_magma=use_magma)
p -= 1 # Convert geqp3's 1-based indices to 0-based indices by subtracting 1.
else:
r, taus = geqrf(a)
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,9 @@ def qr(a: ArrayLike, overwrite_a: bool = False, lwork: Any = None, mode: str = "
with ``K = min(M, N)``.

Notes:
- At present, pivoting is only implemented on CPU backends.
- At present, pivoting is only implemented on the CPU and GPU backends. For further
details about the GPU implementation, see the documentation for
:func:`jax.lax.linalg.qr`.

See also:
- :func:`jax.numpy.linalg.qr`: NumPy-style QR decomposition API
Expand Down
13 changes: 13 additions & 0 deletions jaxlib/ffi_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,19 @@ auto AllocateScratchMemory(std::size_t size)
return std::unique_ptr<ValueType[]>(new ValueType[size]);
}

template <typename T>
inline absl::StatusOr<T*> AllocateWorkspace(
::xla::ffi::ScratchAllocator& scratch, int64_t size,
std::string_view name) {
auto maybe_workspace = scratch.Allocate(sizeof(T) * size);
if (!maybe_workspace.has_value()) {
return absl::Status(
absl::StatusCode::kResourceExhausted,
absl::StrFormat("Unable to allocate workspace for %s", name));
}
return static_cast<T*>(maybe_workspace.value());
}

} // namespace jax

#endif // JAXLIB_FFI_HELPERS_H_
5 changes: 5 additions & 0 deletions jaxlib/gpu/hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ void GetLapackKernelsFromScipy() {
lapack_ptr("cgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<ffi::C128>>(
lapack_ptr("zgeev"));
AssignKernelFn<PivotingQrFactorization<ffi::F32>>(lapack_ptr("sgeqp3"));
AssignKernelFn<PivotingQrFactorization<ffi::F64>>(lapack_ptr("dgeqp3"));
AssignKernelFn<PivotingQrFactorization<ffi::C64>>(lapack_ptr("cgeqp3"));
AssignKernelFn<PivotingQrFactorization<ffi::C128>>(lapack_ptr("zgeqp3"));
});
}

Expand All @@ -57,6 +61,7 @@ NB_MODULE(_hybrid, m) {
nb::dict dict;
dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal);
dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp);
dict[JAX_GPU_PREFIX "hybrid_geqp3"] = EncapsulateFfiHandler(kGeqp3);
return dict;
});
}
Expand Down
Loading
Loading