@@ -41,11 +41,18 @@ static cl::opt<std::string>
4141cl::opt<float > OpcWeight (" mir2vec-opc-weight" , cl::Optional, cl::init(1.0 ),
4242 cl::desc(" Weight for machine opcode embeddings" ),
4343 cl::cat(MIR2VecCategory));
44+ cl::opt<MIR2VecKind> MIR2VecEmbeddingKind (
45+ " mir2vec-kind" , cl::Optional,
46+ cl::values (clEnumValN(MIR2VecKind::Symbolic, " symbolic" ,
47+ " Generate symbolic embeddings for MIR" )),
48+ cl::init(MIR2VecKind::Symbolic), cl::desc(" MIR2Vec embedding kind" ),
49+ cl::cat(MIR2VecCategory));
50+
4451} // namespace mir2vec
4552} // namespace llvm
4653
4754// ===----------------------------------------------------------------------===//
48- // Vocabulary Implementation
55+ // Vocabulary
4956// ===----------------------------------------------------------------------===//
5057
5158MIRVocabulary::MIRVocabulary (VocabMap &&OpcodeEntries,
@@ -190,6 +197,29 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
190197 << " unique base opcodes\n " );
191198}
192199
200+ MIRVocabulary MIRVocabulary::createDummyVocabForTest (const TargetInstrInfo &TII,
201+ unsigned Dim) {
202+ assert (Dim > 0 && " Dimension must be greater than zero" );
203+
204+ float DummyVal = 0 .1f ;
205+
206+ // Create a temporary vocabulary instance to build canonical mapping
207+ MIRVocabulary TempVocab ({}, &TII);
208+ TempVocab.buildCanonicalOpcodeMapping ();
209+
210+ // Create dummy embeddings for all canonical opcode names
211+ VocabMap DummyVocabMap;
212+ for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames ) {
213+ // Create dummy embedding filled with DummyVal
214+ Embedding DummyEmbedding (Dim, DummyVal);
215+ DummyVocabMap[COpcodeName] = DummyEmbedding;
216+ DummyVal += 0 .1f ;
217+ }
218+
219+ // Create and return vocabulary with dummy embeddings
220+ return MIRVocabulary (std::move (DummyVocabMap), &TII);
221+ }
222+
193223// ===----------------------------------------------------------------------===//
194224// MIR2VecVocabLegacyAnalysis Implementation
195225// ===----------------------------------------------------------------------===//
@@ -267,7 +297,104 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
267297}
268298
269299// ===----------------------------------------------------------------------===//
270- // Printer Passes Implementation
300+ // MIREmbedder and its subclasses
301+ // ===----------------------------------------------------------------------===//
302+
303+ MIREmbedder::MIREmbedder (const MachineFunction &MF, const MIRVocabulary &Vocab)
304+ : MF(MF), Vocab(Vocab), Dimension(Vocab.getDimension()),
305+ OpcWeight(::OpcWeight), MFuncVector(Embedding(Dimension)) {}
306+
307+ std::unique_ptr<MIREmbedder> MIREmbedder::create (MIR2VecKind Mode,
308+ const MachineFunction &MF,
309+ const MIRVocabulary &Vocab) {
310+ switch (Mode) {
311+ case MIR2VecKind::Symbolic:
312+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
313+ }
314+ return nullptr ;
315+ }
316+
317+ const MachineInstEmbeddingsMap &MIREmbedder::getMInstVecMap () const {
318+ if (MInstVecMap.empty ())
319+ computeEmbeddings ();
320+ return MInstVecMap;
321+ }
322+
323+ const MachineBlockEmbeddingsMap &MIREmbedder::getMBBVecMap () const {
324+ if (MBBVecMap.empty ())
325+ computeEmbeddings ();
326+ return MBBVecMap;
327+ }
328+
329+ const Embedding &MIREmbedder::getMBBVector (const MachineBasicBlock &BB) const {
330+ auto It = MBBVecMap.find (&BB);
331+ if (It != MBBVecMap.end ())
332+ return It->second ;
333+ computeEmbeddings (BB);
334+ return MBBVecMap[&BB];
335+ }
336+
337+ const Embedding &MIREmbedder::getMFunctionVector () const {
338+ // Currently, we always (re)compute the embeddings for the function.
339+ // This is cheaper than caching the vector.
340+ computeEmbeddings ();
341+ return MFuncVector;
342+ }
343+
344+ void MIREmbedder::computeEmbeddings () const {
345+ // Reset function vector to zero before recomputing
346+ MFuncVector = Embedding (Dimension, 0.0 );
347+
348+ // Consider all machine basic blocks in the function
349+ for (const auto &MBB : MF) {
350+ computeEmbeddings (MBB);
351+ MFuncVector += MBBVecMap[&MBB];
352+ }
353+ }
354+
355+ SymbolicMIREmbedder::SymbolicMIREmbedder (const MachineFunction &MF,
356+ const MIRVocabulary &Vocab)
357+ : MIREmbedder(MF, Vocab) {}
358+
359+ std::unique_ptr<SymbolicMIREmbedder>
360+ SymbolicMIREmbedder::create (const MachineFunction &MF,
361+ const MIRVocabulary &Vocab) {
362+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
363+ }
364+
365+ void SymbolicMIREmbedder::computeEmbeddings (
366+ const MachineBasicBlock &MBB) const {
367+ Embedding MBBVector (Dimension, 0 );
368+
369+ // Get instruction info for opcode name resolution
370+ const auto &Subtarget = MF.getSubtarget ();
371+ const auto *TII = Subtarget.getInstrInfo ();
372+ if (!TII) {
373+ MF.getFunction ().getContext ().emitError (
374+ " MIR2Vec: No TargetInstrInfo available; cannot compute embeddings" );
375+ return ;
376+ }
377+
378+ // Process each machine instruction in the basic block
379+ for (const auto &MI : MBB) {
380+ // Skip debug instructions and other metadata
381+ if (MI.isDebugInstr ())
382+ continue ;
383+
384+ // Todo: Add operand/argument contributions
385+
386+ // Store the instruction embedding
387+ auto InstVector = Vocab[MI.getOpcode ()];
388+ MInstVecMap[&MI] = InstVector;
389+ MBBVector += InstVector;
390+ }
391+
392+ // Store the basic block embedding
393+ MBBVecMap[&MBB] = MBBVector;
394+ }
395+
396+ // ===----------------------------------------------------------------------===//
397+ // Printer Passes
271398// ===----------------------------------------------------------------------===//
272399
273400char MIR2VecVocabPrinterLegacyPass::ID = 0 ;
@@ -304,3 +431,67 @@ MachineFunctionPass *
304431llvm::createMIR2VecVocabPrinterLegacyPass (raw_ostream &OS) {
305432 return new MIR2VecVocabPrinterLegacyPass (OS);
306433}
434+
435+ char MIR2VecPrinterLegacyPass::ID = 0 ;
436+ INITIALIZE_PASS_BEGIN (MIR2VecPrinterLegacyPass, " print-mir2vec" ,
437+ " MIR2Vec Embedder Printer Pass" , false , true )
438+ INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
439+ INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
440+ INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, " print-mir2vec" ,
441+ " MIR2Vec Embedder Printer Pass" , false , true )
442+
443+ bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
444+ auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
445+ auto MIRVocab = Analysis.getMIR2VecVocabulary (*MF.getFunction ().getParent ());
446+
447+ if (!MIRVocab.isValid ()) {
448+ OS << " MIR2Vec Embedder Printer: Invalid vocabulary for function "
449+ << MF.getName () << " \n " ;
450+ return false ;
451+ }
452+
453+ auto Emb = mir2vec::MIREmbedder::create (MIR2VecEmbeddingKind, MF, MIRVocab);
454+ if (!Emb) {
455+ OS << " Error creating MIR2Vec embeddings for function " << MF.getName ()
456+ << " \n " ;
457+ return false ;
458+ }
459+
460+ OS << " MIR2Vec embeddings for machine function " << MF.getName () << " :\n " ;
461+ OS << " Machine Function vector: " ;
462+ Emb->getMFunctionVector ().print (OS);
463+
464+ OS << " Machine basic block vectors:\n " ;
465+ const auto &MBBMap = Emb->getMBBVecMap ();
466+ for (const MachineBasicBlock &MBB : MF) {
467+ auto It = MBBMap.find (&MBB);
468+ if (It != MBBMap.end ()) {
469+ OS << " Machine basic block: " << MBB.getFullName () << " :\n " ;
470+ It->second .print (OS);
471+ }
472+ }
473+
474+ OS << " Machine instruction vectors:\n " ;
475+ const auto &MInstMap = Emb->getMInstVecMap ();
476+ for (const MachineBasicBlock &MBB : MF) {
477+ for (const MachineInstr &MI : MBB) {
478+ // Skip debug instructions as they are not
479+ // embedded
480+ if (MI.isDebugInstr ())
481+ continue ;
482+
483+ auto It = MInstMap.find (&MI);
484+ if (It != MInstMap.end ()) {
485+ OS << " Machine instruction: " ;
486+ MI.print (OS);
487+ It->second .print (OS);
488+ }
489+ }
490+ }
491+
492+ return false ;
493+ }
494+
495+ MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass (raw_ostream &OS) {
496+ return new MIR2VecPrinterLegacyPass (OS);
497+ }
0 commit comments