diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index 1c60eae7f2f85..f54849aef342f 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1744,6 +1744,9 @@ class CallBase : public Instruction { paramHasAttr(ArgNo, Attribute::DereferenceableOrNull); } + /// Drop parameter attributes that may cause this instruction to cause UB. + void dropPoisonGeneratingAndUBImplyingParamAttrs(unsigned ArgNo); + /// Determine if there are is an inalloca argument. Only the last argument can /// have the inalloca attribute. bool hasInAllocaArgument() const { diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h index 3075b7ebae59e..0dee153a33dbe 100644 --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -423,6 +423,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner { return &I; } + /// Replace operand of a call-like instruction and add old operand to the + /// worklist. Also drop poison generating and UB implying parameter + /// attributes. + Instruction *replaceArgOperand(CallBase &I, unsigned OpNum, Value *V) { + I.dropPoisonGeneratingAndUBImplyingParamAttrs(OpNum); + return replaceOperand(I, OpNum, V); + } + /// Replace use and add the previously used value to the worklist. void replaceUse(Use &U, Value *NewValue) { Value *OldOp = U; diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 05e340ffa20a0..158a172c33111 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -330,6 +331,16 @@ unsigned CallBase::getNumSubclassExtraOperandsDynamic() const { return cast(this)->getNumIndirectDests() + 1; } +void CallBase::dropPoisonGeneratingAndUBImplyingParamAttrs(unsigned ArgNo) { + AttributeMask AM = AttributeFuncs::getUBImplyingAttributes(); + // TODO: Add a helper AttributeFuncs::getPoisonGeneratingAttributes + AM.addAttribute(Attribute::NoFPClass); + AM.addAttribute(Attribute::Range); + AM.addAttribute(Attribute::Alignment); + AM.addAttribute(Attribute::NonNull); + removeParamAttrs(ArgNo, AM); +} + bool CallBase::isIndirectCall() const { const Value *V = getCalledOperand(); if (isa(V) || isa(V)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 6cff3c7af91e3..717de9c39ab4e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -345,7 +345,7 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { APInt PoisonElts(DemandedElts.getBitWidth(), 0); if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, PoisonElts)) - return replaceOperand(II, 0, V); + return replaceArgOperand(II, 0, V); return nullptr; } @@ -430,10 +430,10 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { APInt PoisonElts(DemandedElts.getBitWidth(), 0); if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), DemandedElts, PoisonElts)) - return replaceOperand(II, 0, V); + return replaceArgOperand(II, 0, V); if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), DemandedElts, PoisonElts)) - return replaceOperand(II, 1, V); + return replaceArgOperand(II, 1, V); return nullptr; } @@ -513,11 +513,11 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { if (IsTZ) { // cttz(-x) -> cttz(x) if (match(Op0, m_Neg(m_Value(X)))) - return IC.replaceOperand(II, 0, X); + return IC.replaceArgOperand(II, 0, X); // cttz(-x & x) -> cttz(x) if (match(Op0, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) - return IC.replaceOperand(II, 0, X); + return IC.replaceArgOperand(II, 0, X); // cttz(sext(x)) -> cttz(zext(x)) if (match(Op0, m_OneUse(m_SExt(m_Value(X))))) { @@ -541,10 +541,10 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { Value *Y; SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; if (SPF == SPF_ABS || SPF == SPF_NABS) - return IC.replaceOperand(II, 0, X); + return IC.replaceArgOperand(II, 0, X); if (match(Op0, m_Intrinsic(m_Value(X)))) - return IC.replaceOperand(II, 0, X); + return IC.replaceArgOperand(II, 0, X); // cttz(shl(%const, %val), 1) --> add(cttz(%const, 1), %val) if (match(Op0, m_Shl(m_ImmConstant(C), m_Value(X))) && @@ -636,13 +636,13 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { // ctpop(bitreverse(x)) -> ctpop(x) // ctpop(bswap(x)) -> ctpop(x) if (match(Op0, m_BitReverse(m_Value(X))) || match(Op0, m_BSwap(m_Value(X)))) - return IC.replaceOperand(II, 0, X); + return IC.replaceArgOperand(II, 0, X); // ctpop(rot(x)) -> ctpop(x) if ((match(Op0, m_FShl(m_Value(X), m_Value(Y), m_Value())) || match(Op0, m_FShr(m_Value(X), m_Value(Y), m_Value()))) && X == Y) - return IC.replaceOperand(II, 0, X); + return IC.replaceArgOperand(II, 0, X); // ctpop(x | -x) -> bitwidth - cttz(x, false) if (Op0->hasOneUse() && @@ -814,6 +814,15 @@ static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) { if (isa(Arg0) && !isa(Arg1)) { Call.setArgOperand(0, Arg1); Call.setArgOperand(1, Arg0); + auto CallAttr = Call.getAttributes(); + auto LHSAttr = CallAttr.getParamAttrs(0); + auto RHSAttr = CallAttr.getParamAttrs(1); + LLVMContext &Ctx = Call.getContext(); + Call.setAttributes( + CallAttr.removeAttributesAtIndex(Ctx, 0) + .removeAttributesAtIndex(Ctx, 1) + .addParamAttributes(Ctx, 0, AttrBuilder(Ctx, RHSAttr)) + .addParamAttributes(Ctx, 1, AttrBuilder(Ctx, LHSAttr))); return &Call; } return nullptr; @@ -929,13 +938,13 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { // is.fpclass (fneg x), mask -> is.fpclass x, (fneg mask) II.setArgOperand(1, ConstantInt::get(Src1->getType(), fneg(Mask))); - return replaceOperand(II, 0, FNegSrc); + return replaceArgOperand(II, 0, FNegSrc); } Value *FAbsSrc; if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) { II.setArgOperand(1, ConstantInt::get(Src1->getType(), inverse_fabs(Mask))); - return replaceOperand(II, 0, FAbsSrc); + return replaceArgOperand(II, 0, FAbsSrc); } if ((OrderedMask == fcInf || OrderedInvertedMask == fcInf) && @@ -1695,8 +1704,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (II->isCommutative()) { if (auto Pair = matchSymmetricPair(II->getOperand(0), II->getOperand(1))) { - replaceOperand(*II, 0, Pair->first); - replaceOperand(*II, 1, Pair->second); + replaceArgOperand(*II, 0, Pair->first); + replaceArgOperand(*II, 1, Pair->second); return II; } @@ -1733,11 +1742,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // TODO: Copy nsw if it was present on the neg? Value *X; if (match(IIOperand, m_Neg(m_Value(X)))) - return replaceOperand(*II, 0, X); + return replaceArgOperand(*II, 0, X); if (match(IIOperand, m_Select(m_Value(), m_Value(X), m_Neg(m_Deferred(X))))) - return replaceOperand(*II, 0, X); + return replaceArgOperand(*II, 0, X); if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) - return replaceOperand(*II, 0, X); + return replaceArgOperand(*II, 0, X); Value *Y; // abs(a * abs(b)) -> abs(a * b) @@ -1747,7 +1756,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { bool NSW = cast(IIOperand)->hasNoSignedWrap() && IntMinIsPoison; auto *XY = NSW ? Builder.CreateNSWMul(X, Y) : Builder.CreateMul(X, Y); - return replaceOperand(*II, 0, XY); + return replaceArgOperand(*II, 0, XY); } if (std::optional Known = @@ -2122,7 +2131,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { match(II->getArgOperand(0), m_FAbs(m_Value(X))) || match(II->getArgOperand(0), m_Intrinsic(m_Value(X), m_Value()))) - return replaceOperand(*II, 0, X); + return replaceArgOperand(*II, 0, X); } } break; @@ -2152,7 +2161,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (!ModuloC) return nullptr; if (ModuloC != ShAmtC) - return replaceOperand(*II, 2, ModuloC); + return replaceArgOperand(*II, 2, ModuloC); assert(match(ConstantFoldCompareInstOperands(ICmpInst::ICMP_UGT, WidthC, ShAmtC, DL), @@ -2234,8 +2243,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // TODO: If InnerMask == Op1, we could copy attributes from inner // callsite -> outer callsite. Value *NewMask = Builder.CreateAnd(II->getArgOperand(1), InnerMask); - replaceOperand(CI, 0, InnerPtr); - replaceOperand(CI, 1, NewMask); + replaceArgOperand(CI, 0, InnerPtr); + replaceArgOperand(CI, 1, NewMask); Changed = true; } @@ -2520,8 +2529,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *A, *B; if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) && match(II->getArgOperand(1), m_FNeg(m_Value(B)))) { - replaceOperand(*II, 0, A); - replaceOperand(*II, 1, B); + replaceArgOperand(*II, 0, A); + replaceArgOperand(*II, 1, B); return II; } @@ -2556,8 +2565,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (ElementCount::isKnownGT(NegatedCount, OtherCount) && ElementCount::isKnownLT(OtherCount, RetCount)) { Value *InverseOtherOp = Builder.CreateFNeg(OtherOp); - replaceOperand(*II, NegatedOpArg, OpNotNeg); - replaceOperand(*II, OtherOpArg, InverseOtherOp); + replaceArgOperand(*II, NegatedOpArg, OpNotNeg); + replaceArgOperand(*II, OtherOpArg, InverseOtherOp); return II; } // (-A) * B -> -(A * B), if it is cheaper to negate the result @@ -2589,16 +2598,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Src2 = II->getArgOperand(2); Value *X, *Y; if (match(Src0, m_FNeg(m_Value(X))) && match(Src1, m_FNeg(m_Value(Y)))) { - replaceOperand(*II, 0, X); - replaceOperand(*II, 1, Y); + replaceArgOperand(*II, 0, X); + replaceArgOperand(*II, 1, Y); return II; } // fma fabs(x), fabs(x), z -> fma x, x, z if (match(Src0, m_FAbs(m_Value(X))) && match(Src1, m_FAbs(m_Specific(X)))) { - replaceOperand(*II, 0, X); - replaceOperand(*II, 1, X); + replaceArgOperand(*II, 0, X); + replaceArgOperand(*II, 1, X); return II; } @@ -2645,7 +2654,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // copysign Mag, (copysign ?, X) --> copysign Mag, X Value *X; if (match(Sign, m_Intrinsic(m_Value(), m_Value(X)))) - return replaceOperand(*II, 1, X); + return replaceArgOperand(*II, 1, X); // Clear sign-bit of constant magnitude: // copysign -MagC, X --> copysign MagC, X @@ -2654,14 +2663,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(Mag, m_APFloat(MagC)) && MagC->isNegative()) { APFloat PosMagC = *MagC; PosMagC.clearSign(); - return replaceOperand(*II, 0, ConstantFP::get(Mag->getType(), PosMagC)); + return replaceArgOperand(*II, 0, + ConstantFP::get(Mag->getType(), PosMagC)); } // Peek through changes of magnitude's sign-bit. This call rewrites those: // copysign (fabs X), Sign --> copysign X, Sign // copysign (fneg X), Sign --> copysign X, Sign if (match(Mag, m_FAbs(m_Value(X))) || match(Mag, m_FNeg(m_Value(X)))) - return replaceOperand(*II, 0, X); + return replaceArgOperand(*II, 0, X); break; } @@ -2689,10 +2699,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } // fabs (select Cond, -FVal, FVal) --> fabs FVal if (match(TVal, m_FNeg(m_Specific(FVal)))) - return replaceOperand(*II, 0, FVal); + return replaceArgOperand(*II, 0, FVal); // fabs (select Cond, TVal, -TVal) --> fabs TVal if (match(FVal, m_FNeg(m_Specific(TVal)))) - return replaceOperand(*II, 0, TVal); + return replaceArgOperand(*II, 0, TVal); } Value *Magnitude, *Sign; @@ -2731,7 +2741,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // cos(-x) --> cos(x) // cos(fabs(x)) --> cos(x) // cos(copysign(x, y)) --> cos(x) - return replaceOperand(*II, 0, X); + return replaceArgOperand(*II, 0, X); } break; } @@ -2774,8 +2784,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // width. Value *NewExp = Builder.CreateAdd(InnerExp, Exp); II->setArgOperand(1, NewExp); + II->dropPoisonGeneratingAndUBImplyingParamAttrs(1); II->setFastMathFlags(InnerFlags); // Or the inner flags. - return replaceOperand(*II, 0, InnerSrc); + return replaceArgOperand(*II, 0, InnerSrc); } } @@ -3461,11 +3472,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Arg = II->getArgOperand(0); Value *Vect; - if (Value *NewOp = - simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { - replaceUse(II->getOperandUse(0), NewOp); - return II; - } + if (Value *NewOp = simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) + return replaceArgOperand(*II, 0, NewOp); if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { if (auto *FTy = dyn_cast(Vect->getType())) @@ -3501,8 +3509,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Value *NewOp = simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { - replaceUse(II->getOperandUse(0), NewOp); - return II; + return replaceArgOperand(*II, 0, NewOp); } if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { @@ -3535,10 +3542,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Vect; if (Value *NewOp = - simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { - replaceUse(II->getOperandUse(0), NewOp); - return II; - } + simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) + return replaceArgOperand(*II, 0, NewOp); if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { if (auto *VTy = dyn_cast(Vect->getType())) @@ -3566,8 +3571,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Value *NewOp = simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { - replaceUse(II->getOperandUse(0), NewOp); - return II; + return replaceArgOperand(*II, 0, NewOp); } if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { @@ -3597,8 +3601,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Value *NewOp = simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { - replaceUse(II->getOperandUse(0), NewOp); - return II; + return replaceArgOperand(*II, 0, NewOp); } if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { @@ -3639,8 +3642,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Value *NewOp = simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) { - replaceUse(II->getOperandUse(0), NewOp); - return II; + return replaceArgOperand(*II, 0, NewOp); } if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { @@ -3674,8 +3676,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { : 0; Value *Arg = II->getArgOperand(ArgIdx); if (Value *NewOp = simplifyReductionOperand(Arg, CanReorderLanes)) { - replaceUse(II->getOperandUse(ArgIdx), NewOp); - return nullptr; + return replaceArgOperand(*II, ArgIdx, NewOp); } break; } diff --git a/llvm/test/Transforms/InstCombine/abs-intrinsic.ll b/llvm/test/Transforms/InstCombine/abs-intrinsic.ll index 022d60d2f501b..a082375c778f5 100644 --- a/llvm/test/Transforms/InstCombine/abs-intrinsic.ll +++ b/llvm/test/Transforms/InstCombine/abs-intrinsic.ll @@ -227,6 +227,16 @@ define i32 @abs_of_neg(i32 %x) { ret i32 %b } +define i32 @abs_of_neg_range(i32 %x) { +; CHECK-LABEL: @abs_of_neg_range( +; CHECK-NEXT: [[B:%.*]] = call i32 @llvm.abs.i32(i32 [[X:%.*]], i1 false) +; CHECK-NEXT: ret i32 [[B]] +; + %a = sub i32 0, %x + %b = call i32 @llvm.abs.i32(i32 range(i32 -10, 0) %a, i1 false) + ret i32 %b +} + define <4 x i32> @abs_of_neg_vec(<4 x i32> %x) { ; CHECK-LABEL: @abs_of_neg_vec( ; CHECK-NEXT: [[B:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[X:%.*]], i1 false) diff --git a/llvm/test/Transforms/InstCombine/minmax-fold.ll b/llvm/test/Transforms/InstCombine/minmax-fold.ll index ccdf4400b16b5..21474a1faffdf 100644 --- a/llvm/test/Transforms/InstCombine/minmax-fold.ll +++ b/llvm/test/Transforms/InstCombine/minmax-fold.ll @@ -1541,3 +1541,14 @@ define <2 x i32> @test_umax_smax_vec_neg(<2 x i32> %x) { %umax = call <2 x i32> @llvm.umax.v2i32(<2 x i32> %smax, <2 x i32> ) ret <2 x i32> %umax } + +; Make sure that poison-generating/UB-implying parameters are dropped + +define i32 @umax_commute_operand_drop_attrs(i32 %x) { +; CHECK-LABEL: @umax_commute_operand_drop_attrs( +; CHECK-NEXT: [[RET:%.*]] = call i32 @llvm.umax.i32(i32 [[X:%.*]], i32 noundef range(i32 -10, -8) -10) +; CHECK-NEXT: ret i32 [[RET]] +; + %ret = call range(i32 -10, 0) i32 @llvm.umax.i32(i32 noundef range(i32 -10, -8) -10, i32 %x) + ret i32 %ret +}