Skip to content

Commit 7f2012c

Browse files
committed
Vocab changes1
1 parent 6817aa9 commit 7f2012c

File tree

3 files changed

+163
-66
lines changed

3 files changed

+163
-66
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
#include "llvm/ADT/DenseMap.h"
3333
#include "llvm/IR/PassManager.h"
34+
#include "llvm/Support/CommandLine.h"
3435
#include "llvm/Support/ErrorOr.h"
36+
#include "llvm/Support/JSON.h"
3537
#include <map>
3638

3739
namespace llvm {
@@ -43,6 +45,7 @@ class Function;
4345
class Type;
4446
class Value;
4547
class raw_ostream;
48+
class LLVMContext;
4649

4750
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
4851
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,6 +56,11 @@ class raw_ostream;
5356
enum class IR2VecKind { Symbolic };
5457

5558
namespace ir2vec {
59+
60+
LLVM_ABI extern cl::opt<float> OpcWeight;
61+
LLVM_ABI extern cl::opt<float> TypeWeight;
62+
LLVM_ABI extern cl::opt<float> ArgWeight;
63+
5664
/// Embedding is a ADT that wraps std::vector<double>. It provides
5765
/// additional functionality for arithmetic and comparison operations.
5866
/// It is meant to be used *like* std::vector<double> but is more restrictive
@@ -224,10 +232,12 @@ class IR2VecVocabResult {
224232
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
225233
ir2vec::Vocab Vocabulary;
226234
Error readVocabulary();
235+
void emitError(Error Err, LLVMContext &Ctx);
227236

228237
public:
229238
static AnalysisKey Key;
230239
IR2VecVocabAnalysis() = default;
240+
explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
231241
using Result = IR2VecVocabResult;
232242
Result run(Module &M, ModuleAnalysisManager &MAM);
233243
};

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
#include "llvm/ADT/Statistic.h"
1717
#include "llvm/IR/Module.h"
1818
#include "llvm/IR/PassManager.h"
19-
#include "llvm/Support/CommandLine.h"
2019
#include "llvm/Support/Debug.h"
2120
#include "llvm/Support/Errc.h"
2221
#include "llvm/Support/Error.h"
2322
#include "llvm/Support/ErrorHandling.h"
2423
#include "llvm/Support/Format.h"
25-
#include "llvm/Support/JSON.h"
2624
#include "llvm/Support/MemoryBuffer.h"
2725

2826
using namespace llvm;
@@ -33,25 +31,29 @@ using namespace ir2vec;
3331
STATISTIC(VocabMissCounter,
3432
"Number of lookups to entites not present in the vocabulary");
3533

34+
namespace llvm {
35+
namespace ir2vec {
3636
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
3737

3838
// FIXME: Use a default vocab when not specified
3939
static cl::opt<std::string>
4040
VocabFile("ir2vec-vocab-path", cl::Optional,
4141
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
4242
cl::cat(IR2VecCategory));
43-
static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
44-
cl::init(1.0),
45-
cl::desc("Weight for opcode embeddings"),
46-
cl::cat(IR2VecCategory));
47-
static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
48-
cl::init(0.5),
49-
cl::desc("Weight for type embeddings"),
50-
cl::cat(IR2VecCategory));
51-
static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
52-
cl::init(0.2),
53-
cl::desc("Weight for argument embeddings"),
54-
cl::cat(IR2VecCategory));
43+
LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
44+
cl::init(1.0),
45+
cl::desc("Weight for opcode embeddings"),
46+
cl::cat(IR2VecCategory));
47+
LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
48+
cl::init(0.5),
49+
cl::desc("Weight for type embeddings"),
50+
cl::cat(IR2VecCategory));
51+
LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
52+
cl::init(0.2),
53+
cl::desc("Weight for argument embeddings"),
54+
cl::cat(IR2VecCategory));
55+
} // namespace ir2vec
56+
} // namespace llvm
5557

5658
AnalysisKey IR2VecVocabAnalysis::Key;
5759

@@ -251,49 +253,67 @@ bool IR2VecVocabResult::invalidate(
251253
// by auto-generating a default vocabulary during the build time.
252254
Error IR2VecVocabAnalysis::readVocabulary() {
253255
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
254-
if (!BufOrError) {
256+
if (!BufOrError)
255257
return createFileError(VocabFile, BufOrError.getError());
256-
}
258+
257259
auto Content = BufOrError.get()->getBuffer();
258260
json::Path::Root Path("");
259261
Expected<json::Value> ParsedVocabValue = json::parse(Content);
260262
if (!ParsedVocabValue)
261263
return ParsedVocabValue.takeError();
262264

263265
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
264-
if (!Res) {
266+
if (!Res)
265267
return createStringError(errc::illegal_byte_sequence,
266268
"Unable to parse the vocabulary");
267-
}
268-
assert(Vocabulary.size() > 0 && "Vocabulary is empty");
269+
270+
if (Vocabulary.empty())
271+
return createStringError(errc::illegal_byte_sequence,
272+
"Vocabulary is empty");
269273

270274
unsigned Dim = Vocabulary.begin()->second.size();
271-
assert(Dim > 0 && "Dimension of vocabulary is zero");
272-
(void)Dim;
273-
assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
274-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
275-
return Entry.second.size() == Dim;
276-
}) &&
277-
"All vectors in the vocabulary are not of the same dimension");
275+
if (Dim == 0)
276+
return createStringError(errc::illegal_byte_sequence,
277+
"Dimension of vocabulary is zero");
278+
279+
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
280+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
281+
return Entry.second.size() == Dim;
282+
}))
283+
return createStringError(
284+
errc::illegal_byte_sequence,
285+
"All vectors in the vocabulary are not of the same dimension");
286+
278287
return Error::success();
279288
}
280289

290+
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
291+
: Vocabulary(std::move(Vocabulary)) {}
292+
293+
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
294+
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
295+
Ctx.emitError("Error reading vocabulary: " + EI.message());
296+
});
297+
}
298+
281299
IR2VecVocabAnalysis::Result
282300
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
283301
auto Ctx = &M.getContext();
302+
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
303+
// If vocabulary is already populated by the constructor, use it.
304+
if (!Vocabulary.empty())
305+
return IR2VecVocabResult(std::move(Vocabulary));
306+
307+
// Otherwise, try to read from the vocabulary file.
284308
if (VocabFile.empty()) {
285309
// FIXME: Use default vocabulary
286310
Ctx->emitError("IR2Vec vocabulary file path not specified");
287311
return IR2VecVocabResult(); // Return invalid result
288312
}
289313
if (auto Err = readVocabulary()) {
290-
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
291-
Ctx->emitError("Error reading vocabulary: " + EI.message());
292-
});
314+
emitError(std::move(Err), *Ctx);
293315
return IR2VecVocabResult();
294316
}
295-
// FIXME: Scale the vocabulary here once. This would avoid scaling per use
296-
// later.
297317
return IR2VecVocabResult(std::move(Vocabulary));
298318
}
299319

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -261,25 +261,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
261261
EXPECT_EQ(validResult.getDimension(), 2u);
262262
}
263263

264-
// Helper to create a minimal function and embedder for getter tests
265-
struct GetterTestEnv {
266-
Vocab V = {};
264+
// Fixture for IR2Vec tests requiring IR setup and weight management.
265+
class IR2VecTestFixture : public ::testing::Test {
266+
protected:
267+
Vocab V;
267268
LLVMContext Ctx;
268-
std::unique_ptr<Module> M = nullptr;
269+
std::unique_ptr<Module> M;
269270
Function *F = nullptr;
270271
BasicBlock *BB = nullptr;
271-
Instruction *Add = nullptr;
272-
Instruction *Ret = nullptr;
273-
std::unique_ptr<Embedder> Emb = nullptr;
272+
Instruction *AddInst = nullptr;
273+
Instruction *RetInst = nullptr;
274274

275-
GetterTestEnv() {
275+
float OriginalOpcWeight = ::OpcWeight;
276+
float OriginalTypeWeight = ::TypeWeight;
277+
float OriginalArgWeight = ::ArgWeight;
278+
279+
void SetUp() override {
276280
V = {{"add", {1.0, 2.0}},
277281
{"integerTy", {0.5, 0.5}},
278282
{"constant", {0.2, 0.3}},
279283
{"variable", {0.0, 0.0}},
280284
{"unknownTy", {0.0, 0.0}}};
281285

282-
M = std::make_unique<Module>("M", Ctx);
286+
// Setup IR
287+
M = std::make_unique<Module>("TestM", Ctx);
283288
FunctionType *FTy = FunctionType::get(
284289
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
285290
false);
@@ -288,61 +293,82 @@ struct GetterTestEnv {
288293
Argument *Arg = F->getArg(0);
289294
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
290295

291-
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
292-
Ret = ReturnInst::Create(Ctx, Add, BB);
296+
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
297+
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
298+
}
299+
300+
void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
301+
::OpcWeight = OpcWeight;
302+
::TypeWeight = TypeWeight;
303+
::ArgWeight = ArgWeight;
304+
}
293305

294-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
295-
EXPECT_TRUE(static_cast<bool>(Result));
296-
Emb = std::move(*Result);
306+
void TearDown() override {
307+
// Restore original global weights
308+
::OpcWeight = OriginalOpcWeight;
309+
::TypeWeight = OriginalTypeWeight;
310+
::ArgWeight = OriginalArgWeight;
297311
}
298312
};
299313

300-
TEST(IR2VecTest, GetInstVecMap) {
301-
GetterTestEnv Env;
302-
const auto &InstMap = Env.Emb->getInstVecMap();
314+
TEST_F(IR2VecTestFixture, GetInstVecMap) {
315+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
316+
ASSERT_TRUE(static_cast<bool>(Result));
317+
auto Emb = std::move(*Result);
318+
319+
const auto &InstMap = Emb->getInstVecMap();
303320

304321
EXPECT_EQ(InstMap.size(), 2u);
305-
EXPECT_TRUE(InstMap.count(Env.Add));
306-
EXPECT_TRUE(InstMap.count(Env.Ret));
322+
EXPECT_TRUE(InstMap.count(AddInst));
323+
EXPECT_TRUE(InstMap.count(RetInst));
307324

308-
EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
309-
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
325+
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
326+
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
310327

311328
// Check values for add: {1.29, 2.31}
312-
EXPECT_THAT(InstMap.at(Env.Add),
329+
EXPECT_THAT(InstMap.at(AddInst),
313330
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
314331

315332
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
316333
// vocab
317-
EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
334+
EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
318335
}
319336

320-
TEST(IR2VecTest, GetBBVecMap) {
321-
GetterTestEnv Env;
322-
const auto &BBMap = Env.Emb->getBBVecMap();
337+
TEST_F(IR2VecTestFixture, GetBBVecMap) {
338+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
339+
ASSERT_TRUE(static_cast<bool>(Result));
340+
auto Emb = std::move(*Result);
341+
342+
const auto &BBMap = Emb->getBBVecMap();
323343

324344
EXPECT_EQ(BBMap.size(), 1u);
325-
EXPECT_TRUE(BBMap.count(Env.BB));
326-
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
345+
EXPECT_TRUE(BBMap.count(BB));
346+
EXPECT_EQ(BBMap.at(BB).size(), 2u);
327347

328348
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
329349
// {1.29, 2.31}
330-
EXPECT_THAT(BBMap.at(Env.BB),
350+
EXPECT_THAT(BBMap.at(BB),
331351
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
332352
}
333353

334-
TEST(IR2VecTest, GetBBVector) {
335-
GetterTestEnv Env;
336-
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
354+
TEST_F(IR2VecTestFixture, GetBBVector) {
355+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
356+
ASSERT_TRUE(static_cast<bool>(Result));
357+
auto Emb = std::move(*Result);
358+
359+
const auto &BBVec = Emb->getBBVector(*BB);
337360

338361
EXPECT_EQ(BBVec.size(), 2u);
339362
EXPECT_THAT(BBVec,
340363
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
341364
}
342365

343-
TEST(IR2VecTest, GetFunctionVector) {
344-
GetterTestEnv Env;
345-
const auto &FuncVec = Env.Emb->getFunctionVector();
366+
TEST_F(IR2VecTestFixture, GetFunctionVector) {
367+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
368+
ASSERT_TRUE(static_cast<bool>(Result));
369+
auto Emb = std::move(*Result);
370+
371+
const auto &FuncVec = Emb->getFunctionVector();
346372

347373
EXPECT_EQ(FuncVec.size(), 2u);
348374

@@ -351,4 +377,45 @@ TEST(IR2VecTest, GetFunctionVector) {
351377
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
352378
}
353379

380+
TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
381+
setWeights(1.0, 1.0, 1.0);
382+
383+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
384+
ASSERT_TRUE(static_cast<bool>(Result));
385+
auto Emb = std::move(*Result);
386+
387+
const auto &FuncVec = Emb->getFunctionVector();
388+
389+
EXPECT_EQ(FuncVec.size(), 2u);
390+
391+
// Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
392+
// 0.3] + [0.0 0.0])
393+
EXPECT_THAT(FuncVec,
394+
ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
395+
}
396+
397+
TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
398+
Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
399+
Vocab ExpectedVocab = InitialVocab;
400+
unsigned ExpectedDim = InitialVocab.begin()->second.size();
401+
402+
IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
403+
404+
LLVMContext TestCtx;
405+
Module TestMod("TestModuleForVocabAnalysis", TestCtx);
406+
ModuleAnalysisManager MAM;
407+
IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
408+
409+
EXPECT_TRUE(Result.isValid());
410+
ASSERT_FALSE(Result.getVocabulary().empty());
411+
EXPECT_EQ(Result.getDimension(), ExpectedDim);
412+
413+
const auto &ResultVocab = Result.getVocabulary();
414+
EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
415+
for (const auto &pair : ExpectedVocab) {
416+
EXPECT_TRUE(ResultVocab.count(pair.first));
417+
EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
418+
}
419+
}
420+
354421
} // end anonymous namespace

0 commit comments

Comments
 (0)