Skip to content

Commit 45c5498

Browse files
authored
[IR2Vec] Refactor vocabulary to use canonical type IDs (#155323)
Refactor IR2Vec vocabulary to use canonical type IDs, improving the embedding representation for LLVM IR types. The previous implementation used raw Type::TypeID values directly in the vocabulary, which led to redundant entries (e.g., all float variants mapped to "FloatTy" but had separate slots). This change improves the vocabulary by: 1. Making the type representation more consistent by properly canonicalizing types 2. Reducing vocabulary size by eliminating redundant entries 3. Improving the embedding quality by ensuring similar types share the same representation (Tracking issue - #141817)
1 parent e317c7e commit 45c5498

File tree

9 files changed

+350
-308
lines changed

9 files changed

+350
-308
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 111 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/Support/Compiler.h"
3737
#include "llvm/Support/ErrorOr.h"
3838
#include "llvm/Support/JSON.h"
39+
#include <array>
3940
#include <map>
4041

4142
namespace llvm {
@@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
137138
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
138139

139140
/// Class for storing and accessing the IR2Vec vocabulary.
140-
/// Encapsulates all vocabulary-related constants, logic, and access methods.
141+
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
142+
/// seed embeddings are the initial learned representations of the entities
143+
/// of LLVM IR. The IR2Vec representation for a given IR is derived from these
144+
/// seed embeddings.
145+
///
146+
/// The vocabulary contains the seed embeddings for three types of entities:
147+
/// instruction opcodes, types, and operands. Types are grouped/canonicalized
148+
/// for better learning (e.g., all float variants map to FloatTy). The
149+
/// vocabulary abstracts away the canonicalization effectively, the exposed APIs
150+
/// handle all the known LLVM IR opcodes, types and operands.
151+
///
152+
/// This class helps populate the seed embeddings in an internal vector-based
153+
/// ADT. It provides logic to map every IR entity to a specific slot index or
154+
/// position in this vector, enabling O(1) embedding lookup while avoiding
155+
/// unnecessary computations involving string based lookups while generating the
156+
/// embeddings.
141157
class Vocabulary {
142158
friend class llvm::IR2VecVocabAnalysis;
143159
using VocabVector = std::vector<ir2vec::Embedding>;
144160
VocabVector Vocab;
145161
bool Valid = false;
146162

163+
public:
164+
// Slot layout:
165+
// [0 .. MaxOpcodes-1] => Instruction opcodes
166+
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
167+
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
168+
169+
/// Canonical type IDs supported by IR2Vec Vocabulary
170+
enum class CanonicalTypeID : unsigned {
171+
FloatTy,
172+
VoidTy,
173+
LabelTy,
174+
MetadataTy,
175+
VectorTy,
176+
TokenTy,
177+
IntegerTy,
178+
FunctionTy,
179+
PointerTy,
180+
StructTy,
181+
ArrayTy,
182+
UnknownTy,
183+
MaxCanonicalType
184+
};
185+
147186
/// Operand kinds supported by IR2Vec Vocabulary
148187
enum class OperandKind : unsigned {
149188
FunctionID,
@@ -152,20 +191,15 @@ class Vocabulary {
152191
VariableID,
153192
MaxOperandKind
154193
};
155-
/// String mappings for OperandKind values
156-
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
157-
"Constant", "Variable"};
158-
static_assert(std::size(OperandKindNames) ==
159-
static_cast<unsigned>(OperandKind::MaxOperandKind),
160-
"OperandKindNames array size must match MaxOperandKind");
161194

162-
public:
163195
/// Vocabulary layout constants
164196
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
165197
#include "llvm/IR/Instruction.def"
166198
#undef LAST_OTHER_INST
167199

168200
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
201+
static constexpr unsigned MaxCanonicalTypeIDs =
202+
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
169203
static constexpr unsigned MaxOperandKinds =
170204
static_cast<unsigned>(OperandKind::MaxOperandKind);
171205

@@ -174,33 +208,31 @@ class Vocabulary {
174208

175209
LLVM_ABI bool isValid() const;
176210
LLVM_ABI unsigned getDimension() const;
177-
LLVM_ABI size_t size() const;
211+
/// Total number of entries (opcodes + canonicalized types + operand kinds)
212+
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
178213

179-
static size_t expectedSize() {
180-
return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
181-
}
182-
183-
/// Helper function to get vocabulary key for a given Opcode
214+
/// Function to get vocabulary key for a given Opcode
184215
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
185216

186-
/// Helper function to get vocabulary key for a given TypeID
217+
/// Function to get vocabulary key for a given TypeID
187218
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
188219

189-
/// Helper function to get vocabulary key for a given OperandKind
220+
/// Function to get vocabulary key for a given OperandKind
190221
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
191222

192-
/// Helper function to classify an operand into OperandKind
223+
/// Function to classify an operand into OperandKind
193224
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
194225

195-
/// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
196-
LLVM_ABI static unsigned getNumericID(unsigned Opcode);
197-
LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
198-
LLVM_ABI static unsigned getNumericID(const Value *Op);
226+
/// Functions to return the slot index or position of a given Opcode, TypeID,
227+
/// or OperandKind in the vocabulary.
228+
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
229+
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
230+
LLVM_ABI static unsigned getSlotIndex(const Value *Op);
199231

200232
/// Accessors to get the embedding for a given entity.
201233
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
202234
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
203-
LLVM_ABI const ir2vec::Embedding &operator[](const Value *Arg) const;
235+
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
204236

205237
/// Const Iterator type aliases
206238
using const_iterator = VocabVector::const_iterator;
@@ -234,6 +266,61 @@ class Vocabulary {
234266

235267
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
236268
ModuleAnalysisManager::Invalidator &Inv) const;
269+
270+
private:
271+
constexpr static unsigned NumCanonicalEntries =
272+
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
273+
274+
/// String mappings for CanonicalTypeID values
275+
static constexpr StringLiteral CanonicalTypeNames[] = {
276+
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
277+
"VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
278+
"PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
279+
static_assert(std::size(CanonicalTypeNames) ==
280+
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
281+
"CanonicalTypeNames array size must match MaxCanonicalType");
282+
283+
/// String mappings for OperandKind values
284+
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
285+
"Constant", "Variable"};
286+
static_assert(std::size(OperandKindNames) ==
287+
static_cast<unsigned>(OperandKind::MaxOperandKind),
288+
"OperandKindNames array size must match MaxOperandKind");
289+
290+
/// Every known TypeID defined in llvm/IR/Type.h is expected to have a
291+
/// corresponding mapping here in the same order as enum Type::TypeID.
292+
static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
293+
CanonicalTypeID::FloatTy, // HalfTyID = 0
294+
CanonicalTypeID::FloatTy, // BFloatTyID
295+
CanonicalTypeID::FloatTy, // FloatTyID
296+
CanonicalTypeID::FloatTy, // DoubleTyID
297+
CanonicalTypeID::FloatTy, // X86_FP80TyID
298+
CanonicalTypeID::FloatTy, // FP128TyID
299+
CanonicalTypeID::FloatTy, // PPC_FP128TyID
300+
CanonicalTypeID::VoidTy, // VoidTyID
301+
CanonicalTypeID::LabelTy, // LabelTyID
302+
CanonicalTypeID::MetadataTy, // MetadataTyID
303+
CanonicalTypeID::VectorTy, // X86_AMXTyID
304+
CanonicalTypeID::TokenTy, // TokenTyID
305+
CanonicalTypeID::IntegerTy, // IntegerTyID
306+
CanonicalTypeID::FunctionTy, // FunctionTyID
307+
CanonicalTypeID::PointerTy, // PointerTyID
308+
CanonicalTypeID::StructTy, // StructTyID
309+
CanonicalTypeID::ArrayTy, // ArrayTyID
310+
CanonicalTypeID::VectorTy, // FixedVectorTyID
311+
CanonicalTypeID::VectorTy, // ScalableVectorTyID
312+
CanonicalTypeID::PointerTy, // TypedPointerTyID
313+
CanonicalTypeID::UnknownTy // TargetExtTyID
314+
}};
315+
static_assert(TypeIDMapping.size() == MaxTypeIDs,
316+
"TypeIDMapping must cover all Type::TypeID values");
317+
318+
/// Function to get vocabulary key for canonical type by enum
319+
LLVM_ABI static StringRef
320+
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
321+
322+
/// Function to convert TypeID to CanonicalTypeID
323+
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
237324
};
238325

239326
/// Embedder provides the interface to generate embeddings (vector
@@ -262,11 +349,11 @@ class Embedder {
262349

263350
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
264351

265-
/// Helper function to compute embeddings. It generates embeddings for all
352+
/// Function to compute embeddings. It generates embeddings for all
266353
/// the instructions and basic blocks in the function F.
267354
void computeEmbeddings() const;
268355

269-
/// Helper function to compute the embedding for a given basic block.
356+
/// Function to compute the embedding for a given basic block.
270357
/// Specific to the kind of embeddings being computed.
271358
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
272359

0 commit comments

Comments
 (0)