3636#include " llvm/Support/Compiler.h"
3737#include " llvm/Support/ErrorOr.h"
3838#include " llvm/Support/JSON.h"
39+ #include < array>
3940#include < map>
4041
4142namespace llvm {
@@ -137,13 +138,48 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
137138using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
138139
139140// / Class for storing and accessing the IR2Vec vocabulary.
140- // / Encapsulates all vocabulary-related constants, logic, and access methods.
141+ // /
142+ // / The Vocabulary class manages seed embeddings for LLVM IR entities. It
143+ // / contains the seed embeddings for three types of entities: instruction
144+ // / opcodes, types, and operands. Types are grouped/canonicalized for better
145+ // / learning (e.g., all float variants map to FloatTy). The vocabulary abstracts
146+ // / away the canonicalization effectively, the exposed APIs handle all the known
147+ // / LLVM IR opcodes, types and operands.
148+ // /
149+ // / This class helps populate the seed embeddings in an internal vector-based
150+ // / ADT. It provides logic to map every IR entity to a specific slot index or
151+ // / position in this vector, enabling O(1) embedding lookup while avoiding
152+ // / unnecessary computations involving string based lookups while generating the
153+ // / embeddings.
141154class Vocabulary {
142155 friend class llvm ::IR2VecVocabAnalysis;
143156 using VocabVector = std::vector<ir2vec::Embedding>;
144157 VocabVector Vocab;
145158 bool Valid = false ;
146159
160+ public:
161+ // Slot layout:
162+ // [0 .. MaxOpcodes-1] => Instruction opcodes
163+ // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
164+ // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
165+
166+ // / Canonical type IDs supported by IR2Vec Vocabulary
167+ enum class CanonicalTypeID : unsigned {
168+ FloatTy,
169+ VoidTy,
170+ LabelTy,
171+ MetadataTy,
172+ VectorTy,
173+ TokenTy,
174+ IntegerTy,
175+ FunctionTy,
176+ PointerTy,
177+ StructTy,
178+ ArrayTy,
179+ UnknownTy,
180+ MaxCanonicalType
181+ };
182+
147183 // / Operand kinds supported by IR2Vec Vocabulary
148184 enum class OperandKind : unsigned {
149185 FunctionID,
@@ -152,20 +188,15 @@ class Vocabulary {
152188 VariableID,
153189 MaxOperandKind
154190 };
155- // / String mappings for OperandKind values
156- static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
157- " Constant" , " Variable" };
158- static_assert (std::size(OperandKindNames) ==
159- static_cast <unsigned >(OperandKind::MaxOperandKind),
160- " OperandKindNames array size must match MaxOperandKind" );
161191
162- public:
163192 // / Vocabulary layout constants
164193#define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
165194#include " llvm/IR/Instruction.def"
166195#undef LAST_OTHER_INST
167196
168197 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1 ;
198+ static constexpr unsigned MaxCanonicalTypeIDs =
199+ static_cast <unsigned >(CanonicalTypeID::MaxCanonicalType);
169200 static constexpr unsigned MaxOperandKinds =
170201 static_cast <unsigned >(OperandKind::MaxOperandKind);
171202
@@ -174,11 +205,8 @@ class Vocabulary {
174205
175206 LLVM_ABI bool isValid () const ;
176207 LLVM_ABI unsigned getDimension () const ;
177- LLVM_ABI size_t size () const ;
178-
179- static size_t expectedSize () {
180- return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
181- }
208+ // / Total number of entries (opcodes + canonicalized types + operand kinds)
209+ static constexpr size_t getCanonicalSize () { return NumCanonicalEntries; }
182210
183211 // / Helper function to get vocabulary key for a given Opcode
184212 LLVM_ABI static StringRef getVocabKeyForOpcode (unsigned Opcode);
@@ -192,10 +220,11 @@ class Vocabulary {
192220 // / Helper function to classify an operand into OperandKind
193221 LLVM_ABI static OperandKind getOperandKind (const Value *Op);
194222
195- // / Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
196- LLVM_ABI static unsigned getNumericID (unsigned Opcode);
197- LLVM_ABI static unsigned getNumericID (Type::TypeID TypeID);
198- LLVM_ABI static unsigned getNumericID (const Value *Op);
223+ // / Helpers to return the slot index or position of a given Opcode, TypeID, or
224+ // / OperandKind in the vocabulary.
225+ LLVM_ABI static unsigned getSlotIdx (unsigned Opcode);
226+ LLVM_ABI static unsigned getSlotIdx (Type::TypeID TypeID);
227+ LLVM_ABI static unsigned getSlotIdx (const Value *Op);
199228
200229 // / Accessors to get the embedding for a given entity.
201230 LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
@@ -234,6 +263,61 @@ class Vocabulary {
234263
235264 LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236265 ModuleAnalysisManager::Invalidator &Inv) const ;
266+
267+ private:
268+ constexpr static unsigned NumCanonicalEntries =
269+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
270+
271+ // / String mappings for CanonicalTypeID values
272+ static constexpr StringLiteral CanonicalTypeNames[] = {
273+ " FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
274+ " VectorTy" , " TokenTy" , " IntegerTy" , " FunctionTy" ,
275+ " PointerTy" , " StructTy" , " ArrayTy" , " UnknownTy" };
276+ static_assert (std::size(CanonicalTypeNames) ==
277+ static_cast <unsigned >(CanonicalTypeID::MaxCanonicalType),
278+ " CanonicalTypeNames array size must match MaxCanonicalType" );
279+
280+ // / String mappings for OperandKind values
281+ static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
282+ " Constant" , " Variable" };
283+ static_assert (std::size(OperandKindNames) ==
284+ static_cast <unsigned >(OperandKind::MaxOperandKind),
285+ " OperandKindNames array size must match MaxOperandKind" );
286+
287+ // / Every known TypeID defined in llvm/IR/Type.h is expected to have a
288+ // / corresponding mapping here in the same order as enum Type::TypeID.
289+ static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
290+ CanonicalTypeID::FloatTy, // HalfTyID = 0
291+ CanonicalTypeID::FloatTy, // BFloatTyID
292+ CanonicalTypeID::FloatTy, // FloatTyID
293+ CanonicalTypeID::FloatTy, // DoubleTyID
294+ CanonicalTypeID::FloatTy, // X86_FP80TyID
295+ CanonicalTypeID::FloatTy, // FP128TyID
296+ CanonicalTypeID::FloatTy, // PPC_FP128TyID
297+ CanonicalTypeID::VoidTy, // VoidTyID
298+ CanonicalTypeID::LabelTy, // LabelTyID
299+ CanonicalTypeID::MetadataTy, // MetadataTyID
300+ CanonicalTypeID::VectorTy, // X86_AMXTyID
301+ CanonicalTypeID::TokenTy, // TokenTyID
302+ CanonicalTypeID::IntegerTy, // IntegerTyID
303+ CanonicalTypeID::FunctionTy, // FunctionTyID
304+ CanonicalTypeID::PointerTy, // PointerTyID
305+ CanonicalTypeID::StructTy, // StructTyID
306+ CanonicalTypeID::ArrayTy, // ArrayTyID
307+ CanonicalTypeID::VectorTy, // FixedVectorTyID
308+ CanonicalTypeID::VectorTy, // ScalableVectorTyID
309+ CanonicalTypeID::PointerTy, // TypedPointerTyID
310+ CanonicalTypeID::UnknownTy // TargetExtTyID
311+ }};
312+ static_assert (TypeIDMapping.size() == MaxTypeIDs,
313+ " TypeIDMapping must cover all Type::TypeID values" );
314+
315+ // / Helper function to get vocabulary key for canonical type by enum
316+ LLVM_ABI static StringRef
317+ getVocabKeyForCanonicalTypeID (CanonicalTypeID CType);
318+
319+ // / Helper function to convert TypeID to CanonicalTypeID
320+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID);
237321};
238322
239323// / Embedder provides the interface to generate embeddings (vector
0 commit comments