@@ -41,18 +41,38 @@ namespace bufferization {
4141
4242using namespace mlir ;
4343
44- // / Return the unique ReturnOp that terminates `funcOp`.
45- // / Return nullptr if there is no such unique ReturnOp.
46- static func::ReturnOp getAssumedUniqueReturnOp (func::FuncOp funcOp) {
47- func::ReturnOp returnOp;
44+ // / Get all the ReturnOp in the funcOp.
45+ static SmallVector<func::ReturnOp> getReturnOps (func::FuncOp funcOp) {
46+ SmallVector<func::ReturnOp> returnOps;
4847 for (Block &b : funcOp.getBody ()) {
4948 if (auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator ())) {
50- if (returnOp)
51- return nullptr ;
52- returnOp = candidateOp;
49+ returnOps.push_back (candidateOp);
5350 }
5451 }
55- return returnOp;
52+ return returnOps;
53+ }
54+
55+ // / Get the values at the same position in the `returnOps`.
56+ static SmallVector<Value>
57+ getReturnOpsOperandInPos (ArrayRef<func::ReturnOp> returnOps, size_t pos) {
58+ SmallVector<Value> operands;
59+ for (func::ReturnOp returnOp : returnOps) {
60+ operands.push_back (returnOp.getOperand (pos));
61+ }
62+ return operands;
63+ }
64+
65+ // / Check if the value in operands is equal to the argument.
66+ static bool operandsEqualFuncArgument (ArrayRef<Value> operands,
67+ BlockArgument argument) {
68+ for (Value val : operands) {
69+ while (auto castOp = val.getDefiningOp <memref::CastOp>())
70+ val = castOp.getSource ();
71+
72+ if (val != argument)
73+ return false ;
74+ }
75+ return true ;
5676}
5777
5878LogicalResult
@@ -72,40 +92,44 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
7292 for (auto funcOp : module .getOps <func::FuncOp>()) {
7393 if (funcOp.isExternal () || funcOp.isPublic ())
7494 continue ;
75- func::ReturnOp returnOp = getAssumedUniqueReturnOp (funcOp);
76- // TODO: Support functions with multiple blocks.
77- if (!returnOp)
95+ SmallVector<func::ReturnOp> returnOps = getReturnOps (funcOp);
96+ if (returnOps.empty ())
7897 continue ;
98+ func::ReturnOp returnOp = returnOps.front ();
7999
80100 // Compute erased results.
81- SmallVector<Value> newReturnValues;
101+ SmallVector<SmallVector< Value>> newReturnValues (returnOps. size ()) ;
82102 BitVector erasedResultIndices (funcOp.getFunctionType ().getNumResults ());
83103 DenseMap<int64_t , int64_t > resultToArgs;
84- for (const auto &it : llvm::enumerate ( returnOp.getOperands ()) ) {
104+ for (size_t i = 0 , e = returnOp.getOperands (). size (); i < e; ++i ) {
85105 bool erased = false ;
106+ SmallVector<Value> returnOperands =
107+ getReturnOpsOperandInPos (returnOps, i);
86108 for (BlockArgument bbArg : funcOp.getArguments ()) {
87- Value val = it.value ();
88- while (auto castOp = val.getDefiningOp <memref::CastOp>())
89- val = castOp.getSource ();
90-
91- if (val == bbArg) {
92- resultToArgs[it.index ()] = bbArg.getArgNumber ();
109+ if (operandsEqualFuncArgument (returnOperands, bbArg)) {
110+ resultToArgs[i] = bbArg.getArgNumber ();
93111 erased = true ;
94112 break ;
95113 }
96114 }
97115
98116 if (erased) {
99- erasedResultIndices.set (it. index () );
117+ erasedResultIndices.set (i );
100118 } else {
101- newReturnValues.push_back (it.value ());
119+ for (auto [newReturnValue, operand] :
120+ llvm::zip (newReturnValues, returnOperands)) {
121+ newReturnValue.push_back (operand);
122+ }
102123 }
103124 }
104125
105126 // Update function.
106127 if (failed (funcOp.eraseResults (erasedResultIndices)))
107128 return failure ();
108- returnOp.getOperandsMutable ().assign (newReturnValues);
129+
130+ for (auto [returnOp, newReturnValue] :
131+ llvm::zip (returnOps, newReturnValues))
132+ returnOp.getOperandsMutable ().assign (newReturnValue);
109133
110134 // Update function calls.
111135 for (func::CallOp callOp : callerMap[funcOp]) {
0 commit comments