diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index d0cd879d560c8..8040fa283bdcd 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -228,6 +228,22 @@ transform::TransformState::setPayloadOps(Value value, for (Operation *op : targets) mappings.reverse[op].push_back(value); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (options.getExpensiveChecksEnabled()) { + for (Operation *op : targets) { + auto insertion = cachedNames.insert({op, op->getName()}); + if (!insertion.second) { + if (insertion.first->second != op->getName()) { + // Operation is already in the cache, but with a different name. + return emitError(value.getLoc()) + << "expensive checks failure: operation mismatch, expected " + << insertion.first->second; + } + } + } + } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return success(); } @@ -389,15 +405,20 @@ transform::TransformState::replacePayloadOp(Operation *op, #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (options.getExpensiveChecksEnabled()) { auto it = cachedNames.find(op); - assert(it != cachedNames.end() && "entry not found"); - assert(it->second == op->getName() && "operation name mismatch"); - cachedNames.erase(it); - if (replacement) { - auto insertion = - cachedNames.insert({replacement, replacement->getName()}); - if (!insertion.second) { - assert(insertion.first->second == replacement->getName() && - "operation is already cached with a different name"); + // Payload ops (and their children) mapped to consumed handles were already + // removed from the cache. We can make no assumption about which ops are in + // the cache and which are not. But if an op is in the cache, the name must + // match. + if (it != cachedNames.end()) { + assert(it->second == op->getName() && "operation name mismatch"); + cachedNames.erase(it); + if (replacement) { + auto insertion = + cachedNames.insert({replacement, replacement->getName()}); + if (!insertion.second) { + assert(insertion.first->second == replacement->getName() && + "operation is already cached with a different name"); + } } } } @@ -908,9 +929,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { // IR after that. SmallVector origOpFlatResults; SmallVector origAssociatedOps; -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - DenseSet consumedPayloadOps; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (OpOperand *opOperand : consumedOperands) { Value operand = opOperand->get(); if (llvm::isa(operand.getType())) { @@ -918,10 +936,9 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (options.getExpensiveChecksEnabled()) { - // Store all consumed payload ops (and their nested ops) in a set for - // extra error checking. - payloadOp->walk( - [&](Operation *op) { consumedPayloadOps.insert(op); }); + // Remove all consumed payload ops (and their nested ops) from the + // name cache. + payloadOp->walk([&](Operation *op) { cachedNames.erase(op); }); } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } @@ -1004,46 +1021,23 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } } + if (failed(updateStateFromResults(results, transform->getResults()))) + return DiagnosedSilenceableFailure::definiteFailure(); + #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (options.getExpensiveChecksEnabled()) { - // Remove erased ops from the transform state. - for (Operation *op : consumedPayloadOps) { - // This payload op was consumed but it may still be mapped to one or - // multiple handles. Forget all handles that are mapped to the op, so that - // there are no dangling pointers in the transform dialect state. This is - // necessary so that the `cachedNames`-based checks work correctly. - // - // Note: Dangling pointers to erased payload ops are allowed if the - // corresponding handles are not used anymore. There is another - // "expensive-check" that looks for future uses of dangling payload op - // pointers (through arbitrary handles). Removing handles to erased ops - // does not interfere with the other expensive checks: handle invalidation - // happens earlier and keeps track of invalidated handles with - // pre-generated error messages, so we do not need the association to - // still be there when the invalidated handle is accessed. - SmallVector handles; - (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true); - for (Value handle : handles) - forgetMapping(handle, /*origOpFlatResults=*/ValueRange(), - /*allowOutOfScope=*/true); - cachedNames.erase(op); - } - // Check cached operation names. for (std::unique_ptr &mapping : llvm::make_second_range(mappings)) { for (Operation *op : llvm::make_first_range(mapping->reverse)) { // Make sure that the name of the op has not changed. If it has changed, // the op was removed and a new op was allocated at the same memory - // location. This means that we are missing op tracking somewhere. + // location. This means that we are missing op tracking somewhere. We + // can make no assumption about which ops are in the cache and which are + // not. But if an op is in the cache, the name must match. auto cacheIt = cachedNames.find(op); - if (cacheIt == cachedNames.end()) { - DiagnosedDefiniteFailure diag = - emitDefiniteFailure(transform->getLoc()) - << "expensive checks failure: operation not found in cache"; - diag.attachNote(op->getLoc()) << "payload op"; - return diag; - } + if (cacheIt == cachedNames.end()) + continue; // If the `getName` call (or the above `attachNote`) is crashing, we // have a dangling pointer. This usually means that an op was erased but // the transform dialect was not made aware of that; e.g., missing @@ -1061,9 +1055,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - if (failed(updateStateFromResults(results, transform->getResults()))) - return DiagnosedSilenceableFailure::definiteFailure(); - printOnFailureRAII.release(); DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { DBGS() << "Top-level payload:\n"; @@ -1140,26 +1131,9 @@ transform::TransformState::RegionScope::~RegionScope() { } } -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - // Remember pointers to payload ops referenced by the handles going out of - // scope. - SmallVector referencedOps = - llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse)); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - state.mappings.erase(region); #if LLVM_ENABLE_ABI_BREAKING_CHECKS - // If the last handle to a payload op has gone out of scope, we no longer - // need to store the cached name. Pointers may get reused, leading to - // incorrect associations in the cache. - for (Operation *op : referencedOps) { - SmallVector handles; - if (succeeded(state.getHandlesForPayloadOp(op, handles))) - continue; - state.cachedNames.erase(op); - } - state.regionStack.pop_back(); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }