diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 741de84cc5840..543180e68190f 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -107,6 +107,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { LogicalResult verifyOutputZeroPoint(int64_t zp); }]; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -153,6 +154,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { }]; let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -244,6 +247,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { }]; let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 09d2c5d35263c..c4ef7d0bb9ff5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -39,6 +39,280 @@ using namespace mlir::tosa; // Operator Canonicalizers. //===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Tensor Data Engine Operators. +//===----------------------------------------------------------------------===// + +// Check that the zero point of the tensor and padding operations are aligned. +bool checkMatchingPadConstAndZp(Value padConst, Value zp) { + // Check that padConst is a constant value and a scalar tensor + DenseElementsAttr padConstAttr; + if (!matchPattern(padConst, m_Constant(&padConstAttr)) || + (padConstAttr.size() != 1)) { + return false; + } + + // Check that floating point pad is zero + if (auto padConstFpAttr = mlir::dyn_cast(padConstAttr)) { + float padConstVal = (*padConstFpAttr.begin()).convertToFloat(); + return padConstVal == 0.0f; + } + + // Check that the zp and padConst align for the integer (quantized) case + if (auto padConstIntAttr = + mlir::dyn_cast(padConstAttr)) { + DenseIntElementsAttr zpAttr; + // Check that zp is a constant value and a scalar tensor + if (!matchPattern(zp, m_Constant(&zpAttr)) || (padConstAttr.size() != 1)) { + return false; + } + + // Check equality + int64_t zpVal = (*zpAttr.begin()).getSExtValue(); + int64_t padConstVal = (*padConstIntAttr.begin()).getSExtValue(); + return zpVal == padConstVal; + } + + // Bail-out on unsupported type + return false; +} + +namespace { +template +struct PoolPadFoldAdaptor; + +template <> +struct PoolPadFoldAdaptor { + using OpTy = tosa::AvgPool2dOp; + static bool checkKernelCompliance(OpTy op, const ArrayRef newPad) { + const llvm::ArrayRef kernel = op.getKernel(); + if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] || + newPad[0] >= kernel[0] || newPad[1] >= kernel[0]) + return false; + return true; + } + static bool checkPadConstCompliance(OpTy op, Value padConst) { + return checkMatchingPadConstAndZp(padConst, op.getInputZp()); + } + static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, + Value padInput, ArrayRef newPad) { + rewriter.replaceOpWithNewOp( + op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(), + op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad), + op.getAccType()); + } +}; + +template <> +struct PoolPadFoldAdaptor { + using OpTy = tosa::MaxPool2dOp; + static bool checkKernelCompliance(OpTy op, const ArrayRef newPad) { + const llvm::ArrayRef kernel = op.getKernel(); + if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] || + newPad[0] >= kernel[0] || newPad[1] >= kernel[0]) + return false; + return true; + } + static bool checkPadConstCompliance(OpTy, Value padConst) { + // Check that padConst is a constant value and a scalar tensor + DenseElementsAttr padConstAttr; + if (!matchPattern(padConst, m_Constant(&padConstAttr)) || + padConstAttr.size() != 1) { + return false; + } + + // Pad needs to be in the minimum value to be able to merge + if (auto padConstFpAttr = + mlir::dyn_cast(padConstAttr)) { + const APFloat padConstVal = *padConstFpAttr.begin(); + const APFloat lowestVal = + APFloat::getLargest(padConstVal.getSemantics(), true); + return padConstVal == lowestVal; + } else if (auto padConstIntAttr = + mlir::dyn_cast(padConstAttr)) { + const APInt padConstVal = *padConstIntAttr.begin(); + const unsigned int bitWidth = padConstVal.getBitWidth(); + const APInt lowestVal = + padConstIntAttr.getElementType().isUnsignedInteger() + ? APInt::getZero(bitWidth) + : APInt::getSignedMinValue(bitWidth); + return padConstVal == lowestVal; + } + + // Bail-out on unsupported type + return false; + } + static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, + Value padInput, ArrayRef newPad) { + rewriter.replaceOpWithNewOp( + op, op.getType(), padInput, op.getKernel(), op.getStride(), + rewriter.getDenseI64ArrayAttr(newPad), op.getNanMode()); + } +}; + +template +struct ConvPadFoldAdaptor { + static bool checkKernelCompliance(OpTy, const ArrayRef) { + return true; + } + static bool checkPadConstCompliance(OpTy op, Value padConst) { + return checkMatchingPadConstAndZp(padConst, op.getInputZp()); + } + static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, + Value padInput, ArrayRef newPad) { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), padInput, op.getWeight(), op.getBias(), + op.getInputZp(), op.getWeightZp(), newPad, op.getStrideAttr(), + op.getDilationAttr(), op.getAccType(), op.getLocalBound()); + } +}; + +// Pattern attempts to fold a `tosa.pad` operator to a following tensor +// operation like `tosa.conv2d` by merging the padding associated with the +// pad operator directly to the implicit padding of the tensor operation. +// This helps eliminate the explicit padding operator if unused. +template +struct FoldPadToTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy tensorOp, + PatternRewriter &rewriter) const override { + // Check producer is a tosa::PadOp + auto padOp = tensorOp.getInput().template getDefiningOp(); + if (!padOp) + return rewriter.notifyMatchFailure(tensorOp, + "Producer must be a tosa::PadOp."); + + // Validate that tensor operation has sane padding + const std::vector &tensorOpPad = tensorOp.getPad().vec(); + if (tensorOpPad.size() != 4) // pad_top, pad_bottom, pad_left, pad_right + return rewriter.notifyMatchFailure( + tensorOp, "Tensor operation padding shall have 4 elements."); + + // Validate tosa::PadOp padding + DenseIntElementsAttr padOpPadding; + if (!matchPattern(padOp.getPadding(), m_Constant(&padOpPadding))) { + return rewriter.notifyMatchFailure( + tensorOp, + "The `padding` input specified on the tosa::PadOp must be constant."); + } + // N_before, N_after, H_before, H_after, W_before, W_after, C_before, + // C_after + if (padOpPadding.size() != 8) + return rewriter.notifyMatchFailure(tensorOp, + "Pad padding should have 8 elements."); + int64_t padNBefore = (*(padOpPadding.begin() + 0)).getLimitedValue(); + int64_t padNAfter = (*(padOpPadding.begin() + 1)).getLimitedValue(); + int64_t padHBefore = (*(padOpPadding.begin() + 2)).getLimitedValue(); + int64_t padHAfter = (*(padOpPadding.begin() + 3)).getLimitedValue(); + int64_t padWBefore = (*(padOpPadding.begin() + 4)).getLimitedValue(); + int64_t padWAfter = (*(padOpPadding.begin() + 5)).getLimitedValue(); + int64_t padCBefore = (*(padOpPadding.begin() + 6)).getLimitedValue(); + int64_t padCAfter = (*(padOpPadding.begin() + 7)).getLimitedValue(); + + if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0) + return rewriter.notifyMatchFailure( + tensorOp, "Folding padding in N or C dimensions is not supported."); + + // Fold padding from Pad into the tensor operation + // 4 elements - pad_top, pad_bottom, pad_left, pad_right + SmallVector foldedPad(tensorOpPad.size()); + foldedPad[0] = padHBefore + tensorOpPad[0]; + foldedPad[1] = padHAfter + tensorOpPad[1]; + foldedPad[2] = padWBefore + tensorOpPad[2]; + foldedPad[3] = padWAfter + tensorOpPad[3]; + + // Check kernel related restrictions + if (!AdaptorTy::checkKernelCompliance(tensorOp, foldedPad)) { + return rewriter.notifyMatchFailure( + tensorOp, "Padding size not aligned with kernel restrictions."); + } + + // Check padding constant restrictions + if (!AdaptorTy::checkPadConstCompliance(tensorOp, padOp.getPadConst())) { + return rewriter.notifyMatchFailure( + tensorOp, + "Padding constant is not aligned with operator zero-point."); + } + + // Check that padding doesn't grow more than 8K level (8192) for now + if (llvm::any_of(foldedPad, [](int64_t padVal) { return padVal > 8192; })) { + return rewriter.notifyMatchFailure( + tensorOp, "Padding size more than the 8K level limit."); + } + + // Create operator + AdaptorTy::replaceOpWithNewPad(rewriter, tensorOp, padOp.getInput1(), + foldedPad); + + return success(); + } +}; +} // namespace + +void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>>( + context); +} + +void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add< + FoldPadToTensorOp>>( + context); +} + +void DepthwiseConv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>>( + context); +} + +struct MaxPool2dIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + PatternRewriter &rewriter) const override { + Value input = op.getInput(); + Value output = op.getOutput(); + ShapedType inputType = llvm::cast(input.getType()); + ShapedType outputType = llvm::cast(output.getType()); + + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { + return failure(); + } + + // If the output and input shapes are 1x1, then this is a no op. + ArrayRef outputShape = outputType.getShape(); + if (outputShape[1] != 1 || outputShape[2] != 1) { + return failure(); + } + + ArrayRef inputShape = inputType.getShape(); + if (inputShape[1] != 1 || inputShape[2] != 1) { + return failure(); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + +void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>>( + context); +} + +//===----------------------------------------------------------------------===// +// Data Layout / Memory Reinterpretation. +//===----------------------------------------------------------------------===// + struct ConcatOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -175,41 +449,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } -struct MaxPool2dIsNoOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, - PatternRewriter &rewriter) const override { - Value input = op.getInput(); - Value output = op.getOutput(); - ShapedType inputType = llvm::cast(input.getType()); - ShapedType outputType = llvm::cast(output.getType()); - - if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { - return failure(); - } - - // If the output and input shapes are 1x1, then this is a no op. - ArrayRef outputShape = outputType.getShape(); - if (outputShape[1] != 1 || outputShape[2] != 1) { - return failure(); - } - - ArrayRef inputShape = inputType.getShape(); - if (inputShape[1] != 1 || inputShape[2] != 1) { - return failure(); - } - - rewriter.replaceOp(op, input); - return success(); - } -}; - -void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - struct ClampIsNoOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 077a6cee0a1bb..3a0985f6e1868 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -9,6 +9,158 @@ func.func @argmax_nofold(%arg0: tensor) -> tensor<1xi32> { // ----- +// CHECK-LABEL: @pad_wh_avg_pool2d_fold +func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> { + // CHECK-NOT: tosa.pad + // CHECK: tosa.avg_pool2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32> + %pool = tosa.avg_pool2d %padded, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x11x9x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x6x5x3xf32> + return %pool : tensor<1x6x5x3xf32> +} + +// ----- + +// CHECK-LABEL: @pad_wh_avg_pool2d_nofold_pad_const +func.func @pad_wh_avg_pool2d_nofold_pad_const(%input: tensor<1x10x8x3xi8>) -> tensor<1x6x5x3xi8> { + // CHECK: tosa.pad + // CHECK: tosa.avg_pool2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<15> : tensor<1xi8>}> : ()-> tensor<1xi8> + %input_zp = "tosa.const"() <{values = dense<10> : tensor<1xi8>}> : ()-> tensor<1xi8> + %output_zp = "tosa.const"() <{values = dense<20> : tensor<1xi8>}> : ()-> tensor<1xi8> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xi8>, !tosa.shape<8>, tensor<1xi8>) -> tensor<1x11x9x3xi8> + %pool = tosa.avg_pool2d %padded, %input_zp, %output_zp {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x11x9x3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x6x5x3xi8> + return %pool : tensor<1x6x5x3xi8> +} + +// ----- + +// CHECK-LABEL: @pad_wh_avg_pool2d_nofold_pad_larger_than_kernel +func.func @pad_wh_avg_pool2d_nofold_pad_larger_than_kernel(%input: tensor<1x10x8x3xf32>) -> tensor<1x7x5x3xf32> { + // CHECK: tosa.pad + // CHECK: tosa.avg_pool2d + %pad_shape = tosa.const_shape { values = dense<[0, 0, 3, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %output_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x13x9x3xf32> + %pool = tosa.avg_pool2d %padded, %input_zp, %output_zp {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x13x9x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x5x3xf32> + return %pool : tensor<1x7x5x3xf32> +} + +// ----- + +// CHECK-LABEL: @pad_wh_conv2d_fold +func.func @pad_wh_conv2d_fold(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<1x10x8x1xf32> { + // CHECK-NOT: tosa.pad + // CHECK: tosa.conv2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32> + %conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x1xf32> + return %conv : tensor<1x10x8x1xf32> +} + +// ----- + +// CHECK-LABEL: @pad_bwh_conv2d_nofold +func.func @pad_bwh_conv2d_nofold(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<3x10x8x1xf32> { + // CHECK: tosa.pad + // CHECK: tosa.conv2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[1, 1, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<3x10x8x3xf32> + %conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<3x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<3x10x8x1xf32> + return %conv : tensor<3x10x8x1xf32> +} + +// ----- + +// CHECK-LABEL: @pad_wh_conv2d_nofold_pad_const +func.func @pad_wh_conv2d_nofold_pad_const(%input: tensor<1x8x4x3xf32>, %weight: tensor<1x3x3x3xf32>, %bias: tensor<1xf32>) -> tensor<1x10x8x1xf32> { + // CHECK: tosa.pad + // CHECK: tosa.conv2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<1.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32> + %conv = tosa.conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x10x8x3xf32>, tensor<1x3x3x3xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x1xf32> + return %conv : tensor<1x10x8x1xf32> +} + +// ----- + +// CHECK-LABEL: @pad_wh_depthwise_conv2d_fold +func.func @pad_wh_depthwise_conv2d_fold(%input: tensor<1x8x4x3xf32>, %weight: tensor<3x3x3x1xf32>, %bias: tensor<3xf32>) -> tensor<1x10x8x3xf32> { + // CHECK-NOT: tosa.pad + // CHECK: tosa.depthwise_conv2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 1, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x8x4x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x10x8x3xf32> + %conv = tosa.depthwise_conv2d %padded, %weight, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor<1x10x8x3xf32>, tensor<3x3x3x1xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x8x3xf32> + return %conv : tensor<1x10x8x3xf32> +} + +// ----- + +// CHECK-LABEL: @pad_wh_max_pool2d_fold +func.func @pad_wh_max_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> { + // CHECK-NOT: tosa.pad + // CHECK: tosa.max_pool2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<-3.4028235e+38> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32> + %pool = tosa.max_pool2d %padded {kernel = array, pad = array, stride = array} : (tensor<1x11x9x3xf32>) -> tensor<1x6x5x3xf32> + return %pool : tensor<1x6x5x3xf32> +} + +// ----- + +// CHECK-LABEL: @pad_wh_max_pool2d_nofold_pad_const +func.func @pad_wh_max_pool2d_nofold_pad_const(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> { + // CHECK: tosa.pad + // CHECK: tosa.max_pool2d + // CHECK-SAME: pad = array + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 1, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x9x3xf32> + %pool = tosa.max_pool2d %padded {kernel = array, pad = array, stride = array} : (tensor<1x11x9x3xf32>) -> tensor<1x6x5x3xf32> + return %pool : tensor<1x6x5x3xf32> +} + +// ----- + +// CHECK-LABEL: @pad_wh_max_pool2d_no_fold_8k_limit +func.func @pad_wh_max_pool2d_no_fold_8k_limit(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x4101x3xf32> { + // CHECK: tosa.pad + // CHECK: tosa.max_pool2d + %pad_shape = tosa.const_shape { values = dense<[0, 0, 1, 0, 8193, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> + %pad_const = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : ()-> tensor<1xf32> + %padded = tosa.pad %input, %pad_shape, %pad_const : (tensor<1x10x8x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x11x8201x3xf32> + %pool = tosa.max_pool2d %padded {kernel = array, pad = array, stride = array} : (tensor<1x11x8201x3xf32>) -> tensor<1x6x4101x3xf32> + return %pool : tensor<1x6x4101x3xf32> +} + +// ----- + // CHECK-LABEL: @add_bcast_zero_int func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> { // CHECK-NOT: tosa.add