Skip to content

Conversation

Groverkss
Copy link
Member

We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the shape inference directly use the available output shape.

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Kunwar Grover (Groverkss)

Changes

We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the shape inference directly use the available output shape.


Full diff: https://github.com/llvm/llvm-project/pull/118202.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp (+23-96)
  • (modified) mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir (+2-5)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+2-6)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebb88bf695d4c2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -16,24 +16,6 @@
 using namespace mlir;
 using namespace mlir::tensor;
 
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static llvm::DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
-  for (const auto &map : enumerate(reassociation)) {
-    unsigned startPos =
-        cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
-    unsigned endPos =
-        cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
-    for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
-      expandedDimToCollapsedDim[dim] = map.index();
-    }
-  }
-  return expandedDimToCollapsedDim;
-}
-
 /// For reshape op compute the shape at dimension `dimIndex` of the output in
 /// terms of shape of the `src`, when the reshape op is a collapsing
 /// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,86 +58,33 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
       }));
 }
 
-/// For an expanding reshape op, compute the value for a dimension of the output
-/// from the shape of the input.
-static OpFoldResult getExpandedOutputDimFromInputShape(
-    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
-    llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
-  if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
-    // Static dimension: return Attribute.
-    return builder.getIndexAttr(dstStaticShape[dimIndex]);
-  }
-  unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
-  unsigned startPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
-          .getPosition();
-  unsigned endPos =
-      cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
-          .getPosition();
-  int64_t linearizedStaticDim = 1;
-  for (auto d :
-       llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
-    if (d.index() + startPos == static_cast<unsigned>(dimIndex))
-      continue;
-    assert(!ShapedType::isDynamic(d.value()) &&
-           "single dimension cannot be expanded into multiple dynamic "
-           "dimensions");
-    linearizedStaticDim *= d.value();
+struct ReifyCollapseShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
+          ReifyCollapseShapeOp, CollapseShapeOp> {
+  LogicalResult
+  reifyResultShapes(Operation *op, OpBuilder &b,
+                    ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
+    auto loc = op->getLoc();
+    auto collapseShape = cast<CollapseShapeOp>(op);
+    reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
+        b, loc, collapseShape.getSrc(),
+        collapseShape.getResultType().getShape(),
+        collapseShape.getReassociationMaps()));
+    return success();
   }
-  OpFoldResult sourceDim =
-      builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
-
-  // Dynamic dimension: return Value.
-  return affine::makeComposedAffineApply(
-             builder, loc,
-             AffineMap::get(
-                 0, 1,
-                 builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
-             sourceDim)
-      ->getResult(0);
-}
-
-/// Given the `src` of an expanding reshape op, the reassociation maps and the
-/// result type, compute the shape of the result of the reshape.
-static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
-    OpBuilder &builder, Location loc, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
-  llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
-      getExpandedDimToCollapsedDimMap(reassociation);
-  return llvm::to_vector<4>(llvm::map_range(
-      llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
-        return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
-                                                  dstStaticShape, reassociation,
-                                                  expandedDimToCollapsedDim);
-      }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
-                                    ArrayRef<int64_t> dstStaticShape,
-                                    ArrayRef<AffineMap> reassocation) {
-  return dstStaticShape.size() >
-                 static_cast<size_t>(
-                     llvm::cast<ShapedType>(src.getType()).getRank())
-             ? getExpandedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation)
-             : getCollapsedOutputShapeFromInputShape(
-                   builder, loc, src, dstStaticShape, reassocation);
-}
+};
 
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
-    : public ReifyRankedShapedTypeOpInterface::ExternalModel<
-          ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+struct ReifyExpandShapeOp
+    : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+                                                             ExpandShapeOp> {
   LogicalResult
   reifyResultShapes(Operation *op, OpBuilder &b,
                     ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
     auto loc = op->getLoc();
-    auto reshapeOp = cast<OpTy>(op);
-    reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
-        b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
-        reshapeOp.getReassociationMaps()));
+    auto expandShape = cast<ExpandShapeOp>(op);
+    SmallVector<OpFoldResult> outputShape = getMixedValues(
+        expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b);
+    reifiedReturnShapes.push_back(outputShape);
     return success();
   }
 };
@@ -202,10 +131,8 @@ struct ReifyPadOp
 void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
-    ExpandShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
-    CollapseShapeOp::attachInterface<
-        ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+    ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+    CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
     PadOp::attachInterface<ReifyPadOp>(*ctx);
   });
 }
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..0595ac2492c97a 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
   %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
   return %1, %2, %3 : index, index, index
 }
-//      CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 //      CHECK: func @dim_reshape_expansion
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
 //  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
 //  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-//      CHECK:   %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-//      CHECK:   return %[[C3]], %[[C4]], %[[D1]]
+//      CHECK:   return %[[C3]], %[[C4]], %[[ARG1]]
 
 // -----
 
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..d3889f23e7d742 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
 // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
 
 func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,8 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
   return %1 : tensor<2x3x5x4x?x7xf32>
 }
 // CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME:     %[[ARG0:.+]]: index
-// CHECK:        %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT:   %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: index
+// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[ARG1]])
 // CHECK-NEXT:   return %[[INIT]]
 
 func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup, LGTM

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few follow ups here. Please wait for me to get back to this.

@Groverkss
Copy link
Member Author

Landed as #113501

@Groverkss Groverkss closed this Feb 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants