@@ -261,25 +261,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
261
261
EXPECT_EQ (validResult.getDimension (), 2u );
262
262
}
263
263
264
- // Helper to create a minimal function and embedder for getter tests
265
- struct GetterTestEnv {
266
- Vocab V = {};
264
+ // Fixture for IR2Vec tests requiring IR setup and weight management.
265
+ class IR2VecTestFixture : public ::testing::Test {
266
+ protected:
267
+ Vocab V;
267
268
LLVMContext Ctx;
268
- std::unique_ptr<Module> M = nullptr ;
269
+ std::unique_ptr<Module> M;
269
270
Function *F = nullptr ;
270
271
BasicBlock *BB = nullptr ;
271
- Instruction *Add = nullptr ;
272
- Instruction *Ret = nullptr ;
273
- std::unique_ptr<Embedder> Emb = nullptr ;
272
+ Instruction *AddInst = nullptr ;
273
+ Instruction *RetInst = nullptr ;
274
274
275
- GetterTestEnv () {
275
+ float OriginalOpcWeight = ::OpcWeight;
276
+ float OriginalTypeWeight = ::TypeWeight;
277
+ float OriginalArgWeight = ::ArgWeight;
278
+
279
+ void SetUp () override {
276
280
V = {{" add" , {1.0 , 2.0 }},
277
281
{" integerTy" , {0.5 , 0.5 }},
278
282
{" constant" , {0.2 , 0.3 }},
279
283
{" variable" , {0.0 , 0.0 }},
280
284
{" unknownTy" , {0.0 , 0.0 }}};
281
285
282
- M = std::make_unique<Module>(" M" , Ctx);
286
+ // Setup IR
287
+ M = std::make_unique<Module>(" TestM" , Ctx);
283
288
FunctionType *FTy = FunctionType::get (
284
289
Type::getInt32Ty (Ctx), {Type::getInt32Ty (Ctx), Type::getInt32Ty (Ctx)},
285
290
false );
@@ -288,61 +293,82 @@ struct GetterTestEnv {
288
293
Argument *Arg = F->getArg (0 );
289
294
llvm::Value *Const = ConstantInt::get (Type::getInt32Ty (Ctx), 42 );
290
295
291
- Add = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
292
- Ret = ReturnInst::Create (Ctx, Add, BB);
296
+ AddInst = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
297
+ RetInst = ReturnInst::Create (Ctx, AddInst, BB);
298
+ }
299
+
300
+ void setWeights (float OpcWeight, float TypeWeight, float ArgWeight) {
301
+ ::OpcWeight = OpcWeight;
302
+ ::TypeWeight = TypeWeight;
303
+ ::ArgWeight = ArgWeight;
304
+ }
293
305
294
- auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
295
- EXPECT_TRUE (static_cast <bool >(Result));
296
- Emb = std::move (*Result);
306
+ void TearDown () override {
307
+ // Restore original global weights
308
+ ::OpcWeight = OriginalOpcWeight;
309
+ ::TypeWeight = OriginalTypeWeight;
310
+ ::ArgWeight = OriginalArgWeight;
297
311
}
298
312
};
299
313
300
- TEST (IR2VecTest, GetInstVecMap) {
301
- GetterTestEnv Env;
302
- const auto &InstMap = Env.Emb ->getInstVecMap ();
314
+ TEST_F (IR2VecTestFixture, GetInstVecMap) {
315
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
316
+ ASSERT_TRUE (static_cast <bool >(Result));
317
+ auto Emb = std::move (*Result);
318
+
319
+ const auto &InstMap = Emb->getInstVecMap ();
303
320
304
321
EXPECT_EQ (InstMap.size (), 2u );
305
- EXPECT_TRUE (InstMap.count (Env. Add ));
306
- EXPECT_TRUE (InstMap.count (Env. Ret ));
322
+ EXPECT_TRUE (InstMap.count (AddInst ));
323
+ EXPECT_TRUE (InstMap.count (RetInst ));
307
324
308
- EXPECT_EQ (InstMap.at (Env. Add ).size (), 2u );
309
- EXPECT_EQ (InstMap.at (Env. Ret ).size (), 2u );
325
+ EXPECT_EQ (InstMap.at (AddInst ).size (), 2u );
326
+ EXPECT_EQ (InstMap.at (RetInst ).size (), 2u );
310
327
311
328
// Check values for add: {1.29, 2.31}
312
- EXPECT_THAT (InstMap.at (Env. Add ),
329
+ EXPECT_THAT (InstMap.at (AddInst ),
313
330
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
314
331
315
332
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
316
333
// vocab
317
- EXPECT_THAT (InstMap.at (Env. Ret ), ElementsAre (0.0 , 0.0 ));
334
+ EXPECT_THAT (InstMap.at (RetInst ), ElementsAre (0.0 , 0.0 ));
318
335
}
319
336
320
- TEST (IR2VecTest, GetBBVecMap) {
321
- GetterTestEnv Env;
322
- const auto &BBMap = Env.Emb ->getBBVecMap ();
337
+ TEST_F (IR2VecTestFixture, GetBBVecMap) {
338
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
339
+ ASSERT_TRUE (static_cast <bool >(Result));
340
+ auto Emb = std::move (*Result);
341
+
342
+ const auto &BBMap = Emb->getBBVecMap ();
323
343
324
344
EXPECT_EQ (BBMap.size (), 1u );
325
- EXPECT_TRUE (BBMap.count (Env. BB ));
326
- EXPECT_EQ (BBMap.at (Env. BB ).size (), 2u );
345
+ EXPECT_TRUE (BBMap.count (BB));
346
+ EXPECT_EQ (BBMap.at (BB).size (), 2u );
327
347
328
348
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
329
349
// {1.29, 2.31}
330
- EXPECT_THAT (BBMap.at (Env. BB ),
350
+ EXPECT_THAT (BBMap.at (BB),
331
351
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
332
352
}
333
353
334
- TEST (IR2VecTest, GetBBVector) {
335
- GetterTestEnv Env;
336
- const auto &BBVec = Env.Emb ->getBBVector (*Env.BB );
354
+ TEST_F (IR2VecTestFixture, GetBBVector) {
355
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
356
+ ASSERT_TRUE (static_cast <bool >(Result));
357
+ auto Emb = std::move (*Result);
358
+
359
+ const auto &BBVec = Emb->getBBVector (*BB);
337
360
338
361
EXPECT_EQ (BBVec.size (), 2u );
339
362
EXPECT_THAT (BBVec,
340
363
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
341
364
}
342
365
343
- TEST (IR2VecTest, GetFunctionVector) {
344
- GetterTestEnv Env;
345
- const auto &FuncVec = Env.Emb ->getFunctionVector ();
366
+ TEST_F (IR2VecTestFixture, GetFunctionVector) {
367
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
368
+ ASSERT_TRUE (static_cast <bool >(Result));
369
+ auto Emb = std::move (*Result);
370
+
371
+ const auto &FuncVec = Emb->getFunctionVector ();
346
372
347
373
EXPECT_EQ (FuncVec.size (), 2u );
348
374
@@ -351,4 +377,45 @@ TEST(IR2VecTest, GetFunctionVector) {
351
377
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
352
378
}
353
379
380
+ TEST_F (IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
381
+ setWeights (1.0 , 1.0 , 1.0 );
382
+
383
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
384
+ ASSERT_TRUE (static_cast <bool >(Result));
385
+ auto Emb = std::move (*Result);
386
+
387
+ const auto &FuncVec = Emb->getFunctionVector ();
388
+
389
+ EXPECT_EQ (FuncVec.size (), 2u );
390
+
391
+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
392
+ // 0.3] + [0.0 0.0])
393
+ EXPECT_THAT (FuncVec,
394
+ ElementsAre (DoubleNear (1.7 , 1e-6 ), DoubleNear (2.8 , 1e-6 )));
395
+ }
396
+
397
+ TEST (IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
398
+ Vocab InitialVocab = {{" key1" , {1.1 , 2.2 }}, {" key2" , {3.3 , 4.4 }}};
399
+ Vocab ExpectedVocab = InitialVocab;
400
+ unsigned ExpectedDim = InitialVocab.begin ()->second .size ();
401
+
402
+ IR2VecVocabAnalysis VocabAnalysis (std::move (InitialVocab));
403
+
404
+ LLVMContext TestCtx;
405
+ Module TestMod (" TestModuleForVocabAnalysis" , TestCtx);
406
+ ModuleAnalysisManager MAM;
407
+ IR2VecVocabResult Result = VocabAnalysis.run (TestMod, MAM);
408
+
409
+ EXPECT_TRUE (Result.isValid ());
410
+ ASSERT_FALSE (Result.getVocabulary ().empty ());
411
+ EXPECT_EQ (Result.getDimension (), ExpectedDim);
412
+
413
+ const auto &ResultVocab = Result.getVocabulary ();
414
+ EXPECT_EQ (ResultVocab.size (), ExpectedVocab.size ());
415
+ for (const auto &pair : ExpectedVocab) {
416
+ EXPECT_TRUE (ResultVocab.count (pair.first ));
417
+ EXPECT_THAT (ResultVocab.at (pair.first ), ElementsAreArray (pair.second ));
418
+ }
419
+ }
420
+
354
421
} // end anonymous namespace
0 commit comments