Skip to content

Commit b05a12e

Browse files
authored
Let memref.expand_shape implement ReifyRankedShapedTypeOpInterface (#90975)
This is a new take on #89111. Now that #90040 is merged, this has become trivial to implement. The added test shows the kind of benefit that we get from this: now dim-of-expand-shape naturally folds without us needing to implement an ad-hoc folding rewrite.
1 parent 5d81b1c commit b05a12e

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,8 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15781578
}
15791579

15801580
def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
1581-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
1581+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1582+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
15821583
let summary = "operation to produce a memref with a higher rank.";
15831584
let description = [{
15841585
The `memref.expand_shape` op produces a new view with a higher rank whose

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,6 +2079,13 @@ void ExpandShapeOp::getAsmResultNames(
20792079
setNameFn(getResult(), "expand_shape");
20802080
}
20812081

2082+
LogicalResult ExpandShapeOp::reifyResultShapes(
2083+
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2084+
reifiedResultShapes = {
2085+
getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2086+
return success();
2087+
}
2088+
20822089
/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
20832090
/// result and operand. Layout maps are verified separately.
20842091
///

mlir/test/Dialect/MemRef/resolve-dim-ops.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,21 @@ func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index {
5353
%dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8>
5454
return %dim : index
5555
}
56+
57+
// -----
58+
59+
// Test case: Folding of memref.dim(memref.expand_shape)
60+
// CHECK-LABEL: func @dim_of_memref_expand_shape(
61+
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<?x8xi32>
62+
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0
63+
// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref<?x8xi32>
64+
// CHECK: return %[[DIM]] : index
65+
func.func @dim_of_memref_expand_shape(%arg0: memref<?x8xi32>)
66+
-> index {
67+
%c0 = arith.constant 0 : index
68+
%c1 = arith.constant 1 : index
69+
%s = memref.dim %arg0, %c0 : memref<?x8xi32>
70+
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, %s, 2, 4]: memref<?x8xi32> into memref<1x?x2x4xi32>
71+
%1 = memref.dim %0, %c1 : memref<1x?x2x4xi32>
72+
return %1 : index
73+
}

0 commit comments

Comments
 (0)