@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
243
243
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast <unsigned >(ArgKind)];
244
244
}
245
245
246
+ StringRef Vocabulary::getVocabKeyForOpcode (unsigned Opcode) {
247
+ assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
248
+ #define HANDLE_INST (NUM, OPCODE, CLASS ) \
249
+ if (Opcode == NUM) { \
250
+ return #OPCODE; \
251
+ }
252
+ #include " llvm/IR/Instruction.def"
253
+ #undef HANDLE_INST
254
+ return " UnknownOpcode" ;
255
+ }
256
+
246
257
StringRef Vocabulary::getVocabKeyForTypeID (Type::TypeID TypeID) {
247
258
switch (TypeID) {
248
259
case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
280
291
default :
281
292
return " UnknownTy" ;
282
293
}
294
+ return " UnknownTy" ;
283
295
}
284
296
285
297
StringRef Vocabulary::getVocabKeyForOperandKind (Vocabulary::OperandKind Kind) {
@@ -316,14 +328,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
316
328
assert (Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
317
329
" Position out of bounds in vocabulary" );
318
330
// Opcode
319
- if (Pos < MaxOpcodes) {
320
- #define HANDLE_INST (NUM, OPCODE, CLASS ) \
321
- if (Pos == NUM - 1 ) { \
322
- return #OPCODE; \
323
- }
324
- #include " llvm/IR/Instruction.def"
325
- #undef HANDLE_INST
326
- }
331
+ if (Pos < MaxOpcodes)
332
+ return getVocabKeyForOpcode (Pos + 1 );
327
333
// Type
328
334
if (Pos < MaxOpcodes + MaxTypeIDs)
329
335
return getVocabKeyForTypeID (static_cast <Type::TypeID>(Pos - MaxOpcodes));
@@ -431,21 +437,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
431
437
// Handle Opcodes
432
438
std::vector<Embedding> NumericOpcodeEmbeddings (Vocabulary::MaxOpcodes,
433
439
Embedding (Dim, 0 ));
434
- # define HANDLE_INST ( NUM, OPCODE, CLASS ) \
435
- { \
436
- auto It = OpcVocab.find (#OPCODE); \
437
- if (It != OpcVocab.end ()) \
438
- NumericOpcodeEmbeddings[NUM - 1 ] = It->second ; \
439
- else \
440
- handleMissingEntity (#OPCODE); \
440
+ for ( unsigned Opcode : seq ( 0u , Vocabulary::MaxOpcodes)) {
441
+ StringRef VocabKey = Vocabulary::getVocabKeyForOpcode (Opcode + 1 );
442
+ auto It = OpcVocab.find (VocabKey. str ());
443
+ if (It != OpcVocab.end ())
444
+ NumericOpcodeEmbeddings[Opcode ] = It->second ;
445
+ else
446
+ handleMissingEntity (VocabKey. str ());
441
447
}
442
- #include " llvm/IR/Instruction.def"
443
- #undef HANDLE_INST
444
448
Vocab.insert (Vocab.end (), NumericOpcodeEmbeddings.begin (),
445
449
NumericOpcodeEmbeddings.end ());
446
450
447
- // Handle Types using direct iteration through TypeID enum
448
- // We iterate through all possible TypeID values and map them to embeddings
451
+ // Handle Types
449
452
std::vector<Embedding> NumericTypeEmbeddings (Vocabulary::MaxTypeIDs,
450
453
Embedding (Dim, 0 ));
451
454
for (unsigned TypeID : seq (0u , Vocabulary::MaxTypeIDs)) {
0 commit comments