Skip to content

[CIR][CIRGen][Builtin][Neon] Lower BI__builtin_neon_vmovn_v #909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2218,13 +2218,31 @@ mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(

mlir::cir::VectorType vTy = GetNeonType(this, neonType, hasLegalHalfType,
false, allowBFloatArgsAndRet);
if (!vTy)
mlir::Type ty = vTy;
if (!ty)
return nullptr;

unsigned intrinicId = llvmIntrinsic;
if ((modifier & UnsignedAlts) && !isUnsigned)
intrinicId = altLLVMIntrinsic;

// This first switch is for the intrinsics that cannot have a more generic
// codegen solution.
switch (builtinID) {
default:
break;
case NEON::BI__builtin_neon_vmovn_v: {
mlir::cir::VectorType qTy = builder.getExtendedElementVectorType(
vTy, mlir::cast<mlir::cir::IntType>(vTy.getEltType()).isSigned());
ops[0] = builder.createBitcast(ops[0], qTy);
// It really is truncation in this context.
// In CIR, integral cast op supports vector of int type truncating.
return builder.createIntCast(ops[0], ty);
}
}

// This second switch is for the intrinsics that might have a more generic
// codegen solution so we can use the common codegen in future.
switch (builtinID) {
default:
llvm::errs() << getAArch64SIMDIntrinsicString(builtinID) << " ";
Expand Down
26 changes: 12 additions & 14 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,9 @@ class CIRMemCpyOpLowering
};

static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter,
mlir::Value llvmSrc,
mlir::IntegerType llvmDstIntTy,
bool isUnsigned, uint64_t cirDstIntWidth) {
auto cirSrcWidth =
mlir::cast<mlir::IntegerType>(llvmSrc.getType()).getWidth();
mlir::Value llvmSrc, mlir::Type llvmDstIntTy,
bool isUnsigned, uint64_t cirSrcWidth,
uint64_t cirDstIntWidth) {
if (cirSrcWidth == cirDstIntWidth)
return llvmSrc;

Expand Down Expand Up @@ -604,7 +602,7 @@ class CIRPtrStrideOpLowering
auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);
index = getLLVMIntCast(rewriter, index, llvmDstType,
ptrStrideOp.getStride().getType().isUnsigned(),
*layoutWidth);
width, *layoutWidth);

// Rewrite the sub in front of extensions/trunc
if (rewriteSub) {
Expand Down Expand Up @@ -709,10 +707,9 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
mlir::cir::IntType dstIntType =
mlir::cast<mlir::cir::IntType>(elementTypeIfVector(dstType));
rewriter.replaceOp(
castOp,
getLLVMIntCast(rewriter, llvmSrcVal,
mlir::cast<mlir::IntegerType>(llvmDstType),
srcIntType.isUnsigned(), dstIntType.getWidth()));
castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType,
srcIntType.isUnsigned(), srcIntType.getWidth(),
dstIntType.getWidth()));
break;
}
case mlir::cir::CastKind::floating: {
Expand Down Expand Up @@ -2485,7 +2482,8 @@ class CIRShiftOpLowering
// Ensure shift amount is the same type as the value. Some undefined
// behavior might occur in the casts below as per [C99 6.5.7.3].
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
!cirAmtTy.isSigned(), cirValTy.getWidth());
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
cirValTy.getWidth());

// Lower to the proper LLVM shift operation.
if (op.getIsShiftleft())
Expand Down Expand Up @@ -2617,9 +2615,9 @@ static mlir::Value createLLVMBitOp(mlir::Location loc,
operand.getType(), operand);
}

return getLLVMIntCast(rewriter, op->getResult(0),
mlir::cast<mlir::IntegerType>(resultTy),
/*isUnsigned=*/true, resultIntTy.getWidth());
return getLLVMIntCast(
rewriter, op->getResult(0), mlir::cast<mlir::IntegerType>(resultTy),
/*isUnsigned=*/true, operandIntTy.getWidth(), resultIntTy.getWidth());
}

class CIRBitClrsbOpLowering
Expand Down
72 changes: 72 additions & 0 deletions clang/test/CIR/CodeGen/AArch64/neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -17511,3 +17511,75 @@ void test_vst1q_s64(int64_t *a, int64x2_t b) {
// uint64_t test_vaddlv_u32(uint32x2_t a) {
// return vaddlv_u32(a);
// }

uint8x8_t test_vmovn_u16(uint16x8_t a) {
return vmovn_u16(a);
// CIR-LABEL: vmovn_u16
// CIR: [[ARG:%.*]] = cir.cast(bitcast, {{%.*}} : !cir.vector<!s8i x 16>), !cir.vector<!u16i x 8>
// CIR: {{%.*}} = cir.cast(integral, [[ARG]] : !cir.vector<!u16i x 8>), !cir.vector<!u8i x 8>

// LLVM: {{.*}}@test_vmovn_u16(<8 x i16>{{.*}}[[A:%.*]])
// LLVM: [[VMOVN_1:%.*]] = bitcast <8 x i16> [[A]] to <16 x i8>
// LLVM: [[VMOVN_I:%.*]] = trunc <8 x i16> [[A]] to <8 x i8>
// LLVM: ret <8 x i8> [[VMOVN_I]]
}

uint16x4_t test_vmovn_u32(uint32x4_t a) {
return vmovn_u32(a);
// CIR-LABEL: vmovn_u32
// CIR: [[ARG:%.*]] = cir.cast(bitcast, {{%.*}} : !cir.vector<!s8i x 16>), !cir.vector<!u32i x 4>
// CIR: {{%.*}} = cir.cast(integral, [[ARG]] : !cir.vector<!u32i x 4>), !cir.vector<!u16i x 4>

// LLVM: {{.*}}@test_vmovn_u32(<4 x i32>{{.*}}[[A:%.*]])
// LLVM: [[VMOVN_1:%.*]] = bitcast <4 x i32> [[A]] to <16 x i8>
// LLVM: [[VMOVN_I:%.*]] = trunc <4 x i32> [[A]] to <4 x i16>
// LLVM: ret <4 x i16> [[VMOVN_I]]
}

uint32x2_t test_vmovn_u64(uint64x2_t a) {
return vmovn_u64(a);
// CIR-LABEL: vmovn_u64
// CIR: [[ARG:%.*]] = cir.cast(bitcast, {{%.*}} : !cir.vector<!s8i x 16>), !cir.vector<!u64i x 2>
// CIR: {{%.*}} = cir.cast(integral, [[ARG]] : !cir.vector<!u64i x 2>), !cir.vector<!u32i x 2>

// LLVM: {{.*}}@test_vmovn_u64(<2 x i64>{{.*}}[[A:%.*]])
// LLVM: [[VMOVN_1:%.*]] = bitcast <2 x i64> [[A]] to <16 x i8>
// LLVM: [[VMOVN_I:%.*]] = trunc <2 x i64> [[A]] to <2 x i32>
// LLVM: ret <2 x i32> [[VMOVN_I]]
}

int8x8_t test_vmovn_s16(int16x8_t a) {
return vmovn_s16(a);
// CIR-LABEL: vmovn_s16
// CIR: [[ARG:%.*]] = cir.cast(bitcast, {{%.*}} : !cir.vector<!s8i x 16>), !cir.vector<!s16i x 8>
// CIR: {{%.*}} = cir.cast(integral, [[ARG]] : !cir.vector<!s16i x 8>), !cir.vector<!s8i x 8>

// LLVM: {{.*}}@test_vmovn_s16(<8 x i16>{{.*}}[[A:%.*]])
// LLVM: [[VMOVN_1:%.*]] = bitcast <8 x i16> [[A]] to <16 x i8>
// LLVM: [[VMOVN_I:%.*]] = trunc <8 x i16> [[A]] to <8 x i8>
// LLVM: ret <8 x i8> [[VMOVN_I]]
}

int16x4_t test_vmovn_s32(int32x4_t a) {
return vmovn_s32(a);
// CIR-LABEL: vmovn_s32
// CIR: [[ARG:%.*]] = cir.cast(bitcast, {{%.*}} : !cir.vector<!s8i x 16>), !cir.vector<!s32i x 4>
// CIR: {{%.*}} = cir.cast(integral, [[ARG]] : !cir.vector<!s32i x 4>), !cir.vector<!s16i x 4>

// LLVM: {{.*}}@test_vmovn_s32(<4 x i32>{{.*}}[[A:%.*]])
// LLVM: [[VMOVN_1:%.*]] = bitcast <4 x i32> [[A]] to <16 x i8>
// LLVM: [[VMOVN_I:%.*]] = trunc <4 x i32> [[A]] to <4 x i16>
// LLVM: ret <4 x i16> [[VMOVN_I]]
}

int32x2_t test_vmovn_s64(int64x2_t a) {
return vmovn_s64(a);
// CIR-LABEL: vmovn_s64
// CIR: [[ARG:%.*]] = cir.cast(bitcast, {{%.*}} : !cir.vector<!s8i x 16>), !cir.vector<!s64i x 2>
// CIR: {{%.*}} = cir.cast(integral, [[ARG]] : !cir.vector<!s64i x 2>), !cir.vector<!s32i x 2>

// LLVM: {{.*}}@test_vmovn_s64(<2 x i64>{{.*}}[[A:%.*]])
// LLVM: [[VMOVN_1:%.*]] = bitcast <2 x i64> [[A]] to <16 x i8>
// LLVM: [[VMOVN_I:%.*]] = trunc <2 x i64> [[A]] to <2 x i32>
// LLVM: ret <2 x i32> [[VMOVN_I]]
}