diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 48978eb7663d5..d75bba6452dad 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -110,6 +110,17 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +//===----------------------------------------------------------------------===// +// Convolution matcher utility +//===----------------------------------------------------------------------===// + +/// Given a linalg `op` this function returns true if it is a convolution op of +/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred +/// from the indexing maps. +template +bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, + SmallVector *strides); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d68e358f..56a8ee9b96db9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,51 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template +static FailureOr +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef dilations, ArrayRef strides) { + SmallVector inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + LinalgOp namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + return namedOp; +} + +// Converts linalg.generic to named linalg.*conv/pooling* where possible. +static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; +#define CONV_OP_SPECIALIZER(ConvOpTy) \ + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) \ + return specializeToConvOp(rewriter, genericOp, dilations, \ + strides); \ + // ----------------------------- + // Depthwise Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); + // ----------------------------- + // Pooling ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp); +#undef CONV_OP_SPECIALIZER + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +361,11 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 24d3722cf5426..57e97be9eeea0 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,6 +240,642 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Returns the BlockArgument that leads to `val`, if any. Traverses optional +/// ext* ops. +static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { + BlockArgument blockArg; + if (!(blockArg = dyn_cast(val))) { + Operation *defOp = val.getDefiningOp(); + if (!dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp)) { + return nullptr; + } + blockArg = dyn_cast(defOp->getOperand(0)); + } + return blockArg; +} + +/// Utility to match block body for convolution ops. +/// The body is thus expected to yield :- +/// %out + (%lhs * %rhs) +/// where: %lhs, %rhs and %out are block arguments and +/// %lhs and %rhs can have optional upcast operation. +static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { + Operation *addOp = yieldVal.getDefiningOp(); + if (!isa_and_present(addOp)) + return false; + + Operation *mulOp = addOp->getOperand(1).getDefiningOp(); + if (!isa_and_present(mulOp)) + return false; + + BlockArgument lhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + BlockArgument rhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || + lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || + rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) + return false; + return true; +} + +/// Utility to match block body for linalg.pool* ops. +template +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present(defOp) || ...)) + return false; + + BlockArgument lhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(0)); + BlockArgument rhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(1)); + if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || + rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || + rhsArg.getArgNumber() != 0) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +// max_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/issues/164800 +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +// min_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/issues/164800 +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, + uint32_t mapIndex, uint32_t dimIndex) { + auto affineMap = cast(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +/// Check if `expr` is either: +/// - a dimension expr alone (implying multiplication by 1), or +/// - a multiplication of dimension expr by any positive constant != 1 +/// In both cases we will capture the dimension expression into `dim` and +/// return the constant multiplier. Returns -1 in case of a match failure. +static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) { + if ((dim = dyn_cast(expr))) + return 1; + + auto mulExpr = dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return -1; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + AffineConstantExpr cst = nullptr; + if (((dim = dyn_cast(lhs)) && + (cst = dyn_cast(rhs))) || + ((dim = dyn_cast(rhs)) && + (cst = dyn_cast(lhs)))) + return cst.getValue(); + return -1; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following +/// commutatively:- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, +/// - c0 and c1 can be any constant, +/// - n is the size of the indexingMaps' array, +/// - 0, 1 and n-1 are input, filter and output map indices respectively, +/// - iDim, fDim and oDim are the input, filter and output dimension +/// indices in their respective indexing maps +/// Example: +/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) +/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)> +/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +/// +/// Here, +/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3 +/// Therefore, +/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride) +/// would return true and update dilation = 3 and stride = 2 +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned inputMapIdx = 0, filterMapIdx = 1, + outputMapIdx = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); + auto addExpr = dyn_cast(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0); + int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1); + + if (c0 != -1 && c1 != -1) { + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } else if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + } + return false; +} + +/// Give an array of AffineMaps, verify each map to be of the corresponding +/// `expectedSize`. +static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, + ArrayRef expectedSizes) { + if (indexingMaps.size() != expectedSizes.size()) + return false; + + for (auto [indexingMap, expectedSize] : + llvm::zip_equal(indexingMaps, expectedSizes)) { + auto affineMap = cast(indexingMap).getValue(); + if (affineMap.getNumResults() != expectedSize) + return false; + } + return true; +} + +// --------------------------------------------- +// Matchers for specific convolution operation. +// --------------------------------------------- + +/// Returns true if the given indexing maps matches with the expected indexing +/// maps. +static bool convLayoutMatches(ArrayRef> mapListExpected, + ArrayAttr indexingMaps, MLIRContext *context) { + SmallVector expectedIndexingMaps = + AffineMap::inferFromExprList(mapListExpected, context); + return indexingMaps == + ArrayAttr::get( + context, llvm::to_vector<4>(llvm::map_range( + expectedIndexingMaps, [&](AffineMap m) -> Attribute { + return AffineMapAttr::get(m); + }))); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, w) -> (w, C)> +// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + + *dilations = SmallVector(1, 1); + *strides = SmallVector(1, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr W = getAffineDimExpr(1, context); + AffineExpr C = getAffineDimExpr(2, context); + AffineExpr w = getAffineDimExpr(3, context); + // First fetch dilations/strides :- + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C}, + /*filterMap=*/{w, C}, + /*outputMap=*/{N, W, C}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForConvolutionOps(yieldVal, body)) + return false; + return true; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[0], (*strides)[0])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, (*dilations)[1], (*strides)[1])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1]}, + /*filterMap=*/{C, h, w}, + /*outputMap=*/{N, C, H, W}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForConvolutionOps(yieldVal, body)) + return false; + return true; +} + +// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (d, h, w, C, CM)> +// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D, H, W, C, CM)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) + return false; + + *dilations = SmallVector(3, 1); + *strides = SmallVector(3, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr D = getAffineDimExpr(1, context); + AffineExpr H = getAffineDimExpr(2, context); + AffineExpr W = getAffineDimExpr(3, context); + AffineExpr CM = getAffineDimExpr(4, context); + AffineExpr d = getAffineDimExpr(5, context); + AffineExpr h = getAffineDimExpr(6, context); + AffineExpr w = getAffineDimExpr(7, context); + AffineExpr C = getAffineDimExpr(8, context); + // First fetch dilations/strides :- + // Match: D * stride + d * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[1], (*strides)[1])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, (*dilations)[2], (*strides)[2])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, D * (*strides)[0] + d * (*dilations)[0], + H * (*strides)[1] + h * (*dilations)[1], + W * (*strides)[2] + w * (*dilations)[2], C}, + /*filterMap=*/{d, h, w, C, CM}, + /*outputMap=*/{N, D, H, W, C, CM}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForConvolutionOps(yieldVal, body)) + return false; + return true; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[1], (*strides)[1])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForMaxSignedPoolOps(yieldVal, body)) + return false; + return true; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[1], (*strides)[1])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForMinSignedPoolOps(yieldVal, body)) + return false; + return true; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[1], (*strides)[1])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForSumPoolOps(yieldVal, body)) + return false; + return true; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[1], (*strides)[1])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) + return false; + return true; +} + +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[1], (*strides)[1])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + if (!bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) + return false; + return true; +} + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir new file mode 100644 index 0000000000000..3620256b0956e --- /dev/null +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -0,0 +1,118 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic + +func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> { + %0 = linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor<1x25x8xi8>, tensor<3x8xi8>) + outs (%output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> + return %0 : tensor<1x10x8xi32> +} +// CHECK: @depthwise_conv_1d_nwc_wc +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nchw_chw + {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm + {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> + +// ----- + +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_sum + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned_integer +// CHECK: linalg.pooling_nhwc_min_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> + +// ----- + +func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned_float +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>