Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 41 additions & 67 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,22 @@ transform::TransformState::setPayloadOps(Value value,
for (Operation *op : targets)
mappings.reverse[op].push_back(value);

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
if (options.getExpensiveChecksEnabled()) {
for (Operation *op : targets) {
auto insertion = cachedNames.insert({op, op->getName()});
if (!insertion.second) {
if (insertion.first->second != op->getName()) {
// Operation is already in the cache, but with a different name.
return emitError(value.getLoc())
<< "expensive checks failure: operation mismatch, expected "
<< insertion.first->second;
}
}
}
}
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

return success();
}

Expand Down Expand Up @@ -389,15 +405,20 @@ transform::TransformState::replacePayloadOp(Operation *op,
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
if (options.getExpensiveChecksEnabled()) {
auto it = cachedNames.find(op);
assert(it != cachedNames.end() && "entry not found");
assert(it->second == op->getName() && "operation name mismatch");
cachedNames.erase(it);
if (replacement) {
auto insertion =
cachedNames.insert({replacement, replacement->getName()});
if (!insertion.second) {
assert(insertion.first->second == replacement->getName() &&
"operation is already cached with a different name");
// Payload ops (and their children) mapped to consumed handles were already
// removed from the cache. We can make no assumption about which ops are in
// the cache and which are not. But if an op is in the cache, the name must
// match.
if (it != cachedNames.end()) {
assert(it->second == op->getName() && "operation name mismatch");
cachedNames.erase(it);
if (replacement) {
auto insertion =
cachedNames.insert({replacement, replacement->getName()});
if (!insertion.second) {
assert(insertion.first->second == replacement->getName() &&
"operation is already cached with a different name");
}
}
}
}
Expand Down Expand Up @@ -908,20 +929,16 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// IR after that.
SmallVector<Value> origOpFlatResults;
SmallVector<Operation *> origAssociatedOps;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
DenseSet<Operation *> consumedPayloadOps;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (OpOperand *opOperand : consumedOperands) {
Value operand = opOperand->get();
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
for (Operation *payloadOp : getPayloadOps(operand)) {
llvm::append_range(origOpFlatResults, payloadOp->getResults());
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
if (options.getExpensiveChecksEnabled()) {
// Store all consumed payload ops (and their nested ops) in a set for
// extra error checking.
payloadOp->walk(
[&](Operation *op) { consumedPayloadOps.insert(op); });
// Remove all consumed payload ops (and their nested ops) from the
// name cache.
payloadOp->walk([&](Operation *op) { cachedNames.erase(op); });
}
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
Expand Down Expand Up @@ -1004,46 +1021,23 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
}

if (failed(updateStateFromResults(results, transform->getResults())))
return DiagnosedSilenceableFailure::definiteFailure();

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
if (options.getExpensiveChecksEnabled()) {
// Remove erased ops from the transform state.
for (Operation *op : consumedPayloadOps) {
// This payload op was consumed but it may still be mapped to one or
// multiple handles. Forget all handles that are mapped to the op, so that
// there are no dangling pointers in the transform dialect state. This is
// necessary so that the `cachedNames`-based checks work correctly.
//
// Note: Dangling pointers to erased payload ops are allowed if the
// corresponding handles are not used anymore. There is another
// "expensive-check" that looks for future uses of dangling payload op
// pointers (through arbitrary handles). Removing handles to erased ops
// does not interfere with the other expensive checks: handle invalidation
// happens earlier and keeps track of invalidated handles with
// pre-generated error messages, so we do not need the association to
// still be there when the invalidated handle is accessed.
SmallVector<Value> handles;
(void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
for (Value handle : handles)
forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
/*allowOutOfScope=*/true);
cachedNames.erase(op);
}

// Check cached operation names.
for (std::unique_ptr<Mappings> &mapping :
llvm::make_second_range(mappings)) {
for (Operation *op : llvm::make_first_range(mapping->reverse)) {
// Make sure that the name of the op has not changed. If it has changed,
// the op was removed and a new op was allocated at the same memory
// location. This means that we are missing op tracking somewhere.
// location. This means that we are missing op tracking somewhere. We
// can make no assumption about which ops are in the cache and which are
// not. But if an op is in the cache, the name must match.
auto cacheIt = cachedNames.find(op);
if (cacheIt == cachedNames.end()) {
DiagnosedDefiniteFailure diag =
emitDefiniteFailure(transform->getLoc())
<< "expensive checks failure: operation not found in cache";
diag.attachNote(op->getLoc()) << "payload op";
return diag;
}
if (cacheIt == cachedNames.end())
continue;
// If the `getName` call (or the above `attachNote`) is crashing, we
// have a dangling pointer. This usually means that an op was erased but
// the transform dialect was not made aware of that; e.g., missing
Expand All @@ -1061,9 +1055,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

if (failed(updateStateFromResults(results, transform->getResults())))
return DiagnosedSilenceableFailure::definiteFailure();

printOnFailureRAII.release();
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
DBGS() << "Top-level payload:\n";
Expand Down Expand Up @@ -1140,26 +1131,9 @@ transform::TransformState::RegionScope::~RegionScope() {
}
}

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// Remember pointers to payload ops referenced by the handles going out of
// scope.
SmallVector<Operation *> referencedOps =
llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS

state.mappings.erase(region);

#if LLVM_ENABLE_ABI_BREAKING_CHECKS
// If the last handle to a payload op has gone out of scope, we no longer
// need to store the cached name. Pointers may get reused, leading to
// incorrect associations in the cache.
for (Operation *op : referencedOps) {
SmallVector<Value> handles;
if (succeeded(state.getHandlesForPayloadOp(op, handles)))
continue;
state.cachedNames.erase(op);
}

state.regionStack.pop_back();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
Expand Down