@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243
243
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast <unsigned >(ArgKind)];
244
244
}
245
245
246
+ StringRef Vocabulary::getVocabKeyForOpcode (unsigned Opcode) {
247
+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
248
+ #define HANDLE_INST (NUM, OPCODE, CLASS ) \
249
+ if (Opcode == NUM) { \
250
+ return #OPCODE; \
251
+ }
252
+ #include " llvm/IR/Instruction.def"
253
+ #undef HANDLE_INST
254
+ return " UnknownOpcode" ;
255
+ }
256
+
246
257
StringRef Vocabulary::getVocabKeyForTypeID (Type::TypeID TypeID) {
247
258
switch (TypeID) {
248
259
case Type::VoidTyID:
@@ -279,6 +290,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
279
290
case Type::TargetExtTyID:
280
291
return " UnknownTy" ;
281
292
}
293
+ return " UnknownTy" ;
282
294
}
283
295
284
296
StringRef Vocabulary::getVocabKeyForOperandKind (Vocabulary::OperandKind Kind) {
@@ -315,14 +327,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
315
327
assert (Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
316
328
" Position out of bounds in vocabulary" );
317
329
// Opcode
318
- if (Pos < MaxOpcodes) {
319
- #define HANDLE_INST (NUM, OPCODE, CLASS ) \
320
- if (Pos == NUM - 1 ) { \
321
- return #OPCODE; \
322
- }
323
- #include " llvm/IR/Instruction.def"
324
- #undef HANDLE_INST
325
- }
330
+ if (Pos < MaxOpcodes)
331
+ return getVocabKeyForOpcode (Pos + 1 );
326
332
// Type
327
333
if (Pos < MaxOpcodes + MaxTypeIDs)
328
334
return getVocabKeyForTypeID (static_cast <Type::TypeID>(Pos - MaxOpcodes));
@@ -430,21 +436,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
430
436
// Handle Opcodes
431
437
std::vector<Embedding> NumericOpcodeEmbeddings (Vocabulary::MaxOpcodes,
432
438
Embedding (Dim, 0 ));
433
- # define HANDLE_INST ( NUM, OPCODE, CLASS ) \
434
- { \
435
- auto It = OpcVocab.find (#OPCODE); \
436
- if (It != OpcVocab.end ()) \
437
- NumericOpcodeEmbeddings[NUM - 1 ] = It->second ; \
438
- else \
439
- handleMissingEntity (#OPCODE); \
439
+ for ( unsigned Opcode : seq ( 0u , Vocabulary::MaxOpcodes)) {
440
+ StringRef VocabKey = Vocabulary::getVocabKeyForOpcode (Opcode + 1 );
441
+ auto It = OpcVocab.find (VocabKey. str ());
442
+ if (It != OpcVocab.end ())
443
+ NumericOpcodeEmbeddings[Opcode ] = It->second ;
444
+ else
445
+ handleMissingEntity (VocabKey. str ());
440
446
}
441
- #include " llvm/IR/Instruction.def"
442
- #undef HANDLE_INST
443
447
Vocab.insert (Vocab.end (), NumericOpcodeEmbeddings.begin (),
444
448
NumericOpcodeEmbeddings.end ());
445
449
446
- // Handle Types using direct iteration through TypeID enum
447
- // We iterate through all possible TypeID values and map them to embeddings
450
+ // Handle Types
448
451
std::vector<Embedding> NumericTypeEmbeddings (Vocabulary::MaxTypeIDs,
449
452
Embedding (Dim, 0 ));
450
453
for (unsigned TypeID : seq (0u , Vocabulary::MaxTypeIDs)) {
0 commit comments