From 0019711079e7d929b1853748d0f84c22adb04a62 Mon Sep 17 00:00:00 2001 From: Jeffrey Byrnes Date: Thu, 17 Apr 2025 10:11:18 -0700 Subject: [PATCH 1/5] [InstCombine] Extend bitmask->select combine to match and->mul Change-Id: I1cc2acd3804dde50636518f3ef2c9581848ae9f6 --- .../InstCombine/InstCombineAndOrXor.cpp | 122 ++++++++++++------ .../test/Transforms/InstCombine/or-bitmask.ll | 95 ++++++++++++-- 2 files changed, 163 insertions(+), 54 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 59b46ebdb72e2..ea166717d5c05 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3593,6 +3593,72 @@ static Value *foldOrOfInversions(BinaryOperator &I, return nullptr; } +struct DecomposedBitMaskMul { + Value *X; + APInt Factor; + APInt Mask; +}; + +static std::optional matchBitmaskMul(Value *V) { + Instruction *Op = dyn_cast(V); + if (!Op) + return std::nullopt; + + Value *MulOp = nullptr; + const APInt *MulConst = nullptr; + if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) { + Value *Original = nullptr; + const APInt *Mask = nullptr; + if (!MulConst->isStrictlyPositive()) + return std::nullopt; + + if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) { + if (!Mask->isStrictlyPositive()) + return std::nullopt; + DecomposedBitMaskMul Ret; + Ret.X = Original; + Ret.Mask = *Mask; + Ret.Factor = *MulConst; + return Ret; + } + return std::nullopt; + } + + Value *Cond = nullptr; + const APInt *EqZero = nullptr, *NeZero = nullptr; + + // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C + if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) { + auto ICmpDecompose = + decomposeBitTest(Cond, /*LookThruTrunc=*/true, + /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true); + if (!ICmpDecompose.has_value()) + return std::nullopt; + + if (ICmpDecompose->Pred == ICmpInst::ICMP_NE) + std::swap(EqZero, NeZero); + + if (!EqZero->isZero() || !NeZero->isStrictlyPositive()) + return std::nullopt; + + if (!ICmpInst::isEquality(ICmpDecompose->Pred) || + !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() || + ICmpDecompose->Mask.isNegative()) + return std::nullopt; + + if (!NeZero->urem(ICmpDecompose->Mask).isZero()) + return std::nullopt; + + DecomposedBitMaskMul Ret; + Ret.X = ICmpDecompose->X; + Ret.Mask = ICmpDecompose->Mask; + Ret.Factor = NeZero->udiv(ICmpDecompose->Mask); + return Ret; + } + + return std::nullopt; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -3675,49 +3741,19 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*NSW=*/true, /*NUW=*/true)) return R; - Value *Cond0 = nullptr, *Cond1 = nullptr; - const APInt *Op0Eq = nullptr, *Op0Ne = nullptr; - const APInt *Op1Eq = nullptr, *Op1Ne = nullptr; - - // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C - if (match(I.getOperand(0), - m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) && - match(I.getOperand(1), - m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) { - - auto LHSDecompose = - decomposeBitTest(Cond0, /*LookThruTrunc=*/true, - /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true); - auto RHSDecompose = - decomposeBitTest(Cond1, /*LookThruTrunc=*/true, - /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true); - - if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X && - RHSDecompose->Mask.isPowerOf2() && LHSDecompose->Mask.isPowerOf2() && - LHSDecompose->Mask != RHSDecompose->Mask && - LHSDecompose->Mask.getBitWidth() == Op0Ne->getBitWidth() && - RHSDecompose->Mask.getBitWidth() == Op1Ne->getBitWidth()) { - assert(Op0Ne->getBitWidth() == Op1Ne->getBitWidth()); - assert(ICmpInst::isEquality(LHSDecompose->Pred)); - if (LHSDecompose->Pred == ICmpInst::ICMP_NE) - std::swap(Op0Eq, Op0Ne); - if (RHSDecompose->Pred == ICmpInst::ICMP_NE) - std::swap(Op1Eq, Op1Ne); - - if (!Op0Ne->isZero() && !Op1Ne->isZero() && Op0Eq->isZero() && - Op1Eq->isZero() && Op0Ne->urem(LHSDecompose->Mask).isZero() && - Op1Ne->urem(RHSDecompose->Mask).isZero() && - Op0Ne->udiv(LHSDecompose->Mask) == - Op1Ne->udiv(RHSDecompose->Mask)) { - auto NewAnd = Builder.CreateAnd( - LHSDecompose->X, - ConstantInt::get(LHSDecompose->X->getType(), - (LHSDecompose->Mask + RHSDecompose->Mask))); - - return BinaryOperator::CreateMul( - NewAnd, ConstantInt::get(NewAnd->getType(), - Op0Ne->udiv(LHSDecompose->Mask))); - } + auto Decomp0 = matchBitmaskMul(I.getOperand(0)); + auto Decomp1 = matchBitmaskMul(I.getOperand(1)); + + if (Decomp0 && Decomp1) { + if (Decomp0->X == Decomp1->X && + (Decomp0->Mask & Decomp1->Mask).isZero() && + Decomp0->Factor == Decomp1->Factor) { + auto NewAnd = Builder.CreateAnd( + Decomp0->X, ConstantInt::get(Decomp0->X->getType(), + (Decomp0->Mask + Decomp1->Mask))); + + return BinaryOperator::CreateMul( + NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor)); } } } diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll index 3b482dc1794db..87f0bbf4d37ab 100644 --- a/llvm/test/Transforms/InstCombine/or-bitmask.ll +++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll @@ -36,13 +36,9 @@ define i32 @add_select_cmp_and2(i32 %in) { define i32 @add_select_cmp_and3(i32 %in) { ; CHECK-LABEL: @add_select_cmp_and3( -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 7 ; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72 -; CHECK-NEXT: [[BITOP2:%.*]] = and i32 [[IN]], 4 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0 -; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288 -; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]] -; CHECK-NEXT: ret i32 [[OUT]] +; CHECK-NEXT: ret i32 [[TEMP]] ; %bitop0 = and i32 %in, 1 %cmp0 = icmp eq i32 %bitop0, 0 @@ -60,12 +56,9 @@ define i32 @add_select_cmp_and3(i32 %in) { define i32 @add_select_cmp_and4(i32 %in) { ; CHECK-LABEL: @add_select_cmp_and4( -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 -; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 -; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN:%.*]], 15 ; CHECK-NEXT: [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72 -; CHECK-NEXT: [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]] -; CHECK-NEXT: ret i32 [[OUT1]] +; CHECK-NEXT: ret i32 [[TEMP3]] ; %bitop0 = and i32 %in, 1 %cmp0 = icmp eq i32 %bitop0, 0 @@ -361,6 +354,86 @@ define i64 @mask_select_types_1(i64 %in) { ret i64 %out } +define i32 @add_select_cmp_mixed1(i32 %in) { +; CHECK-LABEL: @add_select_cmp_mixed1( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %mask = and i32 %in, 1 + %sel0 = mul i32 %mask, 72 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_mixed2(i32 %in) { +; CHECK-LABEL: @add_select_cmp_mixed2( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %mask = and i32 %in, 2 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = mul i32 %mask, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_mul(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_mul( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %mask0 = and i32 %in, 1 + %sel0 = mul i32 %mask0, 72 + %mask1 = and i32 %in, 2 + %sel1 = mul i32 %mask1, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_mixed2_mismatch(i32 %in) { +; CHECK-LABEL: @add_select_cmp_mixed2_mismatch( +; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0 +; CHECK-NEXT: [[MASK:%.*]] = and i32 [[IN]], 2 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 73 +; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK]], 72 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %mask = and i32 %in, 2 + %sel0 = select i1 %cmp0, i32 0, i32 73 + %sel1 = mul i32 %mask, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_mul_mismatch(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_mul_mismatch( +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[IN:%.*]] to i1 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[TMP1]], i32 73, i32 0 +; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[IN]], 2 +; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %mask0 = and i32 %in, 1 + %sel0 = mul i32 %mask0, 73 + %mask1 = and i32 %in, 2 + %sel1 = mul i32 %mask1, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; CONSTSPLAT: {{.*}} ; CONSTVEC: {{.*}} From 7b63d9b172597da44200f8718a2e3816e436e686 Mon Sep 17 00:00:00 2001 From: Jeffrey Byrnes Date: Thu, 22 May 2025 11:06:24 -0700 Subject: [PATCH 2/5] Review comments + fix some conditions Change-Id: I4b71adfd8bffdda4d2b0d1cba85a3fd73a105a28 --- .../InstCombine/InstCombineAndOrXor.cpp | 52 ++++++++++++------- .../test/Transforms/InstCombine/or-bitmask.ll | 8 +-- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index ea166717d5c05..62ff45fb24379 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3593,10 +3593,16 @@ static Value *foldOrOfInversions(BinaryOperator &I, return nullptr; } +// A decomposition of ((A & N) ? 0 : N * C) . Where X = A, Factor = C, Mask = N. +// The NUW / NSW bools +// Note that we can decompose equivalent forms of this expression (e.g. ((A & N) +// * C)) struct DecomposedBitMaskMul { Value *X; APInt Factor; APInt Mask; + bool NUW; + bool NSW; }; static std::optional matchBitmaskMul(Value *V) { @@ -3606,20 +3612,21 @@ static std::optional matchBitmaskMul(Value *V) { Value *MulOp = nullptr; const APInt *MulConst = nullptr; + + // Decompose (A & N) * C) into BitMaskMul if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) { Value *Original = nullptr; const APInt *Mask = nullptr; - if (!MulConst->isStrictlyPositive()) + if (MulConst->isZero()) return std::nullopt; if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) { - if (!Mask->isStrictlyPositive()) + if (Mask->isZero()) return std::nullopt; - DecomposedBitMaskMul Ret; - Ret.X = Original; - Ret.Mask = *Mask; - Ret.Factor = *MulConst; - return Ret; + return std::optional( + {Original, *MulConst, *Mask, + cast(Op)->hasNoUnsignedWrap(), + cast(Op)->hasNoSignedWrap()}); } return std::nullopt; } @@ -3627,7 +3634,7 @@ static std::optional matchBitmaskMul(Value *V) { Value *Cond = nullptr; const APInt *EqZero = nullptr, *NeZero = nullptr; - // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C + // Decompose ((A & N) ? 0 : N * C) into BitMaskMul if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) { auto ICmpDecompose = decomposeBitTest(Cond, /*LookThruTrunc=*/true, @@ -3638,22 +3645,20 @@ static std::optional matchBitmaskMul(Value *V) { if (ICmpDecompose->Pred == ICmpInst::ICMP_NE) std::swap(EqZero, NeZero); - if (!EqZero->isZero() || !NeZero->isStrictlyPositive()) + if (!EqZero->isZero() || NeZero->isZero()) return std::nullopt; if (!ICmpInst::isEquality(ICmpDecompose->Pred) || !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() || - ICmpDecompose->Mask.isNegative()) + ICmpDecompose->Mask.isZero()) return std::nullopt; if (!NeZero->urem(ICmpDecompose->Mask).isZero()) return std::nullopt; - DecomposedBitMaskMul Ret; - Ret.X = ICmpDecompose->X; - Ret.Mask = ICmpDecompose->Mask; - Ret.Factor = NeZero->udiv(ICmpDecompose->Mask); - return Ret; + return std::optional( + {ICmpDecompose->X, NeZero->udiv(ICmpDecompose->Mask), + ICmpDecompose->Mask, /*NUW=*/false, /*NSW=*/false}); } return std::nullopt; @@ -3741,19 +3746,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*NSW=*/true, /*NUW=*/true)) return R; - auto Decomp0 = matchBitmaskMul(I.getOperand(0)); + // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C + // This also accepts the equivalent mul form of (A & N) ? 0 : N * C) + // expressions i.e. (A & N) * C auto Decomp1 = matchBitmaskMul(I.getOperand(1)); - - if (Decomp0 && Decomp1) { - if (Decomp0->X == Decomp1->X && + if (Decomp1) { + auto Decomp0 = matchBitmaskMul(I.getOperand(0)); + if (Decomp0 && Decomp0->X == Decomp1->X && (Decomp0->Mask & Decomp1->Mask).isZero() && Decomp0->Factor == Decomp1->Factor) { + auto NewAnd = Builder.CreateAnd( Decomp0->X, ConstantInt::get(Decomp0->X->getType(), (Decomp0->Mask + Decomp1->Mask))); - return BinaryOperator::CreateMul( + auto Combined = BinaryOperator::CreateMul( NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor)); + + Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW); + Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW); + return Combined; } } } diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll index 87f0bbf4d37ab..dcfbe171dd08f 100644 --- a/llvm/test/Transforms/InstCombine/or-bitmask.ll +++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll @@ -37,8 +37,8 @@ define i32 @add_select_cmp_and2(i32 %in) { define i32 @add_select_cmp_and3(i32 %in) { ; CHECK-LABEL: @add_select_cmp_and3( ; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 7 -; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72 -; CHECK-NEXT: ret i32 [[TEMP]] +; CHECK-NEXT: [[TEMP1:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[TEMP1]] ; %bitop0 = and i32 %in, 1 %cmp0 = icmp eq i32 %bitop0, 0 @@ -57,8 +57,8 @@ define i32 @add_select_cmp_and3(i32 %in) { define i32 @add_select_cmp_and4(i32 %in) { ; CHECK-LABEL: @add_select_cmp_and4( ; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN:%.*]], 15 -; CHECK-NEXT: [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72 -; CHECK-NEXT: ret i32 [[TEMP3]] +; CHECK-NEXT: [[TEMP2:%.*]] = mul nuw nsw i32 [[TMP2]], 72 +; CHECK-NEXT: ret i32 [[TEMP2]] ; %bitop0 = and i32 %in, 1 %cmp0 = icmp eq i32 %bitop0, 0 From 5fa229ba2432d00512a7d58c3ffa7ec610ee4aa6 Mon Sep 17 00:00:00 2001 From: Jeffrey Byrnes Date: Tue, 27 May 2025 11:03:46 -0700 Subject: [PATCH 3/5] Fix crash due to mismatch APInt bitwidth Change-Id: I12f77aedbf1a2edfe63e4d03cd1e5c1c601365a7 --- llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 62ff45fb24379..e357e3d296cc1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3650,7 +3650,8 @@ static std::optional matchBitmaskMul(Value *V) { if (!ICmpInst::isEquality(ICmpDecompose->Pred) || !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() || - ICmpDecompose->Mask.isZero()) + ICmpDecompose->Mask.isZero() || + NeZero->getBitWidth() != ICmpDecompose->Mask.getBitWidth()) return std::nullopt; if (!NeZero->urem(ICmpDecompose->Mask).isZero()) From 9ccf1fa021df9068cd071493942b7c718dd8ad29 Mon Sep 17 00:00:00 2001 From: Jeffrey Byrnes Date: Thu, 5 Jun 2025 09:40:16 -0700 Subject: [PATCH 4/5] Review comments Change-Id: I56a280990a9bae36e59f784a7f48bdbc9f7ca539 --- .../InstCombine/InstCombineAndOrXor.cpp | 37 ++++++++----------- .../test/Transforms/InstCombine/or-bitmask.ll | 17 +++++++++ 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index e357e3d296cc1..de029be1d28ce 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3593,10 +3593,9 @@ static Value *foldOrOfInversions(BinaryOperator &I, return nullptr; } -// A decomposition of ((A & N) ? 0 : N * C) . Where X = A, Factor = C, Mask = N. -// The NUW / NSW bools -// Note that we can decompose equivalent forms of this expression (e.g. ((A & N) -// * C)) +// A decomposition of ((X & Mask) ? 0 : Mask * Factor) . The NUW / NSW bools +// track these properities for preservation. Note that we can decompose +// equivalent forms of this expression (e.g. ((X & Mask) * Factor)) struct DecomposedBitMaskMul { Value *X; APInt Factor; @@ -3610,25 +3609,20 @@ static std::optional matchBitmaskMul(Value *V) { if (!Op) return std::nullopt; - Value *MulOp = nullptr; const APInt *MulConst = nullptr; // Decompose (A & N) * C) into BitMaskMul - if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) { - Value *Original = nullptr; - const APInt *Mask = nullptr; - if (MulConst->isZero()) + Value *Original = nullptr; + const APInt *Mask = nullptr; + if (match(Op, m_Mul(m_And(m_Value(Original), m_APInt(Mask)), + m_APInt(MulConst)))) { + if (MulConst->isZero() || Mask->isZero()) return std::nullopt; - if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) { - if (Mask->isZero()) - return std::nullopt; - return std::optional( - {Original, *MulConst, *Mask, - cast(Op)->hasNoUnsignedWrap(), - cast(Op)->hasNoSignedWrap()}); - } - return std::nullopt; + return std::optional( + {Original, *MulConst, *Mask, + cast(Op)->hasNoUnsignedWrap(), + cast(Op)->hasNoSignedWrap()}); } Value *Cond = nullptr; @@ -3642,15 +3636,16 @@ static std::optional matchBitmaskMul(Value *V) { if (!ICmpDecompose.has_value()) return std::nullopt; + assert(ICmpInst::isEquality(ICmpDecompose->Pred) && + ICmpDecompose->C.isZero()); + if (ICmpDecompose->Pred == ICmpInst::ICMP_NE) std::swap(EqZero, NeZero); if (!EqZero->isZero() || NeZero->isZero()) return std::nullopt; - if (!ICmpInst::isEquality(ICmpDecompose->Pred) || - !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() || - ICmpDecompose->Mask.isZero() || + if (!ICmpDecompose->Mask.isPowerOf2() || ICmpDecompose->Mask.isZero() || NeZero->getBitWidth() != ICmpDecompose->Mask.getBitWidth()) return std::nullopt; diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll index dcfbe171dd08f..3c992dfea569a 100644 --- a/llvm/test/Transforms/InstCombine/or-bitmask.ll +++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll @@ -434,6 +434,23 @@ define i32 @add_select_cmp_and_mul_mismatch(i32 %in) { ret i32 %out } +define i32 @and_mul_non_disjoint(i32 %in) { +; CHECK-LABEL: @and_mul_non_disjoint( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 2 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[IN]], 4 +; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72 +; CHECK-NEXT: [[OUT1:%.*]] = or i32 [[OUT]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT1]] +; + %mask0 = and i32 %in, 2 + %sel0 = mul i32 %mask0, 72 + %mask1 = and i32 %in, 4 + %sel1 = mul i32 %mask1, 72 + %out = or i32 %sel0, %sel1 + ret i32 %out +} + ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; CONSTSPLAT: {{.*}} ; CONSTVEC: {{.*}} From acd7e8b406af29f56ef9abfcde14cdec08ecaa00 Mon Sep 17 00:00:00 2001 From: Jeffrey Byrnes Date: Wed, 11 Jun 2025 15:31:26 -0700 Subject: [PATCH 5/5] Review comments 1 Change-Id: I04ff0637b85922561dda9e7e827ba3fe9d9c0cbc --- .../InstCombine/InstCombineAndOrXor.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index de029be1d28ce..292490ea1fca2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3593,9 +3593,10 @@ static Value *foldOrOfInversions(BinaryOperator &I, return nullptr; } -// A decomposition of ((X & Mask) ? 0 : Mask * Factor) . The NUW / NSW bools +// A decomposition of ((X & Mask) * Factor). The NUW / NSW bools // track these properities for preservation. Note that we can decompose -// equivalent forms of this expression (e.g. ((X & Mask) * Factor)) +// equivalent select form of this expression (e.g. (!(X & Mask) ? 0 : Mask * +// Factor)) struct DecomposedBitMaskMul { Value *X; APInt Factor; @@ -3609,11 +3610,10 @@ static std::optional matchBitmaskMul(Value *V) { if (!Op) return std::nullopt; - const APInt *MulConst = nullptr; - // Decompose (A & N) * C) into BitMaskMul Value *Original = nullptr; const APInt *Mask = nullptr; + const APInt *MulConst = nullptr; if (match(Op, m_Mul(m_And(m_Value(Original), m_APInt(Mask)), m_APInt(MulConst)))) { if (MulConst->isZero() || Mask->isZero()) @@ -3742,9 +3742,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*NSW=*/true, /*NUW=*/true)) return R; - // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C - // This also accepts the equivalent mul form of (A & N) ? 0 : N * C) - // expressions i.e. (A & N) * C + // (A & N) * C + (A & M) * C -> (A & (N + M)) & C + // This also accepts the equivalent select form of (A & N) * C + // expressions i.e. !(A & N) ? 0 : N * C) auto Decomp1 = matchBitmaskMul(I.getOperand(1)); if (Decomp1) { auto Decomp0 = matchBitmaskMul(I.getOperand(0)); @@ -3752,11 +3752,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { (Decomp0->Mask & Decomp1->Mask).isZero() && Decomp0->Factor == Decomp1->Factor) { - auto NewAnd = Builder.CreateAnd( + Value *NewAnd = Builder.CreateAnd( Decomp0->X, ConstantInt::get(Decomp0->X->getType(), (Decomp0->Mask + Decomp1->Mask))); - auto Combined = BinaryOperator::CreateMul( + auto *Combined = BinaryOperator::CreateMul( NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor)); Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);