diff --git a/llvm/include/llvm/Analysis/CFGPrinter.h b/llvm/include/llvm/Analysis/CFGPrinter.h index ec26da87eb916..aa711642a3a6d 100644 --- a/llvm/include/llvm/Analysis/CFGPrinter.h +++ b/llvm/include/llvm/Analysis/CFGPrinter.h @@ -31,6 +31,9 @@ #include "llvm/Support/DOTGraphTraits.h" #include "llvm/Support/FormatVariadic.h" +#include +#include + namespace llvm { class ModuleSlotTracker; @@ -69,13 +72,18 @@ class DOTFuncInfo { bool ShowHeat; bool EdgeWeights; bool RawWeights; + using NodeIdFormatterTy = + std::function(const BasicBlock *)>; + std::optional NodeIdFormatter; public: DOTFuncInfo(const Function *F) : DOTFuncInfo(F, nullptr, nullptr, 0) {} LLVM_ABI ~DOTFuncInfo(); - LLVM_ABI DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI, - const BranchProbabilityInfo *BPI, uint64_t MaxFreq); + LLVM_ABI + DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI, + const BranchProbabilityInfo *BPI, uint64_t MaxFreq, + std::optional NodeIdFormatter = std::nullopt); const BlockFrequencyInfo *getBFI() const { return BFI; } @@ -102,6 +110,10 @@ class DOTFuncInfo { void setEdgeWeights(bool EdgeWeights) { this->EdgeWeights = EdgeWeights; } bool showEdgeWeights() { return EdgeWeights; } + + std::optional getNodeIdFormatter() { + return NodeIdFormatter; + } }; template <> @@ -311,21 +323,27 @@ struct DOTGraphTraits : public DefaultDOTGraphTraits { } std::string getNodeAttributes(const BasicBlock *Node, DOTFuncInfo *CFGInfo) { + std::stringstream Attrs; + + if (auto NodeIdFmt = CFGInfo->getNodeIdFormatter()) + if (auto NodeId = (*NodeIdFmt)(Node)) + Attrs << "id=\"" << *NodeId << "\""; + + if (CFGInfo->showHeatColors()) { + uint64_t Freq = CFGInfo->getFreq(Node); + std::string Color = getHeatColor(Freq, CFGInfo->getMaxFreq()); + std::string EdgeColor = (Freq <= (CFGInfo->getMaxFreq() / 2)) + ? (getHeatColor(0)) + : (getHeatColor(1)); + if (!Attrs.str().empty()) + Attrs << ","; + Attrs << "color=\"" << EdgeColor << "ff\", style=filled, " + << "fillcolor=\"" << Color << "70\", " << "fontname=\"Courier\""; + } - if (!CFGInfo->showHeatColors()) - return ""; - - uint64_t Freq = CFGInfo->getFreq(Node); - std::string Color = getHeatColor(Freq, CFGInfo->getMaxFreq()); - std::string EdgeColor = (Freq <= (CFGInfo->getMaxFreq() / 2)) - ? (getHeatColor(0)) - : (getHeatColor(1)); - - std::string Attrs = "color=\"" + EdgeColor + "ff\", style=filled," + - " fillcolor=\"" + Color + "70\"" + - " fontname=\"Courier\""; - return Attrs; + return Attrs.str(); } + LLVM_ABI bool isNodeHidden(const BasicBlock *Node, const DOTFuncInfo *CFGInfo); LLVM_ABI void computeDeoptOrUnreachablePaths(const Function *F); diff --git a/llvm/lib/Analysis/CFGPrinter.cpp b/llvm/lib/Analysis/CFGPrinter.cpp index 38aad849755be..39108a906f081 100644 --- a/llvm/lib/Analysis/CFGPrinter.cpp +++ b/llvm/lib/Analysis/CFGPrinter.cpp @@ -92,8 +92,10 @@ static void viewCFG(Function &F, const BlockFrequencyInfo *BFI, } DOTFuncInfo::DOTFuncInfo(const Function *F, const BlockFrequencyInfo *BFI, - const BranchProbabilityInfo *BPI, uint64_t MaxFreq) - : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq) { + const BranchProbabilityInfo *BPI, uint64_t MaxFreq, + std::optional NodeIdFormatter) + : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq), + NodeIdFormatter(NodeIdFormatter) { ShowHeat = false; EdgeWeights = !!BPI; // Print EdgeWeights when BPI is available. RawWeights = !!BFI; // Print RawWeights when BFI is available.