-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add pattern to bubble up tensor.extract_slice #126898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,10 +6,14 @@ | |
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
| #include "mlir/Dialect/Arith/Utils/Utils.h" | ||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
| #include "mlir/Dialect/Tensor/Transforms/Transforms.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" | ||
| #include "llvm/Support/Debug.h" | ||
| #include "llvm/Support/LogicalResult.h" | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::tensor; | ||
|
|
@@ -210,6 +214,214 @@ struct BubbleUpExpandThroughParallelCollapse | |
| } | ||
| }; | ||
|
|
||
| /// Converts `tensor.extract_slice(tensor.expand_shape)` to | ||
| /// `tensor.expand_shape(tensor.extract_slice)`. | ||
| /// | ||
| /// For this transformation to be possible, the slice must be fully contiguous | ||
| /// within each reassociation group of the expand_shape. A slice is defined as | ||
| /// fully contiguous within a reassociation group if after flattening the | ||
| /// reassociation group to a single 1D range, then the slice taken out of the | ||
| /// group could be defined as a single contiguous subrange within that range. | ||
| /// | ||
| /// Rank reducing slices are not supported. | ||
| /// | ||
| /// Example: | ||
| /// The transformation is possible because each reassociation group has a | ||
| /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]). | ||
| /// ``` | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// BEFORE: | ||
| /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]] | ||
| /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32> | ||
| /// %slice = tensor.extract_slice %reshape ... | ||
| /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32> | ||
| /// | ||
| /// AFTER: | ||
| /// %slice = tensor.extract_slice %in ... | ||
| /// tensor<8x16x32xf32> to tensor<8x5x4xf32> | ||
| /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]] | ||
| /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32> | ||
| /// ``` | ||
| /// | ||
| /// Note - this pattern could be extended to be a swap pattern between | ||
| /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently | ||
| /// implemented only as a bubble up pattern for `tensor.extract_slice`. | ||
| struct BubbleUpExpandShapeThroughExtractSlice | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm coming from #153675, which performs the refactoring. I'm sorry that I did not ask this question in the first place. May I ask why you name it as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are asking about the order of operations in the name and not about camel case vs. "_" right?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I also checked other patterns in the upstream, but some of them were created by me. I'm not a native speaker, so I'm not sure if I made a mistake or not. Thanks for checking, I'll prepare a PR after we land the refactoring change. |
||
| : public OpRewritePattern<tensor::ExtractSliceOp> { | ||
| using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, | ||
| PatternRewriter &rewriter) const override { | ||
| auto expandShapeOp = | ||
| sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); | ||
|
|
||
| if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, | ||
| rewriter) | ||
| .failed()) | ||
| return failure(); | ||
|
|
||
| // The tensor.extract_slice before applying the pattern works on the result | ||
| // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) | ||
| // referring to the state before applying the pattern are named with the | ||
| // prefix "expanded", and ones referring to the state after applying the | ||
| // pattern are named with the prefix "collapsed". | ||
| SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); | ||
| SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); | ||
| SmallVector<OpFoldResult> expandedShape = | ||
| getMixedValues(expandShapeOp.getStaticOutputShape(), | ||
| expandShapeOp.getOutputShape(), rewriter); | ||
|
|
||
| // Helper variables and function for accumulating the size values. | ||
| Location loc = expandShapeOp->getLoc(); | ||
| AffineExpr d0, d1, d2; | ||
| bindDims(rewriter.getContext(), d0, d1, d2); | ||
| // Multiply two integers. | ||
| auto mul = [&](OpFoldResult v1, OpFoldResult v2) { | ||
| auto mulMap = AffineMap::get(2, 0, {d0 * d1}); | ||
| return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, | ||
| {v1, v2}); | ||
| }; | ||
|
|
||
| // Compute new offsets, sizes, and strides for tensor.extract_slice. | ||
| // The new tensor.extract_slice will work on a tensor that has has a rank of | ||
| // ReassociationIndices.size(). In the loop a single offset, size, and | ||
| // stride value is computed per reassociation group. | ||
| SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes, | ||
| collapsedStrides; | ||
| for (const ReassociationIndices &indices : | ||
| expandShapeOp.getReassociationIndices()) { | ||
| // collapsedSize will hold the size of the single dim that represents the | ||
| // reassociation group in the non expanded tensor. | ||
| OpFoldResult collapsedSize = rewriter.getIndexAttr(1); | ||
| // The reassocGroupSizes and reassocGroupOffsets are used to create an | ||
| // affine.linearize_index op to linearize the single offset value required | ||
| // for this reassociation group. | ||
| SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| for (long expandedDim : indices) { | ||
| // reassocGroupSizes and reassocGroupOffsets can be obtained directly | ||
| // from the expanded state, but the collapsed size requires calculation | ||
| // as it did not previously exist. | ||
| reassocGroupSizes.push_back(expandedShape[expandedDim]); | ||
| reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); | ||
| collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); | ||
| } | ||
|
|
||
| SmallVector<Value> offsetVals = | ||
| llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { | ||
| return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); | ||
| }); | ||
| OpFoldResult collapsedOffset = | ||
| rewriter | ||
| .create<affine::AffineLinearizeIndexOp>(loc, offsetVals, | ||
| reassocGroupSizes, | ||
| /*disjoint=*/true) | ||
| .getResult(); | ||
| collapsedOffsets.push_back(collapsedOffset); | ||
| collapsedSizes.push_back(collapsedSize); | ||
|
|
||
| // Only unit stride is supported. | ||
| collapsedStrides.push_back(rewriter.getIndexAttr(1)); | ||
| } | ||
|
|
||
| // The shape of the result can be obtained from the sizes passed in. | ||
| SmallVector<Value> dynDims; | ||
| SmallVector<int64_t> shape; | ||
| dispatchIndexOpFoldResults(expandedSizes, dynDims, shape); | ||
| RankedTensorType resultType = RankedTensorType::get( | ||
| shape, expandShapeOp.getResultType().getElementType()); | ||
|
|
||
| // Create a new ExtractSliceOp and ExpandShapeOp. | ||
| Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( | ||
| loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, | ||
| collapsedStrides); | ||
| rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( | ||
| sliceOp, resultType, newSliceOp, | ||
| expandShapeOp.getReassociationIndices(), expandedSizes); | ||
| return success(); | ||
| } | ||
|
|
||
| // Helper function to check if all the required conditions for the | ||
| // tensor.extract_slice to be bubbled up through the tensor.expand_shape are | ||
| // met. | ||
| LogicalResult | ||
| checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tensor::ExpandShapeOp expandShapeOp, | ||
| PatternRewriter &rewriter) const { | ||
|
|
||
| if (!expandShapeOp) { | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, "tensor.extract_slice source not produced by expand_shape"); | ||
| } | ||
|
|
||
| if (!sliceOp.hasUnitStride()) { | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, "unsupported: non-unit stride. Only contiguous slices can " | ||
| "be supported in this transformation."); | ||
| } | ||
|
|
||
| SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); | ||
| SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); | ||
|
|
||
| if (static_cast<size_t>(sliceOp.getResultType().getRank()) != | ||
| sizes.size()) { | ||
| return rewriter.notifyMatchFailure(sliceOp, | ||
| "unimplemented: rank reducing slice"); | ||
| } | ||
|
|
||
| SmallVector<OpFoldResult> outputShape = | ||
| getMixedValues(expandShapeOp.getStaticOutputShape(), | ||
| expandShapeOp.getOutputShape(), rewriter); | ||
|
|
||
| std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)> | ||
ofri-frishman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| isZeroOffsetAndFullSize = | ||
| [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { | ||
| if (!isConstantIntValue(offset, 0)) | ||
| return false; | ||
| FailureOr<bool> maybeEqual = | ||
| ValueBoundsConstraintSet::areEqual(sliceSize, size); | ||
| return llvm::succeeded(maybeEqual) && maybeEqual.value(); | ||
| }; | ||
|
|
||
| // Check that the slice is contiguous within each reassociation group. | ||
| // The slice is contiguous only if after the first dimension where a non | ||
| // unit slice is taken, the slice size on all subsequent dimensions of the | ||
| // group is equal to the entire size of the dimension. | ||
| // Examples of contiguous slices: | ||
| // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] | ||
| // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] | ||
| // Examples of non contiguous slices: | ||
| // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] | ||
| // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] | ||
| for (const ReassociationIndices &indices : | ||
| expandShapeOp.getReassociationIndices()) { | ||
| int64_t i = 0; | ||
| int64_t e = indices.size(); | ||
| // Find the first expanded dim after the first dim with non-unit extracted | ||
| // size. | ||
| for (; i < e; ++i) { | ||
| if (!isConstantIntValue(sizes[indices[i]], 1)) { | ||
| // +1 to skip the first non-unit size dim. | ||
| i++; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| // Verify that all subsequent dimensions extract the full size of the | ||
| // source tensor. | ||
| for (; i < e; ++i) { | ||
| int64_t expandedDim = indices[i]; | ||
| if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], | ||
| outputShape[expandedDim])) { | ||
| return rewriter.notifyMatchFailure( | ||
| sliceOp, "Not a contiguous slice of the expanded tensor."); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| void mlir::tensor::populateReassociativeReshapeFoldingPatterns( | ||
|
|
@@ -227,3 +439,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns( | |
| RewritePatternSet &patterns) { | ||
| patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext()); | ||
| } | ||
|
|
||
| void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( | ||
| RewritePatternSet &patterns) { | ||
| patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext()); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.