Skip to content

Conversation

@dcaballe
Copy link
Contributor

When we infer the mask of a transfer operation which results in a single i1 element, we can use vector<i1> or vector<1xi1> to represent it. To avoid issues with this type mismatch, this PR fixes the mask inference logic to always generate vector<1xi1> for these cases. We can enable 0-d masks if they are eventually needed.

See: #116197

When we inferred the mask of a transfer operation that results in a
single `i1` element, we can represent it using a `vector<i1>` or a
`vector<1xi1>`. To avoid issues with the this type mismatch, this PR
fixes the mask inference logic to always generate `vector<1xi1>` for
these cases. We can enable 0-d masks if they are eventually needed.
@llvmbot
Copy link
Member

llvmbot commented Nov 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

When we infer the mask of a transfer operation which results in a single i1 element, we can use vector&lt;i1&gt; or vector&lt;1xi1&gt; to represent it. To avoid issues with this type mismatch, this PR fixes the mask inference logic to always generate vector&lt;1xi1&gt; for these cases. We can enable 0-d masks if they are eventually needed.

See: #116197


Full diff: https://github.com/llvm/llvm-project/pull/116526.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+16)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+14)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index db199a46e1637c..6ba6d30099ce91 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4122,6 +4122,10 @@ VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
   assert(invPermMap && "Inversed permutation map couldn't be computed");
   SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
 
+  // Turn a 0-D mask into a single-element 1-D mask.
+  if (maskShape.empty())
+    maskShape.push_back(1);
+
   SmallVector<bool> scalableDims =
       applyPermutationMap(invPermMap, vecType.getScalableDims());
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d591c60acb64e7..bfbdc83e382272 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1752,6 +1752,22 @@ func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32>
 
 // -----
 
+// We can support 0-D masks if eventually needed.
+func.func @vector_mask_0d_mask(%arg0: tensor<2x4xi32>,
+                               %idx0: index, %idx1: index,
+                               %m0: vector<i1>) -> vector<1x1x4xi32> {
+  %cst = arith.constant 0 : i32
+  // expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
+  %res = vector.mask %m0 {
+    %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+      : tensor<2x4xi32>, vector<1x1x4xi32>
+    vector.yield %0 : vector<1x1x4xi32>
+  } : vector<i1> -> vector<1x1x4xi32>
+  return %res : vector<1x1x4xi32>
+}
+
+// -----
+
 func.func @vector_scalable_insert_unaligned(%subv: vector<4xi32>, %vec: vector<[16]xi32>) {
   // expected-error@+1 {{op failed to verify that position is a multiple of the source length.}}
   %0 = vector.scalable.insert %subv, %vec[2] : vector<4xi32> into vector<[16]xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 3baacba9b61243..04d9ff0546160a 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1028,6 +1028,20 @@ func.func @vector_mask_empty_return(%m0: vector<16xi1>, %arg0: vector<16xf32>) -
   return %0 : vector<16xf32>
 }
 
+// CHECK-LABEL: func @vector_mask_scalar_broadcast_transfer
+func.func @vector_mask_scalar_broadcast_transfer(%arg0: tensor<2x4xi32>,
+                                                 %idx0: index, %idx1: index,
+                                                 %m0: vector<1xi1>) -> vector<1x1x4xi32> {
+  %cst = arith.constant 0 : i32
+  // CHECK: vector.mask %{{.*}} { vector.transfer_read {{.*}} } : vector<1xi1> -> vector<1x1x4xi32>
+  %res = vector.mask %m0 {
+    %0 = vector.transfer_read %arg0[%idx0, %idx1], %cst {permutation_map = affine_map<(d0, d1) -> (0, 0, 0)>}
+      : tensor<2x4xi32>, vector<1x1x4xi32>
+    vector.yield %0 : vector<1x1x4xi32>
+  } : vector<1xi1> -> vector<1x1x4xi32>
+  return %res : vector<1x1x4xi32>
+}
+
 // CHECK-LABEL: func @vector_scalable_insert(
 // CHECK-SAME: %[[SUB0:.*]]: vector<4xi32>, %[[SUB1:.*]]: vector<8xi32>,
 // CHECK-SAME: %[[SUB2:.*]]: vector<[4]xi32>, %[[SV:.*]]: vector<[8]xi32>

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

assert(invPermMap && "Inversed permutation map couldn't be computed");
SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());

// Turn a 0-D mask into a single-element 1-D mask.
Copy link
Contributor

@banach-space banach-space Nov 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Personally, I'd appreciate a note saying that ATM vector.mask does not support 0-D masks (enforced here). And that's basically the reason to "upgrade" this mask (which to me is a very good reason).

As in, document "why" rather than "what" :)


// -----

// We can support 0-D masks if eventually needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Why not add this comment where this restriction is enforced?

let arguments = (ins VectorOf<[I1]>:$mask,
Optional<AnyType>:$passthru);

And to the Op docs :)

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand why we are forcing a vector<1xi1> mask instead of vector<i1> mask. Something seems wrong here.

I'll take 3 examples:

vector.transfer_read %tensor[], permutation_map<() -> (0, 0, 0)> : tensor<f32>, vector<4x4x4xf32>
vector.transfer_read %tensor[%idx], permutation_map<(d0) -> (0, 0, 0)> : tensor<1xf32>, vector<4x4x4xf32>
vector.transfer_read %tensor[%idx, %idx2], permutation_map<(d0, d1) -> (0, 0, 0)> : tensor<1x1xf32>, vector<4x4x4xf32>

From this PR, the mask for all of these vector.transfer_read will be vector<1xi1>, which doesn't really make sense to me. The permutation_map attribute specifies a mapping from the memory/tensor space to vector space. The masking needs to be done on the memory/tensor space. Each of these examples has a different dimensionality for the memory/tensor space. I would expect the masks to be:

vector.transfer_read %tensor[], permutation_map<() -> (0, 0, 0)> : tensor<f32>, vector<4x4x4xf32> // mask: vector<f32>
vector.transfer_read %tensor[%idx], permutation_map<(d0) -> (0, 0, 0)> : tensor<1xf32>, vector<4x4x4xf32> // mask: vector<1xf32>
vector.transfer_read %tensor[%idx, %idx2], permutation_map<(d0, d1) -> (0, 0, 0)> : tensor<1x1xf32>, vector<4x4x4xf32> // mask: vector<1x1xf32>

Based on the definition, the mask dimensionality should always match the dimensionality of the domain.

Instead of forcing the mask to be always vector<1xf32>, can we fix the inference to infer a mask of dimensionality as domain of the permutation map, and any dimension not used in the result of the permutation_map is simply 1, since it's broadcasted. This will also make the operation much more consistent w.r.t. masking. You can always expect the dimensionality of the mask to be same as the memory/tensor.

@banach-space
Copy link
Contributor

I don't really understand why we are forcing a vector<1xi1> mask instead of vector mask. Something seems wrong here.

I’d be cautious about concluding that something is broken - at least not yet! 😊

I'll take 3 examples:

vector.transfer_read %tensor[], permutation_map<() -> (0, 0, 0)> : tensor<f32>, vector<4x4x4xf32>
vector.transfer_read %tensor[%idx], permutation_map<(d0) -> (0, 0, 0)> : tensor<1xf32>, vector<4x4x4xf32>
vector.transfer_read %tensor[%idx, %idx2], permutation_map<(d0, d1) -> (0, 0, 0)> : tensor<1x1xf32>, vector<4x4x

All these examples involve broadcasting. It’s worth referring to this point from the vector.transfer_read documentation (emphasis mine):

The mask type is an i1 vector with a shape that matches how elements are read from the MemRef/Tensor, before any permutation or broadcasting.

This PR strictly follows the documented design, which ensures consistency. Other parts of MLIR should adhere to the same principle as well.

Based on the definition, the mask dimensionality should always match the dimensionality of the domain.

Which definition? :)

Using vector<1x1xf32> instead of vector<1xf32> in your 3rd example would be one possibility, but not the only one ;-)

@Groverkss
Copy link
Member

I’d be cautious about concluding that something is broken - at least not yet! 😊

Sorry, not the right choice of words. The PR is obviously correct, and useful. The only thing that seems weird to me is the lit test. Let me write up why it feels like that quickly. I only blocked the pr so we can discuss it more. I'm happy to remove the block meanwhile, and explain my concerns better.

This PR strictly follows the documented design, which ensures consistency. Other parts of MLIR should adhere to the same principle as well.

Right, I shouldve been clearer, I don't want to block the change, but want to discuss the masking behavior more.

All these examples involve broadcasting. It’s worth referring to this point from the vector.transfer_read documentation (emphasis mine):

Right, I am following this definition, but I will write up a response explaining myself better.

Meanwhile, I'm unblocking, feel free to land it :)

@Groverkss Groverkss dismissed their stale review November 18, 2024 16:21

explained above

@dcaballe
Copy link
Contributor Author

Hey Kunwar,

Your feedback is definitely welcome. Currently vector.mask, as some other ops, does not support 0-D vectors at ODS level. You can take a look at the differences between:

// Vector types.

class VectorOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
                      "::mlir::VectorType">;

// Temporary vector type clone that allows gradual transition to 0-D vectors.
// TODO: Remove this when all ops support 0-D vectors.
class VectorOfAnyRankOf<list<Type> allowedTypes> :
  ShapedContainerType<allowedTypes, IsVectorOfAnyRankTypePred, "vector",
                      "::mlir::VectorType">;

0-D vector support has been a major source of bugs and trouble in general, even for simple ops, so I don't think we can just enable 0-D vectors here and move on. Honestly, I'm really concerned about the 0-D vector situation and the instability that has been bringing since its introduction. For that reason, as I mentioned in your PR, I was thinking about sending an RFC to propose its removal. It would be good to know if people have any concerns about that beforehand.

@banach-space
Copy link
Contributor

@Groverkss , are you OK for Diego to proceed with this?

It's the first step towards fixing #116197, hence I'm quite keen :) I will be sending a follow-on soonish.

@Groverkss
Copy link
Member

@Groverkss , are you OK for Diego to proceed with this?

It's the first step towards fixing #116197, hence I'm quite keen :) I will be sending a follow-on soonish.

Yes! I will follow up on the discussion elsewhere, this patch by itself is good.

@dcaballe dcaballe requested a review from kuhar as a code owner November 21, 2024 04:19
@dcaballe dcaballe merged commit 3291372 into llvm:main Nov 21, 2024
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants