Skip to content

Commit 8b407f7

Browse files
committed
Introducing MIR2Vec
1 parent 6bcaec6 commit 8b407f7

File tree

12 files changed

+1470
-50
lines changed

12 files changed

+1470
-50
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ class VocabStorage {
210210
const_iterator end() const {
211211
return const_iterator(this, getNumSections(), 0);
212212
}
213+
using VocabMap = std::map<std::string, ir2vec::Embedding>;
214+
static Error parseVocabSection(StringRef Key,
215+
const json::Value &ParsedVocabValue,
216+
VocabMap &TargetVocab, unsigned &Dim);
213217
};
214218

215219
/// Class for storing and accessing the IR2Vec vocabulary.
@@ -593,8 +597,6 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
593597

594598
Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
595599
VocabMap &ArgVocab);
596-
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
597-
VocabMap &TargetVocab, unsigned &Dim);
598600
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
599601
VocabMap &ArgVocab);
600602
void emitError(Error Err, LLVMContext &Ctx);
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
//===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See the LICENSE file for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file defines the MIR2Vec vocabulary analysis(MIR2VecVocabAnalysis),
11+
/// the core mir2vec::Embedder interface for generating Machine IR embeddings,
12+
/// and related utilities.
13+
///
14+
/// MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
15+
/// LLVM Machine IR as embeddings which can be used as input to machine learning
16+
/// algorithms.
17+
///
18+
/// The original idea of MIR2Vec is described in the following paper:
19+
///
20+
/// RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
21+
/// Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
22+
/// Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
23+
/// Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
24+
/// Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
25+
/// https://arxiv.org/abs/2204.02013
26+
///
27+
//===----------------------------------------------------------------------===//
28+
29+
#ifndef LLVM_CODEGEN_MIR2VEC_H
30+
#define LLVM_CODEGEN_MIR2VEC_H
31+
32+
#include "llvm/Analysis/IR2Vec.h"
33+
#include "llvm/CodeGen/MachineBasicBlock.h"
34+
#include "llvm/CodeGen/MachineFunction.h"
35+
#include "llvm/CodeGen/MachineFunctionPass.h"
36+
#include "llvm/CodeGen/MachineInstr.h"
37+
#include "llvm/CodeGen/MachineModuleInfo.h"
38+
#include "llvm/IR/PassManager.h"
39+
#include "llvm/Pass.h"
40+
#include "llvm/Support/CommandLine.h"
41+
#include "llvm/Support/ErrorOr.h"
42+
#include <map>
43+
#include <set>
44+
#include <string>
45+
46+
namespace llvm {
47+
48+
class Module;
49+
class raw_ostream;
50+
class LLVMContext;
51+
class MIR2VecVocabAnalysis;
52+
class TargetInstrInfo;
53+
54+
namespace mir2vec {
55+
56+
// Forward declarations
57+
class Embedder;
58+
class SymbolicEmbedder;
59+
class FlowAwareEmbedder;
60+
61+
extern llvm::cl::OptionCategory MIR2VecCategory;
62+
extern cl::opt<float> OpcWeight;
63+
64+
using Embedding = ir2vec::Embedding;
65+
66+
/// Class for storing and accessing the MIR2Vec vocabulary.
67+
/// The Vocabulary class manages seed embeddings for LLVM Machine IR
68+
class Vocabulary {
69+
friend class llvm::MIR2VecVocabAnalysis;
70+
using VocabMap = std::map<std::string, ir2vec::Embedding>;
71+
72+
public:
73+
// Define vocabulary layout - adapted for MIR
74+
struct {
75+
unsigned OpcodeBase = 0;
76+
unsigned OperandBase = 0;
77+
unsigned TotalEntries = 0;
78+
} Layout;
79+
80+
private:
81+
ir2vec::VocabStorage Storage;
82+
mutable std::set<std::string> UniqueBaseOpcodeNames;
83+
void generateStorage(const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
84+
void buildCanonicalOpcodeMapping(const TargetInstrInfo &TII);
85+
86+
public:
87+
/// Static helper method for extracting base opcode names (public for testing)
88+
static std::string extractBaseOpcodeName(StringRef InstrName);
89+
90+
/// Helper method for getting canonical index for base name (public for
91+
/// testing)
92+
unsigned getCanonicalIndexForBaseName(StringRef BaseName) const;
93+
94+
/// Get the string key for a vocabulary entry at the given position
95+
std::string getStringKey(unsigned Pos) const;
96+
97+
Vocabulary() = default;
98+
Vocabulary(VocabMap &&Entries, const TargetInstrInfo *TII);
99+
Vocabulary(ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
100+
101+
bool isValid() const;
102+
unsigned getDimension() const;
103+
104+
// Accessor methods
105+
const Embedding &operator[](unsigned Index) const;
106+
107+
// Iterator access
108+
using const_iterator = ir2vec::VocabStorage::const_iterator;
109+
const_iterator begin() const;
110+
const_iterator end() const;
111+
};
112+
113+
} // namespace mir2vec
114+
115+
/// Pass to analyze and populate MIR2Vec vocabulary from a module
116+
class MIR2VecVocabAnalysis : public ImmutablePass {
117+
using VocabVector = std::vector<mir2vec::Embedding>;
118+
using VocabMap = std::map<std::string, mir2vec::Embedding>;
119+
VocabMap StrVocabMap;
120+
VocabVector Vocab;
121+
122+
StringRef getPassName() const override;
123+
Error readVocabulary();
124+
void emitError(Error Err, LLVMContext &Ctx);
125+
126+
protected:
127+
void getAnalysisUsage(AnalysisUsage &AU) const override {
128+
AU.addRequired<MachineModuleInfoWrapperPass>();
129+
AU.setPreservesAll();
130+
}
131+
132+
public:
133+
static char ID;
134+
MIR2VecVocabAnalysis() : ImmutablePass(ID) {}
135+
mir2vec::Vocabulary getMIR2VecVocabulary(const Module &M);
136+
};
137+
138+
/// This pass prints the MIR2Vec embeddings for instructions, basic blocks, and
139+
/// functions.
140+
class MIR2VecPrinterPass : public PassInfoMixin<MIR2VecPrinterPass> {
141+
raw_ostream &OS;
142+
143+
public:
144+
explicit MIR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
145+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
146+
static bool isRequired() { return true; }
147+
};
148+
149+
/// This pass prints the embeddings in the MIR2Vec vocabulary
150+
class MIR2VecVocabPrinterPass : public MachineFunctionPass {
151+
raw_ostream &OS;
152+
153+
public:
154+
static char ID;
155+
explicit MIR2VecVocabPrinterPass(raw_ostream &OS)
156+
: MachineFunctionPass(ID), OS(OS) {}
157+
158+
bool runOnMachineFunction(MachineFunction &MF) override;
159+
bool doFinalization(Module &M) override;
160+
void getAnalysisUsage(AnalysisUsage &AU) const override {
161+
AU.addRequired<MIR2VecVocabAnalysis>();
162+
AU.setPreservesAll();
163+
MachineFunctionPass::getAnalysisUsage(AU);
164+
}
165+
166+
StringRef getPassName() const override {
167+
return "MIR2Vec Vocabulary Printer Pass";
168+
}
169+
};
170+
171+
/// Old PM version of the printer pass
172+
class MIR2VecPrinterLegacyPass : public ModulePass {
173+
raw_ostream &OS;
174+
175+
public:
176+
static char ID;
177+
explicit MIR2VecPrinterLegacyPass(raw_ostream &OS) : ModulePass(ID), OS(OS) {}
178+
179+
bool runOnModule(Module &M) override;
180+
void getAnalysisUsage(AnalysisUsage &AU) const override {
181+
AU.setPreservesAll();
182+
AU.addRequired<MIR2VecVocabAnalysis>();
183+
AU.addRequired<MachineModuleInfoWrapperPass>();
184+
}
185+
186+
StringRef getPassName() const override { return "MIR2Vec Printer Pass"; }
187+
};
188+
189+
} // namespace llvm
190+
191+
#endif // LLVM_CODEGEN_MIR2VEC_H

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ LLVM_ABI MachineFunctionPass *
8787
createMachineFunctionPrinterPass(raw_ostream &OS,
8888
const std::string &Banner = "");
8989

90+
/// MIR2VecVocabPrinter pass - This pass prints out the MIR2Vec vocabulary
91+
/// contents to the given stream as a debugging tool.
92+
LLVM_ABI MachineFunctionPass *createMIR2VecVocabPrinterPass(raw_ostream &OS);
93+
9094
/// StackFramePrinter pass - This pass prints out the machine function's
9195
/// stack frame to the given stream as a debugging tool.
9296
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();

llvm/include/llvm/InitializePasses.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ LLVM_ABI void initializeMachinePostDominatorTreeWrapperPassPass(PassRegistry &);
220220
LLVM_ABI void initializeMachineRegionInfoPassPass(PassRegistry &);
221221
LLVM_ABI void
222222
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
223+
LLVM_ABI void initializeMIR2VecVocabAnalysisPass(PassRegistry &);
224+
LLVM_ABI void initializeMIR2VecVocabPrinterPassPass(PassRegistry &);
223225
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
224226
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
225227
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,43 @@ bool VocabStorage::const_iterator::operator!=(
329329
return !(*this == Other);
330330
}
331331

332+
Error VocabStorage::parseVocabSection(StringRef Key,
333+
const json::Value &ParsedVocabValue,
334+
VocabMap &TargetVocab, unsigned &Dim) {
335+
json::Path::Root Path("");
336+
const json::Object *RootObj = ParsedVocabValue.getAsObject();
337+
if (!RootObj)
338+
return createStringError(errc::invalid_argument,
339+
"JSON root is not an object");
340+
341+
const json::Value *SectionValue = RootObj->get(Key);
342+
if (!SectionValue)
343+
return createStringError(errc::invalid_argument,
344+
"Missing '" + std::string(Key) +
345+
"' section in vocabulary file");
346+
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
347+
return createStringError(errc::illegal_byte_sequence,
348+
"Unable to parse '" + std::string(Key) +
349+
"' section from vocabulary");
350+
351+
Dim = TargetVocab.begin()->second.size();
352+
if (Dim == 0)
353+
return createStringError(errc::illegal_byte_sequence,
354+
"Dimension of '" + std::string(Key) +
355+
"' section of the vocabulary is zero");
356+
357+
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
358+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
359+
return Entry.second.size() == Dim;
360+
}))
361+
return createStringError(
362+
errc::illegal_byte_sequence,
363+
"All vectors in the '" + std::string(Key) +
364+
"' section of the vocabulary are not of the same dimension");
365+
366+
return Error::success();
367+
}
368+
332369
// ==----------------------------------------------------------------------===//
333370
// Vocabulary
334371
//===----------------------------------------------------------------------===//
@@ -459,43 +496,6 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) {
459496
// IR2VecVocabAnalysis
460497
//===----------------------------------------------------------------------===//
461498

462-
Error IR2VecVocabAnalysis::parseVocabSection(
463-
StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
464-
unsigned &Dim) {
465-
json::Path::Root Path("");
466-
const json::Object *RootObj = ParsedVocabValue.getAsObject();
467-
if (!RootObj)
468-
return createStringError(errc::invalid_argument,
469-
"JSON root is not an object");
470-
471-
const json::Value *SectionValue = RootObj->get(Key);
472-
if (!SectionValue)
473-
return createStringError(errc::invalid_argument,
474-
"Missing '" + std::string(Key) +
475-
"' section in vocabulary file");
476-
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
477-
return createStringError(errc::illegal_byte_sequence,
478-
"Unable to parse '" + std::string(Key) +
479-
"' section from vocabulary");
480-
481-
Dim = TargetVocab.begin()->second.size();
482-
if (Dim == 0)
483-
return createStringError(errc::illegal_byte_sequence,
484-
"Dimension of '" + std::string(Key) +
485-
"' section of the vocabulary is zero");
486-
487-
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
488-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
489-
return Entry.second.size() == Dim;
490-
}))
491-
return createStringError(
492-
errc::illegal_byte_sequence,
493-
"All vectors in the '" + std::string(Key) +
494-
"' section of the vocabulary are not of the same dimension");
495-
496-
return Error::success();
497-
}
498-
499499
// FIXME: Make this optional. We can avoid file reads
500500
// by auto-generating a default vocabulary during the build time.
501501
Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
@@ -512,16 +512,16 @@ Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab,
512512
return ParsedVocabValue.takeError();
513513

514514
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
515-
if (auto Err =
516-
parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
515+
if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue,
516+
OpcVocab, OpcodeDim))
517517
return Err;
518518

519-
if (auto Err =
520-
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
519+
if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue,
520+
TypeVocab, TypeDim))
521521
return Err;
522522

523-
if (auto Err =
524-
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
523+
if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue,
524+
ArgVocab, ArgDim))
525525
return Err;
526526

527527
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))

0 commit comments

Comments
 (0)