4545#include " llvm/Support/JSON.h"
4646#include < array>
4747#include < map>
48+ #include < optional>
4849
4950namespace llvm {
5051
@@ -144,6 +145,73 @@ struct Embedding {
144145using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145146using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146147
148+ // / Generic storage class for section-based vocabularies.
149+ // / VocabStorage provides a generic foundation for storing and accessing
150+ // / embeddings organized into sections.
151+ class VocabStorage {
152+ private:
153+ // / Section-based storage
154+ std::vector<std::vector<Embedding>> Sections;
155+
156+ size_t TotalSize = 0 ;
157+ unsigned Dimension = 0 ;
158+
159+ public:
160+ // / Default constructor creates empty storage (invalid state)
161+ VocabStorage () : Sections(), TotalSize(0 ), Dimension(0 ) {}
162+
163+ // / Create a VocabStorage with pre-organized section data
164+ VocabStorage (std::vector<std::vector<Embedding>> &&SectionData);
165+
166+ VocabStorage (VocabStorage &&) = default ;
167+ VocabStorage &operator =(VocabStorage &&Other);
168+
169+ VocabStorage (const VocabStorage &) = delete ;
170+ VocabStorage &operator =(const VocabStorage &) = delete ;
171+
172+ // / Get total number of entries across all sections
173+ size_t size () const { return TotalSize; }
174+
175+ // / Get number of sections
176+ unsigned getNumSections () const {
177+ return static_cast <unsigned >(Sections.size ());
178+ }
179+
180+ // / Section-based access: Storage[sectionId][localIndex]
181+ const std::vector<Embedding> &operator [](unsigned SectionId) const {
182+ assert (SectionId < Sections.size () && " Invalid section ID" );
183+ return Sections[SectionId];
184+ }
185+
186+ // / Get vocabulary dimension
187+ unsigned getDimension () const { return Dimension; }
188+
189+ // / Check if vocabulary is valid (has data)
190+ bool isValid () const { return TotalSize > 0 ; }
191+
192+ // / Iterator support for section-based access
193+ class const_iterator {
194+ const VocabStorage *Storage;
195+ unsigned SectionId;
196+ size_t LocalIndex;
197+
198+ public:
199+ const_iterator (const VocabStorage *Storage, unsigned SectionId,
200+ size_t LocalIndex)
201+ : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202+
203+ LLVM_ABI const Embedding &operator *() const ;
204+ LLVM_ABI const_iterator &operator ++();
205+ LLVM_ABI bool operator ==(const const_iterator &Other) const ;
206+ LLVM_ABI bool operator !=(const const_iterator &Other) const ;
207+ };
208+
209+ const_iterator begin () const { return const_iterator (this , 0 , 0 ); }
210+ const_iterator end () const {
211+ return const_iterator (this , getNumSections (), 0 );
212+ }
213+ };
214+
147215// / Class for storing and accessing the IR2Vec vocabulary.
148216// / The Vocabulary class manages seed embeddings for LLVM IR entities. The
149217// / seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164232class Vocabulary {
165233 friend class llvm ::IR2VecVocabAnalysis;
166234
167- // Vocabulary Slot Layout:
235+ // Vocabulary Layout:
168236 // +----------------+------------------------------------------------------+
169237 // | Entity Type | Index Range |
170238 // +----------------+------------------------------------------------------+
@@ -175,8 +243,16 @@ class Vocabulary {
175243 // Note: "Similar" LLVM Types are grouped/canonicalized together.
176244 // Operands include Comparison predicates (ICmp/FCmp).
177245 // This can be extended to include other specializations in future.
178- using VocabVector = std::vector<ir2vec::Embedding>;
179- VocabVector Vocab;
246+ enum class Section : unsigned {
247+ Opcodes = 0 ,
248+ CanonicalTypes = 1 ,
249+ Operands = 2 ,
250+ Predicates = 3 ,
251+ MaxSections
252+ };
253+
254+ // Use section-based storage for better organization and efficiency
255+ VocabStorage Storage;
180256
181257 static constexpr unsigned NumICmpPredicates =
182258 static_cast <unsigned >(CmpInst::LAST_ICMP_PREDICATE) -
@@ -228,9 +304,18 @@ class Vocabulary {
228304 NumICmpPredicates + NumFCmpPredicates;
229305
230306 Vocabulary () = default ;
231- LLVM_ABI Vocabulary (VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
307+ LLVM_ABI Vocabulary (VocabStorage &&Storage) : Storage(std::move(Storage)) {}
308+
309+ Vocabulary (const Vocabulary &) = delete ;
310+ Vocabulary &operator =(const Vocabulary &) = delete ;
311+
312+ Vocabulary (Vocabulary &&) = default ;
313+ Vocabulary &operator =(Vocabulary &&Other);
314+
315+ LLVM_ABI bool isValid () const {
316+ return Storage.size () == NumCanonicalEntries;
317+ }
232318
233- LLVM_ABI bool isValid () const { return Vocab.size () == NumCanonicalEntries; };
234319 LLVM_ABI unsigned getDimension () const ;
235320 // / Total number of entries (opcodes + canonicalized types + operand kinds +
236321 // / predicates)
@@ -251,12 +336,11 @@ class Vocabulary {
251336 // / Function to get vocabulary key for a given predicate
252337 LLVM_ABI static StringRef getVocabKeyForPredicate (CmpInst::Predicate P);
253338
254- // / Functions to return the slot index or position of a given Opcode, TypeID,
255- // / or OperandKind in the vocabulary.
256- LLVM_ABI static unsigned getSlotIndex (unsigned Opcode);
257- LLVM_ABI static unsigned getSlotIndex (Type::TypeID TypeID);
258- LLVM_ABI static unsigned getSlotIndex (const Value &Op);
259- LLVM_ABI static unsigned getSlotIndex (CmpInst::Predicate P);
339+ // / Functions to return flat index
340+ LLVM_ABI static unsigned getIndex (unsigned Opcode);
341+ LLVM_ABI static unsigned getIndex (Type::TypeID TypeID);
342+ LLVM_ABI static unsigned getIndex (const Value &Op);
343+ LLVM_ABI static unsigned getIndex (CmpInst::Predicate P);
260344
261345 // / Accessors to get the embedding for a given entity.
262346 LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
@@ -265,34 +349,29 @@ class Vocabulary {
265349 LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const ;
266350
267351 // / Const Iterator type aliases
268- using const_iterator = VocabVector::const_iterator;
352+ using const_iterator = VocabStorage::const_iterator;
353+
269354 const_iterator begin () const {
270355 assert (isValid () && " IR2Vec Vocabulary is invalid" );
271- return Vocab .begin ();
356+ return Storage .begin ();
272357 }
273358
274- const_iterator cbegin () const {
275- assert (isValid () && " IR2Vec Vocabulary is invalid" );
276- return Vocab.cbegin ();
277- }
359+ const_iterator cbegin () const { return begin (); }
278360
279361 const_iterator end () const {
280362 assert (isValid () && " IR2Vec Vocabulary is invalid" );
281- return Vocab .end ();
363+ return Storage .end ();
282364 }
283365
284- const_iterator cend () const {
285- assert (isValid () && " IR2Vec Vocabulary is invalid" );
286- return Vocab.cend ();
287- }
366+ const_iterator cend () const { return end (); }
288367
289368 // / Returns the string key for a given index position in the vocabulary.
290369 // / This is useful for debugging or printing the vocabulary. Do not use this
291370 // / for embedding generation as string based lookups are inefficient.
292371 LLVM_ABI static StringRef getStringKey (unsigned Pos);
293372
294373 // / Create a dummy vocabulary for testing purposes.
295- LLVM_ABI static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
374+ LLVM_ABI static VocabStorage createDummyVocabForTest (unsigned Dim = 1 );
296375
297376 LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
298377 ModuleAnalysisManager::Invalidator &Inv) const ;
@@ -301,12 +380,16 @@ class Vocabulary {
301380 constexpr static unsigned NumCanonicalEntries =
302381 MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
303382
304- // Base offsets for slot layout to simplify index computation
383+ // Base offsets for flat index computation
305384 constexpr static unsigned OperandBaseOffset =
306385 MaxOpcodes + MaxCanonicalTypeIDs;
307386 constexpr static unsigned PredicateBaseOffset =
308387 OperandBaseOffset + MaxOperandKinds;
309388
389+ // / Functions for predicate index calculations
390+ static unsigned getPredicateLocalIndex (CmpInst::Predicate P);
391+ static CmpInst::Predicate getPredicateFromLocalIndex (unsigned LocalIndex);
392+
310393 // / String mappings for CanonicalTypeID values
311394 static constexpr StringLiteral CanonicalTypeNames[] = {
312395 " FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
@@ -452,22 +535,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
452535// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
453536// / its corresponding embedding.
454537class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
455- using VocabVector = std::vector<ir2vec::Embedding>;
456538 using VocabMap = std::map<std::string, ir2vec::Embedding>;
457- VocabMap OpcVocab, TypeVocab, ArgVocab;
458- VocabVector Vocab;
539+ std::optional<ir2vec::VocabStorage> Vocab;
459540
460- Error readVocabulary ();
541+ Error readVocabulary (VocabMap &OpcVocab, VocabMap &TypeVocab,
542+ VocabMap &ArgVocab);
461543 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
462544 VocabMap &TargetVocab, unsigned &Dim);
463- void generateNumMappedVocab ();
545+ void generateVocabStorage (VocabMap &OpcVocab, VocabMap &TypeVocab,
546+ VocabMap &ArgVocab);
464547 void emitError (Error Err, LLVMContext &Ctx);
465548
466549public:
467550 LLVM_ABI static AnalysisKey Key;
468551 IR2VecVocabAnalysis () = default ;
469- LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector & Vocab);
470- LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector && Vocab);
552+ LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::VocabStorage && Vocab)
553+ : Vocab(std::move(Vocab)) {}
471554 using Result = ir2vec::Vocabulary;
472555 LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
473556};
0 commit comments