Skip to content

Commit 66ad8aa

Browse files
committed
Canonicalized type
1 parent 8ce3c84 commit 66ad8aa

File tree

9 files changed

+335
-297
lines changed

9 files changed

+335
-297
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 101 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/Support/Compiler.h"
3737
#include "llvm/Support/ErrorOr.h"
3838
#include "llvm/Support/JSON.h"
39+
#include <array>
3940
#include <map>
4041

4142
namespace llvm {
@@ -137,13 +138,48 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
137138
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
138139

139140
/// Class for storing and accessing the IR2Vec vocabulary.
140-
/// Encapsulates all vocabulary-related constants, logic, and access methods.
141+
///
142+
/// The Vocabulary class manages seed embeddings for LLVM IR entities. It
143+
/// contains the seed embeddings for three types of entities: instruction
144+
/// opcodes, types, and operands. Types are grouped/canonicalized for better
145+
/// learning (e.g., all float variants map to FloatTy). The vocabulary abstracts
146+
/// away the canonicalization effectively, the exposed APIs handle all the known
147+
/// LLVM IR opcodes, types and operands.
148+
///
149+
/// This class helps populate the seed embeddings in an internal vector-based
150+
/// ADT. It provides logic to map every IR entity to a specific slot index or
151+
/// position in this vector, enabling O(1) embedding lookup while avoiding
152+
/// unnecessary computations involving string based lookups while generating the
153+
/// embeddings.
141154
class Vocabulary {
142155
friend class llvm::IR2VecVocabAnalysis;
143156
using VocabVector = std::vector<ir2vec::Embedding>;
144157
VocabVector Vocab;
145158
bool Valid = false;
146159

160+
public:
161+
// Slot layout:
162+
// [0 .. MaxOpcodes-1] => Instruction opcodes
163+
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
164+
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
165+
166+
/// Canonical type IDs supported by IR2Vec Vocabulary
167+
enum class CanonicalTypeID : unsigned {
168+
FloatTy,
169+
VoidTy,
170+
LabelTy,
171+
MetadataTy,
172+
VectorTy,
173+
TokenTy,
174+
IntegerTy,
175+
FunctionTy,
176+
PointerTy,
177+
StructTy,
178+
ArrayTy,
179+
UnknownTy,
180+
MaxCanonicalType
181+
};
182+
147183
/// Operand kinds supported by IR2Vec Vocabulary
148184
enum class OperandKind : unsigned {
149185
FunctionID,
@@ -152,20 +188,15 @@ class Vocabulary {
152188
VariableID,
153189
MaxOperandKind
154190
};
155-
/// String mappings for OperandKind values
156-
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
157-
"Constant", "Variable"};
158-
static_assert(std::size(OperandKindNames) ==
159-
static_cast<unsigned>(OperandKind::MaxOperandKind),
160-
"OperandKindNames array size must match MaxOperandKind");
161191

162-
public:
163192
/// Vocabulary layout constants
164193
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
165194
#include "llvm/IR/Instruction.def"
166195
#undef LAST_OTHER_INST
167196

168197
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
198+
static constexpr unsigned MaxCanonicalTypeIDs =
199+
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
169200
static constexpr unsigned MaxOperandKinds =
170201
static_cast<unsigned>(OperandKind::MaxOperandKind);
171202

@@ -174,11 +205,8 @@ class Vocabulary {
174205

175206
LLVM_ABI bool isValid() const;
176207
LLVM_ABI unsigned getDimension() const;
177-
LLVM_ABI size_t size() const;
178-
179-
static size_t expectedSize() {
180-
return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
181-
}
208+
/// Total number of entries (opcodes + canonicalized types + operand kinds)
209+
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
182210

183211
/// Helper function to get vocabulary key for a given Opcode
184212
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
@@ -192,10 +220,11 @@ class Vocabulary {
192220
/// Helper function to classify an operand into OperandKind
193221
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
194222

195-
/// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
196-
LLVM_ABI static unsigned getNumericID(unsigned Opcode);
197-
LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
198-
LLVM_ABI static unsigned getNumericID(const Value *Op);
223+
/// Helpers to return the slot index or position of a given Opcode, TypeID, or
224+
/// OperandKind in the vocabulary.
225+
LLVM_ABI static unsigned getSlotIdx(unsigned Opcode);
226+
LLVM_ABI static unsigned getSlotIdx(Type::TypeID TypeID);
227+
LLVM_ABI static unsigned getSlotIdx(const Value *Op);
199228

200229
/// Accessors to get the embedding for a given entity.
201230
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
@@ -234,6 +263,61 @@ class Vocabulary {
234263

235264
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
236265
ModuleAnalysisManager::Invalidator &Inv) const;
266+
267+
private:
268+
constexpr static unsigned NumCanonicalEntries =
269+
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
270+
271+
/// String mappings for CanonicalTypeID values
272+
static constexpr StringLiteral CanonicalTypeNames[] = {
273+
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
274+
"VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
275+
"PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
276+
static_assert(std::size(CanonicalTypeNames) ==
277+
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
278+
"CanonicalTypeNames array size must match MaxCanonicalType");
279+
280+
/// String mappings for OperandKind values
281+
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
282+
"Constant", "Variable"};
283+
static_assert(std::size(OperandKindNames) ==
284+
static_cast<unsigned>(OperandKind::MaxOperandKind),
285+
"OperandKindNames array size must match MaxOperandKind");
286+
287+
/// Every known TypeID defined in llvm/IR/Type.h is expected to have a
288+
/// corresponding mapping here in the same order as enum Type::TypeID.
289+
static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
290+
CanonicalTypeID::FloatTy, // HalfTyID = 0
291+
CanonicalTypeID::FloatTy, // BFloatTyID
292+
CanonicalTypeID::FloatTy, // FloatTyID
293+
CanonicalTypeID::FloatTy, // DoubleTyID
294+
CanonicalTypeID::FloatTy, // X86_FP80TyID
295+
CanonicalTypeID::FloatTy, // FP128TyID
296+
CanonicalTypeID::FloatTy, // PPC_FP128TyID
297+
CanonicalTypeID::VoidTy, // VoidTyID
298+
CanonicalTypeID::LabelTy, // LabelTyID
299+
CanonicalTypeID::MetadataTy, // MetadataTyID
300+
CanonicalTypeID::VectorTy, // X86_AMXTyID
301+
CanonicalTypeID::TokenTy, // TokenTyID
302+
CanonicalTypeID::IntegerTy, // IntegerTyID
303+
CanonicalTypeID::FunctionTy, // FunctionTyID
304+
CanonicalTypeID::PointerTy, // PointerTyID
305+
CanonicalTypeID::StructTy, // StructTyID
306+
CanonicalTypeID::ArrayTy, // ArrayTyID
307+
CanonicalTypeID::VectorTy, // FixedVectorTyID
308+
CanonicalTypeID::VectorTy, // ScalableVectorTyID
309+
CanonicalTypeID::PointerTy, // TypedPointerTyID
310+
CanonicalTypeID::UnknownTy // TargetExtTyID
311+
}};
312+
static_assert(TypeIDMapping.size() == MaxTypeIDs,
313+
"TypeIDMapping must cover all Type::TypeID values");
314+
315+
/// Helper function to get vocabulary key for canonical type by enum
316+
LLVM_ABI static StringRef
317+
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
318+
319+
/// Helper function to convert TypeID to CanonicalTypeID
320+
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
237321
};
238322

239323
/// Embedder provides the interface to generate embeddings (vector

0 commit comments

Comments
 (0)