@@ -71,6 +71,14 @@ LogicalResult
7171mlir::bufferization::dropEquivalentBufferResults (ModuleOp module ) {
7272 IRRewriter rewriter (module .getContext ());
7373
74+ DenseMap<func::FuncOp, DenseSet<func::CallOp>> callerMap;
75+ // Collect the mapping of functions to their call sites.
76+ module .walk ([&](func::CallOp callOp) {
77+ if (func::FuncOp calledFunc = getCalledFunction (callOp)) {
78+ callerMap[calledFunc].insert (callOp);
79+ }
80+ });
81+
7482 for (auto funcOp : module .getOps <func::FuncOp>()) {
7583 if (funcOp.isExternal ())
7684 continue ;
@@ -109,10 +117,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
109117 returnOp.getOperandsMutable ().assign (newReturnValues);
110118
111119 // Update function calls.
112- module .walk ([&](func::CallOp callOp) {
113- if (getCalledFunction (callOp) != funcOp)
114- return WalkResult::skip ();
115-
120+ for (func::CallOp callOp : callerMap[funcOp]) {
116121 rewriter.setInsertionPoint (callOp);
117122 auto newCallOp = rewriter.create <func::CallOp>(callOp.getLoc (), funcOp,
118123 callOp.getOperands ());
@@ -136,8 +141,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
136141 newResults.push_back (replacement);
137142 }
138143 rewriter.replaceOp (callOp, newResults);
139- return WalkResult::advance ();
140- });
144+ }
141145 }
142146
143147 return success ();
0 commit comments