-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir] Dialect Conversion: Add support for post-order legalization order #166292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Dialect Conversion: Add support for post-order legalization order #166292
Conversation
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesBy default, the dialect conversion driver processes operations in pre-order: the initial worklist is populated pre-order. (New/modified operations are immediately legalized recursively.) This commit adds a new API for selective post-order legalization. Patterns can request an operation / region legalization via Note: In rollback mode, a failed recursive legalization typically leads to a conversion failure. Since recursive legalization is performed by separate pattern applications, there is no way for the original pattern to recover from such a failure. Full diff: https://github.com/llvm/llvm-project/pull/166292.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ed7e2a08ebfd9..db903dc337c46 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -981,6 +981,20 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
+ /// Attempt to legalize the given operation. This can be used within
+ /// conversion patterns to change the default post-order legalization order.
+ /// Returns "success" if the operation was legalized, "failure" otherwise.
+ LogicalResult legalize(Operation *op);
+
+ /// Attempt to legalize the given region. This can be used within
+ /// conversion patterns to change the default post-order legalization order.
+ /// Returns "success" if the region was legalized, "failure" otherwise.
+ ///
+ /// If the current pattern runs with a type converter, the entry block
+ /// signature will be converted before legalizing the operations in the
+ /// region.
+ LogicalResult legalize(Region *r);
+
private:
// Allow OperationConverter to construct new rewriters.
friend struct OperationConverter;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2fe06970eb568..c56c3b6e013ef 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -862,8 +862,11 @@ static bool hasRewrite(R &&rewrites, Block *block) {
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
+
namespace mlir {
namespace detail {
+class OperationLegalizer;
+
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
@@ -915,6 +918,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Return "true" if the given operation was replaced or erased.
bool wasOpReplaced(Operation *op) const;
+ /// Set the operation legalizer to use for recursive legalization.
+ void setOperationLegalizer(OperationLegalizer *legalizer) {
+ opLegalizer = legalizer;
+ }
+
/// Lookup the most recently mapped values with the desired types in the
/// mapping, taking into account only replacements. Perform a best-effort
/// search for existing materializations with the desired types.
@@ -1121,6 +1129,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;
+ /// The operation legalizer to use for recursive legalization. This is set by
+ /// the OperationConverter when the rewriter is created.
+ OperationLegalizer *opLegalizer = nullptr;
+
/// Dialect conversion configuration.
const ConversionConfig &config;
@@ -2357,7 +2369,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// OperationLegalizer
//===----------------------------------------------------------------------===//
-namespace {
+namespace mlir::detail {
/// A set of rewrite patterns that can be used to legalize a given operation.
using LegalizationPatterns = SmallVector<const Pattern *, 1>;
@@ -2454,7 +2466,7 @@ class OperationLegalizer {
/// The pattern applicator to use for conversions.
PatternApplicator applicator;
};
-} // namespace
+} // namespace mlir::detail
OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
const ConversionTarget &targetInfo,
@@ -2854,6 +2866,41 @@ LogicalResult OperationLegalizer::legalizePatternRootUpdates(
return success();
}
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+ return impl->opLegalizer->legalize(op);
+}
+
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+ // Fast path: If the region is empty, there is nothing to legalize.
+ if (r->empty())
+ return success();
+
+ // Gather a list of all operations to legalize. This is done before
+ // converting the entry block signature because unrealized_conversion_cast
+ // ops should not be included.
+ SmallVector<Operation *> ops;
+ for (Block &b : *r)
+ for (Operation &op : b)
+ ops.push_back(&op);
+
+ // If the current pattern runs with a type converter, convert the entry block
+ // signature.
+ if (const TypeConverter *converter = impl->currentTypeConverter) {
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(&r->front());
+ if (!conversion)
+ return failure();
+ applySignatureConversion(&r->front(), *conversion, converter);
+ }
+
+ // Legalize all operations in the region.
+ for (Operation *op : ops)
+ if (failed(legalize(op)))
+ return failure();
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Cost Model
//===----------------------------------------------------------------------===//
@@ -3218,7 +3265,10 @@ struct OperationConverter {
const ConversionConfig &config,
OpConversionMode mode)
: rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
- mode(mode) {}
+ mode(mode) {
+ // Set the legalizer in the rewriter so patterns can recursively legalize.
+ rewriter.getImpl().setOperationLegalizer(&opLegalizer);
+ }
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 71e11782e14b0..4bcca6b7e5228 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
"test.return"(%0) : (i32) -> ()
}
}
+
+// -----
+
+// CHECK-LABEL: func @test_failed_preorder_legalization
+// CHECK: "test.post_order_legalization"() ({
+// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
+// CHECK: "test.return"(%[[r]]) : (i32) -> ()
+// CHECK: }) : () -> ()
+// expected-remark @+1 {{applyPartialConversion failed}}
+module {
+func.func @test_failed_preorder_legalization() {
+ // expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
+ "test.post_order_legalization"() ({
+ %0 = "test.illegal_op_g"() : () -> (i32)
+ "test.return"(%0) : (i32) -> ()
+ }) : () -> ()
+ return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 7c43bb7bface0..5c6d4cd1dd205 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -448,3 +448,29 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
"test.type_consumer"(%arg0) : (f16) -> ()
"test.return"() : () -> ()
}
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.invalid
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationModified: test.post_order_legalization
+
+// CHECK-LABEL: func @test_preorder_legalization
+// CHECK: "test.post_order_legalization"() ({
+// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
+// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
+// CHECK: }) {is_legal} : () -> ()
+func.func @test_preorder_legalization() {
+ "test.post_order_legalization"() ({
+ ^bb0(%arg0: i64):
+ "test.invalid"(%arg0) : (i64) -> ()
+ }) : () -> ()
+ // expected-remark @+1 {{is not legalizable}}
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 12edecc113495..9b64bc691588d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern
}
};
+class TestPostOrderLegalization : public ConversionPattern {
+public:
+ TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ for (Region &r : op->getRegions())
+ if (failed(rewriter.legalize(&r)))
+ return failure();
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestValueReplace, TestReplaceWithValidConsumer,
- TestTypeConsumerOpPattern>(&getContext(), converter);
+ TestTypeConsumerOpPattern, TestPostOrderLegalization>(
+ &getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp(
OperationName("test.value_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.post_order_legalization", &getContext()),
+ [](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
|
63eac17 to
c949a6c
Compare
c949a6c to
b4c8f69
Compare
| ops.push_back(&op); | ||
|
|
||
| // If the current pattern runs with a type converter, convert the entry block | ||
| // signature. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this case tested right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, by this test case.
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/20572 Here is the relevant piece of the build log for the reference |
By default, the dialect conversion driver processes operations in pre-order: the initial worklist is populated pre-order. (New/modified operations are immediately legalized recursively.)
This commit adds a new API for selective post-order legalization. Patterns can request an operation / region legalization via
ConversionPatternRewriter::legalize. They can call these helper functions on nested regions before rewriting the operation itself.Note: In rollback mode, a failed recursive legalization typically leads to a conversion failure. Since recursive legalization is performed by separate pattern applications, there is no way for the original pattern to recover from such a failure.