From c0334a95251ca6ef700647a85e5de96c1c2cd12e Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 9 Aug 2024 06:58:58 +0100 Subject: [PATCH 1/2] [nlir][vector] Disable `vector.matrix_multiply` for scalable vectors Disables `vector.matrix_multiply` for scalable vectors. As per the docs: > This is the counterpart of llvm.matrix.multiply in MLIR I'm not aware of any use of matrix-multiply intrinsics in the context of scalable vectors, hence disabling. --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 11 +++++++---- mlir/include/mlir/IR/CommonTypeConstraints.td | 8 ++++++++ mlir/test/Dialect/Vector/invalid.mlir | 13 +++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index bc97a5ae7d2f7..b8559efda13e9 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2683,13 +2683,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure, TCresVTEtIsSameAsOpBase<0, 1>>]>, Arguments<( // TODO: tighten vector element types that make sense. - ins VectorOfRankAndType<[1], + ins FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs, - VectorOfRankAndType<[1], + FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs, I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>, Results<( - outs VectorOfRankAndType<[1], + outs FixedVectorOfRankAndType<[1], [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> { let summary = "Vector matrix multiplication op that operates on flattened 1-D" @@ -2707,7 +2707,10 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure, and multiplies them. The result matrix is returned embedded in the result vector. - Also see: + Note, the semantics of the corresponding LLVM intrinsic, + `@llvm.matrix.multiply.*`, are not clear in the context of scalable + vectors. Hence, this Op is only available for fixed-width vectors. Also + see: http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 5b6ec167fa242..2eec2c6073bbf 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -494,6 +494,14 @@ class VectorOfRankAndType allowedRanks, VectorOf.summary # VectorOfRank.summary, "::mlir::VectorType">; +// Fixed-width vector where the rank is from the given `allowedRanks` list and +// the type is from the given `allowedTypes` list +class FixedVectorOfRankAndType allowedRanks, + list allowedTypes> : AllOfType< + [FixedVectorOf, VectorOfRank], + FixedVectorOf.summary # VectorOfRank.summary, + "::mlir::VectorType">; + // Whether the number of elements of a vector is from the given // `allowedLengths` list class IsVectorOfLengthPred allowedLengths> : diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ba1efe8b3c2d3..6e077a2fb4cee 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1862,3 +1862,16 @@ func.func @invalid_step_2d() { vector.step : vector<2x4xf32> return } + +// ----- + +func.func @matrix_matmul_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) { + // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}} + %c = vector.matrix_multiply %a, %b { + lhs_rows = 2: i32, + lhs_columns = 2: i32 , + rhs_columns = 2: i32 } + : (vector<[4]xf64>, vector<4xf64>) -> vector<4xf64> + + return +} From e6d8031e6d378f14be015d2d7dd33e8343427cf7 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 9 Aug 2024 10:24:19 +0100 Subject: [PATCH 2/2] fixup! [nlir][vector] Disable `vector.matrix_multiply` for scalable vectors Include suggestions from Cullen --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 7 +++---- mlir/include/mlir/IR/CommonTypeConstraints.td | 2 +- mlir/test/Dialect/Vector/invalid.mlir | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index b8559efda13e9..a2a317109e29d 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2707,10 +2707,9 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure, and multiplies them. The result matrix is returned embedded in the result vector. - Note, the semantics of the corresponding LLVM intrinsic, - `@llvm.matrix.multiply.*`, are not clear in the context of scalable - vectors. Hence, this Op is only available for fixed-width vectors. Also - see: + Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not + support scalable vectors. Hence, this Op is only available for fixed-width + vectors. Also see: http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 2eec2c6073bbf..2493f212a356a 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -600,7 +600,7 @@ class VectorOfLengthAndType allowedLengths, // Any fixed-length vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class FixedVectorOfLengthAndType allowedLengths, - list allowedTypes> : AllOfType< + list allowedTypes> : AllOfType< [FixedVectorOf, FixedVectorOfLength], FixedVectorOf.summary # FixedVectorOfLength.summary, diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 6e077a2fb4cee..c95b8bd5ed614 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1865,7 +1865,7 @@ func.func @invalid_step_2d() { // ----- -func.func @matrix_matmul_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) { +func.func @matrix_multiply_scalable(%a: vector<[4]xf64>, %b: vector<4xf64>) { // expected-error @+1 {{'vector.matrix_multiply' op operand #0 must be fixed-length vector of signless integer or signed integer or index or floating-point values of ranks 1, but got 'vector<[4]xf64>'}} %c = vector.matrix_multiply %a, %b { lhs_rows = 2: i32,