diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index 1822016fc88fe..a1eb22eba6987 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -52,7 +52,8 @@ void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, /// Populates conversion passes from TOSA dialect to Linalg named operations. void populateTosaToLinalgNamedConversionPatterns( - RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options); + const TypeConverter &converter, RewritePatternSet *patterns, + const TosaToLinalgNamedOptions &options); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index d537aef579103..b7af37d293ac1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -695,17 +695,18 @@ class FullyConnectedConverter } }; -class MaxPool2dConverter : public OpRewritePattern { +class MaxPool2dConverter : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; // Compute the dynamic output sizes of the maxpool operation. static SmallVector - computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) { + computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) { TensorType resultTy = op.getType(); Location loc = op.getLoc(); - TypedValue input = op.getInput(); + Value input = adaptor.getInput(); ArrayRef kernel = op.getKernel(); ArrayRef pad = op.getPad(); ArrayRef stride = op.getStride(); @@ -744,16 +745,22 @@ class MaxPool2dConverter : public OpRewritePattern { return dynamicDims; } - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - TypedValue input = op.getInput(); - ShapedType inputTy = input.getType(); + Value input = adaptor.getInput(); + ShapedType inputTy = cast(input.getType()); - ShapedType resultTy = op.getType(); + bool isUnsigned = op.getType().getElementType().isUnsignedInteger(); + ShapedType resultTy = + cast(getTypeConverter()->convertType(op.getType())); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert type"); Type resultETy = inputTy.getElementType(); - SmallVector dynamicDims = computeDynamicOutputSizes(op, rewriter); + SmallVector dynamicDims = + computeDynamicOutputSizes(op, adaptor, rewriter); // Determine what the initial value needs to be for the max pool op. TypedAttr initialAttr; @@ -762,7 +769,10 @@ class MaxPool2dConverter : public OpRewritePattern { resultETy, APFloat::getLargest( cast(resultETy).getFloatSemantics(), true)); - if (isa(resultETy)) + else if (isUnsigned) + initialAttr = rewriter.getIntegerAttr( + resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth())); + else if (isa(resultETy)) initialAttr = rewriter.getIntegerAttr( resultETy, APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth())); @@ -798,9 +808,15 @@ class MaxPool2dConverter : public OpRewritePattern { Value fakeWindowDims = rewriter.create(loc, kernel, resultETy); - rewriter.replaceOpWithNewOp( - op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr); + if (isUnsigned) { + rewriter.replaceOpWithNewOp( + op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr); + } else { + rewriter.replaceOpWithNewOp( + op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr); + } return success(); } }; @@ -1070,7 +1086,8 @@ class TransposeConverter : public OpRewritePattern { } // namespace void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( - RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) { + const TypeConverter &converter, RewritePatternSet *patterns, + const TosaToLinalgNamedOptions &options) { if (options.preferConv2DKernelLayoutHWCF) { patterns->add>( @@ -1085,10 +1102,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( ConvConverter, DepthwiseConvConverter, MatMulConverter, - MaxPool2dConverter, AvgPool2dConverter, FullyConnectedConverter, TransposeConverter >(patterns->getContext()); + + patterns->add< + MaxPool2dConverter + >(converter, patterns->getContext()); // clang-format on } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp index 096969391e51b..7d943b3779fb0 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp @@ -47,6 +47,9 @@ struct TosaToLinalgNamed } void runOnOperation() override { + TypeConverter converter; + tosa::populateTosaTypeConversion(converter); + RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect) -> () { return } +// CHECK-LABEL: @max_pool_ui8 +func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> { + // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8> + // CHECK: arith.constant 0 + // CHECK: linalg.pooling_nhwc_max_unsigned + // CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>) + // CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>) + // CHECK-SAME: -> tensor<1x4x32x62xi8> + // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8> + %0 = tosa.max_pool2d %arg0 {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> + return %0 : tensor<1x4x32x62xui8> +} + // CHECK-LABEL: @max_pool_i16 func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () { // CHECK: arith.constant -32768