Skip to content

Commit 0ebc739

Browse files
committed
MIRVocabulary changes
1 parent 747fe0f commit 0ebc739

File tree

3 files changed

+52
-25
lines changed

3 files changed

+52
-25
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
///
99
/// \file
1010
/// This file defines the MIR2Vec vocabulary
11-
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
12-
/// for generating Machine IR embeddings, and related utilities.
11+
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
12+
/// interface for generating Machine IR embeddings, and related utilities.
1313
///
1414
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
1515
/// LLVM Machine IR as embeddings which can be used as input to machine learning
@@ -71,25 +71,31 @@ class MIRVocabulary {
7171
unsigned TotalEntries = 0;
7272
} Layout;
7373

74+
enum class Section : unsigned { Opcodes = 0, MaxSections };
75+
7476
ir2vec::VocabStorage Storage;
7577
mutable std::set<std::string> UniqueBaseOpcodeNames;
76-
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
77-
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
78+
const TargetInstrInfo &TII;
79+
void generateStorage(const VocabMap &OpcodeMap);
80+
void buildCanonicalOpcodeMapping();
81+
82+
/// Get canonical index for a machine opcode
83+
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
7884

7985
public:
80-
/// Static helper method for extracting base opcode names (public for testing)
86+
/// Static method for extracting base opcode names (public for testing)
8187
static std::string extractBaseOpcodeName(StringRef InstrName);
8288

83-
/// Helper method for getting canonical index for base name (public for
84-
/// testing)
89+
/// Get canonical index for base name (public for testing)
8590
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
8691

8792
/// Get the string key for a vocabulary entry at the given position
8893
std::string getStringKey(unsigned Pos) const;
8994

90-
MIRVocabulary() = default;
95+
MIRVocabulary() = delete;
9196
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
92-
MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
97+
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
98+
: Storage(std::move(Storage)), TII(TII) {}
9399

94100
bool isValid() const {
95101
return UniqueBaseOpcodeNames.size() > 0 &&
@@ -103,11 +109,10 @@ class MIRVocabulary {
103109
}
104110

105111
// Accessor methods
106-
const Embedding &operator[](unsigned Index) const {
112+
const Embedding &operator[](unsigned Opcode) const {
107113
assert(isValid() && "MIR2Vec Vocabulary is invalid");
108-
assert(Index < Layout.TotalEntries && "Index out of bounds");
109-
// Fixme: For now, use section 0 for all entries
110-
return Storage[0][Index];
114+
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
115+
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
111116
}
112117

113118
// Iterator access

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,19 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
4949
//===----------------------------------------------------------------------===//
5050

5151
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
52-
const TargetInstrInfo *TII) {
52+
const TargetInstrInfo *TII)
53+
: TII(*TII) {
5354
// Early return for invalid inputs - creates empty/invalid vocabulary
5455
if (!TII || OpcodeEntries.empty())
5556
return;
5657

57-
buildCanonicalOpcodeMapping(*TII);
58+
buildCanonicalOpcodeMapping();
5859

5960
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
6061
assert(CanonicalOpcodeCount > 0 &&
6162
"No canonical opcodes found for target - invalid vocabulary");
6263
Layout.OperandBase = CanonicalOpcodeCount;
63-
generateStorage(OpcodeEntries, *TII);
64+
generateStorage(OpcodeEntries);
6465
Layout.TotalEntries = Storage.size();
6566
}
6667

@@ -103,6 +104,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
103104
return std::distance(UniqueBaseOpcodeNames.begin(), It);
104105
}
105106

107+
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
108+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
109+
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
110+
return getCanonicalIndexForBaseName(BaseOpcode);
111+
}
112+
106113
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
107114
assert(isValid() && "MIR2Vec Vocabulary is invalid");
108115
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
@@ -119,8 +126,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const {
119126
return "";
120127
}
121128

122-
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
123-
const TargetInstrInfo &TII) {
129+
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
124130

125131
// Helper for handling missing entities in the vocabulary.
126132
// Currently, we use a zero vector. In the future, we will throw an error to
@@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
168174
new (&Storage) ir2vec::VocabStorage(std::move(Sections));
169175
}
170176

171-
void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) {
177+
void MIRVocabulary::buildCanonicalOpcodeMapping() {
172178
// Check if already built
173179
if (!UniqueBaseOpcodeNames.empty())
174180
return;

llvm/unittests/CodeGen/MIR2VecTest.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
9393
}
9494
};
9595

96+
// Function to find an opcode by name
97+
static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
98+
for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
99+
if (TII->getName(Opcode) == Name)
100+
return Opcode;
101+
}
102+
return -1; // Not found
103+
}
104+
96105
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
97106
// Test that same base opcodes get same canonical indices
98107
std::string baseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -138,9 +147,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
138147
6880u); // X86 has >6880 unique base opcodes
139148

140149
// Check that the embeddings for opcodes not in the vocab are zero vectors
141-
EXPECT_TRUE(testVocab[addIndex].approximatelyEquals(Val));
142-
EXPECT_TRUE(testVocab[subIndex].approximatelyEquals(Embedding(64, 0.0f)));
143-
EXPECT_TRUE(testVocab[movIndex].approximatelyEquals(Embedding(64, 0.0f)));
150+
int add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
151+
ASSERT_NE(add32rrOpcode, -1) << "ADD32rr opcode not found";
152+
EXPECT_TRUE(testVocab[add32rrOpcode].approximatelyEquals(Val));
153+
154+
int sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
155+
ASSERT_NE(sub32rrOpcode, -1) << "SUB32rr opcode not found";
156+
EXPECT_TRUE(
157+
testVocab[sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
158+
159+
int mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
160+
ASSERT_NE(mov32rrOpcode, -1) << "MOV32rr opcode not found";
161+
EXPECT_TRUE(
162+
testVocab[mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
144163
}
145164

146165
// Test deterministic mapping
@@ -170,9 +189,6 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
170189

171190
// Test MIRVocabulary construction
172191
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
173-
// Test empty MIRVocabulary
174-
MIRVocabulary emptyVocab;
175-
EXPECT_FALSE(emptyVocab.isValid());
176192

177193
// Test MIRVocabulary with embeddings via VocabMap
178194
VocabMap vocabMap;

0 commit comments

Comments
 (0)