From 72fbf710523e18d544b713652a6ee020f3063161 Mon Sep 17 00:00:00 2001 From: Ofri Frishman Date: Wed, 19 Mar 2025 09:02:42 +0200 Subject: [PATCH 1/6] [MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape Add a pattern that bubbles up tensor.extract_slice through tensor.collapse_shape. The pattern is registered in a pattern population function that is used by the transform op transform.apply_patterns.tensor.bubble_up_extract_slice and by the tranform op transform.structured.fuse as a cleanup pattern. This pattern enables tiling and fusing op chains which contain tensor.collapse_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.collapse_shape does not implement the tiling interface. This is an additional pattern to the one added in PR #126898 --- .../Tensor/Transforms/ReshapePatterns.cpp | 189 +++++++++++++++++- .../Dialect/Linalg/transform-op-fuse.mlir | 50 +++++ .../Tensor/bubble-up-extract-slice-op.mlir | 153 ++++++++++++++ 3 files changed, 391 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index acedf51d0e240..efa4d10817e39 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -12,8 +12,10 @@ #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" +#include using namespace mlir; using namespace mlir::tensor; @@ -428,6 +430,190 @@ struct BubbleUpExpandShapeThroughExtractSlice } }; +/// Converts `tensor.collapse_shape(tensor.extract_slice)` to +/// `tensor.extract_slice(tensor.collapse_shape)`. +/// +/// For this transformation to be possible, the slice must be representable as a +/// contiguous slice within each reassociation group of the src. +/// +/// In case the size and offset extracted are static then this is possible if +/// the following conditions are met: +/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn] +/// be the shape of a desired slice. A slice of shape S can be extracted as a +/// contiguous block of memory if and only if there exists an index k in {0, 1, +/// ..., n} such that: +/// S_i = 1 for all i < k (that is, all leading dimensions are singleton), +/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly +/// one dimension), +/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved +/// in full). +/// In other words, the slice shape S must be of the form: +/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] +/// +/// In case the size and/or offset extracted are dynamic then this is possible +/// only if there is single dimension in the reassociation group that has a size +/// not equal to 1. +/// In other words, the tensor shape must be of the form: +/// [ 1, 1, ..., 1, A, 1, ...,1 ] +/// Note - it might be possible to enable this pattern for more cases when the +/// size/offset are dynamic via performing an analysis of the possible values +/// that could be given to the size/offset. +/// +/// Example: +/// The transformation is possible because each reassociation group can be +/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], +/// [20->10]). +/// ``` +/// BEFORE: +/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... +/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> +/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] +/// tensor<128x7x20xf32> to tensor<32x?x10xf32> +/// +/// AFTER: +/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] +// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> +/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... +/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> +/// ``` +struct BubbleUpCollapseShapeThroughExtractSlice + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + sliceOp.getSource().getDefiningOp(); + if (!collapseShapeOp) + return rewriter.notifyMatchFailure( + sliceOp, + "tensor.extract_slice source not produced by tensor.collapse_shape"); + + if (!sliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure( + sliceOp, "unsupported: non-unit stride. Only contiguous slices can " + "be supported in this transformation."); + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.collapse_shape, so variables (i.e. inputs for + // ExtractSliceOp) referring to the state before applying the pattern are + // named with the prefix "collapsed", and ones referring to the state after + // applying the pattern are named with the prefix "expanded". + SmallVector collapsedOffsets = sliceOp.getMixedOffsets(); + SmallVector collapsedSizes = sliceOp.getMixedSizes(); + + if (static_cast(sliceOp.getResultType().getRank()) != + collapsedSizes.size()) + return rewriter.notifyMatchFailure(sliceOp, + "unimplemented: rank reducing slice"); + + ArrayRef srcShape = collapseShapeOp.getSrcType().getShape(); + SmallVector reassociationIndices = + collapseShapeOp.getReassociationIndices(); + + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank + // equal to the rank of the src of the collapse_shape. In each iteration of + // the loop, the offsets and sizes will be computed per reassociation group. + SmallVector expandedOffsets, expandedSizes; + SmallVector expandedStrides(srcShape.size(), + rewriter.getIndexAttr(1)); + + for (auto [groupIdx, reassocIndices] : + enumerate(collapseShapeOp.getReassociationIndices())) { + OpFoldResult collapsedSize = collapsedSizes[groupIdx]; + OpFoldResult collapsedOffset = collapsedOffsets[groupIdx]; + // Case #1 - size and/or offset are dynamic. + // In this case, the slice can be represented as a contiguous slice only + // if there is a single dimension in the reassociation group that has a + // size not equal to 1. + if (isa(collapsedSize) || isa(collapsedOffset)) { + int nonUnitSizeCount = 0; + for (int64_t expandedShapeIdx : reassocIndices) { + if (srcShape[expandedShapeIdx] != 1) { + nonUnitSizeCount++; + expandedSizes.emplace_back(collapsedSize); + expandedOffsets.emplace_back(collapsedOffset); + continue; + } + + expandedSizes.emplace_back(rewriter.getIndexAttr(1)); + expandedOffsets.emplace_back(rewriter.getIndexAttr(0)); + } + + if (nonUnitSizeCount != 1) { + return rewriter.notifyMatchFailure( + sliceOp, + "unsupported: slice cannot be verified to be contiguous"); + } + continue; + } + + // Case #2 = size and offset are static. + // Verify that the slice can be represented as a contiguous slice of the + // src of the collapse_shape. + // Checking this must be done on order of most + // internal dimensions first, so traversal is done in reverse order of the + // reassociation group. + int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value(); + int64_t collapsedOffsetValue = + getConstantIntValue(collapsedOffset).value(); + + SmallVector groupExpandedSizes, groupExpandedOffsets; + + for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) { + int64_t expandedShapeSize = srcShape[expandedShapeIdx]; + + // This is a dimension that slicing will occur on, so need to make sure + // that the slice size can be set to the shape size and the offset to 0. + if (collapsedSizeValue >= expandedShapeSize && + (collapsedSizeValue % expandedShapeSize != 0 || + collapsedOffsetValue % expandedShapeSize != 0)) { + return rewriter.notifyMatchFailure( + sliceOp, "unsupported: cannot be extracted as a contiguous slice " + "of the src of the collapse_shape"); + } + + int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize; + + // This is the dimension that slicing will occur along, so need to make + // sure that the slice size + offset will not exceed the shape size. + if (collapsedSizeValue < expandedShapeSize && + (collapsedSizeValue + offsetInDim) >= expandedShapeSize) { + return rewriter.notifyMatchFailure( + sliceOp, "unsupported: slice cannot be extracted as a contiguous " + "slice of the src of the collapse_shape"); + } + + groupExpandedSizes.emplace_back(rewriter.getIndexAttr( + std::min(collapsedSizeValue, expandedShapeSize))); + groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim)); + + // Remove the size and offset of trailing dimensions from the size and + // offset of the slice. + collapsedSizeValue /= expandedShapeSize; + collapsedSizeValue = std::max(collapsedSizeValue, 1); + collapsedOffsetValue /= expandedShapeSize; + } + + expandedSizes.append(groupExpandedSizes.rbegin(), + groupExpandedSizes.rend()); + expandedOffsets.append(groupExpandedOffsets.rbegin(), + groupExpandedOffsets.rend()); + } + + Value newSliceOp = rewriter.create( + collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets, + expandedSizes, expandedStrides); + rewriter.replaceOpWithNewOp( + sliceOp, sliceOp.getResultType(), newSliceOp, + collapseShapeOp.getReassociationIndices()); + + return success(); + } +}; + } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( @@ -448,5 +634,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns( void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 9bcc125ce1ba9..441020f1cddfc 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -438,3 +438,53 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape( +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) { +// CHECK: %[[EXTRACT1:.*]] = tensor.extract_slice +// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[EXTRACT1]] +// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE1]] +func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> { + %expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32> + %empty = tensor.empty() : tensor<8x1800x32xf32> + %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32> + return %exp : tensor<8x1800x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) + transform.yield + } +} + + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer( +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} +// CHECK: %[[VAL_9:.*]] = tensor.extract_slice +// CHECK: %[[VAL_11:.*]] = linalg.abs ins(%[[VAL_9]] +// CHECK: %[[VAL_12:.*]] = tensor.collapse_shape %[[VAL_11]] +// CHECK: %[[VAL_14:.*]] = linalg.exp ins(%[[VAL_12]] +func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> { + %empty1 = tensor.empty() : tensor<1x8x1800x32xf32> + %abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32> + %expand = tensor.collapse_shape %abs [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32> + %empty2 = tensor.empty() : tensor<8x1800x32xf32> + %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty2 : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32> + return %exp : tensor<8x1800x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) + transform.yield + } +} diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir index 3900bc56f433d..d05bf1bf76f29 100644 --- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir +++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir @@ -113,6 +113,159 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>, return %extract : tensor } +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group( +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%src: tensor<6x5x2xf32>) -> tensor<1xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[0][1][1] : tensor<60xf32> to tensor<1xf32> + return %extract : tensor<1xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> { +// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1] +// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3]] +// CHECK: return %[[VAL_2]] +func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> { + %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<6x5x3x10xf32> into tensor<30x30xf32> + %extract = tensor.extract_slice %collapse[5, 10][15, 10][1, 1] : tensor<30x30xf32> to tensor<15x10xf32> + return %extract : tensor<15x10xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> { +// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][2, 0, 0] [1, 2, 2] [1, 1, 1] +// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[VAL_2]] +func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[20][4][1] : tensor<60xf32> to tensor<4xf32> + return %extract : tensor<4xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(%src: tensor<1x5x1xf32>, %size : index) -> tensor { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32> + %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<5xf32> to tensor + return %extract : tensor +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x?x1xf32>, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(%src: tensor<1x?x1xf32>, %size : index) -> tensor { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x?x1xf32> into tensor + %extract = tensor.extract_slice %collapse[0][%size][1] : tensor to tensor + return %extract : tensor +} + + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<3xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, %[[OFFSET]], 0] [1, 3, 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(%src: tensor<1x5x1xf32>, %offset : index) -> tensor<3xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32> + %extract = tensor.extract_slice %collapse[%offset][3][1] : tensor<5xf32> to tensor<3xf32> + return %extract : tensor<3xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<14x1xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[OFFSET]], 0] {{\[}}%[[SIZE]], 1] [1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size(%src: tensor<14x1xf32>, %offset : index, %size : index) -> tensor { + %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<14x1xf32> into tensor<14xf32> + %extract = tensor.extract_slice %collapse[%offset][%size][1] : tensor<14xf32> to tensor + return %extract : tensor +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups( +// CHECK-SAME: %[[SRC:.*]]: tensor<5x10x1x1x40xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<20x?xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 0, 0, %[[OFFSET]]] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3, 4]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(%src: tensor<5x10x1x1x40xf32>, %offset : index, %size : index) -> tensor<20x?xf32> { + %collapse = tensor.collapse_shape %src [[0, 1], [2, 3, 4]] : tensor<5x10x1x1x40xf32> into tensor<50x40xf32> + %extract = tensor.extract_slice %collapse[10, %offset][20, %size][1, 1] : tensor<50x40xf32> to tensor<20x?xf32> + return %extract : tensor<20x?xf32> +} + +// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed +// shape cannot be defined as a contiguous size in the expanded shape due to size extracted not being suited +// for the expanded shape. +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1( +// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<15xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(%src: tensor<2x3x10xf32>) -> tensor<15xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[0][15][1] : tensor<60xf32> to tensor<15xf32> + return %extract : tensor<15xf32> +} + +// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed +// shape cannot be defined as a contiguous size in the expanded shape due to an unsuitable offset even though +// the size extracted is suited for the expanded shape. +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2( +// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<20xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(%src: tensor<2x3x10xf32>) -> tensor<20xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[20][20][1] : tensor<60xf32> to tensor<20xf32> + return %extract : tensor<20xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride( +// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<5xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride(%src: tensor<2x3x10xf32>) -> tensor<5xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[0][5][2] : tensor<60xf32> to tensor<5xf32> + return %extract : tensor<5xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing( +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2x1xf32>) -> tensor<1xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing(%src: tensor<6x5x2x1xf32>) -> tensor<1xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2], [3]] : tensor<6x5x2x1xf32> into tensor<60x1xf32> + %extract = tensor.extract_slice %collapse[0, 0][1, 1][1, 1] : tensor<60x1xf32> to tensor<1xf32> + return %extract : tensor<1xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x2xf32>, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic(%src: tensor<1x5x2xf32>, %size : index) -> tensor { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x2xf32> into tensor<10xf32> + %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<10xf32> to tensor + return %extract : tensor +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> From c0291d068b66e6b450521813ebb2ad9a3a4c75e2 Mon Sep 17 00:00:00 2001 From: Ofri Frishman Date: Sun, 23 Mar 2025 12:08:23 +0200 Subject: [PATCH 2/6] Updates for code review --- .../Tensor/Transforms/ReshapePatterns.cpp | 13 ++++---- .../Dialect/Linalg/transform-op-fuse.mlir | 14 ++++---- .../Tensor/bubble-up-extract-slice-op.mlir | 33 +++++++++++++++---- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index efa4d10817e39..cc73011151b17 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -15,7 +15,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" -#include using namespace mlir; using namespace mlir::tensor; @@ -533,13 +532,13 @@ struct BubbleUpCollapseShapeThroughExtractSlice for (int64_t expandedShapeIdx : reassocIndices) { if (srcShape[expandedShapeIdx] != 1) { nonUnitSizeCount++; - expandedSizes.emplace_back(collapsedSize); - expandedOffsets.emplace_back(collapsedOffset); + expandedSizes.push_back(collapsedSize); + expandedOffsets.push_back(collapsedOffset); continue; } - expandedSizes.emplace_back(rewriter.getIndexAttr(1)); - expandedOffsets.emplace_back(rewriter.getIndexAttr(0)); + expandedSizes.push_back(rewriter.getIndexAttr(1)); + expandedOffsets.push_back(rewriter.getIndexAttr(0)); } if (nonUnitSizeCount != 1) { @@ -586,9 +585,9 @@ struct BubbleUpCollapseShapeThroughExtractSlice "slice of the src of the collapse_shape"); } - groupExpandedSizes.emplace_back(rewriter.getIndexAttr( + groupExpandedSizes.push_back(rewriter.getIndexAttr( std::min(collapsedSizeValue, expandedShapeSize))); - groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim)); + groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); // Remove the size and offset of trailing dimensions from the size and // offset of the slice. diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 441020f1cddfc..d7339fa3c0be4 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -443,9 +443,9 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape( // CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) { -// CHECK: %[[EXTRACT1:.*]] = tensor.extract_slice -// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[EXTRACT1]] -// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE1]] +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] +// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE]] func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> { %expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32> %empty = tensor.empty() : tensor<8x1800x32xf32> @@ -467,10 +467,10 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer( // CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -// CHECK: %[[VAL_9:.*]] = tensor.extract_slice -// CHECK: %[[VAL_11:.*]] = linalg.abs ins(%[[VAL_9]] -// CHECK: %[[VAL_12:.*]] = tensor.collapse_shape %[[VAL_11]] -// CHECK: %[[VAL_14:.*]] = linalg.exp ins(%[[VAL_12]] +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +// CHECK: %[[ABS:.*]] = linalg.abs ins(%[[EXTRACT]] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] +// CHECK: %[[EXP:.*]] = linalg.exp ins(%[[COLLAPSE]] func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> { %empty1 = tensor.empty() : tensor<1x8x1800x32xf32> %abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32> diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir index d05bf1bf76f29..c0755d7125091 100644 --- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir +++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir @@ -1,5 +1,15 @@ // RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s +///---------------------------------------------------------------------------------------- +/// [Pattern: BubbleUpExpandShapeThroughExtractSlice] +/// +/// IN: tensor.expand_shape(tensor.extract_slice) +/// OUT:tensor.extract_slice(tensor.expand_shape) +/// +/// Note: tensor.extract_slice is bubbled up to be before tensor.expand_shape. +/// Some tests are negative tests for cases where the pattern cannot be applied. +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape( // CHECK-SAME: %[[SRC:.*]]: tensor<60xf32>) -> tensor<1x1x5xf32> { // CHECK: %[[C1:.+]] = arith.constant 5 : index @@ -113,6 +123,16 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>, return %extract : tensor } +///---------------------------------------------------------------------------------------- +/// [Pattern: BubbleUpCollapseShapeThroughExtractSlice] +/// +/// IN: tensor.collapse_shape(tensor.extract_slice) +/// OUT:tensor.extract_slice(tensor.collapse_shape) +/// +/// Note: tensor.extract_slice is bubbled up to be before tensor.collapse_shape. +/// Some tests are negative tests for cases where the pattern cannot be applied. +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group( // CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> { // CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1] @@ -209,9 +229,13 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_gro return %extract : tensor<20x?xf32> } -// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed -// shape cannot be defined as a contiguous size in the expanded shape due to size extracted not being suited -// for the expanded shape. +/// The 2 following tests are cases where the bubble up cannot occur because the contiguous size extracted +/// from the collapsed shape cannot be expressed via a single extract_slice op. +/// In the first test it is because the size extracted cannot be expressed as a slice +/// of the form [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] (see the pattern documentation for more details). +/// In the second test, the size can be expressed as the required form, but the offset is such that the pattern +/// cannot be applied. + // CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1( // CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<15xf32> { // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape @@ -222,9 +246,6 @@ func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1 return %extract : tensor<15xf32> } -// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed -// shape cannot be defined as a contiguous size in the expanded shape due to an unsuitable offset even though -// the size extracted is suited for the expanded shape. // CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2( // CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<20xf32> { // CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape From 72b0be3bd808f94bcf062cdcfee684be553bf507 Mon Sep 17 00:00:00 2001 From: Ofri Frishman Date: Mon, 24 Mar 2025 12:13:56 +0200 Subject: [PATCH 3/6] Updates for CR --- .../Tensor/Transforms/ReshapePatterns.cpp | 111 +++++++++++++----- .../Dialect/Linalg/transform-op-fuse.mlir | 1 - 2 files changed, 80 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index cc73011151b17..6ec7c58f85b4e 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -429,17 +429,19 @@ struct BubbleUpExpandShapeThroughExtractSlice } }; -/// Converts `tensor.collapse_shape(tensor.extract_slice)` to -/// `tensor.extract_slice(tensor.collapse_shape)`. +/// Converts `tensor.extract_slice(tensor.collapse_shape)` to +/// `tensor.collapse_shape(tensor.extract_slice)`. /// -/// For this transformation to be possible, the slice must be representable as a -/// contiguous slice within each reassociation group of the src. +/// For this transformation to be possible - after bubbling up, the extraction +/// of the contiguous slice must be representable as a single slice obtained via +/// tensor.extract_slice within each reassociation group of the src. /// /// In case the size and offset extracted are static then this is possible if -/// the following conditions are met: -/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn] -/// be the shape of a desired slice. A slice of shape S can be extracted as a -/// contiguous block of memory if and only if there exists an index k in {0, 1, +/// the following conditions are met within each reassociation group: +/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the +/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the +/// shape of a desired slice. A slice of shape S can be extracted as a +/// contiguous span of elements if and only if there exists an index k in {0, 1, /// ..., n} such that: /// S_i = 1 for all i < k (that is, all leading dimensions are singleton), /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly @@ -475,6 +477,31 @@ struct BubbleUpExpandShapeThroughExtractSlice /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> /// ``` +/// +/// Negative example: +/// The transformation is not possible because we cannot use a single slice to +/// represent the reassociation group [2x3x10->???]. If we would want the +/// collapse to be after the extraction, we would need to extract multiple +/// slices and concat them together. +/// ``` +/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into +/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : +/// tensor<60xf32> to tensor<15xf32> +/// ``` +/// If we would want the collapse to be after the extraction, a possible +/// alternate transformation could be to extract multiple slices and concat them +/// together: +/// ``` +/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : +/// tensor<2x3x10xf32> to tensor <1x1x10xf32> +/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : +/// tensor<2x3x10xf32> to tensor <1x1x5xf32> +/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : +/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> +/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> +/// to tensor<15xf32> +/// ``` +/// But this is not the intended purpose of the transformation. struct BubbleUpCollapseShapeThroughExtractSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -552,47 +579,69 @@ struct BubbleUpCollapseShapeThroughExtractSlice // Case #2 = size and offset are static. // Verify that the slice can be represented as a contiguous slice of the // src of the collapse_shape. - // Checking this must be done on order of most - // internal dimensions first, so traversal is done in reverse order of the - // reassociation group. + // Checking this is done on order of most internal dimensions first, + // so traversal is done in reverse order of the reassociation group. + // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, + // ...,An] then we first find the size and offset for n...k+1 then for k + // and then for k-1...0. int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value(); int64_t collapsedOffsetValue = getConstantIntValue(collapsedOffset).value(); SmallVector groupExpandedSizes, groupExpandedOffsets; - for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) { - int64_t expandedShapeSize = srcShape[expandedShapeIdx]; + ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), + reassocIndices.rend()); + int64_t idx = 0; + int64_t reassocGroupSize = reassocIndices.size(); + + // First handle the trailing dimensions where the slice size should be + // equal to the tensor shape and the offset should be 0 (n...k+1). + for (; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - // This is a dimension that slicing will occur on, so need to make sure - // that the slice size can be set to the shape size and the offset to 0. - if (collapsedSizeValue >= expandedShapeSize && - (collapsedSizeValue % expandedShapeSize != 0 || - collapsedOffsetValue % expandedShapeSize != 0)) { + if (collapsedSizeValue < expandedShapeSize) + break; + + // We need to make sure that the slice size can be set to the shape size + // and the offset to 0. + if ((collapsedSizeValue % expandedShapeSize) != 0 || + (collapsedOffsetValue % expandedShapeSize) != 0) return rewriter.notifyMatchFailure( sliceOp, "unsupported: cannot be extracted as a contiguous slice " "of the src of the collapse_shape"); - } - int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize; + groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); + groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); + + collapsedSizeValue /= expandedShapeSize; + collapsedOffsetValue /= expandedShapeSize; + } - // This is the dimension that slicing will occur along, so need to make - // sure that the slice size + offset will not exceed the shape size. - if (collapsedSizeValue < expandedShapeSize && - (collapsedSizeValue + offsetInDim) >= expandedShapeSize) { + // Now handle the first dim where slicing occurs on (k). + if (idx < reassocGroupSize) { + int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize; + // We need to make sure that the slice size in this dim + offset will + // not exceed the shape size. + if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize) return rewriter.notifyMatchFailure( sliceOp, "unsupported: slice cannot be extracted as a contiguous " "slice of the src of the collapse_shape"); - } - groupExpandedSizes.push_back(rewriter.getIndexAttr( - std::min(collapsedSizeValue, expandedShapeSize))); + groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue)); groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); - // Remove the size and offset of trailing dimensions from the size and - // offset of the slice. - collapsedSizeValue /= expandedShapeSize; - collapsedSizeValue = std::max(collapsedSizeValue, 1); + collapsedOffsetValue /= expandedShapeSize; + } + + // Now handle the leading dimensions where the slice size is equal to 1 + // (k-1...0). + for (idx++; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize; + groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); + groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); collapsedOffsetValue /= expandedShapeSize; } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index d7339fa3c0be4..962858076db93 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -462,7 +462,6 @@ module attributes {transform.with_named_sequence} { } } - // ----- // CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer( From 1aaf3c907272a1bf1d4a2a870e93565079390d07 Mon Sep 17 00:00:00 2001 From: Ofri Frishman Date: Tue, 25 Mar 2025 10:45:32 +0200 Subject: [PATCH 4/6] Fix lit test checker variable names --- .../Tensor/bubble-up-extract-slice-op.mlir | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir index c0755d7125091..34128d6a5ec8b 100644 --- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir +++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir @@ -145,10 +145,10 @@ func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(% } // CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> { -// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1] -// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3]] -// CHECK: return %[[VAL_2]] +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3]] +// CHECK: return %[[COLLAPSE]] func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> { %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<6x5x3x10xf32> into tensor<30x30xf32> %extract = tensor.extract_slice %collapse[5, 10][15, 10][1, 1] : tensor<30x30xf32> to tensor<15x10xf32> @@ -156,10 +156,10 @@ func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group } // CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> { -// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][2, 0, 0] [1, 2, 2] [1, 1, 1] -// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2]] -// CHECK: return %[[VAL_2]] +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][2, 0, 0] [1, 2, 2] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> { %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32> %extract = tensor.extract_slice %collapse[20][4][1] : tensor<60xf32> to tensor<4xf32> From 5845db6a8de787a5df0c684f85286ae3a4ddf4f4 Mon Sep 17 00:00:00 2001 From: Ofri Frishman Date: Tue, 25 Mar 2025 16:08:09 +0200 Subject: [PATCH 5/6] Additional clarifications --- .../Tensor/Transforms/ReshapePatterns.cpp | 44 ++++++++++++------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 6ec7c58f85b4e..0a7fcba7a71cd 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -550,7 +550,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice enumerate(collapseShapeOp.getReassociationIndices())) { OpFoldResult collapsedSize = collapsedSizes[groupIdx]; OpFoldResult collapsedOffset = collapsedOffsets[groupIdx]; - // Case #1 - size and/or offset are dynamic. + // CASE #1 - size and/or offset are dynamic. // In this case, the slice can be represented as a contiguous slice only // if there is a single dimension in the reassociation group that has a // size not equal to 1. @@ -576,7 +576,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice continue; } - // Case #2 = size and offset are static. + // CASE #2 = size and offset are static. // Verify that the slice can be represented as a contiguous slice of the // src of the collapse_shape. // Checking this is done on order of most internal dimensions first, @@ -584,8 +584,16 @@ struct BubbleUpCollapseShapeThroughExtractSlice // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, // ...,An] then we first find the size and offset for n...k+1 then for k // and then for k-1...0. - int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value(); - int64_t collapsedOffsetValue = + + // currentCollapsedsize and currentCollapsedOffset are initialized with + // the original collapsed size and offset and divided by the expanded + // shape size in each dimension as we go along the reassociation group. + // In essence we are spreading the original collapsed size and offset over + // the various expanded slice dimensions. + // The variables are used both to check the validity of the slice and to + // compute the expanded sizes and offsets. + int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); + int64_t currentCollapsedOffset = getConstantIntValue(collapsedOffset).value(); SmallVector groupExpandedSizes, groupExpandedOffsets; @@ -600,13 +608,13 @@ struct BubbleUpCollapseShapeThroughExtractSlice for (; idx < reassocGroupSize; ++idx) { int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - if (collapsedSizeValue < expandedShapeSize) + if (currentCollapsedsize < expandedShapeSize) break; // We need to make sure that the slice size can be set to the shape size // and the offset to 0. - if ((collapsedSizeValue % expandedShapeSize) != 0 || - (collapsedOffsetValue % expandedShapeSize) != 0) + if ((currentCollapsedsize % expandedShapeSize) != 0 || + (currentCollapsedOffset % expandedShapeSize) != 0) return rewriter.notifyMatchFailure( sliceOp, "unsupported: cannot be extracted as a contiguous slice " "of the src of the collapse_shape"); @@ -614,35 +622,41 @@ struct BubbleUpCollapseShapeThroughExtractSlice groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); - collapsedSizeValue /= expandedShapeSize; - collapsedOffsetValue /= expandedShapeSize; + currentCollapsedsize /= expandedShapeSize; + currentCollapsedOffset /= expandedShapeSize; } // Now handle the first dim where slicing occurs on (k). if (idx < reassocGroupSize) { int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; // We need to make sure that the slice size in this dim + offset will // not exceed the shape size. - if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize) + if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) return rewriter.notifyMatchFailure( sliceOp, "unsupported: slice cannot be extracted as a contiguous " "slice of the src of the collapse_shape"); - groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue)); + groupExpandedSizes.push_back( + rewriter.getIndexAttr(currentCollapsedsize)); groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); - collapsedOffsetValue /= expandedShapeSize; + currentCollapsedOffset /= expandedShapeSize; } // Now handle the leading dimensions where the slice size is equal to 1 // (k-1...0). + // The size for these dimensions must be 1 because of how we constructed + // the slice size of the expanded shape. We spread the original collapsed + // size over the expanded shape sizes until we reached dimension k where + // the remaining size was smaller than the expanded shape size, and spread + // the remaining size on it. So, now we are left with only 1s. for (idx++; idx < reassocGroupSize; ++idx) { int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); - collapsedOffsetValue /= expandedShapeSize; + currentCollapsedOffset /= expandedShapeSize; } expandedSizes.append(groupExpandedSizes.rbegin(), From 3c69390e03c69b8137da78e5a660d70273be5df2 Mon Sep 17 00:00:00 2001 From: Ofri Frishman Date: Wed, 2 Apr 2025 13:20:55 +0300 Subject: [PATCH 6/6] CR fixes --- .../Tensor/Transforms/ReshapePatterns.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 0a7fcba7a71cd..eed44e60d6591 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -510,10 +510,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice PatternRewriter &rewriter) const override { auto collapseShapeOp = sliceOp.getSource().getDefiningOp(); - if (!collapseShapeOp) + if (!collapseShapeOp) { return rewriter.notifyMatchFailure( sliceOp, "tensor.extract_slice source not produced by tensor.collapse_shape"); + } if (!sliceOp.hasUnitStride()) { return rewriter.notifyMatchFailure( @@ -530,9 +531,10 @@ struct BubbleUpCollapseShapeThroughExtractSlice SmallVector collapsedSizes = sliceOp.getMixedSizes(); if (static_cast(sliceOp.getResultType().getRank()) != - collapsedSizes.size()) + collapsedSizes.size()) { return rewriter.notifyMatchFailure(sliceOp, "unimplemented: rank reducing slice"); + } ArrayRef srcShape = collapseShapeOp.getSrcType().getShape(); SmallVector reassociationIndices = @@ -546,10 +548,9 @@ struct BubbleUpCollapseShapeThroughExtractSlice SmallVector expandedStrides(srcShape.size(), rewriter.getIndexAttr(1)); - for (auto [groupIdx, reassocIndices] : - enumerate(collapseShapeOp.getReassociationIndices())) { - OpFoldResult collapsedSize = collapsedSizes[groupIdx]; - OpFoldResult collapsedOffset = collapsedOffsets[groupIdx]; + for (auto [collapsedSize, collapsedOffset, reassocIndices] : + llvm::zip_equal(collapsedSizes, collapsedOffsets, + collapseShapeOp.getReassociationIndices())) { // CASE #1 - size and/or offset are dynamic. // In this case, the slice can be represented as a contiguous slice only // if there is a single dimension in the reassociation group that has a @@ -614,10 +615,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice // We need to make sure that the slice size can be set to the shape size // and the offset to 0. if ((currentCollapsedsize % expandedShapeSize) != 0 || - (currentCollapsedOffset % expandedShapeSize) != 0) + (currentCollapsedOffset % expandedShapeSize) != 0) { return rewriter.notifyMatchFailure( sliceOp, "unsupported: cannot be extracted as a contiguous slice " "of the src of the collapse_shape"); + } groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); @@ -632,10 +634,11 @@ struct BubbleUpCollapseShapeThroughExtractSlice int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; // We need to make sure that the slice size in this dim + offset will // not exceed the shape size. - if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) + if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { return rewriter.notifyMatchFailure( sliceOp, "unsupported: slice cannot be extracted as a contiguous " "slice of the src of the collapse_shape"); + } groupExpandedSizes.push_back( rewriter.getIndexAttr(currentCollapsedsize));