@@ -36,6 +36,24 @@ using namespace llvm;
3636
3737#define DEBUG_TYPE " aggressive-instcombine"
3838
39+ // This function returns true if Value V is a constant or if it's a type
40+ // extension node.
41+ static bool isConstOrExt (Value *V) {
42+ if (isa<Constant>(V))
43+ return true ;
44+
45+ if (Instruction *I = dyn_cast<Instruction>(V)) {
46+ switch (I->getOpcode ()) {
47+ case Instruction::ZExt:
48+ case Instruction::SExt:
49+ return true ;
50+ default :
51+ return false ;
52+ }
53+ }
54+ return false ;
55+ }
56+
3957// / Given an instruction and a container, it fills all the relevant operands of
4058// / that instruction, with respect to the Trunc expression dag optimizaton.
4159static void getRelevantOperands (Instruction *I, SmallVectorImpl<Value *> &Ops) {
@@ -53,12 +71,20 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
5371 case Instruction::And:
5472 case Instruction::Or:
5573 case Instruction::Xor:
74+ case Instruction::ICmp:
5675 Ops.push_back (I->getOperand (0 ));
5776 Ops.push_back (I->getOperand (1 ));
5877 break ;
5978 case Instruction::Select:
79+ Value *Op0 = I->getOperand (0 );
6080 Ops.push_back (I->getOperand (1 ));
6181 Ops.push_back (I->getOperand (2 ));
82+ // In case the condition is a compare instruction, that both of its operands
83+ // are a type extension/truncate or a constant, that can be shrinked without
84+ // loosing information in the compare instruction, add them as well.
85+ if (CmpInst *C = dyn_cast<CmpInst>(Op0))
86+ if (isConstOrExt (C->getOperand (0 )) && isConstOrExt (C->getOperand (1 )))
87+ Ops.push_back (Op0);
6288 break ;
6389 default :
6490 llvm_unreachable (" Unreachable!" );
@@ -119,7 +145,8 @@ bool TruncInstCombine::buildTruncExpressionDag() {
119145 case Instruction::And:
120146 case Instruction::Or:
121147 case Instruction::Xor:
122- case Instruction::Select: {
148+ case Instruction::Select:
149+ case Instruction::ICmp: {
123150 SmallVector<Value *, 2 > Operands;
124151 getRelevantOperands (I, Operands);
125152 for (Value *Operand : Operands)
@@ -139,6 +166,21 @@ bool TruncInstCombine::buildTruncExpressionDag() {
139166 return true ;
140167}
141168
169+ // Get the minimum number of bits needed for the given constant.
170+ static unsigned getConstMinBitWidth (bool IsSigned, ConstantInt *C) {
171+ // If the const value is signed and negative, count the leading ones.
172+ if (IsSigned) {
173+ int64_t Val = C->getSExtValue ();
174+ uint64_t UVal = (uint64_t )Val;
175+ if (Val < 0 )
176+ return sizeof (UVal)*8 - countLeadingOnes (UVal) + 1 ;
177+ }
178+ // Otherwise, count leading zeroes.
179+ uint64_t Val = C->getZExtValue ();
180+ auto MinBits = sizeof (Val)*8 - countLeadingZeros (Val);
181+ return IsSigned ? MinBits + 1 : MinBits;
182+ }
183+
142184unsigned TruncInstCombine::getMinBitWidth () {
143185 SmallVector<Value *, 8 > Worklist;
144186 SmallVector<Instruction *, 8 > Stack;
@@ -180,6 +222,13 @@ unsigned TruncInstCombine::getMinBitWidth() {
180222 if (auto *IOp = dyn_cast<Instruction>(Operand))
181223 Info.MinBitWidth =
182224 std::max (Info.MinBitWidth , InstInfoMap[IOp].MinBitWidth );
225+ else if (auto *C = dyn_cast<ConstantInt>(Operand)) {
226+ // In case of Cmp instruction, make sure the constant can be truncated
227+ // without losing information.
228+ if (CmpInst *Cmp = dyn_cast<CmpInst>(I))
229+ Info.MinBitWidth = std::max (
230+ Info.MinBitWidth , getConstMinBitWidth (Cmp->isSigned (), C));
231+ }
183232 continue ;
184233 }
185234
@@ -193,14 +242,27 @@ unsigned TruncInstCombine::getMinBitWidth() {
193242
194243 for (auto *Operand : Operands)
195244 if (auto *IOp = dyn_cast<Instruction>(Operand)) {
196- // If we already calculated the minimum bit-width for this valid
197- // bit-width, or for a smaller valid bit-width, then just keep the
198- // answer we already calculated.
199- unsigned IOpBitwidth = InstInfoMap.lookup (IOp).ValidBitWidth ;
200- if (IOpBitwidth >= ValidBitWidth)
201- continue ;
202- InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
203- Worklist.push_back (IOp);
245+ if (isa<CmpInst>(I)) {
246+ // Cmp instructions kind of resets the valid bits analysis for its
247+ // operands, as it does not continue with the same calculation chain
248+ // but rather creates a new chain of its own.
249+ switch (IOp->getOpcode ()) {
250+ case Instruction::SExt:
251+ case Instruction::ZExt:
252+ InstInfoMap[IOp].ValidBitWidth =
253+ cast<CastInst>(IOp)->getSrcTy ()->getScalarSizeInBits ();
254+ break ;
255+ }
256+ } else {
257+ // If we already calculated the minimum bit-width for this valid
258+ // bit-width, or for a smaller valid bit-width, then just keep the
259+ // answer we already calculated.
260+ unsigned IOpBitwidth = InstInfoMap.lookup (IOp).ValidBitWidth ;
261+ if (IOpBitwidth >= ValidBitWidth)
262+ continue ;
263+ InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
264+ Worklist.push_back (IOp);
265+ }
204266 }
205267 }
206268 unsigned MinBitWidth = InstInfoMap.lookup (cast<Instruction>(Src)).MinBitWidth ;
@@ -363,6 +425,13 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
363425 Res = Builder.CreateSelect (Op0, LHS, RHS);
364426 break ;
365427 }
428+ case Instruction::ICmp: {
429+ auto ICmp = cast<ICmpInst>(I);
430+ Value *LHS = getReducedOperand (ICmp->getOperand (0 ), SclTy);
431+ Value *RHS = getReducedOperand (ICmp->getOperand (1 ), SclTy);
432+ Res = Builder.CreateICmp (ICmp->getPredicate (), LHS, RHS);
433+ break ;
434+ }
366435 default :
367436 llvm_unreachable (" Unhandled instruction" );
368437 }
0 commit comments