1+ // ===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
2+ //
3+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+ // Exceptions. See the LICENSE file for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+ // /
9+ // / \file
10+ // / This file defines the MIR2Vec vocabulary analysis(MIR2VecVocabAnalysis),
11+ // / the core mir2vec::Embedder interface for generating Machine IR embeddings,
12+ // / and related utilities.
13+ // /
14+ // / MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
15+ // / LLVM Machine IR as embeddings which can be used as input to machine learning
16+ // / algorithms.
17+ // /
18+ // / The original idea of MIR2Vec is described in the following paper:
19+ // /
20+ // / RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
21+ // / Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
22+ // / Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
23+ // / Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
24+ // / Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
25+ // / https://arxiv.org/abs/2204.02013
26+ // /
27+ // ===----------------------------------------------------------------------===//
28+
29+ #ifndef LLVM_CODEGEN_MIR2VEC_H
30+ #define LLVM_CODEGEN_MIR2VEC_H
31+
32+ #include " llvm/Analysis/IR2Vec.h"
33+ #include " llvm/CodeGen/MachineBasicBlock.h"
34+ #include " llvm/CodeGen/MachineFunction.h"
35+ #include " llvm/CodeGen/MachineFunctionPass.h"
36+ #include " llvm/CodeGen/MachineInstr.h"
37+ #include " llvm/CodeGen/MachineModuleInfo.h"
38+ #include " llvm/IR/PassManager.h"
39+ #include " llvm/Pass.h"
40+ #include " llvm/Support/CommandLine.h"
41+ #include " llvm/Support/ErrorOr.h"
42+ #include < map>
43+ #include < set>
44+ #include < string>
45+
46+ namespace llvm {
47+
48+ class Module ;
49+ class raw_ostream ;
50+ class LLVMContext ;
51+ class MIR2VecVocabAnalysis ;
52+ class TargetInstrInfo ;
53+
54+ namespace mir2vec {
55+
56+ // Forward declarations
57+ class Embedder ;
58+ class SymbolicEmbedder ;
59+ class FlowAwareEmbedder ;
60+
61+ extern llvm::cl::OptionCategory MIR2VecCategory;
62+ extern cl::opt<float > OpcWeight;
63+
64+ using Embedding = ir2vec::Embedding;
65+
66+ // / Class for storing and accessing the MIR2Vec vocabulary.
67+ // / The Vocabulary class manages seed embeddings for LLVM Machine IR
68+ class Vocabulary {
69+ friend class llvm ::MIR2VecVocabAnalysis;
70+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
71+
72+ public:
73+ // Define vocabulary layout - adapted for MIR
74+ struct {
75+ unsigned OpcodeBase = 0 ;
76+ unsigned OperandBase = 0 ;
77+ unsigned TotalEntries = 0 ;
78+ } Layout;
79+
80+ private:
81+ ir2vec::VocabStorage Storage;
82+ mutable std::set<std::string> UniqueBaseOpcodeNames;
83+ void generateStorage (const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
84+ void buildCanonicalOpcodeMapping (const TargetInstrInfo &TII);
85+
86+ public:
87+ // / Static helper method for extracting base opcode names (public for testing)
88+ static std::string extractBaseOpcodeName (StringRef InstrName);
89+
90+ // / Helper method for getting canonical index for base name (public for
91+ // / testing)
92+ unsigned getCanonicalIndexForBaseName (StringRef BaseName) const ;
93+
94+ // / Get the string key for a vocabulary entry at the given position
95+ std::string getStringKey (unsigned Pos) const ;
96+
97+ Vocabulary () = default ;
98+ Vocabulary (VocabMap &&Entries, const TargetInstrInfo *TII);
99+ Vocabulary (ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
100+
101+ bool isValid () const ;
102+ unsigned getDimension () const ;
103+
104+ // Accessor methods
105+ const Embedding &operator [](unsigned Index) const ;
106+
107+ // Iterator access
108+ using const_iterator = ir2vec::VocabStorage::const_iterator;
109+ const_iterator begin () const ;
110+ const_iterator end () const ;
111+ };
112+
113+ } // namespace mir2vec
114+
115+ // / Pass to analyze and populate MIR2Vec vocabulary from a module
116+ class MIR2VecVocabAnalysis : public ImmutablePass {
117+ using VocabVector = std::vector<mir2vec::Embedding>;
118+ using VocabMap = std::map<std::string, mir2vec::Embedding>;
119+ VocabMap StrVocabMap;
120+ VocabVector Vocab;
121+
122+ StringRef getPassName () const override ;
123+ Error readVocabulary ();
124+ void emitError (Error Err, LLVMContext &Ctx);
125+
126+ protected:
127+ void getAnalysisUsage (AnalysisUsage &AU) const override {
128+ AU.addRequired <MachineModuleInfoWrapperPass>();
129+ AU.setPreservesAll ();
130+ }
131+
132+ public:
133+ static char ID;
134+ MIR2VecVocabAnalysis () : ImmutablePass(ID) {}
135+ mir2vec::Vocabulary getMIR2VecVocabulary (const Module &M);
136+ };
137+
138+ // / This pass prints the MIR2Vec embeddings for instructions, basic blocks, and
139+ // / functions.
140+ class MIR2VecPrinterPass : public PassInfoMixin <MIR2VecPrinterPass> {
141+ raw_ostream &OS;
142+
143+ public:
144+ explicit MIR2VecPrinterPass (raw_ostream &OS) : OS(OS) {}
145+ PreservedAnalyses run (Module &M, ModuleAnalysisManager &MAM);
146+ static bool isRequired () { return true ; }
147+ };
148+
149+ // / This pass prints the embeddings in the MIR2Vec vocabulary
150+ class MIR2VecVocabPrinterPass : public MachineFunctionPass {
151+ raw_ostream &OS;
152+
153+ public:
154+ static char ID;
155+ explicit MIR2VecVocabPrinterPass (raw_ostream &OS)
156+ : MachineFunctionPass(ID), OS(OS) {}
157+
158+ bool runOnMachineFunction (MachineFunction &MF) override ;
159+ bool doFinalization (Module &M) override ;
160+ void getAnalysisUsage (AnalysisUsage &AU) const override {
161+ AU.addRequired <MIR2VecVocabAnalysis>();
162+ AU.setPreservesAll ();
163+ MachineFunctionPass::getAnalysisUsage (AU);
164+ }
165+
166+ StringRef getPassName () const override {
167+ return " MIR2Vec Vocabulary Printer Pass" ;
168+ }
169+ };
170+
171+ // / Old PM version of the printer pass
172+ class MIR2VecPrinterLegacyPass : public ModulePass {
173+ raw_ostream &OS;
174+
175+ public:
176+ static char ID;
177+ explicit MIR2VecPrinterLegacyPass (raw_ostream &OS) : ModulePass(ID), OS(OS) {}
178+
179+ bool runOnModule (Module &M) override ;
180+ void getAnalysisUsage (AnalysisUsage &AU) const override {
181+ AU.setPreservesAll ();
182+ AU.addRequired <MIR2VecVocabAnalysis>();
183+ AU.addRequired <MachineModuleInfoWrapperPass>();
184+ }
185+
186+ StringRef getPassName () const override { return " MIR2Vec Printer Pass" ; }
187+ };
188+
189+ } // namespace llvm
190+
191+ #endif // LLVM_CODEGEN_MIR2VEC_H
0 commit comments