这是indexloc提供的服务,不要输入任何密码
Skip to content

Allow MHLO passes to operate on mlir::FunctionOpInterface instead of func::FuncOp. #97449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions third_party/xla/xla/mlir_hlo/mhlo/transforms/mhlo_passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License.

include "mlir/Pass/PassBase.td"

def ChloLegalizeToHighLevelMhloPass : Pass<"chlo-legalize-to-high-level-mhlo", "func::FuncOp"> {
def ChloLegalizeToHighLevelMhloPass : Pass<"chlo-legalize-to-high-level-mhlo", "mlir::FunctionOpInterface"> {
let summary = "Legalize CHLO's with XLA counterparts, like TopK and Erf.";
let description = [{
Performs direct legalization of CHLO->MHLO only for high-level (non-basis)
Expand All @@ -25,7 +25,7 @@ def ChloLegalizeToHighLevelMhloPass : Pass<"chlo-legalize-to-high-level-mhlo", "
let dependentDialects = ["mhlo::MhloDialect"];
}

def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "func::FuncOp"> {
def ChloLegalizeToHloPass : Pass<"chlo-legalize-to-hlo", "mlir::FunctionOpInterface"> {
let summary = "Legalize CHLO to MHLO with XLA-supported ops.";
let description = [{
Performs legalization of CHLO->StableHLO->MHLO, while also preserving MHLO
Expand All @@ -44,28 +44,28 @@ def HloLegalizeToArithmeticPass :Pass<"hlo-legalize-to-arithmetic", "ModuleOp">
let constructor = "createLegalizeToArithmeticPass()";
}

def LegalizeDotToDotGeneralPass : Pass<"mhlo-legalize-dot-to-dot-general", "func::FuncOp"> {
def LegalizeDotToDotGeneralPass : Pass<"mhlo-legalize-dot-to-dot-general", "mlir::FunctionOpInterface"> {
let summary = "Legalizes dot ops to dot_general ops.";
let constructor = "createLegalizeDotToDotGeneralPass()";
}

def LegalizeEinsumToDotGeneralPass : Pass<"mhlo-legalize-einsum-to-dot-general", "func::FuncOp"> {
def LegalizeEinsumToDotGeneralPass : Pass<"mhlo-legalize-einsum-to-dot-general", "mlir::FunctionOpInterface"> {
let summary = "Legalizes einsum ops to dot_general ops.";
let constructor = "createLegalizeEinsumToDotGeneralPass()";
}

def LegalizeTorchIndexSelectToGatherPass : Pass<"mhlo-legalize-torch-index-select-to-gather", "func::FuncOp"> {
def LegalizeTorchIndexSelectToGatherPass : Pass<"mhlo-legalize-torch-index-select-to-gather", "mlir::FunctionOpInterface"> {
let summary = "Legalizes torch index select to a gather.";
let constructor = "createLegalizeTorchIndexSelectToGatherPass()";
}


def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "func::FuncOp"> {
def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "mlir::FunctionOpInterface"> {
let summary = "Legalize trigonometric operations from standard dialect to an approximation.";
let constructor = "createLegalizeTrigonometricToApproximationPass()";
}

def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "func::FuncOp"> {
def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "mlir::FunctionOpInterface"> {
let summary = "Legalize from HLO dialect to Linalg dialect.";
let constructor = "createLegalizeHloToLinalgPass()";
let options = [Option<"enablePrimitiveOps", "enable-primitive-ops", "bool",
Expand All @@ -74,12 +74,12 @@ def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "func::FuncOp"> {
"transpose) when possible, instead of linalg.generic">];
}

def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "func::FuncOp"> {
def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "mlir::FunctionOpInterface"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
let constructor = "createTestMaterializeBroadcastsPass()";
}

def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "func::FuncOp"> {
def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "mlir::FunctionOpInterface"> {
let summary = "Sink constants implicitly captured in control flow regions. This "
"is necessary to export to XLA.";
let constructor = "createSinkConstantsToControlFlowPass()";
Expand Down Expand Up @@ -124,12 +124,12 @@ def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow",
}];
}

def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "func::FuncOp"> {
def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "mlir::FunctionOpInterface"> {
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
let constructor = "createTestInferShapedTypeMethodsPass()";
}

def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "func::FuncOp"> {
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "mlir::FunctionOpInterface"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
let constructor = "createTestUnfuseBatchNormPass()";

Expand All @@ -147,7 +147,7 @@ def ExpandHloTuplesPass : Pass<"expand-hlo-tuples", "ModuleOp"> {
let dependentDialects = ["mhlo::MhloDialect"];
}

def FlattenTuplePass : Pass<"mhlo-flatten-tuple", "func::FuncOp"> {
def FlattenTuplePass : Pass<"mhlo-flatten-tuple", "mlir::FunctionOpInterface"> {
let summary = "Flatten tuples in operands and results of operators that "
"support both tuple and variadic type.";
let constructor = "createFlattenTuplePass()";
Expand All @@ -159,7 +159,7 @@ def ConvertToSignlessPass : Pass<"convert-to-signless", "ModuleOp"> {
}

def CollapseElementwiseMapPass
: Pass<"mhlo-collapse-elementwise-map", "func::FuncOp"> {
: Pass<"mhlo-collapse-elementwise-map", "mlir::FunctionOpInterface"> {
let summary = "Collapse the mhlo.map if the map only has elementwise ops.";
let constructor = "createCollapseElementwiseMapPass()";
}
Expand Down Expand Up @@ -189,7 +189,7 @@ def StablehloLegalizeToHloPass : Pass<"stablehlo-legalize-to-hlo", "ModuleOp"> {
];
}

def PrepareForExportPass : Pass<"xla-prepare-for-export", "mlir::func::FuncOp"> {
def PrepareForExportPass : Pass<"xla-prepare-for-export", "mlir::FunctionOpInterface"> {
let summary = "Prepare for XLA export";

let description = [{
Expand Down