-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][memref]: Collapse strided unit dim even if strides are dynamic #157330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Maya Amrami (amrami) Changes…e dynamic Full diff: https://github.com/llvm/llvm-project/pull/157330.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b59d73d1291c8..3bdeaea300659 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2401,6 +2401,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
for (int64_t idx : llvm::reverse(trailingReassocs)) {
+ // Dimensions of size 1 should be skipped, because their strides are
+ // meaningless and could have any arbitrary value.
+ if (srcShape[idx - 1] == 1)
+ continue;
+
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
// Both source and result stride must have the same static value. In that
@@ -2415,11 +2420,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (strict && (stride.saturated || srcStride.saturated))
return failure();
- // Dimensions of size 1 should be skipped, because their strides are
- // meaningless and could have any arbitrary value.
- if (srcShape[idx - 1] == 1)
- continue;
-
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 6c2298a3f8acb..50683761db5bf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -431,7 +431,8 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg4: index,
%arg5: index,
%arg6: index,
- %arg7: memref<4x?x4xf32>) {
+ %arg7: memref<4x?x4xf32>,
+ %arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -480,6 +481,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
+// CHECK-SAME: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
+ %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
return
}
|
|
@matthias-springer Can you take a look? 😊 |
matthias-springer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a problem with this PR, but I'm not working with collapse_shape/expand_shape anymore. @MaheshRavishankar is this safe to merge?
|
@MaheshRavishankar @Groverkss Can you take a look? :) |
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM here
|
Do we need something further before landing this? |
|
Hi, I'd like to signal that this breaks IREE downstream lit tests. With the change, the below memref is no longer collapsable, but it should be okay to collapse? memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>The below method starts returning false with the change. llvm-project/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp Lines 2597 to 2605 in 92a1eb3
which does not compute |
… dynamic (llvm#157330)" This reverts commit f74e909.
|
It fails in the below check. llvm-project/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp Lines 2590 to 2591 in 8f683c3
|
| %arg6: index, | ||
| %arg7: memref<4x?x4xf32>) { | ||
| %arg7: memref<4x?x4xf32>, | ||
| %arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, the element type does not matter. We should probably just use f32 or i8 for consistency in this file.
… dynamic" (#166448) Reverts #157330 The original revision introduces a bug in `isGuaranteedCollapsible`. The `memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>` is no longer collapsable with the change. The revision reverts the change to bring back correct behavior. `stride` should be computed as `96` like the old behavior in the failed iteration. https://github.com/llvm/llvm-project/blob/92a1eb37122fa24e3045fbabdea2bf87127cace5/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp#L2597-L2605
…strides are dynamic" (#166448) Reverts llvm/llvm-project#157330 The original revision introduces a bug in `isGuaranteedCollapsible`. The `memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>` is no longer collapsable with the change. The revision reverts the change to bring back correct behavior. `stride` should be computed as `96` like the old behavior in the failed iteration. https://github.com/llvm/llvm-project/blob/92a1eb37122fa24e3045fbabdea2bf87127cace5/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp#L2597-L2605
… dynamic (llvm#157330)" This reverts commit f74e909.
…e dynamic