Skip to content

Commit 410bcaf

Browse files
[mlir][Transforms] Support replaceAllUsesWith in dialect conversion
This commit adds support for `RewriterBase::replaceAllUsesWith` to the dialect conversion. Uses are not immediately replaced, but in a delayed fashion during the "commit" phase. No type conversions are performed; this is consistent with `ConversionPatternRewriter::replaceUsesOfBlockArgument`. - `RewriterBase::replaceAllUsesWith` is now virtual, so that it can be overridden in the dialect conversion. Note: `RewriterBase::replaceOp` can now be turned into a non-virtual function in a follow-up commit. - `ConversionPatternRewriter::replaceUsesOfBlockArgument` is generalized to `ConversionPatternRewriter::replaceAllUsesWith`, following the same implementation strategy. - A new kind of "IR rewrite" is added: `ValueRewrite` with `ReplaceAllUsesRewrite` (replacing `ReplaceBlockArgRewrite`) as the only value rewrite for now. - `replacedOps` is renamed to `erasedOps` to better capture its meaning. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
1 parent 9067f54 commit 410bcaf

File tree

8 files changed

+172
-97
lines changed

8 files changed

+172
-97
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ class RewriterBase : public OpBuilder {
634634

635635
/// Find uses of `from` and replace them with `to`. Also notify the listener
636636
/// about every in-place op modification (for every use that was replaced).
637-
void replaceAllUsesWith(Value from, Value to) {
637+
virtual void replaceAllUsesWith(Value from, Value to) {
638638
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
639639
Operation *op = operand.getOwner();
640640
modifyOpInPlace(op, [&]() { operand.set(to); });

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -697,9 +697,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
697697
Region *region, const TypeConverter &converter,
698698
ArrayRef<TypeConverter::SignatureConversion> blockConversions);
699699

700-
/// Replace all the uses of the block argument `from` with value `to`.
701-
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
702-
703700
/// Return the converted value of 'key' with a type defined by the type
704701
/// converter of the currently executing pattern. Return nullptr in the case
705702
/// of failure, the remapped value otherwise.
@@ -720,6 +717,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
720717
/// patterns even if a failure is encountered during the rewrite step.
721718
bool canRecoverFromRewriteFailure() const override { return true; }
722719

720+
/// Find uses of `from` and replace them with `to`.
721+
///
722+
/// Note: This function does not convert types.
723+
void replaceAllUsesWith(Value from, Value to) override;
724+
723725
/// PatternRewriter hook for replacing an operation.
724726
void replaceOp(Operation *op, ValueRange newValues) override;
725727

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ static void modifyFuncOpToUseBarePtrCallingConv(
310310
Location loc = funcOp.getLoc();
311311
auto placeholder = rewriter.create<LLVM::UndefOp>(
312312
loc, typeConverter.convertType(memrefTy));
313-
rewriter.replaceUsesOfBlockArgument(arg, placeholder);
313+
rewriter.replaceAllUsesWith(arg, placeholder);
314314

315315
Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
316316
memrefTy, arg);

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
201201
llvmFuncOp.getBody().getArgument(remapping->inputNo);
202202
auto placeholder = rewriter.create<LLVM::UndefOp>(
203203
loc, getTypeConverter()->convertType(memrefTy));
204-
rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
204+
rewriter.replaceAllUsesWith(newArg, placeholder);
205205
Value desc = MemRefDescriptor::fromStaticShape(
206206
rewriter, loc, *getTypeConverter(), memrefTy, newArg);
207207
rewriter.replaceOp(placeholder, {desc});

0 commit comments

Comments
 (0)