@@ -277,13 +277,14 @@ class IRRewrite {
277277 InlineBlock,
278278 MoveBlock,
279279 BlockTypeConversion,
280- ReplaceBlockArg,
281280 // Operation rewrites
282281 MoveOperation,
283282 ModifyOperation,
284283 ReplaceOperation,
285284 CreateOperation,
286- UnresolvedMaterialization
285+ UnresolvedMaterialization,
286+ // Value rewrites
287+ ReplaceValue
287288 };
288289
289290 virtual ~IRRewrite () = default ;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
330331
331332 static bool classof (const IRRewrite *rewrite) {
332333 return rewrite->getKind () >= Kind::CreateBlock &&
333- rewrite->getKind () <= Kind::ReplaceBlockArg ;
334+ rewrite->getKind () <= Kind::BlockTypeConversion ;
334335 }
335336
336337protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
342343 Block *block;
343344};
344345
346+ // / A value rewrite.
347+ class ValueRewrite : public IRRewrite {
348+ public:
349+ // / Return the value that this rewrite operates on.
350+ Value getValue () const { return value; }
351+
352+ static bool classof (const IRRewrite *rewrite) {
353+ return rewrite->getKind () == Kind::ReplaceValue;
354+ }
355+
356+ protected:
357+ ValueRewrite (Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358+ Value value)
359+ : IRRewrite(kind, rewriterImpl), value(value) {}
360+
361+ // The value that this rewrite operates on.
362+ Value value;
363+ };
364+
345365// / Creation of a block. Block creations are immediately reflected in the IR.
346366// / There is no extra work to commit the rewrite. During rollback, the newly
347367// / created block is erased.
@@ -548,29 +568,26 @@ class BlockTypeConversionRewrite : public BlockRewrite {
548568 Block *newBlock;
549569};
550570
551- // / Replacing a block argument . This rewrite is not immediately reflected in the
571+ // / Replacing a value . This rewrite is not immediately reflected in the
552572// / IR. An internal IR mapping is updated, but the actual replacement is delayed
553573// / until the rewrite is committed.
554- class ReplaceBlockArgRewrite : public BlockRewrite {
574+ class ReplaceValueRewrite : public ValueRewrite {
555575public:
556- ReplaceBlockArgRewrite (ConversionPatternRewriterImpl &rewriterImpl,
557- Block *block, BlockArgument arg,
558- const TypeConverter *converter)
559- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
576+ ReplaceValueRewrite (ConversionPatternRewriterImpl &rewriterImpl, Value value,
577+ const TypeConverter *converter)
578+ : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
560579 converter (converter) {}
561580
562581 static bool classof (const IRRewrite *rewrite) {
563- return rewrite->getKind () == Kind::ReplaceBlockArg ;
582+ return rewrite->getKind () == Kind::ReplaceValue ;
564583 }
565584
566585 void commit (RewriterBase &rewriter) override ;
567586
568587 void rollback () override ;
569588
570589private:
571- BlockArgument arg;
572-
573- // / The current type converter when the block argument was replaced.
590+ // / The current type converter when the value was replaced.
574591 const TypeConverter *converter;
575592};
576593
@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
940957 // / uses.
941958 void replaceOp (Operation *op, SmallVector<SmallVector<Value>> &&newValues);
942959
943- // / Replace the given block argument with the given values. The specified
960+ // / Replace the uses of the given value with the given values. The specified
944961 // / converter is used to build materializations (if necessary).
945- void replaceUsesOfBlockArgument (BlockArgument from, ValueRange to,
946- const TypeConverter *converter);
962+ void replaceAllUsesWith (Value from, ValueRange to,
963+ const TypeConverter *converter);
947964
948965 // / Erase the given block and its contents.
949966 void eraseBlock (Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11291146 IRRewriter notifyingRewriter;
11301147
11311148#ifndef NDEBUG
1132- // / A set of replaced block arguments. This set is for debugging purposes
1133- // / only and it is maintained only if `allowPatternRollback` is set to
1134- // / "true".
1135- DenseSet<BlockArgument> replacedArgs;
1149+ // / A set of replaced values. This set is for debugging purposes only and it
1150+ // / is maintained only if `allowPatternRollback` is set to "true".
1151+ DenseSet<Value> replacedValues;
11361152
11371153 // / A set of operations that have pending updates. This tracking isn't
11381154 // / strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
11691185 getNewBlock ()->replaceAllUsesWith (getOrigBlock ());
11701186}
11711187
1172- static void performReplaceBlockArg (RewriterBase &rewriter, BlockArgument arg,
1173- Value repl) {
1188+ // / Replace all uses of `from` with `repl`.
1189+ static void performReplaceValue (RewriterBase &rewriter, Value from,
1190+ Value repl) {
11741191 if (isa<BlockArgument>(repl)) {
1175- rewriter.replaceAllUsesWith (arg, repl);
1192+ // `repl` is a block argument. Directly replace all uses.
1193+ rewriter.replaceAllUsesWith (from, repl);
11761194 return ;
11771195 }
11781196
1179- // If the replacement value is an operation, we check to make sure that we
1180- // don't replace uses that are within the parent operation of the
1181- // replacement value.
1182- Operation *replOp = cast<OpResult>(repl).getOwner ();
1197+ // If the replacement value is an operation, only replace those uses that:
1198+ // - are in a different block than the replacement operation, or
1199+ // - are in the same block but after the replacement operation.
1200+ //
1201+ // Example:
1202+ // ^bb0(%arg0: i32):
1203+ // %0 = "consumer"(%arg0) : (i32) -> (i32)
1204+ // "another_consumer"(%arg0) : (i32) -> ()
1205+ //
1206+ // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1207+ // use in "another_consumer" but not the use in "consumer". When using the
1208+ // normal RewriterBase API, this would typically be done with
1209+ // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1210+ // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1211+ // it cannot be supported efficiently with `allowPatternRollback` set to
1212+ // "true". Therefore, the conversion driver is trying to be smart and replaces
1213+ // only those uses that do not lead to a dominance violation. E.g., the
1214+ // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1215+ // behavior.
1216+ //
1217+ // TODO: As we move more and more towards `allowPatternRollback` set to
1218+ // "false", we should remove this special handling, in order to align the
1219+ // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1220+ Operation *replOp = repl.getDefiningOp ();
11831221 Block *replBlock = replOp->getBlock ();
1184- rewriter.replaceUsesWithIf (arg , repl, [&](OpOperand &operand) {
1222+ rewriter.replaceUsesWithIf (from , repl, [&](OpOperand &operand) {
11851223 Operation *user = operand.getOwner ();
11861224 return user->getBlock () != replBlock || replOp->isBeforeInBlock (user);
11871225 });
11881226}
11891227
1190- void ReplaceBlockArgRewrite ::commit (RewriterBase &rewriter) {
1191- Value repl = rewriterImpl.findOrBuildReplacementValue (arg , converter);
1228+ void ReplaceValueRewrite ::commit (RewriterBase &rewriter) {
1229+ Value repl = rewriterImpl.findOrBuildReplacementValue (value , converter);
11921230 if (!repl)
11931231 return ;
1194- performReplaceBlockArg (rewriter, arg , repl);
1232+ performReplaceValue (rewriter, value , repl);
11951233}
11961234
1197- void ReplaceBlockArgRewrite ::rollback () { rewriterImpl.mapping .erase ({arg }); }
1235+ void ReplaceValueRewrite ::rollback () { rewriterImpl.mapping .erase ({value }); }
11981236
11991237void ReplaceOperationRewrite::commit (RewriterBase &rewriter) {
12001238 auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15841622 /* outputTypes=*/ origArgType, /* originalType=*/ Type (), converter,
15851623 /* isPureTypeConversion=*/ false )
15861624 .front ();
1587- replaceUsesOfBlockArgument (origArg, mat, converter);
1625+ replaceAllUsesWith (origArg, mat, converter);
15881626 continue ;
15891627 }
15901628
@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15931631 assert (inputMap->size == 0 &&
15941632 " invalid to provide a replacement value when the argument isn't "
15951633 " dropped" );
1596- replaceUsesOfBlockArgument (origArg, inputMap->replacementValues ,
1597- converter);
1634+ replaceAllUsesWith (origArg, inputMap->replacementValues , converter);
15981635 continue ;
15991636 }
16001637
16011638 // This is a 1->1+ mapping.
16021639 auto replArgs =
16031640 newBlock->getArguments ().slice (inputMap->inputNo , inputMap->size );
1604- replaceUsesOfBlockArgument (origArg, replArgs, converter);
1641+ replaceAllUsesWith (origArg, replArgs, converter);
16051642 }
16061643
16071644 if (config.allowPatternRollback )
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
18731910 op->walk ([&](Operation *op) { replacedOps.insert (op); });
18741911}
18751912
1876- void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument (
1877- BlockArgument from, ValueRange to, const TypeConverter *converter) {
1913+ void ConversionPatternRewriterImpl::replaceAllUsesWith (
1914+ Value from, ValueRange to, const TypeConverter *converter) {
18781915 if (!config.allowPatternRollback ) {
18791916 SmallVector<Value> toConv = llvm::to_vector (to);
18801917 SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
18841921 if (!repl)
18851922 return ;
18861923
1887- performReplaceBlockArg (r, from, repl);
1924+ performReplaceValue (r, from, repl);
18881925 return ;
18891926 }
18901927
18911928#ifndef NDEBUG
1892- // Make sure that a block argument is not replaced multiple times. In
1893- // rollback mode, `replaceUsesOfBlockArgument ` replaces not only all current
1894- // uses of the given block argument, but also all future uses that may be
1895- // introduced by future pattern applications. Therefore, it does not make
1896- // sense to call `replaceUsesOfBlockArgument ` multiple times with the same
1897- // block argument. Doing so would overwrite the mapping and mess with the
1898- // internal state of the dialect conversion driver.
1899- assert (!replacedArgs .contains (from) &&
1900- " attempting to replace a block argument that was already replaced" );
1901- replacedArgs .insert (from);
1929+ // Make sure that a value is not replaced multiple times. In rollback mode,
1930+ // `replaceAllUsesWith ` replaces not only all current uses of the given value,
1931+ // but also all future uses that may be introduced by future pattern
1932+ // applications. Therefore, it does not make sense to call
1933+ // `replaceAllUsesWith ` multiple times with the same value. Doing so would
1934+ // overwrite the mapping and mess with the internal state of the dialect
1935+ // conversion driver.
1936+ assert (!replacedValues .contains (from) &&
1937+ " attempting to replace a value that was already replaced" );
1938+ replacedValues .insert (from);
19021939#endif // NDEBUG
19031940
1904- appendRewrite<ReplaceBlockArgRewrite>(from.getOwner (), from, converter);
19051941 mapping.map (from, to);
1942+ appendRewrite<ReplaceValueRewrite>(from, converter);
19061943}
19071944
19081945void ConversionPatternRewriterImpl::eraseBlock (Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
21072144 return impl->convertRegionTypes (region, converter, entryConversion);
21082145}
21092146
2110- void ConversionPatternRewriter::replaceUsesOfBlockArgument (BlockArgument from,
2111- ValueRange to) {
2147+ void ConversionPatternRewriter::replaceAllUsesWith (Value from, ValueRange to) {
21122148 LLVM_DEBUG ({
2113- impl->logger .startLine () << " ** Replace Argument : '" << from << " '" ;
2114- if (Operation *parentOp = from.getOwner ()->getParentOp ()) {
2115- impl->logger .getOStream () << " (in region of '" << parentOp->getName ()
2116- << " ' (" << parentOp << " )\n " ;
2117- } else {
2118- impl->logger .getOStream () << " (unlinked block)\n " ;
2149+ impl->logger .startLine () << " ** Replace Value : '" << from << " '" ;
2150+ if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2151+ if (Operation *parentOp = blockArg.getOwner ()->getParentOp ()) {
2152+ impl->logger .getOStream () << " (in region of '" << parentOp->getName ()
2153+ << " ' (" << parentOp << " )\n " ;
2154+ } else {
2155+ impl->logger .getOStream () << " (unlinked block)\n " ;
2156+ }
21192157 }
21202158 });
2121- impl->replaceUsesOfBlockArgument (from, to, impl->currentTypeConverter );
2159+ impl->replaceAllUsesWith (from, to, impl->currentTypeConverter );
21222160}
21232161
21242162Value ConversionPatternRewriter::getRemappedValue (Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
21762214
21772215 // Replace all uses of block arguments.
21782216 for (auto it : llvm::zip (source->getArguments (), argValues))
2179- replaceUsesOfBlockArgument (std::get<0 >(it), std::get<1 >(it));
2217+ replaceAllUsesWith (std::get<0 >(it), std::get<1 >(it));
21802218
21812219 if (fastPath) {
21822220 // Move all ops at once.
0 commit comments