这是indexloc提供的服务,不要输入任何密码
Skip to content

Add quantization utility functions for PropagateQSVPass. #97470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,7 @@ cc_library(
"transforms/prepare_quantize_dynamic_range.cc",
"transforms/prepare_quantize_helper.cc",
"transforms/quantization/propagate_qsv_pass.cc",
"transforms/quantization/quant_utils.cc",
"transforms/quantize.cc",
"transforms/quantize_variables.cc",
"utils/generated_op_quant_spec_getters.inc",
Expand All @@ -1373,6 +1374,7 @@ cc_library(
"transforms/lower_quant_annotations_helper.h",
"transforms/passes.h",
"transforms/prepare_quantize_helper.h",
"transforms/quantization/quant_utils.h",
],
deps = [
"convert_type",
Expand Down
131 changes: 131 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/quantization/quant_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/mlir/lite/transforms/quantization/quant_utils.h"

#include <optional>

#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_interface.h.inc"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep

namespace mlir::TFL {

std::optional<quant::QuantizedType> GetQTypeFromDefiningDequantize(
mlir::Value value) {
mlir::Operation* op = value.getDefiningOp();
TFL::DequantizeOp dq_op = dyn_cast_or_null<TFL::DequantizeOp>(op);
if (!dq_op) {
return std::nullopt;
}
return cast<quant::QuantizedType>(
getElementTypeOrSelf(dq_op.getInput().getType()));
}

std::optional<quant::QuantizedType> GetQTypeFromConsumingQuantize(
mlir::Value value) {
if (!value.hasOneUse()) {
return std::nullopt;
}
mlir::Operation* op = value.use_begin().getUser();
TFL::QuantizeOp q_op = dyn_cast_or_null<TFL::QuantizeOp>(op);
if (!q_op) {
return std::nullopt;
}
// The element type of the result of the quantize op is the quantized type.
return cast<quant::QuantizedType>(getElementTypeOrSelf(q_op.getType()));
}

// Sets the insertion point for the rewriter safely. If the value is an op
// result, it sets the insertion point after the defining op. If the value is a
// block argument, it sets the insertion point to the start of the block.
static LogicalResult SetInsertionPointAfterDefiningOp(
mlir::Value value, PatternRewriter& rewriter) {
mlir::Operation* defining_op = value.getDefiningOp();
if (defining_op) {
// It's an operation result, insert after the defining op.
rewriter.setInsertionPointAfter(defining_op);
} else if (auto block_arg = dyn_cast<BlockArgument>(value)) {
// It's a block argument, insert at the start of its owner block.
rewriter.setInsertionPointToStart(block_arg.getOwner());
} else {
// Handle other unexpected cases, maybe emit an error or return.
emitError(value.getLoc(),
"Value is neither an op result nor a block argument.");
return failure();
}
return success();
}

LogicalResult InsertQDQ(mlir::Value value, quant::QuantizedType qtype,
PatternRewriter& rewriter, mlir::Operation* target_op) {
if (failed(SetInsertionPointAfterDefiningOp(value, rewriter))) {
return failure();
}

// The new RankedTensorType with the element type being the quantized type.
auto shaped_type = dyn_cast<mlir::ShapedType>(value.getType());
if (!shaped_type) {
return failure();
}
RankedTensorType result_type =
RankedTensorType::get(shaped_type.getShape(), qtype);

auto quantize = rewriter.create<TFL::QuantizeOp>(
value.getLoc(), result_type, value, TypeAttr::get(result_type));
// mark this quantize as a propagated Quantize.
quantize->setAttr(kPropagatedQuantizeOpAttr, rewriter.getUnitAttr());

auto dequantize = rewriter.create<TFL::DequantizeOp>(
value.getLoc(), value.getType(), quantize);

rewriter.replaceUsesWithIf(value, dequantize, [&](OpOperand& use) {
// we have value -> Q -> dequantize so Q is already a "use" which we
// need to keep.
if (use.getOwner() == quantize) {
return false;
}
// If a target_op is set, only replace the uses on that target.
// This is helpful in the following case:
// const -> [value] -> [use1] -> op1
// \
// \- [use2] -> op2
//
// In this case, we need to to be able to insert a QDQ only on once of the
// uses and not necessarily all.
if (target_op && use.getOwner() != target_op) {
return false;
}
return true;
});
return success();
}
} // namespace mlir::TFL
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_QUANTIZATION_QUANT_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_QUANTIZATION_QUANT_UTILS_H_

#include <optional>

#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project
#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep

namespace mlir::TFL {

inline constexpr char kPropagatedQuantizeOpAttr[] = "propagated";

// If `value` is the result of a DequantizeOp, returns the quantized type of the
// DequantizeOp's input. Otherwise, returns std::nullopt.
// The IR pattern looks like:
// ... -> [quantized type] -> DequantizeOp -> [value]
// Otherwise, returns std::nullopt.
std::optional<quant::QuantizedType> GetQTypeFromDefiningDequantize(
mlir::Value value);

// If `value` has only one use and that use is a QuantizeOp, returns the
// quantized type of the QuantizeOp's result. Otherwise, returns std::nullopt.
// The single-use check is to avoid ambiguity in cases of fan-out.
// The IR pattern looks like:
// [value] -> QuantizeOp -> ...
std::optional<quant::QuantizedType> GetQTypeFromConsumingQuantize(
mlir::Value value);

// Inserts a Quantize-Dequantize (QDQ) pair for a value.
// If `target_op` is provided, it only replaces the uses of `value` within
// `target_op`. Otherwise, it replaces all uses of `value` (except for the
// newly created Quantize op).
LogicalResult InsertQDQ(mlir::Value value, quant::QuantizedType qtype,
PatternRewriter& rewriter,
mlir::Operation* target_op = nullptr);

} // namespace mlir::TFL

#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_QUANTIZATION_QUANT_UTILS_H_