88// /
99// / \file
1010// / This file defines the MIR2Vec vocabulary
11- // / analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
12- // / for generating Machine IR embeddings, and related utilities.
11+ // / analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
12+ // / interface for generating Machine IR embeddings, and related utilities.
1313// /
1414// / MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
1515// / LLVM Machine IR as embeddings which can be used as input to machine learning
@@ -71,25 +71,31 @@ class MIRVocabulary {
7171 size_t TotalEntries = 0 ;
7272 } Layout;
7373
74+ enum class Section : unsigned { Opcodes = 0 , MaxSections };
75+
7476 ir2vec::VocabStorage Storage;
7577 mutable std::set<std::string> UniqueBaseOpcodeNames;
76- void generateStorage (const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
77- void buildCanonicalOpcodeMapping (const TargetInstrInfo &TII);
78+ const TargetInstrInfo &TII;
79+ void generateStorage (const VocabMap &OpcodeMap);
80+ void buildCanonicalOpcodeMapping ();
81+
82+ // / Get canonical index for a machine opcode
83+ unsigned getCanonicalOpcodeIndex (unsigned Opcode) const ;
7884
7985public:
80- // / Static helper method for extracting base opcode names (public for testing)
86+ // / Static method for extracting base opcode names (public for testing)
8187 static std::string extractBaseOpcodeName (StringRef InstrName);
8288
83- // / Helper method for getting canonical index for base name (public for
84- // / testing)
89+ // / Get canonical index for base name (public for testing)
8590 unsigned getCanonicalIndexForBaseName (StringRef BaseName) const ;
8691
8792 // / Get the string key for a vocabulary entry at the given position
8893 std::string getStringKey (unsigned Pos) const ;
8994
90- MIRVocabulary () = default ;
95+ MIRVocabulary () = delete ;
9196 MIRVocabulary (VocabMap &&Entries, const TargetInstrInfo *TII);
92- MIRVocabulary (ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
97+ MIRVocabulary (ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
98+ : Storage(std::move(Storage)), TII(TII) {}
9399
94100 bool isValid () const {
95101 return UniqueBaseOpcodeNames.size () > 0 &&
@@ -103,11 +109,10 @@ class MIRVocabulary {
103109 }
104110
105111 // Accessor methods
106- const Embedding &operator [](unsigned Index ) const {
112+ const Embedding &operator [](unsigned Opcode ) const {
107113 assert (isValid () && " MIR2Vec Vocabulary is invalid" );
108- assert (Index < Layout.TotalEntries && " Index out of bounds" );
109- // Fixme: For now, use section 0 for all entries
110- return Storage[0 ][Index];
114+ unsigned LocalIndex = getCanonicalOpcodeIndex (Opcode);
115+ return Storage[static_cast <unsigned >(Section::Opcodes)][LocalIndex];
111116 }
112117
113118 // Iterator access
0 commit comments