这是indexloc提供的服务,不要输入任何密码
Skip to content

#sdy define the utils that JAX jaxlib will use to allow for falling back to GSPMD when loading an old checkpoint. #97130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions third_party/xla/xla/service/spmd/shardy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/WalkResult.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/register.h"
#include "shardy/dialect/sdy/ir/utils.h"
Expand Down Expand Up @@ -300,5 +304,80 @@ SmallVector<AxisRefAttr> getOrderedAxisRefs(Attribute shardingOrAxisList,
return axisRefs;
}

namespace {

// Check if the func result is meant for Shardy.
bool isFuncResultForShardy(FuncOp func, int64_t resultIndex) {
if (func.getResultAttr(resultIndex, mlir::sdy::kShardingAttr)) {
return true;
}
Operation* definingOp =
mlir::sdy::getBodyTerminatorOperand(func, resultIndex).getDefiningOp();
if (!definingOp) {
return false;
}
auto customCall = mlir::dyn_cast<CustomCallOp>(definingOp);
if (!customCall) {
return false;
}
return customCall.getCallTargetName() == sdy::kFuncResultShardingTargetName;
}

// Check if the func result shardings are all mhlo shardings for GSPMD.
bool areFuncResultShardingsForGspmd(FuncOp func) {
for (int64_t resultIndex = 0; resultIndex < func.getNumResults();
++resultIndex) {
if (func.getResultAttr(resultIndex, sdy::kXlaShardingAttr) &&
!isFuncResultForShardy(func, resultIndex)) {
return true;
}
}
return false;
}

} // namespace

bool hasGspmdAttrsOrOps(mlir::ModuleOp module) {
for (auto func : module.getOps<mlir::func::FuncOp>()) {
if (func.getSymName() == "main") {
// The loaded module that could be targeting GSPMD is not the main
// function.
continue;
}
for (int64_t argIndex = 0; argIndex < func.getNumArguments(); ++argIndex) {
if (func.getArgAttr(argIndex, sdy::kXlaShardingAttr) &&
!func.getArgAttr(argIndex, mlir::sdy::kShardingAttr) &&
!hasKey(sdy::getFuncArgFrontendAttrs(func, argIndex),
sdy::kShardingRoundTripAttr)) {
return true;
}
}
if (areFuncResultShardingsForGspmd(func)) {
return true;
}
bool hasGspmd = false;
// Check the func for a `Sharding` custom call.
func->walk([&hasGspmd](mlir::stablehlo::CustomCallOp customCall) {
if (customCall.getCallTargetName() ==
sdy::kShardingCustomCallTargetName &&
customCall->hasAttr(sdy::kXlaShardingAttr) &&
!customCall->hasAttr(mlir::sdy::kShardingAttr) &&
!hasFrontendAttr(customCall, sdy::kShardingRoundTripAttr)) {
hasGspmd = true;
return mlir::WalkResult::interrupt();
}
return mlir::WalkResult::advance();
});
if (hasGspmd) {
return true;
}
}
return false;
}

bool hasShardyMesh(mlir::ModuleOp module) {
return !module.getOps<mlir::sdy::MeshOp>().empty();
}

} // namespace sdy
} // namespace xla
14 changes: 13 additions & 1 deletion third_party/xla/xla/service/spmd/shardy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.
#define XLA_SERVICE_SPMD_SHARDY_UTILS_H_

#include <cstdint>
#include <functional>
#include <optional>
#include <string>

Expand All @@ -29,6 +28,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
Expand Down Expand Up @@ -139,6 +139,18 @@ std::string duplicateShardingsAtIndices(
mlir::SmallVector<mlir::sdy::AxisRefAttr> getOrderedAxisRefs(
mlir::Attribute shardingOrAxisList, mlir::sdy::MeshAttr mesh);

// Returns true if the module has at least one GSPMD attribute or op, like an
// `mhlo.sharding` attribute or `Sharding` custom call.
// TODO(b/420837831): delete this once we don't fall back to GSPMD.
bool hasGspmdAttrsOrOps(mlir::ModuleOp module);

// Check if the module has any sort of Shardy mesh:
// - `mesh`
// - `maximal_mesh_{X}`
// - `empty_mesh`
// TODO(b/420837831): delete this once we don't fall back to GSPMD.
bool hasShardyMesh(mlir::ModuleOp module);

} // namespace sdy
} // namespace xla

Expand Down