Skip to content

Commit 5eaecce

Browse files
committed
[NFC][IR2Vec] Minor refactoring of opcode access in vocabulary
1 parent 237e4d2 commit 5eaecce

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,18 @@ class Vocabulary {
162162
static constexpr unsigned MaxOperandKinds =
163163
static_cast<unsigned>(OperandKind::MaxOperandKind);
164164

165+
/// Helper function to get vocabulary key for a given Opcode
166+
static StringRef getVocabKeyForOpcode(unsigned Opcode);
167+
168+
/// Helper function to get vocabulary key for a given TypeID
169+
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
170+
165171
/// Helper function to get vocabulary key for a given OperandKind
166172
static StringRef getVocabKeyForOperandKind(OperandKind Kind);
167173

168174
/// Helper function to classify an operand into OperandKind
169175
static OperandKind getOperandKind(const Value *Op);
170176

171-
/// Helper function to get vocabulary key for a given TypeID
172-
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
173-
174177
public:
175178
Vocabulary() = default;
176179
Vocabulary(VocabVector &&Vocab);

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243243
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
244244
}
245245

246+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
247+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
248+
#define HANDLE_INST(NUM, OPCODE, CLASS) \
249+
if (Opcode == NUM) { \
250+
return #OPCODE; \
251+
}
252+
#include "llvm/IR/Instruction.def"
253+
#undef HANDLE_INST
254+
return "UnknownOpcode";
255+
}
256+
246257
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
247258
switch (TypeID) {
248259
case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
280291
default:
281292
return "UnknownTy";
282293
}
294+
return "UnknownTy";
283295
}
284296

285297
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -316,14 +328,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
316328
assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
317329
"Position out of bounds in vocabulary");
318330
// Opcode
319-
if (Pos < MaxOpcodes) {
320-
#define HANDLE_INST(NUM, OPCODE, CLASS) \
321-
if (Pos == NUM - 1) { \
322-
return #OPCODE; \
323-
}
324-
#include "llvm/IR/Instruction.def"
325-
#undef HANDLE_INST
326-
}
331+
if (Pos < MaxOpcodes)
332+
return getVocabKeyForOpcode(Pos + 1);
327333
// Type
328334
if (Pos < MaxOpcodes + MaxTypeIDs)
329335
return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
@@ -431,21 +437,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
431437
// Handle Opcodes
432438
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
433439
Embedding(Dim, 0));
434-
#define HANDLE_INST(NUM, OPCODE, CLASS) \
435-
{ \
436-
auto It = OpcVocab.find(#OPCODE); \
437-
if (It != OpcVocab.end()) \
438-
NumericOpcodeEmbeddings[NUM - 1] = It->second; \
439-
else \
440-
handleMissingEntity(#OPCODE); \
440+
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
441+
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
442+
auto It = OpcVocab.find(VocabKey.str());
443+
if (It != OpcVocab.end())
444+
NumericOpcodeEmbeddings[Opcode] = It->second;
445+
else
446+
handleMissingEntity(VocabKey.str());
441447
}
442-
#include "llvm/IR/Instruction.def"
443-
#undef HANDLE_INST
444448
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
445449
NumericOpcodeEmbeddings.end());
446450

447-
// Handle Types using direct iteration through TypeID enum
448-
// We iterate through all possible TypeID values and map them to embeddings
451+
// Handle Types
449452
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
450453
Embedding(Dim, 0));
451454
for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {

0 commit comments

Comments
 (0)