From 2ae44808cd573c31331d143d687afcdab6986a72 Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Mon, 28 Apr 2025 17:24:11 +0000 Subject: [PATCH 01/16] [mlir][tensor] Loosen restrictions on folding dynamic reshapes The main idea behind the change is to allow expand-of-collapse folds for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in its `output_shape` argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage: ``` %c32 = arith.constant 32 : index %div = arith.divsi %, %c32 : index %collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]] : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32> %affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div] %expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine] : tensor<9x?x?xf32> into tensor<9x?x32x?xf32> ``` On the above assumption, adjust the routine in `getReassociationIndicesForCollapse()` to allow dynamic reshapes beyond just `?x..?x1x1x..x1` -> `?`. Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include: - abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern, employing similar logic to `ComposeCollapseOfExpandOp`; - providing dialect-specific implementations for Linalg/Tensor. Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 103 ++++++++++-------- .../Dialect/Linalg/simplify-pack-unpack.mlir | 4 +- mlir/test/Dialect/Tensor/canonicalize.mlir | 24 +++- 3 files changed, 79 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index ed40a080441bc..694783849198a 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType, std::optional> mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, ArrayRef targetShape) { - if (sourceShape.size() <= targetShape.size()) + unsigned numSourceDims = sourceShape.size(), + numTargetDims = targetShape.size(); + if (numSourceDims <= numTargetDims) return std::nullopt; - unsigned sourceDim = 0; - SmallVector reassociationMap; - reassociationMap.reserve(targetShape.size()); - - ReassociationIndices currIndices; - int64_t prodOfCollapsedDims = 1; - while (sourceDim < sourceShape.size()) { - unsigned targetDim = reassociationMap.size(); - // If we have mapped all the target dimensions stop and handle the remaining - // tail of size-1 dimensions explicitly. - if (targetDim == targetShape.size()) - break; + SmallVector reassociationMap; + reassociationMap.reserve(numTargetDims); + unsigned sourceDim = 0, targetDim = 0; + for (; targetDim < numTargetDims; ++targetDim) { int64_t currTargetShape = targetShape[targetDim]; - while (sourceDim < (sourceShape.size() - 1) && - sourceShape[sourceDim] != ShapedType::kDynamic && - prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { + ReassociationIndices currIndices; + // 1. Target dimension is dynamic. Source shape should contain at least + // one dynamic dimension. + if (currTargetShape == ShapedType::kDynamic) { + // FIXME: We stop the search with the first dynamic dimension, while in + // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes + // indeterministic altogether when we have neighboring dynamic dimensions + // in the target shape. Most of these patterns will be safely rejected, + // however we might achieve more correct folds by taking affine + // expressions into account, if these can be passed on by the call sites. + bool foundDynamic = false; + while (sourceDim < numSourceDims) { + currIndices.push_back(sourceDim); + if (sourceShape[sourceDim++] == ShapedType::kDynamic) { + foundDynamic = true; + break; + } + } + if (!foundDynamic) + return std::nullopt; + + reassociationMap.push_back(currIndices); + continue; + } + // 2. Target dimension is static. The product of dimensions of the expanded + // shape should match the collapsed dimension shape. + int64_t prodOfCollapsedDims = 1; + bool reachedTargetDimSize = false; + while (sourceDim < numSourceDims) { + // Source shape cannot be dynamic if the target dim is static. + if (sourceShape[sourceDim] == ShapedType::kDynamic) + return std::nullopt; prodOfCollapsedDims *= sourceShape[sourceDim]; - currIndices.push_back(sourceDim++); + if (prodOfCollapsedDims > currTargetShape) + break; + else if (prodOfCollapsedDims == currTargetShape) { + currIndices.push_back(sourceDim++); + reachedTargetDimSize = true; + break; + } else // prodOfCollapsedDims < currTargetShape + currIndices.push_back(sourceDim++); } - - // If the current expanded dimension is dynamic, then the collapsed - // dimensions should also be dynamic and product of all previous unprocessed - // dimensions of the expanded shape should be 1. - if (sourceShape[sourceDim] == ShapedType::kDynamic && - (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) + if (!reachedTargetDimSize) return std::nullopt; - - // If the collapsed dim is dynamic, the current expanded dim should also - // be dynamic. - if (currTargetShape == ShapedType::kDynamic && - sourceShape[sourceDim] != ShapedType::kDynamic) - return std::nullopt; - - // For static shapes, if the product of dimensions of the expanded shape - // should match the collapsed dimension shape. - if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) - return std::nullopt; - - currIndices.push_back(sourceDim++); - reassociationMap.emplace_back(ReassociationIndices{}); - std::swap(reassociationMap.back(), currIndices); - prodOfCollapsedDims = 1; + reassociationMap.push_back(currIndices); } - // All the dimensions in the target must have been processed. - if (reassociationMap.size() != targetShape.size()) - return std::nullopt; - // Process any remaining entries in the source shape. They all need to be - // 1 or dynamic. - for (; sourceDim < sourceShape.size(); sourceDim++) { - if (sourceShape[sourceDim] != ShapedType::kDynamic && + // Now that we've mapped all the target dimensions, process any remaining + // entries in the source shape explicitly. Either the last target dimension + // is dynamic, or all remaining source entries need to be 1 or dynamic. Same + // applies when target shape is empty (can be the case for subshape + // reassociations). + for (; sourceDim < numSourceDims; sourceDim++) { + if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) && + sourceShape[sourceDim] != ShapedType::kDynamic && sourceShape[sourceDim] != 1) return std::nullopt; // The map is empty when the target type is a scalar. diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir index 51350e5bc8498..6979770154bab 100644 --- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir @@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> { // ----- // CHECK-LABEL: func.func @unpack_dynamic -// CHECK-NOT: tensor.collapse -// CHECK: linalg.unpack +// CHECK: tensor.collapse +// CHECK-NOT: linalg.unpack func.func @unpack_dynamic(%arg0: tensor) -> tensor { %c32 = arith.constant 32 : index %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 85bf6fba52aa4..443f931745557 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1068,7 +1068,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3 // ----- -func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index) +func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor, %arg1: index, %arg2: index) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor @@ -1076,12 +1076,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: ind : tensor into tensor return %1 : tensor } -// CHECK-LABEL: @fold_expand_of_collapse_dynamic +// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape // CHECK-NOT: tensor.{{.*}}_shape // ----- -func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) +func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor, %arg1: index, %arg2: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape +// CHECK-NOT: tensor.expand_shape +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]] +// CHECK-SAME: : tensor into tensor +// CHECK-NEXT: return %[[COLLAPSE]] + +// ----- + +func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor, %arg1: index, %arg2: index, %arg3: index) -> tensor { %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor into tensor @@ -1089,7 +1105,7 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor, %arg1: : tensor into tensor return %1 : tensor } -// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic +// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic // CHECK: tensor.collapse_shape // CHECK: %[[EXPAND:.+]] = tensor.expand_shape // CHECK: return %[[EXPAND]] From 52ff4e0a1e81e13282975076d730e741a1da1cae Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Fri, 9 May 2025 15:12:21 +0000 Subject: [PATCH 02/16] [fixup] Algorithm rewrite Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 140 ++++++++++++++------- 1 file changed, 93 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 694783849198a..1cd06a2757363 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -39,67 +39,113 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, reassociationMap.reserve(numTargetDims); unsigned sourceDim = 0, targetDim = 0; - for (; targetDim < numTargetDims; ++targetDim) { - int64_t currTargetShape = targetShape[targetDim]; - ReassociationIndices currIndices; - // 1. Target dimension is dynamic. Source shape should contain at least - // one dynamic dimension. - if (currTargetShape == ShapedType::kDynamic) { - // FIXME: We stop the search with the first dynamic dimension, while in - // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes - // indeterministic altogether when we have neighboring dynamic dimensions - // in the target shape. Most of these patterns will be safely rejected, - // however we might achieve more correct folds by taking affine - // expressions into account, if these can be passed on by the call sites. - bool foundDynamic = false; - while (sourceDim < numSourceDims) { - currIndices.push_back(sourceDim); - if (sourceShape[sourceDim++] == ShapedType::kDynamic) { - foundDynamic = true; - break; - } - } - if (!foundDynamic) - return std::nullopt; - - reassociationMap.push_back(currIndices); - continue; - } - // 2. Target dimension is static. The product of dimensions of the expanded - // shape should match the collapsed dimension shape. + // Source dimensions iteration logic for static target dimensions. + // FIXME: Instead of lambda-capturing this function's source shape index "in + // place", consider refactoring this into a separate function. + auto collectSourceIndicesForStaticTargetDim = + [&](int64_t targetShape, + bool mayHaveOffset = false) -> FailureOr { + ReassociationIndices resultIndices; int64_t prodOfCollapsedDims = 1; bool reachedTargetDimSize = false; - while (sourceDim < numSourceDims) { + for (; sourceDim < numSourceDims; ++sourceDim) { // Source shape cannot be dynamic if the target dim is static. if (sourceShape[sourceDim] == ShapedType::kDynamic) - return std::nullopt; + return failure(); prodOfCollapsedDims *= sourceShape[sourceDim]; - if (prodOfCollapsedDims > currTargetShape) - break; - else if (prodOfCollapsedDims == currTargetShape) { - currIndices.push_back(sourceDim++); + resultIndices.push_back(sourceDim); + if (prodOfCollapsedDims > targetShape && !mayHaveOffset) + return failure(); + while (prodOfCollapsedDims > targetShape) { + assert(!resultIndices.empty()); + auto frontOffsetIdx = resultIndices.begin(); + prodOfCollapsedDims /= sourceShape[*frontOffsetIdx]; + resultIndices.erase(frontOffsetIdx); + } + if (prodOfCollapsedDims == targetShape) { reachedTargetDimSize = true; + ++sourceDim; break; - } else // prodOfCollapsedDims < currTargetShape - currIndices.push_back(sourceDim++); + } } if (!reachedTargetDimSize) + return failure(); + return resultIndices; + }; + // Source dimensions iteration logic for dynamic target dimensions. + // FIXME: Instead of lambda-capturing this function's source shape index "in + // place", consider refactoring this into a separate function. + auto collectSourceIndicesForDynamicTargetDim = + [&](bool allowStaticNonOnes, + bool mapConsecutiveDynDims) -> FailureOr { + ReassociationIndices resultIndices; + bool foundFirstDynamic = false; + while (sourceDim < numSourceDims) { + if (sourceShape[sourceDim] == ShapedType::kDynamic) { + if (foundFirstDynamic && !mapConsecutiveDynDims) + break; + foundFirstDynamic |= true; + } else { + if (foundFirstDynamic) + break; + else if (sourceShape[sourceDim] > 1 && !allowStaticNonOnes) + return failure(); + } + resultIndices.push_back(sourceDim++); + } + if (!foundFirstDynamic) + return failure(); + return resultIndices; + }; + // Iterate over target shape. + bool wasLastDimDynamic = false; + for (; targetDim < numTargetDims; ++targetDim) { + int64_t currTargetShape = targetShape[targetDim]; + if (currTargetShape != ShapedType::kDynamic) { + unsigned sourceDimAtStart = sourceDim; + auto indices = collectSourceIndicesForStaticTargetDim( + currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic); + if (failed(indices)) + return std::nullopt; + if (wasLastDimDynamic) { + assert(!reassociationMap.empty()); + auto &previousIndices = reassociationMap.back(); + for (; sourceDimAtStart < indices->front(); ++sourceDimAtStart) + previousIndices.push_back(sourceDimAtStart); + } + reassociationMap.push_back(*indices); + wasLastDimDynamic = false; + continue; + } + + bool isNextDimDynamic = targetDim + 1 < numTargetDims && + targetShape[targetDim + 1] == ShapedType::kDynamic; + auto indices = collectSourceIndicesForDynamicTargetDim( + /*allowStaticNonOnes=*/!wasLastDimDynamic, + /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic); + if (failed(indices)) return std::nullopt; - reassociationMap.push_back(currIndices); + reassociationMap.push_back(*indices); + wasLastDimDynamic = true; } // Now that we've mapped all the target dimensions, process any remaining - // entries in the source shape explicitly. Either the last target dimension - // is dynamic, or all remaining source entries need to be 1 or dynamic. Same - // applies when target shape is empty (can be the case for subshape - // reassociations). + // entries in the source shape explicitly. for (; sourceDim < numSourceDims; sourceDim++) { - if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) && - sourceShape[sourceDim] != ShapedType::kDynamic && - sourceShape[sourceDim] != 1) + const bool isOne = sourceShape[sourceDim] == 1, + isDynamic = sourceShape[sourceDim] == ShapedType::kDynamic; + if (targetShape.empty()) { + if (!isOne && !isDynamic) + return std::nullopt; + continue; + } + if (wasLastDimDynamic && isDynamic) + return std::nullopt; + // If the last target dimension is static, only source dimensions of 1 are + // acceptable. + if (!wasLastDimDynamic && !isOne) return std::nullopt; - // The map is empty when the target type is a scalar. - if (!reassociationMap.empty()) - reassociationMap.back().push_back(sourceDim); + assert(!reassociationMap.empty()); + reassociationMap.back().push_back(sourceDim); } return reassociationMap; } From 1c85a6875fecf52e6714b81bbe9bd2da81178c9e Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Fri, 9 May 2025 15:12:34 +0000 Subject: [PATCH 03/16] [fixup] Add/expand unit tests Signed-off-by: Artem Gindinson Co-authored-by: Ian Wood --- mlir/test/Dialect/Tensor/canonicalize.mlir | 15 +++ mlir/unittests/Dialect/Utils/CMakeLists.txt | 1 + .../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 125 ++++++++++++++++++ 3 files changed, 141 insertions(+) create mode 100644 mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 443f931745557..035ea850c9102 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1112,6 +1112,21 @@ func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor, % // ----- +func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor, %arg1: index, %arg2: index) + -> tensor { + %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] + : tensor into tensor + %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic +// CHECK: tensor.collapse_shape +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape +// CHECK: return %[[EXPAND]] + +// ----- + func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor) -> tensor { %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor into tensor %c0 = arith.constant 0 : index diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt index 61b9cdcb3b8f3..e921c8bcfb4e5 100644 --- a/mlir/unittests/Dialect/Utils/CMakeLists.txt +++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRDialectUtilsTests StructuredOpsUtilsTest.cpp + ReshapeOpsUtilsTest.cpp IndexingUtilsTest.cpp ) mlir_target_link_libraries(MLIRDialectUtilsTests diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp new file mode 100644 index 0000000000000..bfcc70150e2ed --- /dev/null +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -0,0 +1,125 @@ +//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "gtest/gtest.h" +#include + +using namespace mlir; + +/// Helper to make constructing +/// `std::optional>` more readable. +static std::optional> +makeOptionalIndices(std::initializer_list list) { + return std::optional>(list); +} + +TEST(ReassociationIndicesForCollapse, StaticTest) { + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}), + makeOptionalIndices({{0}, {1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}), + makeOptionalIndices({{0, 1}, {2}})); +} + +TEST(ReassociationIndicesForCollapse, StaticTestFailure) { + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}), + std::nullopt); +} + +TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) { + EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1, 1}), + makeOptionalIndices({{0}, {1, 2}})); +} + +TEST(ReassociationIndicesForCollapse, DynamicTest) { + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ( + getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, ShapedType::kDynamic}, + {1, ShapedType::kDynamic}), + makeOptionalIndices({{0}, {1, 2}})); + + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, ShapedType::kDynamic}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20}, + {ShapedType::kDynamic, 20}), + makeOptionalIndices({{0, 1}, {2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20}, + {ShapedType::kDynamic, 20}), + makeOptionalIndices({{0, 1}, {2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}), + makeOptionalIndices({{0, 1}, {2, 3, 4}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic, 20, ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}, {2}, {3, 4}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic}), + makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}, {2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 1, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0}, {1, 2}})); +} + +TEST(ReassociationIndicesForCollapse, DynamicTestFailure) { + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20}, + {ShapedType::kDynamic, 10}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {20, ShapedType::kDynamic, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}), + std::nullopt); + EXPECT_EQ( + getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); +} From 0fe986e217e52cc519823a999e83155f1cd2d3ef Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Fri, 9 May 2025 15:15:08 +0000 Subject: [PATCH 04/16] [fixup] variable renaming Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 39 +++++++++++----------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 1cd06a2757363..a6ee21d941e17 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -38,7 +38,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, SmallVector reassociationMap; reassociationMap.reserve(numTargetDims); - unsigned sourceDim = 0, targetDim = 0; + unsigned sourceDimIdx = 0, targetDimIdx = 0; // Source dimensions iteration logic for static target dimensions. // FIXME: Instead of lambda-capturing this function's source shape index "in // place", consider refactoring this into a separate function. @@ -48,12 +48,12 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, ReassociationIndices resultIndices; int64_t prodOfCollapsedDims = 1; bool reachedTargetDimSize = false; - for (; sourceDim < numSourceDims; ++sourceDim) { + for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) { // Source shape cannot be dynamic if the target dim is static. - if (sourceShape[sourceDim] == ShapedType::kDynamic) + if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) return failure(); - prodOfCollapsedDims *= sourceShape[sourceDim]; - resultIndices.push_back(sourceDim); + prodOfCollapsedDims *= sourceShape[sourceDimIdx]; + resultIndices.push_back(sourceDimIdx); if (prodOfCollapsedDims > targetShape && !mayHaveOffset) return failure(); while (prodOfCollapsedDims > targetShape) { @@ -64,7 +64,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, } if (prodOfCollapsedDims == targetShape) { reachedTargetDimSize = true; - ++sourceDim; + ++sourceDimIdx; break; } } @@ -80,18 +80,18 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, bool mapConsecutiveDynDims) -> FailureOr { ReassociationIndices resultIndices; bool foundFirstDynamic = false; - while (sourceDim < numSourceDims) { - if (sourceShape[sourceDim] == ShapedType::kDynamic) { + while (sourceDimIdx < numSourceDims) { + if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) { if (foundFirstDynamic && !mapConsecutiveDynDims) break; foundFirstDynamic |= true; } else { if (foundFirstDynamic) break; - else if (sourceShape[sourceDim] > 1 && !allowStaticNonOnes) + else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes) return failure(); } - resultIndices.push_back(sourceDim++); + resultIndices.push_back(sourceDimIdx++); } if (!foundFirstDynamic) return failure(); @@ -99,10 +99,10 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, }; // Iterate over target shape. bool wasLastDimDynamic = false; - for (; targetDim < numTargetDims; ++targetDim) { - int64_t currTargetShape = targetShape[targetDim]; + for (; targetDimIdx < numTargetDims; ++targetDimIdx) { + int64_t currTargetShape = targetShape[targetDimIdx]; if (currTargetShape != ShapedType::kDynamic) { - unsigned sourceDimAtStart = sourceDim; + unsigned sourceDimAtStart = sourceDimIdx; auto indices = collectSourceIndicesForStaticTargetDim( currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic); if (failed(indices)) @@ -118,8 +118,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, continue; } - bool isNextDimDynamic = targetDim + 1 < numTargetDims && - targetShape[targetDim + 1] == ShapedType::kDynamic; + bool isNextDimDynamic = + targetDimIdx + 1 < numTargetDims && + targetShape[targetDimIdx + 1] == ShapedType::kDynamic; auto indices = collectSourceIndicesForDynamicTargetDim( /*allowStaticNonOnes=*/!wasLastDimDynamic, /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic); @@ -130,9 +131,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, } // Now that we've mapped all the target dimensions, process any remaining // entries in the source shape explicitly. - for (; sourceDim < numSourceDims; sourceDim++) { - const bool isOne = sourceShape[sourceDim] == 1, - isDynamic = sourceShape[sourceDim] == ShapedType::kDynamic; + for (; sourceDimIdx < numSourceDims; sourceDimIdx++) { + const bool isOne = sourceShape[sourceDimIdx] == 1, + isDynamic = sourceShape[sourceDimIdx] == ShapedType::kDynamic; if (targetShape.empty()) { if (!isOne && !isDynamic) return std::nullopt; @@ -145,7 +146,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, if (!wasLastDimDynamic && !isOne) return std::nullopt; assert(!reassociationMap.empty()); - reassociationMap.back().push_back(sourceDim); + reassociationMap.back().push_back(sourceDimIdx); } return reassociationMap; } From e3aa2394225bb23cde429cc06bea28df7257070a Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Fri, 9 May 2025 15:51:59 +0000 Subject: [PATCH 05/16] [fixup] Additional edge-case Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 10 ++++++++-- mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp | 9 +++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index a6ee21d941e17..8c19f20f446da 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -139,8 +139,14 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, return std::nullopt; continue; } - if (wasLastDimDynamic && isDynamic) - return std::nullopt; + // If the last 2 dimensions in the target were dynamic, the tail in the + // source shape cannot contain a dynamic value. E.g. ?x?->? is valid, + // however ?x?x10x?->?x? would be indeterminate. + if (wasLastDimDynamic && numTargetDims > 1 && + targetShape[numTargetDims - 2] == ShapedType::kDynamic) { + if (isDynamic) + return std::nullopt; + } // If the last target dimension is static, only source dimensions of 1 are // acceptable. if (!wasLastDimDynamic && !isOne) diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index bfcc70150e2ed..2564866fac493 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -92,6 +92,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) { EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1}, {ShapedType::kDynamic}), makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0}, {1, 2}})); EXPECT_EQ(getReassociationIndicesForCollapse( {1, ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic, ShapedType::kDynamic}), @@ -122,4 +126,9 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) { {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic, ShapedType::kDynamic}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1, + ShapedType::kDynamic}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + std::nullopt); } From 16a932c8fa45f00e6474dd18bd8b7781a4b2fac8 Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Tue, 20 May 2025 20:00:24 +0000 Subject: [PATCH 06/16] [WIP] Current tests pass --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 430 +++++++++++++----- .../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 24 + 2 files changed, 337 insertions(+), 117 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 8c19f20f446da..25dd434fc2122 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -10,6 +10,10 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" #include #include @@ -28,6 +32,257 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType, return std::nullopt; } +namespace { +/// A simple struct to represent ReassociationIndices as an inclusive interval. +/// It's designed to be feasibly minimal, so the call sites should manage the +/// validity of the range manually. +struct ReassociationIndexRange { + /// FIXME: Signed type is used for consistency with ReassociationIndices. + /// We should consider refactoring all reassociation utilities to use unsigned + /// types. + int64_t leftIdx = 0, rightIdx = 0; + + /// Util for manual checks of the range's validity + LogicalResult verify() const { + return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure(); + } + + /// Checks range's containment within another range. Treats the edges + /// non-exclusively. + bool isInRange(const ReassociationIndexRange &outerRange) const { + return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx; + } + + unsigned size() const { + assert(succeeded(verify())); + return rightIdx - leftIdx + 1; + } + bool containsSingleIndex() const { return size() == 1; } + + void expandRight() { ++rightIdx; } + void shrinkLeft() { ++leftIdx; } + + /// Implements arithmetic XOR semantics to get non-overlapping indices between + /// ranges. + ReassociationIndices operator^(ReassociationIndexRange &rhs) const { + ReassociationIndices result; + result.reserve(size() + rhs.size() / 2); // Attempt to amortize + for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) { + if (idx < rhs.leftIdx || idx > rhs.rightIdx) + result.push_back(idx); + } + for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) { + if (rhsIndex < leftIdx || rhsIndex > rightIdx) + result.push_back(rhsIndex); + } + return result; + } + + /// Converts the range into ReassociationIndices. + ReassociationIndices getFullIndices() const { + ReassociationIndices result; + for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) { + result.push_back(idx); + } + return result; + } +}; + +/// Starting from `sourceStartIdx`, searches `sourceShape` for the first +/// sequence that can be collapsed into a dynamic dimension (at least one must +/// be present in the source). +/// By default, lazily returns once the first dynamic dimension has been found. +/// Setting `matchGreedily` as `true` will also mark all subsequent +/// source dimensions for collapsing into the target. +FailureOr +findReassociationRangeForDynamicDim(ArrayRef sourceShape, + int64_t sourceStartIdx, + bool matchGreedily = false) { + ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; + const unsigned numSourceDims = sourceShape.size(); + ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; + if (!iterationRange.isInRange(sourceShapeAsRange)) + return failure(); + auto resultRange = iterationRange; + + bool foundDynamic = false; + for (; iterationRange.isInRange(sourceShapeAsRange); + iterationRange.expandRight()) { + int64_t sourceSize = sourceShape[iterationRange.rightIdx]; + if (foundDynamic && !matchGreedily) + break; + if (sourceSize == ShapedType::kDynamic) + foundDynamic = true; + resultRange = iterationRange; + } + if (!foundDynamic) + return failure(); + return resultRange; +} + +/// Starting from `sourceStartIdx`, searches `sourceShape` for the first +/// sequence of static dimensions such that their product matches `targetSize`. +/// By default, lazily returns once the product matches the target size. Setting +/// `matchGreedily` as `true` will append all neighboring unit dimensions +/// (dimensions of 1) to the match. +FailureOr +findReassociationRangeForSize(ArrayRef sourceShape, + int64_t sourceStartIdx, int64_t targetSize, + bool matchGreedily = false) { + ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; + const unsigned numSourceDims = sourceShape.size(); + ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; + if (!iterationRange.isInRange(sourceShapeAsRange)) + return failure(); + auto resultRange = iterationRange; + + int64_t prodOfCollapsedDims = 1; + bool reachedTargetDimSize = false; + while (iterationRange.isInRange(sourceShapeAsRange)) { + int64_t sourceSize = sourceShape[iterationRange.rightIdx]; + if (reachedTargetDimSize && !matchGreedily) + break; + if (sourceSize == ShapedType::kDynamic) { + if (reachedTargetDimSize) + break; + // Reassociation for a static dim cannot include a dynamic dim. Reset + // induction variables to essentially restart the loop from the next + // source dimension. + prodOfCollapsedDims = 1; + resultRange = {iterationRange.rightIdx + 1, iterationRange.rightIdx + 1}; + iterationRange = resultRange; + continue; + } + prodOfCollapsedDims *= sourceSize; + if (prodOfCollapsedDims > targetSize && reachedTargetDimSize) + break; + // If the target size has been exceeded without matching, we need to shift + // the range start right. From the start of the range, roll back the + // multiplication until the target size exceeds the product again. + while (prodOfCollapsedDims > targetSize && + !iterationRange.containsSingleIndex()) { + int64_t frontSourceSize = sourceShape[iterationRange.leftIdx]; + prodOfCollapsedDims /= frontSourceSize; + iterationRange.shrinkLeft(); + } + resultRange = iterationRange; + // We could've reached the target size with the current dimension, + // also as a result of the above shift to right. + if (prodOfCollapsedDims == targetSize) + reachedTargetDimSize = true; + // Increment the iteration range + iterationRange.expandRight(); + } + if (!reachedTargetDimSize) + return failure(); + return resultRange; +} + +/// Attempts to find a valid collapsing reassociation of `sourceShape` into +/// `targetShape` through a simple traversal. If successful, an array of source +/// index ranges is returned, correspondingly to each dimension in the target +/// shape. The resulting indices shall fully cover the `sourceShape` without +/// overlaps. +/// +/// The algorithm is essentially a lazy one, searching for non-greedy matches - +/// it will only yield a greedy match for the last target dimension. +/// FIXME: The algorithm can only backtrack when it needs to append an offset +/// for a static target dimension to the preceding dynamic one (this retains the +/// linear complexity). As feasible, consider adding further backtracking +/// routines to enable more reassociations, e.g.: +/// - ?x2x?x2 into ?x2 +FailureOr> +findReassociationRangesForCollapse(ArrayRef sourceShape, + ArrayRef targetShape) { + unsigned numSourceDims = sourceShape.size(), + numTargetDims = targetShape.size(); + assert(numSourceDims > numTargetDims); + ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; + + SmallVector reassocRanges; + reassocRanges.reserve(numTargetDims); + // We'll iterate in strides of 2 to enable pseudo-backtracking for simple + // cases, e.g.: + // - ?x2x3x5 into ?x15 + std::optional prevTargetSize = std::nullopt; + for (unsigned targetDimIdx = 0, sourceDimIdx = 0; + targetDimIdx < numTargetDims; ++targetDimIdx) { + int64_t targetSize = targetShape[targetDimIdx]; + std::optional nextTargetSize = std::nullopt; + + // Simply check if there are any subsequent target dimensions left - if not, + // the match must be made greedily. + bool isLastTargetDim = targetDimIdx == numTargetDims - 1; + bool shouldMatchGreedily = isLastTargetDim; + FailureOr sourceRange; + if (targetSize == ShapedType::kDynamic) { + sourceRange = findReassociationRangeForDynamicDim( + sourceShape, sourceDimIdx, shouldMatchGreedily); + } else { + sourceRange = findReassociationRangeForSize( + sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily); + } + + // Run sanity checks on the returned index range. + if (failed(sourceRange) || failed(sourceRange->verify()) || + !sourceRange->isInRange(sourceShapeAsRange)) + return failure(); + if (sourceRange->leftIdx > sourceDimIdx) { + // If some source dimensions had to be skipped in order to find a match, + // they must be collapsed into the directly preceding dynamic dimension. + if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic) + return failure(); + reassocRanges.back().rightIdx = sourceRange->leftIdx - 1; + } + + // Store the gathered information as required for the next iteration. + prevTargetSize = targetSize; + sourceDimIdx = sourceRange->rightIdx + 1; + reassocRanges.emplace_back(std::move(*sourceRange)); + } + // Fail if the source shape wasn't a full match for the target shape. We only + // need to check the last recorded index - any other gaps should have been + // mended by the main loop. + if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx) + return failure(); + return reassocRanges; +} + +/// A variant of `findReassociationRangesForCollapse(...)` that can also scan +/// the shapes right-to-left. +FailureOr> +findReassociationRangesForCollapse(ArrayRef sourceShape, + ArrayRef targetShape, + bool iterateRightToLeft) { + if (!iterateRightToLeft) + return findReassociationRangesForCollapse(sourceShape, targetShape); + // FIXME: It would be preferable to avoid the expensive copies. At the moment, + // this approach is chosen for readability of the main implementation. + auto sourceToReverse = sourceShape.vec(), targetToReverse = targetShape.vec(); + std::reverse(sourceToReverse.begin(), sourceToReverse.end()); + std::reverse(targetToReverse.begin(), targetToReverse.end()); + auto invertedRanges = + findReassociationRangesForCollapse(sourceToReverse, targetToReverse); + if (failed(invertedRanges)) + return failure(); + auto rangesToInvert = *invertedRanges; + unsigned numSourceDims = sourceShape.size(); + // We have received the ranges for inverted shapes. Now we have to invert + // the ranges back to correspond with the original source shape. + for (auto &range : rangesToInvert) { + if (failed(range.verify())) + return failure(); + int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx; + range.leftIdx = numSourceDims - 1 - invRightIdx; + range.rightIdx = numSourceDims - 1 - invLeftIdx; + } + // Also invert the ordering of the ranges to correspond with the original + // target shape. + std::reverse(rangesToInvert.begin(), rangesToInvert.end()); + return rangesToInvert; +} +} // namespace + std::optional> mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, ArrayRef targetShape) { @@ -35,124 +290,65 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, numTargetDims = targetShape.size(); if (numSourceDims <= numTargetDims) return std::nullopt; - SmallVector reassociationMap; - reassociationMap.reserve(numTargetDims); - - unsigned sourceDimIdx = 0, targetDimIdx = 0; - // Source dimensions iteration logic for static target dimensions. - // FIXME: Instead of lambda-capturing this function's source shape index "in - // place", consider refactoring this into a separate function. - auto collectSourceIndicesForStaticTargetDim = - [&](int64_t targetShape, - bool mayHaveOffset = false) -> FailureOr { - ReassociationIndices resultIndices; - int64_t prodOfCollapsedDims = 1; - bool reachedTargetDimSize = false; - for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) { - // Source shape cannot be dynamic if the target dim is static. - if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) - return failure(); - prodOfCollapsedDims *= sourceShape[sourceDimIdx]; - resultIndices.push_back(sourceDimIdx); - if (prodOfCollapsedDims > targetShape && !mayHaveOffset) - return failure(); - while (prodOfCollapsedDims > targetShape) { - assert(!resultIndices.empty()); - auto frontOffsetIdx = resultIndices.begin(); - prodOfCollapsedDims /= sourceShape[*frontOffsetIdx]; - resultIndices.erase(frontOffsetIdx); - } - if (prodOfCollapsedDims == targetShape) { - reachedTargetDimSize = true; - ++sourceDimIdx; - break; - } - } - if (!reachedTargetDimSize) - return failure(); - return resultIndices; - }; - // Source dimensions iteration logic for dynamic target dimensions. - // FIXME: Instead of lambda-capturing this function's source shape index "in - // place", consider refactoring this into a separate function. - auto collectSourceIndicesForDynamicTargetDim = - [&](bool allowStaticNonOnes, - bool mapConsecutiveDynDims) -> FailureOr { - ReassociationIndices resultIndices; - bool foundFirstDynamic = false; - while (sourceDimIdx < numSourceDims) { - if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) { - if (foundFirstDynamic && !mapConsecutiveDynDims) - break; - foundFirstDynamic |= true; - } else { - if (foundFirstDynamic) - break; - else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes) - return failure(); - } - resultIndices.push_back(sourceDimIdx++); - } - if (!foundFirstDynamic) - return failure(); - return resultIndices; - }; - // Iterate over target shape. - bool wasLastDimDynamic = false; - for (; targetDimIdx < numTargetDims; ++targetDimIdx) { - int64_t currTargetShape = targetShape[targetDimIdx]; - if (currTargetShape != ShapedType::kDynamic) { - unsigned sourceDimAtStart = sourceDimIdx; - auto indices = collectSourceIndicesForStaticTargetDim( - currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic); - if (failed(indices)) + // Early handling for scalar target types. + if (numTargetDims == 0) { + ReassociationIndices allSourceIndices(numSourceDims); + for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; + ++sourceDimIdx) { + int64_t sourceSize = sourceShape[sourceDimIdx]; + // All source dimensions must be unit or dynamic. + if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) return std::nullopt; - if (wasLastDimDynamic) { - assert(!reassociationMap.empty()); - auto &previousIndices = reassociationMap.back(); - for (; sourceDimAtStart < indices->front(); ++sourceDimAtStart) - previousIndices.push_back(sourceDimAtStart); - } - reassociationMap.push_back(*indices); - wasLastDimDynamic = false; - continue; + allSourceIndices.emplace_back(sourceDimIdx); } + return SmallVector{allSourceIndices}; + } - bool isNextDimDynamic = - targetDimIdx + 1 < numTargetDims && - targetShape[targetDimIdx + 1] == ShapedType::kDynamic; - auto indices = collectSourceIndicesForDynamicTargetDim( - /*allowStaticNonOnes=*/!wasLastDimDynamic, - /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic); - if (failed(indices)) + // Collect source ranges by iterating over the target shape left-to-right. + auto maybeForwardRanges = + findReassociationRangesForCollapse(sourceShape, targetShape); + if (failed(maybeForwardRanges)) + return std::nullopt; + auto &ranges = *maybeForwardRanges; + // Now do the same in reverse. We need to get another valid reassociation + // through some other strategy, and then compare the results in order to + // disambiguate mixed subshapes, such as: + // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x? + // This leads us to lose some of the reassociation opportunities that can only + // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without + // backtracking, the algorithm will fail right-to-left. However, this is the + // best way to preserve correctness. + // + // NB: The reversed shapes must not be temporary as we're passing through an + // ArrayRef. + auto maybeReverseRanges = findReassociationRangesForCollapse( + sourceShape, targetShape, /*iterateRightToLeft=*/true); + if (failed(maybeReverseRanges)) + return std::nullopt; + auto &reverseRanges = *maybeReverseRanges; + + if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims) + return std::nullopt; + // Now we can check for ambiguity of each target dimension's reassociation. If + // successful, we put the full indices into our result map for the target + // shape. + SmallVector reassociationMap(numTargetDims); + for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims; + ++targetDimIdx) { + auto &range = ranges[targetDimIdx]; + auto &reverseRange = reverseRanges[targetDimIdx]; + // Get non-overlapping indices between the ranges + ReassociationIndices nonMatchingIndices = range ^ reverseRange; + // The ranges should overlap, at the very least + if (nonMatchingIndices.size() == range.size() + reverseRange.size()) return std::nullopt; - reassociationMap.push_back(*indices); - wasLastDimDynamic = true; - } - // Now that we've mapped all the target dimensions, process any remaining - // entries in the source shape explicitly. - for (; sourceDimIdx < numSourceDims; sourceDimIdx++) { - const bool isOne = sourceShape[sourceDimIdx] == 1, - isDynamic = sourceShape[sourceDimIdx] == ShapedType::kDynamic; - if (targetShape.empty()) { - if (!isOne && !isDynamic) - return std::nullopt; - continue; - } - // If the last 2 dimensions in the target were dynamic, the tail in the - // source shape cannot contain a dynamic value. E.g. ?x?->? is valid, - // however ?x?x10x?->?x? would be indeterminate. - if (wasLastDimDynamic && numTargetDims > 1 && - targetShape[numTargetDims - 2] == ShapedType::kDynamic) { - if (isDynamic) + // Unit dimensions can be collapsed wherever - this is the only ambiguity + // that we allow. + for (int64_t sourceDimIdx : nonMatchingIndices) { + if (sourceShape[sourceDimIdx] != 1) return std::nullopt; } - // If the last target dimension is static, only source dimensions of 1 are - // acceptable. - if (!wasLastDimDynamic && !isOne) - return std::nullopt; - assert(!reassociationMap.empty()); - reassociationMap.back().push_back(sourceDimIdx); + reassociationMap[targetDimIdx] = range.getFullIndices(); } return reassociationMap; } @@ -379,11 +575,11 @@ SmallVector SliceFromCollapseHelper::getExtractSliceParams( // have proven that these are not sliced. In this case we just take // the full extent of each dimension in the reassociation list. if (linearizedDimensions[it.index()]) { - llvm::append_range( - offsetsSizesAndStrides, - llvm::map_range(it.value(), [&](int64_t idx) -> Range { - return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; - })); + llvm::append_range(offsetsSizesAndStrides, + llvm::map_range(it.value(), [&](int64_t idx) -> Range { + return {zeroAttr, collapseShapeInputShape[idx], + oneAttr}; + })); continue; } diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index 2564866fac493..a179d91129edb 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -57,6 +57,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) { EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1}, {ShapedType::kDynamic}), makeOptionalIndices({{0, 1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1}, + {ShapedType::kDynamic, ShapedType::kDynamic}), + makeOptionalIndices({{0, 1}, {2, 3, 4}})); EXPECT_EQ( getReassociationIndicesForCollapse( {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}), @@ -76,6 +80,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) { EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic}, {ShapedType::kDynamic}), makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10}, + {ShapedType::kDynamic, 10}), + makeOptionalIndices({{0, 1, 2, 3}, {4}})); EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20}, {ShapedType::kDynamic, 20}), makeOptionalIndices({{0, 1}, {2}})); @@ -131,4 +139,20 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) { ShapedType::kDynamic}, {ShapedType::kDynamic, ShapedType::kDynamic}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, 10, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic}, + {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic}, + {ShapedType::kDynamic, 12, ShapedType::kDynamic}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic}, + {ShapedType::kDynamic, 32, ShapedType::kDynamic}), + std::nullopt); } From dd36c47d7a6bb402497a3c4c1757f47928132e06 Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Tue, 20 May 2025 22:26:12 +0000 Subject: [PATCH 07/16] [WIP] New tests --- .../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index a179d91129edb..124c8ce86fc9c 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -155,4 +155,22 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) { {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic}, {ShapedType::kDynamic, 32, ShapedType::kDynamic}), std::nullopt); + + //===----------------------------------------------------------------------===// + // TODO: Reassociation for the following examples can be computed, but isn't + // supported by `getReassociationIndicesForCollapse`. + //===----------------------------------------------------------------------===// + + // TODO: Fails because there's no backtracking when some source dimensions + // remain unmatched at either edge. + EXPECT_EQ(getReassociationIndicesForCollapse( + {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10}, + {ShapedType::kDynamic, 10}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2}, + {1, ShapedType::kDynamic, 2}), + std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1}, + {2, ShapedType::kDynamic}), + std::nullopt); } From 07ed33d4363d64ed32f85fe0b296ca39cc916124 Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Tue, 20 May 2025 22:45:24 +0000 Subject: [PATCH 08/16] [fixup] Add scalar target tests & fix em Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 3 ++- .../Dialect/Utils/ReshapeOpsUtilsTest.cpp | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 25dd434fc2122..209577db3272f 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -292,7 +292,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, return std::nullopt; // Early handling for scalar target types. if (numTargetDims == 0) { - ReassociationIndices allSourceIndices(numSourceDims); + ReassociationIndices allSourceIndices; + allSourceIndices.reserve(numSourceDims); for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; ++sourceDimIdx) { int64_t sourceSize = sourceShape[sourceDimIdx]; diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index 124c8ce86fc9c..7abdf75c34cda 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "gtest/gtest.h" #include @@ -20,6 +21,29 @@ makeOptionalIndices(std::initializer_list list) { return std::optional>(list); } +TEST(ReassociationIndicesForCollapse, ScalarTest) { + EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}), + makeOptionalIndices({{0}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}), + makeOptionalIndices({{0, 1}})); + EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}), + makeOptionalIndices({{0}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, + ShapedType::kDynamic, 1, + ShapedType::kDynamic}, + {}), + makeOptionalIndices({{0, 1, 2, 3, 4}})); +} + +TEST(ReassociationIndicesForCollapse, ScalarTestFailure) { + EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt); + EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt); + EXPECT_EQ( + getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}), + std::nullopt); +} + TEST(ReassociationIndicesForCollapse, StaticTest) { EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}), makeOptionalIndices({{0, 1}})); From 6e61a527e0f10aca503324f1ab39e1a787134f0d Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Wed, 21 May 2025 05:21:09 +0000 Subject: [PATCH 09/16] [fixup] for self-induced unit dims problem Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 5 ----- mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 209577db3272f..5e40570f5e341 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -208,8 +208,6 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, for (unsigned targetDimIdx = 0, sourceDimIdx = 0; targetDimIdx < numTargetDims; ++targetDimIdx) { int64_t targetSize = targetShape[targetDimIdx]; - std::optional nextTargetSize = std::nullopt; - // Simply check if there are any subsequent target dimensions left - if not, // the match must be made greedily. bool isLastTargetDim = targetDimIdx == numTargetDims - 1; @@ -340,9 +338,6 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, auto &reverseRange = reverseRanges[targetDimIdx]; // Get non-overlapping indices between the ranges ReassociationIndices nonMatchingIndices = range ^ reverseRange; - // The ranges should overlap, at the very least - if (nonMatchingIndices.size() == range.size() + reverseRange.size()) - return std::nullopt; // Unit dimensions can be collapsed wherever - this is the only ambiguity // that we allow. for (int64_t sourceDimIdx : nonMatchingIndices) { diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index 7abdf75c34cda..83d720bb88f0f 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -70,8 +70,8 @@ TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) { makeOptionalIndices({{0, 1, 2}})); EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}), makeOptionalIndices({{0, 1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1, 1}), - makeOptionalIndices({{0}, {1, 2}})); + EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}), + makeOptionalIndices({{0}, {1}, {2, 3}})); } TEST(ReassociationIndicesForCollapse, DynamicTest) { From a6a18d6e7738e93f05b570f75e77278d0d344fac Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Fri, 23 May 2025 16:25:39 +0000 Subject: [PATCH 10/16] [fixup] apply non-functional comments Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 39 +++++++++------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 5e40570f5e341..2ed51de13e2a3 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -59,12 +59,9 @@ struct ReassociationIndexRange { } bool containsSingleIndex() const { return size() == 1; } - void expandRight() { ++rightIdx; } - void shrinkLeft() { ++leftIdx; } - - /// Implements arithmetic XOR semantics to get non-overlapping indices between - /// ranges. - ReassociationIndices operator^(ReassociationIndexRange &rhs) const { + /// Collects indices that do not overlap between this and another range. + ReassociationIndices + getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const { ReassociationIndices result; result.reserve(size() + rhs.size() / 2); // Attempt to amortize for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) { @@ -87,6 +84,7 @@ struct ReassociationIndexRange { return result; } }; +} // namespace /// Starting from `sourceStartIdx`, searches `sourceShape` for the first /// sequence that can be collapsed into a dynamic dimension (at least one must @@ -94,20 +92,18 @@ struct ReassociationIndexRange { /// By default, lazily returns once the first dynamic dimension has been found. /// Setting `matchGreedily` as `true` will also mark all subsequent /// source dimensions for collapsing into the target. -FailureOr +static FailureOr findReassociationRangeForDynamicDim(ArrayRef sourceShape, int64_t sourceStartIdx, bool matchGreedily = false) { ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; const unsigned numSourceDims = sourceShape.size(); ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; - if (!iterationRange.isInRange(sourceShapeAsRange)) - return failure(); auto resultRange = iterationRange; bool foundDynamic = false; for (; iterationRange.isInRange(sourceShapeAsRange); - iterationRange.expandRight()) { + iterationRange.rightIdx++) { int64_t sourceSize = sourceShape[iterationRange.rightIdx]; if (foundDynamic && !matchGreedily) break; @@ -125,15 +121,13 @@ findReassociationRangeForDynamicDim(ArrayRef sourceShape, /// By default, lazily returns once the product matches the target size. Setting /// `matchGreedily` as `true` will append all neighboring unit dimensions /// (dimensions of 1) to the match. -FailureOr +static FailureOr findReassociationRangeForSize(ArrayRef sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily = false) { ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; const unsigned numSourceDims = sourceShape.size(); ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; - if (!iterationRange.isInRange(sourceShapeAsRange)) - return failure(); auto resultRange = iterationRange; int64_t prodOfCollapsedDims = 1; @@ -163,7 +157,8 @@ findReassociationRangeForSize(ArrayRef sourceShape, !iterationRange.containsSingleIndex()) { int64_t frontSourceSize = sourceShape[iterationRange.leftIdx]; prodOfCollapsedDims /= frontSourceSize; - iterationRange.shrinkLeft(); + // Shrink the range rightwards + iterationRange.leftIdx++; } resultRange = iterationRange; // We could've reached the target size with the current dimension, @@ -171,7 +166,7 @@ findReassociationRangeForSize(ArrayRef sourceShape, if (prodOfCollapsedDims == targetSize) reachedTargetDimSize = true; // Increment the iteration range - iterationRange.expandRight(); + iterationRange.rightIdx++; } if (!reachedTargetDimSize) return failure(); @@ -191,7 +186,7 @@ findReassociationRangeForSize(ArrayRef sourceShape, /// linear complexity). As feasible, consider adding further backtracking /// routines to enable more reassociations, e.g.: /// - ?x2x?x2 into ?x2 -FailureOr> +static FailureOr> findReassociationRangesForCollapse(ArrayRef sourceShape, ArrayRef targetShape) { unsigned numSourceDims = sourceShape.size(), @@ -236,7 +231,7 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, // Store the gathered information as required for the next iteration. prevTargetSize = targetSize; sourceDimIdx = sourceRange->rightIdx + 1; - reassocRanges.emplace_back(std::move(*sourceRange)); + reassocRanges.push_back(*sourceRange); } // Fail if the source shape wasn't a full match for the target shape. We only // need to check the last recorded index - any other gaps should have been @@ -248,7 +243,7 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, /// A variant of `findReassociationRangesForCollapse(...)` that can also scan /// the shapes right-to-left. -FailureOr> +static FailureOr> findReassociationRangesForCollapse(ArrayRef sourceShape, ArrayRef targetShape, bool iterateRightToLeft) { @@ -268,8 +263,6 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, // We have received the ranges for inverted shapes. Now we have to invert // the ranges back to correspond with the original source shape. for (auto &range : rangesToInvert) { - if (failed(range.verify())) - return failure(); int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx; range.leftIdx = numSourceDims - 1 - invRightIdx; range.rightIdx = numSourceDims - 1 - invLeftIdx; @@ -279,7 +272,6 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, std::reverse(rangesToInvert.begin(), rangesToInvert.end()); return rangesToInvert; } -} // namespace std::optional> mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, @@ -298,7 +290,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, // All source dimensions must be unit or dynamic. if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) return std::nullopt; - allSourceIndices.emplace_back(sourceDimIdx); + allSourceIndices.push_back(sourceDimIdx); } return SmallVector{allSourceIndices}; } @@ -337,7 +329,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, auto &range = ranges[targetDimIdx]; auto &reverseRange = reverseRanges[targetDimIdx]; // Get non-overlapping indices between the ranges - ReassociationIndices nonMatchingIndices = range ^ reverseRange; + ReassociationIndices nonMatchingIndices = + range.getNonOverlappingIndicesWith(reverseRange); // Unit dimensions can be collapsed wherever - this is the only ambiguity // that we allow. for (int64_t sourceDimIdx : nonMatchingIndices) { From ce007de91645eaeb019158461d236db5eeb12739 Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Fri, 23 May 2025 16:46:36 +0000 Subject: [PATCH 11/16] [fixup] apply greedy logic suggestions Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 55 ++++++++++++---------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 2ed51de13e2a3..3212dcae0cc12 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -96,24 +96,24 @@ static FailureOr findReassociationRangeForDynamicDim(ArrayRef sourceShape, int64_t sourceStartIdx, bool matchGreedily = false) { - ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; const unsigned numSourceDims = sourceShape.size(); ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; - auto resultRange = iterationRange; + std::optional resultRange = std::nullopt; - bool foundDynamic = false; + ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; for (; iterationRange.isInRange(sourceShapeAsRange); iterationRange.rightIdx++) { int64_t sourceSize = sourceShape[iterationRange.rightIdx]; - if (foundDynamic && !matchGreedily) + if (sourceSize == ShapedType::kDynamic) { + resultRange = iterationRange; break; - if (sourceSize == ShapedType::kDynamic) - foundDynamic = true; - resultRange = iterationRange; + } } - if (!foundDynamic) + if (!resultRange) return failure(); - return resultRange; + if (matchGreedily) + resultRange->rightIdx = sourceShapeAsRange.rightIdx; + return *resultRange; } /// Starting from `sourceStartIdx`, searches `sourceShape` for the first @@ -125,31 +125,24 @@ static FailureOr findReassociationRangeForSize(ArrayRef sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily = false) { - ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; const unsigned numSourceDims = sourceShape.size(); ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; - auto resultRange = iterationRange; + std::optional resultRange = std::nullopt; + ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; int64_t prodOfCollapsedDims = 1; - bool reachedTargetDimSize = false; while (iterationRange.isInRange(sourceShapeAsRange)) { int64_t sourceSize = sourceShape[iterationRange.rightIdx]; - if (reachedTargetDimSize && !matchGreedily) - break; if (sourceSize == ShapedType::kDynamic) { - if (reachedTargetDimSize) - break; // Reassociation for a static dim cannot include a dynamic dim. Reset // induction variables to essentially restart the loop from the next // source dimension. prodOfCollapsedDims = 1; - resultRange = {iterationRange.rightIdx + 1, iterationRange.rightIdx + 1}; - iterationRange = resultRange; + iterationRange = {iterationRange.rightIdx + 1, + iterationRange.rightIdx + 1}; continue; } prodOfCollapsedDims *= sourceSize; - if (prodOfCollapsedDims > targetSize && reachedTargetDimSize) - break; // If the target size has been exceeded without matching, we need to shift // the range start right. From the start of the range, roll back the // multiplication until the target size exceeds the product again. @@ -160,17 +153,29 @@ findReassociationRangeForSize(ArrayRef sourceShape, // Shrink the range rightwards iterationRange.leftIdx++; } - resultRange = iterationRange; // We could've reached the target size with the current dimension, // also as a result of the above shift to right. - if (prodOfCollapsedDims == targetSize) - reachedTargetDimSize = true; + if (prodOfCollapsedDims == targetSize) { + resultRange = iterationRange; + break; + } // Increment the iteration range iterationRange.rightIdx++; } - if (!reachedTargetDimSize) + if (!resultRange) return failure(); - return resultRange; + if (matchGreedily) { + // We now want to collect all unit dimensions directly after the target + // product match. Advance the iterator to avoid OOB when the product match + // happens at the last element. + iterationRange.rightIdx++; + while (iterationRange.isInRange(sourceShapeAsRange) && + sourceShape[iterationRange.rightIdx] == 1) { + resultRange = iterationRange; + iterationRange.rightIdx++; + } + } + return *resultRange; } /// Attempts to find a valid collapsing reassociation of `sourceShape` into From 15caa2954616d36f178d6809c327d171b16510ae Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Wed, 28 May 2025 09:33:42 +0000 Subject: [PATCH 12/16] [fixup] improve `getNonOverlappingIndicesWith(&rhs)` Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 25 ++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 3212dcae0cc12..2c5e10f96010b 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -62,16 +62,23 @@ struct ReassociationIndexRange { /// Collects indices that do not overlap between this and another range. ReassociationIndices getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const { - ReassociationIndices result; - result.reserve(size() + rhs.size() / 2); // Attempt to amortize - for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) { - if (idx < rhs.leftIdx || idx > rhs.rightIdx) - result.push_back(idx); - } - for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) { - if (rhsIndex < leftIdx || rhsIndex > rightIdx) - result.push_back(rhsIndex); + if (rightIdx < rhs.leftIdx) { + // The intervals do not overlap - concatenate the indices from both. + auto jointFullIndices = getFullIndices(); + jointFullIndices.append(rhs.getFullIndices()); + return jointFullIndices; } + ReassociationIndices result; + // Handle the chunk left of the overlapping range. + int64_t leftStart = std::min(leftIdx, rhs.leftIdx); + int64_t leftEnd = std::max(leftIdx, rhs.leftIdx); + llvm::append_range(result, llvm::seq(leftStart, leftEnd)); + // Handle the chunk right of the overlapping range. Symmetrically, we should + // skip the edge of the overlap AND include the rightmost index. + int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1; + int64_t rightEnd = std::max(rightIdx, rhs.rightIdx); + if (rightStart < rightEnd) + llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd)); return result; } From 880b394e0d46b64b501b934c261656ac91bc228e Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Tue, 3 Jun 2025 09:22:59 +0000 Subject: [PATCH 13/16] [fixup] Reduce auto usage, drop obsolete variable Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 2c5e10f96010b..7c857b629701b 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -217,8 +217,7 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, int64_t targetSize = targetShape[targetDimIdx]; // Simply check if there are any subsequent target dimensions left - if not, // the match must be made greedily. - bool isLastTargetDim = targetDimIdx == numTargetDims - 1; - bool shouldMatchGreedily = isLastTargetDim; + bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1; FailureOr sourceRange; if (targetSize == ShapedType::kDynamic) { sourceRange = findReassociationRangeForDynamicDim( @@ -263,14 +262,15 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, return findReassociationRangesForCollapse(sourceShape, targetShape); // FIXME: It would be preferable to avoid the expensive copies. At the moment, // this approach is chosen for readability of the main implementation. - auto sourceToReverse = sourceShape.vec(), targetToReverse = targetShape.vec(); + std::vector sourceToReverse = sourceShape.vec(), + targetToReverse = targetShape.vec(); std::reverse(sourceToReverse.begin(), sourceToReverse.end()); std::reverse(targetToReverse.begin(), targetToReverse.end()); auto invertedRanges = findReassociationRangesForCollapse(sourceToReverse, targetToReverse); if (failed(invertedRanges)) return failure(); - auto rangesToInvert = *invertedRanges; + SmallVector &rangesToInvert = *invertedRanges; unsigned numSourceDims = sourceShape.size(); // We have received the ranges for inverted shapes. Now we have to invert // the ranges back to correspond with the original source shape. @@ -312,7 +312,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, findReassociationRangesForCollapse(sourceShape, targetShape); if (failed(maybeForwardRanges)) return std::nullopt; - auto &ranges = *maybeForwardRanges; + SmallVector &ranges = *maybeForwardRanges; // Now do the same in reverse. We need to get another valid reassociation // through some other strategy, and then compare the results in order to // disambiguate mixed subshapes, such as: @@ -328,7 +328,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, sourceShape, targetShape, /*iterateRightToLeft=*/true); if (failed(maybeReverseRanges)) return std::nullopt; - auto &reverseRanges = *maybeReverseRanges; + SmallVector &reverseRanges = *maybeReverseRanges; if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims) return std::nullopt; @@ -338,8 +338,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, SmallVector reassociationMap(numTargetDims); for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims; ++targetDimIdx) { - auto &range = ranges[targetDimIdx]; - auto &reverseRange = reverseRanges[targetDimIdx]; + ReassociationIndexRange &range = ranges[targetDimIdx]; + ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx]; // Get non-overlapping indices between the ranges ReassociationIndices nonMatchingIndices = range.getNonOverlappingIndicesWith(reverseRange); From cc6df046e38ce8f6e684099fa62232aa9c86ba9c Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Tue, 3 Jun 2025 09:29:01 +0000 Subject: [PATCH 14/16] [fixup] Move a comment to the right place Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 7c857b629701b..a20cde3c4e9d4 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -260,6 +260,9 @@ findReassociationRangesForCollapse(ArrayRef sourceShape, bool iterateRightToLeft) { if (!iterateRightToLeft) return findReassociationRangesForCollapse(sourceShape, targetShape); + // NB: To iterate right-to-left, we currently reverse the shapes and then + // reverse the result back. The reversed shapes must not be temporary, as + // we're passing through an ArrayRef. // FIXME: It would be preferable to avoid the expensive copies. At the moment, // this approach is chosen for readability of the main implementation. std::vector sourceToReverse = sourceShape.vec(), @@ -321,9 +324,6 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without // backtracking, the algorithm will fail right-to-left. However, this is the // best way to preserve correctness. - // - // NB: The reversed shapes must not be temporary as we're passing through an - // ArrayRef. auto maybeReverseRanges = findReassociationRangesForCollapse( sourceShape, targetShape, /*iterateRightToLeft=*/true); if (failed(maybeReverseRanges)) From ea9161dea6bfbb15eb8c0e150313dc01d33e06fc Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Tue, 3 Jun 2025 09:38:31 +0000 Subject: [PATCH 15/16] [fixup] Clarify some early-return cases Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 4 ++++ mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index a20cde3c4e9d4..da40e71199e24 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -293,6 +293,10 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, ArrayRef targetShape) { unsigned numSourceDims = sourceShape.size(), numTargetDims = targetShape.size(); + // We're supposed to search for a collapsing reassociation. If the sizes + // match, there's no actual collapsing taking place - it's either a no-op or a + // `tensor.reshape`-style reassociation (that would be beyond the scope of + // this utility). if (numSourceDims <= numTargetDims) return std::nullopt; // Early handling for scalar target types. diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp index 83d720bb88f0f..db1a87a4de2d5 100644 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp @@ -54,11 +54,14 @@ TEST(ReassociationIndicesForCollapse, StaticTest) { } TEST(ReassociationIndicesForCollapse, StaticTestFailure) { - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt); + // No-op reassociation EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}), std::nullopt); + // Invalid static reassociations + EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt); EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}), std::nullopt); + // Non-collapsing (expanding) reassociation EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}), std::nullopt); } From 54abd87447b221b0dbc418f8fef7de7a535370db Mon Sep 17 00:00:00 2001 From: Artem Gindinson Date: Tue, 3 Jun 2025 09:51:14 +0000 Subject: [PATCH 16/16] [fixup] Improve auto usage further Signed-off-by: Artem Gindinson --- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index da40e71199e24..3b1fdb69e8ef1 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -315,11 +315,11 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, } // Collect source ranges by iterating over the target shape left-to-right. - auto maybeForwardRanges = + FailureOr> maybeForwardRanges = findReassociationRangesForCollapse(sourceShape, targetShape); if (failed(maybeForwardRanges)) return std::nullopt; - SmallVector &ranges = *maybeForwardRanges; + auto &ranges = *maybeForwardRanges; // Now do the same in reverse. We need to get another valid reassociation // through some other strategy, and then compare the results in order to // disambiguate mixed subshapes, such as: @@ -328,11 +328,12 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without // backtracking, the algorithm will fail right-to-left. However, this is the // best way to preserve correctness. - auto maybeReverseRanges = findReassociationRangesForCollapse( - sourceShape, targetShape, /*iterateRightToLeft=*/true); + FailureOr> maybeReverseRanges = + findReassociationRangesForCollapse(sourceShape, targetShape, + /*iterateRightToLeft=*/true); if (failed(maybeReverseRanges)) return std::nullopt; - SmallVector &reverseRanges = *maybeReverseRanges; + auto &reverseRanges = *maybeReverseRanges; if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims) return std::nullopt;