diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 14b8d95ea15b4..5738b6ca51c12 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1578,7 +1578,8 @@ class MemRef_ReassociativeReshapeOp traits = []> : } def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "operation to produce a memref with a higher rank."; let description = [{ The `memref.expand_shape` op produces a new view with a higher rank whose diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index b969d41d934d4..393f73dc65cd8 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2079,6 +2079,13 @@ void ExpandShapeOp::getAsmResultNames( setNameFn(getResult(), "expand_shape"); } +LogicalResult ExpandShapeOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) { + reifiedResultShapes = { + getMixedValues(getStaticOutputShape(), getOutputShape(), builder)}; + return success(); +} + /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp /// result and operand. Layout maps are verified separately. /// diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir index 40f88de01b8bd..85a4853972457 100644 --- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir +++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir @@ -53,3 +53,21 @@ func.func @static_dim_of_transpose_op(%arg0: tensor<1x100x?x8xi8>) -> index { %dim = tensor.dim %1, %c2 : tensor<1x8x100x?xi8> return %dim : index } + +// ----- + +// Test case: Folding of memref.dim(memref.expand_shape) +// CHECK-LABEL: func @dim_of_memref_expand_shape( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref +// CHECK-NEXT: %[[IDX:.*]] = arith.constant 0 +// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[MEM]], %[[IDX]] : memref +// CHECK: return %[[DIM]] : index +func.func @dim_of_memref_expand_shape(%arg0: memref) + -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %s = memref.dim %arg0, %c0 : memref + %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, %s, 2, 4]: memref into memref<1x?x2x4xi32> + %1 = memref.dim %0, %c1 : memref<1x?x2x4xi32> + return %1 : index +}