31
31
32
32
#include " llvm/ADT/DenseMap.h"
33
33
#include " llvm/IR/PassManager.h"
34
+ #include " llvm/IR/Type.h"
34
35
#include " llvm/Support/CommandLine.h"
35
36
#include " llvm/Support/Compiler.h"
36
37
#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
43
44
class BasicBlock ;
44
45
class Instruction ;
45
46
class Function ;
46
- class Type ;
47
47
class Value ;
48
48
class raw_ostream ;
49
49
class LLVMContext ;
50
+ class IR2VecVocabAnalysis ;
50
51
51
52
// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
52
53
// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -128,9 +129,94 @@ struct Embedding {
128
129
129
130
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
130
131
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
131
- // FIXME: Current the keys are strings. This can be changed to
132
- // use integers for cheaper lookups.
133
- using Vocab = std::map<std::string, Embedding>;
132
+
133
+ // / Class for storing and accessing the IR2Vec vocabulary.
134
+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
135
+ class Vocabulary {
136
+ friend class llvm ::IR2VecVocabAnalysis;
137
+ using VocabVector = std::vector<ir2vec::Embedding>;
138
+ VocabVector Vocab;
139
+ bool Valid = false ;
140
+
141
+ // / Operand kinds supported by IR2Vec Vocabulary
142
+ enum class OperandKind : unsigned {
143
+ FunctionID,
144
+ PointerID,
145
+ ConstantID,
146
+ VariableID,
147
+ MaxOperandKind
148
+ };
149
+ // / String mappings for OperandKind values
150
+ static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
151
+ " Constant" , " Variable" };
152
+ static_assert (std::size(OperandKindNames) ==
153
+ static_cast <unsigned >(OperandKind::MaxOperandKind),
154
+ " OperandKindNames array size must match MaxOperandKind" );
155
+
156
+ // / Vocabulary layout constants
157
+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
158
+ #include " llvm/IR/Instruction.def"
159
+ #undef LAST_OTHER_INST
160
+
161
+ static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1 ;
162
+ static constexpr unsigned MaxOperandKinds =
163
+ static_cast <unsigned >(OperandKind::MaxOperandKind);
164
+
165
+ // / Helper function to get vocabulary key for a given OperandKind
166
+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
167
+
168
+ // / Helper function to classify an operand into OperandKind
169
+ static OperandKind getOperandKind (const Value *Op);
170
+
171
+ // / Helper function to get vocabulary key for a given TypeID
172
+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
173
+
174
+ public:
175
+ Vocabulary () = default ;
176
+ Vocabulary (VocabVector &&Vocab);
177
+
178
+ bool isValid () const ;
179
+ unsigned getDimension () const ;
180
+ size_t size () const ;
181
+
182
+ // / Accessors to get the embedding for a given entity.
183
+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
184
+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
185
+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
186
+
187
+ // / Const Iterator type aliases
188
+ using const_iterator = VocabVector::const_iterator;
189
+ const_iterator begin () const {
190
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
191
+ return Vocab.begin ();
192
+ }
193
+
194
+ const_iterator cbegin () const {
195
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
196
+ return Vocab.cbegin ();
197
+ }
198
+
199
+ const_iterator end () const {
200
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
201
+ return Vocab.end ();
202
+ }
203
+
204
+ const_iterator cend () const {
205
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
206
+ return Vocab.cend ();
207
+ }
208
+
209
+ // / Returns the string key for a given index position in the vocabulary.
210
+ // / This is useful for debugging or printing the vocabulary. Do not use this
211
+ // / for embedding generation as string based lookups are inefficient.
212
+ static StringRef getStringKey (unsigned Pos);
213
+
214
+ // / Create a dummy vocabulary for testing purposes.
215
+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
216
+
217
+ bool invalidate (Module &M, const PreservedAnalyses &PA,
218
+ ModuleAnalysisManager::Invalidator &Inv) const ;
219
+ };
134
220
135
221
// / Embedder provides the interface to generate embeddings (vector
136
222
// / representations) for instructions, basic blocks, and functions. The
@@ -141,7 +227,7 @@ using Vocab = std::map<std::string, Embedding>;
141
227
class Embedder {
142
228
protected:
143
229
const Function &F;
144
- const Vocab &Vocabulary ;
230
+ const Vocabulary &Vocab ;
145
231
146
232
// / Dimension of the vector representation; captured from the input vocabulary
147
233
const unsigned Dimension;
@@ -156,7 +242,7 @@ class Embedder {
156
242
mutable BBEmbeddingsMap BBVecMap;
157
243
mutable InstEmbeddingsMap InstVecMap;
158
244
159
- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
245
+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
160
246
161
247
// / Helper function to compute embeddings. It generates embeddings for all
162
248
// / the instructions and basic blocks in the function F. Logic of computing
@@ -167,16 +253,12 @@ class Embedder {
167
253
// / Specific to the kind of embeddings being computed.
168
254
virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
169
255
170
- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
171
- // / zero vector.
172
- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
173
-
174
256
public:
175
257
virtual ~Embedder () = default ;
176
258
177
259
// / Factory method to create an Embedder object.
178
260
LLVM_ABI static std::unique_ptr<Embedder>
179
- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
261
+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
180
262
181
263
// / Returns a map containing instructions and the corresponding embeddings for
182
264
// / the function F if it has been computed. If not, it computes the embeddings
@@ -202,56 +284,39 @@ class Embedder {
202
284
// / representations obtained from the Vocabulary.
203
285
class LLVM_ABI SymbolicEmbedder : public Embedder {
204
286
private:
205
- // / Utility function to compute the embedding for a given type.
206
- Embedding getTypeEmbedding (const Type *Ty) const ;
207
-
208
- // / Utility function to compute the embedding for a given operand.
209
- Embedding getOperandEmbedding (const Value *Op) const ;
210
-
211
287
void computeEmbeddings () const override ;
212
288
void computeEmbeddings (const BasicBlock &BB) const override ;
213
289
214
290
public:
215
- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
216
- : Embedder(F, Vocabulary ) {
291
+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
292
+ : Embedder(F, Vocab ) {
217
293
FuncVector = Embedding (Dimension, 0 );
218
294
}
219
295
};
220
296
221
297
} // namespace ir2vec
222
298
223
- // / Class for storing the result of the IR2VecVocabAnalysis.
224
- class IR2VecVocabResult {
225
- ir2vec::Vocab Vocabulary;
226
- bool Valid = false ;
227
-
228
- public:
229
- IR2VecVocabResult () = default ;
230
- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
231
-
232
- bool isValid () const { return Valid; }
233
- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
234
- LLVM_ABI unsigned getDimension () const ;
235
- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236
- ModuleAnalysisManager::Invalidator &Inv) const ;
237
- };
238
-
239
299
// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
240
300
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
241
301
// / its corresponding embedding.
242
302
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
243
- ir2vec::Vocab Vocabulary;
303
+ using VocabVector = std::vector<ir2vec::Embedding>;
304
+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
305
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
306
+ VocabVector Vocab;
307
+
244
308
Error readVocabulary ();
245
309
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
246
- ir2vec::Vocab &TargetVocab, unsigned &Dim);
310
+ VocabMap &TargetVocab, unsigned &Dim);
311
+ void generateNumMappedVocab ();
247
312
void emitError (Error Err, LLVMContext &Ctx);
248
313
249
314
public:
250
315
LLVM_ABI static AnalysisKey Key;
251
316
IR2VecVocabAnalysis () = default ;
252
- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
253
- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
254
- using Result = IR2VecVocabResult ;
317
+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
318
+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
319
+ using Result = ir2vec::Vocabulary ;
255
320
LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
256
321
};
257
322
0 commit comments