这是indexloc提供的服务,不要输入任何密码
Skip to content

Use stablehlo precision config conversion for stablehlo ops #97233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ LogicalResult ExportXlaOp(mlir::stablehlo::ConvolutionOp op,
xla::ConvertConvDimensionNumbers(op.getDimensionNumbers()),
Convertuint64_t(op.getFeatureGroupCount()),
Convertuint64_t(op.getBatchGroupCount()),
Unwrap(Convert_precision_config(op.getPrecisionConfig())),
Unwrap(Convert_precision_config_stablehlo(op.getPrecisionConfig())),
preferred_element_type, op.getWindowReversal());
value_map[op] = xla_result;
return mlir::success();
Expand Down Expand Up @@ -2644,7 +2644,8 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) {
return success();
}

LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) {
LogicalResult ExportXlaOp(mlir::stablehlo::DotGeneralOp op,
OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp lhs, rhs;
if (failed(GetXlaOp(op.getLhs(), value_map, &lhs, op)))
Expand All @@ -2655,7 +2656,8 @@ LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) {
xla::ConvertMlirTypeToPrimitiveType(getElementTypeOrSelf(op.getType()));

// Precision Config / Algorithm
auto precision_config = Convert_precision_config(op.getPrecisionConfig());
auto precision_config =
Convert_precision_config_stablehlo(op.getPrecisionConfig());
if (op.getAlgorithmAttr()) {
absl::StatusOr<xla::PrecisionConfig::Algorithm> algorithm =
xla::ConvertDotAlgorithm(op.getAlgorithmAttr());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,47 @@ func.func @main(%arg0: tensor<?xf32>, %arg1: tensor<1xindex>, %arg2: tensor<1xin
HasSubstr("real_dynamic_slice"))));
}

TEST(ConvertMlirHloToHloModuleTest, ConvertsDotGeneralPrecisionConfig) {
const std::string mlir_source = R"mlir(
func.func @main(%arg0: tensor<5x10xbf16>, %arg1: tensor<10x5xbf16>) -> tensor<5x5xbf16> {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [HIGHEST, HIGHEST] : (tensor<5x10xbf16>, tensor<10x5xbf16>) -> tensor<5x5xbf16>
return %0 : tensor<5x5xbf16>
}
)mlir";

mlir::DialectRegistry registry;
registry.insert<mlir::func::FuncDialect, mlir::shape::ShapeDialect>();
mlir::stablehlo::registerAllDialects(registry);
mlir::MLIRContext context(registry);
mlir::OwningOpRef<mlir::ModuleOp> module;
{
mlir::BaseScopedDiagnosticHandler handler(&context);
module = mlir::parseSourceString<mlir::ModuleOp>(mlir_source, &context);
TF_ASSERT_OK(handler.ConsumeStatus());
}

TF_ASSERT_OK(ConvertMlirHloToHloModule(*module));
}
TEST(ConvertMlirHloToHloModuleTest, ConvertsConvolutionPrecisionConfig) {
const std::string mlir_source = R"mlir(
func.func @main(%arg0: tensor<3x3x3x3xf32>, %arg1: tensor<3x3x3x3xf32>) -> tensor<3x3x3x3xf32> {
%0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision HIGHEST>, #stablehlo<precision HIGHEST>]} : (tensor<3x3x3x3xf32>, tensor<3x3x3x3xf32>) -> tensor<3x3x3x3xf32>
return %0 : tensor<3x3x3x3xf32>
}
)mlir";

mlir::DialectRegistry registry;
registry.insert<mlir::func::FuncDialect, mlir::shape::ShapeDialect>();
mlir::stablehlo::registerAllDialects(registry);
mlir::MLIRContext context(registry);
mlir::OwningOpRef<mlir::ModuleOp> module;
{
mlir::BaseScopedDiagnosticHandler handler(&context);
module = mlir::parseSourceString<mlir::ModuleOp>(mlir_source, &context);
TF_ASSERT_OK(handler.ConsumeStatus());
}

TF_ASSERT_OK(ConvertMlirHloToHloModule(*module));
}
} // namespace
} // namespace mlir