Skip to content

Commit cf15515

Browse files
committed
[AggressiveInstCombine] Add support for ICmp instr that feeds a select intsr's condition operand.
1 parent d3e7816 commit cf15515

File tree

1 file changed

+78
-9
lines changed

1 file changed

+78
-9
lines changed

llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,24 @@ using namespace llvm;
3636

3737
#define DEBUG_TYPE "aggressive-instcombine"
3838

39+
// This function returns true if Value V is a constant or if it's a type
40+
// extension node.
41+
static bool isConstOrExt(Value *V) {
42+
if (isa<Constant>(V))
43+
return true;
44+
45+
if (Instruction *I = dyn_cast<Instruction>(V)) {
46+
switch(I->getOpcode()) {
47+
case Instruction::ZExt:
48+
case Instruction::SExt:
49+
return true;
50+
default:
51+
return false;
52+
}
53+
}
54+
return false;
55+
}
56+
3957
/// Given an instruction and a container, it fills all the relevant operands of
4058
/// that instruction, with respect to the Trunc expression dag optimizaton.
4159
static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
@@ -53,12 +71,20 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
5371
case Instruction::And:
5472
case Instruction::Or:
5573
case Instruction::Xor:
74+
case Instruction::ICmp:
5675
Ops.push_back(I->getOperand(0));
5776
Ops.push_back(I->getOperand(1));
5877
break;
5978
case Instruction::Select:
79+
Value *Op0 = I->getOperand(0);
6080
Ops.push_back(I->getOperand(1));
6181
Ops.push_back(I->getOperand(2));
82+
// In case the condition is a compare instruction, that both of its operands
83+
// are a type extension/truncate or a constant, that can be shrinked without
84+
// loosing information in the compare instruction, add them as well.
85+
if (CmpInst *C = dyn_cast<CmpInst>(Op0))
86+
if (isConstOrExt(C->getOperand(0)) && isConstOrExt(C->getOperand(1)))
87+
Ops.push_back(Op0);
6288
break;
6389
default:
6490
llvm_unreachable("Unreachable!");
@@ -119,7 +145,8 @@ bool TruncInstCombine::buildTruncExpressionDag() {
119145
case Instruction::And:
120146
case Instruction::Or:
121147
case Instruction::Xor:
122-
case Instruction::Select: {
148+
case Instruction::Select:
149+
case Instruction::ICmp: {
123150
SmallVector<Value *, 2> Operands;
124151
getRelevantOperands(I, Operands);
125152
for (Value *Operand : Operands)
@@ -139,6 +166,21 @@ bool TruncInstCombine::buildTruncExpressionDag() {
139166
return true;
140167
}
141168

169+
// Get the minimum number of bits needed for the given constant.
170+
static unsigned getConstMinBitWidth(bool IsSigned, ConstantInt *C) {
171+
// If the const value is signed and negative, count the leading ones.
172+
if (IsSigned) {
173+
int64_t Val = C->getSExtValue();
174+
uint64_t UVal = (uint64_t)Val;
175+
if (Val < 0)
176+
return sizeof(UVal)*8 - countLeadingOnes(UVal) + 1;
177+
}
178+
// Otherwise, count leading zeroes.
179+
uint64_t Val = C->getZExtValue();
180+
auto MinBits = sizeof(Val)*8 - countLeadingZeros(Val);
181+
return IsSigned ? MinBits + 1 : MinBits;
182+
}
183+
142184
unsigned TruncInstCombine::getMinBitWidth() {
143185
SmallVector<Value *, 8> Worklist;
144186
SmallVector<Instruction *, 8> Stack;
@@ -180,6 +222,13 @@ unsigned TruncInstCombine::getMinBitWidth() {
180222
if (auto *IOp = dyn_cast<Instruction>(Operand))
181223
Info.MinBitWidth =
182224
std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);
225+
else if (auto *C = dyn_cast<ConstantInt>(Operand)) {
226+
// In case of Cmp instruction, make sure the constant can be truncated
227+
// without losing information.
228+
if (CmpInst *Cmp = dyn_cast<CmpInst>(I))
229+
Info.MinBitWidth = std::max(
230+
Info.MinBitWidth, getConstMinBitWidth(Cmp->isSigned(), C));
231+
}
183232
continue;
184233
}
185234

@@ -193,14 +242,27 @@ unsigned TruncInstCombine::getMinBitWidth() {
193242

194243
for (auto *Operand : Operands)
195244
if (auto *IOp = dyn_cast<Instruction>(Operand)) {
196-
// If we already calculated the minimum bit-width for this valid
197-
// bit-width, or for a smaller valid bit-width, then just keep the
198-
// answer we already calculated.
199-
unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
200-
if (IOpBitwidth >= ValidBitWidth)
201-
continue;
202-
InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
203-
Worklist.push_back(IOp);
245+
if (isa<CmpInst>(I)) {
246+
// Cmp instructions kind of resets the valid bits analysis for its
247+
// operands, as it does not continue with the same calculation chain
248+
// but rather creates a new chain of its own.
249+
switch (IOp->getOpcode()) {
250+
case Instruction::SExt:
251+
case Instruction::ZExt:
252+
InstInfoMap[IOp].ValidBitWidth =
253+
cast<CastInst>(IOp)->getSrcTy()->getScalarSizeInBits();
254+
break;
255+
}
256+
} else {
257+
// If we already calculated the minimum bit-width for this valid
258+
// bit-width, or for a smaller valid bit-width, then just keep the
259+
// answer we already calculated.
260+
unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
261+
if (IOpBitwidth >= ValidBitWidth)
262+
continue;
263+
InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
264+
Worklist.push_back(IOp);
265+
}
204266
}
205267
}
206268
unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth;
@@ -363,6 +425,13 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
363425
Res = Builder.CreateSelect(Op0, LHS, RHS);
364426
break;
365427
}
428+
case Instruction::ICmp: {
429+
auto ICmp = cast<ICmpInst>(I);
430+
Value *LHS = getReducedOperand(ICmp->getOperand(0), SclTy);
431+
Value *RHS = getReducedOperand(ICmp->getOperand(1), SclTy);
432+
Res = Builder.CreateICmp(ICmp->getPredicate(), LHS, RHS);
433+
break;
434+
}
366435
default:
367436
llvm_unreachable("Unhandled instruction");
368437
}

0 commit comments

Comments
 (0)