Skip to content

Conversation

@matthias-springer
Copy link
Member

By default, the dialect conversion driver processes operations in pre-order: the initial worklist is populated pre-order. (New/modified operations are immediately legalized recursively.)

This commit adds a new API for selective post-order legalization. Patterns can request an operation / region legalization via ConversionPatternRewriter::legalize. They can call these helper functions on nested regions before rewriting the operation itself.

Note: In rollback mode, a failed recursive legalization typically leads to a conversion failure. Since recursive legalization is performed by separate pattern applications, there is no way for the original pattern to recover from such a failure.

@matthias-springer matthias-springer requested review from j2kun and zero9178 and removed request for zero9178 November 4, 2025 03:16
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2025

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

By default, the dialect conversion driver processes operations in pre-order: the initial worklist is populated pre-order. (New/modified operations are immediately legalized recursively.)

This commit adds a new API for selective post-order legalization. Patterns can request an operation / region legalization via ConversionPatternRewriter::legalize. They can call these helper functions on nested regions before rewriting the operation itself.

Note: In rollback mode, a failed recursive legalization typically leads to a conversion failure. Since recursive legalization is performed by separate pattern applications, there is no way for the original pattern to recover from such a failure.


Full diff: https://github.com/llvm/llvm-project/pull/166292.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+14)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+53-3)
  • (modified) mlir/test/Transforms/test-legalizer-rollback.mlir (+19)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+26)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+21-1)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index ed7e2a08ebfd9..db903dc337c46 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -981,6 +981,20 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 
+  /// Attempt to legalize the given operation. This can be used within
+  /// conversion patterns to change the default post-order legalization order.
+  /// Returns "success" if the operation was legalized, "failure" otherwise.
+  LogicalResult legalize(Operation *op);
+
+  /// Attempt to legalize the given region. This can be used within
+  /// conversion patterns to change the default post-order legalization order.
+  /// Returns "success" if the region was legalized, "failure" otherwise.
+  ///
+  /// If the current pattern runs with a type converter, the entry block
+  /// signature will be converted before legalizing the operations in the
+  /// region.
+  LogicalResult legalize(Region *r);
+
 private:
   // Allow OperationConverter to construct new rewriters.
   friend struct OperationConverter;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2fe06970eb568..c56c3b6e013ef 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -862,8 +862,11 @@ static bool hasRewrite(R &&rewrites, Block *block) {
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
+
 namespace mlir {
 namespace detail {
+class OperationLegalizer;
+
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
                                          const ConversionConfig &config)
@@ -915,6 +918,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Return "true" if the given operation was replaced or erased.
   bool wasOpReplaced(Operation *op) const;
 
+  /// Set the operation legalizer to use for recursive legalization.
+  void setOperationLegalizer(OperationLegalizer *legalizer) {
+    opLegalizer = legalizer;
+  }
+
   /// Lookup the most recently mapped values with the desired types in the
   /// mapping, taking into account only replacements. Perform a best-effort
   /// search for existing materializations with the desired types.
@@ -1121,6 +1129,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// converting the arguments of blocks within that region.
   DenseMap<Region *, const TypeConverter *> regionToConverter;
 
+  /// The operation legalizer to use for recursive legalization. This is set by
+  /// the OperationConverter when the rewriter is created.
+  OperationLegalizer *opLegalizer = nullptr;
+
   /// Dialect conversion configuration.
   const ConversionConfig &config;
 
@@ -2357,7 +2369,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
 // OperationLegalizer
 //===----------------------------------------------------------------------===//
 
-namespace {
+namespace mlir::detail {
 /// A set of rewrite patterns that can be used to legalize a given operation.
 using LegalizationPatterns = SmallVector<const Pattern *, 1>;
 
@@ -2454,7 +2466,7 @@ class OperationLegalizer {
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
 };
-} // namespace
+} // namespace mlir::detail
 
 OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
                                        const ConversionTarget &targetInfo,
@@ -2854,6 +2866,41 @@ LogicalResult OperationLegalizer::legalizePatternRootUpdates(
   return success();
 }
 
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+  return impl->opLegalizer->legalize(op);
+}
+
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+  // Fast path: If the region is empty, there is nothing to legalize.
+  if (r->empty())
+    return success();
+
+  // Gather a list of all operations to legalize. This is done before
+  // converting the entry block signature because unrealized_conversion_cast
+  // ops should not be included.
+  SmallVector<Operation *> ops;
+  for (Block &b : *r)
+    for (Operation &op : b)
+      ops.push_back(&op);
+
+  // If the current pattern runs with a type converter, convert the entry block
+  // signature.
+  if (const TypeConverter *converter = impl->currentTypeConverter) {
+    std::optional<TypeConverter::SignatureConversion> conversion =
+        converter->convertBlockSignature(&r->front());
+    if (!conversion)
+      return failure();
+    applySignatureConversion(&r->front(), *conversion, converter);
+  }
+
+  // Legalize all operations in the region.
+  for (Operation *op : ops)
+    if (failed(legalize(op)))
+      return failure();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Cost Model
 //===----------------------------------------------------------------------===//
@@ -3218,7 +3265,10 @@ struct OperationConverter {
                               const ConversionConfig &config,
                               OpConversionMode mode)
       : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
-        mode(mode) {}
+        mode(mode) {
+    // Set the legalizer in the rewriter so patterns can recursively legalize.
+    rewriter.getImpl().setOperationLegalizer(&opLegalizer);
+  }
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops);
diff --git a/mlir/test/Transforms/test-legalizer-rollback.mlir b/mlir/test/Transforms/test-legalizer-rollback.mlir
index 71e11782e14b0..4bcca6b7e5228 100644
--- a/mlir/test/Transforms/test-legalizer-rollback.mlir
+++ b/mlir/test/Transforms/test-legalizer-rollback.mlir
@@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
   "test.return"(%0) : (i32) -> ()
 }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_failed_preorder_legalization
+//       CHECK:   "test.post_order_legalization"() ({
+//       CHECK:     %[[r:.*]] = "test.illegal_op_g"() : () -> i32
+//       CHECK:     "test.return"(%[[r]]) : (i32) -> ()
+//       CHECK:   }) : () -> ()
+// expected-remark @+1 {{applyPartialConversion failed}}
+module {
+func.func @test_failed_preorder_legalization() {
+  // expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
+  "test.post_order_legalization"() ({
+    %0 = "test.illegal_op_g"() : () -> (i32)
+    "test.return"(%0) : (i32) -> ()
+  }) : () -> ()
+  return
+}
+}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 7c43bb7bface0..5c6d4cd1dd205 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -448,3 +448,29 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
   "test.type_consumer"(%arg0) : (f16) -> ()
   "test.return"() : () -> ()
 }
+
+// -----
+
+// The region of "test.post_order_legalization" is converted before the op.
+
+// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
+// CHECK: notifyOperationInserted: test.invalid
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.invalid
+// CHECK: notifyOperationErased: test.invalid
+// CHECK: notifyOperationModified: test.post_order_legalization
+
+// CHECK-LABEL: func @test_preorder_legalization
+//       CHECK:   "test.post_order_legalization"() ({
+//       CHECK:   ^{{.*}}(%[[arg0:.*]]: f64):
+//       CHECK:     "test.valid"(%[[arg0]]) : (f64) -> ()
+//       CHECK:   }) {is_legal} : () -> ()
+func.func @test_preorder_legalization() {
+  "test.post_order_legalization"() ({
+  ^bb0(%arg0: i64):
+    "test.invalid"(%arg0) : (i64) -> ()
+  }) : () -> ()
+  // expected-remark @+1 {{is not legalizable}}
+  return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 12edecc113495..9b64bc691588d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1418,6 +1418,22 @@ class TestTypeConsumerOpPattern
   }
 };
 
+class TestPostOrderLegalization : public ConversionPattern {
+public:
+  TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    for (Region &r : op->getRegions())
+      if (failed(rewriter.legalize(&r)))
+        return failure();
+    rewriter.modifyOpInPlace(
+        op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
+    return success();
+  }
+};
+
 /// Test unambiguous overload resolution of replaceOpWithMultiple. This
 /// function is just to trigger compiler errors. It is never executed.
 [[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
                  TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
                  TestValueReplace, TestReplaceWithValidConsumer,
-                 TestTypeConsumerOpPattern>(&getContext(), converter);
+                 TestTypeConsumerOpPattern, TestPostOrderLegalization>(
+        &getContext(), converter);
     patterns.add<TestConvertBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp(
         OperationName("test.value_replace", &getContext()),
         [](Operation *op) { return op->hasAttr("is_legal"); });
+    target.addDynamicallyLegalOp(
+        OperationName("test.post_order_legalization", &getContext()),
+        [](Operation *op) { return op->hasAttr("is_legal"); });
 
     // TestCreateUnregisteredOp creates `arith.constant` operation,
     // which was not added to target intentionally to test

@matthias-springer matthias-springer force-pushed the users/matthias-springer/post_order_legalization branch 2 times, most recently from 63eac17 to c949a6c Compare November 4, 2025 04:04
@matthias-springer matthias-springer force-pushed the users/matthias-springer/post_order_legalization branch from c949a6c to b4c8f69 Compare November 4, 2025 04:06
ops.push_back(&op);

// If the current pattern runs with a type converter, convert the entry block
// signature.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case tested right now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, by this test case.

@matthias-springer matthias-springer merged commit a38e094 into main Nov 5, 2025
10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/post_order_legalization branch November 5, 2025 12:04
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 5, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/20572

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/GPU/CUDA/async.mlir' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary="format=fatbin"  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so    --entry-point-result=void -O0  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt '-pass-pipeline=builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary=format=fatbin
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so --entry-point-result=void -O0
# .---command stderr------------
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventSynchronize(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# `-----------------------------
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# .---command stderr------------
# | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir:68:12: error: CHECK: expected string not found in input
# |  // CHECK: [84, 84]
# |            ^
# | <stdin>:1:1: note: scanning from here
# | Unranked Memref base@ = 0x57954c5010e0 rank = 1 offset = 0 sizes = [2] strides = [1] data = 
# | ^
# | <stdin>:2:1: note: possible intended match here
# | [42, 42]
# | ^
# | 
# | Input file: <stdin>
# | Check file: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             1: Unranked Memref base@ = 0x57954c5010e0 rank = 1 offset = 0 sizes = [2] strides = [1] data =  
# | check:68'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
# |             2: [42, 42] 
# | check:68'0     ~~~~~~~~~
# | check:68'1     ?         possible intended match
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants