3030
3131using namespace mlir ;
3232
33+ // ===----------------------------------------------------------------------===//
34+ // Helper functions
35+ // ===----------------------------------------------------------------------===//
36+
37+ // / Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
38+ // / properly dominates `b` and `b` is not inside `a`.
39+ static bool happensBefore (Operation *a, Operation *b) {
40+ do {
41+ if (a->isProperAncestor (b))
42+ return false ;
43+ if (Operation *bAncestor = a->getBlock ()->findAncestorOpInBlock (*b)) {
44+ return a->isBeforeInBlock (bAncestor);
45+ }
46+ } while ((a = a->getParentOp ()));
47+ return false ;
48+ }
49+
3350// ===----------------------------------------------------------------------===//
3451// TransformState
3552// ===----------------------------------------------------------------------===//
@@ -44,14 +61,10 @@ transform::TransformState::TransformState(
4461 topLevelMappedValues.reserve (extraMappings.size ());
4562 for (ArrayRef<MappedValue> mapping : extraMappings)
4663 topLevelMappedValues.push_back (mapping);
47-
48- auto result =
49- mappings.insert (std::make_pair (region, std::make_unique<Mappings>()));
50- assert (result.second && " the region scope is already present" );
51- (void )result;
52- #if LLVM_ENABLE_ABI_BREAKING_CHECKS
53- regionStack.push_back (region);
54- #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
64+ if (region) {
65+ RegionScope *scope = new RegionScope (*this , *region);
66+ topLevelRegionScope.reset (scope);
67+ }
5568}
5669
5770Operation *transform::TransformState::getTopLevel () const { return topLevel; }
@@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
811824 LLVM_DEBUG (DBGS () << " Failing Top-level payload:\n " ; getTopLevel ()->print (
812825 llvm::dbgs (), mlir::OpPrintingFlags ().printGenericOpForm ()););
813826 });
827+
828+ // Set current transform op.
829+ regionStack.back ()->currentTransform = transform;
830+
831+ // Expensive checks to detect invalid transform IR.
814832 if (options.getExpensiveChecksEnabled ()) {
815833 FULL_LDBG (" ExpensiveChecksEnabled\n " );
816834 if (failed (checkAndRecordHandleInvalidation (transform)))
@@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
899917 }
900918
901919 // Prepare rewriter and listener.
902- transform::ErrorCheckingTrackingListener trackingListener (*this , transform);
920+ TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
921+ // Skip handle if it is dead.
922+ auto scopeIt =
923+ llvm::find_if (llvm::reverse (regionStack), [&](RegionScope *scope) {
924+ return handle.getParentRegion () == scope->region ;
925+ });
926+ assert (scopeIt != regionStack.rend () &&
927+ " could not find region scope for handle" );
928+ RegionScope *scope = *scopeIt;
929+ for (Operation *user : handle.getUsers ()) {
930+ if (user != scope->currentTransform &&
931+ !happensBefore (user, scope->currentTransform ))
932+ return false ;
933+ }
934+ return true ;
935+ };
936+ transform::ErrorCheckingTrackingListener trackingListener (*this , transform,
937+ skipHandleFn);
903938 transform::TransformRewriter rewriter (transform->getContext (),
904939 &trackingListener);
905940
@@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
10401075#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
10411076
10421077 state.mappings .erase (region);
1043-
1044- #if LLVM_ENABLE_ABI_BREAKING_CHECKS
10451078 state.regionStack .pop_back ();
1046- #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
10471079}
10481080
10491081// ===----------------------------------------------------------------------===//
@@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
11501182// ===----------------------------------------------------------------------===//
11511183
11521184transform::TrackingListener::TrackingListener (TransformState &state,
1153- TransformOpInterface op)
1154- : TransformState::Extension(state), transformOp(op) {
1185+ TransformOpInterface op,
1186+ SkipHandleFn skipHandleFn)
1187+ : TransformState::Extension(state), transformOp(op),
1188+ skipHandleFn(skipHandleFn) {
11551189 if (op) {
11561190 for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands ()) {
11571191 consumedHandles.insert (opOperand->get ());
@@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
12511285 });
12521286}
12531287
1254- // / Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
1255- // / properly dominates `b` and `b` is not inside `a`.
1256- static bool happensBefore (Operation *a, Operation *b) {
1257- do {
1258- if (a->isProperAncestor (b))
1259- return false ;
1260- if (Operation *bAncestor = a->getBlock ()->findAncestorOpInBlock (*b)) {
1261- return a->isBeforeInBlock (bAncestor);
1262- }
1263- } while ((a = a->getParentOp ()));
1264- return false ;
1265- }
1266-
12671288void transform::TrackingListener::notifyOperationReplaced (
12681289 Operation *op, ValueRange newValues) {
12691290 assert (op->getNumResults () == newValues.size () &&
@@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
12951316 [&](Value h) { return consumedHandles.contains (h); });
12961317 };
12971318
1298- // Helper function to check if the handle is alive.
1299- auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
1300- for (Value v : opHandles) {
1301- for (OpOperand &use : v.getUses ())
1302- if (use.getOwner () != transformOp &&
1303- !happensBefore (use.getOwner (), transformOp))
1304- return &use;
1305- }
1306- return std::nullopt ;
1307- }();
1308-
1309- if (!firstAliveUser.has_value () || handleWasConsumed ()) {
1319+ // Check if there are any handles that must be updated.
1320+ Value aliveHandle;
1321+ if (skipHandleFn) {
1322+ auto it =
1323+ llvm::find_if (opHandles, [&](Value v) { return !skipHandleFn (v); });
1324+ if (it != opHandles.end ())
1325+ aliveHandle = *it;
1326+ } else if (!opHandles.empty ()) {
1327+ aliveHandle = opHandles.front ();
1328+ }
1329+ if (!aliveHandle || handleWasConsumed ()) {
13101330 // The op is tracked but the corresponding handles are dead or were
13111331 // consumed. Drop the op form the mapping.
13121332 (void )replacePayloadOp (op, nullptr );
@@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
13191339 // If the op is tracked but no replacement op was found, send a
13201340 // notification.
13211341 if (!diag.succeeded ()) {
1322- diag.attachNote ((*firstAliveUser)->getOwner ()->getLoc ())
1323- << " replacement is required because alive handle(s) exist "
1324- << " (first use in this op as operand number "
1325- << (*firstAliveUser)->getOperandNumber () << " )" ;
1342+ diag.attachNote (aliveHandle.getLoc ())
1343+ << " replacement is required because this handle must be updated" ;
13261344 notifyPayloadReplacementNotFound (op, newValues, std::move (diag));
13271345 (void )replacePayloadOp (op, nullptr );
13281346 return ;
0 commit comments