1+ // ===- MIR2Vec.h - Implementation of MIR2Vec ------------------*- C++ -*-===//
2+ //
3+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+ // Exceptions. See the LICENSE file for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+ // /
9+ // / \file
10+ // / This file defines the MIR2Vec vocabulary
11+ // / analysis(MIR2VecVocabLegacyAnalysis), the core mir2vec::Embedder interface
12+ // / for generating Machine IR embeddings, and related utilities.
13+ // /
14+ // / MIR2Vec extends IR2Vec to support Machine IR embeddings. It represents the
15+ // / LLVM Machine IR as embeddings which can be used as input to machine learning
16+ // / algorithms.
17+ // /
18+ // / The original idea of MIR2Vec is described in the following paper:
19+ // /
20+ // / RL4ReAl: Reinforcement Learning for Register Allocation. S. VenkataKeerthy,
21+ // / Siddharth Jain, Anilava Kundu, Rohit Aggarwal, Albert Cohen, and Ramakrishna
22+ // / Upadrasta. 2023. RL4ReAl: Reinforcement Learning for Register Allocation.
23+ // / Proceedings of the 32nd ACM SIGPLAN International Conference on Compiler
24+ // / Construction (CC 2023). https://doi.org/10.1145/3578360.3580273.
25+ // / https://arxiv.org/abs/2204.02013
26+ // /
27+ // ===----------------------------------------------------------------------===//
28+
29+ #ifndef LLVM_CODEGEN_MIR2VEC_H
30+ #define LLVM_CODEGEN_MIR2VEC_H
31+
32+ #include " llvm/Analysis/IR2Vec.h"
33+ #include " llvm/CodeGen/MachineBasicBlock.h"
34+ #include " llvm/CodeGen/MachineFunction.h"
35+ #include " llvm/CodeGen/MachineFunctionPass.h"
36+ #include " llvm/CodeGen/MachineInstr.h"
37+ #include " llvm/CodeGen/MachineModuleInfo.h"
38+ #include " llvm/IR/PassManager.h"
39+ #include " llvm/Pass.h"
40+ #include " llvm/Support/CommandLine.h"
41+ #include " llvm/Support/ErrorOr.h"
42+ #include < map>
43+ #include < set>
44+ #include < string>
45+
46+ namespace llvm {
47+
48+ class Module ;
49+ class raw_ostream ;
50+ class LLVMContext ;
51+ class MIR2VecVocabLegacyAnalysis ;
52+ class TargetInstrInfo ;
53+
54+ namespace mir2vec {
55+ extern llvm::cl::OptionCategory MIR2VecCategory;
56+ extern cl::opt<float > OpcWeight;
57+
58+ using Embedding = ir2vec::Embedding;
59+
60+ // / Class for storing and accessing the MIR2Vec vocabulary.
61+ // / The MIRVocabulary class manages seed embeddings for LLVM Machine IR
62+ class MIRVocabulary {
63+ friend class llvm ::MIR2VecVocabLegacyAnalysis;
64+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
65+
66+ private:
67+ // Define vocabulary layout - adapted for MIR
68+ struct {
69+ size_t OpcodeBase = 0 ;
70+ size_t OperandBase = 0 ;
71+ size_t TotalEntries = 0 ;
72+ } Layout;
73+
74+ ir2vec::VocabStorage Storage;
75+ mutable std::set<std::string> UniqueBaseOpcodeNames;
76+ void generateStorage (const VocabMap &OpcodeMap, const TargetInstrInfo &TII);
77+ void buildCanonicalOpcodeMapping (const TargetInstrInfo &TII);
78+
79+ public:
80+ // / Static helper method for extracting base opcode names (public for testing)
81+ static std::string extractBaseOpcodeName (StringRef InstrName);
82+
83+ // / Helper method for getting canonical index for base name (public for
84+ // / testing)
85+ unsigned getCanonicalIndexForBaseName (StringRef BaseName) const ;
86+
87+ // / Get the string key for a vocabulary entry at the given position
88+ std::string getStringKey (unsigned Pos) const ;
89+
90+ MIRVocabulary () = default ;
91+ MIRVocabulary (VocabMap &&Entries, const TargetInstrInfo *TII);
92+ MIRVocabulary (ir2vec::VocabStorage &&Storage) : Storage(std::move(Storage)) {}
93+
94+ bool isValid () const {
95+ return UniqueBaseOpcodeNames.size () > 0 &&
96+ Layout.TotalEntries == Storage.size () && Storage.isValid ();
97+ }
98+
99+ unsigned getDimension () const {
100+ if (!isValid ())
101+ return 0 ;
102+ return Storage.getDimension ();
103+ }
104+
105+ // Accessor methods
106+ const Embedding &operator [](unsigned Index) const {
107+ assert (isValid () && " MIR2Vec Vocabulary is invalid" );
108+ assert (Index < Layout.TotalEntries && " Index out of bounds" );
109+ // Fixme: For now, use section 0 for all entries
110+ return Storage[0 ][Index];
111+ }
112+
113+ // Iterator access
114+ using const_iterator = ir2vec::VocabStorage::const_iterator;
115+ const_iterator begin () const {
116+ assert (isValid () && " MIR2Vec Vocabulary is invalid" );
117+ return Storage.begin ();
118+ }
119+
120+ const_iterator end () const {
121+ assert (isValid () && " MIR2Vec Vocabulary is invalid" );
122+ return Storage.end ();
123+ }
124+
125+ // / Total number of entries in the vocabulary
126+ size_t getCanonicalSize () const {
127+ assert (isValid () && " Invalid vocabulary" );
128+ return Storage.size ();
129+ }
130+ };
131+
132+ } // namespace mir2vec
133+
134+ // / Pass to analyze and populate MIR2Vec vocabulary from a module
135+ class MIR2VecVocabLegacyAnalysis : public ImmutablePass {
136+ using VocabVector = std::vector<mir2vec::Embedding>;
137+ using VocabMap = std::map<std::string, mir2vec::Embedding>;
138+ VocabMap StrVocabMap;
139+ VocabVector Vocab;
140+
141+ StringRef getPassName () const override ;
142+ Error readVocabulary ();
143+ void emitError (Error Err, LLVMContext &Ctx);
144+
145+ protected:
146+ void getAnalysisUsage (AnalysisUsage &AU) const override {
147+ AU.addRequired <MachineModuleInfoWrapperPass>();
148+ AU.setPreservesAll ();
149+ }
150+
151+ public:
152+ static char ID;
153+ MIR2VecVocabLegacyAnalysis () : ImmutablePass(ID) {}
154+ mir2vec::MIRVocabulary getMIR2VecVocabulary (const Module &M);
155+ };
156+
157+ // / This pass prints the embeddings in the MIR2Vec vocabulary
158+ class MIR2VecVocabPrinterLegacyPass : public MachineFunctionPass {
159+ raw_ostream &OS;
160+
161+ public:
162+ static char ID;
163+ explicit MIR2VecVocabPrinterLegacyPass (raw_ostream &OS)
164+ : MachineFunctionPass(ID), OS(OS) {}
165+
166+ bool runOnMachineFunction (MachineFunction &MF) override ;
167+ bool doFinalization (Module &M) override ;
168+ void getAnalysisUsage (AnalysisUsage &AU) const override {
169+ AU.addRequired <MIR2VecVocabLegacyAnalysis>();
170+ AU.setPreservesAll ();
171+ MachineFunctionPass::getAnalysisUsage (AU);
172+ }
173+
174+ StringRef getPassName () const override {
175+ return " MIR2Vec Vocabulary Printer Pass" ;
176+ }
177+ };
178+
179+ } // namespace llvm
180+
181+ #endif // LLVM_CODEGEN_MIR2VEC_H
0 commit comments