Skip to content

Commit 39467a8

Browse files
committed
Introducing MIR2Vec
1 parent 6bfa56a commit 39467a8

File tree

14 files changed

+15218
-64
lines changed

14 files changed

+15218
-64
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ class VocabStorage {
210210
const_iterator end() const {
211211
return const_iterator(this, getNumSections(), 0);
212212
}
213+
using VocabMap = std::map<std::string, ir2vec::Embedding>;
214+
static Error parseVocabSection(StringRef Key,
215+
const json::Value &ParsedVocabValue,
216+
VocabMap &TargetVocab, unsigned &Dim);
213217
};
214218

215219
/// Class for storing and accessing the IR2Vec vocabulary.
@@ -600,8 +604,6 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
600604

601605
Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
602606
VocabMap &ArgVocab);
603-
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
604-
VocabMap &TargetVocab, unsigned &Dim);
605607
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
606608
VocabMap &ArgVocab);
607609
void emitError(Error Err, LLVMContext &Ctx);
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See the LICENSE file for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file defines the MIR2Vec vocabulary
11+
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
12+
/// for generating Machine IR embeddings, and related utilities.
13+
///
14+
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
15+
/// LLVM Machine IR as embeddings which can be used as input to machine learning
16+
/// algorithms.
17+
///
18+
/// The original idea of MIR2Vec is described in the following paper:
19+
///
20+
/// RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
21+
/// Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
22+
/// Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
23+
/// Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
24+
/// Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
25+
/// https://arxiv.org/abs/2204.02013
26+
///
27+
//===----------------------------------------------------------------------===//
28+
29+
#ifndef LLVM_CODEGEN_MIR2VEC_H
30+
#define LLVM_CODEGEN_MIR2VEC_H
31+
32+
#include "llvm/Analysis/IR2Vec.h"
33+
#include "llvm/CodeGen/MachineBasicBlock.h"
34+
#include "llvm/CodeGen/MachineFunction.h"
35+
#include "llvm/CodeGen/MachineFunctionPass.h"
36+
#include "llvm/CodeGen/MachineInstr.h"
37+
#include "llvm/CodeGen/MachineModuleInfo.h"
38+
#include "llvm/IR/PassManager.h"
39+
#include "llvm/Pass.h"
40+
#include "llvm/Support/CommandLine.h"
41+
#include "llvm/Support/ErrorOr.h"
42+
#include <map>
43+
#include <set>
44+
#include <string>
45+
46+
namespace llvm {
47+
48+
class Module;
49+
class raw_ostream;
50+
class LLVMContext;
51+
class MIR2VecVocabLegacyAnalysis;
52+
class TargetInstrInfo;
53+
54+
namespace mir2vec {
55+
extern llvm::cl::OptionCategory MIR2VecCategory;
56+
extern cl::opt<float> OpcWeight;
57+
58+
using Embedding = ir2vec::Embedding;
59+
60+
/// Class for storing and accessing the MIR2Vec vocabulary.
61+
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
62+
class MIRVocabulary {
63+
friend class llvm::MIR2VecVocabLegacyAnalysis;
64+
using VocabMap = std::map<std::string, ir2vec::Embedding>;
65+
66+
private:
67+
// Define vocabulary layout - adapted for MIR
68+
struct {
69+
unsigned OpcodeBase = 0;
70+
unsigned OperandBase = 0;
71+
unsigned TotalEntries = 0;
72+
} Layout;
73+
74+
ir2vec::VocabStorage Storage;
75+
mutable std::set<std::string> UniqueBaseOpcodeNames;
76+
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
77+
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
78+
79+
public:
80+
/// Static helper method for extracting base opcode names (public for testing)
81+
static std::string extractBaseOpcodeName(StringRef InstrName);
82+
83+
/// Helper method for getting canonical index for base name (public for
84+
/// testing)
85+
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
86+
87+
/// Get the string key for a vocabulary entry at the given position
88+
std::string getStringKey(unsigned Pos) const;
89+
90+
MIRVocabulary() = default;
91+
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
92+
MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
93+
94+
bool isValid() const {
95+
return UniqueBaseOpcodeNames.size() > 0 &&
96+
Layout.TotalEntries == Storage.size() && Storage.isValid();
97+
}
98+
99+
unsigned getDimension() const {
100+
if (!isValid())
101+
return 0;
102+
return Storage.getDimension();
103+
}
104+
105+
// Accessor methods
106+
const Embedding &operator[](unsigned Index) const {
107+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
108+
assert(Index < Layout.TotalEntries && "Index out of bounds");
109+
// Fixme: For now, use section 0 for all entries
110+
return Storage[0][Index];
111+
}
112+
113+
// Iterator access
114+
using const_iterator = ir2vec::VocabStorage::const_iterator;
115+
const_iterator begin() const {
116+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
117+
return Storage.begin();
118+
}
119+
120+
const_iterator end() const {
121+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
122+
return Storage.end();
123+
}
124+
125+
/// Total number of entries in the vocabulary
126+
size_t getCanonicalSize() const {
127+
assert(isValid() && "Invalid vocabulary");
128+
return Storage.size();
129+
}
130+
};
131+
132+
} // namespace mir2vec
133+
134+
/// Pass to analyze and populate MIR2Vec vocabulary from a module
135+
class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
136+
using VocabVector = std::vector<mir2vec::Embedding>;
137+
using VocabMap = std::map<std::string, mir2vec::Embedding>;
138+
VocabMap StrVocabMap;
139+
VocabVector Vocab;
140+
141+
StringRef getPassName() const override;
142+
Error readVocabulary();
143+
void emitError(Error Err, LLVMContext &Ctx);
144+
145+
protected:
146+
void getAnalysisUsage(AnalysisUsage &AU) const override {
147+
AU.addRequired<MachineModuleInfoWrapperPass>();
148+
AU.setPreservesAll();
149+
}
150+
151+
public:
152+
static char ID;
153+
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
154+
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
155+
};
156+
157+
/// This pass prints the embeddings in the MIR2Vec vocabulary
158+
class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
159+
raw_ostream &OS;
160+
161+
public:
162+
static char ID;
163+
explicit MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
164+
: MachineFunctionPass(ID), OS(OS) {}
165+
166+
bool runOnMachineFunction(MachineFunction &MF) override;
167+
bool doFinalization(Module &M) override;
168+
void getAnalysisUsage(AnalysisUsage &AU) const override {
169+
AU.addRequired<MIR2VecVocabLegacyAnalysis>();
170+
AU.setPreservesAll();
171+
MachineFunctionPass::getAnalysisUsage(AU);
172+
}
173+
174+
StringRef getPassName() const override {
175+
return "MIR2Vec Vocabulary Printer Pass";
176+
}
177+
};
178+
179+
} // namespace llvm
180+
181+
#endif // LLVM_CODEGEN_MIR2VEC_H

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ LLVM_ABI MachineFunctionPass *
8888
createMachineFunctionPrinterPass(raw_ostream &OS,
8989
const std::string &Banner = "");
9090

91+
/// MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary
92+
/// contents to the given stream as a debugging tool.
93+
LLVM_ABI MachineFunctionPass *
94+
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
95+
9196
/// StackFramePrinter pass - This pass prints out the machine function's
9297
/// stack frame to the given stream as a debugging tool.
9398
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();

llvm/include/llvm/InitializePasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ LLVM_ABI void initializeMachinePostDominatorTreeWrapperPassPass(PassRegistry &);
220220
LLVM_ABI void initializeMachineRegionInfoPassPass(PassRegistry &);
221221
LLVM_ABI void
222222
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
223+
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
224+
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
223225
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
224226
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
225227
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,43 @@ bool VocabStorage::const_iterator::operator!=(
330330
return !(*this == Other);
331331
}
332332

333+
Error VocabStorage::parseVocabSection(StringRef Key,
334+
const json::Value &ParsedVocabValue,
335+
VocabMap &TargetVocab, unsigned &Dim) {
336+
json::Path::Root Path("");
337+
const json::Object *RootObj = ParsedVocabValue.getAsObject();
338+
if (!RootObj)
339+
return createStringError(errc::invalid_argument,
340+
"JSON root is not an object");
341+
342+
const json::Value *SectionValue = RootObj->get(Key);
343+
if (!SectionValue)
344+
return createStringError(errc::invalid_argument,
345+
"Missing '" + std::string(Key) +
346+
"' section in vocabulary file");
347+
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
348+
return createStringError(errc::illegal_byte_sequence,
349+
"Unable to parse '" + std::string(Key) +
350+
"' section from vocabulary");
351+
352+
Dim = TargetVocab.begin()->second.size();
353+
if (Dim == 0)
354+
return createStringError(errc::illegal_byte_sequence,
355+
"Dimension of '" + std::string(Key) +
356+
"' section of the vocabulary is zero");
357+
358+
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
359+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
360+
return Entry.second.size() == Dim;
361+
}))
362+
return createStringError(
363+
errc::illegal_byte_sequence,
364+
"All vectors in the '" + std::string(Key) +
365+
"' section of the vocabulary are not of the same dimension");
366+
367+
return Error::success();
368+
}
369+
333370
// ==----------------------------------------------------------------------===//
334371
// Vocabulary
335372
//===----------------------------------------------------------------------===//
@@ -460,43 +497,6 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
460497
// IR2VecVocabAnalysis
461498
//===----------------------------------------------------------------------===//
462499

463-
Error IR2VecVocabAnalysis::parseVocabSection(
464-
StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
465-
unsigned &Dim) {
466-
json::Path::Root Path("");
467-
const json::Object *RootObj = ParsedVocabValue.getAsObject();
468-
if (!RootObj)
469-
return createStringError(errc::invalid_argument,
470-
"JSON root is not an object");
471-
472-
const json::Value *SectionValue = RootObj->get(Key);
473-
if (!SectionValue)
474-
return createStringError(errc::invalid_argument,
475-
"Missing '" + std::string(Key) +
476-
"' section in vocabulary file");
477-
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
478-
return createStringError(errc::illegal_byte_sequence,
479-
"Unable to parse '" + std::string(Key) +
480-
"' section from vocabulary");
481-
482-
Dim = TargetVocab.begin()->second.size();
483-
if (Dim == 0)
484-
return createStringError(errc::illegal_byte_sequence,
485-
"Dimension of '" + std::string(Key) +
486-
"' section of the vocabulary is zero");
487-
488-
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
489-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
490-
return Entry.second.size() == Dim;
491-
}))
492-
return createStringError(
493-
errc::illegal_byte_sequence,
494-
"All vectors in the '" + std::string(Key) +
495-
"' section of the vocabulary are not of the same dimension");
496-
497-
return Error::success();
498-
}
499-
500500
// FIXME: Make this optional. We can avoid file reads
501501
// by auto-generating a default vocabulary during the build time.
502502
Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
@@ -513,16 +513,16 @@ Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
513513
return ParsedVocabValue.takeError();
514514

515515
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
516-
if (auto Err =
517-
parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
516+
if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
517+
OpcVocab, OpcodeDim))
518518
return Err;
519519

520-
if (auto Err =
521-
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
520+
if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
521+
TypeVocab, TypeDim))
522522
return Err;
523523

524-
if (auto Err =
525-
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
524+
if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
525+
ArgVocab, ArgDim))
526526
return Err;
527527

528528
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))

0 commit comments

Comments
 (0)