Skip to content

Commit 191d6f0

Browse files
make drop-equivalent-buffer-results support mult blocks.
1 parent 24c5926 commit 191d6f0

File tree

2 files changed

+76
-22
lines changed

2 files changed

+76
-22
lines changed

mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,38 @@ namespace bufferization {
4141

4242
using namespace mlir;
4343

44-
/// Return the unique ReturnOp that terminates `funcOp`.
45-
/// Return nullptr if there is no such unique ReturnOp.
46-
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp) {
47-
func::ReturnOp returnOp;
44+
/// Get all the ReturnOp in the funcOp.
45+
static SmallVector<func::ReturnOp> getReturnOps(func::FuncOp funcOp) {
46+
SmallVector<func::ReturnOp> returnOps;
4847
for (Block &b : funcOp.getBody()) {
4948
if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
50-
if (returnOp)
51-
return nullptr;
52-
returnOp = candidateOp;
49+
returnOps.push_back(candidateOp);
5350
}
5451
}
55-
return returnOp;
52+
return returnOps;
53+
}
54+
55+
/// Get the values at the same position in the `returnOps`.
56+
static SmallVector<Value>
57+
getReturnOpsOperandInPos(ArrayRef<func::ReturnOp> returnOps, size_t pos) {
58+
SmallVector<Value> operands;
59+
for (func::ReturnOp returnOp : returnOps) {
60+
operands.push_back(returnOp.getOperand(pos));
61+
}
62+
return operands;
63+
}
64+
65+
/// Check if the value in operands is equal to the argument.
66+
static bool operandsEqualFuncArgument(ArrayRef<Value> operands,
67+
BlockArgument argument) {
68+
for (Value val : operands) {
69+
while (auto castOp = val.getDefiningOp<memref::CastOp>())
70+
val = castOp.getSource();
71+
72+
if (val != argument)
73+
return false;
74+
}
75+
return true;
5676
}
5777

5878
LogicalResult
@@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
7292
for (auto funcOp : module.getOps<func::FuncOp>()) {
7393
if (funcOp.isExternal() || funcOp.isPublic())
7494
continue;
75-
func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
76-
// TODO: Support functions with multiple blocks.
77-
if (!returnOp)
95+
SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
96+
if (returnOps.empty())
7897
continue;
98+
func::ReturnOp returnOp = returnOps.front();
7999

80100
// Compute erased results.
81-
SmallVector<Value> newReturnValues;
101+
SmallVector<SmallVector<Value>> newReturnValues(returnOps.size());
82102
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
83103
DenseMap<int64_t, int64_t> resultToArgs;
84-
for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
104+
for (size_t i = 0, e = returnOp.getOperands().size(); i < e; ++i) {
85105
bool erased = false;
106+
SmallVector<Value> returnOperands =
107+
getReturnOpsOperandInPos(returnOps, i);
86108
for (BlockArgument bbArg : funcOp.getArguments()) {
87-
Value val = it.value();
88-
while (auto castOp = val.getDefiningOp<memref::CastOp>())
89-
val = castOp.getSource();
90-
91-
if (val == bbArg) {
92-
resultToArgs[it.index()] = bbArg.getArgNumber();
109+
if (operandsEqualFuncArgument(returnOperands, bbArg)) {
110+
resultToArgs[i] = bbArg.getArgNumber();
93111
erased = true;
94112
break;
95113
}
96114
}
97115

98116
if (erased) {
99-
erasedResultIndices.set(it.index());
117+
erasedResultIndices.set(i);
100118
} else {
101-
newReturnValues.push_back(it.value());
119+
for (auto [newReturnValue, operand] :
120+
llvm::zip(newReturnValues, returnOperands)) {
121+
newReturnValue.push_back(operand);
122+
}
102123
}
103124
}
104125

105126
// Update function.
106127
if (failed(funcOp.eraseResults(erasedResultIndices)))
107128
return failure();
108-
returnOp.getOperandsMutable().assign(newReturnValues);
129+
130+
for (auto [returnOp, newReturnValue] :
131+
llvm::zip(returnOps, newReturnValues))
132+
returnOp.getOperandsMutable().assign(newReturnValue);
109133

110134
// Update function calls.
111135
for (func::CallOp callOp : callerMap[funcOp]) {

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,33 @@ func.func @collapse_shape_regression(
490490
tensor.collapse_shape %0[[0, 1]] : tensor<5x6xf32> into tensor<30xf32>
491491
return
492492
}
493+
494+
// -----
495+
496+
497+
// CHECK-LABEL: func private @mult_return_callee(
498+
// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
499+
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> index {
500+
func.func private @mult_return_callee(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
501+
%casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
502+
// CHECK: cf.cond_br %[[COND]], ^bb1, ^bb2
503+
// CHECK: ^bb1:
504+
// CHECK: return %[[A]] : index
505+
// CHECK: ^bb2:
506+
// CHECK: return %[[B]] : index
507+
cf.cond_br %cond,^a, ^b
508+
^a:
509+
return %casted, %a : tensor<10xf32>, index
510+
^b:
511+
return %casted, %b : tensor<10xf32>, index
512+
}
513+
514+
// CHECK-LABEL: func @mult_return(
515+
// CHECK-SAME: %[[T:.*]]: memref<?xf32, strided<[?], offset: ?>>, %[[COND:.*]]: i1,
516+
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index) -> (memref<?xf32, strided<[?], offset: ?>>, index) {
517+
func.func @mult_return(%t: tensor<?xf32>, %cond:i1, %a: index, %b: index) -> (tensor<10xf32>, index) {
518+
// CHECK: %[[RET:.*]] = call @mult_return_callee(%[[T]], %[[COND]], %[[A]], %[[B]]) : (memref<?xf32, strided<[?], offset: ?>>, i1, index, index) -> index
519+
// CHECK: return %[[T]], %[[RET]] : memref<?xf32, strided<[?], offset: ?>>, index
520+
%t_res, %v = func.call @mult_return_callee(%t, %cond, %a, %b) : (tensor<?xf32>, i1, index, index) -> (tensor<10xf32>, index)
521+
return %t_res, %v : tensor<10xf32>, index
522+
}

0 commit comments

Comments
 (0)