4545#include " llvm/Support/JSON.h"
4646#include < array>
4747#include < map>
48+ #include < optional>
4849
4950namespace llvm {
5051
@@ -144,6 +145,73 @@ struct Embedding {
144145using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145146using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146147
148+ // / Generic storage class for section-based vocabularies.
149+ // / VocabStorage provides a generic foundation for storing and accessing
150+ // / embeddings organized into sections.
151+ class VocabStorage {
152+ private:
153+ // / Section-based storage
154+ std::vector<std::vector<Embedding>> Sections;
155+
156+ const size_t TotalSize;
157+ const unsigned Dimension;
158+
159+ public:
160+ // / Default constructor creates empty storage (invalid state)
161+ VocabStorage () : Sections(), TotalSize(0 ), Dimension(0 ) {}
162+
163+ // / Create a VocabStorage with pre-organized section data
164+ VocabStorage (std::vector<std::vector<Embedding>> &&SectionData);
165+
166+ VocabStorage (VocabStorage &&) = default ;
167+ VocabStorage &operator =(VocabStorage &&) = delete ;
168+
169+ VocabStorage (const VocabStorage &) = delete ;
170+ VocabStorage &operator =(const VocabStorage &) = delete ;
171+
172+ // / Get total number of entries across all sections
173+ size_t size () const { return TotalSize; }
174+
175+ // / Get number of sections
176+ unsigned getNumSections () const {
177+ return static_cast <unsigned >(Sections.size ());
178+ }
179+
180+ // / Section-based access: Storage[sectionId][localIndex]
181+ const std::vector<Embedding> &operator [](unsigned SectionId) const {
182+ assert (SectionId < Sections.size () && " Invalid section ID" );
183+ return Sections[SectionId];
184+ }
185+
186+ // / Get vocabulary dimension
187+ unsigned getDimension () const { return Dimension; }
188+
189+ // / Check if vocabulary is valid (has data)
190+ bool isValid () const { return TotalSize > 0 ; }
191+
192+ // / Iterator support for section-based access
193+ class const_iterator {
194+ const VocabStorage *Storage;
195+ unsigned SectionId = 0 ;
196+ size_t LocalIndex = 0 ;
197+
198+ public:
199+ const_iterator (const VocabStorage *Storage, unsigned SectionId,
200+ size_t LocalIndex)
201+ : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202+
203+ LLVM_ABI const Embedding &operator *() const ;
204+ LLVM_ABI const_iterator &operator ++();
205+ LLVM_ABI bool operator ==(const const_iterator &Other) const ;
206+ LLVM_ABI bool operator !=(const const_iterator &Other) const ;
207+ };
208+
209+ const_iterator begin () const { return const_iterator (this , 0 , 0 ); }
210+ const_iterator end () const {
211+ return const_iterator (this , getNumSections (), 0 );
212+ }
213+ };
214+
147215// / Class for storing and accessing the IR2Vec vocabulary.
148216// / The Vocabulary class manages seed embeddings for LLVM IR entities. The
149217// / seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164232class Vocabulary {
165233 friend class llvm ::IR2VecVocabAnalysis;
166234
167- // Vocabulary Slot Layout:
235+ // Vocabulary Layout:
168236 // +----------------+------------------------------------------------------+
169237 // | Entity Type | Index Range |
170238 // +----------------+------------------------------------------------------+
@@ -175,8 +243,16 @@ class Vocabulary {
175243 // Note: "Similar" LLVM Types are grouped/canonicalized together.
176244 // Operands include Comparison predicates (ICmp/FCmp).
177245 // This can be extended to include other specializations in future.
178- using VocabVector = std::vector<ir2vec::Embedding>;
179- VocabVector Vocab;
246+ enum class Section : unsigned {
247+ Opcodes = 0 ,
248+ CanonicalTypes = 1 ,
249+ Operands = 2 ,
250+ Predicates = 3 ,
251+ MaxSections
252+ };
253+
254+ // Use section-based storage for better organization and efficiency
255+ VocabStorage Storage;
180256
181257 static constexpr unsigned NumICmpPredicates =
182258 static_cast <unsigned >(CmpInst::LAST_ICMP_PREDICATE) -
@@ -228,10 +304,23 @@ class Vocabulary {
228304 NumICmpPredicates + NumFCmpPredicates;
229305
230306 Vocabulary () = default ;
231- LLVM_ABI Vocabulary (VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
307+ LLVM_ABI Vocabulary (VocabStorage &&Storage) : Storage(std::move(Storage)) {}
308+
309+ Vocabulary (const Vocabulary &) = delete ;
310+ Vocabulary &operator =(const Vocabulary &) = delete ;
311+
312+ Vocabulary (Vocabulary &&) = default ;
313+ Vocabulary &operator =(Vocabulary &&Other) = delete ;
314+
315+ LLVM_ABI bool isValid () const {
316+ return Storage.size () == NumCanonicalEntries;
317+ }
318+
319+ LLVM_ABI unsigned getDimension () const {
320+ assert (isValid () && " IR2Vec Vocabulary is invalid" );
321+ return Storage.getDimension ();
322+ }
232323
233- LLVM_ABI bool isValid () const { return Vocab.size () == NumCanonicalEntries; };
234- LLVM_ABI unsigned getDimension () const ;
235324 // / Total number of entries (opcodes + canonicalized types + operand kinds +
236325 // / predicates)
237326 static constexpr size_t getCanonicalSize () { return NumCanonicalEntries; }
@@ -240,59 +329,91 @@ class Vocabulary {
240329 LLVM_ABI static StringRef getVocabKeyForOpcode (unsigned Opcode);
241330
242331 // / Function to get vocabulary key for a given TypeID
243- LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
332+ LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID) {
333+ return getVocabKeyForCanonicalTypeID (getCanonicalTypeID (TypeID));
334+ }
244335
245336 // / Function to get vocabulary key for a given OperandKind
246- LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind);
337+ LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind) {
338+ unsigned Index = static_cast <unsigned >(Kind);
339+ assert (Index < MaxOperandKinds && " Invalid OperandKind" );
340+ return OperandKindNames[Index];
341+ }
247342
248343 // / Function to classify an operand into OperandKind
249344 LLVM_ABI static OperandKind getOperandKind (const Value *Op);
250345
251346 // / Function to get vocabulary key for a given predicate
252347 LLVM_ABI static StringRef getVocabKeyForPredicate (CmpInst::Predicate P);
253348
254- // / Functions to return the slot index or position of a given Opcode, TypeID,
255- // / or OperandKind in the vocabulary.
256- LLVM_ABI static unsigned getSlotIndex (unsigned Opcode);
257- LLVM_ABI static unsigned getSlotIndex (Type::TypeID TypeID);
258- LLVM_ABI static unsigned getSlotIndex (const Value &Op);
259- LLVM_ABI static unsigned getSlotIndex (CmpInst::Predicate P);
349+ // / Functions to return flat index
350+ LLVM_ABI static unsigned getIndex (unsigned Opcode) {
351+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
352+ return Opcode - 1 ; // Convert to zero-based index
353+ }
354+
355+ LLVM_ABI static unsigned getIndex (Type::TypeID TypeID) {
356+ assert (static_cast <unsigned >(TypeID) < MaxTypeIDs && " Invalid type ID" );
357+ return MaxOpcodes + static_cast <unsigned >(getCanonicalTypeID (TypeID));
358+ }
359+
360+ LLVM_ABI static unsigned getIndex (const Value &Op) {
361+ unsigned Index = static_cast <unsigned >(getOperandKind (&Op));
362+ assert (Index < MaxOperandKinds && " Invalid OperandKind" );
363+ return OperandBaseOffset + Index;
364+ }
365+
366+ LLVM_ABI static unsigned getIndex (CmpInst::Predicate P) {
367+ return PredicateBaseOffset + getPredicateLocalIndex (P);
368+ }
260369
261370 // / Accessors to get the embedding for a given entity.
262- LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
263- LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
264- LLVM_ABI const ir2vec::Embedding &operator [](const Value &Arg) const ;
265- LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const ;
371+ LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const {
372+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
373+ return Storage[static_cast <unsigned >(Section::Opcodes)][Opcode - 1 ];
374+ }
375+
376+ LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeID) const {
377+ assert (static_cast <unsigned >(TypeID) < MaxTypeIDs && " Invalid type ID" );
378+ unsigned LocalIndex = static_cast <unsigned >(getCanonicalTypeID (TypeID));
379+ return Storage[static_cast <unsigned >(Section::CanonicalTypes)][LocalIndex];
380+ }
381+
382+ LLVM_ABI const ir2vec::Embedding &operator [](const Value &Arg) const {
383+ unsigned LocalIndex = static_cast <unsigned >(getOperandKind (&Arg));
384+ assert (LocalIndex < MaxOperandKinds && " Invalid OperandKind" );
385+ return Storage[static_cast <unsigned >(Section::Operands)][LocalIndex];
386+ }
387+
388+ LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const {
389+ unsigned LocalIndex = getPredicateLocalIndex (P);
390+ return Storage[static_cast <unsigned >(Section::Predicates)][LocalIndex];
391+ }
266392
267393 // / Const Iterator type aliases
268- using const_iterator = VocabVector::const_iterator;
394+ using const_iterator = VocabStorage::const_iterator;
395+
269396 const_iterator begin () const {
270397 assert (isValid () && " IR2Vec Vocabulary is invalid" );
271- return Vocab .begin ();
398+ return Storage .begin ();
272399 }
273400
274- const_iterator cbegin () const {
275- assert (isValid () && " IR2Vec Vocabulary is invalid" );
276- return Vocab.cbegin ();
277- }
401+ const_iterator cbegin () const { return begin (); }
278402
279403 const_iterator end () const {
280404 assert (isValid () && " IR2Vec Vocabulary is invalid" );
281- return Vocab .end ();
405+ return Storage .end ();
282406 }
283407
284- const_iterator cend () const {
285- assert (isValid () && " IR2Vec Vocabulary is invalid" );
286- return Vocab.cend ();
287- }
408+ const_iterator cend () const { return end (); }
288409
289410 // / Returns the string key for a given index position in the vocabulary.
290411 // / This is useful for debugging or printing the vocabulary. Do not use this
291412 // / for embedding generation as string based lookups are inefficient.
292413 LLVM_ABI static StringRef getStringKey (unsigned Pos);
293414
294415 // / Create a dummy vocabulary for testing purposes.
295- LLVM_ABI static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
416+ LLVM_ABI static VocabStorage createDummyVocabForTest (unsigned Dim = 1 );
296417
297418 LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
298419 ModuleAnalysisManager::Invalidator &Inv) const ;
@@ -301,12 +422,16 @@ class Vocabulary {
301422 constexpr static unsigned NumCanonicalEntries =
302423 MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
303424
304- // Base offsets for slot layout to simplify index computation
425+ // Base offsets for flat index computation
305426 constexpr static unsigned OperandBaseOffset =
306427 MaxOpcodes + MaxCanonicalTypeIDs;
307428 constexpr static unsigned PredicateBaseOffset =
308429 OperandBaseOffset + MaxOperandKinds;
309430
431+ // / Functions for predicate index calculations
432+ static unsigned getPredicateLocalIndex (CmpInst::Predicate P);
433+ static CmpInst::Predicate getPredicateFromLocalIndex (unsigned LocalIndex);
434+
310435 // / String mappings for CanonicalTypeID values
311436 static constexpr StringLiteral CanonicalTypeNames[] = {
312437 " FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
@@ -353,13 +478,24 @@ class Vocabulary {
353478
354479 // / Function to get vocabulary key for canonical type by enum
355480 LLVM_ABI static StringRef
356- getVocabKeyForCanonicalTypeID (CanonicalTypeID CType);
481+ getVocabKeyForCanonicalTypeID (CanonicalTypeID CType) {
482+ unsigned Index = static_cast <unsigned >(CType);
483+ assert (Index < MaxCanonicalTypeIDs && " Invalid CanonicalTypeID" );
484+ return CanonicalTypeNames[Index];
485+ }
357486
358487 // / Function to convert TypeID to CanonicalTypeID
359- LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID);
488+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID) {
489+ unsigned Index = static_cast <unsigned >(TypeID);
490+ assert (Index < MaxTypeIDs && " Invalid TypeID" );
491+ return TypeIDMapping[Index];
492+ }
360493
361494 // / Function to get the predicate enum value for a given index
362- LLVM_ABI static CmpInst::Predicate getPredicate (unsigned Index);
495+ LLVM_ABI static CmpInst::Predicate getPredicate (unsigned Index) {
496+ assert (Index < MaxPredicateKinds && " Invalid predicate index" );
497+ return getPredicateFromLocalIndex (Index);
498+ }
363499};
364500
365501// / Embedder provides the interface to generate embeddings (vector
@@ -452,22 +588,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
452588// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
453589// / its corresponding embedding.
454590class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
455- using VocabVector = std::vector<ir2vec::Embedding>;
456591 using VocabMap = std::map<std::string, ir2vec::Embedding>;
457- VocabMap OpcVocab, TypeVocab, ArgVocab;
458- VocabVector Vocab;
592+ std::optional<ir2vec::VocabStorage> Vocab;
459593
460- Error readVocabulary ();
594+ Error readVocabulary (VocabMap &OpcVocab, VocabMap &TypeVocab,
595+ VocabMap &ArgVocab);
461596 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
462597 VocabMap &TargetVocab, unsigned &Dim);
463- void generateNumMappedVocab ();
598+ void generateVocabStorage (VocabMap &OpcVocab, VocabMap &TypeVocab,
599+ VocabMap &ArgVocab);
464600 void emitError (Error Err, LLVMContext &Ctx);
465601
466602public:
467603 LLVM_ABI static AnalysisKey Key;
468604 IR2VecVocabAnalysis () = default ;
469- LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector & Vocab);
470- LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector && Vocab);
605+ LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::VocabStorage && Vocab)
606+ : Vocab(std::move(Vocab)) {}
471607 using Result = ir2vec::Vocabulary;
472608 LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
473609};
0 commit comments