4545#include " llvm/Support/JSON.h"
4646#include < array>
4747#include < map>
48+ #include < optional>
4849
4950namespace llvm {
5051
@@ -144,6 +145,73 @@ struct Embedding {
144145using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145146using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146147
148+ // / Generic storage class for section-based vocabularies.
149+ // / VocabStorage provides a generic foundation for storing and accessing
150+ // / embeddings organized into sections.
151+ class VocabStorage {
152+ private:
153+ // / Section-based storage
154+ std::vector<std::vector<Embedding>> Sections;
155+
156+ const size_t TotalSize;
157+ const unsigned Dimension;
158+
159+ public:
160+ // / Default constructor creates empty storage (invalid state)
161+ VocabStorage () : Sections(), TotalSize(0 ), Dimension(0 ) {}
162+
163+ // / Create a VocabStorage with pre-organized section data
164+ VocabStorage (std::vector<std::vector<Embedding>> &&SectionData);
165+
166+ VocabStorage (VocabStorage &&) = default ;
167+ VocabStorage &operator =(VocabStorage &&) = delete ;
168+
169+ VocabStorage (const VocabStorage &) = delete ;
170+ VocabStorage &operator =(const VocabStorage &) = delete ;
171+
172+ // / Get total number of entries across all sections
173+ size_t size () const { return TotalSize; }
174+
175+ // / Get number of sections
176+ unsigned getNumSections () const {
177+ return static_cast <unsigned >(Sections.size ());
178+ }
179+
180+ // / Section-based access: Storage[sectionId][localIndex]
181+ const std::vector<Embedding> &operator [](unsigned SectionId) const {
182+ assert (SectionId < Sections.size () && " Invalid section ID" );
183+ return Sections[SectionId];
184+ }
185+
186+ // / Get vocabulary dimension
187+ unsigned getDimension () const { return Dimension; }
188+
189+ // / Check if vocabulary is valid (has data)
190+ bool isValid () const { return TotalSize > 0 ; }
191+
192+ // / Iterator support for section-based access
193+ class const_iterator {
194+ const VocabStorage *Storage;
195+ unsigned SectionId = 0 ;
196+ size_t LocalIndex = 0 ;
197+
198+ public:
199+ const_iterator (const VocabStorage *Storage, unsigned SectionId,
200+ size_t LocalIndex)
201+ : Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202+
203+ LLVM_ABI const Embedding &operator *() const ;
204+ LLVM_ABI const_iterator &operator ++();
205+ LLVM_ABI bool operator ==(const const_iterator &Other) const ;
206+ LLVM_ABI bool operator !=(const const_iterator &Other) const ;
207+ };
208+
209+ const_iterator begin () const { return const_iterator (this , 0 , 0 ); }
210+ const_iterator end () const {
211+ return const_iterator (this , getNumSections (), 0 );
212+ }
213+ };
214+
147215// / Class for storing and accessing the IR2Vec vocabulary.
148216// / The Vocabulary class manages seed embeddings for LLVM IR entities. The
149217// / seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164232class Vocabulary {
165233 friend class llvm ::IR2VecVocabAnalysis;
166234
167- // Vocabulary Slot Layout:
235+ // Vocabulary Layout:
168236 // +----------------+------------------------------------------------------+
169237 // | Entity Type | Index Range |
170238 // +----------------+------------------------------------------------------+
@@ -180,8 +248,16 @@ class Vocabulary {
180248 // and improves learning. Operands include Comparison predicates
181249 // (ICmp/FCmp) along with other operand types. This can be extended to
182250 // include other specializations in future.
183- using VocabVector = std::vector<ir2vec::Embedding>;
184- VocabVector Vocab;
251+ enum class Section : unsigned {
252+ Opcodes = 0 ,
253+ CanonicalTypes = 1 ,
254+ Operands = 2 ,
255+ Predicates = 3 ,
256+ MaxSections
257+ };
258+
259+ // Use section-based storage for better organization and efficiency
260+ VocabStorage Storage;
185261
186262 static constexpr unsigned NumICmpPredicates =
187263 static_cast <unsigned >(CmpInst::LAST_ICMP_PREDICATE) -
@@ -233,10 +309,23 @@ class Vocabulary {
233309 NumICmpPredicates + NumFCmpPredicates;
234310
235311 Vocabulary () = default ;
236- LLVM_ABI Vocabulary (VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
312+ LLVM_ABI Vocabulary (VocabStorage &&Storage) : Storage(std::move(Storage)) {}
313+
314+ Vocabulary (const Vocabulary &) = delete ;
315+ Vocabulary &operator =(const Vocabulary &) = delete ;
316+
317+ Vocabulary (Vocabulary &&) = default ;
318+ Vocabulary &operator =(Vocabulary &&Other) = delete ;
319+
320+ LLVM_ABI bool isValid () const {
321+ return Storage.size () == NumCanonicalEntries;
322+ }
323+
324+ LLVM_ABI unsigned getDimension () const {
325+ assert (isValid () && " IR2Vec Vocabulary is invalid" );
326+ return Storage.getDimension ();
327+ }
237328
238- LLVM_ABI bool isValid () const { return Vocab.size () == NumCanonicalEntries; };
239- LLVM_ABI unsigned getDimension () const ;
240329 // / Total number of entries (opcodes + canonicalized types + operand kinds +
241330 // / predicates)
242331 static constexpr size_t getCanonicalSize () { return NumCanonicalEntries; }
@@ -245,59 +334,91 @@ class Vocabulary {
245334 LLVM_ABI static StringRef getVocabKeyForOpcode (unsigned Opcode);
246335
247336 // / Function to get vocabulary key for a given TypeID
248- LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
337+ LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID) {
338+ return getVocabKeyForCanonicalTypeID (getCanonicalTypeID (TypeID));
339+ }
249340
250341 // / Function to get vocabulary key for a given OperandKind
251- LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind);
342+ LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind) {
343+ unsigned Index = static_cast <unsigned >(Kind);
344+ assert (Index < MaxOperandKinds && " Invalid OperandKind" );
345+ return OperandKindNames[Index];
346+ }
252347
253348 // / Function to classify an operand into OperandKind
254349 LLVM_ABI static OperandKind getOperandKind (const Value *Op);
255350
256351 // / Function to get vocabulary key for a given predicate
257352 LLVM_ABI static StringRef getVocabKeyForPredicate (CmpInst::Predicate P);
258353
259- // / Functions to return the slot index or position of a given Opcode, TypeID,
260- // / or OperandKind in the vocabulary.
261- LLVM_ABI static unsigned getSlotIndex (unsigned Opcode);
262- LLVM_ABI static unsigned getSlotIndex (Type::TypeID TypeID);
263- LLVM_ABI static unsigned getSlotIndex (const Value &Op);
264- LLVM_ABI static unsigned getSlotIndex (CmpInst::Predicate P);
354+ // / Functions to return flat index
355+ LLVM_ABI static unsigned getIndex (unsigned Opcode) {
356+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
357+ return Opcode - 1 ; // Convert to zero-based index
358+ }
359+
360+ LLVM_ABI static unsigned getIndex (Type::TypeID TypeID) {
361+ assert (static_cast <unsigned >(TypeID) < MaxTypeIDs && " Invalid type ID" );
362+ return MaxOpcodes + static_cast <unsigned >(getCanonicalTypeID (TypeID));
363+ }
364+
365+ LLVM_ABI static unsigned getIndex (const Value &Op) {
366+ unsigned Index = static_cast <unsigned >(getOperandKind (&Op));
367+ assert (Index < MaxOperandKinds && " Invalid OperandKind" );
368+ return OperandBaseOffset + Index;
369+ }
370+
371+ LLVM_ABI static unsigned getIndex (CmpInst::Predicate P) {
372+ return PredicateBaseOffset + getPredicateLocalIndex (P);
373+ }
265374
266375 // / Accessors to get the embedding for a given entity.
267- LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
268- LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
269- LLVM_ABI const ir2vec::Embedding &operator [](const Value &Arg) const ;
270- LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const ;
376+ LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const {
377+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
378+ return Storage[static_cast <unsigned >(Section::Opcodes)][Opcode - 1 ];
379+ }
380+
381+ LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeID) const {
382+ assert (static_cast <unsigned >(TypeID) < MaxTypeIDs && " Invalid type ID" );
383+ unsigned LocalIndex = static_cast <unsigned >(getCanonicalTypeID (TypeID));
384+ return Storage[static_cast <unsigned >(Section::CanonicalTypes)][LocalIndex];
385+ }
386+
387+ LLVM_ABI const ir2vec::Embedding &operator [](const Value &Arg) const {
388+ unsigned LocalIndex = static_cast <unsigned >(getOperandKind (&Arg));
389+ assert (LocalIndex < MaxOperandKinds && " Invalid OperandKind" );
390+ return Storage[static_cast <unsigned >(Section::Operands)][LocalIndex];
391+ }
392+
393+ LLVM_ABI const ir2vec::Embedding &operator [](CmpInst::Predicate P) const {
394+ unsigned LocalIndex = getPredicateLocalIndex (P);
395+ return Storage[static_cast <unsigned >(Section::Predicates)][LocalIndex];
396+ }
271397
272398 // / Const Iterator type aliases
273- using const_iterator = VocabVector::const_iterator;
399+ using const_iterator = VocabStorage::const_iterator;
400+
274401 const_iterator begin () const {
275402 assert (isValid () && " IR2Vec Vocabulary is invalid" );
276- return Vocab .begin ();
403+ return Storage .begin ();
277404 }
278405
279- const_iterator cbegin () const {
280- assert (isValid () && " IR2Vec Vocabulary is invalid" );
281- return Vocab.cbegin ();
282- }
406+ const_iterator cbegin () const { return begin (); }
283407
284408 const_iterator end () const {
285409 assert (isValid () && " IR2Vec Vocabulary is invalid" );
286- return Vocab .end ();
410+ return Storage .end ();
287411 }
288412
289- const_iterator cend () const {
290- assert (isValid () && " IR2Vec Vocabulary is invalid" );
291- return Vocab.cend ();
292- }
413+ const_iterator cend () const { return end (); }
293414
294415 // / Returns the string key for a given index position in the vocabulary.
295416 // / This is useful for debugging or printing the vocabulary. Do not use this
296417 // / for embedding generation as string based lookups are inefficient.
297418 LLVM_ABI static StringRef getStringKey (unsigned Pos);
298419
299420 // / Create a dummy vocabulary for testing purposes.
300- LLVM_ABI static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
421+ LLVM_ABI static VocabStorage createDummyVocabForTest (unsigned Dim = 1 );
301422
302423 LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
303424 ModuleAnalysisManager::Invalidator &Inv) const ;
@@ -306,12 +427,16 @@ class Vocabulary {
306427 constexpr static unsigned NumCanonicalEntries =
307428 MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
308429
309- // Base offsets for slot layout to simplify index computation
430+ // Base offsets for flat index computation
310431 constexpr static unsigned OperandBaseOffset =
311432 MaxOpcodes + MaxCanonicalTypeIDs;
312433 constexpr static unsigned PredicateBaseOffset =
313434 OperandBaseOffset + MaxOperandKinds;
314435
436+ // / Functions for predicate index calculations
437+ static unsigned getPredicateLocalIndex (CmpInst::Predicate P);
438+ static CmpInst::Predicate getPredicateFromLocalIndex (unsigned LocalIndex);
439+
315440 // / String mappings for CanonicalTypeID values
316441 static constexpr StringLiteral CanonicalTypeNames[] = {
317442 " FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
@@ -358,15 +483,26 @@ class Vocabulary {
358483
359484 // / Function to get vocabulary key for canonical type by enum
360485 LLVM_ABI static StringRef
361- getVocabKeyForCanonicalTypeID (CanonicalTypeID CType);
486+ getVocabKeyForCanonicalTypeID (CanonicalTypeID CType) {
487+ unsigned Index = static_cast <unsigned >(CType);
488+ assert (Index < MaxCanonicalTypeIDs && " Invalid CanonicalTypeID" );
489+ return CanonicalTypeNames[Index];
490+ }
362491
363492 // / Function to convert TypeID to CanonicalTypeID
364- LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID);
493+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID) {
494+ unsigned Index = static_cast <unsigned >(TypeID);
495+ assert (Index < MaxTypeIDs && " Invalid TypeID" );
496+ return TypeIDMapping[Index];
497+ }
365498
366499 // / Function to get the predicate enum value for a given index. Index is
367500 // / relative to the predicates section of the vocabulary. E.g., Index 0
368501 // / corresponds to the first predicate.
369- LLVM_ABI static CmpInst::Predicate getPredicate (unsigned Index);
502+ LLVM_ABI static CmpInst::Predicate getPredicate (unsigned Index) {
503+ assert (Index < MaxPredicateKinds && " Invalid predicate index" );
504+ return getPredicateFromLocalIndex (Index);
505+ }
370506};
371507
372508// / Embedder provides the interface to generate embeddings (vector
@@ -459,22 +595,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
459595// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
460596// / its corresponding embedding.
461597class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
462- using VocabVector = std::vector<ir2vec::Embedding>;
463598 using VocabMap = std::map<std::string, ir2vec::Embedding>;
464- VocabMap OpcVocab, TypeVocab, ArgVocab;
465- VocabVector Vocab;
599+ std::optional<ir2vec::VocabStorage> Vocab;
466600
467- Error readVocabulary ();
601+ Error readVocabulary (VocabMap &OpcVocab, VocabMap &TypeVocab,
602+ VocabMap &ArgVocab);
468603 Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
469604 VocabMap &TargetVocab, unsigned &Dim);
470- void generateNumMappedVocab ();
605+ void generateVocabStorage (VocabMap &OpcVocab, VocabMap &TypeVocab,
606+ VocabMap &ArgVocab);
471607 void emitError (Error Err, LLVMContext &Ctx);
472608
473609public:
474610 LLVM_ABI static AnalysisKey Key;
475611 IR2VecVocabAnalysis () = default ;
476- LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector & Vocab);
477- LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector && Vocab);
612+ LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::VocabStorage && Vocab)
613+ : Vocab(std::move(Vocab)) {}
478614 using Result = ir2vec::Vocabulary;
479615 LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
480616};
0 commit comments