diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 3b1fdb69e8ef1..aa566c0086a2f 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, // this utility). if (numSourceDims <= numTargetDims) return std::nullopt; - // Early handling for scalar target types. + // Early handling for scalar target types. We should report an invalid + // reassociation for non-unit static dimensions - no chance to collapse these + // into a scalar. if (numTargetDims == 0) { - ReassociationIndices allSourceIndices; - allSourceIndices.reserve(numSourceDims); for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; ++sourceDimIdx) { int64_t sourceSize = sourceShape[sourceDimIdx]; - // All source dimensions must be unit or dynamic. if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) return std::nullopt; - allSourceIndices.push_back(sourceDimIdx); } - return SmallVector{allSourceIndices}; + return SmallVector{}; } // Collect source ranges by iterating over the target shape left-to-right. diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index db1a87a4de2d5..05f97e875e2dc 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list list) { TEST(ReassociationIndicesForCollapse, ScalarTest) { EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}), - makeOptionalIndices({{0}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}), - makeOptionalIndices({{0, 1}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}), - makeOptionalIndices({{0}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, ShapedType::kDynamic, 1, ShapedType::kDynamic}, {}), - makeOptionalIndices({{0, 1, 2, 3, 4}})); + makeOptionalIndices({})); } TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {