diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index dec68048dc1d3..704597148dfac 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -152,17 +152,12 @@ namespace { /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { - RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, - unsigned numRewrites, unsigned numIgnoredOperations, - unsigned numErased) - : numCreatedOps(numCreatedOps), - numUnresolvedMaterializations(numUnresolvedMaterializations), + RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites, + unsigned numIgnoredOperations, unsigned numErased) + : numUnresolvedMaterializations(numUnresolvedMaterializations), numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), numErased(numErased) {} - /// The current number of created operations. - unsigned numCreatedOps; - /// The current number of unresolved materializations. unsigned numUnresolvedMaterializations; @@ -303,7 +298,8 @@ class IRRewrite { // Operation rewrites MoveOperation, ModifyOperation, - ReplaceOperation + ReplaceOperation, + CreateOperation }; virtual ~IRRewrite() = default; @@ -376,7 +372,10 @@ class CreateBlockRewrite : public BlockRewrite { auto &blockOps = block->getOperations(); while (!blockOps.empty()) blockOps.remove(blockOps.begin()); - eraseBlock(block); + if (block->getParent()) + eraseBlock(block); + else + delete block; } }; @@ -606,7 +605,7 @@ class OperationRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::MoveOperation && - rewrite->getKind() <= Kind::ReplaceOperation; + rewrite->getKind() <= Kind::CreateOperation; } protected: @@ -740,6 +739,19 @@ class ReplaceOperationRewrite : public OperationRewrite { /// A boolean flag that indicates whether result types have changed or not. bool changedResults; }; + +class CreateOperationRewrite : public OperationRewrite { +public: + CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Operation *op) + : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::CreateOperation; + } + + void rollback() override; +}; } // namespace /// Return "true" if there is an operation rewrite that matches the specified @@ -957,9 +969,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // replacing a value with one of a different type. ConversionValueMapping mapping; - /// Ordered vector of all of the newly created operations during conversion. - SmallVector createdOps; - /// Ordered vector of all unresolved type conversion materializations during /// conversion. SmallVector unresolvedMaterializations; @@ -1144,6 +1153,15 @@ void ReplaceOperationRewrite::rollback() { void ReplaceOperationRewrite::cleanup() { eraseOp(op); } +void CreateOperationRewrite::rollback() { + for (Region ®ion : op->getRegions()) { + while (!region.getBlocks().empty()) + region.getBlocks().remove(region.getBlocks().begin()); + } + op->dropAllUses(); + eraseOp(op); +} + void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { @@ -1161,8 +1179,6 @@ void ConversionPatternRewriterImpl::discardRewrites() { // Remove any newly created ops. for (UnresolvedMaterialization &materialization : unresolvedMaterializations) detachNestedAndErase(materialization.getOp()); - for (auto *op : llvm::reverse(createdOps)) - detachNestedAndErase(op); } void ConversionPatternRewriterImpl::applyRewrites() { @@ -1182,9 +1198,8 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management RewriterState ConversionPatternRewriterImpl::getCurrentState() { - return RewriterState(createdOps.size(), unresolvedMaterializations.size(), - rewrites.size(), ignoredOps.size(), - eraseRewriter.erased.size()); + return RewriterState(unresolvedMaterializations.size(), rewrites.size(), + ignoredOps.size(), eraseRewriter.erased.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -1205,12 +1220,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) { detachNestedAndErase(op); } - // Pop all of the newly created operations. - while (createdOps.size() != state.numCreatedOps) { - detachNestedAndErase(createdOps.back()); - createdOps.pop_back(); - } - // Pop all of the recorded ignored operations that are no longer valid. while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); @@ -1478,7 +1487,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( }); if (!previous.isSet()) { // This is a newly created op. - createdOps.push_back(op); + appendRewrite(op); return; } Operation *prevOp = previous.getPoint() == previous.getBlock()->end() @@ -1979,13 +1988,16 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriter.replaceOp(op, replacementValues); // Recursively legalize any new constant operations. - for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); + for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); i != e; ++i) { - Operation *cstOp = rewriterImpl.createdOps[i]; - if (failed(legalize(cstOp, rewriter))) { + auto *createOp = + dyn_cast(rewriterImpl.rewrites[i].get()); + if (!createOp) + continue; + if (failed(legalize(createOp->getOperation(), rewriter))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to legalize generated constant '{0}'", - cstOp->getName())); + createOp->getOperation()->getName())); rewriterImpl.resetState(curState); return failure(); } @@ -2132,9 +2144,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // blocks in regions created by this pattern will already be legalized later // on. If we haven't built the set yet, build it now. if (operationsToIgnore.empty()) { - auto createdOps = ArrayRef(impl.createdOps) - .drop_front(state.numCreatedOps); - operationsToIgnore.insert(createdOps.begin(), createdOps.end()); + for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e; + ++i) { + auto *createOp = + dyn_cast(impl.rewrites[i].get()); + if (!createOp) + continue; + operationsToIgnore.insert(createOp->getOperation()); + } } // If this operation should be considered for re-legalization, try it. @@ -2152,8 +2169,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( LogicalResult OperationLegalizer::legalizePatternCreatedOperations( ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, RewriterState &state, RewriterState &newState) { - for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) { - Operation *op = impl.createdOps[i]; + for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { + auto *createOp = dyn_cast(impl.rewrites[i].get()); + if (!createOp) + continue; + Operation *op = createOp->getOperation(); if (failed(legalize(op, rewriter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to legalize generated operation '{0}'({1})", @@ -2583,10 +2603,16 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( }); return liveUserIt == val.user_end() ? nullptr : *liveUserIt; }; - for (auto &r : rewriterImpl.rewrites) - if (auto *rewrite = dyn_cast(r.get())) - if (failed(rewrite->materializeLiveConversions(findLiveUser))) + // Note: `rewrites` may be reallocated as the loop is running. + for (int64_t i = 0; i < static_cast(rewriterImpl.rewrites.size()); + ++i) { + auto &rewrite = rewriterImpl.rewrites[i]; + if (auto *blockTypeConversionRewrite = + dyn_cast(rewrite.get())) + if (failed(blockTypeConversionRewrite->materializeLiveConversions( + findLiveUser))) return failure(); + } return success(); }