@@ -277,13 +277,14 @@ class IRRewrite {
277277 InlineBlock,
278278 MoveBlock,
279279 BlockTypeConversion,
280- ReplaceBlockArg,
281280 // Operation rewrites
282281 MoveOperation,
283282 ModifyOperation,
284283 ReplaceOperation,
285284 CreateOperation,
286- UnresolvedMaterialization
285+ UnresolvedMaterialization,
286+ // Value rewrites
287+ ReplaceValue
287288 };
288289
289290 virtual ~IRRewrite () = default ;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
330331
331332 static bool classof (const IRRewrite *rewrite) {
332333 return rewrite->getKind () >= Kind::CreateBlock &&
333- rewrite->getKind () <= Kind::ReplaceBlockArg ;
334+ rewrite->getKind () <= Kind::BlockTypeConversion ;
334335 }
335336
336337protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
342343 Block *block;
343344};
344345
346+ // / A value rewrite.
347+ class ValueRewrite : public IRRewrite {
348+ public:
349+ // / Return the value that this rewrite operates on.
350+ Value getValue () const { return value; }
351+
352+ static bool classof (const IRRewrite *rewrite) {
353+ return rewrite->getKind () == Kind::ReplaceValue;
354+ }
355+
356+ protected:
357+ ValueRewrite (Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358+ Value value)
359+ : IRRewrite(kind, rewriterImpl), value(value) {}
360+
361+ // The value that this rewrite operates on.
362+ Value value;
363+ };
364+
345365// / Creation of a block. Block creations are immediately reflected in the IR.
346366// / There is no extra work to commit the rewrite. During rollback, the newly
347367// / created block is erased.
@@ -548,29 +568,26 @@ class BlockTypeConversionRewrite : public BlockRewrite {
548568 Block *newBlock;
549569};
550570
551- // / Replacing a block argument . This rewrite is not immediately reflected in the
571+ // / Replacing a value . This rewrite is not immediately reflected in the
552572// / IR. An internal IR mapping is updated, but the actual replacement is delayed
553573// / until the rewrite is committed.
554- class ReplaceBlockArgRewrite : public BlockRewrite {
574+ class ReplaceValueRewrite : public ValueRewrite {
555575public:
556- ReplaceBlockArgRewrite (ConversionPatternRewriterImpl &rewriterImpl,
557- Block *block, BlockArgument arg,
558- const TypeConverter *converter)
559- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
576+ ReplaceValueRewrite (ConversionPatternRewriterImpl &rewriterImpl, Value value,
577+ const TypeConverter *converter)
578+ : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
560579 converter (converter) {}
561580
562581 static bool classof (const IRRewrite *rewrite) {
563- return rewrite->getKind () == Kind::ReplaceBlockArg ;
582+ return rewrite->getKind () == Kind::ReplaceValue ;
564583 }
565584
566585 void commit (RewriterBase &rewriter) override ;
567586
568587 void rollback () override ;
569588
570589private:
571- BlockArgument arg;
572-
573- // / The current type converter when the block argument was replaced.
590+ // / The current type converter when the value was replaced.
574591 const TypeConverter *converter;
575592};
576593
@@ -942,10 +959,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
942959 // / uses.
943960 void replaceOp (Operation *op, SmallVector<SmallVector<Value>> &&newValues);
944961
945- // / Replace the given block argument with the given values. The specified
962+ // / Replace the uses of the given value with the given values. The specified
946963 // / converter is used to build materializations (if necessary).
947- void replaceUsesOfBlockArgument (BlockArgument from, ValueRange to,
948- const TypeConverter *converter);
964+ void replaceAllUsesWith (Value from, ValueRange to,
965+ const TypeConverter *converter);
949966
950967 // / Erase the given block and its contents.
951968 void eraseBlock (Block *block);
@@ -1132,10 +1149,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11321149 IRRewriter notifyingRewriter;
11331150
11341151#ifndef NDEBUG
1135- // / A set of replaced block arguments. This set is for debugging purposes
1136- // / only and it is maintained only if `allowPatternRollback` is set to
1137- // / "true".
1138- DenseSet<BlockArgument> replacedArgs;
1152+ // / A set of replaced values. This set is for debugging purposes only and it
1153+ // / is maintained only if `allowPatternRollback` is set to "true".
1154+ DenseSet<Value> replacedValues;
11391155
11401156 // / A set of operations that have pending updates. This tracking isn't
11411157 // / strictly necessary, and is thus only active during debug builds for extra
@@ -1172,32 +1188,54 @@ void BlockTypeConversionRewrite::rollback() {
11721188 getNewBlock ()->replaceAllUsesWith (getOrigBlock ());
11731189}
11741190
1175- static void performReplaceBlockArg (RewriterBase &rewriter, BlockArgument arg,
1176- Value repl) {
1191+ // / Replace all uses of `from` with `repl`.
1192+ static void performReplaceValue (RewriterBase &rewriter, Value from,
1193+ Value repl) {
11771194 if (isa<BlockArgument>(repl)) {
1178- rewriter.replaceAllUsesWith (arg, repl);
1195+ // `repl` is a block argument. Directly replace all uses.
1196+ rewriter.replaceAllUsesWith (from, repl);
11791197 return ;
11801198 }
11811199
1182- // If the replacement value is an operation, we check to make sure that we
1183- // don't replace uses that are within the parent operation of the
1184- // replacement value.
1185- Operation *replOp = cast<OpResult>(repl).getOwner ();
1200+ // If the replacement value is an operation, only replace those uses that:
1201+ // - are in a different block than the replacement operation, or
1202+ // - are in the same block but after the replacement operation.
1203+ //
1204+ // Example:
1205+ // ^bb0(%arg0: i32):
1206+ // %0 = "consumer"(%arg0) : (i32) -> (i32)
1207+ // "another_consumer"(%arg0) : (i32) -> ()
1208+ //
1209+ // In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1210+ // use in "another_consumer" but not the use in "consumer". When using the
1211+ // normal RewriterBase API, this would typically be done with
1212+ // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1213+ // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1214+ // it cannot be supported efficiently with `allowPatternRollback` set to
1215+ // "true". Therefore, the conversion driver is trying to be smart and replaces
1216+ // only those uses that do not lead to a dominance violation. E.g., the
1217+ // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1218+ // behavior.
1219+ //
1220+ // TODO: As we move more and more towards `allowPatternRollback` set to
1221+ // "false", we should remove this special handling, in order to align the
1222+ // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1223+ Operation *replOp = repl.getDefiningOp ();
11861224 Block *replBlock = replOp->getBlock ();
1187- rewriter.replaceUsesWithIf (arg , repl, [&](OpOperand &operand) {
1225+ rewriter.replaceUsesWithIf (from , repl, [&](OpOperand &operand) {
11881226 Operation *user = operand.getOwner ();
11891227 return user->getBlock () != replBlock || replOp->isBeforeInBlock (user);
11901228 });
11911229}
11921230
1193- void ReplaceBlockArgRewrite ::commit (RewriterBase &rewriter) {
1194- Value repl = rewriterImpl.findOrBuildReplacementValue (arg , converter);
1231+ void ReplaceValueRewrite ::commit (RewriterBase &rewriter) {
1232+ Value repl = rewriterImpl.findOrBuildReplacementValue (value , converter);
11951233 if (!repl)
11961234 return ;
1197- performReplaceBlockArg (rewriter, arg , repl);
1235+ performReplaceValue (rewriter, value , repl);
11981236}
11991237
1200- void ReplaceBlockArgRewrite ::rollback () { rewriterImpl.mapping .erase ({arg }); }
1238+ void ReplaceValueRewrite ::rollback () { rewriterImpl.mapping .erase ({value }); }
12011239
12021240void ReplaceOperationRewrite::commit (RewriterBase &rewriter) {
12031241 auto *listener =
@@ -1590,7 +1628,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15901628 /* outputTypes=*/ origArgType, /* originalType=*/ Type (), converter,
15911629 /* castOp=*/ nullptr , /* isPureTypeConversion=*/ false )
15921630 .front ();
1593- replaceUsesOfBlockArgument (origArg, mat, converter);
1631+ replaceAllUsesWith (origArg, mat, converter);
15941632 continue ;
15951633 }
15961634
@@ -1599,15 +1637,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15991637 assert (inputMap->size == 0 &&
16001638 " invalid to provide a replacement value when the argument isn't "
16011639 " dropped" );
1602- replaceUsesOfBlockArgument (origArg, inputMap->replacementValues ,
1603- converter);
1640+ replaceAllUsesWith (origArg, inputMap->replacementValues , converter);
16041641 continue ;
16051642 }
16061643
16071644 // This is a 1->1+ mapping.
16081645 auto replArgs =
16091646 newBlock->getArguments ().slice (inputMap->inputNo , inputMap->size );
1610- replaceUsesOfBlockArgument (origArg, replArgs, converter);
1647+ replaceAllUsesWith (origArg, replArgs, converter);
16111648 }
16121649
16131650 if (config.allowPatternRollback )
@@ -1882,8 +1919,8 @@ void ConversionPatternRewriterImpl::replaceOp(
18821919 op->walk ([&](Operation *op) { replacedOps.insert (op); });
18831920}
18841921
1885- void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument (
1886- BlockArgument from, ValueRange to, const TypeConverter *converter) {
1922+ void ConversionPatternRewriterImpl::replaceAllUsesWith (
1923+ Value from, ValueRange to, const TypeConverter *converter) {
18871924 if (!config.allowPatternRollback ) {
18881925 SmallVector<Value> toConv = llvm::to_vector (to);
18891926 SmallVector<Value> repls =
@@ -1893,25 +1930,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
18931930 if (!repl)
18941931 return ;
18951932
1896- performReplaceBlockArg (r, from, repl);
1933+ performReplaceValue (r, from, repl);
18971934 return ;
18981935 }
18991936
19001937#ifndef NDEBUG
1901- // Make sure that a block argument is not replaced multiple times. In
1902- // rollback mode, `replaceUsesOfBlockArgument ` replaces not only all current
1903- // uses of the given block argument, but also all future uses that may be
1904- // introduced by future pattern applications. Therefore, it does not make
1905- // sense to call `replaceUsesOfBlockArgument ` multiple times with the same
1906- // block argument. Doing so would overwrite the mapping and mess with the
1907- // internal state of the dialect conversion driver.
1908- assert (!replacedArgs .contains (from) &&
1909- " attempting to replace a block argument that was already replaced" );
1910- replacedArgs .insert (from);
1938+ // Make sure that a value is not replaced multiple times. In rollback mode,
1939+ // `replaceAllUsesWith ` replaces not only all current uses of the given value,
1940+ // but also all future uses that may be introduced by future pattern
1941+ // applications. Therefore, it does not make sense to call
1942+ // `replaceAllUsesWith ` multiple times with the same value. Doing so would
1943+ // overwrite the mapping and mess with the internal state of the dialect
1944+ // conversion driver.
1945+ assert (!replacedValues .contains (from) &&
1946+ " attempting to replace a value that was already replaced" );
1947+ replacedValues .insert (from);
19111948#endif // NDEBUG
19121949
1913- appendRewrite<ReplaceBlockArgRewrite>(from.getOwner (), from, converter);
19141950 mapping.map (from, to);
1951+ appendRewrite<ReplaceValueRewrite>(from, converter);
19151952}
19161953
19171954void ConversionPatternRewriterImpl::eraseBlock (Block *block) {
@@ -2116,18 +2153,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
21162153 return impl->convertRegionTypes (*this , region, converter, entryConversion);
21172154}
21182155
2119- void ConversionPatternRewriter::replaceUsesOfBlockArgument (BlockArgument from,
2120- ValueRange to) {
2156+ void ConversionPatternRewriter::replaceAllUsesWith (Value from, ValueRange to) {
21212157 LLVM_DEBUG ({
2122- impl->logger .startLine () << " ** Replace Argument : '" << from << " '" ;
2123- if (Operation *parentOp = from.getOwner ()->getParentOp ()) {
2124- impl->logger .getOStream () << " (in region of '" << parentOp->getName ()
2125- << " ' (" << parentOp << " )\n " ;
2126- } else {
2127- impl->logger .getOStream () << " (unlinked block)\n " ;
2158+ impl->logger .startLine () << " ** Replace Value : '" << from << " '" ;
2159+ if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2160+ if (Operation *parentOp = blockArg.getOwner ()->getParentOp ()) {
2161+ impl->logger .getOStream () << " (in region of '" << parentOp->getName ()
2162+ << " ' (" << parentOp << " )\n " ;
2163+ } else {
2164+ impl->logger .getOStream () << " (unlinked block)\n " ;
2165+ }
21282166 }
21292167 });
2130- impl->replaceUsesOfBlockArgument (from, to, impl->currentTypeConverter );
2168+ impl->replaceAllUsesWith (from, to, impl->currentTypeConverter );
21312169}
21322170
21332171Value ConversionPatternRewriter::getRemappedValue (Value key) {
@@ -2185,7 +2223,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
21852223
21862224 // Replace all uses of block arguments.
21872225 for (auto it : llvm::zip (source->getArguments (), argValues))
2188- replaceUsesOfBlockArgument (std::get<0 >(it), std::get<1 >(it));
2226+ replaceAllUsesWith (std::get<0 >(it), std::get<1 >(it));
21892227
21902228 if (fastPath) {
21912229 // Move all ops at once.
0 commit comments