@@ -446,6 +446,7 @@ static bool allCallersPassValidPointerForArgument(Argument *Arg,
446446// / parts it can be promoted into.
447447static bool findArgParts (Argument *Arg, const DataLayout &DL, AAResults &AAR,
448448 unsigned MaxElements, bool IsRecursive,
449+ bool IsSelfRecursive,
449450 SmallVectorImpl<OffsetAndArgPart> &ArgPartsVec) {
450451 // Quick exit for unused arguments
451452 if (Arg->use_empty ())
@@ -610,13 +611,61 @@ static bool findArgParts(Argument *Arg, const DataLayout &DL, AAResults &AAR,
610611 // unknown users
611612 }
612613
614+ auto *CB = dyn_cast<CallBase>(V);
615+ Value *PtrArg = dyn_cast<Value>(U);
616+ if (IsSelfRecursive && CB && PtrArg) {
617+ Type *PtrTy = PtrArg->getType ();
618+ Align PtrAlign = PtrArg->getPointerAlignment (DL);
619+ APInt Offset (DL.getIndexTypeSizeInBits (PtrArg->getType ()), 0 );
620+ PtrArg = PtrArg->stripAndAccumulateConstantOffsets (
621+ DL, Offset,
622+ /* AllowNonInbounds= */ true );
623+ if (PtrArg != Arg)
624+ return false ;
625+
626+ if (Offset.getSignificantBits () >= 64 )
627+ return false ;
628+
629+ int64_t Off = Offset.getSExtValue ();
630+ auto Pair = ArgParts.try_emplace (Off, ArgPart{PtrTy, PtrAlign, nullptr });
631+ ArgPart &Part = Pair.first ->second ;
632+
633+ // We limit promotion to only promoting up to a fixed number of elements
634+ // of the aggregate.
635+ if (MaxElements > 0 && ArgParts.size () > MaxElements) {
636+ LLVM_DEBUG (dbgs () << " ArgPromotion of " << *Arg << " failed: "
637+ << " more than " << MaxElements << " parts\n " );
638+ return false ;
639+ }
640+
641+ Part.Alignment = std::max (Part.Alignment , PtrAlign);
642+ continue ;
643+ }
613644 // Unknown user.
614645 LLVM_DEBUG (dbgs () << " ArgPromotion of " << *Arg << " failed: "
615646 << " unknown user " << *V << " \n " );
616647 return false ;
617648 }
618649
619- if (NeededDerefBytes || NeededAlign > 1 ) {
650+ // Incase of functions with recursive calls, this check will fail when it
651+ // tries to look at the first caller of this function. The caller may or may
652+ // not have a load, incase it doesn't load the pointer being passed, this
653+ // check will fail. So, it's safe to skip the check incase we know that we
654+ // are dealing with a recursive call.
655+ //
656+ // def fun(ptr %a) {
657+ // ...
658+ // %loadres = load i32, ptr %a, align 4
659+ // %res = call i32 @fun(ptr %a)
660+ // ...
661+ // }
662+ //
663+ // def bar(ptr %x) {
664+ // ...
665+ // %resbar = call i32 @fun(ptr %x)
666+ // ...
667+ // }
668+ if (!IsRecursive && (NeededDerefBytes || NeededAlign > 1 )) {
620669 // Try to prove a required deref / aligned requirement.
621670 if (!allCallersPassValidPointerForArgument (Arg, NeededAlign,
622671 NeededDerefBytes)) {
@@ -699,6 +748,10 @@ static bool areTypesABICompatible(ArrayRef<Type *> Types, const Function &F,
699748// / calls the DoPromotion method.
700749static Function *promoteArguments (Function *F, FunctionAnalysisManager &FAM,
701750 unsigned MaxElements, bool IsRecursive) {
751+ // Due to complexity of handling cases where the SCC has more than one
752+ // component. We want to limit argument promotion of recursive calls to
753+ // just functions that directly call themselves.
754+ bool IsSelfRecursive = false ;
702755 // Don't perform argument promotion for naked functions; otherwise we can end
703756 // up removing parameters that are seemingly 'not used' as they are referred
704757 // to in the assembly.
@@ -744,8 +797,10 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
744797 if (CB->isMustTailCall ())
745798 return nullptr ;
746799
747- if (CB->getFunction () == F)
800+ if (CB->getFunction () == F) {
748801 IsRecursive = true ;
802+ IsSelfRecursive = true ;
803+ }
749804 }
750805
751806 // Can't change signature of musttail caller
@@ -779,7 +834,8 @@ static Function *promoteArguments(Function *F, FunctionAnalysisManager &FAM,
779834 // If we can promote the pointer to its value.
780835 SmallVector<OffsetAndArgPart, 4 > ArgParts;
781836
782- if (findArgParts (PtrArg, DL, AAR, MaxElements, IsRecursive, ArgParts)) {
837+ if (findArgParts (PtrArg, DL, AAR, MaxElements, IsRecursive, IsSelfRecursive,
838+ ArgParts)) {
783839 SmallVector<Type *, 4 > Types;
784840 for (const auto &Pair : ArgParts)
785841 Types.push_back (Pair.second .Ty );
0 commit comments