From efc29a7ab095bfbd2e2b97b76e929e313939c633 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sun, 29 Dec 2024 18:52:12 +0000 Subject: [PATCH 1/4] [mlir][vector] Restrict use of 0-D vectors in vector.insert/vector.extract MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch enforces a restriction in the Vector dialect: the non-indexed operands of `vector.insert` and `vector.extract` must no longer be 0-D vectors. In other words, rank-0 vector types like `vector` are disallowed as the source or result. EXAMPLES -------- The following are now **illegal** (note the use of `vector`): ```mlir %0 = vector.insert %v, %dst[0, 0] : vector into vector<2x2xf32> %1 = vector.extract %src[0, 0] : vector from vector<2x2xf32> ``` Instead, use scalars as the source and result types: ```mlir %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32> %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32> ``` This change serves three goals: 1. REDUCED AMBIGUITY -------------------- By enforcing scalar-only semantics when n-k = 0, we eliminate ambiguity in interpretation. Prior to this patch, both `f32` and `vector` were accepted in practice, though only scalars were intended. 2. MATCH IMPLEMENTATION TO DOCUMENTATION ---------------------------------------- The current behavior contradicts the documented intent. For example, vector.extract states: > Degenerates to an element type if n-k is zero. This patch enforces that intent in code. 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT --------------------------------------------- With the stricter semantics in place, it’s natural and consistent to make `vector.insert` behave symmetrically to `vector.extract`, i.e., degenerate the source type to a scalar when n = 0. NOTES FOR REVIEWERS ------------------- 1. Main change is in "VectorOps.cpp", where stricter type checks are implemented. 2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to remove now-illegal examples. 2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now avoid using `vector.transfer_read` for scalar loads and instead rely on `memref.load` / `tensor.extract`. RELATED RFC ----------- * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect --- .../mlir/Dialect/Vector/IR/VectorOps.td | 16 ++++---- .../Conversion/VectorToSCF/VectorToSCF.cpp | 39 ++++++++++++++++--- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 10 +++++ mlir/test/Dialect/Vector/invalid.mlir | 19 +++------ mlir/test/Dialect/Vector/ops.mlir | 6 +-- 5 files changed, 59 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 8353314ed958b..cd6b3e7ad82dc 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -691,8 +691,9 @@ def Vector_ExtractOp : InferTypeOpAdaptorWithIsCompatible]> { let summary = "extract operation"; let description = [{ - Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at - the proper position. Degenerates to an element type if n-k is zero. + Extracts an (n − k)-D subvector (the result) from an n-D vector at a + specified k-D position. When n = k, the result degenerates to a scalar + element. Static and dynamic indices must be greater or equal to zero and less than the size of the corresponding dimension. The result is undefined if any @@ -704,7 +705,6 @@ def Vector_ExtractOp : ```mlir %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32> %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32> - %3 = vector.extract %1[]: vector from vector %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> %6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32> @@ -886,9 +886,10 @@ def Vector_InsertOp : AllTypesMatch<["dest", "result"]>]> { let summary = "insert operation"; let description = [{ - Takes an n-D source vector, an (n+k)-D destination vector and a k-D position - and inserts the n-D source into the (n+k)-D destination at the proper - position. Degenerates to a scalar or a 0-d vector source type when n = 0. + Inserts an n-D source vector (the value to store) into an (n + k)-D + destination vector at a specified k-D position. When n = 0, the source + degenerates to a scalar element inserted into the (0 + k)-D destination + vector. Static and dynamic indices must be greater or equal to zero and less than the size of the corresponding dimension. The result is undefined if any @@ -900,8 +901,7 @@ def Vector_InsertOp : ```mlir %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32> - %8 = vector.insert %6, %7[] : f32 into vector - %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> + %11 = vector.insert %9, %10[%a, %b, %c] : f32 into vector<4x8x16xf32> %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> %13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32> ``` diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index cc5623068ab10..08f398a1c8ba6 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1294,6 +1294,10 @@ struct UnrollTransferReadConversion /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds /// accesses, and broadcasts and transposes in permutation maps. + /// + /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces + /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for + /// MemRef and Tensor source, respectively). LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) @@ -1324,6 +1328,8 @@ struct UnrollTransferReadConversion for (int64_t i = 0; i < dimSize; ++i) { Value iv = rewriter.create(loc, i); + // FIXME: Rename this lambda - it does much more than just + // in-bounds-check generation. vec = generateInBoundsCheck( rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), /*inBoundsCase=*/ @@ -1338,12 +1344,33 @@ struct UnrollTransferReadConversion insertionIndices.push_back(rewriter.getIndexAttr(i)); auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - auto newXferOp = b.create( - loc, newXferVecType, xferOp.getBase(), xferIndices, - AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), - xferOp.getPadding(), Value(), inBoundsAttr); - maybeAssignMask(b, xferOp, newXferOp, i); - return b.create(loc, newXferOp, vec, + + // A value that's read after rank-reducing the original + // vector.transfer_read Op. + Value unpackedReadRes; + if (newXferVecType.getRank() != 0) { + // Unpacking Vector that's rank > 2 + // (use vector.transfer_read to load a rank-reduced vector) + unpackedReadRes = b.create( + loc, newXferVecType, xferOp.getBase(), xferIndices, + AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), + xferOp.getPadding(), Value(), inBoundsAttr); + maybeAssignMask(b, xferOp, + dyn_cast( + unpackedReadRes.getDefiningOp()), + i); + } else { + // Unpacking Vector that's rank == 1 + // (use memref.load/tensor.extract to load a scalar) + unpackedReadRes = dyn_cast(xferOp.getBase().getType()) + ? b.create( + loc, xferOp.getBase(), xferIndices) + .getResult() + : b.create( + loc, xferOp.getBase(), xferIndices) + .getResult(); + } + return b.create(loc, unpackedReadRes, vec, insertionIndices); }, /*outOfBoundsCase=*/ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2a2357319bd23..dc4bcd9b6bd84 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { } LogicalResult vector::ExtractOp::verify() { + if (auto resTy = dyn_cast(getResult().getType())) + if (resTy.getRank() == 0) + return emitError( + "expected a scalar instead of a 0-d vector as the result type"); + // Note: This check must come before getMixedPosition() to prevent a crash. auto dynamicMarkersCount = llvm::count_if(getStaticPosition(), ShapedType::isDynamic); @@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result, } LogicalResult InsertOp::verify() { + if (auto srcTy = dyn_cast(getValueToStoreType())) + if (srcTy.getRank() == 0) + return emitError( + "expected a scalar instead of a 0-d vector as the source operand"); + SmallVector position = getMixedPosition(); auto destVectorType = getDestVectorType(); if (position.size() > static_cast(destVectorType.getRank())) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 04810ed52584f..a2622c06fa71c 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -178,9 +178,9 @@ func.func @extract_precise_position_overflow(%arg0: vector<4x8x16xf32>) { // ----- -func.func @extract_0d(%arg0: vector) { - // expected-error@+1 {{expected position attribute of rank no greater than vector rank}} - %1 = vector.extract %arg0[0] : f32 from vector +func.func @extract_0d_result(%arg0: vector) { + // expected-error@+1 {{expected a scalar instead of a 0-d vector as the result type}} + %1 = vector.extract %arg0[] : vector from vector } // ----- @@ -259,16 +259,9 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) { // ----- -func.func @insert_0d(%a: vector, %b: vector<4x8x16xf32>) { - // expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}} - %1 = vector.insert %a, %b[2, 6] : vector into vector<4x8x16xf32> -} - -// ----- - -func.func @insert_0d(%a: f32, %b: vector) { - // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} - %1 = vector.insert %a, %b[0] : f32 into vector +func.func @insert_0d_value_to_store(%a: vector, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}} + %1 = vector.insert %a, %b[0, 0, 0] : vector into vector<4x8x16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index f3220aed4360c..7d43f2a84dc77 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, } // CHECK-LABEL: @insert_0d -func.func @insert_0d(%a: f32, %b: vector, %c: vector<2x3xf32>) -> (vector, vector<2x3xf32>) { +func.func @insert_0d(%a: f32, %b: vector) -> vector { // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector %1 = vector.insert %a, %b[] : f32 into vector - // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector into vector<2x3xf32> - %2 = vector.insert %b, %c[0, 1] : vector into vector<2x3xf32> - return %1, %2 : vector, vector<2x3xf32> + return %1 : vector } // CHECK-LABEL: @insert_poison_idx From c15e7dddaea765eab4f9ed73e79b762138dc4ac0 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 19 Jun 2025 19:33:43 +0100 Subject: [PATCH 2/4] fixup! [mlir][vector] Restrict use of 0-D vectors in vector.insert/vector.extract Address Kunwar's comments --- .../mlir/Dialect/Vector/IR/VectorOps.td | 9 ++-- .../Conversion/VectorToSCF/VectorToSCF.cpp | 42 ++++++------------- 2 files changed, 17 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index cd6b3e7ad82dc..a038ed53f327a 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -691,7 +691,7 @@ def Vector_ExtractOp : InferTypeOpAdaptorWithIsCompatible]> { let summary = "extract operation"; let description = [{ - Extracts an (n − k)-D subvector (the result) from an n-D vector at a + Extracts an (n − k)-D result sub-vector from an n-D source vector at a specified k-D position. When n = k, the result degenerates to a scalar element. @@ -886,10 +886,9 @@ def Vector_InsertOp : AllTypesMatch<["dest", "result"]>]> { let summary = "insert operation"; let description = [{ - Inserts an n-D source vector (the value to store) into an (n + k)-D - destination vector at a specified k-D position. When n = 0, the source - degenerates to a scalar element inserted into the (0 + k)-D destination - vector. + Inserts an n-D value-to-store vector into an (n + k)-D destination vector + at a specified k-D position. When n = 0, value-to-store degenerates to + a scalar element inserted into the (0 + k)-D destination vector. Static and dynamic indices must be greater or equal to zero and less than the size of the corresponding dimension. The result is undefined if any diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 08f398a1c8ba6..9b95b0a5d6050 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1294,10 +1294,6 @@ struct UnrollTransferReadConversion /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds /// accesses, and broadcasts and transposes in permutation maps. - /// - /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces - /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for - /// MemRef and Tensor source, respectively). LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) @@ -1345,32 +1341,20 @@ struct UnrollTransferReadConversion auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); - // A value that's read after rank-reducing the original - // vector.transfer_read Op. - Value unpackedReadRes; - if (newXferVecType.getRank() != 0) { - // Unpacking Vector that's rank > 2 - // (use vector.transfer_read to load a rank-reduced vector) - unpackedReadRes = b.create( - loc, newXferVecType, xferOp.getBase(), xferIndices, - AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), - xferOp.getPadding(), Value(), inBoundsAttr); - maybeAssignMask(b, xferOp, - dyn_cast( - unpackedReadRes.getDefiningOp()), - i); - } else { - // Unpacking Vector that's rank == 1 - // (use memref.load/tensor.extract to load a scalar) - unpackedReadRes = dyn_cast(xferOp.getBase().getType()) - ? b.create( - loc, xferOp.getBase(), xferIndices) - .getResult() - : b.create( - loc, xferOp.getBase(), xferIndices) - .getResult(); + auto newXferOp = b.create( + loc, newXferVecType, xferOp.getBase(), xferIndices, + AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), + xferOp.getPadding(), Value(), inBoundsAttr); + maybeAssignMask(b, xferOp, newXferOp, i); + + Value valToInser = newXferOp.getResult(); + if (newXferVecType.getRank() == 0) { + // vector.insert does not accept rank-0 as the non-indexed + // argument. Extract the scalar before inserting. + valToInser = b.create(loc, valToInser, + SmallVector()); } - return b.create(loc, unpackedReadRes, vec, + return b.create(loc, valToInser, vec, insertionIndices); }, /*outOfBoundsCase=*/ From c62abf4d15500182b66e5957f945f790f5d8ad72 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 20 Jun 2025 09:38:10 +0100 Subject: [PATCH 3/4] fixup! fixup! [mlir][vector] Restrict use of 0-D vectors in vector.insert/vector.extract Apply clang-format --- mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 9b95b0a5d6050..002dfebd2b602 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1352,7 +1352,7 @@ struct UnrollTransferReadConversion // vector.insert does not accept rank-0 as the non-indexed // argument. Extract the scalar before inserting. valToInser = b.create(loc, valToInser, - SmallVector()); + SmallVector()); } return b.create(loc, valToInser, vec, insertionIndices); From f1babc3b1d00333e872fafb8b1bb926061d00b27 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Wed, 25 Jun 2025 14:32:18 +0100 Subject: [PATCH 4/4] fixup! fixup! fixup! [mlir][vector] Restrict use of 0-D vectors in vector.insert/vector.extract Update the docs as suggested by Kunwar --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index a038ed53f327a..d47206f52def8 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -886,9 +886,9 @@ def Vector_InsertOp : AllTypesMatch<["dest", "result"]>]> { let summary = "insert operation"; let description = [{ - Inserts an n-D value-to-store vector into an (n + k)-D destination vector - at a specified k-D position. When n = 0, value-to-store degenerates to - a scalar element inserted into the (0 + k)-D destination vector. + Inserts an (n - k)-D sub-vector (value-to-store) into an n-D destination + vector at a specified k-D position. When n = 0, value-to-store degenerates + to a scalar element inserted into the n-D destination vector. Static and dynamic indices must be greater or equal to zero and less than the size of the corresponding dimension. The result is undefined if any