From 4730954aa66fe31516b37e04f6b1c609236718f7 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 27 Nov 2023 23:30:35 +0100 Subject: [PATCH] [mlir][transform] Fix and improve "cached names" check When running with "expensive checks", the transform dialect interpreter maintains a payload `Operation *` -> `OperationName` cache. This cache is used to detect invalid API usage such as missing/incorrect handle consumption/production side effects and/or payload IR modifications that bypass the rewriter. There was a bug in the check that can cause issues such as #72931. (#72986 was just a workaround and did not really fix the underlying issue.) - Payload ops mapped to newly produced handles are now added to the cache. This is in addition to adding/checking all mapped payload ops at the beginning of each transform op, for extra safety. - Remove consumed ops (and their children) before applying the transform op. This used to happen after applying the transform op, which is incorrect in cases such as: (1) transform op replaces a consumed payload op with another op, (2) the new op reuses the same memory pointer and (3) the new op is added to a newly produced handle. In such a case the previous implementation removed the newly created op from the cache. - No assumptions can be made about whether an op should be in the cache or not. The code previously reported an error when an op was not found in the cache. E.g., this is problematic in cases such as: (1) the transform op consumes the handle mapped to a payload op A and (2) the implementation of the payload op removes/replaces a nested op with A, which is mapped to another handle. This triggers a listener notification, which removes the nested op from the cache. However, because consumed ops (and their children) are removed from the cache before applying the transform op, the nested op will not be in cache and making such an assumption would be incorrect. --- .../Transform/IR/TransformInterfaces.cpp | 108 +++++++----------- 1 file changed, 41 insertions(+), 67 deletions(-) diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index d0cd879d560c8..8040fa283bdcd 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -228,6 +228,22 @@ transform::TransformState::setPayloadOps(Value value, for (Operation *op : targets) mappings.reverse[op].push_back(value); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (options.getExpensiveChecksEnabled()) { + for (Operation *op : targets) { + auto insertion = cachedNames.insert({op, op->getName()}); + if (!insertion.second) { + if (insertion.first->second != op->getName()) { + // Operation is already in the cache, but with a different name. + return emitError(value.getLoc()) + << "expensive checks failure: operation mismatch, expected " + << insertion.first->second; + } + } + } + } +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return success(); } @@ -389,15 +405,20 @@ transform::TransformState::replacePayloadOp(Operation *op, #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (options.getExpensiveChecksEnabled()) { auto it = cachedNames.find(op); - assert(it != cachedNames.end() && "entry not found"); - assert(it->second == op->getName() && "operation name mismatch"); - cachedNames.erase(it); - if (replacement) { - auto insertion = - cachedNames.insert({replacement, replacement->getName()}); - if (!insertion.second) { - assert(insertion.first->second == replacement->getName() && - "operation is already cached with a different name"); + // Payload ops (and their children) mapped to consumed handles were already + // removed from the cache. We can make no assumption about which ops are in + // the cache and which are not. But if an op is in the cache, the name must + // match. + if (it != cachedNames.end()) { + assert(it->second == op->getName() && "operation name mismatch"); + cachedNames.erase(it); + if (replacement) { + auto insertion = + cachedNames.insert({replacement, replacement->getName()}); + if (!insertion.second) { + assert(insertion.first->second == replacement->getName() && + "operation is already cached with a different name"); + } } } } @@ -908,9 +929,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { // IR after that. SmallVector origOpFlatResults; SmallVector origAssociatedOps; -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - DenseSet consumedPayloadOps; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (OpOperand *opOperand : consumedOperands) { Value operand = opOperand->get(); if (llvm::isa(operand.getType())) { @@ -918,10 +936,9 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (options.getExpensiveChecksEnabled()) { - // Store all consumed payload ops (and their nested ops) in a set for - // extra error checking. - payloadOp->walk( - [&](Operation *op) { consumedPayloadOps.insert(op); }); + // Remove all consumed payload ops (and their nested ops) from the + // name cache. + payloadOp->walk([&](Operation *op) { cachedNames.erase(op); }); } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } @@ -1004,46 +1021,23 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } } + if (failed(updateStateFromResults(results, transform->getResults()))) + return DiagnosedSilenceableFailure::definiteFailure(); + #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (options.getExpensiveChecksEnabled()) { - // Remove erased ops from the transform state. - for (Operation *op : consumedPayloadOps) { - // This payload op was consumed but it may still be mapped to one or - // multiple handles. Forget all handles that are mapped to the op, so that - // there are no dangling pointers in the transform dialect state. This is - // necessary so that the `cachedNames`-based checks work correctly. - // - // Note: Dangling pointers to erased payload ops are allowed if the - // corresponding handles are not used anymore. There is another - // "expensive-check" that looks for future uses of dangling payload op - // pointers (through arbitrary handles). Removing handles to erased ops - // does not interfere with the other expensive checks: handle invalidation - // happens earlier and keeps track of invalidated handles with - // pre-generated error messages, so we do not need the association to - // still be there when the invalidated handle is accessed. - SmallVector handles; - (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true); - for (Value handle : handles) - forgetMapping(handle, /*origOpFlatResults=*/ValueRange(), - /*allowOutOfScope=*/true); - cachedNames.erase(op); - } - // Check cached operation names. for (std::unique_ptr &mapping : llvm::make_second_range(mappings)) { for (Operation *op : llvm::make_first_range(mapping->reverse)) { // Make sure that the name of the op has not changed. If it has changed, // the op was removed and a new op was allocated at the same memory - // location. This means that we are missing op tracking somewhere. + // location. This means that we are missing op tracking somewhere. We + // can make no assumption about which ops are in the cache and which are + // not. But if an op is in the cache, the name must match. auto cacheIt = cachedNames.find(op); - if (cacheIt == cachedNames.end()) { - DiagnosedDefiniteFailure diag = - emitDefiniteFailure(transform->getLoc()) - << "expensive checks failure: operation not found in cache"; - diag.attachNote(op->getLoc()) << "payload op"; - return diag; - } + if (cacheIt == cachedNames.end()) + continue; // If the `getName` call (or the above `attachNote`) is crashing, we // have a dangling pointer. This usually means that an op was erased but // the transform dialect was not made aware of that; e.g., missing @@ -1061,9 +1055,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - if (failed(updateStateFromResults(results, transform->getResults()))) - return DiagnosedSilenceableFailure::definiteFailure(); - printOnFailureRAII.release(); DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { DBGS() << "Top-level payload:\n"; @@ -1140,26 +1131,9 @@ transform::TransformState::RegionScope::~RegionScope() { } } -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - // Remember pointers to payload ops referenced by the handles going out of - // scope. - SmallVector referencedOps = - llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse)); -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - state.mappings.erase(region); #if LLVM_ENABLE_ABI_BREAKING_CHECKS - // If the last handle to a payload op has gone out of scope, we no longer - // need to store the cached name. Pointers may get reused, leading to - // incorrect associations in the cache. - for (Operation *op : referencedOps) { - SmallVector handles; - if (succeeded(state.getHandlesForPayloadOp(op, handles))) - continue; - state.cachedNames.erase(op); - } - state.regionStack.pop_back(); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }