@@ -21,7 +21,7 @@ namespace bufferization {
2121} // namespace mlir
2222
2323using namespace mlir ;
24- using MemCpyFn = bufferization::BufferResultsToOutParamsOptions ::MemCpyFn;
24+ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts ::MemCpyFn;
2525
2626// / Return `true` if the given MemRef type has a fully dynamic layout.
2727static bool hasFullyDynamicLayoutMap (MemRefType type) {
@@ -47,7 +47,8 @@ static bool hasStaticIdentityLayout(MemRefType type) {
4747// Any args appended to the entry block are added to `appendedEntryArgs`.
4848static LogicalResult
4949updateFuncOp (func::FuncOp func,
50- SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
50+ SmallVectorImpl<BlockArgument> &appendedEntryArgs,
51+ bool addResultAttribute) {
5152 auto functionType = func.getFunctionType ();
5253
5354 // Collect information about the results will become appended arguments.
@@ -80,6 +81,10 @@ updateFuncOp(func::FuncOp func,
8081 for (int i = 0 , e = erasedResultTypes.size (); i < e; ++i, ++erasedIndicesIt) {
8182 func.setArgAttrs (functionType.getNumInputs () + i,
8283 func.getResultAttrs (*erasedIndicesIt));
84+ if (addResultAttribute)
85+ func.setArgAttr (functionType.getNumInputs () + i,
86+ StringAttr::get (func.getContext (), " bufferize.result" ),
87+ UnitAttr::get (func.getContext ()));
8388 }
8489
8590 // Erase the results.
@@ -127,7 +132,7 @@ static LogicalResult updateReturnOps(func::FuncOp func,
127132// temporary buffers for newly introduced out params.
128133static LogicalResult
129134updateCalls (ModuleOp module ,
130- const bufferization::BufferResultsToOutParamsOptions &options) {
135+ const bufferization::BufferResultsToOutParamsOpts &options) {
131136 bool didFail = false ;
132137 SymbolTable symtab (module );
133138 module .walk ([&](func::CallOp op) {
@@ -189,12 +194,13 @@ updateCalls(ModuleOp module,
189194
190195LogicalResult mlir::bufferization::promoteBufferResultsToOutParams (
191196 ModuleOp module ,
192- const bufferization::BufferResultsToOutParamsOptions &options) {
197+ const bufferization::BufferResultsToOutParamsOpts &options) {
193198 for (auto func : module .getOps <func::FuncOp>()) {
194199 if (!options.filterFn (&func))
195200 continue ;
196201 SmallVector<BlockArgument, 6 > appendedEntryArgs;
197- if (failed (updateFuncOp (func, appendedEntryArgs)))
202+ if (failed (
203+ updateFuncOp (func, appendedEntryArgs, options.addResultAttribute )))
198204 return failure ();
199205 if (func.isExternal ())
200206 continue ;
@@ -218,21 +224,25 @@ struct BufferResultsToOutParamsPass
218224 : bufferization::impl::BufferResultsToOutParamsBase<
219225 BufferResultsToOutParamsPass> {
220226 explicit BufferResultsToOutParamsPass (
221- const bufferization::BufferResultsToOutParamsOptions &options)
227+ const bufferization::BufferResultsToOutParamsOpts &options)
222228 : options(options) {}
223229
224230 void runOnOperation () override {
231+ // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
232+ if (addResultAttribute)
233+ options.addResultAttribute = true ;
234+
225235 if (failed (bufferization::promoteBufferResultsToOutParams (getOperation (),
226236 options)))
227237 return signalPassFailure ();
228238 }
229239
230240private:
231- bufferization::BufferResultsToOutParamsOptions options;
241+ bufferization::BufferResultsToOutParamsOpts options;
232242};
233243} // namespace
234244
235245std::unique_ptr<Pass> mlir::bufferization::createBufferResultsToOutParamsPass (
236- const bufferization::BufferResultsToOutParamsOptions &options) {
246+ const bufferization::BufferResultsToOutParamsOpts &options) {
237247 return std::make_unique<BufferResultsToOutParamsPass>(options);
238248}
0 commit comments