@@ -117,8 +117,8 @@ class MemCmpExpansion {
117117 Value *Lhs = nullptr ;
118118 Value *Rhs = nullptr ;
119119 };
120- LoadPair getLoadPair (Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType ,
121- unsigned OffsetBytes);
120+ LoadPair getLoadPair (Type *LoadSizeType, Type *BSwapSizeType ,
121+ Type *CmpSizeType, unsigned OffsetBytes);
122122
123123 static LoadEntryVector
124124 computeGreedyLoadSequence (uint64_t Size, llvm::ArrayRef<unsigned > LoadSizes,
@@ -128,6 +128,11 @@ class MemCmpExpansion {
128128 unsigned MaxNumLoads,
129129 unsigned &NumLoadsNonOneByte);
130130
131+ static void optimiseLoadSequence (
132+ LoadEntryVector &LoadSequence,
133+ const TargetTransformInfo::MemCmpExpansionOptions &Options,
134+ bool IsUsedForZeroCmp);
135+
131136public:
132137 MemCmpExpansion (CallInst *CI, uint64_t Size,
133138 const TargetTransformInfo::MemCmpExpansionOptions &Options,
@@ -210,6 +215,37 @@ MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size,
210215 return LoadSequence;
211216}
212217
218+ void MemCmpExpansion::optimiseLoadSequence (
219+ LoadEntryVector &LoadSequence,
220+ const TargetTransformInfo::MemCmpExpansionOptions &Options,
221+ bool IsUsedForZeroCmp) {
222+ // This part of code attempts to optimize the LoadSequence by merging allowed
223+ // subsequences into single loads of allowed sizes from
224+ // `MemCmpExpansionOptions::AllowedTailExpansions`. If it is for zero
225+ // comparison or if no allowed tail expansions are specified, we exit early.
226+ if (IsUsedForZeroCmp || Options.AllowedTailExpansions .empty ())
227+ return ;
228+
229+ while (LoadSequence.size () >= 2 ) {
230+ auto Last = LoadSequence[LoadSequence.size () - 1 ];
231+ auto PreLast = LoadSequence[LoadSequence.size () - 2 ];
232+
233+ // Exit the loop if the two sequences are not contiguous
234+ if (PreLast.Offset + PreLast.LoadSize != Last.Offset )
235+ break ;
236+
237+ auto LoadSize = Last.LoadSize + PreLast.LoadSize ;
238+ if (find (Options.AllowedTailExpansions , LoadSize) ==
239+ Options.AllowedTailExpansions .end ())
240+ break ;
241+
242+ // Remove the last two sequences and replace with the combined sequence
243+ LoadSequence.pop_back ();
244+ LoadSequence.pop_back ();
245+ LoadSequence.emplace_back (PreLast.Offset , LoadSize);
246+ }
247+ }
248+
213249// Initialize the basic block structure required for expansion of memcmp call
214250// with given maximum load size and memcmp size parameter.
215251// This structure includes:
@@ -255,6 +291,7 @@ MemCmpExpansion::MemCmpExpansion(
255291 }
256292 }
257293 assert (LoadSequence.size () <= Options.MaxNumLoads && " broken invariant" );
294+ optimiseLoadSequence (LoadSequence, Options, IsUsedForZeroCmp);
258295}
259296
260297unsigned MemCmpExpansion::getNumBlocks () {
@@ -278,7 +315,7 @@ void MemCmpExpansion::createResultBlock() {
278315}
279316
280317MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair (Type *LoadSizeType,
281- bool NeedsBSwap ,
318+ Type *BSwapSizeType ,
282319 Type *CmpSizeType,
283320 unsigned OffsetBytes) {
284321 // Get the memory source at offset `OffsetBytes`.
@@ -307,16 +344,22 @@ MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType,
307344 if (!Rhs)
308345 Rhs = Builder.CreateAlignedLoad (LoadSizeType, RhsSource, RhsAlign);
309346
347+ // Zero extend if Byte Swap intrinsic has different type
348+ if (BSwapSizeType && LoadSizeType != BSwapSizeType) {
349+ Lhs = Builder.CreateZExt (Lhs, BSwapSizeType);
350+ Rhs = Builder.CreateZExt (Rhs, BSwapSizeType);
351+ }
352+
310353 // Swap bytes if required.
311- if (NeedsBSwap ) {
312- Function *Bswap = Intrinsic::getDeclaration (CI-> getModule (),
313- Intrinsic::bswap, LoadSizeType );
354+ if (BSwapSizeType ) {
355+ Function *Bswap = Intrinsic::getDeclaration (
356+ CI-> getModule (), Intrinsic::bswap, BSwapSizeType );
314357 Lhs = Builder.CreateCall (Bswap, Lhs);
315358 Rhs = Builder.CreateCall (Bswap, Rhs);
316359 }
317360
318361 // Zero extend if required.
319- if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType ) {
362+ if (CmpSizeType != nullptr && CmpSizeType != Lhs-> getType () ) {
320363 Lhs = Builder.CreateZExt (Lhs, CmpSizeType);
321364 Rhs = Builder.CreateZExt (Rhs, CmpSizeType);
322365 }
@@ -332,7 +375,7 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
332375 BasicBlock *BB = LoadCmpBlocks[BlockIndex];
333376 Builder.SetInsertPoint (BB);
334377 const LoadPair Loads =
335- getLoadPair (Type::getInt8Ty (CI->getContext ()), /* NeedsBSwap= */ false ,
378+ getLoadPair (Type::getInt8Ty (CI->getContext ()), nullptr ,
336379 Type::getInt32Ty (CI->getContext ()), OffsetBytes);
337380 Value *Diff = Builder.CreateSub (Loads.Lhs , Loads.Rhs );
338381
@@ -385,11 +428,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
385428 IntegerType *const MaxLoadType =
386429 NumLoads == 1 ? nullptr
387430 : IntegerType::get (CI->getContext (), MaxLoadSize * 8 );
431+
388432 for (unsigned i = 0 ; i < NumLoads; ++i, ++LoadIndex) {
389433 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
390434 const LoadPair Loads = getLoadPair (
391- IntegerType::get (CI->getContext (), CurLoadEntry.LoadSize * 8 ),
392- /* NeedsBSwap= */ false , MaxLoadType, CurLoadEntry.Offset );
435+ IntegerType::get (CI->getContext (), CurLoadEntry.LoadSize * 8 ), nullptr ,
436+ MaxLoadType, CurLoadEntry.Offset );
393437
394438 if (NumLoads != 1 ) {
395439 // If we have multiple loads per block, we need to generate a composite
@@ -475,14 +519,20 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
475519
476520 Type *LoadSizeType =
477521 IntegerType::get (CI->getContext (), CurLoadEntry.LoadSize * 8 );
478- Type *MaxLoadType = IntegerType::get (CI->getContext (), MaxLoadSize * 8 );
522+ Type *BSwapSizeType =
523+ DL.isLittleEndian ()
524+ ? IntegerType::get (CI->getContext (),
525+ PowerOf2Ceil (CurLoadEntry.LoadSize * 8 ))
526+ : nullptr ;
527+ Type *MaxLoadType = IntegerType::get (
528+ CI->getContext (),
529+ std::max (MaxLoadSize, (unsigned )PowerOf2Ceil (CurLoadEntry.LoadSize )) * 8 );
479530 assert (CurLoadEntry.LoadSize <= MaxLoadSize && " Unexpected load type" );
480531
481532 Builder.SetInsertPoint (LoadCmpBlocks[BlockIndex]);
482533
483- const LoadPair Loads =
484- getLoadPair (LoadSizeType, /* NeedsBSwap=*/ DL.isLittleEndian (), MaxLoadType,
485- CurLoadEntry.Offset );
534+ const LoadPair Loads = getLoadPair (LoadSizeType, BSwapSizeType, MaxLoadType,
535+ CurLoadEntry.Offset );
486536
487537 // Add the loaded values to the phi nodes for calculating memcmp result only
488538 // if result is not used in a zero equality.
@@ -587,19 +637,24 @@ Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
587637// / A memcmp expansion that only has one block of load and compare can bypass
588638// / the compare, branch, and phi IR that is required in the general case.
589639Value *MemCmpExpansion::getMemCmpOneBlock () {
590- Type *LoadSizeType = IntegerType::get (CI->getContext (), Size * 8 );
591640 bool NeedsBSwap = DL.isLittleEndian () && Size != 1 ;
641+ Type *LoadSizeType = IntegerType::get (CI->getContext (), Size * 8 );
642+ Type *BSwapSizeType =
643+ NeedsBSwap ? IntegerType::get (CI->getContext (), PowerOf2Ceil (Size * 8 ))
644+ : nullptr ;
645+ Type *MaxLoadType =
646+ IntegerType::get (CI->getContext (),
647+ std::max (MaxLoadSize, (unsigned )PowerOf2Ceil (Size)) * 8 );
592648
593649 // The i8 and i16 cases don't need compares. We zext the loaded values and
594650 // subtract them to get the suitable negative, zero, or positive i32 result.
595651 if (Size < 4 ) {
596- const LoadPair Loads =
597- getLoadPair (LoadSizeType, NeedsBSwap, Builder.getInt32Ty (),
598- /* Offset*/ 0 );
652+ const LoadPair Loads = getLoadPair (LoadSizeType, BSwapSizeType,
653+ Builder.getInt32Ty (), /* Offset*/ 0 );
599654 return Builder.CreateSub (Loads.Lhs , Loads.Rhs );
600655 }
601656
602- const LoadPair Loads = getLoadPair (LoadSizeType, NeedsBSwap, LoadSizeType ,
657+ const LoadPair Loads = getLoadPair (LoadSizeType, BSwapSizeType, MaxLoadType ,
603658 /* Offset*/ 0 );
604659 // The result of memcmp is negative, zero, or positive, so produce that by
605660 // subtracting 2 extended compare bits: sub (ugt, ult).
0 commit comments