@@ -3846,18 +3846,21 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
38463846 setOriginForNaryOp (I);
38473847 }
38483848
3849- // Instrument multiply-add intrinsics.
3849+ // Instrument multiply-add(-accumulate)? intrinsics.
38503850 //
38513851 // e.g., Two operands:
38523852 // <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a, <8 x i16> %b)
38533853 //
38543854 // Two operands which require an EltSizeInBits override:
38553855 // <1 x i64> @llvm.x86.mmx.pmadd.wd(<1 x i64> %a, <1 x i64> %b)
38563856 //
3857- // Three operands are not implemented yet :
3857+ // Three operands:
38583858 // <4 x i32> @llvm.x86.avx512.vpdpbusd.128
38593859 // (<4 x i32> %s, <4 x i32> %a, <4 x i32> %b)
3860- // (the result of multiply-add'ing %a and %b is accumulated with %s)
3860+ // (this is equivalent to multiply-add on %a and %b, followed by
3861+ // adding/"accumulating" %s. "Accumulation" stores the result in one
3862+ // of the source registers, but this accumulate vs. add distinction
3863+ // is lost when dealing with LLVM intrinsics.)
38613864 void handleVectorPmaddIntrinsic (IntrinsicInst &I, unsigned ReductionFactor,
38623865 unsigned EltSizeInBits = 0 ) {
38633866 IRBuilder<> IRB (&I);
@@ -3866,22 +3869,39 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
38663869 cast<FixedVectorType>(I.getType ());
38673870 assert (isa<FixedVectorType>(ReturnType));
38683871
3869- assert (I.arg_size () == 2 );
3870-
38713872 // Vectors A and B, and shadows
3872- Value *Va = I.getOperand (0 );
3873- Value *Vb = I.getOperand (1 );
3873+ Value *Va = nullptr ;
3874+ Value *Vb = nullptr ;
3875+ Value *Sa = nullptr ;
3876+ Value *Sb = nullptr ;
38743877
3875- Value *Sa = getShadow (&I, 0 );
3876- Value *Sb = getShadow (&I, 1 );
3878+ assert (I.arg_size () == 2 || I.arg_size () == 3 );
3879+ if (I.arg_size () == 2 ) {
3880+ Va = I.getOperand (0 );
3881+ Vb = I.getOperand (1 );
38773882
3878- FixedVectorType *ParamType =
3879- cast<FixedVectorType>(I.getArgOperand (0 )->getType ());
3880- assert (ParamType == I.getArgOperand (1 )->getType ());
3883+ Sa = getShadow (&I, 0 );
3884+ Sb = getShadow (&I, 1 );
3885+ } else if (I.arg_size () == 3 ) {
3886+ // Operand 0 is the accumulator. We will deal with that below.
3887+ Va = I.getOperand (1 );
3888+ Vb = I.getOperand (2 );
3889+
3890+ Sa = getShadow (&I, 1 );
3891+ Sb = getShadow (&I, 2 );
3892+ }
3893+
3894+ FixedVectorType *ParamType = cast<FixedVectorType>(Va->getType ());
3895+ assert (ParamType == Vb->getType ());
38813896
38823897 assert (ParamType->getPrimitiveSizeInBits () ==
38833898 ReturnType->getPrimitiveSizeInBits ());
38843899
3900+ if (I.arg_size () == 3 ) {
3901+ assert (ParamType == ReturnType);
3902+ assert (ParamType == I.getArgOperand (0 )->getType ());
3903+ }
3904+
38853905 FixedVectorType *ImplicitReturnType = ReturnType;
38863906 // Step 1: instrument multiplication of corresponding vector elements
38873907 if (EltSizeInBits) {
@@ -3944,10 +3964,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
39443964 Constant::getNullValue (Horizontal->getType ())),
39453965 ImplicitReturnType);
39463966
3947- // For MMX, cast it back to the required fake return type (<1 x i64>).
3967+ // Cast it back to the required fake return type (<1 x i64>).
39483968 if (EltSizeInBits)
39493969 OutShadow = CreateShadowCast (IRB, OutShadow, getShadowTy (&I));
39503970
3971+ // Step 3 (if applicable): instrument accumulator
3972+ if (I.arg_size () == 3 )
3973+ OutShadow = IRB.CreateOr (OutShadow, getShadow (&I, 0 ));
3974+
39513975 setShadow (&I, OutShadow);
39523976 setOriginForNaryOp (I);
39533977 }
@@ -5525,6 +5549,143 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
55255549 handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
55265550 break ;
55275551
5552+ // AVX Vector Neural Network Instructions: bytes
5553+ //
5554+ // Multiply and Add Packed Signed and Unsigned Bytes
5555+ // < 4 x i32> @llvm.x86.avx512.vpdpbusd.128
5556+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5557+ // < 8 x i32> @llvm.x86.avx512.vpdpbusd.256
5558+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5559+ // <16 x i32> @llvm.x86.avx512.vpdpbusd.512
5560+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5561+ //
5562+ // Multiply and Add Unsigned and Signed Bytes With Saturation
5563+ // < 4 x i32> @llvm.x86.avx512.vpdpbusds.128
5564+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5565+ // < 8 x i32> @llvm.x86.avx512.vpdpbusds.256
5566+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5567+ // <16 x i32> @llvm.x86.avx512.vpdpbusds.512
5568+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5569+ //
5570+ // < 4 x i32> @llvm.x86.avx2.vpdpbssd.128
5571+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5572+ // < 8 x i32> @llvm.x86.avx2.vpdpbssd.256
5573+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5574+ //
5575+ // < 4 x i32> @llvm.x86.avx2.vpdpbssds.128
5576+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5577+ // < 8 x i32> @llvm.x86.avx2.vpdpbssds.256
5578+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5579+ //
5580+ // <16 x i32> @llvm.x86.avx10.vpdpbssd.512
5581+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5582+ // <16 x i32> @llvm.x86.avx10.vpdpbssds.512
5583+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5584+ //
5585+ // These intrinsics are auto-upgraded into non-masked forms:
5586+ // <4 x i32> @llvm.x86.avx512.mask.vpdpbusd.128
5587+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5588+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpbusd.128
5589+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5590+ // <8 x i32> @llvm.x86.avx512.mask.vpdpbusd.256
5591+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5592+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpbusd.256
5593+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5594+ // <16 x i32> @llvm.x86.avx512.mask.vpdpbusd.512
5595+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5596+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpbusd.512
5597+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5598+ //
5599+ // <4 x i32> @llvm.x86.avx512.mask.vpdpbusds.128
5600+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5601+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpbusds.128
5602+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5603+ // <8 x i32> @llvm.x86.avx512.mask.vpdpbusds.256
5604+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5605+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpbusds.256
5606+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5607+ // <16 x i32> @llvm.x86.avx512.mask.vpdpbusds.512
5608+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5609+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpbusds.512
5610+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5611+ case Intrinsic::x86_avx512_vpdpbusd_128:
5612+ case Intrinsic::x86_avx512_vpdpbusd_256:
5613+ case Intrinsic::x86_avx512_vpdpbusd_512:
5614+ case Intrinsic::x86_avx512_vpdpbusds_128:
5615+ case Intrinsic::x86_avx512_vpdpbusds_256:
5616+ case Intrinsic::x86_avx512_vpdpbusds_512:
5617+ case Intrinsic::x86_avx2_vpdpbssd_128:
5618+ case Intrinsic::x86_avx2_vpdpbssd_256:
5619+ case Intrinsic::x86_avx2_vpdpbssds_128:
5620+ case Intrinsic::x86_avx2_vpdpbssds_256:
5621+ case Intrinsic::x86_avx10_vpdpbssd_512:
5622+ case Intrinsic::x86_avx10_vpdpbssds_512:
5623+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 4 , /* EltSize=*/ 8 );
5624+ break ;
5625+
5626+ // AVX Vector Neural Network Instructions: words
5627+ //
5628+ // Multiply and Add Signed Word Integers
5629+ // < 4 x i32> @llvm.x86.avx512.vpdpwssd.128
5630+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5631+ // < 8 x i32> @llvm.x86.avx512.vpdpwssd.256
5632+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5633+ // <16 x i32> @llvm.x86.avx512.vpdpwssd.512
5634+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5635+ //
5636+ // Multiply and Add Signed Word Integers With Saturation
5637+ // < 4 x i32> @llvm.x86.avx512.vpdpwssds.128
5638+ // (< 4 x i32>, < 4 x i32>, < 4 x i32>)
5639+ // < 8 x i32> @llvm.x86.avx512.vpdpwssds.256
5640+ // (< 8 x i32>, < 8 x i32>, < 8 x i32>)
5641+ // <16 x i32> @llvm.x86.avx512.vpdpwssds.512
5642+ // (<16 x i32>, <16 x i32>, <16 x i32>)
5643+ //
5644+ // These intrinsics are auto-upgraded into non-masked forms:
5645+ // <4 x i32> @llvm.x86.avx512.mask.vpdpwssd.128
5646+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5647+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpwssd.128
5648+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5649+ // <8 x i32> @llvm.x86.avx512.mask.vpdpwssd.256
5650+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5651+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpwssd.256
5652+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5653+ // <16 x i32> @llvm.x86.avx512.mask.vpdpwssd.512
5654+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5655+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpwssd.512
5656+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5657+ //
5658+ // <4 x i32> @llvm.x86.avx512.mask.vpdpwssds.128
5659+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5660+ // <4 x i32> @llvm.x86.avx512.maskz.vpdpwssds.128
5661+ // (<4 x i32>, <4 x i32>, <4 x i32>, i8)
5662+ // <8 x i32> @llvm.x86.avx512.mask.vpdpwssds.256
5663+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5664+ // <8 x i32> @llvm.x86.avx512.maskz.vpdpwssds.256
5665+ // (<8 x i32>, <8 x i32>, <8 x i32>, i8)
5666+ // <16 x i32> @llvm.x86.avx512.mask.vpdpwssds.512
5667+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5668+ // <16 x i32> @llvm.x86.avx512.maskz.vpdpwssds.512
5669+ // (<16 x i32>, <16 x i32>, <16 x i32>, i16)
5670+ case Intrinsic::x86_avx512_vpdpwssd_128:
5671+ case Intrinsic::x86_avx512_vpdpwssd_256:
5672+ case Intrinsic::x86_avx512_vpdpwssd_512:
5673+ case Intrinsic::x86_avx512_vpdpwssds_128:
5674+ case Intrinsic::x86_avx512_vpdpwssds_256:
5675+ case Intrinsic::x86_avx512_vpdpwssds_512:
5676+ handleVectorPmaddIntrinsic (I, /* ReductionFactor=*/ 2 , /* EltSize=*/ 16 );
5677+ break ;
5678+
5679+ // TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
5680+ // Precision
5681+ // <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
5682+ // (<4 x float>, <8 x bfloat>, <8 x bfloat>)
5683+ // <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
5684+ // (<8 x float>, <16 x bfloat>, <16 x bfloat>)
5685+ // <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
5686+ // (<16 x float>, <32 x bfloat>, <32 x bfloat>)
5687+ // handleVectorPmaddIntrinsic() currently only handles integer types.
5688+
55285689 case Intrinsic::x86_sse_cmp_ss:
55295690 case Intrinsic::x86_sse2_cmp_sd:
55305691 case Intrinsic::x86_sse_comieq_ss:
0 commit comments