1212// ===----------------------------------------------------------------------===//
1313
1414#include " llvm/CodeGen/MIR2Vec.h"
15+ #include " llvm/ADT/DepthFirstIterator.h"
1516#include " llvm/ADT/Statistic.h"
1617#include " llvm/CodeGen/TargetInstrInfo.h"
1718#include " llvm/IR/Module.h"
@@ -29,20 +30,30 @@ using namespace mir2vec;
2930STATISTIC (MIRVocabMissCounter,
3031 " Number of lookups to MIR entities not present in the vocabulary" );
3132
32- cl::OptionCategory llvm::mir2vec::MIR2VecCategory (" MIR2Vec Options" );
33+ namespace llvm {
34+ namespace mir2vec {
35+ cl::OptionCategory MIR2VecCategory (" MIR2Vec Options" );
3336
3437// FIXME: Use a default vocab when not specified
3538static cl::opt<std::string>
3639 VocabFile (" mir2vec-vocab-path" , cl::Optional,
3740 cl::desc (" Path to the vocabulary file for MIR2Vec" ), cl::init(" " ),
3841 cl::cat(MIR2VecCategory));
39- cl::opt<float >
40- llvm::mir2vec::OpcWeight (" mir2vec-opc-weight" , cl::Optional, cl::init(1.0 ),
41- cl::desc(" Weight for machine opcode embeddings" ),
42- cl::cat(MIR2VecCategory));
42+ cl::opt<float > OpcWeight (" mir2vec-opc-weight" , cl::Optional, cl::init(1.0 ),
43+ cl::desc(" Weight for machine opcode embeddings" ),
44+ cl::cat(MIR2VecCategory));
45+ cl::opt<MIR2VecKind> MIR2VecEmbeddingKind (
46+ " mir2vec-kind" , cl::Optional,
47+ cl::values (clEnumValN(MIR2VecKind::Symbolic, " symbolic" ,
48+ " Generate symbolic embeddings for MIR" )),
49+ cl::init(MIR2VecKind::Symbolic), cl::desc(" MIR2Vec embedding kind" ),
50+ cl::cat(MIR2VecCategory));
51+
52+ } // namespace mir2vec
53+ } // namespace llvm
4354
4455// ===----------------------------------------------------------------------===//
45- // Vocabulary Implementation
56+ // Vocabulary
4657// ===----------------------------------------------------------------------===//
4758
4859MIRVocabulary::MIRVocabulary (VocabMap &&OpcodeEntries,
@@ -188,6 +199,30 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
188199 << " unique base opcodes\n " );
189200}
190201
202+ Expected<MIRVocabulary>
203+ MIRVocabulary::createDummyVocabForTest (const TargetInstrInfo &TII,
204+ unsigned Dim) {
205+ assert (Dim > 0 && " Dimension must be greater than zero" );
206+
207+ float DummyVal = 0 .1f ;
208+
209+ // Create a temporary vocabulary instance to build canonical mapping
210+ MIRVocabulary TempVocab ({}, TII);
211+ TempVocab.buildCanonicalOpcodeMapping ();
212+
213+ // Create dummy embeddings for all canonical opcode names
214+ VocabMap DummyVocabMap;
215+ for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames ) {
216+ // Create dummy embedding filled with DummyVal
217+ Embedding DummyEmbedding (Dim, DummyVal);
218+ DummyVocabMap[COpcodeName] = DummyEmbedding;
219+ DummyVal += 0 .1f ;
220+ }
221+
222+ // Create and return vocabulary with dummy embeddings
223+ return MIRVocabulary::create (std::move (DummyVocabMap), TII);
224+ }
225+
191226// ===----------------------------------------------------------------------===//
192227// MIR2VecVocabLegacyAnalysis Implementation
193228// ===----------------------------------------------------------------------===//
@@ -258,7 +293,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
258293}
259294
260295// ===----------------------------------------------------------------------===//
261- // Printer Passes Implementation
296+ // MIREmbedder and its subclasses
297+ // ===----------------------------------------------------------------------===//
298+
299+ std::unique_ptr<MIREmbedder> MIREmbedder::create (MIR2VecKind Mode,
300+ const MachineFunction &MF,
301+ const MIRVocabulary &Vocab) {
302+ switch (Mode) {
303+ case MIR2VecKind::Symbolic:
304+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
305+ }
306+ return nullptr ;
307+ }
308+
309+ Embedding MIREmbedder::computeEmbeddings (const MachineBasicBlock &MBB) const {
310+ Embedding MBBVector (Dimension, 0 );
311+
312+ // Get instruction info for opcode name resolution
313+ const auto &Subtarget = MF.getSubtarget ();
314+ const auto *TII = Subtarget.getInstrInfo ();
315+ if (!TII) {
316+ MF.getFunction ().getContext ().emitError (
317+ " MIR2Vec: No TargetInstrInfo available; cannot compute embeddings" );
318+ return MBBVector;
319+ }
320+
321+ // Process each machine instruction in the basic block
322+ for (const auto &MI : MBB) {
323+ // Skip debug instructions and other metadata
324+ if (MI.isDebugInstr ())
325+ continue ;
326+ MBBVector += computeEmbeddings (MI);
327+ }
328+
329+ return MBBVector;
330+ }
331+
332+ Embedding MIREmbedder::computeEmbeddings () const {
333+ Embedding MFuncVector (Dimension, 0 );
334+
335+ // Consider all reachable machine basic blocks in the function
336+ for (const auto *MBB : depth_first (&MF))
337+ MFuncVector += computeEmbeddings (*MBB);
338+ return MFuncVector;
339+ }
340+
341+ SymbolicMIREmbedder::SymbolicMIREmbedder (const MachineFunction &MF,
342+ const MIRVocabulary &Vocab)
343+ : MIREmbedder(MF, Vocab) {}
344+
345+ std::unique_ptr<SymbolicMIREmbedder>
346+ SymbolicMIREmbedder::create (const MachineFunction &MF,
347+ const MIRVocabulary &Vocab) {
348+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
349+ }
350+
351+ Embedding SymbolicMIREmbedder::computeEmbeddings (const MachineInstr &MI) const {
352+ // Skip debug instructions and other metadata
353+ if (MI.isDebugInstr ())
354+ return Embedding (Dimension, 0 );
355+
356+ // Todo: Add operand/argument contributions
357+
358+ return Vocab[MI.getOpcode ()];
359+ }
360+
361+ // ===----------------------------------------------------------------------===//
362+ // Printer Passes
262363// ===----------------------------------------------------------------------===//
263364
264365char MIR2VecVocabPrinterLegacyPass::ID = 0 ;
@@ -297,3 +398,56 @@ MachineFunctionPass *
297398llvm::createMIR2VecVocabPrinterLegacyPass (raw_ostream &OS) {
298399 return new MIR2VecVocabPrinterLegacyPass (OS);
299400}
401+
402+ char MIR2VecPrinterLegacyPass::ID = 0 ;
403+ INITIALIZE_PASS_BEGIN (MIR2VecPrinterLegacyPass, " print-mir2vec" ,
404+ " MIR2Vec Embedder Printer Pass" , false , true )
405+ INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
406+ INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
407+ INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, " print-mir2vec" ,
408+ " MIR2Vec Embedder Printer Pass" , false , true )
409+
410+ bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
411+ auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
412+ auto VocabOrErr =
413+ Analysis.getMIR2VecVocabulary (*MF.getFunction ().getParent ());
414+ assert (VocabOrErr && " Failed to get MIR2Vec vocabulary" );
415+ auto &MIRVocab = *VocabOrErr;
416+
417+ auto Emb = mir2vec::MIREmbedder::create (MIR2VecEmbeddingKind, MF, MIRVocab);
418+ if (!Emb) {
419+ OS << " Error creating MIR2Vec embeddings for function " << MF.getName ()
420+ << " \n " ;
421+ return false ;
422+ }
423+
424+ OS << " MIR2Vec embeddings for machine function " << MF.getName () << " :\n " ;
425+ OS << " Machine Function vector: " ;
426+ Emb->getMFunctionVector ().print (OS);
427+
428+ OS << " Machine basic block vectors:\n " ;
429+ for (const MachineBasicBlock &MBB : MF) {
430+ OS << " Machine basic block: " << MBB.getFullName () << " :\n " ;
431+ Emb->getMBBVector (MBB).print (OS);
432+ }
433+
434+ OS << " Machine instruction vectors:\n " ;
435+ for (const MachineBasicBlock &MBB : MF) {
436+ for (const MachineInstr &MI : MBB) {
437+ // Skip debug instructions as they are not
438+ // embedded
439+ if (MI.isDebugInstr ())
440+ continue ;
441+
442+ OS << " Machine instruction: " ;
443+ MI.print (OS);
444+ Emb->getMInstVector (MI).print (OS);
445+ }
446+ }
447+
448+ return false ;
449+ }
450+
451+ MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass (raw_ostream &OS) {
452+ return new MIR2VecPrinterLegacyPass (OS);
453+ }
0 commit comments