-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Fix function signature legalization for n-D vectors #99872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Fix function signature legalization for n-D vectors #99872
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Angel Zhang (angelz913) ChangesFull diff: https://github.com/llvm/llvm-project/pull/99872.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf5044437fd09..c146589612b5e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1098,13 +1098,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
// the original operand of illegal type.
auto originalShape =
llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
- SmallVector<int64_t> strides(targetShape->size(), 1);
+ SmallVector<int64_t> strides(originalShape.size(), 1);
+ SmallVector<int64_t> extractShape(originalShape.size(), 1);
+ extractShape.back() = targetShape->back();
SmallVector<Type> newTypes;
Value returnValue = returnOp.getOperand(origResultNo);
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(originalShape, *targetShape)) {
Value result = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, returnValue, offsets, *targetShape, strides);
+ loc, returnValue, offsets, extractShape, strides);
+ SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
+ if (originalShape.size() > 1)
+ result =
+ rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
newOperands.push_back(result);
newTypes.push_back(unrolledType);
}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index 347d282f9ee0c..c018ccb924983 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
// -----
+// CHECK-LABEL: @simple_vector_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+ return %arg0 : vector<4x4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @vector_6and8
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
@@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i
// -----
+// CHECK-LABEL: @vector_2dand1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
+func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
+ // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
+ // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+ // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+ // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
+ // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
+ return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
+}
+
+// -----
+
// CHECK-LABEL: @reduction
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {
|
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what the issue is and what this PR does to fix that?
Updated description |
…100138) ### Description This PR builds on #99872. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR also includes new LIT tests that only check for vector unrolling. ### Future Plans - Support more ops --------- Co-authored-by: Jakub Kuderski <[email protected]>
|
The changes are included in #100138. |
…100138) Summary: ### Description This PR builds on #99872. It implements a minimal version of function body vector unrolling to convert vector types into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). The ops that are currently supported include those with elementwise traits (e.g. `arith.addi`), `vector.reduction` and `vector.transpose`. This PR also includes new LIT tests that only check for vector unrolling. ### Future Plans - Support more ops --------- Co-authored-by: Jakub Kuderski <[email protected]> Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250596
This PR makes a fix to the function signature vector unrolling for SPIR-V (introduced in #98337).
Issue
vector.extract_strided_sliceop requires itsoffsets,sizesandstridesoperands to be of the same length. The original implementation did not consider when the source vector has rank greater than 1, so theoffsetscreated usingStaticTileOffsetRangewill have length > 1, while thesizeandstridesstill have length of 1. This caused assertion failures when creating thevector.extract_strided_sliceops for vectors in function results.Fix
This PR addresses the issue by making the
sizesandstridesoperands to match the length ofoffsets(equal to the rank of the original vector). An extractvector.extractop is also created to trim potential leading ones in the result from thevector.extract_strided_sliceops.