From 238322164f0c973318ddfcb9d66bcc5b05e7546c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 26 Feb 2024 19:41:55 +0000 Subject: [PATCH] [mlir][Transforms] Track erased ops separately BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC --- .../Transforms/Utils/DialectConversion.cpp | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4165e0a52428f..f967e8352bf4c 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -153,9 +153,9 @@ namespace { /// This is useful when saving and undoing a set of rewrites. struct RewriterState { RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, - unsigned numErased) + unsigned numErased, unsigned numReplacedOps) : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), - numErased(numErased) {} + numErased(numErased), numReplacedOps(numReplacedOps) {} /// The current number of rewrites performed. unsigned numRewrites; @@ -165,6 +165,9 @@ struct RewriterState { /// The current number of erased operations/blocks. unsigned numErased; + + /// The current number of replaced ops that are scheduled for erasure. + unsigned numReplacedOps; }; //===----------------------------------------------------------------------===// @@ -954,6 +957,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// operation was ignored. SetVector ignoredOps; + // A set of operations that were erased. + SetVector replacedOps; + /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1152,7 +1158,7 @@ void ConversionPatternRewriterImpl::applyRewrites() { RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(rewrites.size(), ignoredOps.size(), - eraseRewriter.erased.size()); + eraseRewriter.erased.size(), replacedOps.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -1165,6 +1171,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { while (eraseRewriter.erased.size() != state.numErased) eraseRewriter.erased.pop_back(); + + while (replacedOps.size() != state.numReplacedOps) + replacedOps.pop_back(); } void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { @@ -1228,9 +1237,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( return success(); } +// TODO: This function is a misnomer. It does not actually check if `op` is in +// `ignoredOps`. bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation or the parent operation is ignored. - return ignoredOps.count(op->getParentOp()) || ignoredOps.count(op); + return ignoredOps.count(op->getParentOp()) || replacedOps.count(op); } void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { @@ -1479,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); - assert(!ignoredOps.contains(op) && "operation was already replaced"); + assert(!replacedOps.contains(op) && "operation was already replaced"); // Track if any of the results changed, e.g. erased and replaced with null. bool resultChanged = false; @@ -1500,7 +1511,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. - ignoredOps.insert(op); + replacedOps.insert(op); markNestedOpsIgnored(op); }