Skip to content

Commit 52875ac

Browse files
committed
Support predicates
1 parent 69f1aeb commit 52875ac

File tree

13 files changed

+344
-24
lines changed

13 files changed

+344
-24
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#define LLVM_ANALYSIS_IR2VEC_H
3737

3838
#include "llvm/ADT/DenseMap.h"
39+
#include "llvm/IR/Instructions.h"
3940
#include "llvm/IR/PassManager.h"
4041
#include "llvm/IR/Type.h"
4142
#include "llvm/Support/CommandLine.h"
@@ -162,15 +163,29 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
162163
/// embeddings.
163164
class Vocabulary {
164165
friend class llvm::IR2VecVocabAnalysis;
166+
167+
// Vocabulary Slot Layout:
168+
// +----------------+------------------------------------------------------+
169+
// | Entity Type | Index Range |
170+
// +----------------+------------------------------------------------------+
171+
// | Opcodes | [0 .. (MaxOpcodes-1)] |
172+
// | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)] |
173+
// | Operands | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries] |
174+
// +----------------+------------------------------------------------------+
175+
// Note: "Similar" LLVM Types are grouped/canonicalized together.
176+
// Operands include Comparison predicates (ICmp/FCmp).
177+
// This can be extended to include other specializations in future.
165178
using VocabVector = std::vector<ir2vec::Embedding>;
166179
VocabVector Vocab;
167180

168-
public:
169-
// Slot layout:
170-
// [0 .. MaxOpcodes-1] => Instruction opcodes
171-
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
172-
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
181+
static constexpr unsigned NumICmpPredicates =
182+
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
183+
static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
184+
static constexpr unsigned NumFCmpPredicates =
185+
static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
186+
static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
173187

188+
public:
174189
/// Canonical type IDs supported by IR2Vec Vocabulary
175190
enum class CanonicalTypeID : unsigned {
176191
FloatTy,
@@ -207,13 +222,18 @@ class Vocabulary {
207222
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
208223
static constexpr unsigned MaxOperandKinds =
209224
static_cast<unsigned>(OperandKind::MaxOperandKind);
225+
// CmpInst::Predicate has gaps. We want the vocabulary to be dense without
226+
// empty slots.
227+
static constexpr unsigned MaxPredicateKinds =
228+
NumICmpPredicates + NumFCmpPredicates;
210229

211230
Vocabulary() = default;
212231
LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
213232

214233
LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
215234
LLVM_ABI unsigned getDimension() const;
216-
/// Total number of entries (opcodes + canonicalized types + operand kinds)
235+
/// Total number of entries (opcodes + canonicalized types + operand kinds +
236+
/// predicates)
217237
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
218238

219239
/// Function to get vocabulary key for a given Opcode
@@ -228,16 +248,21 @@ class Vocabulary {
228248
/// Function to classify an operand into OperandKind
229249
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
230250

251+
/// Function to get vocabulary key for a given predicate
252+
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
253+
231254
/// Functions to return the slot index or position of a given Opcode, TypeID,
232255
/// or OperandKind in the vocabulary.
233256
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
234257
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
235258
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
259+
LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
236260

237261
/// Accessors to get the embedding for a given entity.
238262
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
239263
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
240264
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
265+
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
241266

242267
/// Const Iterator type aliases
243268
using const_iterator = VocabVector::const_iterator;
@@ -274,7 +299,13 @@ class Vocabulary {
274299

275300
private:
276301
constexpr static unsigned NumCanonicalEntries =
277-
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
302+
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
303+
304+
// Base offsets for slot layout to simplify index computation
305+
constexpr static unsigned OperandBaseOffset =
306+
MaxOpcodes + MaxCanonicalTypeIDs;
307+
constexpr static unsigned PredicateBaseOffset =
308+
OperandBaseOffset + MaxOperandKinds;
278309

279310
/// String mappings for CanonicalTypeID values
280311
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -326,6 +357,9 @@ class Vocabulary {
326357

327358
/// Function to convert TypeID to CanonicalTypeID
328359
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
360+
361+
/// Function to get the predicate enum value for a given index
362+
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
329363
};
330364

331365
/// Embedder provides the interface to generate embeddings (vector

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
216216
ArgEmb += Vocab[*Op];
217217
auto InstVector =
218218
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
219+
if (const auto *IC = dyn_cast<CmpInst>(&I))
220+
InstVector += Vocab[IC->getPredicate()];
219221
InstVecMap[&I] = InstVector;
220222
BBVector += InstVector;
221223
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
250252
// embeddings
251253
auto InstVector =
252254
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
255+
// Add compare predicate embedding as an additional operand if applicable
256+
if (const auto *IC = dyn_cast<CmpInst>(&I))
257+
InstVector += Vocab[IC->getPredicate()];
253258
InstVecMap[&I] = InstVector;
254259
BBVector += InstVector;
255260
}
@@ -278,7 +283,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
278283
unsigned Vocabulary::getSlotIndex(const Value &Op) {
279284
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
280285
assert(Index < MaxOperandKinds && "Invalid OperandKind");
281-
return MaxOpcodes + MaxCanonicalTypeIDs + Index;
286+
return OperandBaseOffset + Index;
287+
}
288+
289+
unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
290+
unsigned PU = static_cast<unsigned>(P);
291+
unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
292+
unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
293+
294+
unsigned PredIdx =
295+
(PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
296+
return PredicateBaseOffset + PredIdx;
282297
}
283298

284299
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -293,6 +308,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
293308
return Vocab[getSlotIndex(Arg)];
294309
}
295310

311+
const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
312+
return Vocab[getSlotIndex(P)];
313+
}
314+
296315
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
297316
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
298317
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -338,18 +357,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
338357
return OperandKind::VariableID;
339358
}
340359

360+
CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
361+
assert(Index < MaxPredicateKinds && "Invalid predicate index");
362+
unsigned PredEnumVal =
363+
(Index < NumFCmpPredicates)
364+
? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
365+
: (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
366+
(Index - NumFCmpPredicates));
367+
return static_cast<CmpInst::Predicate>(PredEnumVal);
368+
}
369+
370+
StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
371+
static SmallString<16> PredNameBuffer;
372+
if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
373+
PredNameBuffer = "FCMP_";
374+
else
375+
PredNameBuffer = "ICMP_";
376+
PredNameBuffer += CmpInst::getPredicateName(Pred);
377+
return PredNameBuffer;
378+
}
379+
341380
StringRef Vocabulary::getStringKey(unsigned Pos) {
342381
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
343382
// Opcode
344383
if (Pos < MaxOpcodes)
345384
return getVocabKeyForOpcode(Pos + 1);
346385
// Type
347-
if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
386+
if (Pos < OperandBaseOffset)
348387
return getVocabKeyForCanonicalTypeID(
349388
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
350389
// Operand
351-
return getVocabKeyForOperandKind(
352-
static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
390+
if (Pos < PredicateBaseOffset)
391+
return getVocabKeyForOperandKind(
392+
static_cast<OperandKind>(Pos - OperandBaseOffset));
393+
// Predicates
394+
return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
353395
}
354396

355397
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -363,11 +405,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
363405
VocabVector DummyVocab;
364406
DummyVocab.reserve(NumCanonicalEntries);
365407
float DummyVal = 0.1f;
366-
// Create a dummy vocabulary with entries for all opcodes, types, and
367-
// operands
368-
for ([[maybe_unused]] unsigned _ :
369-
seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
370-
Vocabulary::MaxOperandKinds)) {
408+
// Create a dummy vocabulary with entries for all opcodes, types, operands
409+
// and predicates
410+
for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
371411
DummyVocab.push_back(Embedding(Dim, DummyVal));
372412
DummyVal += 0.1f;
373413
}
@@ -510,6 +550,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
510550
}
511551
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
512552
NumericArgEmbeddings.end());
553+
554+
// Handle Predicates: part of Operands section. We look up predicate keys
555+
// in ArgVocab.
556+
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
557+
Embedding(Dim, 0));
558+
NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
559+
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
560+
StringRef VocabKey =
561+
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
562+
auto It = ArgVocab.find(VocabKey.str());
563+
if (It != ArgVocab.end()) {
564+
NumericPredEmbeddings[PK] = It->second;
565+
continue;
566+
}
567+
handleMissingEntity(VocabKey.str());
568+
}
569+
Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
570+
NumericPredEmbeddings.end());
513571
}
514572

515573
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)

llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,32 @@
8787
"Function": [1, 2],
8888
"Pointer": [3, 4],
8989
"Constant": [5, 6],
90-
"Variable": [7, 8]
90+
"Variable": [7, 8],
91+
"FCMP_false": [9, 10],
92+
"FCMP_oeq": [11, 12],
93+
"FCMP_ogt": [13, 14],
94+
"FCMP_oge": [15, 16],
95+
"FCMP_olt": [17, 18],
96+
"FCMP_ole": [19, 20],
97+
"FCMP_one": [21, 22],
98+
"FCMP_ord": [23, 24],
99+
"FCMP_uno": [25, 26],
100+
"FCMP_ueq": [27, 28],
101+
"FCMP_ugt": [29, 30],
102+
"FCMP_uge": [31, 32],
103+
"FCMP_ult": [33, 34],
104+
"FCMP_ule": [35, 36],
105+
"FCMP_une": [37, 38],
106+
"FCMP_true": [39, 40],
107+
"ICMP_eq": [41, 42],
108+
"ICMP_ne": [43, 44],
109+
"ICMP_ugt": [45, 46],
110+
"ICMP_uge": [47, 48],
111+
"ICMP_ult": [49, 50],
112+
"ICMP_ule": [51, 52],
113+
"ICMP_sgt": [53, 54],
114+
"ICMP_sge": [55, 56],
115+
"ICMP_slt": [57, 58],
116+
"ICMP_sle": [59, 60]
91117
}
92118
}

llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,32 @@
8686
"Function": [1, 2, 3],
8787
"Pointer": [4, 5, 6],
8888
"Constant": [7, 8, 9],
89-
"Variable": [10, 11, 12]
89+
"Variable": [10, 11, 12],
90+
"FCMP_false": [13, 14, 15],
91+
"FCMP_oeq": [16, 17, 18],
92+
"FCMP_ogt": [19, 20, 21],
93+
"FCMP_oge": [22, 23, 24],
94+
"FCMP_olt": [25, 26, 27],
95+
"FCMP_ole": [28, 29, 30],
96+
"FCMP_one": [31, 32, 33],
97+
"FCMP_ord": [34, 35, 36],
98+
"FCMP_uno": [37, 38, 39],
99+
"FCMP_ueq": [40, 41, 42],
100+
"FCMP_ugt": [43, 44, 45],
101+
"FCMP_uge": [46, 47, 48],
102+
"FCMP_ult": [49, 50, 51],
103+
"FCMP_ule": [52, 53, 54],
104+
"FCMP_une": [55, 56, 57],
105+
"FCMP_true": [58, 59, 60],
106+
"ICMP_eq": [61, 62, 63],
107+
"ICMP_ne": [64, 65, 66],
108+
"ICMP_ugt": [67, 68, 69],
109+
"ICMP_uge": [70, 71, 72],
110+
"ICMP_ult": [73, 74, 75],
111+
"ICMP_ule": [76, 77, 78],
112+
"ICMP_sgt": [79, 80, 81],
113+
"ICMP_sge": [82, 83, 84],
114+
"ICMP_slt": [85, 86, 87],
115+
"ICMP_sle": [88, 89, 90]
90116
}
91117
}

llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"FPTrunc": [133, 134, 135],
4848
"FPExt": [136, 137, 138],
4949
"PtrToInt": [139, 140, 141],
50+
"PtrToAddr": [202, 203, 204],
5051
"IntToPtr": [142, 143, 144],
5152
"BitCast": [145, 146, 147],
5253
"AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
8687
"Function": [0, 0, 0],
8788
"Pointer": [0, 0, 0],
8889
"Constant": [0, 0, 0],
89-
"Variable": [0, 0, 0]
90+
"Variable": [0, 0, 0],
91+
"FCMP_false": [0, 0, 0],
92+
"FCMP_oeq": [0, 0, 0],
93+
"FCMP_ogt": [0, 0, 0],
94+
"FCMP_oge": [0, 0, 0],
95+
"FCMP_olt": [0, 0, 0],
96+
"FCMP_ole": [0, 0, 0],
97+
"FCMP_one": [0, 0, 0],
98+
"FCMP_ord": [0, 0, 0],
99+
"FCMP_uno": [0, 0, 0],
100+
"FCMP_ueq": [0, 0, 0],
101+
"FCMP_ugt": [0, 0, 0],
102+
"FCMP_uge": [0, 0, 0],
103+
"FCMP_ult": [0, 0, 0],
104+
"FCMP_ule": [0, 0, 0],
105+
"FCMP_une": [0, 0, 0],
106+
"FCMP_true": [0, 0, 0],
107+
"ICMP_eq": [0, 0, 0],
108+
"ICMP_ne": [0, 0, 0],
109+
"ICMP_ugt": [0, 0, 0],
110+
"ICMP_uge": [0, 0, 0],
111+
"ICMP_ult": [0, 0, 0],
112+
"ICMP_ule": [0, 0, 0],
113+
"ICMP_sgt": [1, 1, 1],
114+
"ICMP_sge": [0, 0, 0],
115+
"ICMP_slt": [0, 0, 0],
116+
"ICMP_sle": [0, 0, 0]
90117
}
91118
}

llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ]
8282
Key: Pointer: [ 0.60 0.80 ]
8383
Key: Constant: [ 1.00 1.20 ]
8484
Key: Variable: [ 1.40 1.60 ]
85+
Key: FCMP_false: [ 1.80 2.00 ]
86+
Key: FCMP_oeq: [ 2.20 2.40 ]
87+
Key: FCMP_ogt: [ 2.60 2.80 ]
88+
Key: FCMP_oge: [ 3.00 3.20 ]
89+
Key: FCMP_olt: [ 3.40 3.60 ]
90+
Key: FCMP_ole: [ 3.80 4.00 ]
91+
Key: FCMP_one: [ 4.20 4.40 ]
92+
Key: FCMP_ord: [ 4.60 4.80 ]
93+
Key: FCMP_uno: [ 5.00 5.20 ]
94+
Key: FCMP_ueq: [ 5.40 5.60 ]
95+
Key: FCMP_ugt: [ 5.80 6.00 ]
96+
Key: FCMP_uge: [ 6.20 6.40 ]
97+
Key: FCMP_ult: [ 6.60 6.80 ]
98+
Key: FCMP_ule: [ 7.00 7.20 ]
99+
Key: FCMP_une: [ 7.40 7.60 ]
100+
Key: FCMP_true: [ 7.80 8.00 ]
101+
Key: ICMP_eq: [ 8.20 8.40 ]
102+
Key: ICMP_ne: [ 8.60 8.80 ]
103+
Key: ICMP_ugt: [ 9.00 9.20 ]
104+
Key: ICMP_uge: [ 9.40 9.60 ]
105+
Key: ICMP_ult: [ 9.80 10.00 ]
106+
Key: ICMP_ule: [ 10.20 10.40 ]
107+
Key: ICMP_sgt: [ 10.60 10.80 ]
108+
Key: ICMP_sge: [ 11.00 11.20 ]
109+
Key: ICMP_slt: [ 11.40 11.60 ]
110+
Key: ICMP_sle: [ 11.80 12.00 ]

0 commit comments

Comments
 (0)