diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 857b601acbc35..4165e0a52428f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1229,9 +1229,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { - // Check to see if this operation was replaced or its parent ignored. - return ignoredOps.count(op->getParentOp()) || - hasRewrite(rewrites, op); + // Check to see if this operation or the parent operation is ignored. + return ignoredOps.count(op->getParentOp()) || ignoredOps.count(op); } void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { @@ -1480,12 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); -#ifndef NDEBUG - for (auto &rewrite : rewrites) - if (auto *opReplacement = dyn_cast(rewrite.get())) - assert(opReplacement->getOperation() != op && - "operation was already replaced"); -#endif // NDEBUG + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Track if any of the results changed, e.g. erased and replaced with null. bool resultChanged = false; @@ -1506,6 +1500,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. + ignoredOps.insert(op); markNestedOpsIgnored(op); }