Skip to content

Commit b7a8ff7

Browse files
committed
MIR2Vec embedding
1 parent 0ebc739 commit b7a8ff7

File tree

10 files changed

+884
-47
lines changed

10 files changed

+884
-47
lines changed

llvm/include/llvm/CodeGen/MIR2Vec.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,21 @@ class LLVMContext;
5151
class MIR2VecVocabLegacyAnalysis;
5252
class TargetInstrInfo;
5353

54+
enum class MIR2VecKind { Symbolic };
55+
5456
namespace mir2vec {
57+
58+
// Forward declarations
59+
class MIREmbedder;
60+
class SymbolicMIREmbedder;
61+
5562
extern llvm::cl::OptionCategory MIR2VecCategory;
5663
extern cl::opt<float> OpcWeight;
5764

5865
using Embedding = ir2vec::Embedding;
66+
using MachineInstEmbeddingsMap = DenseMap<const MachineInstr *, Embedding>;
67+
using MachineBlockEmbeddingsMap =
68+
DenseMap<const MachineBasicBlock *, Embedding>;
5969

6070
/// Class for storing and accessing the MIR2Vec vocabulary.
6171
/// The MIRVocabulary class manages seed embeddings for LLVM Machine IR
@@ -132,6 +142,79 @@ class MIRVocabulary {
132142
assert(isValid() && "Invalid vocabulary");
133143
return Storage.size();
134144
}
145+
146+
/// Create a dummy vocabulary for testing purposes.
147+
static MIRVocabulary createDummyVocabForTest(const TargetInstrInfo &TII,
148+
unsigned Dim = 1);
149+
};
150+
151+
/// Base class for MIR embedders
152+
class MIREmbedder {
153+
protected:
154+
const MachineFunction &MF;
155+
const MIRVocabulary &Vocab;
156+
157+
/// Dimension of the embeddings; Captured from the vocabulary
158+
const unsigned Dimension;
159+
160+
/// Weight for opcode embeddings
161+
const float OpcWeight;
162+
163+
// Utility maps - these are used to store the vector representations of
164+
// instructions, basic blocks and functions.
165+
mutable Embedding MFuncVector;
166+
mutable MachineBlockEmbeddingsMap MBBVecMap;
167+
mutable MachineInstEmbeddingsMap MInstVecMap;
168+
169+
MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab);
170+
171+
/// Function to compute embeddings. It generates embeddings for all
172+
/// the instructions and basic blocks in the function F.
173+
void computeEmbeddings() const;
174+
175+
/// Function to compute the embedding for a given basic block.
176+
/// Specific to the kind of embeddings being computed.
177+
virtual void computeEmbeddings(const MachineBasicBlock &MBB) const = 0;
178+
179+
public:
180+
virtual ~MIREmbedder() = default;
181+
182+
/// Factory method to create an Embedder object of the specified kind
183+
/// Returns nullptr if the requested kind is not supported.
184+
static std::unique_ptr<MIREmbedder> create(MIR2VecKind Mode,
185+
const MachineFunction &MF,
186+
const MIRVocabulary &Vocab);
187+
188+
/// Returns a map containing machine instructions and the corresponding
189+
/// embeddings for the machine function MF if it has been computed. If not, it
190+
/// computes the embeddings for MF and returns the map.
191+
const MachineInstEmbeddingsMap &getMInstVecMap() const;
192+
193+
/// Returns a map containing machine basic block and the corresponding
194+
/// embeddings for the machine function MF if it has been computed. If not, it
195+
/// computes the embeddings for MF and returns the map.
196+
const MachineBlockEmbeddingsMap &getMBBVecMap() const;
197+
198+
/// Returns the embedding for a given machine basic block in the machine
199+
/// function MF if it has been computed. If not, it computes the embedding for
200+
/// MBB and returns it.
201+
const Embedding &getMBBVector(const MachineBasicBlock &MBB) const;
202+
203+
/// Computes and returns the embedding for the current machine function.
204+
const Embedding &getMFunctionVector() const;
205+
};
206+
207+
/// Class for computing Symbolic embeddings
208+
/// Symbolic embeddings are constructed based on the entity-level
209+
/// representations obtained from the MIR Vocabulary.
210+
class SymbolicMIREmbedder : public MIREmbedder {
211+
private:
212+
void computeEmbeddings(const MachineBasicBlock &MBB) const override;
213+
214+
public:
215+
SymbolicMIREmbedder(const MachineFunction &F, const MIRVocabulary &Vocab);
216+
static std::unique_ptr<SymbolicMIREmbedder>
217+
create(const MachineFunction &MF, const MIRVocabulary &Vocab);
135218
};
136219

137220
} // namespace mir2vec
@@ -181,6 +264,31 @@ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
181264
}
182265
};
183266

267+
/// This pass prints the MIR2Vec embeddings for machine functions, basic blocks,
268+
/// and instructions
269+
class MIR2VecPrinterLegacyPass : public MachineFunctionPass {
270+
raw_ostream &OS;
271+
272+
public:
273+
static char ID;
274+
explicit MIR2VecPrinterLegacyPass(raw_ostream &OS)
275+
: MachineFunctionPass(ID), OS(OS) {}
276+
277+
bool runOnMachineFunction(MachineFunction &MF) override;
278+
void getAnalysisUsage(AnalysisUsage &AU) const override {
279+
AU.addRequired<MIR2VecVocabLegacyAnalysis>();
280+
AU.setPreservesAll();
281+
MachineFunctionPass::getAnalysisUsage(AU);
282+
}
283+
284+
StringRef getPassName() const override {
285+
return "MIR2Vec Embedder Printer Pass";
286+
}
287+
};
288+
289+
/// Create a machine pass that prints MIR2Vec embeddings
290+
MachineFunctionPass *createMIR2VecPrinterLegacyPass(raw_ostream &OS);
291+
184292
} // namespace llvm
185293

186294
#endif // LLVM_CODEGEN_MIR2VEC_H

llvm/include/llvm/CodeGen/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ createMachineFunctionPrinterPass(raw_ostream &OS,
9393
LLVM_ABI MachineFunctionPass *
9494
createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS);
9595

96+
/// MIR2VecPrinter pass - This pass prints out the MIR2Vec embeddings for
97+
/// machine functions, basic blocks and instructions.
98+
LLVM_ABI MachineFunctionPass *
99+
createMIR2VecPrinterLegacyPass(raw_ostream &OS);
100+
96101
/// StackFramePrinter pass - This pass prints out the machine function's
97102
/// stack frame to the given stream as a debugging tool.
98103
LLVM_ABI MachineFunctionPass *createStackFrameLayoutAnalysisPass();

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ LLVM_ABI void
222222
initializeMachineSanitizerBinaryMetadataLegacyPass(PassRegistry &);
223223
LLVM_ABI void initializeMIR2VecVocabLegacyAnalysisPass(PassRegistry &);
224224
LLVM_ABI void initializeMIR2VecVocabPrinterLegacyPassPass(PassRegistry &);
225+
LLVM_ABI void initializeMIR2VecPrinterLegacyPassPass(PassRegistry &);
225226
LLVM_ABI void initializeMachineSchedulerLegacyPass(PassRegistry &);
226227
LLVM_ABI void initializeMachineSinkingLegacyPass(PassRegistry &);
227228
LLVM_ABI void initializeMachineTraceMetricsWrapperPassPass(PassRegistry &);

llvm/lib/CodeGen/CodeGen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
9898
initializeMachineUniformityAnalysisPassPass(Registry);
9999
initializeMIR2VecVocabLegacyAnalysisPass(Registry);
100100
initializeMIR2VecVocabPrinterLegacyPassPass(Registry);
101+
initializeMIR2VecPrinterLegacyPassPass(Registry);
101102
initializeMachineUniformityInfoPrinterPassPass(Registry);
102103
initializeMachineVerifierLegacyPassPass(Registry);
103104
initializeObjCARCContractLegacyPassPass(Registry);

llvm/lib/CodeGen/MIR2Vec.cpp

Lines changed: 192 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,18 @@ static cl::opt<std::string>
4141
cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
4242
cl::desc("Weight for machine opcode embeddings"),
4343
cl::cat(MIR2VecCategory));
44+
cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
45+
"mir2vec-kind", cl::Optional,
46+
cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
47+
"Generate symbolic embeddings for MIR")),
48+
cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
49+
cl::cat(MIR2VecCategory));
50+
4451
} // namespace mir2vec
4552
} // namespace llvm
4653

4754
//===----------------------------------------------------------------------===//
48-
// Vocabulary Implementation
55+
// Vocabulary
4956
//===----------------------------------------------------------------------===//
5057

5158
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -190,6 +197,28 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
190197
<< " unique base opcodes\n");
191198
}
192199

200+
MIRVocabulary MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
201+
unsigned Dim) {
202+
assert(Dim > 0 && "Dimension must be greater than zero");
203+
204+
float DummyVal = 0.1f;
205+
206+
// Create a temporary vocabulary instance to build canonical mapping
207+
MIRVocabulary TempVocab({}, &TII);
208+
TempVocab.buildCanonicalOpcodeMapping();
209+
210+
// Create dummy embeddings for all canonical opcode names
211+
VocabMap DummyVocabMap;
212+
for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) {
213+
// Create dummy embedding filled with DummyVal
214+
Embedding DummyEmbedding(Dim, DummyVal);
215+
DummyVocabMap[COpcodeName] = DummyEmbedding;
216+
}
217+
218+
// Create and return vocabulary with dummy embeddings
219+
return MIRVocabulary(std::move(DummyVocabMap), &TII);
220+
}
221+
193222
//===----------------------------------------------------------------------===//
194223
// MIR2VecVocabLegacyAnalysis Implementation
195224
//===----------------------------------------------------------------------===//
@@ -267,7 +296,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
267296
}
268297

269298
//===----------------------------------------------------------------------===//
270-
// Printer Passes Implementation
299+
// MIREmbedder and its subclasses
300+
//===----------------------------------------------------------------------===//
301+
302+
MIREmbedder::MIREmbedder(const MachineFunction &MF, const MIRVocabulary &Vocab)
303+
: MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
304+
OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {}
305+
306+
std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
307+
const MachineFunction &MF,
308+
const MIRVocabulary &Vocab) {
309+
switch (Mode) {
310+
case MIR2VecKind::Symbolic:
311+
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
312+
}
313+
return nullptr;
314+
}
315+
316+
const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap() const {
317+
if (MInstVecMap.empty())
318+
computeEmbeddings();
319+
return MInstVecMap;
320+
}
321+
322+
const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap() const {
323+
if (MBBVecMap.empty())
324+
computeEmbeddings();
325+
return MBBVecMap;
326+
}
327+
328+
const Embedding &MIREmbedder::getMBBVector(const MachineBasicBlock &BB) const {
329+
auto It = MBBVecMap.find(&BB);
330+
if (It != MBBVecMap.end())
331+
return It->second;
332+
computeEmbeddings(BB);
333+
return MBBVecMap[&BB];
334+
}
335+
336+
const Embedding &MIREmbedder::getMFunctionVector() const {
337+
// Currently, we always (re)compute the embeddings for the function.
338+
// This is cheaper than caching the vector.
339+
computeEmbeddings();
340+
return MFuncVector;
341+
}
342+
343+
void MIREmbedder::computeEmbeddings() const {
344+
// Reset function vector to zero before recomputing
345+
MFuncVector = Embedding(Dimension, 0.0);
346+
347+
// Consider all machine basic blocks in the function
348+
for (const auto &MBB : MF) {
349+
computeEmbeddings(MBB);
350+
MFuncVector += MBBVecMap[&MBB];
351+
}
352+
}
353+
354+
SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
355+
const MIRVocabulary &Vocab)
356+
: MIREmbedder(MF, Vocab) {}
357+
358+
std::unique_ptr<SymbolicMIREmbedder>
359+
SymbolicMIREmbedder::create(const MachineFunction &MF,
360+
const MIRVocabulary &Vocab) {
361+
return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
362+
}
363+
364+
void SymbolicMIREmbedder::computeEmbeddings(
365+
const MachineBasicBlock &MBB) const {
366+
Embedding MBBVector(Dimension, 0);
367+
368+
// Get instruction info for opcode name resolution
369+
const auto &Subtarget = MF.getSubtarget();
370+
const auto *TII = Subtarget.getInstrInfo();
371+
if (!TII) {
372+
MF.getFunction().getContext().emitError(
373+
"MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
374+
return;
375+
}
376+
377+
// Process each machine instruction in the basic block
378+
for (const auto &MI : MBB) {
379+
// Skip debug instructions and other metadata
380+
if (MI.isDebugInstr())
381+
continue;
382+
383+
// Todo: Add operand/argument contributions
384+
385+
// Store the instruction embedding
386+
auto InstVector = Vocab[MI.getOpcode()];
387+
MInstVecMap[&MI] = InstVector;
388+
MBBVector += InstVector;
389+
}
390+
391+
// Store the basic block embedding
392+
MBBVecMap[&MBB] = MBBVector;
393+
}
394+
395+
//===----------------------------------------------------------------------===//
396+
// Printer Passes
271397
//===----------------------------------------------------------------------===//
272398

273399
char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -304,3 +430,67 @@ MachineFunctionPass *
304430
llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
305431
return new MIR2VecVocabPrinterLegacyPass(OS);
306432
}
433+
434+
char MIR2VecPrinterLegacyPass::ID = 0;
435+
INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
436+
"MIR2Vec Embedder Printer Pass", false, true)
437+
INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
438+
INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
439+
INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
440+
"MIR2Vec Embedder Printer Pass", false, true)
441+
442+
bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
443+
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
444+
auto MIRVocab = Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
445+
446+
if (!MIRVocab.isValid()) {
447+
OS << "MIR2Vec Embedder Printer: Invalid vocabulary for function "
448+
<< MF.getName() << "\n";
449+
return false;
450+
}
451+
452+
auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
453+
if (!Emb) {
454+
OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
455+
<< "\n";
456+
return false;
457+
}
458+
459+
OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
460+
OS << "Machine Function vector: ";
461+
Emb->getMFunctionVector().print(OS);
462+
463+
OS << "Machine basic block vectors:\n";
464+
const auto &MBBMap = Emb->getMBBVecMap();
465+
for (const MachineBasicBlock &MBB : MF) {
466+
auto It = MBBMap.find(&MBB);
467+
if (It != MBBMap.end()) {
468+
OS << "Machine basic block: " << MBB.getFullName() << ":\n";
469+
It->second.print(OS);
470+
}
471+
}
472+
473+
OS << "Machine instruction vectors:\n";
474+
const auto &MInstMap = Emb->getMInstVecMap();
475+
for (const MachineBasicBlock &MBB : MF) {
476+
for (const MachineInstr &MI : MBB) {
477+
// Skip debug instructions as they are not
478+
// embedded
479+
if (MI.isDebugInstr())
480+
continue;
481+
482+
auto It = MInstMap.find(&MI);
483+
if (It != MInstMap.end()) {
484+
OS << "Machine instruction: ";
485+
MI.print(OS);
486+
It->second.print(OS);
487+
}
488+
}
489+
}
490+
491+
return false;
492+
}
493+
494+
MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
495+
return new MIR2VecPrinterLegacyPass(OS);
496+
}

0 commit comments

Comments
 (0)