From c6aea9193db7ad415f2fcde00e8bdcc3d98cfea4 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 16 Oct 2025 02:54:14 -0500 Subject: [PATCH 01/13] [Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops -- This commit includes the basic infra/utilities to add matchers for linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it identifies which linalg.*conv*/*pool* op it is. -- It adds a few representative linalg.*conv*/*pool* ops to demo the matchers' capability and does so as part of `linalg-specialize-generic-ops` pass. -- The goal is directed towards addressing the aim of [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863) iteratively for `*conv*/*pooling*` ops. -- This is part-1 of a series of PRs aimed to add matchers for Convolution ops. -- For further details, refer to https://github.com/llvm/llvm-project/pull/163374#pullrequestreview-3341048722 Signed-off-by: Abhishek Varma --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 9 + .../Dialect/Linalg/Transforms/Specialize.cpp | 144 +++++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 502 ++++++++++++++++++ .../convolution/roundtrip-convolution.mlir | 112 ++++ 4 files changed, 767 insertions(+) create mode 100644 mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 48978eb7663d5..771d753a8bddb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to); std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +//===----------------------------------------------------------------------===// +// Convolution matcher utility +//===----------------------------------------------------------------------===// + +template +bool isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations = nullptr, + SmallVector *strides = nullptr); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 40fc0d68e358f..35861002e309e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,6 +237,145 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } +/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` +/// with `dilations` and `strides`. +template +static FailureOr +specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, + ArrayRef dilations, ArrayRef strides) { + SmallVector inputs = genericOp.getDpsInputs(); + ValueRange outputs = genericOp.getDpsInits(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector resultTypes = genericOp.hasPureTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; + LinalgOp namedOp; + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } + return namedOp; +} + +/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType(genericOp, &dilations, + &strides)) + return specializeToConvOp(rewriter, genericOp, + dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be +/// added here incrementally in follow-up PRs. +static FailureOr +inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + return failure(); +} + +static FailureOr +inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, + GenericOp genericOp) { + SmallVector dilations, strides; + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + return failure(); +} + +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To +// improve the search speed, the convolution ops have been segregated based on +// the rank of iterator types array. +static FailureOr +inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { + SmallVector iteratorTypes = + genericOp.getIteratorTypesArray(); + unsigned totalIterators = iteratorTypes.size(); + switch (totalIterators) { + case 2: + return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); + case 4: + return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); + case 5: + return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); + case 6: + return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); + case 7: + return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); + case 8: + return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); + case 9: + return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); + } + return failure(); +} + } // namespace //===----------------------------------------------------------------------===// @@ -316,6 +455,11 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, if (isaContractionOpInterface(genericOp)) { return specializeLinalgContractions(rewriter, genericOp); } + + // Convolution - e.g. *conv/pooling* + if (isaConvolutionOpInterface(genericOp)) { + return inferAndSpecializeToConvolutionOp(rewriter, genericOp); + } return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 24d3722cf5426..c3c2819652129 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Convolution matcher utilities +//===----------------------------------------------------------------------===// + +/// Utility to match block body for linalg.pool* ops. +template +static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { + Operation *defOp = yieldVal.getDefiningOp(); + if (!(isa_and_present(defOp) || ...)) + return false; + + BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); + BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + if (!lhsArg || !rhsArg) + return false; + return true; +} + +static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, + body); +} + +static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) { + return bodyMatcherForPoolOps(yieldVal, body); +} + +static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, + uint32_t mapIndex, uint32_t dimIndex) { + auto affineMap = cast(indexingMaps[mapIndex]).getValue(); + if (dimIndex < affineMap.getNumResults()) + return affineMap.getResult(dimIndex); + return nullptr; +} + +// Check if `expr` is either: +// - a dimension expr alone (implying *1), or +// - a multiplication of dimension expr by constant. +static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, + int64_t &constantValue) { + if (auto dExpr = dyn_cast(expr)) { + dim = dExpr; + constantValue = 1; + return true; + } + + auto mulExpr = dyn_cast(expr); + if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) + return false; + + AffineExpr lhs = mulExpr.getLHS(); + AffineExpr rhs = mulExpr.getRHS(); + + if (auto dExpr = dyn_cast(lhs)) { + if (auto cst = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + if (auto cst = dyn_cast(lhs)) { + if (auto dExpr = dyn_cast(rhs)) { + dim = dExpr; + constantValue = cst.getValue(); + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[0].getResult(iDim) == +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, CST_1 and CST_2 can be any constant. +static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, + unsigned fDim, unsigned oDim, + int64_t &dilation, int64_t &stride) { + unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); + auto addExpr = dyn_cast(inpExpr); + if (!addExpr || addExpr.getKind() != AffineExprKind::Add) + return false; + + AffineExpr dim0, dim1; + int64_t c0, c1; + + if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && + isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + // Pattern matched with dims and constants extracted. + AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); + if (dim0 == fExpr && dim1 == oExpr) { + dilation = c0; + stride = c1; + return true; + } else if (dim1 == fExpr && dim0 == oExpr) { + dilation = c1; + stride = c0; + return true; + } + } + return false; +} + +/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// indexingMaps[aIndex].getResult(aDim) == +/// indexingMaps[bIndex].getResult(bDim) +static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, + unsigned aDim, unsigned bIndex, + unsigned bDim) { + return getAffineMapDim(indexingMaps, aIndex, aDim) == + getAffineMapDim(indexingMaps, bIndex, bDim); +} + +/// Give an array of AffineMaps, verify each map to be of the corresponding +/// `expectedSize`. +static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, + ArrayRef expectedSizes) { + if (indexingMaps.size() != expectedSizes.size()) + return false; + + for (auto [indexingMap, expectedSize] : + llvm::zip_equal(indexingMaps, expectedSizes)) { + auto affineMap = cast(indexingMap).getValue(); + if (affineMap.getNumResults() != expectedSize) + return false; + } + return true; +} + +/// Utility to update `dilations` and `strides` by copy the corresponding data +/// from `tempDilations` and `tempStrides`. +static bool updateConvDilationsAndStrides(SmallVector *dilations, + SmallVector *strides, + ArrayRef tempDilations, + ArrayRef tempStrides) { + if (!(dilations && strides)) + return true; + for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { + dilations->push_back(dilation); + strides->push_back(stride); + } + return true; +} + +static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(1, 1); + SmallVector tempStrides(1, 1); + // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> + // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> + // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && + matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && + matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], + tempStrides[1])); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) + return false; + + unsigned iIndex = 0, fIndex = 1, oIndex = 2; + + SmallVector tempDilations(3, 1); + SmallVector tempStrides(3, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d5, d6, d7, d8, d4)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) + // -> (d0, d1, d2, d3, d8, d4)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], + tempStrides[2]) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && + matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && + matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinSignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForSumPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + if (!isaConvolutionOpInterface(op)) + return false; + + ArrayAttr indexingMaps = op.getIndexingMaps(); + if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) + return false; + + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned iIndex = 0, oIndex = 2; + + SmallVector tempDilations(2, 1); + SmallVector tempStrides(2, 1); + // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> + // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> + // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + bool returnVal = + (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], + tempStrides[0]) && + matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], + tempStrides[1]) && + matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); + return returnVal && updateConvDilationsAndStrides(dilations, strides, + tempDilations, tempStrides); +} + +template +bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if constexpr (std::is_same_v) { + return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcSumOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); + } else if constexpr (std::is_same_v) { + return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); + } else { + return false; + } +} + +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); +template bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides); + Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir new file mode 100644 index 0000000000000..5a18ca8519be3 --- /dev/null +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -0,0 +1,112 @@ +// The following test examples of linalg convolution named ops lowered to linalg.generic and then +// lifted back up to named op. +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s + +func.func @depthwise_conv_1d_nwc_wc(%input: memref, %filter: memref, %output: memref) { + linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, + strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: memref, memref) + outs (%output: memref) + return +} +// CHECK: @depthwise_conv_1d_nwc_wc +// CHECK: linalg.depthwise_conv_1d_nwc_wc +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_2d_nchw_chw +// CHECK: linalg.depthwise_conv_2d_nchw_chw +// CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, + strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwcm +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_max_unsigned +// CHECK: linalg.pooling_nhwc_max_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic + +// ----- + +func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned +// CHECK: linalg.pooling_nhwc_min_unsigned +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic From cd1b88a9d7febdd1f933ac22254303f74643f1c2 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 17 Oct 2025 02:46:56 -0500 Subject: [PATCH 02/13] Review comment v1.0 --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 143 +++++++++--------- .../convolution/roundtrip-convolution.mlir | 16 +- 2 files changed, 87 insertions(+), 72 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c3c2819652129..4dfec7b361eab 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -418,9 +418,9 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, SmallVector tempDilations(1, 1); SmallVector tempStrides(1, 1); - // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)> - // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)> - // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + // #map = affine_map<(N, W, C, w) -> (N, W + w, C)> + // #map1 = affine_map<(N, W, C, w) -> (w, C)> + // #map2 = affine_map<(N, W, C, w) -> (N, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && @@ -449,9 +449,9 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (C, h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && @@ -483,12 +483,12 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, SmallVector tempDilations(3, 1); SmallVector tempStrides(3, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) - // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) - // -> (d5, d6, d7, d8, d4)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) - // -> (d0, d1, d2, d3, d8, d4)> + // #map = affine_map<(N, D, H, W, CM, d, h, w, C) + // -> (N, D + d, H + h, W + w, C)> + // #map1 = affine_map<(N, D, H, W, CM, d, h, w, C) + // -> (d, h, w, C, CM)> + // #map2 = affine_map<(N, D, H, W, CM, d, h, w, C) + // -> (N, D, H, W, C, CM)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -526,9 +526,9 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -562,9 +562,9 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -598,9 +598,9 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -635,9 +635,9 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -672,9 +672,9 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> - // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> - // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> + // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> + // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, @@ -689,58 +689,61 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, tempDilations, tempStrides); } -template -bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - if constexpr (std::is_same_v) { - return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMaxOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMinOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcSumOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); - } else if constexpr (std::is_same_v) { - return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); - } else { - return false; - } +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); } -template bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcMaxOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcMinOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcSumOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); -template bool isaConvolutionOpOfType( + SmallVector *strides) { + return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); +} + +template <> +bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, - SmallVector *strides); + SmallVector *strides) { + return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); +} Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 5a18ca8519be3..06c9a84049d81 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -99,14 +99,26 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tenso // ----- -func.func @pooling_nhwc_min_unsigned(%input: tensor, %filter: tensor, %init: tensor) -> tensor { +func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filter: tensor, %init: tensor) -> tensor { %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins (%input, %filter: tensor, tensor) outs (%init: tensor) -> tensor return %0 : tensor } -// CHECK: @pooling_nhwc_min_unsigned +// CHECK: @pooling_nhwc_min_unsigned_integer // CHECK: linalg.pooling_nhwc_min_unsigned // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> // CHECK-NOT: linalg.generic + +func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} +// CHECK: @pooling_nhwc_min_unsigned_float +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> +// CHECK-NOT: linalg.generic From 0e9946b47c518867dae394c5221fba7d812c4803 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Wed, 22 Oct 2025 03:28:15 -0500 Subject: [PATCH 03/13] Review comment Hanhan v1.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 40 ------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 109 +++++------------- 2 files changed, 32 insertions(+), 117 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 35861002e309e..2bfa21d9062ee 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -264,14 +264,6 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, return namedOp; } -/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - static FailureOr inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { @@ -283,14 +275,6 @@ inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, return failure(); } -/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - static FailureOr inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { @@ -322,22 +306,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, return failure(); } -/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - -/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be -/// added here incrementally in follow-up PRs. -static FailureOr -inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - return failure(); -} - static FailureOr inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, GenericOp genericOp) { @@ -358,18 +326,10 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { genericOp.getIteratorTypesArray(); unsigned totalIterators = iteratorTypes.size(); switch (totalIterators) { - case 2: - return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp); case 4: return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); - case 5: - return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp); case 6: return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); - case 7: - return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp); - case 8: - return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp); case 9: return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 4dfec7b361eab..23c7fb68a5534 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -401,9 +401,10 @@ static bool updateConvDilationsAndStrides(SmallVector *dilations, return true; } -static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -432,9 +433,10 @@ static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -466,9 +468,10 @@ static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -507,8 +510,10 @@ static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -543,8 +548,10 @@ static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -579,8 +586,10 @@ static bool isaPoolingNhwcMinOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -615,9 +624,10 @@ static bool isaPoolingNhwcSumOp(LinalgOp op, SmallVector *dilations, tempDilations, tempStrides); } -static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -652,9 +662,10 @@ static bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op, tempDilations, tempStrides); } -static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, - SmallVector *dilations, - SmallVector *strides) { +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { if (isa(op)) return true; @@ -689,62 +700,6 @@ static bool isaPoolingNhwcMinUnsignedOp(LinalgOp op, tempDilations, tempStrides); } -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaDepthwiseConv1DNwcWcOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaDepthwiseConv2DNchwChwOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaDepthwiseConv3DNdhwcDhwcmOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMaxOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMinOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcSumOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMaxUnsignedOp(op, dilations, strides); -} - -template <> -bool isaConvolutionOpOfType( - LinalgOp op, SmallVector *dilations, - SmallVector *strides) { - return isaPoolingNhwcMinUnsignedOp(op, dilations, strides); -} - Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold, ValueRange typeDynDims) { From d44cc34ce67daccce72d930f6fea0982ce02a273 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Thu, 23 Oct 2025 06:04:50 -0500 Subject: [PATCH 04/13] Review comment Andrszej v2.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 54 ++++--------------- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 6 +++ 2 files changed, 17 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 2bfa21d9062ee..ce3df6a485f92 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -264,25 +264,26 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, return namedOp; } +// Converts linalg.generic to named linalg.*conv/pooling* where possible. To +// improve the search speed, the convolution ops have been segregated based on +// the rank of iterator types array. static FailureOr -inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { +inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; + // Depthwise Convolution ops. if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); - return failure(); -} - -static FailureOr -inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - SmallVector dilations, strides; if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); + if (isaConvolutionOpOfType( + genericOp, &dilations, &strides)) + return specializeToConvOp( + rewriter, genericOp, dilations, strides); + // Pooling ops. if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, @@ -306,36 +307,6 @@ inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter, return failure(); } -static FailureOr -inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter, - GenericOp genericOp) { - SmallVector dilations, strides; - if (isaConvolutionOpOfType( - genericOp, &dilations, &strides)) - return specializeToConvOp( - rewriter, genericOp, dilations, strides); - return failure(); -} - -// Converts linalg.generic to named linalg.*conv/pooling* where possible. To -// improve the search speed, the convolution ops have been segregated based on -// the rank of iterator types array. -static FailureOr -inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { - SmallVector iteratorTypes = - genericOp.getIteratorTypesArray(); - unsigned totalIterators = iteratorTypes.size(); - switch (totalIterators) { - case 4: - return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp); - case 6: - return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp); - case 9: - return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp); - } - return failure(); -} - } // namespace //===----------------------------------------------------------------------===// @@ -417,10 +388,7 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, } // Convolution - e.g. *conv/pooling* - if (isaConvolutionOpInterface(genericOp)) { - return inferAndSpecializeToConvolutionOp(rewriter, genericOp); - } - return failure(); + return inferAndSpecializeToConvolutionOp(rewriter, genericOp); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 23c7fb68a5534..cd518fc38819e 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -263,6 +263,9 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { body); } +// max_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); @@ -273,6 +276,9 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { body); } +// min_unsigned ops should not allow float data type. +// TODO: Retire OPDSL logic. Refer to : +// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); From 7b47d9e56db22366604e8608d099002cba5e9fd6 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 24 Oct 2025 02:58:28 -0500 Subject: [PATCH 05/13] Review comment Andrszej v3.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 13 +++++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 39 ++++++++++--------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index ce3df6a485f92..c68f7bd88c1ae 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -267,10 +267,12 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, // Converts linalg.generic to named linalg.*conv/pooling* where possible. To // improve the search speed, the convolution ops have been segregated based on // the rank of iterator types array. -static FailureOr -inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { +static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, + GenericOp genericOp) { SmallVector dilations, strides; + // ----------------------------- // Depthwise Convolution ops. + //------------------------------ if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( @@ -283,7 +285,9 @@ inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) { genericOp, &dilations, &strides)) return specializeToConvOp( rewriter, genericOp, dilations, strides); + // ----------------------------- // Pooling ops. + //------------------------------ if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, @@ -388,7 +392,10 @@ FailureOr mlir::linalg::specializeGenericOp(RewriterBase &rewriter, } // Convolution - e.g. *conv/pooling* - return inferAndSpecializeToConvolutionOp(rewriter, genericOp); + if (isaConvolutionOpInterface(genericOp)) { + return specializeLinalgConvolutions(rewriter, genericOp); + } + return failure(); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index cd518fc38819e..c5c9e4b2f8387 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -265,7 +265,7 @@ static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) { // max_unsigned ops should not allow float data type. // TODO: Retire OPDSL logic. Refer to : -// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 +// https://github.com/llvm/llvm-project/issues/164800 static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); @@ -278,7 +278,7 @@ static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) { // min_unsigned ops should not allow float data type. // TODO: Retire OPDSL logic. Refer to : -// https://github.com/llvm/llvm-project/pull/163724#discussion_r2438940337 +// https://github.com/llvm/llvm-project/issues/164800 static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) { return bodyMatcherForPoolOps(yieldVal, body); @@ -407,6 +407,9 @@ static bool updateConvDilationsAndStrides(SmallVector *dilations, return true; } +// --------------------------------------------- +// Matchers for specific convolution operation. +//---------------------------------------------- template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -414,8 +417,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) @@ -446,8 +449,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) @@ -481,8 +484,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) @@ -523,8 +526,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -561,8 +564,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -599,8 +602,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -637,8 +640,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) @@ -675,8 +678,8 @@ bool isaConvolutionOpOfType( if (isa(op)) return true; - if (!isaConvolutionOpInterface(op)) - return false; + assert(isaConvolutionOpInterface(op) && + "expected linalgOp to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) From 47b3e34dc9f6bb882a8d91df0bd09fa2f8c684d3 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 27 Oct 2025 04:02:50 -0500 Subject: [PATCH 06/13] Review comment Andrzej v4.0 --- .../Dialect/Linalg/Transforms/Specialize.cpp | 6 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 237 ++++++++++++------ 2 files changed, 158 insertions(+), 85 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index c68f7bd88c1ae..0b3662c888010 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -237,7 +237,7 @@ static FailureOr specializeLinalgContractions(RewriterBase &rewriter, return replaceWithMatmulVariant(rewriter, genericOp); } -/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` +/// Utility to specialize a `genericOp` with a convolution op of type `ConvOpTy` /// with `dilations` and `strides`. template static FailureOr @@ -272,7 +272,7 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, SmallVector dilations, strides; // ----------------------------- // Depthwise Convolution ops. - //------------------------------ + // ----------------------------- if (isaConvolutionOpOfType( genericOp, &dilations, &strides)) return specializeToConvOp( @@ -287,7 +287,7 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, rewriter, genericOp, dilations, strides); // ----------------------------- // Pooling ops. - //------------------------------ + // ----------------------------- if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) return specializeToConvOp(rewriter, genericOp, diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index c5c9e4b2f8387..0be2668a9b346 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -244,6 +244,46 @@ bool isReductionIterator(utils::IteratorType iteratorType) { // Convolution matcher utilities //===----------------------------------------------------------------------===// +/// Returns the BlockArgument that leads to `val`, if any. Traverses optional +/// ext* ops. +static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { + BlockArgument blockArg; + if (!(blockArg = dyn_cast(val))) { + Operation *defOp = val.getDefiningOp(); + if (!dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp) && + !dyn_cast_if_present(defOp)) { + return nullptr; + } + blockArg = dyn_cast(defOp->getOperand(0)); + } + return blockArg; +} + +/// Utility to match block body for matmul-like ops. +static bool bodyMatcherForMatmulLikeOps(Value yieldVal, Block *body) { + Operation *addOp = yieldVal.getDefiningOp(); + if (!isa_and_present(addOp)) + return false; + + Operation *mulOp = addOp->getOperand(1).getDefiningOp(); + if (!isa_and_present(mulOp)) + return false; + + BlockArgument lhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(0)); + BlockArgument rhsBlockArg = + getBlockArgumentWithOptionalExtOps(mulOp->getOperand(1)); + BlockArgument outBlockArg = + getBlockArgumentWithOptionalExtOps(addOp->getOperand(0)); + if (!lhsBlockArg || !rhsBlockArg || !outBlockArg || + lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body || + outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 || + rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2) + return false; + return true; +} + /// Utility to match block body for linalg.pool* ops. template static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { @@ -253,7 +293,9 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); - if (!lhsArg || !rhsArg) + if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || + rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || + rhsArg.getArgNumber() != 0) return false; return true; } @@ -339,8 +381,9 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride) { - unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1; - AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim); + unsigned inputMapIdx = 0, filterMapIdx = 1, + outputMapIdx = indexingMaps.size() - 1; + AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); auto addExpr = dyn_cast(inpExpr); if (!addExpr || addExpr.getKind() != AffineExprKind::Add) return false; @@ -351,8 +394,8 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { // Pattern matched with dims and constants extracted. - AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim); - AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim); + AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); + AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); if (dim0 == fExpr && dim1 == oExpr) { dilation = c0; stride = c1; @@ -394,22 +437,26 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, /// Utility to update `dilations` and `strides` by copy the corresponding data /// from `tempDilations` and `tempStrides`. -static bool updateConvDilationsAndStrides(SmallVector *dilations, +static void updateConvDilationsAndStrides(SmallVector *dilations, SmallVector *strides, ArrayRef tempDilations, ArrayRef tempStrides) { if (!(dilations && strides)) - return true; + return; for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { dilations->push_back(dilation); strides->push_back(stride); } - return true; + return; } // --------------------------------------------- // Matchers for specific convolution operation. -//---------------------------------------------- +// --------------------------------------------- + +// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, w) -> (w, C)> +// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -424,24 +471,30 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) return false; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; SmallVector tempDilations(1, 1); SmallVector tempStrides(1, 1); - // #map = affine_map<(N, W, C, w) -> (N, W + w, C)> - // #map1 = affine_map<(N, W, C, w) -> (w, C)> - // #map2 = affine_map<(N, W, C, w) -> (N, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) && - matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, filterMapIdx, 1) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, outputMapIdx, 2) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], - tempStrides[0])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + tempStrides[0]) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -456,27 +509,36 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) return false; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (C, h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) && - matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, filterMapIdx, 0) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, outputMapIdx, 1) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[1], - tempStrides[1])); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + tempStrides[1]) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (d, h, w, C, CM)> +// #outputMap = affine_map<(N, D, H, W, CM, d, h, w, C) +// -> (N, D, H, W, C, CM)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -491,18 +553,15 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) return false; - unsigned iIndex = 0, fIndex = 1, oIndex = 2; + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; SmallVector tempDilations(3, 1); SmallVector tempStrides(3, 1); - // #map = affine_map<(N, D, H, W, CM, d, h, w, C) - // -> (N, D + d, H + h, W + w, C)> - // #map1 = affine_map<(N, D, H, W, CM, d, h, w, C) - // -> (d, h, w, C, CM)> - // #map2 = affine_map<(N, D, H, W, CM, d, h, w, C) - // -> (N, D, H, W, C, CM)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && @@ -512,13 +571,20 @@ bool isaConvolutionOpOfType( matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, tempDilations[2], tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) && - matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) && - matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, filterMapIdx, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, outputMapIdx, 4) && + matchConvDimExprPattern(indexingMaps, filterMapIdx, 4, outputMapIdx, + 5) && + bodyMatcherForMatmulLikeOps(yieldVal, body)); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -536,27 +602,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -574,27 +642,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMinSignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -612,27 +682,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForSumPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -650,27 +722,29 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } +// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> +// #filterMap = affine_map<(N, H, W, C, h, w) -> (h, w)> +// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> template <> bool isaConvolutionOpOfType( LinalgOp op, SmallVector *dilations, @@ -688,25 +762,24 @@ bool isaConvolutionOpOfType( Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); Value yieldVal = yieldOp.getOperand(0); - unsigned iIndex = 0, oIndex = 2; + unsigned inputMapIdx = 0, outputMapIdx = 2; SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - // #map = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> - // #map1 = affine_map<(N, H, W, C, h, w) -> (h, w)> - // #map2 = affine_map<(N, H, W, C, h, w) -> (N, H, W, C)> bool returnVal = - (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) && + (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, tempDilations[0], tempStrides[0]) && matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, tempDilations[1], tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) && + matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); - return returnVal && updateConvDilationsAndStrides(dilations, strides, - tempDilations, tempStrides); + if (returnVal) + updateConvDilationsAndStrides(dilations, strides, tempDilations, + tempStrides); + return returnVal; } Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, From ab8eb8f5354aa0d3436f47cabfacd228c5cc5ea4 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 4 Nov 2025 04:09:18 -0600 Subject: [PATCH 07/13] Doc comment + function signature change --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 73 ++++++++++++++----------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 0be2668a9b346..53669542cdb91 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -338,46 +338,53 @@ static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps, return nullptr; } -// Check if `expr` is either: -// - a dimension expr alone (implying *1), or -// - a multiplication of dimension expr by constant. -static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, - int64_t &constantValue) { - if (auto dExpr = dyn_cast(expr)) { - dim = dExpr; - constantValue = 1; - return true; - } +/// Check if `expr` is either: +/// - a dimension expr alone (implying multiplication by 1), or +/// - a multiplication of dimension expr by any positive constant != 1 +/// In both cases we will capture the dimension expression into `dim` and +/// return the constant multiplier. Returns -1 in case of a match failure. +static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim) { + if ((dim = dyn_cast(expr))) + return 1; auto mulExpr = dyn_cast(expr); if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul) - return false; + return -1; AffineExpr lhs = mulExpr.getLHS(); AffineExpr rhs = mulExpr.getRHS(); - if (auto dExpr = dyn_cast(lhs)) { - if (auto cst = dyn_cast(rhs)) { - dim = dExpr; - constantValue = cst.getValue(); - return true; - } - } - if (auto cst = dyn_cast(lhs)) { - if (auto dExpr = dyn_cast(rhs)) { - dim = dExpr; - constantValue = cst.getValue(); - return true; - } - } - return false; + AffineConstantExpr cst = nullptr; + if (((dim = dyn_cast(lhs)) && + (cst = dyn_cast(rhs))) || + ((dim = dyn_cast(rhs)) && + (cst = dyn_cast(lhs)))) + return cst.getValue(); + return -1; } -/// Given an array of AffineMaps `indexingMaps` verify the following :- +/// Given an array of AffineMaps `indexingMaps` verify the following +/// commutatively:- /// indexingMaps[0].getResult(iDim) == -/// indexingMaps[1].getResult(fDim) * + -/// indexingMaps[n-1].getResult(oDim) * -/// where, CST_1 and CST_2 can be any constant. +/// indexingMaps[1].getResult(fDim) * + +/// indexingMaps[n-1].getResult(oDim) * +/// where, +/// - c0 and c1 can be any constant, +/// - n is the size of the indexingMaps' array, +/// - 0, 1 and n-1 are input, filter and output map indices respectively, +/// - iDim, fDim and oDim are the input, filter and output dimension +/// indices in their respective indexing maps +/// Example: +/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) +/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)> +/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)> +/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +/// +/// Here, +/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3 +/// Therefore, +/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride) +/// would return true and update dilation = 3 and stride = 2 static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride) { @@ -389,10 +396,10 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, return false; AffineExpr dim0, dim1; - int64_t c0, c1; + int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0); + int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1); - if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) && - isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) { + if (c0 != -1 && c1 != -1) { // Pattern matched with dims and constants extracted. AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim); AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim); From c39f831f3bacde21a9c53b164c0a7d7ab2a28a14 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Fri, 7 Nov 2025 14:50:18 -0600 Subject: [PATCH 08/13] Early exit + indirection removal + lit test updates --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 3 + .../Dialect/Linalg/Transforms/Specialize.cpp | 4 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 322 +++++++++++------- .../convolution/roundtrip-convolution.mlir | 91 ++--- 4 files changed, 241 insertions(+), 179 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 771d753a8bddb..87c69a4fc57df 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -114,6 +114,9 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); // Convolution matcher utility //===----------------------------------------------------------------------===// +/// Given a linalg `op` this function returns true if it is a convolution op of +/// type `ConvOpTy` and populate the optional `dilations` and `strides` +/// arguments, if present. template bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations = nullptr, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 0b3662c888010..77c49019f4eb9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -264,9 +264,7 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, return namedOp; } -// Converts linalg.generic to named linalg.*conv/pooling* where possible. To -// improve the search speed, the convolution ops have been segregated based on -// the rank of iterator types array. +// Converts linalg.generic to named linalg.*conv/pooling* where possible. static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 53669542cdb91..3b301acd079db 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -260,8 +260,12 @@ static BlockArgument getBlockArgumentWithOptionalExtOps(Value val) { return blockArg; } -/// Utility to match block body for matmul-like ops. -static bool bodyMatcherForMatmulLikeOps(Value yieldVal, Block *body) { +/// Utility to match block body for convolution ops. +/// The body is thus expected to yield :- +/// %out + (%lhs * %rhs) +/// where: %lhs, %rhs and %out are block arguments and +/// %lhs and %rhs can have optional upcast operation. +static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body) { Operation *addOp = yieldVal.getDefiningOp(); if (!isa_and_present(addOp)) return false; @@ -291,8 +295,10 @@ static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) { if (!(isa_and_present(defOp) || ...)) return false; - BlockArgument lhsArg = dyn_cast(defOp->getOperand(0)); - BlockArgument rhsArg = dyn_cast(defOp->getOperand(1)); + BlockArgument lhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(0)); + BlockArgument rhsArg = + getBlockArgumentWithOptionalExtOps(defOp->getOperand(1)); if (!lhsArg || !rhsArg || lhsArg.getOwner() != body || rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 || rhsArg.getArgNumber() != 0) @@ -416,16 +422,6 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, return false; } -/// Given an array of AffineMaps `indexingMaps` verify the following :- -/// indexingMaps[aIndex].getResult(aDim) == -/// indexingMaps[bIndex].getResult(bDim) -static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, - unsigned aDim, unsigned bIndex, - unsigned bDim) { - return getAffineMapDim(indexingMaps, aIndex, aDim) == - getAffineMapDim(indexingMaps, bIndex, bDim); -} - /// Give an array of AffineMaps, verify each map to be of the corresponding /// `expectedSize`. static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, @@ -485,18 +481,26 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(1, 1); SmallVector tempStrides(1, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, filterMapIdx, 1) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 2, outputMapIdx, 2) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], - tempStrides[0]) && - bodyMatcherForMatmulLikeOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 2) != + getAffineMapDim(indexingMaps, filterMapIdx, 1)) + return false; + if (getAffineMapDim(indexingMaps, inputMapIdx, 2) != + getAffineMapDim(indexingMaps, outputMapIdx, 2)) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], tempStrides[0])) + return false; + // Match body + if (!bodyMatcherForConvolutionOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> @@ -523,21 +527,30 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, filterMapIdx, 0) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 1, outputMapIdx, 1) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[0], - tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, - /*oDim=*/3, tempDilations[1], - tempStrides[1]) && - bodyMatcherForMatmulLikeOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 1) != + getAffineMapDim(indexingMaps, filterMapIdx, 0)) + return false; + if (getAffineMapDim(indexingMaps, inputMapIdx, 1) != + getAffineMapDim(indexingMaps, outputMapIdx, 1)) + return false; + // Match: H + h + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[0], tempStrides[0])) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[1], tempStrides[1])) + return false; + // Match body + if (!bodyMatcherForConvolutionOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) @@ -567,26 +580,38 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(3, 1); SmallVector tempStrides(3, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], - tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], - tempStrides[1]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, - /*oDim=*/3, tempDilations[2], - tempStrides[2]) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, filterMapIdx, 3) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 4, outputMapIdx, 4) && - matchConvDimExprPattern(indexingMaps, filterMapIdx, 4, outputMapIdx, - 5) && - bodyMatcherForMatmulLikeOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: D + d + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], tempStrides[0])) + return false; + // Match: H + h + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], tempStrides[1])) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, + /*oDim=*/3, tempDilations[2], tempStrides[2])) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 4) != + getAffineMapDim(indexingMaps, filterMapIdx, 3)) + return false; + if (getAffineMapDim(indexingMaps, inputMapIdx, 4) != + getAffineMapDim(indexingMaps, outputMapIdx, 4)) + return false; + // Match: CM + if (getAffineMapDim(indexingMaps, filterMapIdx, 4) != + getAffineMapDim(indexingMaps, outputMapIdx, 5)) + return false; + // Match body + if (!bodyMatcherForConvolutionOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -613,20 +638,27 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], - tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], - tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && - bodyMatcherForMaxSignedPoolOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: H + h + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], tempStrides[0])) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], tempStrides[1])) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != + getAffineMapDim(indexingMaps, outputMapIdx, 3)) + return false; + // Match body + if (!bodyMatcherForMaxSignedPoolOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -653,20 +685,27 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], - tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], - tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && - bodyMatcherForMinSignedPoolOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: H + h + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], tempStrides[0])) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], tempStrides[1])) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != + getAffineMapDim(indexingMaps, outputMapIdx, 3)) + return false; + // Match body + if (!bodyMatcherForMinSignedPoolOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -693,20 +732,27 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], - tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], - tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && - bodyMatcherForSumPoolOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: H + h + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], tempStrides[0])) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], tempStrides[1])) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != + getAffineMapDim(indexingMaps, outputMapIdx, 3)) + return false; + // Match body + if (!bodyMatcherForSumPoolOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -733,20 +779,27 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], - tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], - tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && - bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: H + h + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], tempStrides[0])) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], tempStrides[1])) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != + getAffineMapDim(indexingMaps, outputMapIdx, 3)) + return false; + // Match body + if (!bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, H + h, W + w, C)> @@ -773,20 +826,27 @@ bool isaConvolutionOpOfType( SmallVector tempDilations(2, 1); SmallVector tempStrides(2, 1); - bool returnVal = - (matchConvDimExprPattern(indexingMaps, inputMapIdx, 0, outputMapIdx, 0) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], - tempStrides[0]) && - matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], - tempStrides[1]) && - matchConvDimExprPattern(indexingMaps, inputMapIdx, 3, outputMapIdx, 3) && - bodyMatcherForMinUnsignedPoolOps(yieldVal, body)); - if (returnVal) - updateConvDilationsAndStrides(dilations, strides, tempDilations, - tempStrides); - return returnVal; + // Match: N + if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != + getAffineMapDim(indexingMaps, outputMapIdx, 0)) + return false; + // Match: H + h + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, tempDilations[0], tempStrides[0])) + return false; + // Match: W + w + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, tempDilations[1], tempStrides[1])) + return false; + // Match: C + if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != + getAffineMapDim(indexingMaps, outputMapIdx, 3)) + return false; + // Match body + if (!bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) + return false; + updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + return true; } Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 06c9a84049d81..97ca4d30d0d94 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -2,12 +2,12 @@ // lifted back up to named op. // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s -func.func @depthwise_conv_1d_nwc_wc(%input: memref, %filter: memref, %output: memref) { - linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, - strides = dense<2> : tensor<1xi64>} - ins (%input, %filter: memref, memref) - outs (%output: memref) - return +func.func @depthwise_conv_1d_nwc_wc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor } // CHECK: @depthwise_conv_1d_nwc_wc // CHECK: linalg.depthwise_conv_1d_nwc_wc @@ -16,10 +16,11 @@ func.func @depthwise_conv_1d_nwc_wc(%input: memref, %filter: memref, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_2d_nchw_chw + {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @depthwise_conv_2d_nchw_chw @@ -29,11 +30,11 @@ func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tens // ----- -func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, - strides = dense<1> : tensor<3xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm + {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @depthwise_conv_3d_ndhwc_dhwcm @@ -43,11 +44,11 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: // ----- -func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @pooling_nhwc_max @@ -57,11 +58,11 @@ func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @pooling_nhwc_min @@ -71,11 +72,11 @@ func.func @pooling_nhwc_min(%input: tensor, %filter: tensor, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_sum + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @pooling_nhwc_sum @@ -85,11 +86,11 @@ func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_max_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @pooling_nhwc_max_unsigned @@ -99,11 +100,11 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tenso // ----- -func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @pooling_nhwc_min_unsigned_integer @@ -111,11 +112,11 @@ func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filte // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> // CHECK-NOT: linalg.generic -func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %init: tensor) -> tensor { - %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, - strides = dense<1> : tensor<2xi64>} - ins (%input, %filter: tensor, tensor) - outs (%init: tensor) -> tensor +func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.pooling_nhwc_min_unsigned + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor return %0 : tensor } // CHECK: @pooling_nhwc_min_unsigned_float From c8b1c57360c726be41e755c6840a786e102b6e4d Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Mon, 10 Nov 2025 01:24:11 -0600 Subject: [PATCH 09/13] Use macro in specialize.cpp pass --- .../Dialect/Linalg/Transforms/Specialize.cpp | 45 ++++++------------- .../convolution/roundtrip-convolution.mlir | 2 + 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 77c49019f4eb9..8b0f192a39781 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -268,44 +268,25 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, GenericOp genericOp) { SmallVector dilations, strides; +#define CONV_OP_SPECIALIZER(ConvOpTy) \ + if (isaConvolutionOpOfType(genericOp, &dilations, &strides)) \ + return specializeToConvOp(rewriter, genericOp, dilations, \ + strides); \ // ----------------------------- // Depthwise Convolution ops. // ----------------------------- - if (isaConvolutionOpOfType( - genericOp, &dilations, &strides)) - return specializeToConvOp( - rewriter, genericOp, dilations, strides); - if (isaConvolutionOpOfType( - genericOp, &dilations, &strides)) - return specializeToConvOp( - rewriter, genericOp, dilations, strides); - if (isaConvolutionOpOfType( - genericOp, &dilations, &strides)) - return specializeToConvOp( - rewriter, genericOp, dilations, strides); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); // ----------------------------- // Pooling ops. // ----------------------------- - if (isaConvolutionOpOfType(genericOp, &dilations, - &strides)) - return specializeToConvOp(rewriter, genericOp, - dilations, strides); - if (isaConvolutionOpOfType(genericOp, &dilations, - &strides)) - return specializeToConvOp(rewriter, genericOp, - dilations, strides); - if (isaConvolutionOpOfType(genericOp, &dilations, - &strides)) - return specializeToConvOp(rewriter, genericOp, - dilations, strides); - if (isaConvolutionOpOfType( - genericOp, &dilations, &strides)) - return specializeToConvOp( - rewriter, genericOp, dilations, strides); - if (isaConvolutionOpOfType( - genericOp, &dilations, &strides)) - return specializeToConvOp( - rewriter, genericOp, dilations, strides); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcSumOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMaxUnsignedOp); + CONV_OP_SPECIALIZER(linalg::PoolingNhwcMinUnsignedOp); +#undef CONV_OP_SPECIALIZER return failure(); } diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 97ca4d30d0d94..04374fcc2e9ed 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -112,6 +112,8 @@ func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filte // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> // CHECK-NOT: linalg.generic +// ----- + func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} From f557fcae3dfa5450476106dd84382fe8e21d0c42 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 11 Nov 2025 03:17:20 -0600 Subject: [PATCH 10/13] Easier dilations/strides update --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 3b301acd079db..5402820e3b6d5 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -446,10 +446,8 @@ static void updateConvDilationsAndStrides(SmallVector *dilations, ArrayRef tempStrides) { if (!(dilations && strides)) return; - for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) { - dilations->push_back(dilation); - strides->push_back(stride); - } + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return; } From c82c3d38518175ab2068e0d375bddb4c92ba3c1a Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 11 Nov 2025 03:58:22 -0600 Subject: [PATCH 11/13] Make dilations/strides mandatory --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 8 ++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 37 ++++++++----------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 87c69a4fc57df..3c8f222d62525 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -115,12 +115,10 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); //===----------------------------------------------------------------------===// /// Given a linalg `op` this function returns true if it is a convolution op of -/// type `ConvOpTy` and populate the optional `dilations` and `strides` -/// arguments, if present. +/// type `ConvOpTy` and populate `dilations` and `strides` arguments. template -bool isaConvolutionOpOfType(LinalgOp op, - SmallVector *dilations = nullptr, - SmallVector *strides = nullptr); +bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, + SmallVector *strides); //===----------------------------------------------------------------------===// // Fusion / Tiling utilities diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 5402820e3b6d5..8dc4952172b92 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -438,19 +438,6 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, return true; } -/// Utility to update `dilations` and `strides` by copy the corresponding data -/// from `tempDilations` and `tempStrides`. -static void updateConvDilationsAndStrides(SmallVector *dilations, - SmallVector *strides, - ArrayRef tempDilations, - ArrayRef tempStrides) { - if (!(dilations && strides)) - return; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); - return; -} - // --------------------------------------------- // Matchers for specific convolution operation. // --------------------------------------------- @@ -497,7 +484,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForConvolutionOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } @@ -547,7 +535,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForConvolutionOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } @@ -608,7 +597,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForConvolutionOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } @@ -655,7 +645,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } @@ -702,7 +693,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForMinSignedPoolOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } @@ -749,7 +741,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForSumPoolOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } @@ -796,7 +789,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } @@ -843,7 +837,8 @@ bool isaConvolutionOpOfType( // Match body if (!bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) return false; - updateConvDilationsAndStrides(dilations, strides, tempDilations, tempStrides); + *dilations = SmallVector(tempDilations); + *strides = SmallVector(tempStrides); return true; } From e811d48f8470e38abbe6625db290dc949ef46a0a Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 11 Nov 2025 12:23:51 -0600 Subject: [PATCH 12/13] Tests reviews + clean up more --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 3 +- .../Dialect/Linalg/Transforms/Specialize.cpp | 16 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 144 ++++++++---------- .../convolution/roundtrip-convolution.mlir | 19 +-- 4 files changed, 75 insertions(+), 107 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 3c8f222d62525..d75bba6452dad 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -115,7 +115,8 @@ getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); //===----------------------------------------------------------------------===// /// Given a linalg `op` this function returns true if it is a convolution op of -/// type `ConvOpTy` and populate `dilations` and `strides` arguments. +/// type `ConvOpTy` and populates `dilations` and `strides` with values inferred +/// from the indexing maps. template bool isaConvolutionOpOfType(LinalgOp op, SmallVector *dilations, SmallVector *strides); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 8b0f192a39781..56a8ee9b96db9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -249,18 +249,10 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, SmallVector resultTypes = genericOp.hasPureTensorSemantics() ? TypeRange(ValueRange(outputs)) : TypeRange{}; - LinalgOp namedOp; - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v) { - namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, - inputs, outputs); - } else { - Attribute stridesAttr = rewriter.getI64TensorAttr(strides); - Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); - namedOp = rewriter.replaceOpWithNewOp( - genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); - } + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + LinalgOp namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); return namedOp; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 8dc4952172b92..58666a69ef492 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -453,19 +453,16 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; - SmallVector tempDilations(1, 1); - SmallVector tempStrides(1, 1); + *dilations = SmallVector(1, 1); + *strides = SmallVector(1, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) @@ -479,13 +476,14 @@ bool isaConvolutionOpOfType( return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], tempStrides[0])) + /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForConvolutionOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } @@ -500,19 +498,16 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; - SmallVector tempDilations(2, 1); - SmallVector tempStrides(2, 1); + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) @@ -526,17 +521,18 @@ bool isaConvolutionOpOfType( return false; // Match: H + h if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[0], tempStrides[0])) + /*oDim=*/2, (*dilations)[0], (*strides)[0])) return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, - /*oDim=*/3, tempDilations[1], tempStrides[1])) + /*oDim=*/3, (*dilations)[1], (*strides)[1])) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForConvolutionOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } @@ -554,34 +550,31 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; - SmallVector tempDilations(3, 1); - SmallVector tempStrides(3, 1); + *dilations = SmallVector(3, 1); + *strides = SmallVector(3, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) return false; // Match: D + d if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], tempStrides[0])) + /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; // Match: H + h if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], tempStrides[1])) + /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, - /*oDim=*/3, tempDilations[2], tempStrides[2])) + /*oDim=*/3, (*dilations)[2], (*strides)[2])) return false; // Match: C if (getAffineMapDim(indexingMaps, inputMapIdx, 4) != @@ -595,10 +588,11 @@ bool isaConvolutionOpOfType( getAffineMapDim(indexingMaps, outputMapIdx, 5)) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForConvolutionOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } @@ -613,40 +607,38 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, outputMapIdx = 2; - SmallVector tempDilations(2, 1); - SmallVector tempStrides(2, 1); + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) return false; // Match: H + h if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], tempStrides[0])) + /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], tempStrides[1])) + /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; // Match: C if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != getAffineMapDim(indexingMaps, outputMapIdx, 3)) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForMaxSignedPoolOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } @@ -661,40 +653,38 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, outputMapIdx = 2; - SmallVector tempDilations(2, 1); - SmallVector tempStrides(2, 1); + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) return false; // Match: H + h if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], tempStrides[0])) + /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], tempStrides[1])) + /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; // Match: C if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != getAffineMapDim(indexingMaps, outputMapIdx, 3)) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForMinSignedPoolOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } @@ -709,40 +699,38 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, outputMapIdx = 2; - SmallVector tempDilations(2, 1); - SmallVector tempStrides(2, 1); + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) return false; // Match: H + h if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], tempStrides[0])) + /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], tempStrides[1])) + /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; // Match: C if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != getAffineMapDim(indexingMaps, outputMapIdx, 3)) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForSumPoolOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } @@ -757,40 +745,38 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, outputMapIdx = 2; - SmallVector tempDilations(2, 1); - SmallVector tempStrides(2, 1); + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) return false; // Match: H + h if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], tempStrides[0])) + /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], tempStrides[1])) + /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; // Match: C if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != getAffineMapDim(indexingMaps, outputMapIdx, 3)) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForMaxUnsignedPoolOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } @@ -805,40 +791,38 @@ bool isaConvolutionOpOfType( return true; assert(isaConvolutionOpInterface(op) && - "expected linalgOp to implement ConvolutionOpInterface"); + "expected op to implement ConvolutionOpInterface"); ArrayAttr indexingMaps = op.getIndexingMaps(); if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - Block *body = op.getBlock(); - auto yieldOp = cast(body->getTerminator()); - Value yieldVal = yieldOp.getOperand(0); unsigned inputMapIdx = 0, outputMapIdx = 2; - SmallVector tempDilations(2, 1); - SmallVector tempStrides(2, 1); + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); // Match: N if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != getAffineMapDim(indexingMaps, outputMapIdx, 0)) return false; // Match: H + h if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, - /*oDim=*/1, tempDilations[0], tempStrides[0])) + /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; // Match: W + w if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, - /*oDim=*/2, tempDilations[1], tempStrides[1])) + /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; // Match: C if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != getAffineMapDim(indexingMaps, outputMapIdx, 3)) return false; // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); if (!bodyMatcherForMinUnsignedPoolOps(yieldVal, body)) return false; - *dilations = SmallVector(tempDilations); - *strides = SmallVector(tempStrides); return true; } diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 04374fcc2e9ed..3620256b0956e 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -1,18 +1,17 @@ // The following test examples of linalg convolution named ops lowered to linalg.generic and then // lifted back up to named op. -// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s +// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic -func.func @depthwise_conv_1d_nwc_wc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { +func.func @depthwise_conv_1d_nwc_wc(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> { %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} - ins (%input, %filter: tensor, tensor) - outs (%output: tensor) -> tensor - return %0 : tensor + ins (%input, %filter: tensor<1x25x8xi8>, tensor<3x8xi8>) + outs (%output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> + return %0 : tensor<1x10x8xi32> } // CHECK: @depthwise_conv_1d_nwc_wc // CHECK: linalg.depthwise_conv_1d_nwc_wc // CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> -// CHECK-NOT: linalg.generic // ----- @@ -26,7 +25,6 @@ func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tens // CHECK: @depthwise_conv_2d_nchw_chw // CHECK: linalg.depthwise_conv_2d_nchw_chw // CHECK-SAME: dilations = dense<[2, 3]> : tensor<2xi64>, strides = dense<[4, 5]> : tensor<2xi64> -// CHECK-NOT: linalg.generic // ----- @@ -40,7 +38,6 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: // CHECK: @depthwise_conv_3d_ndhwc_dhwcm // CHECK: linalg.depthwise_conv_3d_ndhwc_dhwcm // CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> -// CHECK-NOT: linalg.generic // ----- @@ -54,7 +51,6 @@ func.func @pooling_nhwc_max(%input: tensor, %filter: tensor : tensor<2xi64>, strides = dense<1> : tensor<2xi64> -// CHECK-NOT: linalg.generic // ----- @@ -68,7 +64,6 @@ func.func @pooling_nhwc_min(%input: tensor, %filter: tensor : tensor<2xi64>, strides = dense<1> : tensor<2xi64> -// CHECK-NOT: linalg.generic // ----- @@ -82,7 +77,6 @@ func.func @pooling_nhwc_sum(%input: tensor, %filter: tensor : tensor<2xi64>, strides = dense<1> : tensor<2xi64> -// CHECK-NOT: linalg.generic // ----- @@ -96,7 +90,6 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor, %filter: tensor // CHECK: @pooling_nhwc_max_unsigned // CHECK: linalg.pooling_nhwc_max_unsigned // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> -// CHECK-NOT: linalg.generic // ----- @@ -110,7 +103,6 @@ func.func @pooling_nhwc_min_unsigned_integer(%input: tensor, %filte // CHECK: @pooling_nhwc_min_unsigned_integer // CHECK: linalg.pooling_nhwc_min_unsigned // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> -// CHECK-NOT: linalg.generic // ----- @@ -124,4 +116,3 @@ func.func @pooling_nhwc_min_unsigned_float(%input: tensor, %filter: // CHECK: @pooling_nhwc_min_unsigned_float // CHECK: linalg.pooling_nhwc_min // CHECK-SAME: dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> -// CHECK-NOT: linalg.generic From 8c828fdae6276895cff129b16082acf6d018afd3 Mon Sep 17 00:00:00 2001 From: Abhishek Varma Date: Tue, 11 Nov 2025 14:40:06 -0600 Subject: [PATCH 13/13] Change the way you compare maps --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 256 ++++++++++++++---------- 1 file changed, 153 insertions(+), 103 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 58666a69ef492..57e97be9eeea0 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -442,6 +442,20 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, // Matchers for specific convolution operation. // --------------------------------------------- +/// Returns true if the given indexing maps matches with the expected indexing +/// maps. +static bool convLayoutMatches(ArrayRef> mapListExpected, + ArrayAttr indexingMaps, MLIRContext *context) { + SmallVector expectedIndexingMaps = + AffineMap::inferFromExprList(mapListExpected, context); + return indexingMaps == + ArrayAttr::get( + context, llvm::to_vector<4>(llvm::map_range( + expectedIndexingMaps, [&](AffineMap m) -> Attribute { + return AffineMapAttr::get(m); + }))); +} + // #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> // #filterMap = affine_map<(N, W, C, w) -> (w, C)> // #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> @@ -459,25 +473,25 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3})) return false; - unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; - *dilations = SmallVector(1, 1); *strides = SmallVector(1, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 2) != - getAffineMapDim(indexingMaps, filterMapIdx, 1)) - return false; - if (getAffineMapDim(indexingMaps, inputMapIdx, 2) != - getAffineMapDim(indexingMaps, outputMapIdx, 2)) - return false; - // Match: W + w + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr W = getAffineDimExpr(1, context); + AffineExpr C = getAffineDimExpr(2, context); + AffineExpr w = getAffineDimExpr(3, context); + // First fetch dilations/strides :- + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C}, + /*filterMap=*/{w, C}, + /*outputMap=*/{N, W, C}}, + indexingMaps, context)) + return false; // Match body Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); @@ -504,29 +518,32 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4})) return false; - unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; - *dilations = SmallVector(2, 1); *strides = SmallVector(2, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 1) != - getAffineMapDim(indexingMaps, filterMapIdx, 0)) - return false; - if (getAffineMapDim(indexingMaps, inputMapIdx, 1) != - getAffineMapDim(indexingMaps, outputMapIdx, 1)) - return false; - // Match: H + h + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, (*dilations)[0], (*strides)[0])) return false; - // Match: W + w + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, (*dilations)[1], (*strides)[1])) return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1]}, + /*filterMap=*/{C, h, w}, + /*outputMap=*/{N, C, H, W}}, + indexingMaps, context)) + return false; // Match body Block *body = op.getBlock(); auto yieldOp = cast(body->getTerminator()); @@ -556,36 +573,39 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6})) return false; - unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2; - *dilations = SmallVector(3, 1); *strides = SmallVector(3, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: D + d + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr D = getAffineDimExpr(1, context); + AffineExpr H = getAffineDimExpr(2, context); + AffineExpr W = getAffineDimExpr(3, context); + AffineExpr CM = getAffineDimExpr(4, context); + AffineExpr d = getAffineDimExpr(5, context); + AffineExpr h = getAffineDimExpr(6, context); + AffineExpr w = getAffineDimExpr(7, context); + AffineExpr C = getAffineDimExpr(8, context); + // First fetch dilations/strides :- + // Match: D * stride + d * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; - // Match: H + h + // Match: H * stride + h * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; - // Match: W + w + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3, (*dilations)[2], (*strides)[2])) return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 4) != - getAffineMapDim(indexingMaps, filterMapIdx, 3)) - return false; - if (getAffineMapDim(indexingMaps, inputMapIdx, 4) != - getAffineMapDim(indexingMaps, outputMapIdx, 4)) - return false; - // Match: CM - if (getAffineMapDim(indexingMaps, filterMapIdx, 4) != - getAffineMapDim(indexingMaps, outputMapIdx, 5)) + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, D * (*strides)[0] + d * (*dilations)[0], + H * (*strides)[1] + h * (*dilations)[1], + W * (*strides)[2] + w * (*dilations)[2], C}, + /*filterMap=*/{d, h, w, C, CM}, + /*outputMap=*/{N, D, H, W, C, CM}}, + indexingMaps, context)) return false; // Match body Block *body = op.getBlock(); @@ -613,25 +633,31 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - unsigned inputMapIdx = 0, outputMapIdx = 2; - *dilations = SmallVector(2, 1); *strides = SmallVector(2, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: H + h + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; - // Match: W + w + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != - getAffineMapDim(indexingMaps, outputMapIdx, 3)) + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) return false; // Match body Block *body = op.getBlock(); @@ -659,25 +685,31 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - unsigned inputMapIdx = 0, outputMapIdx = 2; - *dilations = SmallVector(2, 1); *strides = SmallVector(2, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: H + h + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; - // Match: W + w + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != - getAffineMapDim(indexingMaps, outputMapIdx, 3)) + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) return false; // Match body Block *body = op.getBlock(); @@ -705,25 +737,31 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - unsigned inputMapIdx = 0, outputMapIdx = 2; - *dilations = SmallVector(2, 1); *strides = SmallVector(2, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: H + h + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; - // Match: W + w + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != - getAffineMapDim(indexingMaps, outputMapIdx, 3)) + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) return false; // Match body Block *body = op.getBlock(); @@ -751,25 +789,31 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - unsigned inputMapIdx = 0, outputMapIdx = 2; - *dilations = SmallVector(2, 1); *strides = SmallVector(2, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: H + h + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; - // Match: W + w + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != - getAffineMapDim(indexingMaps, outputMapIdx, 3)) + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) return false; // Match body Block *body = op.getBlock(); @@ -797,25 +841,31 @@ bool isaConvolutionOpOfType( if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4})) return false; - unsigned inputMapIdx = 0, outputMapIdx = 2; - *dilations = SmallVector(2, 1); *strides = SmallVector(2, 1); - // Match: N - if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != - getAffineMapDim(indexingMaps, outputMapIdx, 0)) - return false; - // Match: H + h + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr C = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1, (*dilations)[0], (*strides)[0])) return false; - // Match: W + w + // Match: W * stride + w * dilation if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2, (*dilations)[1], (*strides)[1])) return false; - // Match: C - if (getAffineMapDim(indexingMaps, inputMapIdx, 3) != - getAffineMapDim(indexingMaps, outputMapIdx, 3)) + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1], C}, + /*filterMap=*/{h, w}, + /*outputMap=*/{N, H, W, C}}, + indexingMaps, context)) return false; // Match body Block *body = op.getBlock();