@@ -41,11 +41,18 @@ static cl::opt<std::string>
4141cl::opt<float > OpcWeight (" mir2vec-opc-weight" , cl::Optional, cl::init(1.0 ),
4242 cl::desc(" Weight for machine opcode embeddings" ),
4343 cl::cat(MIR2VecCategory));
44+ cl::opt<MIR2VecKind> MIR2VecEmbeddingKind (
45+ " mir2vec-kind" , cl::Optional,
46+ cl::values (clEnumValN(MIR2VecKind::Symbolic, " symbolic" ,
47+ " Generate symbolic embeddings for MIR" )),
48+ cl::init(MIR2VecKind::Symbolic), cl::desc(" MIR2Vec embedding kind" ),
49+ cl::cat(MIR2VecCategory));
50+
4451} // namespace mir2vec
4552} // namespace llvm
4653
4754// ===----------------------------------------------------------------------===//
48- // Vocabulary Implementation
55+ // Vocabulary
4956// ===----------------------------------------------------------------------===//
5057
5158MIRVocabulary::MIRVocabulary (VocabMap &&OpcodeEntries,
@@ -190,6 +197,28 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
190197 << " unique base opcodes\n " );
191198}
192199
200+ MIRVocabulary MIRVocabulary::createDummyVocabForTest (const TargetInstrInfo &TII,
201+ unsigned Dim) {
202+ assert (Dim > 0 && " Dimension must be greater than zero" );
203+
204+ float DummyVal = 0 .1f ;
205+
206+ // Create a temporary vocabulary instance to build canonical mapping
207+ MIRVocabulary TempVocab ({}, &TII);
208+ TempVocab.buildCanonicalOpcodeMapping ();
209+
210+ // Create dummy embeddings for all canonical opcode names
211+ VocabMap DummyVocabMap;
212+ for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames ) {
213+ // Create dummy embedding filled with DummyVal
214+ Embedding DummyEmbedding (Dim, DummyVal);
215+ DummyVocabMap[COpcodeName] = DummyEmbedding;
216+ }
217+
218+ // Create and return vocabulary with dummy embeddings
219+ return MIRVocabulary (std::move (DummyVocabMap), &TII);
220+ }
221+
193222// ===----------------------------------------------------------------------===//
194223// MIR2VecVocabLegacyAnalysis Implementation
195224// ===----------------------------------------------------------------------===//
@@ -267,7 +296,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
267296}
268297
269298// ===----------------------------------------------------------------------===//
270- // Printer Passes Implementation
299+ // MIREmbedder and its subclasses
300+ // ===----------------------------------------------------------------------===//
301+
302+ MIREmbedder::MIREmbedder (const MachineFunction &MF, const MIRVocabulary &Vocab)
303+ : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
304+ OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {}
305+
306+ std::unique_ptr<MIREmbedder> MIREmbedder::create (MIR2VecKind Mode,
307+ const MachineFunction &MF,
308+ const MIRVocabulary &Vocab) {
309+ switch (Mode) {
310+ case MIR2VecKind::Symbolic:
311+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
312+ }
313+ return nullptr ;
314+ }
315+
316+ const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap () const {
317+ if (MInstVecMap.empty ())
318+ computeEmbeddings ();
319+ return MInstVecMap;
320+ }
321+
322+ const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap () const {
323+ if (MBBVecMap.empty ())
324+ computeEmbeddings ();
325+ return MBBVecMap;
326+ }
327+
328+ const Embedding &MIREmbedder::getMBBVector (const MachineBasicBlock &BB) const {
329+ auto It = MBBVecMap.find (&BB);
330+ if (It != MBBVecMap.end ())
331+ return It->second ;
332+ computeEmbeddings (BB);
333+ return MBBVecMap[&BB];
334+ }
335+
336+ const Embedding &MIREmbedder::getMFunctionVector () const {
337+ // Currently, we always (re)compute the embeddings for the function.
338+ // This is cheaper than caching the vector.
339+ computeEmbeddings ();
340+ return MFuncVector;
341+ }
342+
343+ void MIREmbedder::computeEmbeddings () const {
344+ // Reset function vector to zero before recomputing
345+ MFuncVector = Embedding (Dimension, 0.0 );
346+
347+ // Consider all machine basic blocks in the function
348+ for (const auto &MBB : MF) {
349+ computeEmbeddings (MBB);
350+ MFuncVector += MBBVecMap[&MBB];
351+ }
352+ }
353+
354+ SymbolicMIREmbedder::SymbolicMIREmbedder (const MachineFunction &MF,
355+ const MIRVocabulary &Vocab)
356+ : MIREmbedder(MF, Vocab) {}
357+
358+ std::unique_ptr<SymbolicMIREmbedder>
359+ SymbolicMIREmbedder::create (const MachineFunction &MF,
360+ const MIRVocabulary &Vocab) {
361+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
362+ }
363+
364+ void SymbolicMIREmbedder::computeEmbeddings (
365+ const MachineBasicBlock &MBB) const {
366+ Embedding MBBVector (Dimension, 0 );
367+
368+ // Get instruction info for opcode name resolution
369+ const auto &Subtarget = MF.getSubtarget ();
370+ const auto *TII = Subtarget.getInstrInfo ();
371+ if (!TII) {
372+ MF.getFunction ().getContext ().emitError (
373+ " MIR2Vec: No TargetInstrInfo available; cannot compute embeddings" );
374+ return ;
375+ }
376+
377+ // Process each machine instruction in the basic block
378+ for (const auto &MI : MBB) {
379+ // Skip debug instructions and other metadata
380+ if (MI.isDebugInstr ())
381+ continue ;
382+
383+ // Todo: Add operand/argument contributions
384+
385+ // Store the instruction embedding
386+ auto InstVector = Vocab[MI.getOpcode ()];
387+ MInstVecMap[&MI] = InstVector;
388+ MBBVector += InstVector;
389+ }
390+
391+ // Store the basic block embedding
392+ MBBVecMap[&MBB] = MBBVector;
393+ }
394+
395+ // ===----------------------------------------------------------------------===//
396+ // Printer Passes
271397// ===----------------------------------------------------------------------===//
272398
273399char MIR2VecVocabPrinterLegacyPass::ID = 0 ;
@@ -304,3 +430,67 @@ MachineFunctionPass *
304430llvm::createMIR2VecVocabPrinterLegacyPass (raw_ostream &OS) {
305431 return new MIR2VecVocabPrinterLegacyPass (OS);
306432}
433+
434+ char MIR2VecPrinterLegacyPass::ID = 0 ;
435+ INITIALIZE_PASS_BEGIN (MIR2VecPrinterLegacyPass, " print-mir2vec" ,
436+ " MIR2Vec Embedder Printer Pass" , false , true )
437+ INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
438+ INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
439+ INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, " print-mir2vec" ,
440+ " MIR2Vec Embedder Printer Pass" , false , true )
441+
442+ bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
443+ auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
444+ auto MIRVocab = Analysis.getMIR2VecVocabulary (*MF.getFunction ().getParent ());
445+
446+ if (!MIRVocab.isValid ()) {
447+ OS << " MIR2Vec Embedder Printer: Invalid vocabulary for function "
448+ << MF.getName () << " \n " ;
449+ return false ;
450+ }
451+
452+ auto Emb = mir2vec::MIREmbedder::create (MIR2VecEmbeddingKind, MF, MIRVocab);
453+ if (!Emb) {
454+ OS << " Error creating MIR2Vec embeddings for function " << MF.getName ()
455+ << " \n " ;
456+ return false ;
457+ }
458+
459+ OS << " MIR2Vec embeddings for machine function " << MF.getName () << " :\n " ;
460+ OS << " Machine Function vector: " ;
461+ Emb->getMFunctionVector ().print (OS);
462+
463+ OS << " Machine basic block vectors:\n " ;
464+ const auto &MBBMap = Emb->getMBBVecMap ();
465+ for (const MachineBasicBlock &MBB : MF) {
466+ auto It = MBBMap.find (&MBB);
467+ if (It != MBBMap.end ()) {
468+ OS << " Machine basic block: " << MBB.getFullName () << " :\n " ;
469+ It->second .print (OS);
470+ }
471+ }
472+
473+ OS << " Machine instruction vectors:\n " ;
474+ const auto &MInstMap = Emb->getMInstVecMap ();
475+ for (const MachineBasicBlock &MBB : MF) {
476+ for (const MachineInstr &MI : MBB) {
477+ // Skip debug instructions as they are not
478+ // embedded
479+ if (MI.isDebugInstr ())
480+ continue ;
481+
482+ auto It = MInstMap.find (&MI);
483+ if (It != MInstMap.end ()) {
484+ OS << " Machine instruction: " ;
485+ MI.print (OS);
486+ It->second .print (OS);
487+ }
488+ }
489+ }
490+
491+ return false ;
492+ }
493+
494+ MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass (raw_ostream &OS) {
495+ return new MIR2VecPrinterLegacyPass (OS);
496+ }
0 commit comments