From f3b2a0d4135c6c00a0803cfb0df31c26774501d0 Mon Sep 17 00:00:00 2001 From: AmrDeveloper Date: Sun, 22 Jun 2025 16:25:48 +0200 Subject: [PATCH] [CIR] Backport VecSplatOp simplifier --- .../CIR/Dialect/Transforms/CIRSimplify.cpp | 30 +++++++++++++++++-- clang/test/CIR/Transforms/vector-splat.cir | 16 ++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 clang/test/CIR/Transforms/vector-splat.cir diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index 4cc0021ee287..7913a0ccebac 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -141,6 +141,31 @@ struct SimplifySelect : public OpRewritePattern { } }; +struct SimplifyVecSplat : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(VecSplatOp op, + PatternRewriter &rewriter) const override { + mlir::Value splatValue = op.getValue(); + auto constant = + mlir::dyn_cast_if_present(splatValue.getDefiningOp()); + if (!constant) + return mlir::failure(); + + auto value = constant.getValue(); + if (!mlir::isa_and_nonnull(value) && + !mlir::isa_and_nonnull(value)) + return mlir::failure(); + + cir::VectorType resultType = op.getResult().getType(); + SmallVector elements(resultType.getSize(), value); + auto constVecAttr = cir::ConstVectorAttr::get( + resultType, mlir::ArrayAttr::get(getContext(), elements)); + + rewriter.replaceOpWithNewOp(op, constVecAttr); + return mlir::success(); + } +}; + //===----------------------------------------------------------------------===// // CIRSimplifyPass //===----------------------------------------------------------------------===// @@ -155,7 +180,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) { // clang-format off patterns.add< SimplifyTernary, - SimplifySelect + SimplifySelect, + SimplifyVecSplat >(patterns.getContext()); // clang-format on } @@ -168,7 +194,7 @@ void CIRSimplifyPass::runOnOperation() { // Collect operations to apply patterns. llvm::SmallVector ops; getOperation()->walk([&](Operation *op) { - if (isa(op)) + if (isa(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/Transforms/vector-splat.cir b/clang/test/CIR/Transforms/vector-splat.cir new file mode 100644 index 000000000000..76195c8a289e --- /dev/null +++ b/clang/test/CIR/Transforms/vector-splat.cir @@ -0,0 +1,16 @@ +// RUN: cir-opt %s -cir-simplify -o - | FileCheck %s + +!s32i = !cir.int + +module { + cir.func @fold_splat_vector_op_test() -> !cir.vector { + %v = cir.const #cir.int<3> : !s32i + %vec = cir.vec.splat %v : !s32i, !cir.vector + cir.return %vec : !cir.vector + } + + // CHECK: cir.func @fold_splat_vector_op_test() -> !cir.vector { + // CHECK-NEXT: %0 = cir.const #cir.const_vector<[#cir.int<3> : !s32i, #cir.int<3> : !s32i, + // CHECK-SAME: #cir.int<3> : !s32i, #cir.int<3> : !s32i]> : !cir.vector + // CHECK-NEXT: cir.return %0 : !cir.vector +}