@@ -152,17 +152,12 @@ namespace {
152152// / This class contains a snapshot of the current conversion rewriter state.
153153// / This is useful when saving and undoing a set of rewrites.
154154struct RewriterState {
155- RewriterState (unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
156- unsigned numRewrites, unsigned numIgnoredOperations,
157- unsigned numErased)
158- : numCreatedOps(numCreatedOps),
159- numUnresolvedMaterializations (numUnresolvedMaterializations),
155+ RewriterState (unsigned numUnresolvedMaterializations, unsigned numRewrites,
156+ unsigned numIgnoredOperations, unsigned numErased)
157+ : numUnresolvedMaterializations(numUnresolvedMaterializations),
160158 numRewrites (numRewrites), numIgnoredOperations(numIgnoredOperations),
161159 numErased(numErased) {}
162160
163- // / The current number of created operations.
164- unsigned numCreatedOps;
165-
166161 // / The current number of unresolved materializations.
167162 unsigned numUnresolvedMaterializations;
168163
@@ -303,7 +298,8 @@ class IRRewrite {
303298 // Operation rewrites
304299 MoveOperation,
305300 ModifyOperation,
306- ReplaceOperation
301+ ReplaceOperation,
302+ CreateOperation
307303 };
308304
309305 virtual ~IRRewrite () = default ;
@@ -376,7 +372,10 @@ class CreateBlockRewrite : public BlockRewrite {
376372 auto &blockOps = block->getOperations ();
377373 while (!blockOps.empty ())
378374 blockOps.remove (blockOps.begin ());
379- eraseBlock (block);
375+ if (block->getParent ())
376+ eraseBlock (block);
377+ else
378+ delete block;
380379 }
381380};
382381
@@ -606,7 +605,7 @@ class OperationRewrite : public IRRewrite {
606605
607606 static bool classof (const IRRewrite *rewrite) {
608607 return rewrite->getKind () >= Kind::MoveOperation &&
609- rewrite->getKind () <= Kind::ReplaceOperation ;
608+ rewrite->getKind () <= Kind::CreateOperation ;
610609 }
611610
612611protected:
@@ -740,6 +739,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
740739 // / A boolean flag that indicates whether result types have changed or not.
741740 bool changedResults;
742741};
742+
743+ class CreateOperationRewrite : public OperationRewrite {
744+ public:
745+ CreateOperationRewrite (ConversionPatternRewriterImpl &rewriterImpl,
746+ Operation *op)
747+ : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
748+
749+ static bool classof (const IRRewrite *rewrite) {
750+ return rewrite->getKind () == Kind::CreateOperation;
751+ }
752+
753+ void rollback () override ;
754+ };
743755} // namespace
744756
745757// / Return "true" if there is an operation rewrite that matches the specified
@@ -957,9 +969,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
957969 // replacing a value with one of a different type.
958970 ConversionValueMapping mapping;
959971
960- // / Ordered vector of all of the newly created operations during conversion.
961- SmallVector<Operation *> createdOps;
962-
963972 // / Ordered vector of all unresolved type conversion materializations during
964973 // / conversion.
965974 SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1144,6 +1153,15 @@ void ReplaceOperationRewrite::rollback() {
11441153
11451154void ReplaceOperationRewrite::cleanup () { eraseOp (op); }
11461155
1156+ void CreateOperationRewrite::rollback () {
1157+ for (Region ®ion : op->getRegions ()) {
1158+ while (!region.getBlocks ().empty ())
1159+ region.getBlocks ().remove (region.getBlocks ().begin ());
1160+ }
1161+ op->dropAllUses ();
1162+ eraseOp (op);
1163+ }
1164+
11471165void ConversionPatternRewriterImpl::detachNestedAndErase (Operation *op) {
11481166 for (Region ®ion : op->getRegions ()) {
11491167 for (Block &block : region.getBlocks ()) {
@@ -1161,8 +1179,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
11611179 // Remove any newly created ops.
11621180 for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
11631181 detachNestedAndErase (materialization.getOp ());
1164- for (auto *op : llvm::reverse (createdOps))
1165- detachNestedAndErase (op);
11661182}
11671183
11681184void ConversionPatternRewriterImpl::applyRewrites () {
@@ -1182,9 +1198,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
11821198// State Management
11831199
11841200RewriterState ConversionPatternRewriterImpl::getCurrentState () {
1185- return RewriterState (createdOps.size (), unresolvedMaterializations.size (),
1186- rewrites.size (), ignoredOps.size (),
1187- eraseRewriter.erased .size ());
1201+ return RewriterState (unresolvedMaterializations.size (), rewrites.size (),
1202+ ignoredOps.size (), eraseRewriter.erased .size ());
11881203}
11891204
11901205void ConversionPatternRewriterImpl::resetState (RewriterState state) {
@@ -1205,12 +1220,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
12051220 detachNestedAndErase (op);
12061221 }
12071222
1208- // Pop all of the newly created operations.
1209- while (createdOps.size () != state.numCreatedOps ) {
1210- detachNestedAndErase (createdOps.back ());
1211- createdOps.pop_back ();
1212- }
1213-
12141223 // Pop all of the recorded ignored operations that are no longer valid.
12151224 while (ignoredOps.size () != state.numIgnoredOperations )
12161225 ignoredOps.pop_back ();
@@ -1478,7 +1487,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14781487 });
14791488 if (!previous.isSet ()) {
14801489 // This is a newly created op.
1481- createdOps. push_back (op);
1490+ appendRewrite<CreateOperationRewrite> (op);
14821491 return ;
14831492 }
14841493 Operation *prevOp = previous.getPoint () == previous.getBlock ()->end ()
@@ -1979,13 +1988,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
19791988 rewriter.replaceOp (op, replacementValues);
19801989
19811990 // Recursively legalize any new constant operations.
1982- for (unsigned i = curState.numCreatedOps , e = rewriterImpl.createdOps .size ();
1991+ for (unsigned i = curState.numRewrites , e = rewriterImpl.rewrites .size ();
19831992 i != e; ++i) {
1984- Operation *cstOp = rewriterImpl.createdOps [i];
1985- if (failed (legalize (cstOp, rewriter))) {
1993+ auto *createOp =
1994+ dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites [i].get ());
1995+ if (!createOp)
1996+ continue ;
1997+ if (failed (legalize (createOp->getOperation (), rewriter))) {
19861998 LLVM_DEBUG (logFailure (rewriterImpl.logger ,
19871999 " failed to legalize generated constant '{0}'" ,
1988- cstOp ->getName ()));
2000+ createOp-> getOperation () ->getName ()));
19892001 rewriterImpl.resetState (curState);
19902002 return failure ();
19912003 }
@@ -2132,9 +2144,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21322144 // blocks in regions created by this pattern will already be legalized later
21332145 // on. If we haven't built the set yet, build it now.
21342146 if (operationsToIgnore.empty ()) {
2135- auto createdOps = ArrayRef<Operation *>(impl.createdOps )
2136- .drop_front (state.numCreatedOps );
2137- operationsToIgnore.insert (createdOps.begin (), createdOps.end ());
2147+ for (unsigned i = state.numRewrites , e = impl.rewrites .size (); i != e;
2148+ ++i) {
2149+ auto *createOp =
2150+ dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2151+ if (!createOp)
2152+ continue ;
2153+ operationsToIgnore.insert (createOp->getOperation ());
2154+ }
21382155 }
21392156
21402157 // If this operation should be considered for re-legalization, try it.
@@ -2152,8 +2169,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
21522169LogicalResult OperationLegalizer::legalizePatternCreatedOperations (
21532170 ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
21542171 RewriterState &state, RewriterState &newState) {
2155- for (int i = state.numCreatedOps , e = newState.numCreatedOps ; i != e; ++i) {
2156- Operation *op = impl.createdOps [i];
2172+ for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2173+ auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2174+ if (!createOp)
2175+ continue ;
2176+ Operation *op = createOp->getOperation ();
21572177 if (failed (legalize (op, rewriter))) {
21582178 LLVM_DEBUG (logFailure (impl.logger ,
21592179 " failed to legalize generated operation '{0}'({1})" ,
@@ -2583,10 +2603,16 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
25832603 });
25842604 return liveUserIt == val.user_end () ? nullptr : *liveUserIt;
25852605 };
2586- for (auto &r : rewriterImpl.rewrites )
2587- if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get ()))
2588- if (failed (rewrite->materializeLiveConversions (findLiveUser)))
2606+ // Note: `rewrites` may be reallocated as the loop is running.
2607+ for (int64_t i = 0 ; i < static_cast <int64_t >(rewriterImpl.rewrites .size ());
2608+ ++i) {
2609+ auto &rewrite = rewriterImpl.rewrites [i];
2610+ if (auto *blockTypeConversionRewrite =
2611+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get ()))
2612+ if (failed (blockTypeConversionRewrite->materializeLiveConversions (
2613+ findLiveUser)))
25892614 return failure ();
2615+ }
25902616 return success ();
25912617}
25922618
0 commit comments