diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 2fdc15db9ad85..23871cc16d87d 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -310,10 +310,8 @@ class TransformState { /// with the type of the handle value. LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef operations) { -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - assert(argument.getParentRegion() == regionStack.back() && + assert(argument.getParentRegion() == regionStack.back()->region && "mapping block arguments from a region other than the active one"); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS return setPayloadOps(argument, operations); } LogicalResult mapBlockArgument(BlockArgument argument, @@ -350,9 +348,7 @@ class TransformState { std::make_pair(®ion, std::make_unique())); assert(res.second && "the region scope is already present"); (void)res; -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - state.regionStack.push_back(®ion); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + state.regionStack.push_back(this); } /// Back-reference to the transform state. @@ -361,7 +357,10 @@ class TransformState { /// The region this scope is associated with. Region *region; - friend RegionScope TransformState::make_region_scope(Region &); + /// The transform op within this region that is currently being applied. + TransformOpInterface currentTransform; + + friend class transform::TransformState; }; friend class RegionScope; @@ -784,12 +783,14 @@ class TransformState { /// location. InvalidatedHandleMap invalidatedHandles; -#if LLVM_ENABLE_ABI_BREAKING_CHECKS /// A stack of nested regions that are being processed in the transform IR. /// Each region must be an ancestor of the following regions in this list. /// These are also the keys for "mappings". - SmallVector regionStack; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + SmallVector regionStack; + + /// The top-level region scope. The first (bottom) element of `regionStack` + /// is the top-level region scope object. + std::unique_ptr topLevelRegionScope; }; /// Local mapping between values defined by a specific op implementing the @@ -926,8 +927,14 @@ TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { class TrackingListener : public RewriterBase::Listener, public TransformState::Extension { public: + /// A function that returns "true" for handles that do not have to be updated. + using SkipHandleFn = std::function; + /// Create a new TrackingListener for usage in the specified transform op. - TrackingListener(TransformState &state, TransformOpInterface op); + /// Optionally, a function can be specified to identify handles that should + /// do not have to be updated. + TrackingListener(TransformState &state, TransformOpInterface op, + SkipHandleFn skipHandleFn = nullptr); protected: /// Return a replacement payload op for the given op, which is going to be @@ -1015,6 +1022,10 @@ class TrackingListener : public RewriterBase::Listener, /// The handles that are consumed by the transform op. DenseSet consumedHandles; + + /// Handles for which this function evaluates to "true" do not have to be + /// updated. These are typically dead or consumed handles. + SkipHandleFn skipHandleFn; }; /// A specialized listener that keeps track of cases in which no replacement diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index de5b7a81286bc..cd66a0e566f6c 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -30,6 +30,23 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors +/// properly dominates `b` and `b` is not inside `a`. +static bool happensBefore(Operation *a, Operation *b) { + do { + if (a->isProperAncestor(b)) + return false; + if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { + return a->isBeforeInBlock(bAncestor); + } + } while ((a = a->getParentOp())); + return false; +} + //===----------------------------------------------------------------------===// // TransformState //===----------------------------------------------------------------------===// @@ -44,14 +61,10 @@ transform::TransformState::TransformState( topLevelMappedValues.reserve(extraMappings.size()); for (ArrayRef mapping : extraMappings) topLevelMappedValues.push_back(mapping); - - auto result = - mappings.insert(std::make_pair(region, std::make_unique())); - assert(result.second && "the region scope is already present"); - (void)result; -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - regionStack.push_back(region); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + if (region) { + RegionScope *scope = new RegionScope(*this, *region); + topLevelRegionScope.reset(scope); + } } Operation *transform::TransformState::getTopLevel() const { return topLevel; } @@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); }); + + // Set current transform op. + regionStack.back()->currentTransform = transform; + + // Expensive checks to detect invalid transform IR. if (options.getExpensiveChecksEnabled()) { FULL_LDBG("ExpensiveChecksEnabled\n"); if (failed(checkAndRecordHandleInvalidation(transform))) @@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } // Prepare rewriter and listener. - transform::ErrorCheckingTrackingListener trackingListener(*this, transform); + TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) { + // Skip handle if it is dead. + auto scopeIt = + llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) { + return handle.getParentRegion() == scope->region; + }); + assert(scopeIt != regionStack.rend() && + "could not find region scope for handle"); + RegionScope *scope = *scopeIt; + for (Operation *user : handle.getUsers()) { + if (user != scope->currentTransform && + !happensBefore(user, scope->currentTransform)) + return false; + } + return true; + }; + transform::ErrorCheckingTrackingListener trackingListener(*this, transform, + skipHandleFn); transform::TransformRewriter rewriter(transform->getContext(), &trackingListener); @@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() { #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS state.mappings.erase(region); - -#if LLVM_ENABLE_ABI_BREAKING_CHECKS state.regionStack.pop_back(); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } //===----------------------------------------------------------------------===// @@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const { //===----------------------------------------------------------------------===// transform::TrackingListener::TrackingListener(TransformState &state, - TransformOpInterface op) - : TransformState::Extension(state), transformOp(op) { + TransformOpInterface op, + SkipHandleFn skipHandleFn) + : TransformState::Extension(state), transformOp(op), + skipHandleFn(skipHandleFn) { if (op) { for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) { consumedHandles.insert(opOperand->get()); @@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) { }); } -/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors -/// properly dominates `b` and `b` is not inside `a`. -static bool happensBefore(Operation *a, Operation *b) { - do { - if (a->isProperAncestor(b)) - return false; - if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { - return a->isBeforeInBlock(bAncestor); - } - } while ((a = a->getParentOp())); - return false; -} - void transform::TrackingListener::notifyOperationReplaced( Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && @@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced( [&](Value h) { return consumedHandles.contains(h); }); }; - // Helper function to check if the handle is alive. - auto firstAliveUser = [&]() -> std::optional { - for (Value v : opHandles) { - for (OpOperand &use : v.getUses()) - if (use.getOwner() != transformOp && - !happensBefore(use.getOwner(), transformOp)) - return &use; - } - return std::nullopt; - }(); - - if (!firstAliveUser.has_value() || handleWasConsumed()) { + // Check if there are any handles that must be updated. + Value aliveHandle; + if (skipHandleFn) { + auto it = + llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); }); + if (it != opHandles.end()) + aliveHandle = *it; + } else if (!opHandles.empty()) { + aliveHandle = opHandles.front(); + } + if (!aliveHandle || handleWasConsumed()) { // The op is tracked but the corresponding handles are dead or were // consumed. Drop the op form the mapping. (void)replacePayloadOp(op, nullptr); @@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced( // If the op is tracked but no replacement op was found, send a // notification. if (!diag.succeeded()) { - diag.attachNote((*firstAliveUser)->getOwner()->getLoc()) - << "replacement is required because alive handle(s) exist " - << "(first use in this op as operand number " - << (*firstAliveUser)->getOperandNumber() << ")"; + diag.attachNote(aliveHandle.getLoc()) + << "replacement is required because this handle must be updated"; notifyPayloadReplacementNotFound(op, newValues, std::move(diag)); (void)replacePayloadOp(op, nullptr); return; diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir index 2d57d4aa2547f..2fd47c6bae396 100644 --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -36,6 +36,7 @@ func.func @replacement_op_not_found() { transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-note @below {{replacement is required because this handle must be updated}} %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{tracking listener failed to find replacement op during application of this transform op}} // expected-note @below {{ran out of suitable replacement values}} @@ -44,7 +45,6 @@ transform.sequence failures(propagate) { } : !transform.any_op // %1 must be used in some way. If no replacement payload op could be found, // an error is thrown only if the handle is not dead. - // expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}} transform.annotate %1 "annotated" : !transform.any_op } @@ -363,3 +363,31 @@ transform.sequence failures(propagate) { legal_ops = ["func.func", "func.return", "test.new_op"]} : !transform.any_op } + +// ----- + +module attributes { transform.with_named_sequence } { +func.func @replacement_op_not_found() { + // No op replacement can be found, but there are no handles that must be + // updated. No error should be reported. + "test.container"() ({ + %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32) + }) : () -> () + return +} + +transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) { + transform.apply_patterns to %container { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op + transform.yield +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.annotate %1 "annotated" : !transform.any_op + transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> () +} +}