diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 3ddf182149e57..9bbac7ed86308 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" @@ -345,6 +346,45 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, std::make_pair(LeftType, RightType)); } +static Instruction *foldBinOpIntoSelectWhenConditionIsOperand(Instruction &I) { + Type *Ty = I.getType(); + Value *CommonInput; + ConstantInt *SelLeft, *SelRight; + bool MatchRes = + match(&I, m_c_BinOp(m_Select(m_Value(CommonInput), m_ConstantInt(SelLeft), + m_ConstantInt(SelRight)), + m_ZExtOrTruncOrSelf(m_Deferred(CommonInput)))); + if (!MatchRes) + return nullptr; + + switch (I.getOpcode()) { + case Instruction::Or: + return SelectInst::Create( + CommonInput, ConstantInt::get(Ty, SelLeft->getValue() | 1), SelRight); + case Instruction::And: + return SelectInst::Create(CommonInput, + ConstantInt::get(Ty, SelLeft->getValue() & 1), + ConstantInt::getNullValue(Ty)); + case Instruction::Xor: + return SelectInst::Create(CommonInput, + ConstantInt::get(Ty, SelLeft->getValue() ^ 1), + ConstantInt::get(Ty, SelRight->getValue() ^ 0)); + case Instruction::Add: + return SelectInst::Create(CommonInput, + ConstantInt::get(Ty, SelLeft->getValue() + 1), + SelRight); + case Instruction::Sub: + return SelectInst::Create(CommonInput, + ConstantInt::get(Ty, SelLeft->getValue() - 1), + SelRight); + case Instruction::Mul: + return SelectInst::Create(CommonInput, + ConstantInt::get(Ty, SelLeft->getValue()), + ConstantInt::getNullValue(Ty)); + } + return nullptr; +} + /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single /// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros /// and the right hand side is of type BMask_Mixed. For example, @@ -2470,6 +2510,10 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y); } + + if (Instruction *Res = foldBinOpIntoSelectWhenConditionIsOperand(I)) + return Res; + // Canonicalize: // (X +/- Y) & Y --> ~X & Y when Y is a power of 2. if (match(&I, m_c_And(m_Value(Y), m_OneUse(m_CombineOr( @@ -4099,6 +4143,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { return BinaryOperator::CreateXor(Or, ConstantInt::get(Ty, *CV)); } + if (Instruction *Res = foldBinOpIntoSelectWhenConditionIsOperand(I)) + return Res; + + // If the operands have no common bits set: // or (mul X, Y), X --> add (mul X, Y), X --> mul X, (Y + 1) if (match(&I, m_c_DisjointOr(m_OneUse(m_Mul(m_Value(X), m_Value(Y))), @@ -5196,6 +5244,10 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { return BinaryOperator::CreateXor(XorBC, X); } + if (Instruction *Res = foldBinOpIntoSelectWhenConditionIsOperand(I)) + return Res; + + // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) // This it a special case in haveNoCommonBitsSet, but the computeKnownBits // calls in there are unnecessary as SimplifyDemandedInstructionBits should diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll index 25f624ee13412..7c703a036805d 100644 --- a/llvm/test/Transforms/InstCombine/binop-select.ll +++ b/llvm/test/Transforms/InstCombine/binop-select.ll @@ -403,3 +403,146 @@ define i32 @ashr_sel_op1_use(i1 %b) { %r = ashr i32 -2, %s ret i32 %r } + +define i8 @commonArgWithOr0(i1 %arg0) { +; CHECK-LABEL: @commonArgWithOr0( +; CHECK-NEXT: [[V3:%.*]] = select i1 [[ARG0:%.*]], i8 17, i8 24 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 0, i8 8 + %v2 = or i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithOr1(i1 %arg0) { +; CHECK-LABEL: @commonArgWithOr1( +; CHECK-NEXT: [[V3:%.*]] = select i1 [[ARG0:%.*]], i8 17, i8 23 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 1, i8 7 + %v2 = or i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithOr2(i1 %arg0) { +; CHECK-LABEL: @commonArgWithOr2( +; CHECK-NEXT: [[V3:%.*]] = select i1 [[ARG0:%.*]], i8 21, i8 58 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 21, i8 42 + %v2 = or i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithAnd0(i1 %arg0) { +; CHECK-LABEL: @commonArgWithAnd0( +; CHECK-NEXT: ret i8 16 +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 0, i8 8 + %v2 = and i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithAnd1(i1 %arg0) { +; CHECK-LABEL: @commonArgWithAnd1( +; CHECK-NEXT: ret i8 16 +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 8, i8 1 + %v2 = and i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithAnd2(i1 %arg0) { +; CHECK-LABEL: @commonArgWithAnd2( +; CHECK-NEXT: [[V2:%.*]] = zext i1 [[ARG0:%.*]] to i8 +; CHECK-NEXT: [[V3:%.*]] = or disjoint i8 [[V2]], 16 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 1, i8 7 + %v2 = and i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithAnd3(i1 %arg0) { +; CHECK-LABEL: @commonArgWithAnd3( +; CHECK-NEXT: [[V2:%.*]] = zext i1 [[ARG0:%.*]] to i8 +; CHECK-NEXT: [[V3:%.*]] = or disjoint i8 [[V2]], 16 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 21, i8 42 + %v2 = and i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithXor0(i1 %arg0) { +; CHECK-LABEL: @commonArgWithXor0( +; CHECK-NEXT: [[V3:%.*]] = select i1 [[ARG0:%.*]], i8 17, i8 24 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 0, i8 8 + %v2 = xor i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithXor1(i1 %arg0) { +; CHECK-LABEL: @commonArgWithXor1( +; CHECK-NEXT: [[V2:%.*]] = select i1 [[ARG0:%.*]], i8 8, i8 1 +; CHECK-NEXT: ret i8 [[V2]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 9, i8 1 + %v2 = xor i8 %v1, %v0 + ret i8 %v2 +} + +define i8 @commonArgWithXor2(i1 %arg0) { +; CHECK-LABEL: @commonArgWithXor2( +; CHECK-NEXT: [[V3:%.*]] = select i1 [[ARG0:%.*]], i8 16, i8 23 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 1, i8 7 + %v2 = xor i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithXor3(i1 %arg0) { +; CHECK-LABEL: @commonArgWithXor3( +; CHECK-NEXT: [[V3:%.*]] = select i1 [[ARG0:%.*]], i8 20, i8 61 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 21, i8 45 + %v2 = xor i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +} + +define i8 @commonArgWithAdd0(i1 %arg0) { +; CHECK-LABEL: @commonArgWithXor3( +; CHECK-NEXT: [[V3:%.*]] = select i1 [[ARG0:%.*]], i8 20, i8 61 +; CHECK-NEXT: ret i8 [[V3]] +; + %v0 = zext i1 %arg0 to i8 + %v1 = select i1 %arg0, i8 21, i8 45 + %v2 = add i8 %v1, %v0 + %v3 = or i8 %v2, 16 + ret i8 %v3 +}