@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
5252cl::opt<float > ArgWeight (" ir2vec-arg-weight" , cl::Optional, cl::init(0.2 ),
5353 cl::desc(" Weight for argument embeddings" ),
5454 cl::cat(IR2VecCategory));
55+ cl::opt<IR2VecKind> IR2VecEmbeddingKind (
56+ " ir2vec-kind" , cl::Optional,
57+ cl::values (clEnumValN(IR2VecKind::Symbolic, " symbolic" ,
58+ " Generate symbolic embeddings" ),
59+ clEnumValN(IR2VecKind::FlowAware, " flow-aware" ,
60+ " Generate flow-aware embeddings" )),
61+ cl::init(IR2VecKind::Symbolic), cl::desc(" IR2Vec embedding kind" ),
62+ cl::cat(IR2VecCategory));
63+
5564} // namespace ir2vec
5665} // namespace llvm
5766
@@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
123132 double Tolerance) const {
124133 assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
125134 for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
126- if (std::abs ((*this )[Itr] - RHS[Itr]) > Tolerance)
135+ if (std::abs ((*this )[Itr] - RHS[Itr]) > Tolerance) {
136+ LLVM_DEBUG (errs () << " Embedding mismatch at index " << Itr << " : "
137+ << (*this )[Itr] << " vs " << RHS[Itr]
138+ << " ; Tolerance: " << Tolerance << " \n " );
127139 return false ;
140+ }
128141 return true ;
129142}
130143
@@ -141,14 +154,16 @@ void Embedding::print(raw_ostream &OS) const {
141154
142155Embedder::Embedder (const Function &F, const Vocabulary &Vocab)
143156 : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
144- OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
145- }
157+ OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
158+ FuncVector(Embedding(Dimension, 0 )) { }
146159
147160std::unique_ptr<Embedder> Embedder::create (IR2VecKind Mode, const Function &F,
148161 const Vocabulary &Vocab) {
149162 switch (Mode) {
150163 case IR2VecKind::Symbolic:
151164 return std::make_unique<SymbolicEmbedder>(F, Vocab);
165+ case IR2VecKind::FlowAware:
166+ return std::make_unique<FlowAwareEmbedder>(F, Vocab);
152167 }
153168 return nullptr ;
154169}
@@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const {
180195 return FuncVector;
181196}
182197
198+ void Embedder::computeEmbeddings () const {
199+ if (F.isDeclaration ())
200+ return ;
201+
202+ // Consider only the basic blocks that are reachable from entry
203+ for (const BasicBlock *BB : depth_first (&F)) {
204+ computeEmbeddings (*BB);
205+ FuncVector += BBVecMap[BB];
206+ }
207+ }
208+
183209void SymbolicEmbedder::computeEmbeddings (const BasicBlock &BB) const {
184210 Embedding BBVector (Dimension, 0 );
185211
@@ -196,15 +222,38 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196222 BBVecMap[&BB] = BBVector;
197223}
198224
199- void SymbolicEmbedder::computeEmbeddings () const {
200- if (F.isDeclaration ())
201- return ;
225+ void FlowAwareEmbedder::computeEmbeddings (const BasicBlock &BB) const {
226+ Embedding BBVector (Dimension, 0 );
202227
203- // Consider only the basic blocks that are reachable from entry
204- for (const BasicBlock *BB : depth_first (&F)) {
205- computeEmbeddings (*BB);
206- FuncVector += BBVecMap[BB];
228+ // We consider only the non-debug and non-pseudo instructions
229+ for (const auto &I : BB.instructionsWithoutDebug ()) {
230+ // TODO: Handle call instructions differently.
231+ // For now, we treat them like other instructions
232+ Embedding ArgEmb (Dimension, 0 );
233+ for (const auto &Op : I.operands ()) {
234+ // If the operand is defined elsewhere, we use its embedding
235+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
236+ auto DefIt = InstVecMap.find (DefInst);
237+ assert (DefIt != InstVecMap.end () &&
238+ " Instruction should have been processed before its operands" );
239+ ArgEmb += DefIt->second ;
240+ continue ;
241+ }
242+ // If the operand is not defined by an instruction, we use the vocabulary
243+ else {
244+ LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
245+ << *Op << " =" << Vocab[Op][0 ] << " \n " );
246+ ArgEmb += Vocab[Op];
247+ }
248+ }
249+ // Create the instruction vector by combining opcode, type, and arguments
250+ // embeddings
251+ auto InstVector =
252+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
253+ InstVecMap[&I] = InstVector;
254+ BBVector += InstVector;
207255 }
256+ BBVecMap[&BB] = BBVector;
208257}
209258
210259// ==----------------------------------------------------------------------===//
@@ -552,8 +601,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
552601 assert (Vocabulary.isValid () && " IR2Vec Vocabulary is invalid" );
553602
554603 for (Function &F : M) {
555- std::unique_ptr<Embedder> Emb =
556- Embedder::create (IR2VecKind::Symbolic, F, Vocabulary);
604+ std::unique_ptr<Embedder> Emb;
605+ switch (IR2VecEmbeddingKind) {
606+ case IR2VecKind::Symbolic:
607+ Emb = std::make_unique<SymbolicEmbedder>(F, Vocabulary);
608+ break ;
609+ case IR2VecKind::FlowAware:
610+ Emb = std::make_unique<FlowAwareEmbedder>(F, Vocabulary);
611+ break ;
612+ default :
613+ llvm_unreachable (" Unknown IR2Vec embedding kind" );
614+ }
557615 if (!Emb) {
558616 OS << " Error creating IR2Vec embeddings \n " ;
559617 continue ;
0 commit comments