Skip to content

Commit e45210a

Browse files
authored
[msan] Handle AVX512 VCVTPS2PH (#154460)
This extends handleAVX512VectorConvertFPToInt() from 556c846 (#147377) to handle AVX512 VCVTPS2PH.
1 parent 1d05d69 commit e45210a

File tree

3 files changed

+239
-104
lines changed

3 files changed

+239
-104
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3429,26 +3429,30 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
34293429
return ShadowType;
34303430
}
34313431

3432-
/// Doubles the length of a vector shadow (filled with zeros) if necessary to
3433-
/// match the length of the shadow for the instruction.
3432+
/// Doubles the length of a vector shadow (extending with zeros) if necessary
3433+
/// to match the length of the shadow for the instruction.
3434+
/// If scalar types of the vectors are different, it will use the type of the
3435+
/// input vector.
34343436
/// This is more type-safe than CreateShadowCast().
34353437
Value *maybeExtendVectorShadowWithZeros(Value *Shadow, IntrinsicInst &I) {
34363438
IRBuilder<> IRB(&I);
34373439
assert(isa<FixedVectorType>(Shadow->getType()));
34383440
assert(isa<FixedVectorType>(I.getType()));
34393441

34403442
Value *FullShadow = getCleanShadow(&I);
3441-
assert(cast<FixedVectorType>(Shadow->getType())->getNumElements() <=
3442-
cast<FixedVectorType>(FullShadow->getType())->getNumElements());
3443-
assert(cast<FixedVectorType>(Shadow->getType())->getScalarType() ==
3444-
cast<FixedVectorType>(FullShadow->getType())->getScalarType());
3443+
unsigned ShadowNumElems =
3444+
cast<FixedVectorType>(Shadow->getType())->getNumElements();
3445+
unsigned FullShadowNumElems =
3446+
cast<FixedVectorType>(FullShadow->getType())->getNumElements();
34453447

3446-
if (Shadow->getType() == FullShadow->getType()) {
3448+
assert((ShadowNumElems == FullShadowNumElems) ||
3449+
(ShadowNumElems * 2 == FullShadowNumElems));
3450+
3451+
if (ShadowNumElems == FullShadowNumElems) {
34473452
FullShadow = Shadow;
34483453
} else {
34493454
// TODO: generalize beyond 2x?
3450-
SmallVector<int, 32> ShadowMask(
3451-
cast<FixedVectorType>(FullShadow->getType())->getNumElements());
3455+
SmallVector<int, 32> ShadowMask(FullShadowNumElems);
34523456
std::iota(ShadowMask.begin(), ShadowMask.end(), 0);
34533457

34543458
// Append zeros
@@ -4528,58 +4532,102 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
45284532
return isFixedFPVectorTy(V->getType());
45294533
}
45304534

4531-
// e.g., call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512
4532-
// (<16 x float> a, <16 x i32> writethru, i16 mask,
4533-
// i32 rounding)
4535+
// e.g., <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512
4536+
// (<16 x float> a, <16 x i32> writethru, i16 mask,
4537+
// i32 rounding)
4538+
//
4539+
// Inconveniently, some similar intrinsics have a different operand order:
4540+
// <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512
4541+
// (<16 x float> a, i32 rounding, <16 x i16> writethru,
4542+
// i16 mask)
4543+
//
4544+
// If the return type has more elements than A, the excess elements are
4545+
// zeroed (and the corresponding shadow is initialized).
4546+
// <8 x i16> @llvm.x86.avx512.mask.vcvtps2ph.128
4547+
// (<4 x float> a, i32 rounding, <8 x i16> writethru,
4548+
// i8 mask)
45344549
//
45354550
// dst[i] = mask[i] ? convert(a[i]) : writethru[i]
45364551
// dst_shadow[i] = mask[i] ? all_or_nothing(a_shadow[i]) : writethru_shadow[i]
45374552
// where all_or_nothing(x) is fully uninitialized if x has any
45384553
// uninitialized bits
4539-
void handleAVX512VectorConvertFPToInt(IntrinsicInst &I) {
4554+
void handleAVX512VectorConvertFPToInt(IntrinsicInst &I, bool LastMask) {
45404555
IRBuilder<> IRB(&I);
45414556

45424557
assert(I.arg_size() == 4);
45434558
Value *A = I.getOperand(0);
4544-
Value *WriteThrough = I.getOperand(1);
4545-
Value *Mask = I.getOperand(2);
4546-
Value *RoundingMode = I.getOperand(3);
4559+
Value *WriteThrough;
4560+
Value *Mask;
4561+
Value *RoundingMode;
4562+
if (LastMask) {
4563+
WriteThrough = I.getOperand(2);
4564+
Mask = I.getOperand(3);
4565+
RoundingMode = I.getOperand(1);
4566+
} else {
4567+
WriteThrough = I.getOperand(1);
4568+
Mask = I.getOperand(2);
4569+
RoundingMode = I.getOperand(3);
4570+
}
45474571

45484572
assert(isFixedFPVector(A));
45494573
assert(isFixedIntVector(WriteThrough));
45504574

45514575
unsigned ANumElements =
45524576
cast<FixedVectorType>(A->getType())->getNumElements();
4553-
assert(ANumElements ==
4554-
cast<FixedVectorType>(WriteThrough->getType())->getNumElements());
4577+
[[maybe_unused]] unsigned WriteThruNumElements =
4578+
cast<FixedVectorType>(WriteThrough->getType())->getNumElements();
4579+
assert(ANumElements == WriteThruNumElements ||
4580+
ANumElements * 2 == WriteThruNumElements);
45554581

45564582
assert(Mask->getType()->isIntegerTy());
4557-
assert(Mask->getType()->getScalarSizeInBits() == ANumElements);
4583+
unsigned MaskNumElements = Mask->getType()->getScalarSizeInBits();
4584+
assert(ANumElements == MaskNumElements ||
4585+
ANumElements * 2 == MaskNumElements);
4586+
4587+
assert(WriteThruNumElements == MaskNumElements);
4588+
4589+
// Some bits of the mask may be unused, though it's unusual to have partly
4590+
// uninitialized bits.
45584591
insertCheckShadowOf(Mask, &I);
45594592

45604593
assert(RoundingMode->getType()->isIntegerTy());
4561-
// Only four bits of the rounding mode are used, though it's very
4594+
// Only some bits of the rounding mode are used, though it's very
45624595
// unusual to have uninitialized bits there (more commonly, it's a
45634596
// constant).
45644597
insertCheckShadowOf(RoundingMode, &I);
45654598

45664599
assert(I.getType() == WriteThrough->getType());
45674600

4601+
Value *AShadow = getShadow(A);
4602+
AShadow = maybeExtendVectorShadowWithZeros(AShadow, I);
4603+
4604+
if (ANumElements * 2 == MaskNumElements) {
4605+
// Ensure that the irrelevant bits of the mask are zero, hence selecting
4606+
// from the zeroed shadow instead of the writethrough's shadow.
4607+
Mask =
4608+
IRB.CreateTrunc(Mask, IRB.getIntNTy(ANumElements), "_ms_mask_trunc");
4609+
Mask =
4610+
IRB.CreateZExt(Mask, IRB.getIntNTy(MaskNumElements), "_ms_mask_zext");
4611+
}
4612+
45684613
// Convert i16 mask to <16 x i1>
45694614
Mask = IRB.CreateBitCast(
4570-
Mask, FixedVectorType::get(IRB.getInt1Ty(), ANumElements));
4615+
Mask, FixedVectorType::get(IRB.getInt1Ty(), MaskNumElements),
4616+
"_ms_mask_bitcast");
45714617

4572-
Value *AShadow = getShadow(A);
4573-
/// For scalars:
4574-
/// Since they are converting from floating-point, the output is:
4618+
/// For floating-point to integer conversion, the output is:
45754619
/// - fully uninitialized if *any* bit of the input is uninitialized
45764620
/// - fully ininitialized if all bits of the input are ininitialized
45774621
/// We apply the same principle on a per-element basis for vectors.
4578-
AShadow = IRB.CreateSExt(IRB.CreateICmpNE(AShadow, getCleanShadow(A)),
4579-
getShadowTy(A));
4622+
///
4623+
/// We use the scalar width of the return type instead of A's.
4624+
AShadow = IRB.CreateSExt(
4625+
IRB.CreateICmpNE(AShadow, getCleanShadow(AShadow->getType())),
4626+
getShadowTy(&I), "_ms_a_shadow");
45804627

45814628
Value *WriteThroughShadow = getShadow(WriteThrough);
4582-
Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow);
4629+
Value *Shadow = IRB.CreateSelect(Mask, AShadow, WriteThroughShadow,
4630+
"_ms_writethru_select");
45834631

45844632
setShadow(&I, Shadow);
45854633
setOriginForNaryOp(I);
@@ -5300,6 +5348,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
53005348
case Intrinsic::x86_sse_ldmxcsr:
53015349
handleLdmxcsr(I);
53025350
break;
5351+
5352+
// Convert Scalar Double Precision Floating-Point Value
5353+
// to Unsigned Doubleword Integer
5354+
// etc.
53035355
case Intrinsic::x86_avx512_vcvtsd2usi64:
53045356
case Intrinsic::x86_avx512_vcvtsd2usi32:
53055357
case Intrinsic::x86_avx512_vcvtss2usi64:
@@ -5340,6 +5392,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
53405392
break;
53415393
}
53425394

5395+
// Convert Packed Double Precision Floating-Point Values
5396+
// to Packed Single Precision Floating-Point Values
53435397
case Intrinsic::x86_sse2_cvtpd2ps:
53445398
case Intrinsic::x86_sse2_cvtps2dq:
53455399
case Intrinsic::x86_sse2_cvtpd2dq:
@@ -5354,6 +5408,20 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
53545408
break;
53555409
}
53565410

5411+
// Convert Single-Precision FP Value to 16-bit FP Value
5412+
// <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512
5413+
// (<16 x float>, i32, <16 x i16>, i16)
5414+
// <8 x i16> @llvm.x86.avx512.mask.vcvtps2ph.128
5415+
// (<4 x float>, i32, <8 x i16>, i8)
5416+
// <8 x i16> @llvm.x86.avx512.mask.vcvtps2ph.256
5417+
// (<8 x float>, i32, <8 x i16>, i8)
5418+
case Intrinsic::x86_avx512_mask_vcvtps2ph_512:
5419+
case Intrinsic::x86_avx512_mask_vcvtps2ph_256:
5420+
case Intrinsic::x86_avx512_mask_vcvtps2ph_128:
5421+
handleAVX512VectorConvertFPToInt(I, /*LastMask=*/true);
5422+
break;
5423+
5424+
// Shift Packed Data (Left Logical, Right Arithmetic, Right Logical)
53575425
case Intrinsic::x86_avx512_psll_w_512:
53585426
case Intrinsic::x86_avx512_psll_d_512:
53595427
case Intrinsic::x86_avx512_psll_q_512:
@@ -5920,10 +5988,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
59205988
/*trailingVerbatimArgs=*/1);
59215989
break;
59225990

5923-
case Intrinsic::x86_avx512_mask_cvtps2dq_512: {
5924-
handleAVX512VectorConvertFPToInt(I);
5991+
// Convert Packed Single Precision Floating-Point Values
5992+
// to Packed Signed Doubleword Integer Values
5993+
//
5994+
// <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512
5995+
// (<16 x float>, <16 x i32>, i16, i32)
5996+
case Intrinsic::x86_avx512_mask_cvtps2dq_512:
5997+
handleAVX512VectorConvertFPToInt(I, /*LastMask=*/false);
59255998
break;
5926-
}
59275999

59286000
// AVX512 PMOV: Packed MOV, with truncation
59296001
// Precisely handled by applying the same intrinsic to the shadow

llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
; - llvm.x86.avx512.mask.rndscale.pd.512, llvm.x86.avx512.mask.rndscale.ps.512, llvm.x86.avx512.mask.rndscale.sd, llvm.x86.avx512.mask.rndscale.ss
2525
; - llvm.x86.avx512.mask.scalef.pd.512, llvm.x86.avx512.mask.scalef.ps.512
2626
; - llvm.x86.avx512.mask.sqrt.sd, llvm.x86.avx512.mask.sqrt.ss
27-
; - llvm.x86.avx512.mask.vcvtps2ph.512
2827
; - llvm.x86.avx512.maskz.fixupimm.pd.512, llvm.x86.avx512.maskz.fixupimm.ps.512, llvm.x86.avx512.maskz.fixupimm.sd, llvm.x86.avx512.maskz.fixupimm.ss
2928
; - llvm.x86.avx512.mul.pd.512, llvm.x86.avx512.mul.ps.512
3029
; - llvm.x86.avx512.permvar.df.512, llvm.x86.avx512.permvar.sf.512
@@ -1903,50 +1902,46 @@ define <16 x i16> @test_x86_vcvtps2ph_256(<16 x float> %a0, <16 x i16> %src, i16
19031902
; CHECK-NEXT: [[TMP3:%.*]] = load <16 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
19041903
; CHECK-NEXT: [[TMP4:%.*]] = load i64, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 104) to ptr), align 8
19051904
; CHECK-NEXT: call void @llvm.donothing()
1906-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
1907-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP5]], 0
1908-
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP6:%.*]], label [[TMP7:%.*]], !prof [[PROF1]]
1909-
; CHECK: 6:
1910-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
1911-
; CHECK-NEXT: unreachable
1912-
; CHECK: 7:
1905+
; CHECK-NEXT: [[TMP6:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
1906+
; CHECK-NEXT: [[TMP7:%.*]] = sext <16 x i1> [[TMP6]] to <16 x i16>
1907+
; CHECK-NEXT: [[TMP8:%.*]] = select <16 x i1> splat (i1 true), <16 x i16> [[TMP7]], <16 x i16> zeroinitializer
19131908
; CHECK-NEXT: [[RES1:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0:%.*]], i32 2, <16 x i16> zeroinitializer, i16 -1)
1914-
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
1915-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP8]], 0
1909+
; CHECK-NEXT: [[TMP10:%.*]] = bitcast i16 [[MASK:%.*]] to <16 x i1>
1910+
; CHECK-NEXT: [[TMP11:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
1911+
; CHECK-NEXT: [[TMP12:%.*]] = sext <16 x i1> [[TMP11]] to <16 x i16>
1912+
; CHECK-NEXT: [[TMP13:%.*]] = select <16 x i1> [[TMP10]], <16 x i16> [[TMP12]], <16 x i16> zeroinitializer
19161913
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i16 [[TMP2]], 0
1917-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP1]], [[_MSCMP2]]
1918-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP9:%.*]], label [[TMP10:%.*]], !prof [[PROF1]]
1919-
; CHECK: 9:
1914+
; CHECK-NEXT: br i1 [[_MSCMP2]], label [[TMP9:%.*]], label [[TMP14:%.*]], !prof [[PROF1]]
1915+
; CHECK: 7:
19201916
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
19211917
; CHECK-NEXT: unreachable
1922-
; CHECK: 10:
1923-
; CHECK-NEXT: [[RES2:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0]], i32 11, <16 x i16> zeroinitializer, i16 [[MASK:%.*]])
1924-
; CHECK-NEXT: [[TMP11:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
1925-
; CHECK-NEXT: [[_MSCMP3:%.*]] = icmp ne i512 [[TMP11]], 0
1926-
; CHECK-NEXT: [[TMP12:%.*]] = bitcast <16 x i16> [[TMP3]] to i256
1927-
; CHECK-NEXT: [[_MSCMP4:%.*]] = icmp ne i256 [[TMP12]], 0
1928-
; CHECK-NEXT: [[_MSOR5:%.*]] = or i1 [[_MSCMP3]], [[_MSCMP4]]
1918+
; CHECK: 8:
1919+
; CHECK-NEXT: [[RES2:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0]], i32 11, <16 x i16> zeroinitializer, i16 [[MASK]])
1920+
; CHECK-NEXT: [[TMP25:%.*]] = bitcast i16 [[MASK]] to <16 x i1>
1921+
; CHECK-NEXT: [[TMP26:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
1922+
; CHECK-NEXT: [[TMP27:%.*]] = sext <16 x i1> [[TMP26]] to <16 x i16>
1923+
; CHECK-NEXT: [[TMP20:%.*]] = select <16 x i1> [[TMP25]], <16 x i16> [[TMP27]], <16 x i16> [[TMP3]]
19291924
; CHECK-NEXT: [[_MSCMP6:%.*]] = icmp ne i16 [[TMP2]], 0
1930-
; CHECK-NEXT: [[_MSOR7:%.*]] = or i1 [[_MSOR5]], [[_MSCMP6]]
1931-
; CHECK-NEXT: br i1 [[_MSOR7]], label [[TMP13:%.*]], label [[TMP14:%.*]], !prof [[PROF1]]
1932-
; CHECK: 13:
1925+
; CHECK-NEXT: br i1 [[_MSCMP6]], label [[TMP15:%.*]], label [[TMP16:%.*]], !prof [[PROF1]]
1926+
; CHECK: 10:
19331927
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
19341928
; CHECK-NEXT: unreachable
1935-
; CHECK: 14:
1929+
; CHECK: 11:
19361930
; CHECK-NEXT: [[RES3:%.*]] = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> [[A0]], i32 12, <16 x i16> [[SRC:%.*]], i16 [[MASK]])
19371931
; CHECK-NEXT: [[_MSCMP8:%.*]] = icmp ne i64 [[TMP4]], 0
1938-
; CHECK-NEXT: br i1 [[_MSCMP8]], label [[TMP15:%.*]], label [[TMP16:%.*]], !prof [[PROF1]]
1939-
; CHECK: 15:
1932+
; CHECK-NEXT: br i1 [[_MSCMP8]], label [[TMP21:%.*]], label [[TMP22:%.*]], !prof [[PROF1]]
1933+
; CHECK: 12:
19401934
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
19411935
; CHECK-NEXT: unreachable
1942-
; CHECK: 16:
1936+
; CHECK: 13:
19431937
; CHECK-NEXT: [[TMP17:%.*]] = ptrtoint ptr [[DST:%.*]] to i64
19441938
; CHECK-NEXT: [[TMP18:%.*]] = xor i64 [[TMP17]], 87960930222080
19451939
; CHECK-NEXT: [[TMP19:%.*]] = inttoptr i64 [[TMP18]] to ptr
1946-
; CHECK-NEXT: store <16 x i16> zeroinitializer, ptr [[TMP19]], align 32
1940+
; CHECK-NEXT: store <16 x i16> [[TMP8]], ptr [[TMP19]], align 32
19471941
; CHECK-NEXT: store <16 x i16> [[RES1]], ptr [[DST]], align 32
1942+
; CHECK-NEXT: [[_MSPROP:%.*]] = or <16 x i16> [[TMP13]], [[TMP20]]
19481943
; CHECK-NEXT: [[RES:%.*]] = add <16 x i16> [[RES2]], [[RES3]]
1949-
; CHECK-NEXT: store <16 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
1944+
; CHECK-NEXT: store <16 x i16> [[_MSPROP]], ptr @__msan_retval_tls, align 8
19501945
; CHECK-NEXT: ret <16 x i16> [[RES]]
19511946
;
19521947
%res1 = call <16 x i16> @llvm.x86.avx512.mask.vcvtps2ph.512(<16 x float> %a0, i32 2, <16 x i16> zeroinitializer, i16 -1)
@@ -7451,10 +7446,10 @@ define <16 x i32>@test_int_x86_avx512_mask_cvt_ps2dq_512(<16 x float> %x0, <16 x
74517446
; CHECK-NEXT: [[TMP6:%.*]] = select <16 x i1> [[TMP3]], <16 x i32> [[TMP5]], <16 x i32> [[TMP2]]
74527447
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i16 [[TMP10]], 0
74537448
; CHECK-NEXT: br i1 [[_MSCMP]], label [[TMP11:%.*]], label [[TMP12:%.*]], !prof [[PROF1]]
7454-
; CHECK: 8:
7449+
; CHECK: 5:
74557450
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
74567451
; CHECK-NEXT: unreachable
7457-
; CHECK: 9:
7452+
; CHECK: 6:
74587453
; CHECK-NEXT: [[RES:%.*]] = call <16 x i32> @llvm.x86.avx512.mask.cvtps2dq.512(<16 x float> [[X0:%.*]], <16 x i32> [[X1:%.*]], i16 [[X2]], i32 10)
74597454
; CHECK-NEXT: [[TMP7:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
74607455
; CHECK-NEXT: [[TMP8:%.*]] = sext <16 x i1> [[TMP7]] to <16 x i32>

0 commit comments

Comments
 (0)