@@ -804,26 +804,51 @@ static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) {
804804 return LatchBR;
805805}
806806
807- // / Return the estimated trip count for any exiting branch which dominates
808- // / the loop latch.
809- static std::optional<unsigned > getEstimatedTripCount (BranchInst *ExitingBranch,
810- Loop *L,
811- uint64_t &OrigExitWeight) {
807+ struct DbgLoop {
808+ const Loop *L;
809+ explicit DbgLoop (const Loop *L) : L(L) {}
810+ };
811+
812+ #ifndef NDEBUG
813+ static inline raw_ostream &operator <<(raw_ostream &OS, DbgLoop D) {
814+ OS << " function " ;
815+ D.L ->getHeader ()->getParent ()->printAsOperand (OS, /* PrintType=*/ false );
816+ return OS << " " << *D.L ;
817+ }
818+ #endif // NDEBUG
819+
820+ static std::optional<unsigned > estimateLoopTripCount (Loop *L) {
821+ // Currently we take the estimate exit count only from the loop latch,
822+ // ignoring other exiting blocks. This can overestimate the trip count
823+ // if we exit through another exit, but can never underestimate it.
824+ // TODO: incorporate information from other exits
825+ BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch (L);
826+ if (!ExitingBranch) {
827+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Failed to find exiting "
828+ << " latch branch of required form in " << DbgLoop (L)
829+ << " \n " );
830+ return std::nullopt ;
831+ }
832+
812833 // To estimate the number of times the loop body was executed, we want to
813834 // know the number of times the backedge was taken, vs. the number of times
814835 // we exited the loop.
815836 uint64_t LoopWeight, ExitWeight;
816- if (!extractBranchWeights (*ExitingBranch, LoopWeight, ExitWeight))
837+ if (!extractBranchWeights (*ExitingBranch, LoopWeight, ExitWeight)) {
838+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Failed to extract branch "
839+ << " weights for " << DbgLoop (L) << " \n " );
817840 return std::nullopt ;
841+ }
818842
819843 if (L->contains (ExitingBranch->getSuccessor (1 )))
820844 std::swap (LoopWeight, ExitWeight);
821845
822- if (!ExitWeight)
846+ if (!ExitWeight) {
823847 // Don't have a way to return predicated infinite
848+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Failed because of zero exit "
849+ << " probability for " << DbgLoop (L) << " \n " );
824850 return std::nullopt ;
825-
826- OrigExitWeight = ExitWeight;
851+ }
827852
828853 // Estimated exit count is a ratio of the loop weight by the weight of the
829854 // edge exiting the loop, rounded to nearest.
@@ -834,43 +859,102 @@ static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch,
834859 return std::numeric_limits<unsigned >::max ();
835860
836861 // Estimated trip count is one plus estimated exit count.
837- return ExitCount + 1 ;
862+ uint64_t TC = ExitCount + 1 ;
863+ LLVM_DEBUG (dbgs () << " estimateLoopTripCount: Estimated trip count of " << TC
864+ << " for " << DbgLoop (L) << " \n " );
865+ return TC;
838866}
839867
840868std::optional<unsigned >
841869llvm::getLoopEstimatedTripCount (Loop *L,
842870 unsigned *EstimatedLoopInvocationWeight) {
843- // Currently we take the estimate exit count only from the loop latch,
844- // ignoring other exiting blocks. This can overestimate the trip count
845- // if we exit through another exit, but can never underestimate it.
846- // TODO: incorporate information from other exits
847- if (BranchInst *LatchBranch = getExpectedExitLoopLatchBranch (L)) {
848- uint64_t ExitWeight;
849- if (std::optional<uint64_t > EstTripCount =
850- getEstimatedTripCount (LatchBranch, L, ExitWeight)) {
851- if (EstimatedLoopInvocationWeight)
852- *EstimatedLoopInvocationWeight = ExitWeight;
853- return *EstTripCount;
854- }
871+ // If EstimatedLoopInvocationWeight, we do not support this loop if
872+ // getExpectedExitLoopLatchBranch returns nullptr.
873+ //
874+ // FIXME: Also, this is a stop-gap solution for nested loops. It avoids
875+ // mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
876+ // it was created for an inner loop. The problem is that loop metadata is
877+ // attached to the branch instruction in the loop latch block, but that can be
878+ // shared by the loops. A solution is to attach loop metadata to loop headers
879+ // instead, but that would be a large change to LLVM.
880+ //
881+ // Until that happens, we work around the problem as follows.
882+ // getExpectedExitLoopLatchBranch (which also guards
883+ // setLoopEstimatedTripCount) returns nullptr for a loop unless the loop has
884+ // one latch and that latch has exactly two successors one of which is an exit
885+ // from the loop. If the latch is shared by nested loops, then that condition
886+ // might hold for the inner loop but cannot hold for the outer loop:
887+ // - Because the latch is shared, it must have at least two successors: the
888+ // inner loop header and the outer loop header, which is also an exit for
889+ // the inner loop. That satisifies the condition for the inner loop.
890+ // - To satsify the condition for the outer loop, the latch must have a third
891+ // successor that is an exit for the outer loop. But that violates the
892+ // condition for both loops.
893+ BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch (L);
894+ if (!ExitingBranch)
895+ return std::nullopt ;
896+
897+ // If requested, either compute *EstimatedLoopInvocationWeight or return
898+ // nullopt if cannot.
899+ //
900+ // TODO: Eventually, once all passes have migrated away from setting branch
901+ // weights to indicate estimated trip counts, this function will drop the
902+ // EstimatedLoopInvocationWeight parameter.
903+ if (EstimatedLoopInvocationWeight) {
904+ uint64_t LoopWeight = 0 , ExitWeight = 0 ; // Inits expected to be unused.
905+ if (!extractBranchWeights (*ExitingBranch, LoopWeight, ExitWeight))
906+ return std::nullopt ;
907+ if (L->contains (ExitingBranch->getSuccessor (1 )))
908+ std::swap (LoopWeight, ExitWeight);
909+ if (!ExitWeight)
910+ return std::nullopt ;
911+ *EstimatedLoopInvocationWeight = ExitWeight;
855912 }
856- return std::nullopt ;
913+
914+ // Return the estimated trip count from metadata unless the metadata is
915+ // missing or has no value.
916+ if (auto TC = getOptionalIntLoopAttribute (L, LLVMLoopEstimatedTripCount)) {
917+ LLVM_DEBUG (dbgs () << " getLoopEstimatedTripCount: "
918+ << LLVMLoopEstimatedTripCount << " metadata has trip "
919+ << " count of " << *TC << " for " << DbgLoop (L) << " \n " );
920+ return TC;
921+ }
922+
923+ // Estimate the trip count from latch branch weights.
924+ return estimateLoopTripCount (L);
857925}
858926
859- bool llvm::setLoopEstimatedTripCount (Loop *L, unsigned EstimatedTripCount,
860- unsigned EstimatedloopInvocationWeight) {
861- // At the moment, we currently support changing the estimate trip count of
862- // the latch branch only. We could extend this API to manipulate estimated
863- // trip counts for any exit.
927+ bool llvm::setLoopEstimatedTripCount (
928+ Loop *L, unsigned EstimatedTripCount,
929+ std::optional<unsigned > EstimatedloopInvocationWeight) {
930+ // If EstimatedLoopInvocationWeight, we do not support this loop if
931+ // getExpectedExitLoopLatchBranch returns nullptr.
932+ //
933+ // FIXME: See comments in getLoopEstimatedTripCount for why this is required
934+ // here regardless of EstimatedLoopInvocationWeight.
864935 BranchInst *LatchBranch = getExpectedExitLoopLatchBranch (L);
865936 if (!LatchBranch)
866937 return false ;
867938
939+ // Set the metadata.
940+ addStringMetadataToLoop (L, LLVMLoopEstimatedTripCount, EstimatedTripCount);
941+
942+ // At the moment, we currently support changing the estimated trip count in
943+ // the latch branch's branch weights only. We could extend this API to
944+ // manipulate estimated trip counts for any exit.
945+ //
946+ // TODO: Eventually, once all passes have migrated away from setting branch
947+ // weights to indicate estimated trip counts, we will not set branch weights
948+ // here at all.
949+ if (!EstimatedloopInvocationWeight)
950+ return true ;
951+
868952 // Calculate taken and exit weights.
869953 unsigned LatchExitWeight = 0 ;
870954 unsigned BackedgeTakenWeight = 0 ;
871955
872- if (EstimatedTripCount > 0 ) {
873- LatchExitWeight = EstimatedloopInvocationWeight;
956+ if (EstimatedTripCount != 0 ) {
957+ LatchExitWeight = * EstimatedloopInvocationWeight;
874958 BackedgeTakenWeight = (EstimatedTripCount - 1 ) * LatchExitWeight;
875959 }
876960
0 commit comments