From 7f71eba0a72ed92be61d03284e61e003d90d25f1 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 6 Oct 2023 18:01:58 +0000 Subject: [PATCH 1/2] [mlir][sparse] introduce a pass to stage complex sparse operations into simple steps --- .../Dialect/SparseTensor/Transforms/Passes.h | 9 +++++++++ .../Dialect/SparseTensor/Transforms/Passes.td | 12 ++++++++++++ .../SparseTensor/Transforms/CMakeLists.txt | 1 + .../Transforms/SparseTensorPasses.cpp | 17 +++++++++++++++++ .../Transforms/StageSparseOperations.cpp | 4 ++++ 5 files changed, 43 insertions(+) create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index c1e217675020f..c537e92a51d53 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -87,6 +87,15 @@ std::unique_ptr createSparsificationPass(); std::unique_ptr createSparsificationPass(const SparsificationOptions &options); +//===----------------------------------------------------------------------===// +// The StageSparseOperations pass. +//===----------------------------------------------------------------------===// + +/// Sets up StageSparseOperation rewriting rules. +void populateStageSparseOperationsPatterns(RewritePatternSet &patterns); + +std::unique_ptr createStageSparseOperationsPass(); + //===----------------------------------------------------------------------===// // The PostSparsificationRewriting pass. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index d8d5dbb5ad3ce..7071c3091d33f 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -123,6 +123,18 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> { ]; } +def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> { + let summary = "Decompose a complex sparse operations into multiple stages"; + let description = [{ + A pass that decomposes a complex sparse operations into multiple stages. + E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC. + }]; + let constructor = "mlir::createStageSparseOperationsPass()"; + let dependentDialects = [ + "sparse_tensor::SparseTensorDialect", + ]; +} + def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> { let summary = "Applies sparse tensor rewriting rules after sparsification"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt index 5ef9d906f0e8b..0ca6668c8c747 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms SparseVectorization.cpp Sparsification.cpp SparsificationAndBufferizationPass.cpp + StageSparseOperations.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index f50d3d4606554..e1f88ad9c0e11 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -30,6 +30,7 @@ namespace mlir { #define GEN_PASS_DEF_SPARSEBUFFERREWRITE #define GEN_PASS_DEF_SPARSEVECTORIZATION #define GEN_PASS_DEF_SPARSEGPUCODEGEN +#define GEN_PASS_DEF_STAGESPARSEOPERATIONS #define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -92,6 +93,18 @@ struct SparsificationPass } }; +struct StageSparseOperationsPass + : public impl::StageSparseOperationsBase { + StageSparseOperationsPass() = default; + StageSparseOperationsPass(const StageSparseOperationsPass &pass) = default; + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateStageSparseOperationsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct PostSparsificationRewritePass : public impl::PostSparsificationRewriteBase< PostSparsificationRewritePass> { @@ -384,6 +397,10 @@ mlir::createSparsificationPass(const SparsificationOptions &options) { return std::make_unique(options); } +std::unique_ptr mlir::createStageSparseOperationsPass() { + return std::make_unique(); +} + std::unique_ptr mlir::createPostSparsificationRewritePass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp new file mode 100644 index 0000000000000..4adc4d131198c --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp @@ -0,0 +1,4 @@ +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" + +void mlir::populateStageSparseOperationsPatterns( + RewritePatternSet & /*patterns*/) {} From b1e8f710bd2e7f4df6dd0d4be9f28b13388a0e6b Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Fri, 6 Oct 2023 21:08:09 +0000 Subject: [PATCH 2/2] address comments --- mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td index 7071c3091d33f..8f116bff9b185 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -124,10 +124,10 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> { } def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> { - let summary = "Decompose a complex sparse operations into multiple stages"; + let summary = "Decompose a complex sparse operation into multiple stages"; let description = [{ - A pass that decomposes a complex sparse operations into multiple stages. - E.g., CSR -> CSC conversion is staged into CSR -> COO (unordered) -> sort -> CSC. + A pass that decomposes a complex sparse operation into multiple stages. + E.g., CSR -> CSC is staged into CSR -> COO (unordered) -> sort -> CSC. }]; let constructor = "mlir::createStageSparseOperationsPass()"; let dependentDialects = [