Skip to content

Commit a466c28

Browse files
committed
[NFC][IR2Vec] Minor refactoring of opcode access in vocabulary
1 parent d1db176 commit a466c28

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,18 @@ class Vocabulary {
162162
static constexpr unsigned MaxOperandKinds =
163163
static_cast<unsigned>(OperandKind::MaxOperandKind);
164164

165+
/// Helper function to get vocabulary key for a given Opcode
166+
static StringRef getVocabKeyForOpcode(unsigned Opcode);
167+
168+
/// Helper function to get vocabulary key for a given TypeID
169+
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
170+
165171
/// Helper function to get vocabulary key for a given OperandKind
166172
static StringRef getVocabKeyForOperandKind(OperandKind Kind);
167173

168174
/// Helper function to classify an operand into OperandKind
169175
static OperandKind getOperandKind(const Value *Op);
170176

171-
/// Helper function to get vocabulary key for a given TypeID
172-
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
173-
174177
public:
175178
Vocabulary() = default;
176179
Vocabulary(VocabVector &&Vocab);

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243243
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
244244
}
245245

246+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
247+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
248+
#define HANDLE_INST(NUM, OPCODE, CLASS) \
249+
if (Opcode == NUM) { \
250+
return #OPCODE; \
251+
}
252+
#include "llvm/IR/Instruction.def"
253+
#undef HANDLE_INST
254+
return "UnknownOpcode";
255+
}
256+
246257
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
247258
switch (TypeID) {
248259
case Type::VoidTyID:
@@ -279,6 +290,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
279290
case Type::TargetExtTyID:
280291
return "UnknownTy";
281292
}
293+
return "UnknownTy";
282294
}
283295

284296
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -315,14 +327,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
315327
assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
316328
"Position out of bounds in vocabulary");
317329
// Opcode
318-
if (Pos < MaxOpcodes) {
319-
#define HANDLE_INST(NUM, OPCODE, CLASS) \
320-
if (Pos == NUM - 1) { \
321-
return #OPCODE; \
322-
}
323-
#include "llvm/IR/Instruction.def"
324-
#undef HANDLE_INST
325-
}
330+
if (Pos < MaxOpcodes)
331+
return getVocabKeyForOpcode(Pos + 1);
326332
// Type
327333
if (Pos < MaxOpcodes + MaxTypeIDs)
328334
return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
@@ -430,21 +436,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
430436
// Handle Opcodes
431437
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
432438
Embedding(Dim, 0));
433-
#define HANDLE_INST(NUM, OPCODE, CLASS) \
434-
{ \
435-
auto It = OpcVocab.find(#OPCODE); \
436-
if (It != OpcVocab.end()) \
437-
NumericOpcodeEmbeddings[NUM - 1] = It->second; \
438-
else \
439-
handleMissingEntity(#OPCODE); \
439+
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
440+
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
441+
auto It = OpcVocab.find(VocabKey.str());
442+
if (It != OpcVocab.end())
443+
NumericOpcodeEmbeddings[Opcode] = It->second;
444+
else
445+
handleMissingEntity(VocabKey.str());
440446
}
441-
#include "llvm/IR/Instruction.def"
442-
#undef HANDLE_INST
443447
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
444448
NumericOpcodeEmbeddings.end());
445449

446-
// Handle Types using direct iteration through TypeID enum
447-
// We iterate through all possible TypeID values and map them to embeddings
450+
// Handle Types
448451
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
449452
Embedding(Dim, 0));
450453
for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {

0 commit comments

Comments
 (0)