diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 4fa5b8a4865b4..b59e9062e5a08 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -26,6 +26,9 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { // Reject index since getElementTypeBitWidth will abort for Index types. if (!vecType || vecType.getElementType().isIndex()) return false; + // There are no dimension to fold if it is a 0-D vector. + if (vecType.getRank() == 0) + return false; unsigned trailingVecDimBitWidth = vecType.getShape().back() * vecType.getElementTypeBitWidth(); if (trailingVecDimBitWidth >= targetBitWidth) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index f0e9b3a05c066..212541c79565b 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -146,6 +146,16 @@ func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x // ----- +// ALL-LABEL: func.func @test_0d_vector +func.func @test_0d_vector() -> vector { + // ALL: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector + %0 = arith.constant dense<0.0> : vector + // ALL: return %[[CST]] + return %0 : vector +} + +// ----- + func.func @test_scalable_no_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> { // expected-error@+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}} %0 = arith.constant dense<[[1., 1.], [3., 3.]]> : vector<2x[2]xf32>