|
12 | 12 | //===----------------------------------------------------------------------===// |
13 | 13 |
|
14 | 14 | #include "llvm/CodeGen/MIR2Vec.h" |
| 15 | +#include "llvm/ADT/DepthFirstIterator.h" |
15 | 16 | #include "llvm/ADT/Statistic.h" |
16 | 17 | #include "llvm/CodeGen/TargetInstrInfo.h" |
17 | 18 | #include "llvm/IR/Module.h" |
@@ -41,11 +42,18 @@ static cl::opt<std::string> |
41 | 42 | cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), |
42 | 43 | cl::desc("Weight for machine opcode embeddings"), |
43 | 44 | cl::cat(MIR2VecCategory)); |
| 45 | +cl::opt<MIR2VecKind> MIR2VecEmbeddingKind( |
| 46 | + "mir2vec-kind", cl::Optional, |
| 47 | + cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", |
| 48 | + "Generate symbolic embeddings for MIR")), |
| 49 | + cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"), |
| 50 | + cl::cat(MIR2VecCategory)); |
| 51 | + |
44 | 52 | } // namespace mir2vec |
45 | 53 | } // namespace llvm |
46 | 54 |
|
47 | 55 | //===----------------------------------------------------------------------===// |
48 | | -// Vocabulary Implementation |
| 56 | +// Vocabulary |
49 | 57 | //===----------------------------------------------------------------------===// |
50 | 58 |
|
51 | 59 | MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, |
@@ -191,6 +199,30 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() { |
191 | 199 | << " unique base opcodes\n"); |
192 | 200 | } |
193 | 201 |
|
| 202 | +Expected<MIRVocabulary> |
| 203 | +MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII, |
| 204 | + unsigned Dim) { |
| 205 | + assert(Dim > 0 && "Dimension must be greater than zero"); |
| 206 | + |
| 207 | + float DummyVal = 0.1f; |
| 208 | + |
| 209 | + // Create a temporary vocabulary instance to build canonical mapping |
| 210 | + MIRVocabulary TempVocab({}, TII); |
| 211 | + TempVocab.buildCanonicalOpcodeMapping(); |
| 212 | + |
| 213 | + // Create dummy embeddings for all canonical opcode names |
| 214 | + VocabMap DummyVocabMap; |
| 215 | + for (const auto &COpcodeName : TempVocab.UniqueBaseOpcodeNames) { |
| 216 | + // Create dummy embedding filled with DummyVal |
| 217 | + Embedding DummyEmbedding(Dim, DummyVal); |
| 218 | + DummyVocabMap[COpcodeName] = DummyEmbedding; |
| 219 | + DummyVal += 0.1f; |
| 220 | + } |
| 221 | + |
| 222 | + // Create and return vocabulary with dummy embeddings |
| 223 | + return MIRVocabulary::create(std::move(DummyVocabMap), TII); |
| 224 | +} |
| 225 | + |
194 | 226 | //===----------------------------------------------------------------------===// |
195 | 227 | // MIR2VecVocabLegacyAnalysis Implementation |
196 | 228 | //===----------------------------------------------------------------------===// |
@@ -261,7 +293,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { |
261 | 293 | } |
262 | 294 |
|
263 | 295 | //===----------------------------------------------------------------------===// |
264 | | -// Printer Passes Implementation |
| 296 | +// MIREmbedder and its subclasses |
| 297 | +//===----------------------------------------------------------------------===// |
| 298 | + |
| 299 | +std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode, |
| 300 | + const MachineFunction &MF, |
| 301 | + const MIRVocabulary &Vocab) { |
| 302 | + switch (Mode) { |
| 303 | + case MIR2VecKind::Symbolic: |
| 304 | + return std::make_unique<SymbolicMIREmbedder>(MF, Vocab); |
| 305 | + } |
| 306 | + return nullptr; |
| 307 | +} |
| 308 | + |
| 309 | +Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const { |
| 310 | + Embedding MBBVector(Dimension, 0); |
| 311 | + |
| 312 | + // Get instruction info for opcode name resolution |
| 313 | + const auto &Subtarget = MF.getSubtarget(); |
| 314 | + const auto *TII = Subtarget.getInstrInfo(); |
| 315 | + if (!TII) { |
| 316 | + MF.getFunction().getContext().emitError( |
| 317 | + "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings"); |
| 318 | + return MBBVector; |
| 319 | + } |
| 320 | + |
| 321 | + // Process each machine instruction in the basic block |
| 322 | + for (const auto &MI : MBB) { |
| 323 | + // Skip debug instructions and other metadata |
| 324 | + if (MI.isDebugInstr()) |
| 325 | + continue; |
| 326 | + MBBVector += computeEmbeddings(MI); |
| 327 | + } |
| 328 | + |
| 329 | + return MBBVector; |
| 330 | +} |
| 331 | + |
| 332 | +Embedding MIREmbedder::computeEmbeddings() const { |
| 333 | + Embedding MFuncVector(Dimension, 0); |
| 334 | + |
| 335 | + // Consider all reachable machine basic blocks in the function |
| 336 | + for (const auto *MBB : depth_first(&MF)) |
| 337 | + MFuncVector += computeEmbeddings(*MBB); |
| 338 | + return MFuncVector; |
| 339 | +} |
| 340 | + |
| 341 | +SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF, |
| 342 | + const MIRVocabulary &Vocab) |
| 343 | + : MIREmbedder(MF, Vocab) {} |
| 344 | + |
| 345 | +std::unique_ptr<SymbolicMIREmbedder> |
| 346 | +SymbolicMIREmbedder::create(const MachineFunction &MF, |
| 347 | + const MIRVocabulary &Vocab) { |
| 348 | + return std::make_unique<SymbolicMIREmbedder>(MF, Vocab); |
| 349 | +} |
| 350 | + |
| 351 | +Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const { |
| 352 | + // Skip debug instructions and other metadata |
| 353 | + if (MI.isDebugInstr()) |
| 354 | + return Embedding(Dimension, 0); |
| 355 | + |
| 356 | + // Todo: Add operand/argument contributions |
| 357 | + |
| 358 | + return Vocab[MI.getOpcode()]; |
| 359 | +} |
| 360 | + |
| 361 | +//===----------------------------------------------------------------------===// |
| 362 | +// Printer Passes |
265 | 363 | //===----------------------------------------------------------------------===// |
266 | 364 |
|
267 | 365 | char MIR2VecVocabPrinterLegacyPass::ID = 0; |
@@ -300,3 +398,56 @@ MachineFunctionPass * |
300 | 398 | llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) { |
301 | 399 | return new MIR2VecVocabPrinterLegacyPass(OS); |
302 | 400 | } |
| 401 | + |
| 402 | +char MIR2VecPrinterLegacyPass::ID = 0; |
| 403 | +INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec", |
| 404 | + "MIR2Vec Embedder Printer Pass", false, true) |
| 405 | +INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis) |
| 406 | +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) |
| 407 | +INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec", |
| 408 | + "MIR2Vec Embedder Printer Pass", false, true) |
| 409 | + |
| 410 | +bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { |
| 411 | + auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>(); |
| 412 | + auto VocabOrErr = |
| 413 | + Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent()); |
| 414 | + assert(VocabOrErr && "Failed to get MIR2Vec vocabulary"); |
| 415 | + auto &MIRVocab = *VocabOrErr; |
| 416 | + |
| 417 | + auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab); |
| 418 | + if (!Emb) { |
| 419 | + OS << "Error creating MIR2Vec embeddings for function " << MF.getName() |
| 420 | + << "\n"; |
| 421 | + return false; |
| 422 | + } |
| 423 | + |
| 424 | + OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; |
| 425 | + OS << "Machine Function vector: "; |
| 426 | + Emb->getMFunctionVector().print(OS); |
| 427 | + |
| 428 | + OS << "Machine basic block vectors:\n"; |
| 429 | + for (const MachineBasicBlock &MBB : MF) { |
| 430 | + OS << "Machine basic block: " << MBB.getFullName() << ":\n"; |
| 431 | + Emb->getMBBVector(MBB).print(OS); |
| 432 | + } |
| 433 | + |
| 434 | + OS << "Machine instruction vectors:\n"; |
| 435 | + for (const MachineBasicBlock &MBB : MF) { |
| 436 | + for (const MachineInstr &MI : MBB) { |
| 437 | + // Skip debug instructions as they are not |
| 438 | + // embedded |
| 439 | + if (MI.isDebugInstr()) |
| 440 | + continue; |
| 441 | + |
| 442 | + OS << "Machine instruction: "; |
| 443 | + MI.print(OS); |
| 444 | + Emb->getMInstVector(MI).print(OS); |
| 445 | + } |
| 446 | + } |
| 447 | + |
| 448 | + return false; |
| 449 | +} |
| 450 | + |
| 451 | +MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) { |
| 452 | + return new MIR2VecPrinterLegacyPass(OS); |
| 453 | +} |
0 commit comments