diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index e3084530bd11b..dc10f3a1c58ae 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -52,6 +52,17 @@ def ApplyDecomposeTensorPackUnpackPatternsOp let assemblyFormat = "attr-dict"; } +def ApplyDecomposeTensorPadPatternsOp + : Op]> { + let description = [{ + Collect patterns to decompose tensor.pad into e.g. tensor::EmptyOp, + linalg::FillOp and tensor::InsertSliceOp. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 51967f83fee37..3c160d55a38e7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1503,8 +1503,8 @@ using OptimizeCopyFn = /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and /// InsertSliceOp. For now, only constant padding values are supported. -struct GeneralizePadOpPattern : public OpRewritePattern { - GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) +struct DecomposePadOpPattern : public OpRewritePattern { + DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit) {} LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override; @@ -1688,6 +1688,10 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, /// outer dims to be unit. void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns); +/// Populates patterns to decompose tensor.pad into e.g. +/// tensor.empty, linalg.fill, tensor.insert_slice. +void populateDecomposePadPatterns(RewritePatternSet &patterns); + /// Populates patterns to transform linalg.conv_2d_xxx operations into /// linalg.generic (for img2col packing) and linalg.matmul. /// \see rewriteInIm2Col for more details. diff --git a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp index 5bb79d4bc84e2..b0ca0ca13d062 100644 --- a/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp +++ b/mlir/lib/Conversion/TensorToLinalg/TensorToLinalg.cpp @@ -25,5 +25,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// void mlir::populateTensorToLinalgPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + // TODO: Add the remaining patterns, e.g. to decompose Pack/Unpack Ops. + // Alternatively, delete this file. + patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index ada80deacfdbf..e08be7d2ebd6a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -234,6 +234,11 @@ void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns( linalg::populateDecomposePackUnpackPatterns(patterns); } +void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateDecomposePadPatterns(patterns); +} + void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::ControlDropUnitDims options; @@ -3491,8 +3496,12 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( // Add misc. vectorization patterns (e.g. for tensor.insert_slice) linalg::populateInsertSliceVectorizationPatterns(patterns); - if (getVectorizePadding()) + if (getVectorizePadding()) { linalg::populatePadOpVectorizationPatterns(patterns); + // This creates an alternative path for lowering tensor.pad - by + // decomposing it into e.g. linalg.fill. + linalg::populateDecomposePadPatterns(patterns); + } vector::populateVectorStepLoweringPatterns(patterns); TrackingListener listener(state, *this); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index d92543d726462..c3e176299317e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -921,7 +921,7 @@ LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( /// Filling `dest` using FillOp constant padding value if possible. /// Otherwise, generate a tensor::GenerateOp. -Value GeneralizePadOpPattern::createFillOrGenerateOp( +Value DecomposePadOpPattern::createFillOrGenerateOp( RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector &dynSizes) const { auto padValue = padOp.getConstantPaddingValue(); @@ -938,8 +938,8 @@ Value GeneralizePadOpPattern::createFillOrGenerateOp( } LogicalResult -GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const { +DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const { // Given an OpFoldResult, return an index-typed value. auto getIdxValue = [&](OpFoldResult ofr) { if (auto val = llvm::dyn_cast_if_present(ofr)) @@ -1623,3 +1623,7 @@ void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { // TODO: Add and test patterns for tensor.unpack patterns.add(patterns.getContext()); } + +void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 23b46a2ee55f8..06bb6c0fb1cac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2770,12 +2770,6 @@ void mlir::linalg::populateInsertSliceVectorizationPatterns( void mlir::linalg::populatePadOpVectorizationPatterns( RewritePatternSet &patterns, PatternBenefit baseBenefit) { - // TODO: The following pattern implements "decomposition" and - // optional "vectorization". Seperate "decomposition" into a sepereate - // pre-processing pattern group. - patterns.add(patterns.getContext(), baseBenefit); - - // Try these specialized patterns first before resorting to the generic one. patterns.add( diff --git a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir b/mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir similarity index 98% rename from mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir rename to mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir index 2beab31b613d5..184361dfb30df 100644 --- a/mlir/test/Dialect/Linalg/generalize-pad-tensor.mlir +++ b/mlir/test/Dialect/Linalg/decompose-pad-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-pad-tensor" %s | FileCheck %s +// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-decompose-pad-tensor" %s | FileCheck %s // CHECK-LABEL: func @generalize_pad_tensor_static_shape( // CHECK-SAME: %[[IN:.*]]: tensor<1x28x28x1xf32>) -> tensor<1x32x32x1xf32> { diff --git a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir index 640de85cc5f12..41e480648177f 100644 --- a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir @@ -202,6 +202,8 @@ module attributes {transform.with_named_sequence} { %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { + // TODO: Split into two tests, one for each pattern + transform.apply_patterns.linalg.decompose_pad transform.apply_patterns.linalg.pad_vectorization } : !transform.op<"func.func"> transform.yield @@ -236,6 +238,8 @@ module attributes {transform.with_named_sequence} { %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { + // TODO: Split into two tests, one for each pattern + transform.apply_patterns.linalg.decompose_pad transform.apply_patterns.linalg.pad_vectorization } : !transform.op<"func.func"> transform.yield @@ -270,6 +274,8 @@ module attributes {transform.with_named_sequence} { %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { + // TODO: Split into two tests, one for each pattern + transform.apply_patterns.linalg.decompose_pad transform.apply_patterns.linalg.pad_vectorization } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index c65e68eaf31f0..25aec75c3c14a 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -70,8 +70,8 @@ struct TestLinalgTransforms llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; - Option testGeneralizePadTensor{ - *this, "test-generalize-pad-tensor", + Option testDecomposePadTensor{ + *this, "test-decompose-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; Option testDecomposeTensorPackOp{ @@ -166,9 +166,9 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } -static void applyGeneralizePadTensorPatterns(func::FuncOp funcOp) { +static void applyDecomposePadPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); - patterns.add(funcOp.getContext()); + patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } @@ -235,8 +235,8 @@ void TestLinalgTransforms::runOnOperation() { return applyVectorTransferForwardingPatterns(getOperation()); if (testGenericToVectorPattern) return applyLinalgToVectorPatterns(getOperation()); - if (testGeneralizePadTensor) - return applyGeneralizePadTensorPatterns(getOperation()); + if (testDecomposePadTensor) + return applyDecomposePadPatterns(getOperation()); if (testDecomposeTensorPackOp) return applyDecomposeTensorPackPatterns(getOperation()); if (testDecomposeTensorUnPackOp)