@@ -21,7 +21,7 @@ namespace bufferization {
2121} // namespace mlir
2222
2323using namespace mlir ;
24- using MemCpyFn = bufferization::BufferResultsToOutParamsOptions ::MemCpyFn;
24+ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts ::MemCpyFn;
2525
2626// / Return `true` if the given MemRef type has a fully dynamic layout.
2727static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -45,9 +45,12 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4545// Updates the func op and entry block.
4646//
4747// Any args appended to the entry block are added to `appendedEntryArgs`.
48+ // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
49+ // to each newly created function argument.
4850static LogicalResult
4951updateFuncOp (func::FuncOp func,
50- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
52+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
53+ bool addResultAttribute) {
5154 auto functionType = func.getFunctionType ();
5255
5356 // Collect information about the results will become appended arguments.
@@ -80,6 +83,10 @@ updateFuncOp(func::FuncOp func,
8083 for (int i = 0 , e = erasedResultTypes.size (); i < e; ++i, ++erasedIndicesIt) {
8184 func.setArgAttrs (functionType.getNumInputs () + i,
8285 func.getResultAttrs (*erasedIndicesIt));
86+ if (addResultAttribute)
87+ func.setArgAttr (functionType.getNumInputs () + i,
88+ StringAttr::get (func.getContext (), " bufferize.result" ),
89+ UnitAttr::get (func.getContext ()));
8390 }
8491
8592 // Erase the results.
@@ -127,7 +134,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
127134// temporary buffers for newly introduced out params.
128135static LogicalResult
129136updateCalls (ModuleOp module ,
130- const bufferization::BufferResultsToOutParamsOptions &options) {
137+ const bufferization::BufferResultsToOutParamsOpts &options) {
131138 bool didFail = false ;
132139 SymbolTable symtab (module );
133140 module .walk ([&](func::CallOp op) {
@@ -189,12 +196,13 @@ updateCalls(ModuleOp module,
189196
190197LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
191198 ModuleOp module ,
192- const bufferization::BufferResultsToOutParamsOptions &options) {
199+ const bufferization::BufferResultsToOutParamsOpts &options) {
193200 for (auto func : module .getOps <func::FuncOp>()) {
194201 if (!options.filterFn (&func))
195202 continue ;
196203 SmallVector<BlockArgument, 6 > appendedEntryArgs;
197- if (failed (updateFuncOp (func, appendedEntryArgs)))
204+ if (failed (
205+ updateFuncOp (func, appendedEntryArgs, options.addResultAttribute )))
198206 return failure ();
199207 if (func.isExternal ())
200208 continue ;
@@ -218,21 +226,25 @@ struct BufferResultsToOutParamsPass
218226 : bufferization::impl::BufferResultsToOutParamsBase<
219227 BufferResultsToOutParamsPass> {
220228 explicit BufferResultsToOutParamsPass (
221- const bufferization::BufferResultsToOutParamsOptions &options)
229+ const bufferization::BufferResultsToOutParamsOpts &options)
222230 : options(options) {}
223231
224232 void runOnOperation () override {
233+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
234+ if (addResultAttribute)
235+ options.addResultAttribute = true ;
236+
225237 if (failed (bufferization::promoteBufferResultsToOutParams (getOperation (),
226238 options)))
227239 return signalPassFailure ();
228240 }
229241
230242private:
231- bufferization::BufferResultsToOutParamsOptions options;
243+ bufferization::BufferResultsToOutParamsOpts options;
232244};
233245} // namespace
234246
235247std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass (
236- const bufferization::BufferResultsToOutParamsOptions &options) {
248+ const bufferization::BufferResultsToOutParamsOpts &options) {
237249 return std::make_unique<BufferResultsToOutParamsPass>(options);
238250}
0 commit comments