Skip to content

Commit 93ad0c4

Browse files
[mlir][Transforms] Add support for ConversionPatternRewriter::replaceAllUsesWith
1 parent f4dbd0d commit 93ad0c4

File tree

5 files changed

+112
-72
lines changed

5 files changed

+112
-72
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
633633

634634
/// Find uses of `from` and replace them with `to`. Also notify the listener
635635
/// about every in-place op modification (for every use that was replaced).
636-
void replaceAllUsesWith(Value from, Value to) {
636+
virtual void replaceAllUsesWith(Value from, Value to) {
637637
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
638638
Operation *op = operand.getOwner();
639639
modifyOpInPlace(op, [&]() { operand.set(to); });

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -854,15 +854,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
854854
Region *region, const TypeConverter &converter,
855855
TypeConverter::SignatureConversion *entryConversion = nullptr);
856856

857-
/// Replace all the uses of the block argument `from` with `to`. This
858-
/// function supports both 1:1 and 1:N replacements.
857+
/// Replace all the uses of `from` with `to`. This function supports both 1:1
858+
/// and 1:N replacements.
859859
///
860860
/// Note: If `allowPatternRollback` is set to "true", this function replaces
861-
/// all current and future uses of the block argument. This same block
862-
/// block argument must not be replaced multiple times. Uses are not replaced
863-
/// immediately but in a delayed fashion. Patterns may still see the original
864-
/// uses when inspecting IR.
865-
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
861+
/// all current and future uses of the `from` value. This same value must not
862+
/// be replaced multiple times. Uses are not replaced immediately but in a
863+
/// delayed fashion. Patterns may still see the original uses when inspecting
864+
/// IR.
865+
void replaceAllUsesWith(Value from, ValueRange to);
866+
void replaceAllUsesWith(Value from, Value to) override {
867+
replaceAllUsesWith(from, ValueRange{to});
868+
}
866869

867870
/// Return the converted value of 'key' with a type defined by the type
868871
/// converter of the currently executing pattern. Return nullptr in the case

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
284284
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
285285

286286
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
287-
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
287+
rewriter.replaceAllUsesWith(arg, valueArg);
288288
}
289289
}
290290

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 98 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,14 @@ class IRRewrite {
277277
InlineBlock,
278278
MoveBlock,
279279
BlockTypeConversion,
280-
ReplaceBlockArg,
281280
// Operation rewrites
282281
MoveOperation,
283282
ModifyOperation,
284283
ReplaceOperation,
285284
CreateOperation,
286-
UnresolvedMaterialization
285+
UnresolvedMaterialization,
286+
// Value rewrites
287+
ReplaceValue
287288
};
288289

289290
virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
330331

331332
static bool classof(const IRRewrite *rewrite) {
332333
return rewrite->getKind() >= Kind::CreateBlock &&
333-
rewrite->getKind() <= Kind::ReplaceBlockArg;
334+
rewrite->getKind() <= Kind::BlockTypeConversion;
334335
}
335336

336337
protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
342343
Block *block;
343344
};
344345

346+
/// A value rewrite.
347+
class ValueRewrite : public IRRewrite {
348+
public:
349+
/// Return the value that this rewrite operates on.
350+
Value getValue() const { return value; }
351+
352+
static bool classof(const IRRewrite *rewrite) {
353+
return rewrite->getKind() == Kind::ReplaceValue;
354+
}
355+
356+
protected:
357+
ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358+
Value value)
359+
: IRRewrite(kind, rewriterImpl), value(value) {}
360+
361+
// The value that this rewrite operates on.
362+
Value value;
363+
};
364+
345365
/// Creation of a block. Block creations are immediately reflected in the IR.
346366
/// There is no extra work to commit the rewrite. During rollback, the newly
347367
/// created block is erased.
@@ -548,29 +568,26 @@ class BlockTypeConversionRewrite : public BlockRewrite {
548568
Block *newBlock;
549569
};
550570

551-
/// Replacing a block argument. This rewrite is not immediately reflected in the
571+
/// Replacing a value. This rewrite is not immediately reflected in the
552572
/// IR. An internal IR mapping is updated, but the actual replacement is delayed
553573
/// until the rewrite is committed.
554-
class ReplaceBlockArgRewrite : public BlockRewrite {
574+
class ReplaceValueRewrite : public ValueRewrite {
555575
public:
556-
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
557-
Block *block, BlockArgument arg,
558-
const TypeConverter *converter)
559-
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
576+
ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
577+
const TypeConverter *converter)
578+
: ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
560579
converter(converter) {}
561580

562581
static bool classof(const IRRewrite *rewrite) {
563-
return rewrite->getKind() == Kind::ReplaceBlockArg;
582+
return rewrite->getKind() == Kind::ReplaceValue;
564583
}
565584

566585
void commit(RewriterBase &rewriter) override;
567586

568587
void rollback() override;
569588

570589
private:
571-
BlockArgument arg;
572-
573-
/// The current type converter when the block argument was replaced.
590+
/// The current type converter when the value was replaced.
574591
const TypeConverter *converter;
575592
};
576593

@@ -940,10 +957,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
940957
/// uses.
941958
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
942959

943-
/// Replace the given block argument with the given values. The specified
960+
/// Replace the uses of the given value with the given values. The specified
944961
/// converter is used to build materializations (if necessary).
945-
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
946-
const TypeConverter *converter);
962+
void replaceAllUsesWith(Value from, ValueRange to,
963+
const TypeConverter *converter);
947964

948965
/// Erase the given block and its contents.
949966
void eraseBlock(Block *block);
@@ -1129,10 +1146,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11291146
IRRewriter notifyingRewriter;
11301147

11311148
#ifndef NDEBUG
1132-
/// A set of replaced block arguments. This set is for debugging purposes
1133-
/// only and it is maintained only if `allowPatternRollback` is set to
1134-
/// "true".
1135-
DenseSet<BlockArgument> replacedArgs;
1149+
/// A set of replaced values. This set is for debugging purposes only and it
1150+
/// is maintained only if `allowPatternRollback` is set to "true".
1151+
DenseSet<Value> replacedValues;
11361152

11371153
/// A set of operations that have pending updates. This tracking isn't
11381154
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1169,32 +1185,54 @@ void BlockTypeConversionRewrite::rollback() {
11691185
getNewBlock()->replaceAllUsesWith(getOrigBlock());
11701186
}
11711187

1172-
static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
1173-
Value repl) {
1188+
/// Replace all uses of `from` with `repl`.
1189+
static void performReplaceValue(RewriterBase &rewriter, Value from,
1190+
Value repl) {
11741191
if (isa<BlockArgument>(repl)) {
1175-
rewriter.replaceAllUsesWith(arg, repl);
1192+
// `repl` is a block argument. Directly replace all uses.
1193+
rewriter.replaceAllUsesWith(from, repl);
11761194
return;
11771195
}
11781196

1179-
// If the replacement value is an operation, we check to make sure that we
1180-
// don't replace uses that are within the parent operation of the
1181-
// replacement value.
1182-
Operation *replOp = cast<OpResult>(repl).getOwner();
1197+
// If the replacement value is an operation, only replace those uses that:
1198+
// - are in a different block than the replacement operation, or
1199+
// - are in the same block but after the replacement operation.
1200+
//
1201+
// Example:
1202+
// ^bb0(%arg0: i32):
1203+
// %0 = "consumer"(%arg0) : (i32) -> (i32)
1204+
// "another_consumer"(%arg0) : (i32) -> ()
1205+
//
1206+
// In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1207+
// use in "another_consumer" but not the use in "consumer". When using the
1208+
// normal RewriterBase API, this would typically be done with
1209+
// `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1210+
// supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1211+
// it cannot be supported efficiently with `allowPatternRollback` set to
1212+
// "true". Therefore, the conversion driver is trying to be smart and replaces
1213+
// only those uses that do not lead to a dominance violation. E.g., the
1214+
// FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1215+
// behavior.
1216+
//
1217+
// TODO: As we move more and more towards `allowPatternRollback` set to
1218+
// "false", we should remove this special handling, in order to align the
1219+
// `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1220+
Operation *replOp = repl.getDefiningOp();
11831221
Block *replBlock = replOp->getBlock();
1184-
rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1222+
rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
11851223
Operation *user = operand.getOwner();
11861224
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
11871225
});
11881226
}
11891227

1190-
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1191-
Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1228+
void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1229+
Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
11921230
if (!repl)
11931231
return;
1194-
performReplaceBlockArg(rewriter, arg, repl);
1232+
performReplaceValue(rewriter, value, repl);
11951233
}
11961234

1197-
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1235+
void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
11981236

11991237
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
12001238
auto *listener =
@@ -1584,7 +1622,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15841622
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
15851623
/*isPureTypeConversion=*/false)
15861624
.front();
1587-
replaceUsesOfBlockArgument(origArg, mat, converter);
1625+
replaceAllUsesWith(origArg, mat, converter);
15881626
continue;
15891627
}
15901628

@@ -1593,15 +1631,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15931631
assert(inputMap->size == 0 &&
15941632
"invalid to provide a replacement value when the argument isn't "
15951633
"dropped");
1596-
replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
1597-
converter);
1634+
replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
15981635
continue;
15991636
}
16001637

16011638
// This is a 1->1+ mapping.
16021639
auto replArgs =
16031640
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1604-
replaceUsesOfBlockArgument(origArg, replArgs, converter);
1641+
replaceAllUsesWith(origArg, replArgs, converter);
16051642
}
16061643

16071644
if (config.allowPatternRollback)
@@ -1873,8 +1910,8 @@ void ConversionPatternRewriterImpl::replaceOp(
18731910
op->walk([&](Operation *op) { replacedOps.insert(op); });
18741911
}
18751912

1876-
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
1877-
BlockArgument from, ValueRange to, const TypeConverter *converter) {
1913+
void ConversionPatternRewriterImpl::replaceAllUsesWith(
1914+
Value from, ValueRange to, const TypeConverter *converter) {
18781915
if (!config.allowPatternRollback) {
18791916
SmallVector<Value> toConv = llvm::to_vector(to);
18801917
SmallVector<Value> repls =
@@ -1884,25 +1921,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
18841921
if (!repl)
18851922
return;
18861923

1887-
performReplaceBlockArg(r, from, repl);
1924+
performReplaceValue(r, from, repl);
18881925
return;
18891926
}
18901927

18911928
#ifndef NDEBUG
1892-
// Make sure that a block argument is not replaced multiple times. In
1893-
// rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
1894-
// uses of the given block argument, but also all future uses that may be
1895-
// introduced by future pattern applications. Therefore, it does not make
1896-
// sense to call `replaceUsesOfBlockArgument` multiple times with the same
1897-
// block argument. Doing so would overwrite the mapping and mess with the
1898-
// internal state of the dialect conversion driver.
1899-
assert(!replacedArgs.contains(from) &&
1900-
"attempting to replace a block argument that was already replaced");
1901-
replacedArgs.insert(from);
1929+
// Make sure that a value is not replaced multiple times. In rollback mode,
1930+
// `replaceAllUsesWith` replaces not only all current uses of the given value,
1931+
// but also all future uses that may be introduced by future pattern
1932+
// applications. Therefore, it does not make sense to call
1933+
// `replaceAllUsesWith` multiple times with the same value. Doing so would
1934+
// overwrite the mapping and mess with the internal state of the dialect
1935+
// conversion driver.
1936+
assert(!replacedValues.contains(from) &&
1937+
"attempting to replace a value that was already replaced");
1938+
replacedValues.insert(from);
19021939
#endif // NDEBUG
19031940

1904-
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
19051941
mapping.map(from, to);
1942+
appendRewrite<ReplaceValueRewrite>(from, converter);
19061943
}
19071944

19081945
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2107,18 +2144,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
21072144
return impl->convertRegionTypes(region, converter, entryConversion);
21082145
}
21092146

2110-
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
2111-
ValueRange to) {
2147+
void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
21122148
LLVM_DEBUG({
2113-
impl->logger.startLine() << "** Replace Argument : '" << from << "'";
2114-
if (Operation *parentOp = from.getOwner()->getParentOp()) {
2115-
impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2116-
<< "' (" << parentOp << ")\n";
2117-
} else {
2118-
impl->logger.getOStream() << " (unlinked block)\n";
2149+
impl->logger.startLine() << "** Replace Value : '" << from << "'";
2150+
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2151+
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
2152+
impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2153+
<< "' (" << parentOp << ")\n";
2154+
} else {
2155+
impl->logger.getOStream() << " (unlinked block)\n";
2156+
}
21192157
}
21202158
});
2121-
impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
2159+
impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
21222160
}
21232161

21242162
Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2176,7 +2214,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
21762214

21772215
// Replace all uses of block arguments.
21782216
for (auto it : llvm::zip(source->getArguments(), argValues))
2179-
replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
2217+
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
21802218

21812219
if (fastPath) {
21822220
// Move all ops at once.

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -952,7 +952,7 @@ struct TestCreateIllegalBlock : public RewritePattern {
952952
}
953953
};
954954

955-
/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
955+
/// A simple pattern that tests the "replaceAllUsesWith" API.
956956
struct TestBlockArgReplace : public ConversionPattern {
957957
TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
958958
: ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
@@ -963,8 +963,7 @@ struct TestBlockArgReplace : public ConversionPattern {
963963
ConversionPatternRewriter &rewriter) const final {
964964
// Replace the first block argument with 2x the second block argument.
965965
Value repl = op->getRegion(0).getArgument(1);
966-
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
967-
{repl, repl});
966+
rewriter.replaceAllUsesWith(op->getRegion(0).getArgument(0), {repl, repl});
968967
rewriter.modifyOpInPlace(op, [&] {
969968
// If the "trigger_rollback" attribute is set, keep the op illegal, so
970969
// that a rollback is triggered.

0 commit comments

Comments
 (0)