3636#include " llvm/Support/Compiler.h"
3737#include " llvm/Support/ErrorOr.h"
3838#include " llvm/Support/JSON.h"
39+ #include < array>
3940#include < map>
4041
4142namespace llvm {
@@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
137138using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
138139
139140// / Class for storing and accessing the IR2Vec vocabulary.
140- // / Encapsulates all vocabulary-related constants, logic, and access methods.
141+ // / The Vocabulary class manages seed embeddings for LLVM IR entities. The
142+ // / seed embeddings are the initial learned representations of the entities
143+ // / of LLVM IR. The IR2Vec representation for a given IR is derived from these
144+ // / seed embeddings.
145+ // /
146+ // / The vocabulary contains the seed embeddings for three types of entities:
147+ // / instruction opcodes, types, and operands. Types are grouped/canonicalized
148+ // / for better learning (e.g., all float variants map to FloatTy). The
149+ // / vocabulary abstracts away the canonicalization effectively, the exposed APIs
150+ // / handle all the known LLVM IR opcodes, types and operands.
151+ // /
152+ // / This class helps populate the seed embeddings in an internal vector-based
153+ // / ADT. It provides logic to map every IR entity to a specific slot index or
154+ // / position in this vector, enabling O(1) embedding lookup while avoiding
155+ // / unnecessary computations involving string based lookups while generating the
156+ // / embeddings.
141157class Vocabulary {
142158 friend class llvm ::IR2VecVocabAnalysis;
143159 using VocabVector = std::vector<ir2vec::Embedding>;
144160 VocabVector Vocab;
145161 bool Valid = false ;
146162
163+ public:
164+ // Slot layout:
165+ // [0 .. MaxOpcodes-1] => Instruction opcodes
166+ // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
167+ // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
168+
169+ // / Canonical type IDs supported by IR2Vec Vocabulary
170+ enum class CanonicalTypeID : unsigned {
171+ FloatTy,
172+ VoidTy,
173+ LabelTy,
174+ MetadataTy,
175+ VectorTy,
176+ TokenTy,
177+ IntegerTy,
178+ FunctionTy,
179+ PointerTy,
180+ StructTy,
181+ ArrayTy,
182+ UnknownTy,
183+ MaxCanonicalType
184+ };
185+
147186 // / Operand kinds supported by IR2Vec Vocabulary
148187 enum class OperandKind : unsigned {
149188 FunctionID,
@@ -152,20 +191,15 @@ class Vocabulary {
152191 VariableID,
153192 MaxOperandKind
154193 };
155- // / String mappings for OperandKind values
156- static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
157- " Constant" , " Variable" };
158- static_assert (std::size(OperandKindNames) ==
159- static_cast <unsigned >(OperandKind::MaxOperandKind),
160- " OperandKindNames array size must match MaxOperandKind" );
161194
162- public:
163195 // / Vocabulary layout constants
164196#define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
165197#include " llvm/IR/Instruction.def"
166198#undef LAST_OTHER_INST
167199
168200 static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1 ;
201+ static constexpr unsigned MaxCanonicalTypeIDs =
202+ static_cast <unsigned >(CanonicalTypeID::MaxCanonicalType);
169203 static constexpr unsigned MaxOperandKinds =
170204 static_cast <unsigned >(OperandKind::MaxOperandKind);
171205
@@ -174,33 +208,31 @@ class Vocabulary {
174208
175209 LLVM_ABI bool isValid () const ;
176210 LLVM_ABI unsigned getDimension () const ;
177- LLVM_ABI size_t size () const ;
211+ // / Total number of entries (opcodes + canonicalized types + operand kinds)
212+ static constexpr size_t getCanonicalSize () { return NumCanonicalEntries; }
178213
179- static size_t expectedSize () {
180- return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
181- }
182-
183- // / Helper function to get vocabulary key for a given Opcode
214+ // / Function to get vocabulary key for a given Opcode
184215 LLVM_ABI static StringRef getVocabKeyForOpcode (unsigned Opcode);
185216
186- // / Helper function to get vocabulary key for a given TypeID
217+ // / Function to get vocabulary key for a given TypeID
187218 LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
188219
189- // / Helper function to get vocabulary key for a given OperandKind
220+ // / Function to get vocabulary key for a given OperandKind
190221 LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind);
191222
192- // / Helper function to classify an operand into OperandKind
223+ // / Function to classify an operand into OperandKind
193224 LLVM_ABI static OperandKind getOperandKind (const Value *Op);
194225
195- // / Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
196- LLVM_ABI static unsigned getNumericID (unsigned Opcode);
197- LLVM_ABI static unsigned getNumericID (Type::TypeID TypeID);
198- LLVM_ABI static unsigned getNumericID (const Value *Op);
226+ // / Functions to return the slot index or position of a given Opcode, TypeID,
227+ // / or OperandKind in the vocabulary.
228+ LLVM_ABI static unsigned getSlotIndex (unsigned Opcode);
229+ LLVM_ABI static unsigned getSlotIndex (Type::TypeID TypeID);
230+ LLVM_ABI static unsigned getSlotIndex (const Value *Op);
199231
200232 // / Accessors to get the embedding for a given entity.
201233 LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
202234 LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
203- LLVM_ABI const ir2vec::Embedding &operator [](const Value * Arg) const ;
235+ LLVM_ABI const ir2vec::Embedding &operator [](const Value & Arg) const ;
204236
205237 // / Const Iterator type aliases
206238 using const_iterator = VocabVector::const_iterator;
@@ -234,6 +266,61 @@ class Vocabulary {
234266
235267 LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236268 ModuleAnalysisManager::Invalidator &Inv) const ;
269+
270+ private:
271+ constexpr static unsigned NumCanonicalEntries =
272+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
273+
274+ // / String mappings for CanonicalTypeID values
275+ static constexpr StringLiteral CanonicalTypeNames[] = {
276+ " FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
277+ " VectorTy" , " TokenTy" , " IntegerTy" , " FunctionTy" ,
278+ " PointerTy" , " StructTy" , " ArrayTy" , " UnknownTy" };
279+ static_assert (std::size(CanonicalTypeNames) ==
280+ static_cast <unsigned >(CanonicalTypeID::MaxCanonicalType),
281+ " CanonicalTypeNames array size must match MaxCanonicalType" );
282+
283+ // / String mappings for OperandKind values
284+ static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
285+ " Constant" , " Variable" };
286+ static_assert (std::size(OperandKindNames) ==
287+ static_cast <unsigned >(OperandKind::MaxOperandKind),
288+ " OperandKindNames array size must match MaxOperandKind" );
289+
290+ // / Every known TypeID defined in llvm/IR/Type.h is expected to have a
291+ // / corresponding mapping here in the same order as enum Type::TypeID.
292+ static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
293+ CanonicalTypeID::FloatTy, // HalfTyID = 0
294+ CanonicalTypeID::FloatTy, // BFloatTyID
295+ CanonicalTypeID::FloatTy, // FloatTyID
296+ CanonicalTypeID::FloatTy, // DoubleTyID
297+ CanonicalTypeID::FloatTy, // X86_FP80TyID
298+ CanonicalTypeID::FloatTy, // FP128TyID
299+ CanonicalTypeID::FloatTy, // PPC_FP128TyID
300+ CanonicalTypeID::VoidTy, // VoidTyID
301+ CanonicalTypeID::LabelTy, // LabelTyID
302+ CanonicalTypeID::MetadataTy, // MetadataTyID
303+ CanonicalTypeID::VectorTy, // X86_AMXTyID
304+ CanonicalTypeID::TokenTy, // TokenTyID
305+ CanonicalTypeID::IntegerTy, // IntegerTyID
306+ CanonicalTypeID::FunctionTy, // FunctionTyID
307+ CanonicalTypeID::PointerTy, // PointerTyID
308+ CanonicalTypeID::StructTy, // StructTyID
309+ CanonicalTypeID::ArrayTy, // ArrayTyID
310+ CanonicalTypeID::VectorTy, // FixedVectorTyID
311+ CanonicalTypeID::VectorTy, // ScalableVectorTyID
312+ CanonicalTypeID::PointerTy, // TypedPointerTyID
313+ CanonicalTypeID::UnknownTy // TargetExtTyID
314+ }};
315+ static_assert (TypeIDMapping.size() == MaxTypeIDs,
316+ " TypeIDMapping must cover all Type::TypeID values" );
317+
318+ // / Function to get vocabulary key for canonical type by enum
319+ LLVM_ABI static StringRef
320+ getVocabKeyForCanonicalTypeID (CanonicalTypeID CType);
321+
322+ // / Function to convert TypeID to CanonicalTypeID
323+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID);
237324};
238325
239326// / Embedder provides the interface to generate embeddings (vector
@@ -262,11 +349,11 @@ class Embedder {
262349
263350 LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab);
264351
265- // / Helper function to compute embeddings. It generates embeddings for all
352+ // / Function to compute embeddings. It generates embeddings for all
266353 // / the instructions and basic blocks in the function F.
267354 void computeEmbeddings () const ;
268355
269- // / Helper function to compute the embedding for a given basic block.
356+ // / Function to compute the embedding for a given basic block.
270357 // / Specific to the kind of embeddings being computed.
271358 virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
272359
0 commit comments