From 18ab550e0dabc6eb76aa290dc474ce5fedf9ed75 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 6 Dec 2023 09:34:52 +0900 Subject: [PATCH] [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern The `ShapeOfOp` folder used to generate invalid IR. Input: ``` %0 = shape.shape_of %arg1 : tensor -> tensor ``` Output: ``` %0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor' ``` This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original `shape.shape_of` op had a return type of `tensor`, but the folded attribute (materialized as a `shape.const_shape` op) must have a type of `tensor<0xf32>` to be valid. This commit fixes tests such as `mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after each pattern application (#74270). --- .../include/mlir/Dialect/Shape/IR/ShapeOps.td | 1 - mlir/lib/Dialect/Shape/IR/Shape.cpp | 34 ++++++++++++++----- mlir/test/Dialect/Shape/canonicalize.mlir | 12 +++++++ 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 3c9f45366fa2b..08a0398e74b0c 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; let hasCanonicalizer = 1; - let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 2444556a45635..4f829db1305c8 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); } // ShapeOfOp //===----------------------------------------------------------------------===// -OpFoldResult ShapeOfOp::fold(FoldAdaptor) { - auto type = llvm::dyn_cast(getOperand().getType()); - if (!type || !type.hasStaticShape()) - return nullptr; - Builder builder(getContext()); - return builder.getIndexTensorAttr(type.getShape()); -} - namespace { +/// Replace shape_of(x) where x has a constant shape with a const_shape op. +struct ShapeOfOpToConstShapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::ShapeOfOp op, + PatternRewriter &rewriter) const override { + auto type = llvm::dyn_cast(op.getArg().getType()); + if (!type || !type.hasStaticShape()) + return failure(); + Location loc = op.getLoc(); + Value constShape = + rewriter + .create(loc, + rewriter.getIndexTensorAttr(type.getShape())) + .getResult(); + if (constShape.getType() != op.getResult().getType()) + constShape = rewriter.create( + loc, op.getResult().getType(), constShape); + rewriter.replaceOp(op, constShape); + return success(); + } +}; + struct ShapeOfWithTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern { void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); + ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>( + context); } LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 8edbae3baf52e..40b137f1fa36e 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1492,3 +1492,15 @@ func.func @add_poison() -> !shape.size { %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size return %result : !shape.size } + +// ----- + +// CHECK-LABEL: func @shape_of_0d( +// CHECK-SAME: %[[arg0:.*]]: tensor +// CHECK: %[[const:.*]] = shape.const_shape [] : tensor<0xindex> +// CHECK: %[[cast:.*]] = tensor.cast %[[const]] : tensor<0xindex> to tensor +// CHECK: return %[[cast]] +func.func @shape_of_0d(%arg0: tensor) -> tensor { + %0 = shape.shape_of %arg0 : tensor -> tensor + return %0 : tensor +}