Skip to content

Conversation

@amrami
Copy link
Contributor

@amrami amrami commented Sep 7, 2025

…e dynamic

@llvmbot
Copy link
Member

llvmbot commented Sep 7, 2025

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Maya Amrami (amrami)

Changes

…e dynamic


Full diff: https://github.com/llvm/llvm-project/pull/157330.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+5-5)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+6-1)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b59d73d1291c8..3bdeaea300659 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2401,6 +2401,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
     auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
     auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
     for (int64_t idx : llvm::reverse(trailingReassocs)) {
+      // Dimensions of size 1 should be skipped, because their strides are
+      // meaningless and could have any arbitrary value.
+      if (srcShape[idx - 1] == 1)
+        continue;
+
       stride = stride * SaturatedInteger::wrap(srcShape[idx]);
 
       // Both source and result stride must have the same static value. In that
@@ -2415,11 +2420,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
       if (strict && (stride.saturated || srcStride.saturated))
         return failure();
 
-      // Dimensions of size 1 should be skipped, because their strides are
-      // meaningless and could have any arbitrary value.
-      if (srcShape[idx - 1] == 1)
-        continue;
-
       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
         return failure();
     }
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 6c2298a3f8acb..50683761db5bf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -431,7 +431,8 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg4: index,
          %arg5: index,
          %arg6: index,
-         %arg7: memref<4x?x4xf32>) {
+         %arg7: memref<4x?x4xf32>,
+         %arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -480,6 +481,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
   %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
         : memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
+//  CHECK-SAME:     memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
+  %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
   return
 }
 

@amrami
Copy link
Contributor Author

amrami commented Sep 14, 2025

@matthias-springer Can you take a look? 😊

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a problem with this PR, but I'm not working with collapse_shape/expand_shape anymore. @MaheshRavishankar is this safe to merge?

@amrami
Copy link
Contributor Author

amrami commented Sep 28, 2025

@MaheshRavishankar @Groverkss Can you take a look? :)

@Groverkss Groverkss requested a review from krzysz00 September 29, 2025 09:25
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM here

@krzysz00
Copy link
Contributor

krzysz00 commented Nov 3, 2025

Do we need something further before landing this?

@amrami amrami merged commit f74e909 into llvm:main Nov 4, 2025
10 checks passed
@hanhanW
Copy link
Contributor

hanhanW commented Nov 4, 2025

Hi, I'd like to signal that this breaks IREE downstream lit tests.

With the change, the below memref is no longer collapsable, but it should be okay to collapse?

memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>

The below method starts returning false with the change.

bool CollapseShapeOp::isGuaranteedCollapsible(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
// MemRefs with identity layout are always collapsible.
if (srcType.getLayout().isIdentity())
return true;
return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
/*strict=*/true));
}

which does not compute stride properly in the change.

hanhanW added a commit to iree-org/llvm-project that referenced this pull request Nov 4, 2025
@hanhanW
Copy link
Contributor

hanhanW commented Nov 4, 2025

It fails in the below check. stride=1 mismatch srcStride=96. I think stride should be computed as 96 like the old behavior.

if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();

%arg6: index,
%arg7: memref<4x?x4xf32>) {
%arg7: memref<4x?x4xf32>,
%arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, the element type does not matter. We should probably just use f32 or i8 for consistency in this file.

hanhanW added a commit that referenced this pull request Nov 4, 2025
… dynamic" (#166448)

Reverts #157330

The original revision introduces a bug in `isGuaranteedCollapsible`. The
`memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>` is no
longer collapsable with the change. The revision reverts the change to
bring back correct behavior. `stride` should be computed as `96` like
the old behavior in the failed iteration.


https://github.com/llvm/llvm-project/blob/92a1eb37122fa24e3045fbabdea2bf87127cace5/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp#L2597-L2605
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Nov 4, 2025
…strides are dynamic" (#166448)

Reverts llvm/llvm-project#157330

The original revision introduces a bug in `isGuaranteedCollapsible`. The
`memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>` is no
longer collapsable with the change. The revision reverts the change to
bring back correct behavior. `stride` should be computed as `96` like
the old behavior in the failed iteration.

https://github.com/llvm/llvm-project/blob/92a1eb37122fa24e3045fbabdea2bf87127cace5/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp#L2597-L2605
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Nov 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants