From 96eb4d3991d066dce0600549b7b463108a7edb47 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 22 Nov 2024 14:00:32 +0000 Subject: [PATCH] [mlir][linalg] Extract `GeneralizePadOpPattern` into a standalone transformation Currently, `GeneralizePadOpPattern` is grouped under `populatePadOpVectorizationPatterns`. However, as noted in #111349, this transformation "decomposes" rather than "vectorizes" `tensor.pad`. As such, it functions as: * a vectorization _pre-processing_ transformation, not * a vectorization transformation itself. To clarify its purpose, this PR turns `GeneralizePadOpPattern` into a standalone transformation by: * introducing a dedicated `populateDecomposePadPatterns` method, * adding a `apply_patterns.linalg.decompose_pad` Transform Dialect Op, and * removing it from `populatePadOpVectorizationPatterns`. In addition, to better reflect its role, it is renamed as "decomposition" rather then "generalization". That's to better reflect its role. This is in line with the recent renaming of similar ops, i.e. tensor.pack/tensor.unpack Ops in #116439. --- .../Linalg/TransformOps/LinalgTransformOps.td | 11 +++++++++++ .../mlir/Dialect/Linalg/Transforms/Transforms.h | 8 ++++++-- .../lib/Conversion/TensorToLinalg/TensorToLinalg.cpp | 4 +++- .../Linalg/TransformOps/LinalgTransformOps.cpp | 11 ++++++++++- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 10 +++++++--- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 6 ------ ...ize-pad-tensor.mlir => decompose-pad-tensor.mlir} | 2 +- .../Dialect/Linalg/vectorization-pad-patterns.mlir | 6 ++++++ .../test/lib/Dialect/Linalg/TestLinalgTransforms.cpp | 12 ++++++------ 9 files changed, 50 insertions(+), 20 deletions(-) rename mlir/test/Dialect/Linalg/{generalize-pad-tensor.mlir => decompose-pad-tensor.mlir} (98%) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index e3084530bd11b..dc10f3a1c58ae 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -52,6 +52,17 @@ def ApplyDecomposeTensorPackUnpackPatternsOp let assemblyFormat = "attr-dict"; } +def ApplyDecomposeTensorPadPatternsOp + : Op]> { + let description = [{ + Collect patterns to decompose tensor.pad into e.g. tensor::EmptyOp, + linalg::FillOp and tensor::InsertSliceOp. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 51967f83fee37..3c160d55a38e7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1503,8 +1503,8 @@ using OptimizeCopyFn = /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and /// InsertSliceOp. For now, only constant padding values are supported. -struct GeneralizePadOpPattern : public OpRewritePattern { - GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) +struct DecomposePadOpPattern : public OpRewritePattern { + DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit) {} LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override; @@ -1688,6 +1688,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, /// outer dims to be unit. void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns); +/// Populates patterns to decompose tensor.pad into e.g. +/// tensor.empty, linalg.fill, tensor.insert_slice. +void populateDecomposePadPatterns(RewritePatternSet &patterns); + /// Populates patterns to transform linalg.conv_2d_xxx operations into /// linalg.generic (for img2col packing) and linalg.matmul. /// \see rewriteInIm2Col for more details. diff --git a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp index 5bb79d4bc84e2..b0ca0ca13d062 100644 --- a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp +++ b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp @@ -25,5 +25,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + // TODO: Add the remaining patterns, e.g. to decompose Pack/Unpack Ops. + // Alternatively, delete this file. + patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ada80deacfdbf..e08be7d2ebd6a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -234,6 +234,11 @@ void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns( linalg::populateDecomposePackUnpackPatterns(patterns); } +void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateDecomposePadPatterns(patterns); +} + void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::ControlDropUnitDims options; @@ -3491,8 +3496,12 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( // Add misc. vectorization patterns (e.g. for tensor.insert_slice) linalg::populateInsertSliceVectorizationPatterns(patterns); - if (getVectorizePadding()) + if (getVectorizePadding()) { linalg::populatePadOpVectorizationPatterns(patterns); + // This creates an alternative path for lowering tensor.pad - by + // decomposing it into e.g. linalg.fill. + linalg::populateDecomposePadPatterns(patterns); + } vector::populateVectorStepLoweringPatterns(patterns); TrackingListener listener(state, *this); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index d92543d726462..c3e176299317e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -921,7 +921,7 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( /// Filling `dest` using FillOp constant padding value if possible. /// Otherwise, generate a tensor::GenerateOp. -Value GeneralizePadOpPattern::createFillOrGenerateOp( +Value DecomposePadOpPattern::createFillOrGenerateOp( RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector &dynSizes) const { auto padValue = padOp.getConstantPaddingValue(); @@ -938,8 +938,8 @@ Value GeneralizePadOpPattern::createFillOrGenerateOp( } LogicalResult -GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const { +DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const { // Given an OpFoldResult, return an index-typed value. auto getIdxValue = [&](OpFoldResult ofr) { if (auto val = llvm::dyn_cast_if_present(ofr)) @@ -1623,3 +1623,7 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { // TODO: Add and test patterns for tensor.unpack patterns.add(patterns.getContext()); } + +void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 23b46a2ee55f8..06bb6c0fb1cac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2770,12 +2770,6 @@ void mlir::linalg::populateInsertSliceVectorizationPatterns( void mlir::linalg::populatePadOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { - // TODO: The following pattern implements "decomposition" and - // optional "vectorization". Seperate "decomposition" into a sepereate - // pre-processing pattern group. - patterns.add(patterns.getContext(), baseBenefit); - - // Try these specialized patterns first before resorting to the generic one. patterns.add( diff --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir similarity index 98% rename from mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir rename to mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir index 2beab31b613d5..184361dfb30df 100644 --- a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor" %s | FileCheck %s +// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-pad-tensor" %s | FileCheck %s // CHECK-LABEL: func @generalize_pad_tensor_static_shape( // CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { diff --git a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir index 640de85cc5f12..41e480648177f 100644 --- a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir @@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} { %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { + // TODO: Split into two tests, one for each pattern + transform.apply_patterns.linalg.decompose_pad transform.apply_patterns.linalg.pad_vectorization } : !transform.op<"func.func"> transform.yield @@ -236,6 +238,8 @@ module attributes {transform.with_named_sequence} { %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { + // TODO: Split into two tests, one for each pattern + transform.apply_patterns.linalg.decompose_pad transform.apply_patterns.linalg.pad_vectorization } : !transform.op<"func.func"> transform.yield @@ -270,6 +274,8 @@ module attributes {transform.with_named_sequence} { %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { + // TODO: Split into two tests, one for each pattern + transform.apply_patterns.linalg.decompose_pad transform.apply_patterns.linalg.pad_vectorization } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index c65e68eaf31f0..25aec75c3c14a 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -70,8 +70,8 @@ struct TestLinalgTransforms llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; - Option testGeneralizePadTensor{ - *this, "test-generalize-pad-tensor", + Option testDecomposePadTensor{ + *this, "test-decompose-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; Option testDecomposeTensorPackOp{ @@ -166,9 +166,9 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) { +static void applyDecomposePadPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); - patterns.add(funcOp.getContext()); + patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -235,8 +235,8 @@ void TestLinalgTransforms::runOnOperation() { return applyVectorTransferForwardingPatterns(getOperation()); if (testGenericToVectorPattern) return applyLinalgToVectorPatterns(getOperation()); - if (testGeneralizePadTensor) - return applyGeneralizePadTensorPatterns(getOperation()); + if (testDecomposePadTensor) + return applyDecomposePadPatterns(getOperation()); if (testDecomposeTensorPackOp) return applyDecomposeTensorPackPatterns(getOperation()); if (testDecomposeTensorUnPackOp)