-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Fix 0-d vector transfer mask inference #116526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When we inferred the mask of a transfer operation that results in a single `i1` element, we can represent it using a `vector<i1>` or a `vector<1xi1>`. To avoid issues with the this type mismatch, this PR fixes the mask inference logic to always generate `vector<1xi1>` for these cases. We can enable 0-d masks if they are eventually needed.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesWhen we infer the mask of a transfer operation which results in a single See: #116197 Full diff: https://github.com/llvm/llvm-project/pull/116526.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..6ba6d30099ce91 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4122,6 +4122,10 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
+ // Turn a 0-D mask into a single-element 1-D mask.
+ if (maskShape.empty())
+ maskShape.push_back(1);
+
SmallVector<bool> scalableDims =
applyPermutationMap(invPermMap, vecType.getScalableDims());
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..bfbdc83e382272 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1752,6 +1752,22 @@ func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32>
// -----
+// We can support 0-D masks if eventually needed.
+func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
+ %idx0: index, %idx1: index,
+ %m0: vector<i1>) -> vector<1x1x4xi32> {
+ %cst = arith.constant 0 : i32
+ // expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
+ %res = vector.mask %m0 {
+ %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+ : tensor<2x4xi32>, vector<1x1x4xi32>
+ vector.yield %0 : vector<1x1x4xi32>
+ } : vector<i1> -> vector<1x1x4xi32>
+ return %res : vector<1x1x4xi32>
+}
+
+// -----
+
func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
// expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
%0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..04d9ff0546160a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1028,6 +1028,20 @@ func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -
return %0 : vector<16xf32>
}
+// CHECK-LABEL: func @vector_mask_scalar_broadcast_transfer
+func.func @vector_mask_scalar_broadcast_transfer(%arg0: tensor<2x4xi32>,
+ %idx0: index, %idx1: index,
+ %m0: vector<1xi1>) -> vector<1x1x4xi32> {
+ %cst = arith.constant 0 : i32
+ // CHECK: vector.mask %{{.*}} { vector.transfer_read {{.*}} } : vector<1xi1> -> vector<1x1x4xi32>
+ %res = vector.mask %m0 {
+ %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+ : tensor<2x4xi32>, vector<1x1x4xi32>
+ vector.yield %0 : vector<1x1x4xi32>
+ } : vector<1xi1> -> vector<1x1x4xi32>
+ return %res : vector<1x1x4xi32>
+}
+
// CHECK-LABEL: func @vector_scalable_insert(
// CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
// CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>
|
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix!
| assert(invPermMap && "Inversed permutation map couldn't be computed"); | ||
| SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape()); | ||
|
|
||
| // Turn a 0-D mask into a single-element 1-D mask. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Personally, I'd appreciate a note saying that ATM vector.mask does not support 0-D masks (enforced here). And that's basically the reason to "upgrade" this mask (which to me is a very good reason).
As in, document "why" rather than "what" :)
|
|
||
| // ----- | ||
|
|
||
| // We can support 0-D masks if eventually needed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Why not add this comment where this restriction is enforced?
llvm-project/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Lines 2509 to 2510 in 4b50ec4
| let arguments = (ins VectorOf<[I1]>:$mask, | |
| Optional<AnyType>:$passthru); |
And to the Op docs :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really understand why we are forcing a vector<1xi1> mask instead of vector<i1> mask. Something seems wrong here.
I'll take 3 examples:
vector.transfer_read %tensor[], permutation_map<() -> (0, 0, 0)> : tensor<f32>, vector<4x4x4xf32>
vector.transfer_read %tensor[%idx], permutation_map<(d0) -> (0, 0, 0)> : tensor<1xf32>, vector<4x4x4xf32>
vector.transfer_read %tensor[%idx, %idx2], permutation_map<(d0, d1) -> (0, 0, 0)> : tensor<1x1xf32>, vector<4x4x4xf32>
From this PR, the mask for all of these vector.transfer_read will be vector<1xi1>, which doesn't really make sense to me. The permutation_map attribute specifies a mapping from the memory/tensor space to vector space. The masking needs to be done on the memory/tensor space. Each of these examples has a different dimensionality for the memory/tensor space. I would expect the masks to be:
vector.transfer_read %tensor[], permutation_map<() -> (0, 0, 0)> : tensor<f32>, vector<4x4x4xf32> // mask: vector<f32>
vector.transfer_read %tensor[%idx], permutation_map<(d0) -> (0, 0, 0)> : tensor<1xf32>, vector<4x4x4xf32> // mask: vector<1xf32>
vector.transfer_read %tensor[%idx, %idx2], permutation_map<(d0, d1) -> (0, 0, 0)> : tensor<1x1xf32>, vector<4x4x4xf32> // mask: vector<1x1xf32>
Based on the definition, the mask dimensionality should always match the dimensionality of the domain.
Instead of forcing the mask to be always vector<1xf32>, can we fix the inference to infer a mask of dimensionality as domain of the permutation map, and any dimension not used in the result of the permutation_map is simply 1, since it's broadcasted. This will also make the operation much more consistent w.r.t. masking. You can always expect the dimensionality of the mask to be same as the memory/tensor.
I’d be cautious about concluding that something is broken - at least not yet! 😊
All these examples involve broadcasting. It’s worth referring to this point from the vector.transfer_read documentation (emphasis mine):
This PR strictly follows the documented design, which ensures consistency. Other parts of MLIR should adhere to the same principle as well.
Which definition? :) Using |
Sorry, not the right choice of words. The PR is obviously correct, and useful. The only thing that seems weird to me is the lit test. Let me write up why it feels like that quickly. I only blocked the pr so we can discuss it more. I'm happy to remove the block meanwhile, and explain my concerns better.
Right, I shouldve been clearer, I don't want to block the change, but want to discuss the masking behavior more.
Right, I am following this definition, but I will write up a response explaining myself better. Meanwhile, I'm unblocking, feel free to land it :) |
|
Hey Kunwar, Your feedback is definitely welcome. Currently 0-D vector support has been a major source of bugs and trouble in general, even for simple ops, so I don't think we can just enable 0-D vectors here and move on. Honestly, I'm really concerned about the 0-D vector situation and the instability that has been bringing since its introduction. For that reason, as I mentioned in your PR, I was thinking about sending an RFC to propose its removal. It would be good to know if people have any concerns about that beforehand. |
|
@Groverkss , are you OK for Diego to proceed with this? It's the first step towards fixing #116197, hence I'm quite keen :) I will be sending a follow-on soonish. |
Yes! I will follow up on the discussion elsewhere, this patch by itself is good. |
When we infer the mask of a transfer operation which results in a single
i1element, we can usevector<i1>orvector<1xi1>to represent it. To avoid issues with this type mismatch, this PR fixes the mask inference logic to always generatevector<1xi1>for these cases. We can enable 0-d masks if they are eventually needed.See: #116197