@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
216216 ArgEmb += Vocab[*Op];
217217 auto InstVector =
218218 Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
219+ if (const auto *IC = dyn_cast<CmpInst>(&I))
220+ InstVector += Vocab[IC->getPredicate ()];
219221 InstVecMap[&I] = InstVector;
220222 BBVector += InstVector;
221223 }
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
250252 // embeddings
251253 auto InstVector =
252254 Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
255+ // Add compare predicate embedding as an additional operand if applicable
256+ if (const auto *IC = dyn_cast<CmpInst>(&I))
257+ InstVector += Vocab[IC->getPredicate ()];
253258 InstVecMap[&I] = InstVector;
254259 BBVector += InstVector;
255260 }
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
285290unsigned Vocabulary::getSlotIndex (const Value &Op) {
286291 unsigned Index = static_cast <unsigned >(getOperandKind (&Op));
287292 assert (Index < MaxOperandKinds && " Invalid OperandKind" );
288- return MaxOpcodes + MaxCanonicalTypeIDs + Index;
293+ return OperandBaseOffset + Index;
294+ }
295+
296+ unsigned Vocabulary::getSlotIndex (CmpInst::Predicate P) {
297+ unsigned PU = static_cast <unsigned >(P);
298+ unsigned FirstFC = static_cast <unsigned >(CmpInst::FIRST_FCMP_PREDICATE);
299+ unsigned FirstIC = static_cast <unsigned >(CmpInst::FIRST_ICMP_PREDICATE);
300+
301+ unsigned PredIdx =
302+ (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
303+ return PredicateBaseOffset + PredIdx;
289304}
290305
291306const Embedding &Vocabulary::operator [](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
300315 return Vocab[getSlotIndex (Arg)];
301316}
302317
318+ const ir2vec::Embedding &Vocabulary::operator [](CmpInst::Predicate P) const {
319+ return Vocab[getSlotIndex (P)];
320+ }
321+
303322StringRef Vocabulary::getVocabKeyForOpcode (unsigned Opcode) {
304323 assert (Opcode >= 1 && Opcode <= MaxOpcodes && " Invalid opcode" );
305324#define HANDLE_INST (NUM, OPCODE, CLASS ) \
@@ -345,18 +364,35 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
345364 return OperandKind::VariableID;
346365}
347366
367+ CmpInst::Predicate Vocabulary::getPredicate (unsigned Index) {
368+ assert (Index < MaxPredicateKinds && " Invalid predicate index" );
369+ unsigned PredEnumVal =
370+ (Index < NumFCmpPredicates)
371+ ? (static_cast <unsigned >(CmpInst::FIRST_FCMP_PREDICATE) + Index)
372+ : (static_cast <unsigned >(CmpInst::FIRST_ICMP_PREDICATE) +
373+ (Index - NumFCmpPredicates));
374+ return static_cast <CmpInst::Predicate>(PredEnumVal);
375+ }
376+
377+ StringRef Vocabulary::getVocabKeyForPredicate (CmpInst::Predicate Pred) {
378+ return CmpInst::getPredicateName (Pred);
379+ }
380+
348381StringRef Vocabulary::getStringKey (unsigned Pos) {
349382 assert (Pos < NumCanonicalEntries && " Position out of bounds in vocabulary" );
350383 // Opcode
351384 if (Pos < MaxOpcodes)
352385 return getVocabKeyForOpcode (Pos + 1 );
353386 // Type
354- if (Pos < MaxOpcodes + MaxCanonicalTypeIDs )
387+ if (Pos < OperandBaseOffset )
355388 return getVocabKeyForCanonicalTypeID (
356389 static_cast <CanonicalTypeID>(Pos - MaxOpcodes));
357390 // Operand
358- return getVocabKeyForOperandKind (
359- static_cast <OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
391+ if (Pos < PredicateBaseOffset)
392+ return getVocabKeyForOperandKind (
393+ static_cast <OperandKind>(Pos - OperandBaseOffset));
394+ // Predicates
395+ return getVocabKeyForPredicate (getPredicate (Pos - PredicateBaseOffset));
360396}
361397
362398// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +406,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
370406 VocabVector DummyVocab;
371407 DummyVocab.reserve (NumCanonicalEntries);
372408 float DummyVal = 0 .1f ;
373- // Create a dummy vocabulary with entries for all opcodes, types, and
374- // operands
375- for ([[maybe_unused]] unsigned _ :
376- seq (0u , Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
377- Vocabulary::MaxOperandKinds)) {
409+ // Create a dummy vocabulary with entries for all opcodes, types, operands
410+ // and predicates
411+ for ([[maybe_unused]] unsigned _ : seq (0u , Vocabulary::NumCanonicalEntries)) {
378412 DummyVocab.push_back (Embedding (Dim, DummyVal));
379413 DummyVal += 0 .1f ;
380414 }
@@ -517,6 +551,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
517551 }
518552 Vocab.insert (Vocab.end (), NumericArgEmbeddings.begin (),
519553 NumericArgEmbeddings.end ());
554+
555+ // Handle Predicates: part of Operands section. We look up predicate keys
556+ // in ArgVocab.
557+ std::vector<Embedding> NumericPredEmbeddings (Vocabulary::MaxPredicateKinds,
558+ Embedding (Dim, 0 ));
559+ NumericPredEmbeddings.reserve (Vocabulary::MaxPredicateKinds);
560+ for (unsigned PK : seq (0u , Vocabulary::MaxPredicateKinds)) {
561+ StringRef VocabKey =
562+ Vocabulary::getVocabKeyForPredicate (Vocabulary::getPredicate (PK));
563+ auto It = ArgVocab.find (VocabKey.str ());
564+ if (It != ArgVocab.end ()) {
565+ NumericPredEmbeddings[PK] = It->second ;
566+ continue ;
567+ }
568+ handleMissingEntity (VocabKey.str ());
569+ }
570+ Vocab.insert (Vocab.end (), NumericPredEmbeddings.begin (),
571+ NumericPredEmbeddings.end ());
520572}
521573
522574IR2VecVocabAnalysis::IR2VecVocabAnalysis (const VocabVector &Vocab)
0 commit comments