-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tensor][NFC] Refactor common methods for bubbling extract_slice op #153675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Ian Wood <[email protected]>
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Ian Wood (IanWood1) ChangesExposes the This should also make it easier to implement the two other bubbling cases: (1) the Patch is 30.98 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153675.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 87deef9ca7466..2602252916388 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -142,6 +142,34 @@ FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
ValueRange independencies);
+/// Computes the offsets, sizes, and strides needed to build a collapsed
+/// `sliceOp`. The dimensions to collapse are specified by `reassociation`.
+///
+/// This fails when the specified collapse cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getCollapsedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides,
+ OpBuilder &b);
+
+/// Computes the offsets, sizes, and strides needed to build an expanded
+/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
+/// `expandedShape`.
+///
+/// This fails when the specified expansion cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getExpandedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &expandedOffsets,
+ SmallVectorImpl<OpFoldResult> &expandedSizes,
+ SmallVectorImpl<OpFoldResult> &expandedStrides,
+ OpBuilder &b);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1fb35ce..a93681b1fce92 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
PatternRewriter &rewriter) const override {
auto expandShapeOp =
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "tensor.extract_slice source not produced by expand_shape");
+ }
+ SmallVector<ReassociationIndices> reassociation =
+ expandShapeOp.getReassociationIndices();
- if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
- rewriter)
- .failed())
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getCollapsedExtractSliceInfo(sliceOp, reassociation, offsets,
+ sizes, strides, rewriter)))
return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
- // referring to the state before applying the pattern are named with the
- // prefix "expanded", and ones referring to the state after applying the
- // pattern are named with the prefix "collapsed".
- SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
- SmallVector<OpFoldResult> expandedShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- // Helper variables and function for accumulating the size values.
- Location loc = expandShapeOp->getLoc();
- AffineExpr d0, d1, d2;
- bindDims(rewriter.getContext(), d0, d1, d2);
- // Multiply two integers.
- auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
- auto mulMap = AffineMap::get(2, 0, {d0 * d1});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2});
- };
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank of
- // ReassociationIndices.size(). In the loop a single offset, size, and
- // stride value is computed per reassociation group.
- SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
- collapsedStrides;
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- // collapsedSize will hold the size of the single dim that represents the
- // reassociation group in the non expanded tensor.
- OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
- // The reassocGroupSizes and reassocGroupOffsets are used to create an
- // affine.linearize_index op to linearize the single offset value required
- // for this reassociation group.
- SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
- for (long expandedDim : indices) {
- // reassocGroupSizes and reassocGroupOffsets can be obtained directly
- // from the expanded state, but the collapsed size requires calculation
- // as it did not previously exist.
- reassocGroupSizes.push_back(expandedShape[expandedDim]);
- reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
- collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
- }
-
- SmallVector<Value> offsetVals =
- llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
- });
- OpFoldResult collapsedOffset =
- affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
- reassocGroupSizes,
- /*disjoint=*/true)
- .getResult();
- collapsedOffsets.push_back(collapsedOffset);
- collapsedSizes.push_back(collapsedSize);
-
- // Only unit stride is supported.
- collapsedStrides.push_back(rewriter.getIndexAttr(1));
- }
-
// The shape of the result can be obtained from the sizes passed in.
- SmallVector<Value> dynDims;
- SmallVector<int64_t> shape;
- dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
- RankedTensorType resultType = RankedTensorType::get(
- shape, expandShapeOp.getResultType().getElementType());
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ RankedTensorType resultType = sliceOp.getResultType();
// Create a new ExtractSliceOp and ExpandShapeOp.
+ Location loc = sliceOp.getLoc();
Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
- collapsedStrides);
+ rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSliceOp,
expandShapeOp.getReassociationIndices(), expandedSizes);
return success();
}
-
- // Helper function to check if all the required conditions for the
- // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
- // met.
- LogicalResult
- checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
- tensor::ExpandShapeOp expandShapeOp,
- PatternRewriter &rewriter) const {
-
- if (!expandShapeOp) {
- return rewriter.notifyMatchFailure(
- sliceOp, "tensor.extract_slice source not produced by expand_shape");
- }
-
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
-
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- sizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
-
- SmallVector<OpFoldResult> outputShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
- isZeroOffsetAndFullSize =
- [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isZeroInteger(offset))
- return false;
- FailureOr<bool> maybeEqual =
- ValueBoundsConstraintSet::areEqual(sliceSize, size);
- return llvm::succeeded(maybeEqual) && maybeEqual.value();
- };
-
- // Check that the slice is contiguous within each reassociation group.
- // The slice is contiguous only if after the first dimension where a non
- // unit slice is taken, the slice size on all subsequent dimensions of the
- // group is equal to the entire size of the dimension.
- // Examples of contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
- // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
- // Examples of non contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
- // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- int64_t i = 0;
- int64_t e = indices.size();
- // Find the first expanded dim after the first dim with non-unit extracted
- // size.
- for (; i < e; ++i) {
- if (!isOneInteger(sizes[indices[i]])) {
- // +1 to skip the first non-unit size dim.
- i++;
- break;
- }
- }
-
- // Verify that all subsequent dimensions extract the full size of the
- // source tensor.
- for (; i < e; ++i) {
- int64_t expandedDim = indices[i];
- if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
- outputShape[expandedDim])) {
- return rewriter.notifyMatchFailure(
- sliceOp, "Not a contiguous slice of the expanded tensor.");
- }
- }
- }
-
- return success();
- }
};
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,282 @@ struct BubbleUpCollapseShapeThroughExtractSlice
"tensor.extract_slice source not produced by tensor.collapse_shape");
}
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getExpandedExtractSliceInfo(
+ sliceOp, collapseShapeOp.getReassociationIndices(),
+ collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides,
+ rewriter)))
+ return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.collapse_shape, so variables (i.e. inputs for
- // ExtractSliceOp) referring to the state before applying the pattern are
- // named with the prefix "collapsed", and ones referring to the state after
- // applying the pattern are named with the prefix "expanded".
- SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- collapsedSizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
+ Value newSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+ sizes, strides);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ sliceOp, sliceOp.getResultType(), newSliceOp,
+ collapseShapeOp.getReassociationIndices());
- ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
- SmallVector<ReassociationIndices, 4> reassociationIndices =
- collapseShapeOp.getReassociationIndices();
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank
- // equal to the rank of the src of the collapse_shape. In each iteration of
- // the loop, the offsets and sizes will be computed per reassociation group.
- SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
- SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
- rewriter.getIndexAttr(1));
-
- for (auto [collapsedSize, collapsedOffset, reassocIndices] :
- llvm::zip_equal(collapsedSizes, collapsedOffsets,
- collapseShapeOp.getReassociationIndices())) {
- // CASE #1 - size and/or offset are dynamic.
- // In this case, the slice can be represented as a contiguous slice only
- // if there is a single dimension in the reassociation group that has a
- // size not equal to 1.
- if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
- int nonUnitSizeCount = 0;
- for (int64_t expandedShapeIdx : reassocIndices) {
- if (srcShape[expandedShapeIdx] != 1) {
- nonUnitSizeCount++;
- expandedSizes.push_back(collapsedSize);
- expandedOffsets.push_back(collapsedOffset);
- continue;
- }
-
- expandedSizes.push_back(rewriter.getIndexAttr(1));
- expandedOffsets.push_back(rewriter.getIndexAttr(0));
- }
+ return success();
+ }
+};
- if (nonUnitSizeCount != 1) {
- return rewriter.notifyMatchFailure(
- sliceOp,
- "unsupported: slice cannot be verified to be contiguous");
- }
- continue;
- }
+} // namespace
- // CASE #2 = size and offset are static.
- // Verify that the slice can be represented as a contiguous slice of the
- // src of the collapse_shape.
- // Checking this is done on order of most internal dimensions first,
- // so traversal is done in reverse order of the reassociation group.
- // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
- // ...,An] then we first find the size and offset for n...k+1 then for k
- // and then for k-1...0.
-
- // currentCollapsedsize and currentCollapsedOffset are initialized with
- // the original collapsed size and offset and divided by the expanded
- // shape size in each dimension as we go along the reassociation group.
- // In essence we are spreading the original collapsed size and offset over
- // the various expanded slice dimensions.
- // The variables are used both to check the validity of the slice and to
- // compute the expanded sizes and offsets.
- int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
- int64_t currentCollapsedOffset =
- getConstantIntValue(collapsedOffset).value();
-
- SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
- ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
- reassocIndices.rend());
- int64_t idx = 0;
- int64_t reassocGroupSize = reassocIndices.size();
-
- // First handle the trailing dimensions where the slice size should be
- // equal to the tensor shape and the offset should be 0 (n...k+1).
- for (; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
- if (currentCollapsedsize < expandedShapeSize)
- break;
-
- // We need to make sure that the slice size can be set to the shape size
- // and the offset to 0.
- if ((currentCollapsedsize % expandedShapeSize) != 0 ||
- (currentCollapsedOffset % expandedShapeSize) != 0) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: cannot be extracted as a contiguous slice "
- "of the src of the collapse_shape");
- }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+ tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides, OpBuilder &b) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
- groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+ return failure();
+ }
- currentCollapsedsize /= expandedShapeSize;
- currentCollapsedOffset /= expandedShapeSize;
+ auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+ OpFoldResult sliceSize, int64_t inputDim) {
+ if (!isZeroInteger(offset))
+ return false;
+ ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+ return llvm::succeeded(maybeEqual) && maybeEqual.value();
+ };
+
+ // Check that the slice is contiguous within each reassociation group.
+ // The slice is contiguous only if after the first dimension where a non
+ // unit slice is taken, the slice size on all subsequent dimensions of the
+ // group is equal to the entire size of the dimension.
+ // Examples of contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+ // Examples of non contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+ for (const ReassociationIndices &indices : reassociation) {
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Find the first expanded dim after the first dim with non-unit extracted
+ // size.
+ for (; i < e; ++i) {
+ if (!isOneInteger(sizes[indices[i]])) {
+ // +1 to skip the first non-unit size dim.
+ i++;
+ break;
}
+ }
+
+ // Verify that all subsequent dimensions extract the full size of the
+ // source tensor.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+ expandedDim)) {
+ return failure();
+ }
+ }
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+ // referring to the state before applying the pattern are named with the
+ // prefix "expanded", and ones referring to the state after applying the
+ // pattern are named with the prefix "collapsed".
+ Location loc = sliceOp.getLoc();
+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> expandedShape =
+ getMixedSizes(b, loc, sliceOp.getSource());
+
+ // Helper variables and function for accumulat...
[truncated]
|
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me. I am assuming this is mostly a move of the logic and that github is messing up the diff?
The only real change here was to remove the calls to |
Signed-off-by: Ian Wood <[email protected]>
|
@hanhanW @nicolasvasilache would either of you mind reviewing? |
|
LGTM, maybe rename the title to |
Exposes the
tensor.extract_slicereshaping logic inBubbleUpExpandShapeThroughExtractSliceandBubbleUpCollapseShapeThroughExtractSlicethrough two corresponding utility functions. These compute the offsets/sizes/strides of an extract slice after either collapsing or expanding.This should also make it easier to implement the two other bubbling cases: (1) the
collapse_shapeis a consumer or (2) theexpand_shapeis a consumer.