Skip to content

Commit ffc59e7

Browse files
committed
MIRVocabulary changes
1 parent 3f44240 commit ffc59e7

File tree

3 files changed

+52
-30
lines changed

3 files changed

+52
-30
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
///
99
/// \file
1010
/// This file defines the MIR2Vec vocabulary
11-
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
12-
/// for generating Machine IR embeddings, and related utilities.
11+
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::MIREmbedder
12+
/// interface for generating Machine IR embeddings, and related utilities.
1313
///
1414
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
1515
/// LLVM Machine IR as embeddings which can be used as input to machine learning
@@ -71,25 +71,31 @@ class MIRVocabulary {
7171
size_t TotalEntries = 0;
7272
} Layout;
7373

74+
enum class Section : unsigned { Opcodes = 0, MaxSections };
75+
7476
ir2vec::VocabStorage Storage;
7577
mutable std::set<std::string> UniqueBaseOpcodeNames;
76-
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
77-
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
78+
const TargetInstrInfo &TII;
79+
void generateStorage(const VocabMap &OpcodeMap);
80+
void buildCanonicalOpcodeMapping();
81+
82+
/// Get canonical index for a machine opcode
83+
unsigned getCanonicalOpcodeIndex(unsigned Opcode) const;
7884

7985
public:
80-
/// Static helper method for extracting base opcode names (public for testing)
86+
/// Static method for extracting base opcode names (public for testing)
8187
static std::string extractBaseOpcodeName(StringRef InstrName);
8288

83-
/// Helper method for getting canonical index for base name (public for
84-
/// testing)
89+
/// Get canonical index for base name (public for testing)
8590
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
8691

8792
/// Get the string key for a vocabulary entry at the given position
8893
std::string getStringKey(unsigned Pos) const;
8994

90-
MIRVocabulary() = default;
95+
MIRVocabulary() = delete;
9196
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
92-
MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
97+
MIRVocabulary(ir2vec::VocabStorage &&Storage, const TargetInstrInfo &TII)
98+
: Storage(std::move(Storage)), TII(TII) {}
9399

94100
bool isValid() const {
95101
return UniqueBaseOpcodeNames.size() > 0 &&
@@ -103,11 +109,10 @@ class MIRVocabulary {
103109
}
104110

105111
// Accessor methods
106-
const Embedding &operator[](unsigned Index) const {
112+
const Embedding &operator[](unsigned Opcode) const {
107113
assert(isValid() && "MIR2Vec Vocabulary is invalid");
108-
assert(Index < Layout.TotalEntries && "Index out of bounds");
109-
// Fixme: For now, use section 0 for all entries
110-
return Storage[0][Index];
114+
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
115+
return Storage[static_cast<unsigned>(Section::Opcodes)][LocalIndex];
111116
}
112117

113118
// Iterator access

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,19 @@ cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
4949
//===----------------------------------------------------------------------===//
5050

5151
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
52-
const TargetInstrInfo *TII) {
52+
const TargetInstrInfo *TII)
53+
: TII(*TII) {
5354
// Early return for invalid inputs - creates empty/invalid vocabulary
5455
if (!TII || OpcodeEntries.empty())
5556
return;
5657

57-
buildCanonicalOpcodeMapping(*TII);
58+
buildCanonicalOpcodeMapping();
5859

5960
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
6061
assert(CanonicalOpcodeCount > 0 &&
6162
"No canonical opcodes found for target - invalid vocabulary");
6263
Layout.OperandBase = CanonicalOpcodeCount;
63-
generateStorage(OpcodeEntries, *TII);
64+
generateStorage(OpcodeEntries);
6465
Layout.TotalEntries = Storage.size();
6566
}
6667

@@ -103,6 +104,12 @@ unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
103104
return std::distance(UniqueBaseOpcodeNames.begin(), It);
104105
}
105106

107+
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
108+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
109+
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
110+
return getCanonicalIndexForBaseName(BaseOpcode);
111+
}
112+
106113
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
107114
assert(isValid() && "MIR2Vec Vocabulary is invalid");
108115
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
@@ -119,8 +126,7 @@ std::string MIRVocabulary::getStringKey(unsigned Pos) const {
119126
return "";
120127
}
121128

122-
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
123-
const TargetInstrInfo &TII) {
129+
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
124130

125131
// Helper for handling missing entities in the vocabulary.
126132
// Currently, we use a zero vector. In the future, we will throw an error to
@@ -168,7 +174,7 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap,
168174
new (&Storage) ir2vec::VocabStorage(std::move(Sections));
169175
}
170176

171-
void MIRVocabulary::buildCanonicalOpcodeMapping(const TargetInstrInfo &TII) {
177+
void MIRVocabulary::buildCanonicalOpcodeMapping() {
172178
// Check if already built
173179
if (!UniqueBaseOpcodeNames.empty())
174180
return;

llvm/unittests/CodeGen/MIR2VecTest.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ class MIR2VecVocabTestFixture : public ::testing::Test {
8787
}
8888
};
8989

90+
// Function to find an opcode by name
91+
static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
92+
for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
93+
if (TII->getName(Opcode) == Name)
94+
return Opcode;
95+
}
96+
return -1; // Not found
97+
}
98+
9099
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
91100
// Test that same base opcodes get same canonical indices
92101
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
@@ -132,9 +141,19 @@ TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
132141
6880u); // X86 has >6880 unique base opcodes
133142

134143
// Check that the embeddings for opcodes not in the vocab are zero vectors
135-
EXPECT_TRUE(TestVocab[AddIndex].approximatelyEquals(Val));
136-
EXPECT_TRUE(TestVocab[SubIndex].approximatelyEquals(Embedding(64, 0.0f)));
137-
EXPECT_TRUE(TestVocab[MovIndex].approximatelyEquals(Embedding(64, 0.0f)));
144+
int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
145+
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
146+
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
147+
148+
int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
149+
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
150+
EXPECT_TRUE(
151+
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
152+
153+
int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
154+
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
155+
EXPECT_TRUE(
156+
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
138157
}
139158

140159
// Test deterministic mapping
@@ -164,15 +183,7 @@ TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
164183

165184
// Test MIRVocabulary construction
166185
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
167-
// Test empty MIRVocabulary
168-
MIRVocabulary EmptyVocab;
169-
EXPECT_FALSE(EmptyVocab.isValid());
170-
171-
// Test MIRVocabulary with embeddings via VocabMap
172186
VocabMap VM;
173-
VM["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
174-
VM["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
175-
176187
MIRVocabulary Vocab(std::move(VM), TII);
177188
EXPECT_TRUE(Vocab.isValid());
178189
EXPECT_EQ(Vocab.getDimension(), 128u);

0 commit comments

Comments
 (0)