-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops #163724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Linalg] Add basic infra to add matchers for linalg.*conv*/*pool* ops #163724
Conversation
-- This commit includes the basic infra/utilities to add matchers for linalg.*conv*/*pool* ops - such that given a `linalg.generic` op it identifies which linalg.*conv*/*pool* op it is. -- It adds a few representative linalg.*conv*/*pool* ops to demo the matchers' capability and does so as part of `linalg-specialize-generic-ops` pass. -- The goal is directed towards addressing the aim of [[RFC] Op explosion in Linalg](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863) iteratively for `*conv*/*pooling*` ops. -- This is part-1 of a series of PRs aimed to add matchers for Convolution ops. -- For further details, refer to llvm#163374 (review) Signed-off-by: Abhishek Varma <[email protected]>
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit includes the basic infra/utilities to add matchers for Signed-off-by: Abhishek Varma <[email protected]> Patch is 36.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163724.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 48978eb7663d5..771d753a8bddb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -110,6 +110,15 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to);
std::optional<SmallVector<ReassociationIndices>>
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
+//===----------------------------------------------------------------------===//
+// Convolution matcher utility
+//===----------------------------------------------------------------------===//
+
+template <typename ConvOpTy>
+bool isaConvolutionOpOfType(LinalgOp op,
+ SmallVector<int64_t> *dilations = nullptr,
+ SmallVector<int64_t> *strides = nullptr);
+
//===----------------------------------------------------------------------===//
// Fusion / Tiling utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 40fc0d68e358f..35861002e309e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -237,6 +237,145 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
+/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy`
+/// with `dilations` and `strides`.
+template <typename ConvOpTy>
+static FailureOr<LinalgOp>
+specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
+ ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
+ SmallVector<Value> inputs = genericOp.getDpsInputs();
+ ValueRange outputs = genericOp.getDpsInits();
+ SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
+ SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
+ ? TypeRange(ValueRange(outputs))
+ : TypeRange{};
+ LinalgOp namedOp;
+ if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
+ std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
+ inputs, outputs);
+ } else {
+ Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
+ Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
+ namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
+ genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
+ }
+ return namedOp;
+}
+
+/// TODO(avarma): Convolution ops which rank-2 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank2ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank4ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv1DNwcWcOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-5 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank5ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank6ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv2DNchwChwOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(genericOp, &dilations,
+ &strides))
+ return specializeToConvOp<linalg::PoolingNhwcSumOp>(rewriter, genericOp,
+ dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMaxUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ if (isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::PoolingNhwcMinUnsignedOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-7 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank7ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+/// TODO(avarma): Convolution ops which rank-8 iteratory types array will be
+/// added here incrementally in follow-up PRs.
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank8ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ return failure();
+}
+
+static FailureOr<LinalgOp>
+inferAndSpecializeBasedOnRank9ConvIteratorTypes(RewriterBase &rewriter,
+ GenericOp genericOp) {
+ SmallVector<int64_t> dilations, strides;
+ if (isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ genericOp, &dilations, &strides))
+ return specializeToConvOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
+ rewriter, genericOp, dilations, strides);
+ return failure();
+}
+
+// Converts linalg.generic to named linalg.*conv/pooling* where possible. To
+// improve the search speed, the convolution ops have been segregated based on
+// the rank of iterator types array.
+static FailureOr<LinalgOp>
+inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
+ SmallVector<utils::IteratorType> iteratorTypes =
+ genericOp.getIteratorTypesArray();
+ unsigned totalIterators = iteratorTypes.size();
+ switch (totalIterators) {
+ case 2:
+ return inferAndSpecializeBasedOnRank2ConvIteratorTypes(rewriter, genericOp);
+ case 4:
+ return inferAndSpecializeBasedOnRank4ConvIteratorTypes(rewriter, genericOp);
+ case 5:
+ return inferAndSpecializeBasedOnRank5ConvIteratorTypes(rewriter, genericOp);
+ case 6:
+ return inferAndSpecializeBasedOnRank6ConvIteratorTypes(rewriter, genericOp);
+ case 7:
+ return inferAndSpecializeBasedOnRank7ConvIteratorTypes(rewriter, genericOp);
+ case 8:
+ return inferAndSpecializeBasedOnRank8ConvIteratorTypes(rewriter, genericOp);
+ case 9:
+ return inferAndSpecializeBasedOnRank9ConvIteratorTypes(rewriter, genericOp);
+ }
+ return failure();
+}
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -316,6 +455,11 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
+
+ // Convolution - e.g. *conv/pooling*
+ if (isaConvolutionOpInterface(genericOp)) {
+ return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
+ }
return failure();
}
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 24d3722cf5426..c3c2819652129 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -240,6 +240,508 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
return iteratorType == utils::IteratorType::reduction;
}
+//===----------------------------------------------------------------------===//
+// Convolution matcher utilities
+//===----------------------------------------------------------------------===//
+
+/// Utility to match block body for linalg.pool* ops.
+template <typename... OpTypes>
+static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
+ Operation *defOp = yieldVal.getDefiningOp();
+ if (!(isa_and_present<OpTypes>(defOp) || ...))
+ return false;
+
+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
+ if (!lhsArg || !rhsArg)
+ return false;
+ return true;
+}
+
+static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal,
+ body);
+}
+
+static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
+}
+
+static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
+ uint32_t mapIndex, uint32_t dimIndex) {
+ auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
+ if (dimIndex < affineMap.getNumResults())
+ return affineMap.getResult(dimIndex);
+ return nullptr;
+}
+
+// Check if `expr` is either:
+// - a dimension expr alone (implying *1), or
+// - a multiplication of dimension expr by constant.
+static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim,
+ int64_t &constantValue) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
+ dim = dExpr;
+ constantValue = 1;
+ return true;
+ }
+
+ auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+ if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
+ return false;
+
+ AffineExpr lhs = mulExpr.getLHS();
+ AffineExpr rhs = mulExpr.getRHS();
+
+ if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
+ if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
+ if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
+ dim = dExpr;
+ constantValue = cst.getValue();
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[0].getResult(iDim) ==
+/// indexingMaps[1].getResult(fDim) * <CST_1> +
+/// indexingMaps[n-1].getResult(oDim) * <CST_2>
+/// where, CST_1 and CST_2 can be any constant.
+static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
+ unsigned fDim, unsigned oDim,
+ int64_t &dilation, int64_t &stride) {
+ unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
+ AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
+ auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
+ if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
+ return false;
+
+ AffineExpr dim0, dim1;
+ int64_t c0, c1;
+
+ if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
+ isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
+ // Pattern matched with dims and constants extracted.
+ AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
+ AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
+ if (dim0 == fExpr && dim1 == oExpr) {
+ dilation = c0;
+ stride = c1;
+ return true;
+ } else if (dim1 == fExpr && dim0 == oExpr) {
+ dilation = c1;
+ stride = c0;
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Given an array of AffineMaps `indexingMaps` verify the following :-
+/// indexingMaps[aIndex].getResult(aDim) ==
+/// indexingMaps[bIndex].getResult(bDim)
+static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex,
+ unsigned aDim, unsigned bIndex,
+ unsigned bDim) {
+ return getAffineMapDim(indexingMaps, aIndex, aDim) ==
+ getAffineMapDim(indexingMaps, bIndex, bDim);
+}
+
+/// Give an array of AffineMaps, verify each map to be of the corresponding
+/// `expectedSize`.
+static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
+ ArrayRef<int64_t> expectedSizes) {
+ if (indexingMaps.size() != expectedSizes.size())
+ return false;
+
+ for (auto [indexingMap, expectedSize] :
+ llvm::zip_equal(indexingMaps, expectedSizes)) {
+ auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
+ if (affineMap.getNumResults() != expectedSize)
+ return false;
+ }
+ return true;
+}
+
+/// Utility to update `dilations` and `strides` by copy the corresponding data
+/// from `tempDilations` and `tempStrides`.
+static bool updateConvDilationsAndStrides(SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides,
+ ArrayRef<int64_t> tempDilations,
+ ArrayRef<int64_t> tempStrides) {
+ if (!(dilations && strides))
+ return true;
+ for (auto [dilation, stride] : llvm::zip(tempDilations, tempStrides)) {
+ dilations->push_back(dilation);
+ strides->push_back(stride);
+ }
+ return true;
+}
+
+static bool isaDepthwiseConv1DNwcWcOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(1, 1);
+ SmallVector<int64_t> tempStrides(1, 1);
+ // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
+ // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv2DNchwChwOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(2, 1);
+ SmallVector<int64_t> tempStrides(2, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[1],
+ tempStrides[1]));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op,
+ SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAttr indexingMaps = op.getIndexingMaps();
+ if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
+ return false;
+
+ unsigned iIndex = 0, fIndex = 1, oIndex = 2;
+
+ SmallVector<int64_t> tempDilations(3, 1);
+ SmallVector<int64_t> tempStrides(3, 1);
+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d5, d6, d7, d8, d4)>
+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)
+ // -> (d0, d1, d2, d3, d8, d4)>
+ bool returnVal =
+ (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
+ /*oDim=*/1, tempDilations[0],
+ tempStrides[0]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
+ /*oDim=*/2, tempDilations[1],
+ tempStrides[1]) &&
+ matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
+ /*oDim=*/3, tempDilations[2],
+ tempStrides[2]) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
+ matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
+ matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
+ return returnVal && updateConvDilationsAndStrides(dilations, strides,
+ tempDilations, tempStrides);
+}
+
+static bool isaPoolingNhwcMaxOp(LinalgOp op, SmallVector<int64_t> *dilations,
+ SmallVector<int64_t> *strides) {
+ if (isa<linalg::PoolingNhwcMaxOp>(op))
+ return true;
+
+ if (!isaConvolutionOpInterface(op))
+ return false;
+
+ ArrayAt...
[truncated]
|
mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
Outdated
Show resolved
Hide resolved
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extracting this! Sharing my first set of comments. This is still quite dense, so I've not read everything yet 😅
mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
Outdated
Show resolved
Hide resolved
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for splitting the PR, it is easier to review! I'll take a look at [Utils.cpp] changes once we are aligned on the code structure.
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are getting there 😅
I've started reviewing the utility functions, see my comments inline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments in this round - slowly returning to this (apologies for the delay - travelling)
Thanks for all the updates so far 🙏🏻
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my comments so far - I am posting a few more. I will have a proper look at the test in my next iteration.
Thanks for all the effort - this is quite involved and very nuanced! 🙏🏻
| template <typename ConvOpTy> | ||
| bool isaConvolutionOpOfType(LinalgOp op, | ||
| SmallVector<int64_t> *dilations = nullptr, | ||
| SmallVector<int64_t> *strides = nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a note that dilations and strides are output arguments? Also, nullptr as default implies that this arguments are optional, but the code breaks if you don't specify them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a note that dilations and strides are output arguments?
Done.
Also, nullptr as default implies that this arguments are optional, but the code breaks if you don't specify them.
I don't think the code would break if the user doesn't specify them. A user can invoke the API either as isaConvolutionOpOfType<linalg::Conv1DOp>(op) OR as isaConvolutionOpOfType<linalg::Conv1DOp>(op, dilations, strides) and these both should work fine.
Could you point me to the codebase which seems erroneous to you ?
EDIT (10th Nov, '25): I re-read what you wrote - I guess you meant it just as a notion and not that it actually breaks. 😅 I misunderstood it. If that's correct, I guess this stands addressed with the API doc comment now. Please correct me if that's not the case and I'll address it. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(note - I wrote this comment before your edit)
I don't think the code would break if the user doesn't specify them.
I think that you are right.
However, ATM every time isaConvolutionOpOfType is invoked, dilations and strides are indeed != nullptr - why not make these arguments non-optional? To me the that would make the code simpler:
- I won't change how it's used ATM as currently every invocation does provide these arguments.
- By making
dilationsandstridesmandatory, we can avoidifstatements inside the implementation ofupdateConvDilationsAndStrides.
One if stmt is not a big deal, but there is a lot going on here - I'd rather we focus on the problem at hand than worry about generality. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One if stmt is not a big deal, but there is a lot going on here - I'd rather we focus on the problem at hand than worry about generality. WDYT?
Sure, I understand your point and I've made an update to reflect the suggested change.
I usually prefer generalizing an implementation earlier on the best we can.
For instance while scoping out work required to update :-
- DownscaleSizeOneWindowed2DConvolution to work with
linalg.generic, and - Vectorization to also use
linalg.generic,
I figured that our API we'll need to have dilations/strides optional.
A similar rationale for not using assert(isaConvolutionOpInterface(op)...) - as these will be necessary checks that the caller of our API will have to ensure anyway.
Two ways to go about this :-
1. But as discussed earlier we can remove these in the follow-up patches to make the API as well as
the caller sites cleaner as the problem we're trying to deal with through the current patch
doesn't warrant such changes.
OR
2. Again, even without removing these in the follow-up patch, I guess I can make it work. Example :-
- I can add a `if (isaConvolutionOpInterface(op))` before invoking our API at the above caller sites.
- I can use a dummy `SmallVector<int64_t> dilations, strides;` before invoking our API at the above
caller sites.
But the question then would be if we would actually want those callers to take on the extra effort.
Note: Just penning down my thoughts to prefetch/scope our plan better for the above caller sites in follow-up PRs that I'll take up. Doesn't require immediate response/attention, but we're heading there sooner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually prefer generalizing an implementation earlier on the best we can.
Generality is good if there's use for it. In this PR the generality that you are referring to is not required - it's merely noise. You might be right that it will be required later, but it is hard to see that in this PR. And there is no harm in restoring it later, when the updated context will provide the necessary justification.
Btw, here's the principle that I try to follow for most things: https://en.wikipedia.org/wiki/KISS_principle :)
For instance while scoping out work required to update :-
- ...
- Vectorization to also use linalg.generic,
Are you planning to make the vectorizer to go via linalg.generic rather than linalg.conv? Why do you expect that do be a good choice? ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, here's the principle that I try to follow for most things: https://en.wikipedia.org/wiki/KISS_principle :)
😅 I'll keep this in mind.
Are you planning to make the vectorizer to go via linalg.generic rather than linalg.conv? Why do you expect that do be a good choice? ;-)
So IREE generalizes linalg operations which improves the fusion. So in certain pipelines, in IREE, where we make use of the above callers we have generaized linalg ops in the IR. @hanhanW can add more to this. :)
mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
Outdated
Show resolved
Hide resolved
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Posting some comments that I have in-flight.
I was hoping I'd get a chance to read tests (and perhaps other things), but today was busier than I expected and so will be tomorrow morning. So, posting to progress the discussion.
| template <typename ConvOpTy> | ||
| bool isaConvolutionOpOfType(LinalgOp op, | ||
| SmallVector<int64_t> *dilations = nullptr, | ||
| SmallVector<int64_t> *strides = nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(note - I wrote this comment before your edit)
I don't think the code would break if the user doesn't specify them.
I think that you are right.
However, ATM every time isaConvolutionOpOfType is invoked, dilations and strides are indeed != nullptr - why not make these arguments non-optional? To me the that would make the code simpler:
- I won't change how it's used ATM as currently every invocation does provide these arguments.
- By making
dilationsandstridesmandatory, we can avoidifstatements inside the implementation ofupdateConvDilationsAndStrides.
One if stmt is not a big deal, but there is a lot going on here - I'd rather we focus on the problem at hand than worry about generality. WDYT?
87c6047 to
c82c3d3
Compare
| // Match: N | ||
| if (getAffineMapDim(indexingMaps, inputMapIdx, 0) != | ||
| getAffineMapDim(indexingMaps, outputMapIdx, 0)) | ||
| return false; | ||
| // Match: C | ||
| if (getAffineMapDim(indexingMaps, inputMapIdx, 1) != | ||
| getAffineMapDim(indexingMaps, filterMapIdx, 0)) | ||
| return false; | ||
| if (getAffineMapDim(indexingMaps, inputMapIdx, 1) != | ||
| getAffineMapDim(indexingMaps, outputMapIdx, 1)) | ||
| return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this also be matching H and W? Also, why:
// Match: C
if (getAffineMapDim(indexingMaps, inputMapIdx, 1) !=
getAffineMapDim(indexingMaps, outputMapIdx, 1))
return false;instead of simply (similar number of characters, but iMHO the intent is clearer):
// Match: C
auto inputMap = indexingMaps[0];
auto outputMap = indexingMap[2];
if (inputMap.getResult(1) != outputMap.getResult(1))Related to that, why no construct the expected maps and do a map comparison, similarly to:
llvm-project/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Lines 4089 to 4091 in 71b21b5
| if (layout({/*lhsIndex*/ {w + kw}, | |
| /*rhsIndex*/ {kw}, | |
| /*resIndex*/ {w}})) |
In fact, you might be able to re-use that, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried updating it. Does it look okay to you ?
In fact, you might be able to re-use that, no?
Were you suggesting to use StructuredGenerator for the above layout(...) API ?
mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir
Outdated
Show resolved
Hide resolved
| template <typename ConvOpTy> | ||
| bool isaConvolutionOpOfType(LinalgOp op, | ||
| SmallVector<int64_t> *dilations = nullptr, | ||
| SmallVector<int64_t> *strides = nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually prefer generalizing an implementation earlier on the best we can.
Generality is good if there's use for it. In this PR the generality that you are referring to is not required - it's merely noise. You might be right that it will be required later, but it is hard to see that in this PR. And there is no harm in restoring it later, when the updated context will provide the necessary justification.
Btw, here's the principle that I try to follow for most things: https://en.wikipedia.org/wiki/KISS_principle :)
For instance while scoping out work required to update :-
- ...
- Vectorization to also use linalg.generic,
Are you planning to make the vectorizer to go via linalg.generic rather than linalg.conv? Why do you expect that do be a good choice? ;-)
-- This commit includes the basic infra/utilities to add matchers for
linalg.conv/pool ops - such that given a
linalg.genericop itidentifies which linalg.conv/pool op it is.
-- It adds a few representative linalg.conv/pool ops to demo the
matchers' capability and does so as part of
linalg-specialize-generic-opspass.
-- The goal is directed towards addressing the aim of
[RFC] Op explosion in Linalg
iteratively for
*conv*/*pooling*ops.-- This is part-1 of a series of PRs aimed to add matchers for Convolution ops.
-- For further details, refer to #163374 (review)
Signed-off-by: Abhishek Varma [email protected]