From c50459de767b6d3b34fd2ccd1ee08a8dda5d76ec Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 12 Nov 2024 20:33:29 +0000 Subject: [PATCH] [mlir] Add apply_patterns.linalg.generalize_pack_unpack Transform Dialect Op This PR introduces populateGeneralizePatterns, which collects the following patterns: * `GeneralizeOuterUnitDimsPackOpPattern`, * `GeneralizeOuterUnitDimsUnPackOpPattern` (currently a TODO). These patterns are wrapped in a new Transform Dialect Op: `apply_patterns.linalg.generalize_pack_unpack`. This Op facilitates creating more involved end-to-end compilation pipelines for `tensor.pack` and `tensor.unpack` operations. It will be required in an upcoming PR building on top of #115698. No new tests are added in this PR. Instead, existing tests from: * "generalize-tensor-pack.mlir" are reused. To achieve this: * I've updated the test to use `transform.apply_patterns.linalg.generalize_pack_unpack` instead of the flag `--test-linalg-transform-patterns="test-generalize-tensor-pack"`, avoiding artificial tests solely for the TD Op. * The TD sequence is saved to a new file, "generalize_pack.mlir", and pre-loaded using the option: `--transform-preload-library='transform-library-paths=%p/td/generalize_pack.mlir'` This avoids duplicating the sequence for every "split" in the input file. * Added lit.local.cfg to exclude the "test/Dialect/Linalg/td" directory from test discovery, ensuring "generalize_pack.mlir" is not treated as a test file. --- .../Linalg/TransformOps/LinalgTransformOps.td | 12 ++++++++++++ .../mlir/Dialect/Linalg/Transforms/Transforms.h | 9 +++++++-- .../Linalg/TransformOps/LinalgTransformOps.cpp | 5 +++++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 5 +++++ mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir | 3 +-- mlir/test/Dialect/Linalg/lit.local.cfg | 2 ++ mlir/test/Dialect/Linalg/td/generalize-pack.mlir | 12 ++++++++++++ 7 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/lit.local.cfg create mode 100644 mlir/test/Dialect/Linalg/td/generalize-pack.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index f256af2f6b12b..42057d8d0c910 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -41,6 +41,18 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op]> { + let description = [{ + Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to + decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires + all outer dims to be unit. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 89e9a3b70d2ab..0b55a76f88433 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1516,8 +1516,8 @@ struct GeneralizePadOpPattern : public OpRewritePattern { }; /// Rewrites a tensor::PackOp into a sequence of: -/// * tensor::PadOp + linalg::TransposeOp + -/// tensor::EmptyOp + tensor::InsertSliceOp ops. +/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp + +/// tensor::InsertSliceOp ops. /// /// Required that all the outer dims of the input tensor::PackOp are 1. /// @@ -1683,6 +1683,11 @@ void populateLinalgGenericOpsSpecializationPatterns( void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g. +/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all +/// outer dims to be unit. +void populateGeneralizePatterns(RewritePatternSet &patterns); + /// Populates patterns to transform linalg.conv_2d_xxx operations into /// linalg.generic (for img2col packing) and linalg.matmul. /// \see rewriteInIm2Col for more details. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 1956fc634ef39..a00c609779c3a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -229,6 +229,11 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns( linalg::populateEraseUnnecessaryInputsPatterns(patterns); } +void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateGeneralizePatterns(patterns); +} + void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::ControlDropUnitDims options; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index ed9ebca4f306a..c9eac66367559 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1618,3 +1618,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, DownscaleSizeOneWindowed2DConvolution>( patterns.getContext(), benefit); } + +void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) { + // TODO: Add and test patterns for tensor.unpack + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir index f4b1d9a55f091..ad20541e301d3 100644 --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s - +// RUN: mlir-opt --transform-preload-library='transform-library-paths=%p/td/generalize-pack.mlir' -split-input-file --transform-interpreter %s | FileCheck %s func.func @simple_KCRS_to_KCRSsr(%arg0: tensor, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> { %c8 = arith.constant 8 : index diff --git a/mlir/test/Dialect/Linalg/lit.local.cfg b/mlir/test/Dialect/Linalg/lit.local.cfg new file mode 100644 index 0000000000000..62743008a3e3a --- /dev/null +++ b/mlir/test/Dialect/Linalg/lit.local.cfg @@ -0,0 +1,2 @@ +# Skip the directory with input TD sequences +config.excludes = ["td"] diff --git a/mlir/test/Dialect/Linalg/td/generalize-pack.mlir b/mlir/test/Dialect/Linalg/td/generalize-pack.mlir new file mode 100644 index 0000000000000..62e5b779ff361 --- /dev/null +++ b/mlir/test/Dialect/Linalg/td/generalize-pack.mlir @@ -0,0 +1,12 @@ +module @transforms attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op + + %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.generalize_pack_unpack + } : !transform.any_op + + transform.yield + } +}