diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index a15bf891dd596..ca3c366ccec5e 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -98,6 +98,27 @@ struct RankOpInterface } }; +struct CollapseShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto collapseOp = cast(op); + assert(value == collapseOp.getResult() && "invalid value"); + + // Multiply the expressions for the dimensions in the reassociation group. + const ReassociationIndices &reassocIndices = + collapseOp.getReassociationIndices()[dim]; + AffineExpr productExpr = + cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]); + for (size_t i = 1; i < reassocIndices.size(); ++i) { + productExpr = + productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]); + } + cstr.bound(value)[dim] == productExpr; + } +}; + struct SubViewOpInterface : public ValueBoundsOpInterface::ExternalModel { @@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface>(*ctx); memref::CastOp::attachInterface(*ctx); memref::DimOp::attachInterface(*ctx); + memref::CollapseShapeOp::attachInterface( + *ctx); memref::ExpandShapeOp::attachInterface( *ctx); memref::GetGlobalOp::attachInterface(*ctx); diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir index ac1f22b68b1e1..700535a3c21ff 100644 --- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir @@ -77,6 +77,24 @@ func.func @memref_expand(%m: memref, %sz: index) -> (index, index) { // ----- +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-LABEL: func @memref_collapse( +// CHECK-SAME: %[[sz0:.*]]: index +// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index +// CHECK: %[[dim:.*]] = memref.dim %{{.*}}, %[[c2]] : memref<3x4x?x2xf32> +// CHECK: %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]] +// CHECK: return %[[c12]], %[[mul]] +func.func @memref_collapse(%sz0: index) -> (index, index) { + %0 = memref.alloc(%sz0) : memref<3x4x?x2xf32> + %1 = memref.collapse_shape %0 [[0, 1], [2, 3]] : memref<3x4x?x2xf32> into memref<12x?xf32> + %2 = "test.reify_bound"(%1) {dim = 0} : (memref<12x?xf32>) -> (index) + %3 = "test.reify_bound"(%1) {dim = 1} : (memref<12x?xf32>) -> (index) + return %2, %3 : index, index +} + +// ----- + // CHECK-LABEL: func @memref_get_global( // CHECK: %[[c4:.*]] = arith.constant 4 : index // CHECK: return %[[c4]]