diff --git a/llvm/include/llvm/ProfileData/MemProf.h b/llvm/include/llvm/ProfileData/MemProf.h index d378c3696f8d0..8b00faf2a219d 100644 --- a/llvm/include/llvm/ProfileData/MemProf.h +++ b/llvm/include/llvm/ProfileData/MemProf.h @@ -737,6 +737,64 @@ class CallStackLookupTrait { // Compute a CallStackId for a given call stack. CallStackId hashCallStack(ArrayRef CS); +namespace detail { +// "Dereference" the iterator from DenseMap or OnDiskChainedHashTable. We have +// to do so in one of two different ways depending on the type of the hash +// table. +template +value_type DerefIterator(IterTy Iter) { + using deref_type = llvm::remove_cvref_t; + if constexpr (std::is_same_v) + return *Iter; + else + return Iter->second; +} +} // namespace detail + +// A function object that returns a frame for a given FrameId. +template struct FrameIdConverter { + std::optional LastUnmappedId; + MapTy ⤅ + + FrameIdConverter() = delete; + FrameIdConverter(MapTy &Map) : Map(Map) {} + + Frame operator()(FrameId Id) { + auto Iter = Map.find(Id); + if (Iter == Map.end()) { + LastUnmappedId = Id; + return Frame(0, 0, 0, false); + } + return detail::DerefIterator(Iter); + } +}; + +// A function object that returns a call stack for a given CallStackId. +template struct CallStackIdConverter { + std::optional LastUnmappedId; + MapTy ⤅ + std::function FrameIdToFrame; + + CallStackIdConverter() = delete; + CallStackIdConverter(MapTy &Map, std::function FrameIdToFrame) + : Map(Map), FrameIdToFrame(FrameIdToFrame) {} + + llvm::SmallVector operator()(CallStackId CSId) { + llvm::SmallVector Frames; + auto CSIter = Map.find(CSId); + if (CSIter == Map.end()) { + LastUnmappedId = CSId; + } else { + llvm::SmallVector CS = + detail::DerefIterator>(CSIter); + Frames.reserve(CS.size()); + for (FrameId Id : CS) + Frames.push_back(FrameIdToFrame(Id)); + } + return Frames; + } +}; + // Verify that each CallStackId is computed with hashCallStack. This function // is intended to help transition from CallStack to CSId in // IndexedAllocationInfo. diff --git a/llvm/include/llvm/ProfileData/MemProfReader.h b/llvm/include/llvm/ProfileData/MemProfReader.h index 444c58e8bdc8b..b42e4f5977740 100644 --- a/llvm/include/llvm/ProfileData/MemProfReader.h +++ b/llvm/include/llvm/ProfileData/MemProfReader.h @@ -76,20 +76,16 @@ class MemProfReader { Callback = std::bind(&MemProfReader::idToFrame, this, std::placeholders::_1); - auto CallStackCallback = [&](CallStackId CSId) { - llvm::SmallVector CallStack; - auto Iter = CSIdToCallStack.find(CSId); - assert(Iter != CSIdToCallStack.end()); - for (FrameId Id : Iter->second) - CallStack.push_back(Callback(Id)); - return CallStack; - }; + memprof::CallStackIdConverter CSIdConv( + CSIdToCallStack, Callback); const IndexedMemProfRecord &IndexedRecord = Iter->second; GuidRecord = { Iter->first, - IndexedRecord.toMemProfRecord(CallStackCallback), + IndexedRecord.toMemProfRecord(CSIdConv), }; + if (CSIdConv.LastUnmappedId) + return make_error(instrprof_error::hash_mismatch); Iter++; return Error::success(); } diff --git a/llvm/lib/ProfileData/InstrProfReader.cpp b/llvm/lib/ProfileData/InstrProfReader.cpp index cefb6af12d002..440be2f255d39 100644 --- a/llvm/lib/ProfileData/InstrProfReader.cpp +++ b/llvm/lib/ProfileData/InstrProfReader.cpp @@ -1520,53 +1520,35 @@ IndexedMemProfReader::getMemProfRecord(const uint64_t FuncNameHash) const { // Setup a callback to convert from frame ids to frame using the on-disk // FrameData hash table. - std::optional LastUnmappedFrameId; - auto IdToFrameCallback = [&](const memprof::FrameId Id) { - auto FrIter = MemProfFrameTable->find(Id); - if (FrIter == MemProfFrameTable->end()) { - LastUnmappedFrameId = Id; - return memprof::Frame(0, 0, 0, false); - } - return *FrIter; - }; + memprof::FrameIdConverter FrameIdConv( + *MemProfFrameTable.get()); // Setup a callback to convert call stack ids to call stacks using the on-disk // hash table. - std::optional LastUnmappedCSId; - auto CSIdToCallStackCallback = [&](memprof::CallStackId CSId) { - llvm::SmallVector Frames; - auto CSIter = MemProfCallStackTable->find(CSId); - if (CSIter == MemProfCallStackTable->end()) { - LastUnmappedCSId = CSId; - } else { - const llvm::SmallVector &CS = *CSIter; - Frames.reserve(CS.size()); - for (memprof::FrameId Id : CS) - Frames.push_back(IdToFrameCallback(Id)); - } - return Frames; - }; + memprof::CallStackIdConverter CSIdConv( + *MemProfCallStackTable.get(), FrameIdConv); const memprof::IndexedMemProfRecord IndexedRecord = *Iter; memprof::MemProfRecord Record; if (MemProfCallStackTable) - Record = IndexedRecord.toMemProfRecord(CSIdToCallStackCallback); + Record = IndexedRecord.toMemProfRecord(CSIdConv); else - Record = memprof::MemProfRecord(IndexedRecord, IdToFrameCallback); + Record = memprof::MemProfRecord(IndexedRecord, FrameIdConv); // Check that all frame ids were successfully converted to frames. - if (LastUnmappedFrameId) { - return make_error(instrprof_error::hash_mismatch, - "memprof frame not found for frame id " + - Twine(*LastUnmappedFrameId)); + if (FrameIdConv.LastUnmappedId) { + return make_error( + instrprof_error::hash_mismatch, + "memprof frame not found for frame id " + + Twine(*FrameIdConv.LastUnmappedId)); } // Check that all call stack ids were successfully converted to call stacks. - if (LastUnmappedCSId) { + if (CSIdConv.LastUnmappedId) { return make_error( instrprof_error::hash_mismatch, "memprof call stack not found for call stack id " + - Twine(*LastUnmappedCSId)); + Twine(*CSIdConv.LastUnmappedId)); } return Record; } diff --git a/llvm/unittests/ProfileData/InstrProfTest.cpp b/llvm/unittests/ProfileData/InstrProfTest.cpp index edc427dcbc454..acc633de11b6b 100644 --- a/llvm/unittests/ProfileData/InstrProfTest.cpp +++ b/llvm/unittests/ProfileData/InstrProfTest.cpp @@ -495,44 +495,6 @@ TEST_F(InstrProfTest, test_memprof_v0) { EXPECT_THAT(WantRecord, EqualsRecord(Record)); } -struct CallStackIdConverter { - std::optional LastUnmappedFrameId; - std::optional LastUnmappedCSId; - - const FrameIdMapTy &IdToFrameMap; - const CallStackIdMapTy &CSIdToCallStackMap; - - CallStackIdConverter() = delete; - CallStackIdConverter(const FrameIdMapTy &IdToFrameMap, - const CallStackIdMapTy &CSIdToCallStackMap) - : IdToFrameMap(IdToFrameMap), CSIdToCallStackMap(CSIdToCallStackMap) {} - - llvm::SmallVector - operator()(::llvm::memprof::CallStackId CSId) { - auto IdToFrameCallback = [&](const memprof::FrameId Id) { - auto Iter = IdToFrameMap.find(Id); - if (Iter == IdToFrameMap.end()) { - LastUnmappedFrameId = Id; - return memprof::Frame(0, 0, 0, false); - } - return Iter->second; - }; - - llvm::SmallVector Frames; - auto CSIter = CSIdToCallStackMap.find(CSId); - if (CSIter == CSIdToCallStackMap.end()) { - LastUnmappedCSId = CSId; - } else { - const ::llvm::SmallVector<::llvm::memprof::FrameId> &CS = - CSIter->getSecond(); - Frames.reserve(CS.size()); - for (::llvm::memprof::FrameId Id : CS) - Frames.push_back(IdToFrameCallback(Id)); - } - return Frames; - } -}; - TEST_F(InstrProfTest, test_memprof_v2_full_schema) { const MemInfoBlock MIB = makeFullMIB(); @@ -562,14 +524,16 @@ TEST_F(InstrProfTest, test_memprof_v2_full_schema) { ASSERT_THAT_ERROR(RecordOr.takeError(), Succeeded()); const memprof::MemProfRecord &Record = RecordOr.get(); - CallStackIdConverter CSIdConv(IdToFrameMap, CSIdToCallStackMap); + memprof::FrameIdConverter FrameIdConv(IdToFrameMap); + memprof::CallStackIdConverter CSIdConv( + CSIdToCallStackMap, FrameIdConv); const ::llvm::memprof::MemProfRecord WantRecord = IndexedMR.toMemProfRecord(CSIdConv); - ASSERT_EQ(CSIdConv.LastUnmappedFrameId, std::nullopt) - << "could not map frame id: " << *CSIdConv.LastUnmappedFrameId; - ASSERT_EQ(CSIdConv.LastUnmappedCSId, std::nullopt) - << "could not map call stack id: " << *CSIdConv.LastUnmappedCSId; + ASSERT_EQ(FrameIdConv.LastUnmappedId, std::nullopt) + << "could not map frame id: " << *FrameIdConv.LastUnmappedId; + ASSERT_EQ(CSIdConv.LastUnmappedId, std::nullopt) + << "could not map call stack id: " << *CSIdConv.LastUnmappedId; EXPECT_THAT(WantRecord, EqualsRecord(Record)); } @@ -602,14 +566,16 @@ TEST_F(InstrProfTest, test_memprof_v2_partial_schema) { ASSERT_THAT_ERROR(RecordOr.takeError(), Succeeded()); const memprof::MemProfRecord &Record = RecordOr.get(); - CallStackIdConverter CSIdConv(IdToFrameMap, CSIdToCallStackMap); + memprof::FrameIdConverter FrameIdConv(IdToFrameMap); + memprof::CallStackIdConverter CSIdConv( + CSIdToCallStackMap, FrameIdConv); const ::llvm::memprof::MemProfRecord WantRecord = IndexedMR.toMemProfRecord(CSIdConv); - ASSERT_EQ(CSIdConv.LastUnmappedFrameId, std::nullopt) - << "could not map frame id: " << *CSIdConv.LastUnmappedFrameId; - ASSERT_EQ(CSIdConv.LastUnmappedCSId, std::nullopt) - << "could not map call stack id: " << *CSIdConv.LastUnmappedCSId; + ASSERT_EQ(FrameIdConv.LastUnmappedId, std::nullopt) + << "could not map frame id: " << *FrameIdConv.LastUnmappedId; + ASSERT_EQ(CSIdConv.LastUnmappedId, std::nullopt) + << "could not map call stack id: " << *CSIdConv.LastUnmappedId; EXPECT_THAT(WantRecord, EqualsRecord(Record)); } diff --git a/llvm/unittests/ProfileData/MemProfTest.cpp b/llvm/unittests/ProfileData/MemProfTest.cpp index 98dacd3511e1d..d031049cea14b 100644 --- a/llvm/unittests/ProfileData/MemProfTest.cpp +++ b/llvm/unittests/ProfileData/MemProfTest.cpp @@ -502,37 +502,15 @@ TEST(MemProf, IndexedMemProfRecordToMemProfRecord) { IndexedRecord.CallSiteIds.push_back(llvm::memprof::hashCallStack(CS3)); IndexedRecord.CallSiteIds.push_back(llvm::memprof::hashCallStack(CS4)); - bool CSIdMissing = false; - bool FrameIdMissing = false; - - auto Callback = [&](CallStackId CSId) -> llvm::SmallVector { - llvm::SmallVector CallStack; - llvm::SmallVector FrameIds; - - auto Iter = CallStackIdMap.find(CSId); - if (Iter == CallStackIdMap.end()) - CSIdMissing = true; - else - FrameIds = Iter->second; - - for (FrameId Id : FrameIds) { - Frame F(0, 0, 0, false); - auto Iter = FrameIdMap.find(Id); - if (Iter == FrameIdMap.end()) - FrameIdMissing = true; - else - F = Iter->second; - CallStack.push_back(F); - } - - return CallStack; - }; - - MemProfRecord Record = IndexedRecord.toMemProfRecord(Callback); + llvm::memprof::FrameIdConverter FrameIdConv(FrameIdMap); + llvm::memprof::CallStackIdConverter CSIdConv( + CallStackIdMap, FrameIdConv); + + MemProfRecord Record = IndexedRecord.toMemProfRecord(CSIdConv); // Make sure that all lookups are successful. - ASSERT_FALSE(CSIdMissing); - ASSERT_FALSE(FrameIdMissing); + ASSERT_EQ(FrameIdConv.LastUnmappedId, std::nullopt); + ASSERT_EQ(CSIdConv.LastUnmappedId, std::nullopt); // Verify the contents of Record. ASSERT_THAT(Record.AllocSites, SizeIs(2));