Skip to content

Commit 6bcaec6

Browse files
committed
VocabStorage
1 parent 52875ac commit 6bcaec6

File tree

6 files changed

+648
-194
lines changed

6 files changed

+648
-194
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 177 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "llvm/Support/JSON.h"
4646
#include <array>
4747
#include <map>
48+
#include <optional>
4849

4950
namespace llvm {
5051

@@ -144,6 +145,73 @@ struct Embedding {
144145
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
145146
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
146147

148+
/// Generic storage class for section-based vocabularies.
149+
/// VocabStorage provides a generic foundation for storing and accessing
150+
/// embeddings organized into sections.
151+
class VocabStorage {
152+
private:
153+
/// Section-based storage
154+
std::vector<std::vector<Embedding>> Sections;
155+
156+
const size_t TotalSize;
157+
const unsigned Dimension;
158+
159+
public:
160+
/// Default constructor creates empty storage (invalid state)
161+
VocabStorage() : Sections(), TotalSize(0), Dimension(0) {}
162+
163+
/// Create a VocabStorage with pre-organized section data
164+
VocabStorage(std::vector<std::vector<Embedding>> &&SectionData);
165+
166+
VocabStorage(VocabStorage &&) = default;
167+
VocabStorage &operator=(VocabStorage &&) = delete;
168+
169+
VocabStorage(const VocabStorage &) = delete;
170+
VocabStorage &operator=(const VocabStorage &) = delete;
171+
172+
/// Get total number of entries across all sections
173+
size_t size() const { return TotalSize; }
174+
175+
/// Get number of sections
176+
unsigned getNumSections() const {
177+
return static_cast<unsigned>(Sections.size());
178+
}
179+
180+
/// Section-based access: Storage[sectionId][localIndex]
181+
const std::vector<Embedding> &operator[](unsigned SectionId) const {
182+
assert(SectionId < Sections.size() && "Invalid section ID");
183+
return Sections[SectionId];
184+
}
185+
186+
/// Get vocabulary dimension
187+
unsigned getDimension() const { return Dimension; }
188+
189+
/// Check if vocabulary is valid (has data)
190+
bool isValid() const { return TotalSize > 0; }
191+
192+
/// Iterator support for section-based access
193+
class const_iterator {
194+
const VocabStorage *Storage;
195+
unsigned SectionId = 0;
196+
size_t LocalIndex = 0;
197+
198+
public:
199+
const_iterator(const VocabStorage *Storage, unsigned SectionId,
200+
size_t LocalIndex)
201+
: Storage(Storage), SectionId(SectionId), LocalIndex(LocalIndex) {}
202+
203+
LLVM_ABI const Embedding &operator*() const;
204+
LLVM_ABI const_iterator &operator++();
205+
LLVM_ABI bool operator==(const const_iterator &Other) const;
206+
LLVM_ABI bool operator!=(const const_iterator &Other) const;
207+
};
208+
209+
const_iterator begin() const { return const_iterator(this, 0, 0); }
210+
const_iterator end() const {
211+
return const_iterator(this, getNumSections(), 0);
212+
}
213+
};
214+
147215
/// Class for storing and accessing the IR2Vec vocabulary.
148216
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
149217
/// seed embeddings are the initial learned representations of the entities
@@ -164,7 +232,7 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
164232
class Vocabulary {
165233
friend class llvm::IR2VecVocabAnalysis;
166234

167-
// Vocabulary Slot Layout:
235+
// Vocabulary Layout:
168236
// +----------------+------------------------------------------------------+
169237
// | Entity Type | Index Range |
170238
// +----------------+------------------------------------------------------+
@@ -175,8 +243,16 @@ class Vocabulary {
175243
// Note: "Similar" LLVM Types are grouped/canonicalized together.
176244
// Operands include Comparison predicates (ICmp/FCmp).
177245
// This can be extended to include other specializations in future.
178-
using VocabVector = std::vector<ir2vec::Embedding>;
179-
VocabVector Vocab;
246+
enum class Section : unsigned {
247+
Opcodes = 0,
248+
CanonicalTypes = 1,
249+
Operands = 2,
250+
Predicates = 3,
251+
MaxSections
252+
};
253+
254+
// Use section-based storage for better organization and efficiency
255+
VocabStorage Storage;
180256

181257
static constexpr unsigned NumICmpPredicates =
182258
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
@@ -228,10 +304,23 @@ class Vocabulary {
228304
NumICmpPredicates + NumFCmpPredicates;
229305

230306
Vocabulary() = default;
231-
LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
307+
LLVM_ABI Vocabulary(VocabStorage &&Storage) : Storage(std::move(Storage)) {}
308+
309+
Vocabulary(const Vocabulary &) = delete;
310+
Vocabulary &operator=(const Vocabulary &) = delete;
311+
312+
Vocabulary(Vocabulary &&) = default;
313+
Vocabulary &operator=(Vocabulary &&Other) = delete;
314+
315+
LLVM_ABI bool isValid() const {
316+
return Storage.size() == NumCanonicalEntries;
317+
}
318+
319+
LLVM_ABI unsigned getDimension() const {
320+
assert(isValid() && "IR2Vec Vocabulary is invalid");
321+
return Storage.getDimension();
322+
}
232323

233-
LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
234-
LLVM_ABI unsigned getDimension() const;
235324
/// Total number of entries (opcodes + canonicalized types + operand kinds +
236325
/// predicates)
237326
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
@@ -240,59 +329,91 @@ class Vocabulary {
240329
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
241330

242331
/// Function to get vocabulary key for a given TypeID
243-
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
332+
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID) {
333+
return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID));
334+
}
244335

245336
/// Function to get vocabulary key for a given OperandKind
246-
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
337+
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind) {
338+
unsigned Index = static_cast<unsigned>(Kind);
339+
assert(Index < MaxOperandKinds && "Invalid OperandKind");
340+
return OperandKindNames[Index];
341+
}
247342

248343
/// Function to classify an operand into OperandKind
249344
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
250345

251346
/// Function to get vocabulary key for a given predicate
252347
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
253348

254-
/// Functions to return the slot index or position of a given Opcode, TypeID,
255-
/// or OperandKind in the vocabulary.
256-
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
257-
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
258-
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
259-
LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
349+
/// Functions to return flat index
350+
LLVM_ABI static unsigned getIndex(unsigned Opcode) {
351+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
352+
return Opcode - 1; // Convert to zero-based index
353+
}
354+
355+
LLVM_ABI static unsigned getIndex(Type::TypeID TypeID) {
356+
assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
357+
return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID));
358+
}
359+
360+
LLVM_ABI static unsigned getIndex(const Value &Op) {
361+
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
362+
assert(Index < MaxOperandKinds && "Invalid OperandKind");
363+
return OperandBaseOffset + Index;
364+
}
365+
366+
LLVM_ABI static unsigned getIndex(CmpInst::Predicate P) {
367+
return PredicateBaseOffset + getPredicateLocalIndex(P);
368+
}
260369

261370
/// Accessors to get the embedding for a given entity.
262-
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
263-
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
264-
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
265-
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
371+
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const {
372+
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
373+
return Storage[static_cast<unsigned>(Section::Opcodes)][Opcode - 1];
374+
}
375+
376+
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeID) const {
377+
assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
378+
unsigned LocalIndex = static_cast<unsigned>(getCanonicalTypeID(TypeID));
379+
return Storage[static_cast<unsigned>(Section::CanonicalTypes)][LocalIndex];
380+
}
381+
382+
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const {
383+
unsigned LocalIndex = static_cast<unsigned>(getOperandKind(&Arg));
384+
assert(LocalIndex < MaxOperandKinds && "Invalid OperandKind");
385+
return Storage[static_cast<unsigned>(Section::Operands)][LocalIndex];
386+
}
387+
388+
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const {
389+
unsigned LocalIndex = getPredicateLocalIndex(P);
390+
return Storage[static_cast<unsigned>(Section::Predicates)][LocalIndex];
391+
}
266392

267393
/// Const Iterator type aliases
268-
using const_iterator = VocabVector::const_iterator;
394+
using const_iterator = VocabStorage::const_iterator;
395+
269396
const_iterator begin() const {
270397
assert(isValid() && "IR2Vec Vocabulary is invalid");
271-
return Vocab.begin();
398+
return Storage.begin();
272399
}
273400

274-
const_iterator cbegin() const {
275-
assert(isValid() && "IR2Vec Vocabulary is invalid");
276-
return Vocab.cbegin();
277-
}
401+
const_iterator cbegin() const { return begin(); }
278402

279403
const_iterator end() const {
280404
assert(isValid() && "IR2Vec Vocabulary is invalid");
281-
return Vocab.end();
405+
return Storage.end();
282406
}
283407

284-
const_iterator cend() const {
285-
assert(isValid() && "IR2Vec Vocabulary is invalid");
286-
return Vocab.cend();
287-
}
408+
const_iterator cend() const { return end(); }
288409

289410
/// Returns the string key for a given index position in the vocabulary.
290411
/// This is useful for debugging or printing the vocabulary. Do not use this
291412
/// for embedding generation as string based lookups are inefficient.
292413
LLVM_ABI static StringRef getStringKey(unsigned Pos);
293414

294415
/// Create a dummy vocabulary for testing purposes.
295-
LLVM_ABI static VocabVector createDummyVocabForTest(unsigned Dim = 1);
416+
LLVM_ABI static VocabStorage createDummyVocabForTest(unsigned Dim = 1);
296417

297418
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
298419
ModuleAnalysisManager::Invalidator &Inv) const;
@@ -301,12 +422,16 @@ class Vocabulary {
301422
constexpr static unsigned NumCanonicalEntries =
302423
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
303424

304-
// Base offsets for slot layout to simplify index computation
425+
// Base offsets for flat index computation
305426
constexpr static unsigned OperandBaseOffset =
306427
MaxOpcodes + MaxCanonicalTypeIDs;
307428
constexpr static unsigned PredicateBaseOffset =
308429
OperandBaseOffset + MaxOperandKinds;
309430

431+
/// Functions for predicate index calculations
432+
static unsigned getPredicateLocalIndex(CmpInst::Predicate P);
433+
static CmpInst::Predicate getPredicateFromLocalIndex(unsigned LocalIndex);
434+
310435
/// String mappings for CanonicalTypeID values
311436
static constexpr StringLiteral CanonicalTypeNames[] = {
312437
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
@@ -353,13 +478,24 @@ class Vocabulary {
353478

354479
/// Function to get vocabulary key for canonical type by enum
355480
LLVM_ABI static StringRef
356-
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
481+
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) {
482+
unsigned Index = static_cast<unsigned>(CType);
483+
assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID");
484+
return CanonicalTypeNames[Index];
485+
}
357486

358487
/// Function to convert TypeID to CanonicalTypeID
359-
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
488+
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID) {
489+
unsigned Index = static_cast<unsigned>(TypeID);
490+
assert(Index < MaxTypeIDs && "Invalid TypeID");
491+
return TypeIDMapping[Index];
492+
}
360493

361494
/// Function to get the predicate enum value for a given index
362-
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
495+
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index) {
496+
assert(Index < MaxPredicateKinds && "Invalid predicate index");
497+
return getPredicateFromLocalIndex(Index);
498+
}
363499
};
364500

365501
/// Embedder provides the interface to generate embeddings (vector
@@ -452,22 +588,22 @@ class LLVM_ABI FlowAwareEmbedder : public Embedder {
452588
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
453589
/// its corresponding embedding.
454590
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
455-
using VocabVector = std::vector<ir2vec::Embedding>;
456591
using VocabMap = std::map<std::string, ir2vec::Embedding>;
457-
VocabMap OpcVocab, TypeVocab, ArgVocab;
458-
VocabVector Vocab;
592+
std::optional<ir2vec::VocabStorage> Vocab;
459593

460-
Error readVocabulary();
594+
Error readVocabulary(VocabMap &OpcVocab, VocabMap &TypeVocab,
595+
VocabMap &ArgVocab);
461596
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
462597
VocabMap &TargetVocab, unsigned &Dim);
463-
void generateNumMappedVocab();
598+
void generateVocabStorage(VocabMap &OpcVocab, VocabMap &TypeVocab,
599+
VocabMap &ArgVocab);
464600
void emitError(Error Err, LLVMContext &Ctx);
465601

466602
public:
467603
LLVM_ABI static AnalysisKey Key;
468604
IR2VecVocabAnalysis() = default;
469-
LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab);
470-
LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab);
605+
LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::VocabStorage &&Vocab)
606+
: Vocab(std::move(Vocab)) {}
471607
using Result = ir2vec::Vocabulary;
472608
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM);
473609
};

0 commit comments

Comments
 (0)