From 9040d7d1ef07b4f446c8d2d3b4ea317b921398ae Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sat, 25 May 2024 23:03:40 -0400 Subject: [PATCH 1/2] [mlir][vector] Add result type to `interleave` assembly format This is to make it more obvious for what the result type is, especially with some less trivial cases like 0-d inputs resulting in 1-d inputs or interaction with scalable vector types. Note that `vector.deinterleave` uses the same format with explicit result type. Also improve examples and clean up surrounding code. --- .../mlir/Dialect/Vector/IR/VectorOps.td | 38 +++++++++---------- .../Transforms/LowerVectorInterleave.cpp | 15 ++++---- .../Transforms/VectorEmulateNarrowType.cpp | 6 +-- .../VectorToLLVM/vector-to-llvm.mlir | 22 +++++------ .../VectorToSPIRV/vector-to-spirv.mlir | 2 +- mlir/test/Dialect/Vector/canonicalize.mlir | 7 ++-- mlir/test/Dialect/Vector/ops.mlir | 12 +++--- ...vector-interleave-lowering-transforms.mlir | 20 +++++----- .../Vector/vector-interleave-to-shuffle.mlir | 5 +-- .../CPU/ArmSVE/test-scalable-interleave.mlir | 2 +- .../Dialect/Vector/CPU/test-interleave.mlir | 2 +- 11 files changed, 61 insertions(+), 70 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 2bb7540ef0b0f..e043320b56411 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -480,24 +480,25 @@ def Vector_ShuffleOp : let hasCanonicalizer = 1; } -def Vector_InterleaveOp : - Vector_Op<"interleave", [Pure, - AllTypesMatch<["lhs", "rhs"]>, - TypesMatchWith< +def ResultIsDoubleSourceVectorType : TypesMatchWith< "type of 'result' is double the width of the inputs", "lhs", "result", [{ [&]() -> ::mlir::VectorType { - auto vectorType = ::llvm::cast($_self); + auto vectorType = ::llvm::cast<::mlir::VectorType>($_self); ::mlir::VectorType::Builder builder(vectorType); if (vectorType.getRank() == 0) { - static constexpr int64_t v2xty_shape[] = { 2 }; - return builder.setShape(v2xty_shape); + static constexpr int64_t v2xTyShape[] = {2}; + return builder.setShape(v2xTyShape); } auto lastDim = vectorType.getRank() - 1; return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2); }() - }]>]> { + }]>; + +def Vector_InterleaveOp : + Vector_Op<"interleave", [Pure, AllTypesMatch<["lhs", "rhs"]>, + ResultIsDoubleSourceVectorType]> { let summary = "constructs a vector by interleaving two input vectors"; let description = [{ The interleave operation constructs a new vector by interleaving the @@ -513,16 +514,15 @@ def Vector_InterleaveOp : Example: ```mlir - %0 = vector.interleave %a, %b - : vector<[4]xi32> ; yields vector<[8]xi32> - %1 = vector.interleave %c, %d - : vector<8xi8> ; yields vector<16xi8> - %2 = vector.interleave %e, %f - : vector ; yields vector<2xf16> - %3 = vector.interleave %g, %h - : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64> - %4 = vector.interleave %i, %j - : vector<6x3xf32> ; yields vector<6x6xf32> + %a = arith.constant dense<[0, 1]> : vector<2xi32> + %b = arith.constant dense<[2, 3]> : vector<2xi32> + %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32> + // The value of `%0` is `[0, 2, 1, 3]`. + + %1 = vector.interleave %c, %d : vector -> vector<2xf16> + %2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32> + %3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32> + %4 = vector.interleave %i, %j : vector<2x4x[2]xf64> -> vector<2x4x[4]xf64> ``` }]; @@ -530,7 +530,7 @@ def Vector_InterleaveOp : let results = (outs AnyVector:$result); let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` type($lhs) + $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result) }]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 5326760c9b4eb..77c97b2f1497c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -30,7 +30,7 @@ namespace { /// Example: /// /// ```mlir -/// vector.interleave %a, %b : vector<1x2x3x4xi64> +/// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64> /// ``` /// Would be unrolled to: /// ```mlir @@ -39,14 +39,15 @@ namespace { /// : vector<4xi64> from vector<1x2x3x4xi64> | /// %1 = vector.extract %b[0, 0, 0] | /// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for -/// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions +/// %2 = vector.interleave %0, %1 : | all leading positions +/// : vector<4xi64> -> vector<8xi64> | /// %3 = vector.insert %2, %result [0, 0, 0] | /// : vector<8xi64> into vector<1x2x3x8xi64> ┘ /// ``` /// /// Note: If any leading dimension before the `targetRank` is scalable the /// unrolling will stop before the scalable dimension. -class UnrollInterleaveOp : public OpRewritePattern { +class UnrollInterleaveOp final : public OpRewritePattern { public: UnrollInterleaveOp(int64_t targetRank, MLIRContext *context, PatternBenefit benefit = 1) @@ -84,7 +85,7 @@ class UnrollInterleaveOp : public OpRewritePattern { /// Example: /// /// ```mlir -/// vector.interleave %a, %b : vector<7xi16> +/// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16> /// ``` /// /// Is rewritten into: @@ -93,10 +94,8 @@ class UnrollInterleaveOp : public OpRewritePattern { /// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] /// : vector<7xi16>, vector<7xi16> /// ``` -class InterleaveToShuffle : public OpRewritePattern { -public: - InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {}; +struct InterleaveToShuffle final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::InterleaveOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 6025c4ad7c145..59b6cb3ae667a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1090,7 +1090,7 @@ struct RewriteExtOfBitCast : OpRewritePattern { /// %1 = arith.shli %0, 4 : vector<4xi8> /// %2 = arith.shrsi %1, 4 : vector<4xi8> /// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.interleave %2, %3 : vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> /// /// arith.sitofp %in : vector<8xi4> to vector<8xf32> @@ -1099,7 +1099,7 @@ struct RewriteExtOfBitCast : OpRewritePattern { /// %1 = arith.shli %0, 4 : vector<4xi8> /// %2 = arith.shrsi %1, 4 : vector<4xi8> /// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.interleave %2, %3 : vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> /// /// Example (unsigned): @@ -1108,7 +1108,7 @@ struct RewriteExtOfBitCast : OpRewritePattern { /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> /// %1 = arith.andi %0, 15 : vector<4xi8> /// %2 = arith.shrui %0, 4 : vector<4xi8> -/// %3 = vector.interleave %1, %2 : vector<4xi8> +/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8> /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> /// template diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 439f1e920e392..a7a0ca3d43b01 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2495,7 +2495,7 @@ func.func @vector_interleave_0d(%a: vector, %b: vector) -> vector<2xi8> // CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector to vector<1xi8> // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8> // CHECK: return %[[ZIP]] - %0 = vector.interleave %a, %b : vector + %0 = vector.interleave %a, %b : vector -> vector<2xi8> return %0 : vector<2xi8> } @@ -2503,11 +2503,10 @@ func.func @vector_interleave_0d(%a: vector, %b: vector) -> vector<2xi8> // CHECK-LABEL: @vector_interleave_1d // CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>) -func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32> -{ +func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32> { // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32> // CHECK: return %[[ZIP]] - %0 = vector.interleave %a, %b : vector<8xf32> + %0 = vector.interleave %a, %b : vector<8xf32> -> vector<16xf32> return %0 : vector<16xf32> } @@ -2515,11 +2514,10 @@ func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector< // CHECK-LABEL: @vector_interleave_1d_scalable // CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>) -func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32> -{ +func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32> { // CHECK: %[[ZIP:.*]] = "llvm.intr.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> // CHECK: return %[[ZIP]] - %0 = vector.interleave %a, %b : vector<[4]xi32> + %0 = vector.interleave %a, %b : vector<[4]xi32> -> vector<[8]xi32> return %0 : vector<[8]xi32> } @@ -2527,11 +2525,10 @@ func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32 // CHECK-LABEL: @vector_interleave_2d // CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>) -func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> -{ +func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> { // CHECK: llvm.shufflevector // CHECK-NOT: vector.interleave {{.*}} : vector<2x3xi8> - %0 = vector.interleave %a, %b : vector<2x3xi8> + %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8> return %0 : vector<2x6xi8> } @@ -2539,10 +2536,9 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto // CHECK-LABEL: @vector_interleave_2d_scalable // CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>) -func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> -{ +func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> { // CHECK: llvm.intr.vector.interleave2 // CHECK-NOT: vector.interleave {{.*}} : vector<2x[8]xi16> - %0 = vector.interleave %a, %b : vector<2x[8]xi16> + %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16> return %0 : vector<2x[16]xi16> } diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index a7542086aa766..b24088d951259 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -488,7 +488,7 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> { // CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32> // CHECK: return %[[SHUFFLE]] func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> { - %0 = vector.interleave %a, %b : vector<2xf32> + %0 = vector.interleave %a, %b : vector<2xf32> -> vector<4xf32> return %0 : vector<4xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 61a5f2a96e1c1..22af91e0eb327 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2576,9 +2576,8 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te // CHECK-LABEL: func.func @rank_0_shuffle_to_interleave( // CHECK-SAME: %[[LHS:.*]]: vector, %[[RHS:.*]]: vector) -func.func @rank_0_shuffle_to_interleave(%arg0: vector, %arg1: vector) -> vector<2xf64> -{ - // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector +func.func @rank_0_shuffle_to_interleave(%arg0: vector, %arg1: vector) -> vector<2xf64> { + // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector -> vector<2xf64> // CHECK: return %[[ZIP]] %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector, vector return %0 : vector<2xf64> @@ -2589,7 +2588,7 @@ func.func @rank_0_shuffle_to_interleave(%arg0: vector, %arg1: vector) // CHECK-LABEL: func.func @rank_1_shuffle_to_interleave( // CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>) func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> { - // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32> + // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32> -> vector<12xi32> // CHECK: return %[[ZIP]] %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32> return %0 : vector<12xi32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 9d8101d3eee97..c868c881d079a 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1084,36 +1084,36 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 { // CHECK-LABEL: @interleave_0d func.func @interleave_0d(%a: vector, %b: vector) -> vector<2xf32> { - // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector - %0 = vector.interleave %a, %b : vector + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector -> vector<2xf32> + %0 = vector.interleave %a, %b : vector -> vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: @interleave_1d func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> { // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32> - %0 = vector.interleave %a, %b : vector<4xf32> + %0 = vector.interleave %a, %b : vector<4xf32> -> vector<8xf32> return %0 : vector<8xf32> } // CHECK-LABEL: @interleave_1d_scalable func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> { // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16> - %0 = vector.interleave %a, %b : vector<[8]xi16> + %0 = vector.interleave %a, %b : vector<[8]xi16> -> vector<[16]xi16> return %0 : vector<[16]xi16> } // CHECK-LABEL: @interleave_2d func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> { // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32> - %0 = vector.interleave %a, %b : vector<2x8xf32> + %0 = vector.interleave %a, %b : vector<2x8xf32> -> vector<2x16xf32> return %0 : vector<2x16xf32> } // CHECK-LABEL: @interleave_2d_scalable func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> { // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64> - %0 = vector.interleave %a, %b : vector<2x[2]xf64> + %0 = vector.interleave %a, %b : vector<2x[2]xf64> -> vector<2x[4]xf64> return %0 : vector<2x[4]xf64> } diff --git a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir index 3dd4857860eb1..598f7d70b4f1b 100644 --- a/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir @@ -2,8 +2,7 @@ // CHECK-LABEL: @vector_interleave_2d // CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>) -func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> -{ +func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0] // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0] @@ -14,14 +13,13 @@ func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vecto // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0] // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1] // CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8> - %0 = vector.interleave %a, %b : vector<2x3xi8> + %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8> return %0 : vector<2x6xi8> } // CHECK-LABEL: @vector_interleave_2d_scalable // CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>) -func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> -{ +func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> { // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0] // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0] @@ -32,7 +30,7 @@ func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8] // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0] // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1] // CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16> - %0 = vector.interleave %a, %b : vector<2x[8]xi16> + %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16> return %0 : vector<2x[16]xi16> } @@ -44,17 +42,17 @@ func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64> // CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64> // CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64> // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64> - // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64> - %0 = vector.interleave %a, %b : vector<1x2x3x4xi64> + // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64> -> vector<8xi64> + %0 = vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64> return %0 : vector<1x2x3x8xi64> } // CHECK-LABEL: @vector_interleave_nd_with_scalable_dim -func.func @vector_interleave_nd_with_scalable_dim(%a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16> -{ +func.func @vector_interleave_nd_with_scalable_dim( + %a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16> { // The scalable dim blocks unrolling so only the first two dims are unrolled. // CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16> - %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16> + %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16> -> vector<1x3x[2]x2x3x8xf16> return %0 : vector<1x3x[2]x2x3x8xf16> } diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir index ed3b3396bf3ea..d59cd4e6765ba 100644 --- a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir +++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir @@ -1,9 +1,8 @@ // RUN: mlir-opt %s --transform-interpreter | FileCheck %s // CHECK-LABEL: @vector_interleave_to_shuffle -func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16> -{ - %0 = vector.interleave %a, %b : vector<7xi16> +func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16> { + %0 = vector.interleave %a, %b : vector<7xi16> -> vector<14xi16> return %0 : vector<14xi16> } // CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir index 07989bd71f501..e9f1bbeafacdd 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir @@ -17,7 +17,7 @@ func.func @entry() { // CHECK: ( 1, 1, 1, 1 // CHECK: ( 2, 2, 2, 2 - %v3 = vector.interleave %v1, %v2 : vector<[4]xf32> + %v3 = vector.interleave %v1, %v2 : vector<[4]xf32> -> vector<[8]xf32> vector.print %v3 : vector<[8]xf32> // CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir index 0bc78af6aba03..d6962cbe2776a 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir @@ -16,7 +16,7 @@ func.func @entry() { // CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) ) // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) ) - %v3 = vector.interleave %v1, %v2 : vector<2x4xf32> + %v3 = vector.interleave %v1, %v2 : vector<2x4xf32> -> vector<2x8xf32> vector.print %v3 : vector<2x8xf32> // CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) ) From 8ee982880bf550fbe78e8cff7df22bd195325256 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 27 May 2024 10:59:29 -0400 Subject: [PATCH 2/2] Improve comments --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index e043320b56411..56d866ac5b40c 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -516,9 +516,10 @@ def Vector_InterleaveOp : ```mlir %a = arith.constant dense<[0, 1]> : vector<2xi32> %b = arith.constant dense<[2, 3]> : vector<2xi32> - %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32> // The value of `%0` is `[0, 2, 1, 3]`. + %0 = vector.interleave %a, %b : vector<2xi32> -> vector<4xi32> + // Examples showing allowed input and result types. %1 = vector.interleave %c, %d : vector -> vector<2xf16> %2 = vector.interleave %e, %f : vector<6x3xf32> -> vector<6x6xf32> %3 = vector.interleave %g, %h : vector<[4]xi32> -> vector<[8]xi32>