diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d12ba8c4c59b3..58af9995548e9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -803,14 +803,36 @@ struct FoldFillWithPack : public OpRewritePattern { } }; +/// Fold fill with copy. +struct FoldFillWithCopy : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::CopyOp copyOp, + PatternRewriter &rewriter) const override { + if (auto fillOp = copyOp.getInputs().front().getDefiningOp()) { + rewriter.replaceOpWithNewOp(copyOp, copyOp.getResultTypes(), + fillOp.getInputs(), + copyOp.getOutputs()); + return success(); + } + if (auto fillOp = copyOp.getOutputs().front().getDefiningOp()) { + rewriter.replaceOpWithNewOp(copyOp, copyOp.getInputs(), + fillOp.getOutputs()); + return success(); + } + return failure(); + } +}; + } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, - FoldFillWithTensorReshape, - FoldInsertPadIntoFill>(context); + results + .add, + FoldFillWithTensorReshape, + FoldInsertPadIntoFill>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 7793e43558274..e875bae473094 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -972,3 +972,29 @@ func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor) -> tensor< %3 = linalg.copy ins(%1 : tensor) outs(%2 : tensor) -> tensor return %3: tensor } +// ----- + +// CHECK-LABEL: func @canonicalize_fill_to_copy_input( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[ZERO:.+]] = arith.constant 0.0 +// CHECK: linalg.fill ins(%[[ZERO]] : f32) outs(%[[ARG1]] : tensor) +func.func @canonicalize_fill_to_copy_input(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor + %copy = linalg.copy ins(%fill : tensor) outs(%arg1 : tensor) -> tensor + return %copy : tensor +} + +// ----- + +// CHECK-LABEL: func @canonicalize_fill_to_copy_dest( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK: linalg.copy ins(%[[ARG1]] : tensor) outs(%[[ARG0]] : tensor) +func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor, %arg1 : tensor) -> tensor { + %c0 = arith.constant 0.0 : f32 + %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor + %copy = linalg.copy ins(%arg1 : tensor) outs(%fill : tensor) -> tensor + return %copy : tensor +}