Skip to content

Commit 237e4d2

Browse files
committed
Vocab Changes
1 parent 3b05edf commit 237e4d2

19 files changed

+1290
-404
lines changed

llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ class FunctionPropertiesInfo {
3434
void reIncludeBB(const BasicBlock &BB);
3535

3636
ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
37-
std::optional<ir2vec::Vocab> IR2VecVocab;
37+
const ir2vec::Vocabulary *IR2VecVocab = nullptr;
3838

3939
public:
4040
LLVM_ABI static FunctionPropertiesInfo
4141
getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
4242
const LoopInfo &LI,
43-
const IR2VecVocabResult *VocabResult);
43+
const ir2vec::Vocabulary *Vocabulary);
4444

4545
LLVM_ABI static FunctionPropertiesInfo
4646
getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
@@ -145,9 +145,7 @@ class FunctionPropertiesInfo {
145145
return FunctionEmbedding;
146146
}
147147

148-
const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
149-
return IR2VecVocab;
150-
}
148+
const ir2vec::Vocabulary *getIR2VecVocab() const { return IR2VecVocab; }
151149

152150
// Helper intended to be useful for unittests
153151
void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 105 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "llvm/ADT/DenseMap.h"
3333
#include "llvm/IR/PassManager.h"
34+
#include "llvm/IR/Type.h"
3435
#include "llvm/Support/CommandLine.h"
3536
#include "llvm/Support/Compiler.h"
3637
#include "llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
4344
class BasicBlock;
4445
class Instruction;
4546
class Function;
46-
class Type;
4747
class Value;
4848
class raw_ostream;
4949
class LLVMContext;
50+
class IR2VecVocabAnalysis;
5051

5152
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
5253
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -128,9 +129,94 @@ struct Embedding {
128129

129130
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
130131
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
131-
// FIXME: Current the keys are strings. This can be changed to
132-
// use integers for cheaper lookups.
133-
using Vocab = std::map<std::string, Embedding>;
132+
133+
/// Class for storing and accessing the IR2Vec vocabulary.
134+
/// Encapsulates all vocabulary-related constants, logic, and access methods.
135+
class Vocabulary {
136+
friend class llvm::IR2VecVocabAnalysis;
137+
using VocabVector = std::vector<ir2vec::Embedding>;
138+
VocabVector Vocab;
139+
bool Valid = false;
140+
141+
/// Operand kinds supported by IR2Vec Vocabulary
142+
enum class OperandKind : unsigned {
143+
FunctionID,
144+
PointerID,
145+
ConstantID,
146+
VariableID,
147+
MaxOperandKind
148+
};
149+
/// String mappings for OperandKind values
150+
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
151+
"Constant", "Variable"};
152+
static_assert(std::size(OperandKindNames) ==
153+
static_cast<unsigned>(OperandKind::MaxOperandKind),
154+
"OperandKindNames array size must match MaxOperandKind");
155+
156+
/// Vocabulary layout constants
157+
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
158+
#include "llvm/IR/Instruction.def"
159+
#undef LAST_OTHER_INST
160+
161+
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
162+
static constexpr unsigned MaxOperandKinds =
163+
static_cast<unsigned>(OperandKind::MaxOperandKind);
164+
165+
/// Helper function to get vocabulary key for a given OperandKind
166+
static StringRef getVocabKeyForOperandKind(OperandKind Kind);
167+
168+
/// Helper function to classify an operand into OperandKind
169+
static OperandKind getOperandKind(const Value *Op);
170+
171+
/// Helper function to get vocabulary key for a given TypeID
172+
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
173+
174+
public:
175+
Vocabulary() = default;
176+
Vocabulary(VocabVector &&Vocab);
177+
178+
bool isValid() const;
179+
unsigned getDimension() const;
180+
size_t size() const;
181+
182+
/// Accessors to get the embedding for a given entity.
183+
const ir2vec::Embedding &operator[](unsigned Opcode) const;
184+
const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
185+
const ir2vec::Embedding &operator[](const Value *Arg) const;
186+
187+
/// Const Iterator type aliases
188+
using const_iterator = VocabVector::const_iterator;
189+
const_iterator begin() const {
190+
assert(Valid && "IR2Vec Vocabulary is invalid");
191+
return Vocab.begin();
192+
}
193+
194+
const_iterator cbegin() const {
195+
assert(Valid && "IR2Vec Vocabulary is invalid");
196+
return Vocab.cbegin();
197+
}
198+
199+
const_iterator end() const {
200+
assert(Valid && "IR2Vec Vocabulary is invalid");
201+
return Vocab.end();
202+
}
203+
204+
const_iterator cend() const {
205+
assert(Valid && "IR2Vec Vocabulary is invalid");
206+
return Vocab.cend();
207+
}
208+
209+
/// Returns the string key for a given index position in the vocabulary.
210+
/// This is useful for debugging or printing the vocabulary. Do not use this
211+
/// for embedding generation as string based lookups are inefficient.
212+
static StringRef getStringKey(unsigned Pos);
213+
214+
/// Create a dummy vocabulary for testing purposes.
215+
static VocabVector createDummyVocabForTest(unsigned Dim = 1);
216+
217+
bool invalidate(Module &M, const PreservedAnalyses &PA,
218+
ModuleAnalysisManager::Invalidator &Inv) const;
219+
};
134220

135221
/// Embedder provides the interface to generate embeddings (vector
136222
/// representations) for instructions, basic blocks, and functions. The
@@ -141,7 +227,7 @@ using Vocab = std::map<std::string, Embedding>;
141227
class Embedder {
142228
protected:
143229
const Function &F;
144-
const Vocab &Vocabulary;
230+
const Vocabulary &Vocab;
145231

146232
/// Dimension of the vector representation; captured from the input vocabulary
147233
const unsigned Dimension;
@@ -156,7 +242,7 @@ class Embedder {
156242
mutable BBEmbeddingsMap BBVecMap;
157243
mutable InstEmbeddingsMap InstVecMap;
158244

159-
LLVM_ABI Embedder(const Function &F, const Vocab &Vocabulary);
245+
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
160246

161247
/// Helper function to compute embeddings. It generates embeddings for all
162248
/// the instructions and basic blocks in the function F. Logic of computing
@@ -167,16 +253,12 @@ class Embedder {
167253
/// Specific to the kind of embeddings being computed.
168254
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
169255

170-
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
171-
/// zero vector.
172-
LLVM_ABI Embedding lookupVocab(const std::string &Key) const;
173-
174256
public:
175257
virtual ~Embedder() = default;
176258

177259
/// Factory method to create an Embedder object.
178260
LLVM_ABI static std::unique_ptr<Embedder>
179-
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
261+
create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab);
180262

181263
/// Returns a map containing instructions and the corresponding embeddings for
182264
/// the function F if it has been computed. If not, it computes the embeddings
@@ -202,56 +284,39 @@ class Embedder {
202284
/// representations obtained from the Vocabulary.
203285
class LLVM_ABI SymbolicEmbedder : public Embedder {
204286
private:
205-
/// Utility function to compute the embedding for a given type.
206-
Embedding getTypeEmbedding(const Type *Ty) const;
207-
208-
/// Utility function to compute the embedding for a given operand.
209-
Embedding getOperandEmbedding(const Value *Op) const;
210-
211287
void computeEmbeddings() const override;
212288
void computeEmbeddings(const BasicBlock &BB) const override;
213289

214290
public:
215-
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
216-
: Embedder(F, Vocabulary) {
291+
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
292+
: Embedder(F, Vocab) {
217293
FuncVector = Embedding(Dimension, 0);
218294
}
219295
};
220296

221297
} // namespace ir2vec
222298

223-
/// Class for storing the result of the IR2VecVocabAnalysis.
224-
class IR2VecVocabResult {
225-
ir2vec::Vocab Vocabulary;
226-
bool Valid = false;
227-
228-
public:
229-
IR2VecVocabResult() = default;
230-
LLVM_ABI IR2VecVocabResult(ir2vec::Vocab &&Vocabulary);
231-
232-
bool isValid() const { return Valid; }
233-
LLVM_ABI const ir2vec::Vocab &getVocabulary() const;
234-
LLVM_ABI unsigned getDimension() const;
235-
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
236-
ModuleAnalysisManager::Invalidator &Inv) const;
237-
};
238-
239299
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
240300
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
241301
/// its corresponding embedding.
242302
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
243-
ir2vec::Vocab Vocabulary;
303+
using VocabVector = std::vector<ir2vec::Embedding>;
304+
using VocabMap = std::map<std::string, ir2vec::Embedding>;
305+
VocabMap OpcVocab, TypeVocab, ArgVocab;
306+
VocabVector Vocab;
307+
244308
Error readVocabulary();
245309
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
246-
ir2vec::Vocab &TargetVocab, unsigned &Dim);
310+
VocabMap &TargetVocab, unsigned &Dim);
311+
void generateNumMappedVocab();
247312
void emitError(Error Err, LLVMContext &Ctx);
248313

249314
public:
250315
LLVM_ABI static AnalysisKey Key;
251316
IR2VecVocabAnalysis() = default;
252-
LLVM_ABI explicit IR2VecVocabAnalysis(const ir2vec::Vocab &Vocab);
253-
LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
254-
using Result = IR2VecVocabResult;
317+
LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
318+
LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
319+
using Result = ir2vec::Vocabulary;
255320
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
256321
};
257322

llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,20 +242,20 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
242242
// We use the cached result of the IR2VecVocabAnalysis run by
243243
// InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
244244
// use IR2Vec embeddings.
245-
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
246-
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
245+
auto Vocabulary = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
246+
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
247247
return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
248-
FAM.getResult<LoopAnalysis>(F), VocabResult);
248+
FAM.getResult<LoopAnalysis>(F), Vocabulary);
249249
}
250250

251251
FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
252252
const Function &F, const DominatorTree &DT, const LoopInfo &LI,
253-
const IR2VecVocabResult *VocabResult) {
253+
const ir2vec::Vocabulary *Vocabulary) {
254254

255255
FunctionPropertiesInfo FPI;
256-
if (VocabResult && VocabResult->isValid()) {
257-
FPI.IR2VecVocab = VocabResult->getVocabulary();
258-
FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
256+
if (Vocabulary && Vocabulary->isValid()) {
257+
FPI.IR2VecVocab = Vocabulary;
258+
FPI.FunctionEmbedding = ir2vec::Embedding(Vocabulary->getDimension(), 0.0);
259259
}
260260
for (const auto &BB : F)
261261
if (DT.isReachableFromEntry(&BB))
@@ -588,9 +588,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
588588
return false;
589589
DominatorTree DT(F);
590590
LoopInfo LI(DT);
591-
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
592-
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
591+
auto Vocabulary = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
592+
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
593593
auto Fresh =
594-
FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
594+
FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, Vocabulary);
595595
return FPI == Fresh;
596596
}

0 commit comments

Comments
 (0)