Skip to content

Conversation

@IanWood1
Copy link
Contributor

Exposes the tensor.extract_slice reshaping logic in BubbleUpExpandShapeThroughExtractSlice and BubbleUpCollapseShapeThroughExtractSlice through two corresponding utility functions. These compute the offsets/sizes/strides of an extract slice after either collapsing or expanding.

This should also make it easier to implement the two other bubbling cases: (1) the collapse_shape is a consumer or (2) the expand_shape is a consumer.

@IanWood1 IanWood1 marked this pull request as ready for review August 14, 2025 20:45
@llvmbot
Copy link
Member

llvmbot commented Aug 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Ian Wood (IanWood1)

Changes

Exposes the tensor.extract_slice reshaping logic in BubbleUpExpandShapeThroughExtractSlice and BubbleUpCollapseShapeThroughExtractSlice through two corresponding utility functions. These compute the offsets/sizes/strides of an extract slice after either collapsing or expanding.

This should also make it easier to implement the two other bubbling cases: (1) the collapse_shape is a consumer or (2) the expand_shape is a consumer.


Patch is 30.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153675.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+28)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+270-299)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 87deef9ca7466..2602252916388 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -142,6 +142,34 @@ FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
 FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
                                     ValueRange independencies);
 
+/// Computes the offsets, sizes, and strides needed to build a collapsed
+/// `sliceOp`. The dimensions to collapse are specified by `reassociation`.
+///
+/// This fails when the specified collapse cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getCollapsedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+                             ArrayRef<ReassociationIndices> reassociation,
+                             SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+                             SmallVectorImpl<OpFoldResult> &collapsedSizes,
+                             SmallVectorImpl<OpFoldResult> &collapsedStrides,
+                             OpBuilder &b);
+
+/// Computes the offsets, sizes, and strides needed to build an expanded
+/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
+/// `expandedShape`.
+///
+/// This fails when the specified expansion cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getExpandedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+                            ArrayRef<ReassociationIndices> reassociation,
+                            ArrayRef<int64_t> expandedShape,
+                            SmallVectorImpl<OpFoldResult> &expandedOffsets,
+                            SmallVectorImpl<OpFoldResult> &expandedSizes,
+                            SmallVectorImpl<OpFoldResult> &expandedStrides,
+                            OpBuilder &b);
+
 } // namespace tensor
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1fb35ce..a93681b1fce92 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
                                 PatternRewriter &rewriter) const override {
     auto expandShapeOp =
         sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandShapeOp) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "tensor.extract_slice source not produced by expand_shape");
+    }
+    SmallVector<ReassociationIndices> reassociation =
+        expandShapeOp.getReassociationIndices();
 
-    if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
-                                                 rewriter)
-            .failed())
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+    if (failed(getCollapsedExtractSliceInfo(sliceOp, reassociation, offsets,
+                                            sizes, strides, rewriter)))
       return failure();
 
-    // The tensor.extract_slice before applying the pattern works on the result
-    // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
-    // referring to the state before applying the pattern are named with the
-    // prefix "expanded", and ones referring to the state after applying the
-    // pattern are named with the prefix "collapsed".
-    SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
-    SmallVector<OpFoldResult> expandedShape =
-        getMixedValues(expandShapeOp.getStaticOutputShape(),
-                       expandShapeOp.getOutputShape(), rewriter);
-
-    // Helper variables and function for accumulating the size values.
-    Location loc = expandShapeOp->getLoc();
-    AffineExpr d0, d1, d2;
-    bindDims(rewriter.getContext(), d0, d1, d2);
-    // Multiply two integers.
-    auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
-      auto mulMap = AffineMap::get(2, 0, {d0 * d1});
-      return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
-                                                   {v1, v2});
-    };
-
-    // Compute new offsets, sizes, and strides for tensor.extract_slice.
-    // The new tensor.extract_slice will work on a tensor that has has a rank of
-    // ReassociationIndices.size(). In the loop a single offset, size, and
-    // stride value is computed per reassociation group.
-    SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
-        collapsedStrides;
-    for (const ReassociationIndices &indices :
-         expandShapeOp.getReassociationIndices()) {
-      // collapsedSize will hold the size of the single dim that represents the
-      // reassociation group in the non expanded tensor.
-      OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
-      // The reassocGroupSizes and reassocGroupOffsets are used to create an
-      // affine.linearize_index op to linearize the single offset value required
-      // for this reassociation group.
-      SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
-      for (long expandedDim : indices) {
-        // reassocGroupSizes and reassocGroupOffsets can be obtained directly
-        // from the expanded state, but the collapsed size requires calculation
-        // as it did not previously exist.
-        reassocGroupSizes.push_back(expandedShape[expandedDim]);
-        reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
-        collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
-      }
-
-      SmallVector<Value> offsetVals =
-          llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
-            return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
-          });
-      OpFoldResult collapsedOffset =
-          affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
-                                                 reassocGroupSizes,
-                                                 /*disjoint=*/true)
-              .getResult();
-      collapsedOffsets.push_back(collapsedOffset);
-      collapsedSizes.push_back(collapsedSize);
-
-      // Only unit stride is supported.
-      collapsedStrides.push_back(rewriter.getIndexAttr(1));
-    }
-
     // The shape of the result can be obtained from the sizes passed in.
-    SmallVector<Value> dynDims;
-    SmallVector<int64_t> shape;
-    dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
-    RankedTensorType resultType = RankedTensorType::get(
-        shape, expandShapeOp.getResultType().getElementType());
+    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+    RankedTensorType resultType = sliceOp.getResultType();
 
     // Create a new ExtractSliceOp and ExpandShapeOp.
+    Location loc = sliceOp.getLoc();
     Value newSliceOp = tensor::ExtractSliceOp::create(
-        rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
-        collapsedStrides);
+        rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
         sliceOp, resultType, newSliceOp,
         expandShapeOp.getReassociationIndices(), expandedSizes);
     return success();
   }
-
-  // Helper function to check if all the required conditions for the
-  // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
-  // met.
-  LogicalResult
-  checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
-                                           tensor::ExpandShapeOp expandShapeOp,
-                                           PatternRewriter &rewriter) const {
-
-    if (!expandShapeOp) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "tensor.extract_slice source not produced by expand_shape");
-    }
-
-    if (!sliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
-                   "be supported in this transformation.");
-    }
-
-    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
-    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
-        sizes.size()) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "unimplemented: rank reducing slice");
-    }
-
-    SmallVector<OpFoldResult> outputShape =
-        getMixedValues(expandShapeOp.getStaticOutputShape(),
-                       expandShapeOp.getOutputShape(), rewriter);
-
-    std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
-        isZeroOffsetAndFullSize =
-            [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
-              if (!isZeroInteger(offset))
-                return false;
-              FailureOr<bool> maybeEqual =
-                  ValueBoundsConstraintSet::areEqual(sliceSize, size);
-              return llvm::succeeded(maybeEqual) && maybeEqual.value();
-            };
-
-    // Check that the slice is contiguous within each reassociation group.
-    // The slice is contiguous only if after the first dimension where a non
-    // unit slice is taken, the slice size on all subsequent dimensions of the
-    // group is equal to the entire size of the dimension.
-    // Examples of contiguous slices:
-    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
-    //   full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
-    // Examples of non contiguous slices:
-    //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
-    //   full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
-    for (const ReassociationIndices &indices :
-         expandShapeOp.getReassociationIndices()) {
-      int64_t i = 0;
-      int64_t e = indices.size();
-      // Find the first expanded dim after the first dim with non-unit extracted
-      // size.
-      for (; i < e; ++i) {
-        if (!isOneInteger(sizes[indices[i]])) {
-          // +1 to skip the first non-unit size dim.
-          i++;
-          break;
-        }
-      }
-
-      // Verify that all subsequent dimensions extract the full size of the
-      // source tensor.
-      for (; i < e; ++i) {
-        int64_t expandedDim = indices[i];
-        if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
-                                     outputShape[expandedDim])) {
-          return rewriter.notifyMatchFailure(
-              sliceOp, "Not a contiguous slice of the expanded tensor.");
-        }
-      }
-    }
-
-    return success();
-  }
 };
 
 /// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,282 @@ struct BubbleUpCollapseShapeThroughExtractSlice
           "tensor.extract_slice source not produced by tensor.collapse_shape");
     }
 
-    if (!sliceOp.hasUnitStride()) {
-      return rewriter.notifyMatchFailure(
-          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
-                   "be supported in this transformation.");
-    }
+    SmallVector<OpFoldResult> offsets, sizes, strides;
+    if (failed(getExpandedExtractSliceInfo(
+            sliceOp, collapseShapeOp.getReassociationIndices(),
+            collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides,
+            rewriter)))
+      return failure();
 
-    // The tensor.extract_slice before applying the pattern works on the result
-    // of the tensor.collapse_shape, so variables (i.e. inputs for
-    // ExtractSliceOp) referring to the state before applying the pattern are
-    // named with the prefix "collapsed", and ones referring to the state after
-    // applying the pattern are named with the prefix "expanded".
-    SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
-    SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
-    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
-        collapsedSizes.size()) {
-      return rewriter.notifyMatchFailure(sliceOp,
-                                         "unimplemented: rank reducing slice");
-    }
+    Value newSliceOp = tensor::ExtractSliceOp::create(
+        rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+        sizes, strides);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        sliceOp, sliceOp.getResultType(), newSliceOp,
+        collapseShapeOp.getReassociationIndices());
 
-    ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
-    SmallVector<ReassociationIndices, 4> reassociationIndices =
-        collapseShapeOp.getReassociationIndices();
-
-    // Compute new offsets, sizes, and strides for tensor.extract_slice.
-    // The new tensor.extract_slice will work on a tensor that has has a rank
-    // equal to the rank of the src of the collapse_shape. In each iteration of
-    // the loop, the offsets and sizes will be computed per reassociation group.
-    SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
-    SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
-                                              rewriter.getIndexAttr(1));
-
-    for (auto [collapsedSize, collapsedOffset, reassocIndices] :
-         llvm::zip_equal(collapsedSizes, collapsedOffsets,
-                         collapseShapeOp.getReassociationIndices())) {
-      // CASE #1 - size and/or offset are dynamic.
-      // In this case, the slice can be represented as a contiguous slice only
-      // if there is a single dimension in the reassociation group that has a
-      // size not equal to 1.
-      if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
-        int nonUnitSizeCount = 0;
-        for (int64_t expandedShapeIdx : reassocIndices) {
-          if (srcShape[expandedShapeIdx] != 1) {
-            nonUnitSizeCount++;
-            expandedSizes.push_back(collapsedSize);
-            expandedOffsets.push_back(collapsedOffset);
-            continue;
-          }
-
-          expandedSizes.push_back(rewriter.getIndexAttr(1));
-          expandedOffsets.push_back(rewriter.getIndexAttr(0));
-        }
+    return success();
+  }
+};
 
-        if (nonUnitSizeCount != 1) {
-          return rewriter.notifyMatchFailure(
-              sliceOp,
-              "unsupported: slice cannot be verified to be contiguous");
-        }
-        continue;
-      }
+} // namespace
 
-      // CASE #2 = size and offset are static.
-      // Verify that the slice can be represented as a contiguous slice of the
-      // src of the collapse_shape.
-      // Checking this is done on order of most internal dimensions first,
-      // so traversal is done in reverse order of the reassociation group.
-      // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
-      // ...,An] then we first find the size and offset for n...k+1 then for k
-      // and then for k-1...0.
-
-      // currentCollapsedsize and currentCollapsedOffset are initialized with
-      // the original collapsed size and offset and divided by the expanded
-      // shape size in each dimension as we go along the reassociation group.
-      // In essence we are spreading the original collapsed size and offset over
-      // the various expanded slice dimensions.
-      // The variables are used both to check the validity of the slice and to
-      // compute the expanded sizes and offsets.
-      int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
-      int64_t currentCollapsedOffset =
-          getConstantIntValue(collapsedOffset).value();
-
-      SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
-      ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
-                                                  reassocIndices.rend());
-      int64_t idx = 0;
-      int64_t reassocGroupSize = reassocIndices.size();
-
-      // First handle the trailing dimensions where the slice size should be
-      // equal to the tensor shape and the offset should be 0 (n...k+1).
-      for (; idx < reassocGroupSize; ++idx) {
-        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
-        if (currentCollapsedsize < expandedShapeSize)
-          break;
-
-        // We need to make sure that the slice size can be set to the shape size
-        // and the offset to 0.
-        if ((currentCollapsedsize % expandedShapeSize) != 0 ||
-            (currentCollapsedOffset % expandedShapeSize) != 0) {
-          return rewriter.notifyMatchFailure(
-              sliceOp, "unsupported: cannot be extracted as a contiguous slice "
-                       "of the src of the collapse_shape");
-        }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+    tensor::ExtractSliceOp sliceOp,
+    ArrayRef<ReassociationIndices> reassociation,
+    SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+    SmallVectorImpl<OpFoldResult> &collapsedSizes,
+    SmallVectorImpl<OpFoldResult> &collapsedStrides, OpBuilder &b) {
+  if (!sliceOp.hasUnitStride()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
 
-        groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
-        groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+  if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+    return failure();
+  }
 
-        currentCollapsedsize /= expandedShapeSize;
-        currentCollapsedOffset /= expandedShapeSize;
+  auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+                                     OpFoldResult sliceSize, int64_t inputDim) {
+    if (!isZeroInteger(offset))
+      return false;
+    ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+    FailureOr<bool> maybeEqual =
+        ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+    return llvm::succeeded(maybeEqual) && maybeEqual.value();
+  };
+
+  // Check that the slice is contiguous within each reassociation group.
+  // The slice is contiguous only if after the first dimension where a non
+  // unit slice is taken, the slice size on all subsequent dimensions of the
+  // group is equal to the entire size of the dimension.
+  // Examples of contiguous slices:
+  //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+  //   full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+  // Examples of non contiguous slices:
+  //   full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+  //   full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+  for (const ReassociationIndices &indices : reassociation) {
+    int64_t i = 0;
+    int64_t e = indices.size();
+    // Find the first expanded dim after the first dim with non-unit extracted
+    // size.
+    for (; i < e; ++i) {
+      if (!isOneInteger(sizes[indices[i]])) {
+        // +1 to skip the first non-unit size dim.
+        i++;
+        break;
       }
+    }
+
+    // Verify that all subsequent dimensions extract the full size of the
+    // source tensor.
+    for (; i < e; ++i) {
+      int64_t expandedDim = indices[i];
+      if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+                                   expandedDim)) {
+        return failure();
+      }
+    }
+  }
+
+  // The tensor.extract_slice before applying the pattern works on the result
+  // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+  // referring to the state before applying the pattern are named with the
+  // prefix "expanded", and ones referring to the state after applying the
+  // pattern are named with the prefix "collapsed".
+  Location loc = sliceOp.getLoc();
+  SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+  SmallVector<OpFoldResult> expandedShape =
+      getMixedSizes(b, loc, sliceOp.getSource());
+
+  // Helper variables and function for accumulat...
[truncated]

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me. I am assuming this is mostly a move of the logic and that github is messing up the diff?

@IanWood1
Copy link
Contributor Author

Looks fine to me. I am assuming this is mostly a move of the logic and that github is messing up the diff?

The only real change here was to remove the calls to rewriter.notifyMatchFailure and replace them with return failure(); since the standalone functions use an OpBuilder instead of a PatternRewriter.

@IanWood1
Copy link
Contributor Author

@hanhanW @nicolasvasilache would either of you mind reviewing?

@hanhanW
Copy link
Contributor

hanhanW commented Aug 19, 2025

LGTM, maybe rename the title to [mlir][tensor][NFC] Refactor common methods for bubbling extract_slice op

@IanWood1 IanWood1 changed the title [mlir][tensor] Move extract_slice reshaping into two functions [mlir][tensor][NFC] Refactor common methods for bubbling extract_slice op Aug 19, 2025
@IanWood1 IanWood1 enabled auto-merge (squash) August 19, 2025 19:21
@IanWood1 IanWood1 merged commit 961b052 into llvm:main Aug 19, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants