Skip to content

Commit e9b7f55

Browse files
committed
[CIR] Upstream splat op for VectorType
1 parent 377cb7f commit e9b7f55

File tree

7 files changed

+249
-3
lines changed

7 files changed

+249
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,4 +2037,37 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
20372037
let hasFolder = 1;
20382038
}
20392039

2040+
2041+
//===----------------------------------------------------------------------===//
2042+
// VecSplat
2043+
//===----------------------------------------------------------------------===//
2044+
2045+
// cir.vec.splat is a separate operation from cir.vec.create because more
2046+
// efficient LLVM IR can be generated for it, and because some optimization and
2047+
// analysis passes can benefit from knowing that all elements of the vector
2048+
// have the same value.
2049+
2050+
def VecSplatOp : CIR_Op<"vec.splat", [Pure,
2051+
TypesMatchWith<"type of 'value' matches element type of 'result'", "result",
2052+
"value", "cast<VectorType>($_self).getElementType()">]> {
2053+
2054+
let summary = "Convert a scalar into a vector";
2055+
let description = [{
2056+
The `cir.vec.splat` operation creates a vector value from a scalar value.
2057+
All elements of the vector have the same value, that of the given scalar.
2058+
2059+
```mlir
2060+
%value = cir.const #cir.int<3> : !s32i
2061+
%value_vec = cir.vec.splat %value : !s32i, !cir.vector<4 x !s32i>
2062+
```
2063+
}];
2064+
2065+
let arguments = (ins CIR_AnyType:$value);
2066+
let results = (outs CIR_VectorType:$result);
2067+
2068+
let assemblyFormat = [{
2069+
$value `:` type($value) `,` qualified(type($result)) attr-dict
2070+
}];
2071+
}
2072+
20402073
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,14 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
16501650
cgf.convertType(destTy));
16511651
}
16521652

1653+
case CK_VectorSplat: {
1654+
// Create a vector object and fill all elements with the same scalar value.
1655+
assert(destTy->isVectorType() && "CK_VectorSplat to non-vector type");
1656+
return cgf.getBuilder().create<cir::VecSplatOp>(
1657+
cgf.getLoc(subExpr->getSourceRange()), cgf.convertType(destTy),
1658+
Visit(subExpr));
1659+
}
1660+
16531661
default:
16541662
cgf.getCIRGenModule().errorNYI(subExpr->getSourceRange(),
16551663
"CastExpr: ", ce->getCastKindName());

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1648,7 +1648,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
16481648
CIRToLLVMUnaryOpLowering,
16491649
CIRToLLVMVecCreateOpLowering,
16501650
CIRToLLVMVecExtractOpLowering,
1651-
CIRToLLVMVecInsertOpLowering
1651+
CIRToLLVMVecInsertOpLowering,
1652+
CIRToLLVMVecSplatOpLowering
16521653
// clang-format on
16531654
>(converter, patterns.getContext());
16541655

@@ -1773,6 +1774,38 @@ mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
17731774
return mlir::success();
17741775
}
17751776

1777+
mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
1778+
cir::VecSplatOp op, OpAdaptor adaptor,
1779+
mlir::ConversionPatternRewriter &rewriter) const {
1780+
// Vector splat can be implemented with an `insertelement` and a
1781+
// `shufflevector`, which is better than an `insertelement` for each
1782+
// element in the vector. Start with an undef vector. Insert the value into
1783+
// the first element. Then use a `shufflevector` with a mask of all 0 to
1784+
// fill out the entire vector with that value.
1785+
const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
1786+
const mlir::Type llvmTy = typeConverter->convertType(vecTy);
1787+
const mlir::Location loc = op.getLoc();
1788+
const mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
1789+
1790+
const mlir::Value elementValue = adaptor.getValue();
1791+
if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
1792+
// If the splat value is poison, then we can just use poison value
1793+
// for the entire vector.
1794+
rewriter.replaceOp(op, poison);
1795+
return mlir::success();
1796+
}
1797+
1798+
const mlir::Value indexValue =
1799+
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
1800+
const mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
1801+
loc, poison, elementValue, indexValue);
1802+
const SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
1803+
const mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
1804+
loc, oneElement, poison, zeroValues);
1805+
rewriter.replaceOp(op, shuffled);
1806+
return mlir::success();
1807+
}
1808+
17761809
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
17771810
return std::make_unique<ConvertCIRToLLVMPass>();
17781811
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ class CIRToLLVMVecInsertOpLowering
332332
mlir::ConversionPatternRewriter &) const override;
333333
};
334334

335+
class CIRToLLVMVecSplatOpLowering
336+
: public mlir::OpConversionPattern<cir::VecSplatOp> {
337+
public:
338+
using mlir::OpConversionPattern<cir::VecSplatOp>::OpConversionPattern;
339+
340+
mlir::LogicalResult
341+
matchAndRewrite(cir::VecSplatOp op, OpAdaptor,
342+
mlir::ConversionPatternRewriter &) const override;
343+
};
344+
335345
} // namespace direct
336346
} // namespace cir
337347

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
77

88
typedef int vi4 __attribute__((ext_vector_type(4)));
9+
typedef unsigned int uvi4 __attribute__((ext_vector_type(4)));
910
typedef int vi3 __attribute__((ext_vector_type(3)));
1011
typedef int vi2 __attribute__((ext_vector_type(2)));
1112
typedef double vd2 __attribute__((ext_vector_type(2)));
@@ -400,4 +401,67 @@ void foo9() {
400401
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
401402
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
402403
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
403-
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
404+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
405+
406+
void foo11() {
407+
vi4 a = {1, 2, 3, 4};
408+
vi4 shl = a << 3;
409+
410+
uvi4 b = {1u, 2u, 3u, 4u};
411+
uvi4 shr = b >> 3u;
412+
}
413+
414+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
415+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
416+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
417+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
418+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
419+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
420+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
421+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
422+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
423+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
424+
// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
425+
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
426+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
427+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
428+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
429+
// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
430+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
431+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
432+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
433+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
434+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
435+
// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
436+
// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
437+
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
438+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
439+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
440+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
441+
// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
442+
443+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
444+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
445+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
446+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
447+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
448+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
449+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
450+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
451+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
452+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
453+
// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
454+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
455+
456+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
457+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
458+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
459+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
460+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
461+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
462+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
463+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
464+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
465+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
466+
// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
467+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// RUN: FileCheck --input-file=%t.ll %s -check-prefix=OGCG
77

88
typedef int vi4 __attribute__((vector_size(16)));
9+
typedef unsigned int uvi4 __attribute__((vector_size(16)));
910
typedef double vd2 __attribute__((vector_size(16)));
1011
typedef long long vll2 __attribute__((vector_size(16)));
1112

@@ -388,4 +389,67 @@ void foo9() {
388389
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
389390
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
390391
// OGCG: %[[SHR:.*]] = ashr <4 x i32> %[[TMP_A]], %[[TMP_B]]
391-
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
392+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
393+
394+
void foo11() {
395+
vi4 a = {1, 2, 3, 4};
396+
vi4 shl = a << 3;
397+
398+
uvi4 b = {1u, 2u, 3u, 4u};
399+
uvi4 shr = b >> 3u;
400+
}
401+
402+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
403+
// CIR: %[[SHL_RES:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
404+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["b", init]
405+
// CIR: %[[SHR_RES:.*]] = cir.alloca !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>, ["shr", init]
406+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
407+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
408+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
409+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
410+
// CIR: %[[VEC_A_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
411+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
412+
// CIR: cir.store %[[VEC_A_VAL]], %[[VEC_A]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
413+
// CIR: %[[TMP_A:.*]] = cir.load %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
414+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !s32i
415+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !s32i, !cir.vector<4 x !s32i>
416+
// CIR: %[[SHL:.*]] = cir.shift(left, %[[TMP_A]] : !cir.vector<4 x !s32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
417+
// CIR: cir.store %[[SHL]], %[[SHL_RES]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
418+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !u32i
419+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !u32i
420+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !u32i
421+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !u32i
422+
// CIR: %[[VEC_B_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
423+
// CIR-SAME: !u32i, !u32i, !u32i, !u32i) : !cir.vector<4 x !u32i>
424+
// CIR: cir.store %[[VEC_B_VAL]], %[[VEC_B]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
425+
// CIR: %[[TMP_B:.*]] = cir.load %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !u32i>>, !cir.vector<4 x !u32i>
426+
// CIR: %[[SH_AMOUNT:.*]] = cir.const #cir.int<3> : !u32i
427+
// CIR: %[[SPLAT_VEC:.*]] = cir.vec.splat %[[SH_AMOUNT]] : !u32i, !cir.vector<4 x !u32i>
428+
// CIR: %[[SHR:.*]] = cir.shift(right, %[[TMP_B]] : !cir.vector<4 x !u32i>, %[[SPLAT_VEC]] : !cir.vector<4 x !u32i>) -> !cir.vector<4 x !u32i>
429+
// CIR: cir.store %[[SHR]], %[[SHR_RES]] : !cir.vector<4 x !u32i>, !cir.ptr<!cir.vector<4 x !u32i>>
430+
431+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
432+
// LLVM: %[[SHL_RES:.*]] = alloca <4 x i32>, i64 1, align 16
433+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
434+
// LLVM: %[[SHR_RES:.*]] = alloca <4 x i32>, i64 1, align 16
435+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
436+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
437+
// LLVM: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
438+
// LLVM: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
439+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
440+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
441+
// LLVM: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
442+
// LLVM: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16
443+
444+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
445+
// OGCG: %[[SHL_RES:.*]] = alloca <4 x i32>, align 16
446+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
447+
// OGCG: %[[SHR_RES:.*]] = alloca <4 x i32>, align 16
448+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_A]], align 16
449+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
450+
// OGCG: %[[SHL:.*]] = shl <4 x i32> %[[TMP_A]], splat (i32 3)
451+
// OGCG: store <4 x i32> %[[SHL]], ptr %[[SHL_RES]], align 16
452+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC_B]], align 16
453+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
454+
// OGCG: %[[SHR:.*]] = lshr <4 x i32> %[[TMP_B]], splat (i32 3)
455+
// OGCG: store <4 x i32> %[[SHR]], ptr %[[SHR_RES]], align 16

clang/test/CIR/IR/vector.cir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,38 @@ cir.func @vector_insert_element_test() {
135135
// CHECK: cir.return
136136
// CHECK: }
137137

138+
cir.func @vector_splat_test() {
139+
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
140+
%1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
141+
%2 = cir.const #cir.int<1> : !s32i
142+
%3 = cir.const #cir.int<2> : !s32i
143+
%4 = cir.const #cir.int<3> : !s32i
144+
%5 = cir.const #cir.int<4> : !s32i
145+
%6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
146+
cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
147+
%7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
148+
%8 = cir.const #cir.int<3> : !s32i
149+
%9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
150+
%10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
151+
cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
152+
cir.return
153+
}
154+
155+
// CHECK: cir.func @vector_splat_test() {
156+
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
157+
// CHECK: %1 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["shl", init]
158+
// CHECK: %2 = cir.const #cir.int<1> : !s32i
159+
// CHECK: %3 = cir.const #cir.int<2> : !s32i
160+
// CHECK: %4 = cir.const #cir.int<3> : !s32i
161+
// CHECK: %5 = cir.const #cir.int<4> : !s32i
162+
// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
163+
// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
164+
// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
165+
// CHECK: %8 = cir.const #cir.int<3> : !s32i
166+
// CHECK: %9 = cir.vec.splat %8 : !s32i, !cir.vector<4 x !s32i>
167+
// CHECK: %10 = cir.shift(left, %7 : !cir.vector<4 x !s32i>, %9 : !cir.vector<4 x !s32i>) -> !cir.vector<4 x !s32i>
168+
// CHECK: cir.store %10, %1 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
169+
// CHECK: cir.return
170+
// CHECK: }
171+
138172
}

0 commit comments

Comments
 (0)