diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h index 809f03407258a..a729bc99b987c 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -148,7 +148,7 @@ std::unique_ptr createBufferLoopHoistingPass(); // Options struct for BufferResultsToOutParams pass. // Note: defined only here, not in tablegen. -struct BufferResultsToOutParamsOptions { +struct BufferResultsToOutParamsOpts { /// Memcpy function: Generate a memcpy between two memrefs. using MemCpyFn = std::function; @@ -162,17 +162,21 @@ struct BufferResultsToOutParamsOptions { /// Memcpy function; used to create a copy between two memrefs. /// If this is empty, memref.copy is used. std::optional memCpyFn; + + /// If true, the pass adds a "bufferize.result" attribute to each output + /// parameter. + bool addResultAttribute = false; }; /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass( - const BufferResultsToOutParamsOptions &options = {}); + const BufferResultsToOutParamsOpts &options = {}); /// Replace buffers that are returned from a function with an out parameter. /// Also update all call sites. LogicalResult promoteBufferResultsToOutParams(ModuleOp module, - const BufferResultsToOutParamsOptions &options); + const BufferResultsToOutParamsOpts &options); /// Creates a pass that drops memref function results that are equivalent to a /// function argument. diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td index e01f36b8daa18..1c3cdec81a39e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -316,6 +316,11 @@ def BufferResultsToOutParams : Pass<"buffer-results-to-out-params", "ModuleOp"> buffers for results need to be allocated in the caller. This currently only works for static shaped memrefs. }]; + let options = [ + Option<"addResultAttribute", "add-result-attr", "bool", + /*default=*/"false", + "Add the attribute 'bufferize.result' to all output parameters.">, + ]; let constructor = "mlir::bufferization::createBufferResultsToOutParamsPass()"; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 930f035339c1d..a2222e169c4d6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -21,7 +21,7 @@ namespace bufferization { } // namespace mlir using namespace mlir; -using MemCpyFn = bufferization::BufferResultsToOutParamsOptions::MemCpyFn; +using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; /// Return `true` if the given MemRef type has a fully dynamic layout. static bool hasFullyDynamicLayoutMap(MemRefType type) { @@ -45,9 +45,12 @@ static bool hasStaticIdentityLayout(MemRefType type) { // Updates the func op and entry block. // // Any args appended to the entry block are added to `appendedEntryArgs`. +// If `addResultAttribute` is true, adds the unit attribute `bufferize.result` +// to each newly created function argument. static LogicalResult updateFuncOp(func::FuncOp func, - SmallVectorImpl &appendedEntryArgs) { + SmallVectorImpl &appendedEntryArgs, + bool addResultAttribute) { auto functionType = func.getFunctionType(); // Collect information about the results will become appended arguments. @@ -80,6 +83,10 @@ updateFuncOp(func::FuncOp func, for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { func.setArgAttrs(functionType.getNumInputs() + i, func.getResultAttrs(*erasedIndicesIt)); + if (addResultAttribute) + func.setArgAttr(functionType.getNumInputs() + i, + StringAttr::get(func.getContext(), "bufferize.result"), + UnitAttr::get(func.getContext())); } // Erase the results. @@ -127,7 +134,7 @@ static LogicalResult updateReturnOps(func::FuncOp func, // temporary buffers for newly introduced out params. static LogicalResult updateCalls(ModuleOp module, - const bufferization::BufferResultsToOutParamsOptions &options) { + const bufferization::BufferResultsToOutParamsOpts &options) { bool didFail = false; SymbolTable symtab(module); module.walk([&](func::CallOp op) { @@ -189,12 +196,13 @@ updateCalls(ModuleOp module, LogicalResult mlir::bufferization::promoteBufferResultsToOutParams( ModuleOp module, - const bufferization::BufferResultsToOutParamsOptions &options) { + const bufferization::BufferResultsToOutParamsOpts &options) { for (auto func : module.getOps()) { if (!options.filterFn(&func)) continue; SmallVector appendedEntryArgs; - if (failed(updateFuncOp(func, appendedEntryArgs))) + if (failed( + updateFuncOp(func, appendedEntryArgs, options.addResultAttribute))) return failure(); if (func.isExternal()) continue; @@ -218,21 +226,25 @@ struct BufferResultsToOutParamsPass : bufferization::impl::BufferResultsToOutParamsBase< BufferResultsToOutParamsPass> { explicit BufferResultsToOutParamsPass( - const bufferization::BufferResultsToOutParamsOptions &options) + const bufferization::BufferResultsToOutParamsOpts &options) : options(options) {} void runOnOperation() override { + // Convert from pass options in tablegen to BufferResultsToOutParamsOpts. + if (addResultAttribute) + options.addResultAttribute = true; + if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(), options))) return signalPassFailure(); } private: - bufferization::BufferResultsToOutParamsOptions options; + bufferization::BufferResultsToOutParamsOpts options; }; } // namespace std::unique_ptr mlir::bufferization::createBufferResultsToOutParamsPass( - const bufferization::BufferResultsToOutParamsOptions &options) { + const bufferization::BufferResultsToOutParamsOpts &options) { return std::make_unique(options); } diff --git a/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir new file mode 100644 index 0000000000000..f4a95c73e2953 --- /dev/null +++ b/mlir/test/Transforms/buffer-results-to-out-params-add-result-attr.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -p 'builtin.module(buffer-results-to-out-params{add-result-attr})' -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @basic({{.*}}: memref {bufferize.result}) +func.func @basic() -> (memref) { + %0 = "test.source"() : () -> (memref) + return %0 : memref +} + +// ----- + +// CHECK-LABEL: multiple_results +// CHECK-SAME: memref<1xf32> {bufferize.result} +// CHECK-SAME: memref<2xf32> {bufferize.result} +func.func @multiple_results() -> (memref<1xf32>, memref<2xf32>) { + %0, %1 = "test.source"() : () -> (memref<1xf32>, memref<2xf32>) + return %0, %1 : memref<1xf32>, memref<2xf32> +}