Skip to content

Commit 049be7b

Browse files
committed
Introducing MIR2Vec
1 parent de9b3ca commit 049be7b

18 files changed

+15229
-19
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
//===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See the LICENSE file for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file defines the MIR2Vec vocabulary
11+
/// analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
12+
/// for generating Machine IR embeddings, and related utilities.
13+
///
14+
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
15+
/// LLVM Machine IR as embeddings which can be used as input to machine learning
16+
/// algorithms.
17+
///
18+
/// The original idea of MIR2Vec is described in the following paper:
19+
///
20+
/// RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
21+
/// Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
22+
/// Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
23+
/// Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
24+
/// Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
25+
/// https://arxiv.org/abs/2204.02013
26+
///
27+
//===----------------------------------------------------------------------===//
28+
29+
#ifndef LLVM_CODEGEN_MIR2VEC_H
30+
#define LLVM_CODEGEN_MIR2VEC_H
31+
32+
#include "llvm/Analysis/IR2Vec.h"
33+
#include "llvm/CodeGen/MachineBasicBlock.h"
34+
#include "llvm/CodeGen/MachineFunction.h"
35+
#include "llvm/CodeGen/MachineFunctionPass.h"
36+
#include "llvm/CodeGen/MachineInstr.h"
37+
#include "llvm/CodeGen/MachineModuleInfo.h"
38+
#include "llvm/IR/PassManager.h"
39+
#include "llvm/Pass.h"
40+
#include "llvm/Support/CommandLine.h"
41+
#include "llvm/Support/ErrorOr.h"
42+
#include <map>
43+
#include <set>
44+
#include <string>
45+
46+
namespace llvm {
47+
48+
class Module;
49+
class raw_ostream;
50+
class LLVMContext;
51+
class MIR2VecVocabLegacyAnalysis;
52+
class TargetInstrInfo;
53+
54+
namespace mir2vec {
55+
extern llvm::cl::OptionCategory MIR2VecCategory;
56+
extern cl::opt<float> OpcWeight;
57+
58+
using Embedding = ir2vec::Embedding;
59+
60+
/// Class for storing and accessing the MIR2Vec vocabulary.
61+
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
62+
class MIRVocabulary {
63+
friend class llvm::MIR2VecVocabLegacyAnalysis;
64+
using VocabMap = std::map<std::string, ir2vec::Embedding>;
65+
66+
private:
67+
// Define vocabulary layout - adapted for MIR
68+
struct {
69+
size_t OpcodeBase = 0;
70+
size_t OperandBase = 0;
71+
size_t TotalEntries = 0;
72+
} Layout;
73+
74+
ir2vec::VocabStorage Storage;
75+
mutable std::set<std::string> UniqueBaseOpcodeNames;
76+
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
77+
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
78+
79+
public:
80+
/// Static helper method for extracting base opcode names (public for testing)
81+
static std::string extractBaseOpcodeName(StringRef InstrName);
82+
83+
/// Helper method for getting canonical index for base name (public for
84+
/// testing)
85+
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
86+
87+
/// Get the string key for a vocabulary entry at the given position
88+
std::string getStringKey(unsigned Pos) const;
89+
90+
MIRVocabulary() = default;
91+
MIRVocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
92+
MIRVocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
93+
94+
bool isValid() const {
95+
return UniqueBaseOpcodeNames.size() > 0 &&
96+
Layout.TotalEntries == Storage.size() && Storage.isValid();
97+
}
98+
99+
unsigned getDimension() const {
100+
if (!isValid())
101+
return 0;
102+
return Storage.getDimension();
103+
}
104+
105+
// Accessor methods
106+
const Embedding &operator[](unsigned Index) const {
107+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
108+
assert(Index < Layout.TotalEntries && "Index out of bounds");
109+
// Fixme: For now, use section 0 for all entries
110+
return Storage[0][Index];
111+
}
112+
113+
// Iterator access
114+
using const_iterator = ir2vec::VocabStorage::const_iterator;
115+
const_iterator begin() const {
116+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
117+
return Storage.begin();
118+
}
119+
120+
const_iterator end() const {
121+
assert(isValid() && "MIR2Vec Vocabulary is invalid");
122+
return Storage.end();
123+
}
124+
125+
/// Total number of entries in the vocabulary
126+
size_t getCanonicalSize() const {
127+
assert(isValid() && "Invalid vocabulary");
128+
return Storage.size();
129+
}
130+
};
131+
132+
} // namespace mir2vec
133+
134+
/// Pass to analyze and populate MIR2Vec vocabulary from a module
135+
class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
136+
using VocabVector = std::vector<mir2vec::Embedding>;
137+
using VocabMap = std::map<std::string, mir2vec::Embedding>;
138+
VocabMap StrVocabMap;
139+
VocabVector Vocab;
140+
141+
StringRef getPassName() const override;
142+
Error readVocabulary();
143+
void emitError(Error Err, LLVMContext &Ctx);
144+
145+
protected:
146+
void getAnalysisUsage(AnalysisUsage &AU) const override {
147+
AU.addRequired<MachineModuleInfoWrapperPass>();
148+
AU.setPreservesAll();
149+
}
150+
151+
public:
152+
static char ID;
153+
MIR2VecVocabLegacyAnalysis() : ImmutablePass(ID) {}
154+
mir2vec::MIRVocabulary getMIR2VecVocabulary(const Module &M);
155+
};
156+
157+
/// This pass prints the embeddings in the MIR2Vec vocabulary
158+
class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
159+
raw_ostream &OS;
160+
161+
public:
162+
static char ID;
163+
explicit MIR2VecVocabPrinterLegacyPass(raw_ostream &OS)
164+
: MachineFunctionPass(ID), OS(OS) {}
165+
166+
bool runOnMachineFunction(MachineFunction &MF) override;
167+
bool doFinalization(Module &M) override;
168+
void getAnalysisUsage(AnalysisUsage &AU) const override {
169+
AU.addRequired<MIR2VecVocabLegacyAnalysis>();
170+
AU.setPreservesAll();
171+
MachineFunctionPass::getAnalysisUsage(AU);
172+
}
173+
174+
StringRef getPassName() const override {
175+
return "MIR2Vec Vocabulary Printer Pass";
176+
}
177+
};
178+
179+
} // namespace llvm
180+
181+
#endif // LLVM_CODEGEN_MIR2VEC_H

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ LLVM_ABI MachineFunctionPass *
8888
createMachineFunctionPrinterPass(raw_ostream &OS,
8989
const std::string &Banner = "");
9090

91+
/// MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary
92+
/// contents to the given stream as a debugging tool.
93+
LLVM_ABI MachineFunctionPass *
94+
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
95+
9196
/// StackFramePrinter pass - This pass prints out the machine function's
9297
/// stack frame to the given stream as a debugging tool.
9398
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();

llvm/include/llvm/InitializePasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ LLVM_ABI void initializeMachinePostDominatorTreeWrapperPassPass(PassRegistry &);
220220
LLVM_ABI void initializeMachineRegionInfoPassPass(PassRegistry &);
221221
LLVM_ABI void
222222
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
223+
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
224+
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
223225
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
224226
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
225227
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);

0 commit comments

Comments
 (0)