diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 038a59b8ff4eb..568861cea9f92 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2199,7 +2199,9 @@ def VecShuffleOp : CIR_Op<"vec.shuffle", `(` $vec1 `,` $vec2 `:` qualified(type($vec1)) `)` $indices `:` qualified(type($result)) attr-dict }]; + let hasVerifier = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index bfd3a0a62a8e7..a6cf0a6b5d75e 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1580,9 +1580,43 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// VecShuffle +// VecShuffleOp //===----------------------------------------------------------------------===// +OpFoldResult cir::VecShuffleOp::fold(FoldAdaptor adaptor) { + auto vec1Attr = + mlir::dyn_cast_if_present(adaptor.getVec1()); + auto vec2Attr = + mlir::dyn_cast_if_present(adaptor.getVec2()); + if (!vec1Attr || !vec2Attr) + return {}; + + mlir::Type vec1ElemTy = + mlir::cast(vec1Attr.getType()).getElementType(); + + mlir::ArrayAttr vec1Elts = vec1Attr.getElts(); + mlir::ArrayAttr vec2Elts = vec2Attr.getElts(); + mlir::ArrayAttr indicesElts = adaptor.getIndices(); + + SmallVector elements; + elements.reserve(indicesElts.size()); + + uint64_t vec1Size = vec1Elts.size(); + for (const auto &idxAttr : indicesElts.getAsRange()) { + if (idxAttr.getSInt() == -1) { + elements.push_back(cir::UndefAttr::get(vec1ElemTy)); + continue; + } + + uint64_t idxValue = idxAttr.getUInt(); + elements.push_back(idxValue < vec1Size ? vec1Elts[idxValue] + : vec2Elts[idxValue - vec1Size]); + } + + return cir::ConstVectorAttr::get( + getType(), mlir::ArrayAttr::get(getContext(), elements)); +} + LogicalResult cir::VecShuffleOp::verify() { // The number of elements in the indices array must match the number of // elements in the result type. @@ -1613,7 +1647,6 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) { mlir::isa_and_nonnull(indices)) { auto vecAttr = mlir::cast(vec); auto indicesAttr = mlir::cast(indices); - auto vecTy = mlir::cast(vecAttr.getType()); mlir::ArrayAttr vecElts = vecAttr.getElts(); mlir::ArrayAttr indicesElts = indicesAttr.getElts(); @@ -1631,7 +1664,7 @@ OpFoldResult cir::VecShuffleDynamicOp::fold(FoldAdaptor adaptor) { } return cir::ConstVectorAttr::get( - vecTy, mlir::ArrayAttr::get(getContext(), elements)); + getType(), mlir::ArrayAttr::get(getContext(), elements)); } return {}; diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp index 33881c69eec5f..29f9942638964 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp @@ -142,7 +142,7 @@ void CIRCanonicalizePass::runOnOperation() { // Many operations are here to perform a manual `fold` in // applyOpPatternsGreedily. if (isa(op)) + VecExtractOp, VecShuffleOp, VecShuffleDynamicOp, VecTernaryOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/Transforms/vector-shuffle-fold.cir b/clang/test/CIR/Transforms/vector-shuffle-fold.cir new file mode 100644 index 0000000000000..87d409728989b --- /dev/null +++ b/clang/test/CIR/Transforms/vector-shuffle-fold.cir @@ -0,0 +1,59 @@ +// RUN: cir-opt %s -cir-canonicalize -o - -split-input-file | FileCheck %s + +!s32i = !cir.int +!s64i = !cir.int + +module { + cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> { + %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i> + %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i> + %new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i, + #cir.int<1> : !s64i, #cir.int<5> : !s64i] : !cir.vector<4 x !s32i> + cir.return %new_vec : !cir.vector<4 x !s32i> + } + + // CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> { + // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, + // CHECK-SAME: #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i> +} + +// ----- + +!s32i = !cir.int +!s64i = !cir.int + +module { + cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> { + %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i> + %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i> + %new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<0> : !s64i, #cir.int<4> : !s64i, + #cir.int<1> : !s64i, #cir.int<5> : !s64i, #cir.int<2> : !s64i, #cir.int<6> : !s64i] : !cir.vector<6 x !s32i> + cir.return %new_vec : !cir.vector<6 x !s32i> + } + + // CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<6 x !s32i> { + // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, + // CHECK-SAME: #cir.int<4> : !s32i, #cir.int<5> : !s32i, #cir.int<6> : !s32i]> : !cir.vector<6 x !s32i> + // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<6 x !s32i> +} + +// ----- + +!s32i = !cir.int +!s64i = !cir.int + +module { + cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> { + %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i> + %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i> + %new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<-1> : !s64i, #cir.int<4> : !s64i, + #cir.int<1> : !s64i, #cir.int<5> : !s64i] : !cir.vector<4 x !s32i> + cir.return %new_vec : !cir.vector<4 x !s32i> + } + + // CHECK: cir.func @fold_shuffle_vector_op_test() -> !cir.vector<4 x !s32i> { + // CHECK: cir.const #cir.const_vector<[#cir.undef : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, + // CHECK-SAME: #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i> + // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i> +}