From 1f421b4c3b4cd67859a275b7a1770e3691f38515 Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Fri, 13 Jun 2025 16:44:19 +0000 Subject: [PATCH] [mlir][tensor] Fix `getReassociationForCollapse` for tensor/scalar reshapes Commit 6e5a142 changed the behavior of the function when computing reassociations between tensors (consisting of unit/dynamic dimensions) and scalars/0d vectors. The IR representation for such reshapes actually expects an empty reassociation, like so: ``` func.func @example(%arg0 : tensor) -> tensor { %0 = tensor.collapse_shape %arg0 [] : tensor into tensor } ``` Restore the original behavior - the routine should resort to reporting failures when compile time-known non-unit dimensions are part of the attempted reassociation. Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 10 ++++------ mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 3b1fdb69e8ef1..aa566c0086a2f 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, // this utility). if (numSourceDims <= numTargetDims) return std::nullopt; - // Early handling for scalar target types. + // Early handling for scalar target types. We should report an invalid + // reassociation for non-unit static dimensions - no chance to collapse these + // into a scalar. if (numTargetDims == 0) { - ReassociationIndices allSourceIndices; - allSourceIndices.reserve(numSourceDims); for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; ++sourceDimIdx) { int64_t sourceSize = sourceShape[sourceDimIdx]; - // All source dimensions must be unit or dynamic. if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) return std::nullopt; - allSourceIndices.push_back(sourceDimIdx); } - return SmallVector{allSourceIndices}; + return SmallVector{}; } // Collect source ranges by iterating over the target shape left-to-right. diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index db1a87a4de2d5..05f97e875e2dc 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list list) { TEST(ReassociationIndicesForCollapse, ScalarTest) { EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}), - makeOptionalIndices({{0}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}), - makeOptionalIndices({{0, 1}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}), - makeOptionalIndices({{0}})); + makeOptionalIndices({})); EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, ShapedType::kDynamic, 1, ShapedType::kDynamic}, {}), - makeOptionalIndices({{0, 1, 2, 3, 4}})); + makeOptionalIndices({})); } TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {