Skip to content

Commit 88d9afc

Browse files
[mlir][Transforms] Add support for ConversionPatternRewriter::replaceAllUsesWith
1 parent 769d5c2 commit 88d9afc

File tree

5 files changed

+112
-72
lines changed

5 files changed

+112
-72
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder {
633633

634634
/// Find uses of `from` and replace them with `to`. Also notify the listener
635635
/// about every in-place op modification (for every use that was replaced).
636-
void replaceAllUsesWith(Value from, Value to) {
636+
virtual void replaceAllUsesWith(Value from, Value to) {
637637
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
638638
Operation *op = operand.getOwner();
639639
modifyOpInPlace(op, [&]() { operand.set(to); });

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -780,15 +780,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
780780
Region *region, const TypeConverter &converter,
781781
TypeConverter::SignatureConversion *entryConversion = nullptr);
782782

783-
/// Replace all the uses of the block argument `from` with `to`. This
784-
/// function supports both 1:1 and 1:N replacements.
783+
/// Replace all the uses of `from` with `to`. This function supports both 1:1
784+
/// and 1:N replacements.
785785
///
786786
/// Note: If `allowPatternRollback` is set to "true", this function replaces
787-
/// all current and future uses of the block argument. This same block
788-
/// block argument must not be replaced multiple times. Uses are not replaced
789-
/// immediately but in a delayed fashion. Patterns may still see the original
790-
/// uses when inspecting IR.
791-
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
787+
/// all current and future uses of the `from` value. This same value must not
788+
/// be replaced multiple times. Uses are not replaced immediately but in a
789+
/// delayed fashion. Patterns may still see the original uses when inspecting
790+
/// IR.
791+
void replaceAllUsesWith(Value from, ValueRange to);
792+
void replaceAllUsesWith(Value from, Value to) override {
793+
replaceAllUsesWith(from, ValueRange{to});
794+
}
792795

793796
/// Return the converted value of 'key' with a type defined by the type
794797
/// converter of the currently executing pattern. Return nullptr in the case

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ static void restoreByValRefArgumentType(
284284
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
285285

286286
Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
287-
rewriter.replaceUsesOfBlockArgument(arg, valueArg);
287+
rewriter.replaceAllUsesWith(arg, valueArg);
288288
}
289289
}
290290

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 98 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,14 @@ class IRRewrite {
277277
InlineBlock,
278278
MoveBlock,
279279
BlockTypeConversion,
280-
ReplaceBlockArg,
281280
// Operation rewrites
282281
MoveOperation,
283282
ModifyOperation,
284283
ReplaceOperation,
285284
CreateOperation,
286-
UnresolvedMaterialization
285+
UnresolvedMaterialization,
286+
// Value rewrites
287+
ReplaceValue
287288
};
288289

289290
virtual ~IRRewrite() = default;
@@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite {
330331

331332
static bool classof(const IRRewrite *rewrite) {
332333
return rewrite->getKind() >= Kind::CreateBlock &&
333-
rewrite->getKind() <= Kind::ReplaceBlockArg;
334+
rewrite->getKind() <= Kind::BlockTypeConversion;
334335
}
335336

336337
protected:
@@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite {
342343
Block *block;
343344
};
344345

346+
/// A value rewrite.
347+
class ValueRewrite : public IRRewrite {
348+
public:
349+
/// Return the value that this rewrite operates on.
350+
Value getValue() const { return value; }
351+
352+
static bool classof(const IRRewrite *rewrite) {
353+
return rewrite->getKind() == Kind::ReplaceValue;
354+
}
355+
356+
protected:
357+
ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
358+
Value value)
359+
: IRRewrite(kind, rewriterImpl), value(value) {}
360+
361+
// The value that this rewrite operates on.
362+
Value value;
363+
};
364+
345365
/// Creation of a block. Block creations are immediately reflected in the IR.
346366
/// There is no extra work to commit the rewrite. During rollback, the newly
347367
/// created block is erased.
@@ -548,29 +568,26 @@ class BlockTypeConversionRewrite : public BlockRewrite {
548568
Block *newBlock;
549569
};
550570

551-
/// Replacing a block argument. This rewrite is not immediately reflected in the
571+
/// Replacing a value. This rewrite is not immediately reflected in the
552572
/// IR. An internal IR mapping is updated, but the actual replacement is delayed
553573
/// until the rewrite is committed.
554-
class ReplaceBlockArgRewrite : public BlockRewrite {
574+
class ReplaceValueRewrite : public ValueRewrite {
555575
public:
556-
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
557-
Block *block, BlockArgument arg,
558-
const TypeConverter *converter)
559-
: BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
576+
ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value,
577+
const TypeConverter *converter)
578+
: ValueRewrite(Kind::ReplaceValue, rewriterImpl, value),
560579
converter(converter) {}
561580

562581
static bool classof(const IRRewrite *rewrite) {
563-
return rewrite->getKind() == Kind::ReplaceBlockArg;
582+
return rewrite->getKind() == Kind::ReplaceValue;
564583
}
565584

566585
void commit(RewriterBase &rewriter) override;
567586

568587
void rollback() override;
569588

570589
private:
571-
BlockArgument arg;
572-
573-
/// The current type converter when the block argument was replaced.
590+
/// The current type converter when the value was replaced.
574591
const TypeConverter *converter;
575592
};
576593

@@ -942,10 +959,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
942959
/// uses.
943960
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);
944961

945-
/// Replace the given block argument with the given values. The specified
962+
/// Replace the uses of the given value with the given values. The specified
946963
/// converter is used to build materializations (if necessary).
947-
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to,
948-
const TypeConverter *converter);
964+
void replaceAllUsesWith(Value from, ValueRange to,
965+
const TypeConverter *converter);
949966

950967
/// Erase the given block and its contents.
951968
void eraseBlock(Block *block);
@@ -1132,10 +1149,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11321149
IRRewriter notifyingRewriter;
11331150

11341151
#ifndef NDEBUG
1135-
/// A set of replaced block arguments. This set is for debugging purposes
1136-
/// only and it is maintained only if `allowPatternRollback` is set to
1137-
/// "true".
1138-
DenseSet<BlockArgument> replacedArgs;
1152+
/// A set of replaced values. This set is for debugging purposes only and it
1153+
/// is maintained only if `allowPatternRollback` is set to "true".
1154+
DenseSet<Value> replacedValues;
11391155

11401156
/// A set of operations that have pending updates. This tracking isn't
11411157
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1172,32 +1188,54 @@ void BlockTypeConversionRewrite::rollback() {
11721188
getNewBlock()->replaceAllUsesWith(getOrigBlock());
11731189
}
11741190

1175-
static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
1176-
Value repl) {
1191+
/// Replace all uses of `from` with `repl`.
1192+
static void performReplaceValue(RewriterBase &rewriter, Value from,
1193+
Value repl) {
11771194
if (isa<BlockArgument>(repl)) {
1178-
rewriter.replaceAllUsesWith(arg, repl);
1195+
// `repl` is a block argument. Directly replace all uses.
1196+
rewriter.replaceAllUsesWith(from, repl);
11791197
return;
11801198
}
11811199

1182-
// If the replacement value is an operation, we check to make sure that we
1183-
// don't replace uses that are within the parent operation of the
1184-
// replacement value.
1185-
Operation *replOp = cast<OpResult>(repl).getOwner();
1200+
// If the replacement value is an operation, only replace those uses that:
1201+
// - are in a different block than the replacement operation, or
1202+
// - are in the same block but after the replacement operation.
1203+
//
1204+
// Example:
1205+
// ^bb0(%arg0: i32):
1206+
// %0 = "consumer"(%arg0) : (i32) -> (i32)
1207+
// "another_consumer"(%arg0) : (i32) -> ()
1208+
//
1209+
// In the above example, replaceAllUsesWith(%arg0, %0) will replace the
1210+
// use in "another_consumer" but not the use in "consumer". When using the
1211+
// normal RewriterBase API, this would typically be done with
1212+
// `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not
1213+
// supported by the `ConversionPatternRewriter`. Due to the mapping mechanism
1214+
// it cannot be supported efficiently with `allowPatternRollback` set to
1215+
// "true". Therefore, the conversion driver is trying to be smart and replaces
1216+
// only those uses that do not lead to a dominance violation. E.g., the
1217+
// FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this
1218+
// behavior.
1219+
//
1220+
// TODO: As we move more and more towards `allowPatternRollback` set to
1221+
// "false", we should remove this special handling, in order to align the
1222+
// `ConversionPatternRewriter` API with the normal `RewriterBase` API.
1223+
Operation *replOp = repl.getDefiningOp();
11861224
Block *replBlock = replOp->getBlock();
1187-
rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
1225+
rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
11881226
Operation *user = operand.getOwner();
11891227
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
11901228
});
11911229
}
11921230

1193-
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
1194-
Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
1231+
void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
1232+
Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
11951233
if (!repl)
11961234
return;
1197-
performReplaceBlockArg(rewriter, arg, repl);
1235+
performReplaceValue(rewriter, value, repl);
11981236
}
11991237

1200-
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
1238+
void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); }
12011239

12021240
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
12031241
auto *listener =
@@ -1590,7 +1628,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15901628
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
15911629
/*castOp=*/nullptr, /*isPureTypeConversion=*/false)
15921630
.front();
1593-
replaceUsesOfBlockArgument(origArg, mat, converter);
1631+
replaceAllUsesWith(origArg, mat, converter);
15941632
continue;
15951633
}
15961634

@@ -1599,15 +1637,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
15991637
assert(inputMap->size == 0 &&
16001638
"invalid to provide a replacement value when the argument isn't "
16011639
"dropped");
1602-
replaceUsesOfBlockArgument(origArg, inputMap->replacementValues,
1603-
converter);
1640+
replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
16041641
continue;
16051642
}
16061643

16071644
// This is a 1->1+ mapping.
16081645
auto replArgs =
16091646
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1610-
replaceUsesOfBlockArgument(origArg, replArgs, converter);
1647+
replaceAllUsesWith(origArg, replArgs, converter);
16111648
}
16121649

16131650
if (config.allowPatternRollback)
@@ -1882,8 +1919,8 @@ void ConversionPatternRewriterImpl::replaceOp(
18821919
op->walk([&](Operation *op) { replacedOps.insert(op); });
18831920
}
18841921

1885-
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
1886-
BlockArgument from, ValueRange to, const TypeConverter *converter) {
1922+
void ConversionPatternRewriterImpl::replaceAllUsesWith(
1923+
Value from, ValueRange to, const TypeConverter *converter) {
18871924
if (!config.allowPatternRollback) {
18881925
SmallVector<Value> toConv = llvm::to_vector(to);
18891926
SmallVector<Value> repls =
@@ -1893,25 +1930,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
18931930
if (!repl)
18941931
return;
18951932

1896-
performReplaceBlockArg(r, from, repl);
1933+
performReplaceValue(r, from, repl);
18971934
return;
18981935
}
18991936

19001937
#ifndef NDEBUG
1901-
// Make sure that a block argument is not replaced multiple times. In
1902-
// rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
1903-
// uses of the given block argument, but also all future uses that may be
1904-
// introduced by future pattern applications. Therefore, it does not make
1905-
// sense to call `replaceUsesOfBlockArgument` multiple times with the same
1906-
// block argument. Doing so would overwrite the mapping and mess with the
1907-
// internal state of the dialect conversion driver.
1908-
assert(!replacedArgs.contains(from) &&
1909-
"attempting to replace a block argument that was already replaced");
1910-
replacedArgs.insert(from);
1938+
// Make sure that a value is not replaced multiple times. In rollback mode,
1939+
// `replaceAllUsesWith` replaces not only all current uses of the given value,
1940+
// but also all future uses that may be introduced by future pattern
1941+
// applications. Therefore, it does not make sense to call
1942+
// `replaceAllUsesWith` multiple times with the same value. Doing so would
1943+
// overwrite the mapping and mess with the internal state of the dialect
1944+
// conversion driver.
1945+
assert(!replacedValues.contains(from) &&
1946+
"attempting to replace a value that was already replaced");
1947+
replacedValues.insert(from);
19111948
#endif // NDEBUG
19121949

1913-
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
19141950
mapping.map(from, to);
1951+
appendRewrite<ReplaceValueRewrite>(from, converter);
19151952
}
19161953

19171954
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
@@ -2116,18 +2153,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
21162153
return impl->convertRegionTypes(*this, region, converter, entryConversion);
21172154
}
21182155

2119-
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
2120-
ValueRange to) {
2156+
void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
21212157
LLVM_DEBUG({
2122-
impl->logger.startLine() << "** Replace Argument : '" << from << "'";
2123-
if (Operation *parentOp = from.getOwner()->getParentOp()) {
2124-
impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2125-
<< "' (" << parentOp << ")\n";
2126-
} else {
2127-
impl->logger.getOStream() << " (unlinked block)\n";
2158+
impl->logger.startLine() << "** Replace Value : '" << from << "'";
2159+
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
2160+
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
2161+
impl->logger.getOStream() << " (in region of '" << parentOp->getName()
2162+
<< "' (" << parentOp << ")\n";
2163+
} else {
2164+
impl->logger.getOStream() << " (unlinked block)\n";
2165+
}
21282166
}
21292167
});
2130-
impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter);
2168+
impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
21312169
}
21322170

21332171
Value ConversionPatternRewriter::getRemappedValue(Value key) {
@@ -2185,7 +2223,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
21852223

21862224
// Replace all uses of block arguments.
21872225
for (auto it : llvm::zip(source->getArguments(), argValues))
2188-
replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
2226+
replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
21892227

21902228
if (fastPath) {
21912229
// Move all ops at once.

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ struct TestCreateIllegalBlock : public RewritePattern {
951951
}
952952
};
953953

954-
/// A simple pattern that tests the "replaceUsesOfBlockArgument" API.
954+
/// A simple pattern that tests the "replaceAllUsesWith" API.
955955
struct TestBlockArgReplace : public ConversionPattern {
956956
TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter)
957957
: ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1,
@@ -962,8 +962,7 @@ struct TestBlockArgReplace : public ConversionPattern {
962962
ConversionPatternRewriter &rewriter) const final {
963963
// Replace the first block argument with 2x the second block argument.
964964
Value repl = op->getRegion(0).getArgument(1);
965-
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
966-
{repl, repl});
965+
rewriter.replaceAllUsesWith(op->getRegion(0).getArgument(0), {repl, repl});
967966
rewriter.modifyOpInPlace(op, [&] {
968967
// If the "trigger_rollback" attribute is set, keep the op illegal, so
969968
// that a rollback is triggered.

0 commit comments

Comments
 (0)