Skip to content

Commit d133558

Browse files
committed
[InstCombine] Enhance pblendvb to select conversion with complex boolean masks
1 parent d5d68a1 commit d133558

File tree

1 file changed

+174
-36
lines changed

1 file changed

+174
-36
lines changed

llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp

Lines changed: 174 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,124 @@ static Value *getBoolVecFromMask(Value *Mask, const DataLayout &DL) {
5252
return nullptr;
5353
}
5454

55+
// Helper function to decompose complex logic on sign-extended i1 vectors
56+
static Value *tryDecomposeVectorLogicMask(Value *Mask, IRBuilderBase &Builder) {
57+
// Look through bitcasts
58+
Mask = InstCombiner::peekThroughBitcast(Mask);
59+
60+
// Direct sign-extension case (should be caught by the main code path)
61+
Value *InnerVal;
62+
if (match(Mask, m_SExt(m_Value(InnerVal))) &&
63+
InnerVal->getType()->isVectorTy() &&
64+
InnerVal->getType()->getScalarType()->isIntegerTy(1))
65+
return InnerVal;
66+
67+
// Handle AND of sign-extended vectors: (sext A) & (sext B) -> sext(A & B)
68+
Value *LHS, *RHS;
69+
Value *LHSInner, *RHSInner;
70+
if (match(Mask, m_And(m_Value(LHS), m_Value(RHS)))) {
71+
LHS = InstCombiner::peekThroughBitcast(LHS);
72+
RHS = InstCombiner::peekThroughBitcast(RHS);
73+
74+
if (match(LHS, m_SExt(m_Value(LHSInner))) &&
75+
LHSInner->getType()->isVectorTy() &&
76+
LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
77+
match(RHS, m_SExt(m_Value(RHSInner))) &&
78+
RHSInner->getType()->isVectorTy() &&
79+
RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
80+
LHSInner->getType() == RHSInner->getType()) {
81+
return Builder.CreateAnd(LHSInner, RHSInner);
82+
}
83+
84+
// Try recursively on each operand
85+
Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
86+
Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
87+
if (DecomposedLHS && DecomposedRHS &&
88+
DecomposedLHS->getType() == DecomposedRHS->getType())
89+
return Builder.CreateAnd(DecomposedLHS, DecomposedRHS);
90+
}
91+
92+
// Handle XOR of sign-extended vectors: (sext A) ^ (sext B) -> sext(A ^ B)
93+
if (match(Mask, m_Xor(m_Value(LHS), m_Value(RHS)))) {
94+
LHS = InstCombiner::peekThroughBitcast(LHS);
95+
RHS = InstCombiner::peekThroughBitcast(RHS);
96+
97+
if (match(LHS, m_SExt(m_Value(LHSInner))) &&
98+
LHSInner->getType()->isVectorTy() &&
99+
LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
100+
match(RHS, m_SExt(m_Value(RHSInner))) &&
101+
RHSInner->getType()->isVectorTy() &&
102+
RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
103+
LHSInner->getType() == RHSInner->getType()) {
104+
return Builder.CreateXor(LHSInner, RHSInner);
105+
}
106+
107+
// Try recursively on each operand
108+
Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
109+
Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
110+
if (DecomposedLHS && DecomposedRHS &&
111+
DecomposedLHS->getType() == DecomposedRHS->getType())
112+
return Builder.CreateXor(DecomposedLHS, DecomposedRHS);
113+
}
114+
115+
// Handle OR of sign-extended vectors: (sext A) | (sext B) -> sext(A | B)
116+
if (match(Mask, m_Or(m_Value(LHS), m_Value(RHS)))) {
117+
LHS = InstCombiner::peekThroughBitcast(LHS);
118+
RHS = InstCombiner::peekThroughBitcast(RHS);
119+
120+
if (match(LHS, m_SExt(m_Value(LHSInner))) &&
121+
LHSInner->getType()->isVectorTy() &&
122+
LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
123+
match(RHS, m_SExt(m_Value(RHSInner))) &&
124+
RHSInner->getType()->isVectorTy() &&
125+
RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
126+
LHSInner->getType() == RHSInner->getType()) {
127+
return Builder.CreateOr(LHSInner, RHSInner);
128+
}
129+
130+
// Try recursively on each operand
131+
Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
132+
Value *DecomposedRHS = tryDecomposeVectorLogicMask(RHS, Builder);
133+
if (DecomposedLHS && DecomposedRHS &&
134+
DecomposedLHS->getType() == DecomposedRHS->getType())
135+
return Builder.CreateOr(DecomposedLHS, DecomposedRHS);
136+
}
137+
138+
// Handle AndNot: (sext A) & ~(sext B) -> sext(A & ~B)
139+
Value *NotOp;
140+
if (match(Mask, m_And(m_Value(LHS),
141+
m_Not(m_Value(NotOp))))) {
142+
LHS = InstCombiner::peekThroughBitcast(LHS);
143+
NotOp = InstCombiner::peekThroughBitcast(NotOp);
144+
145+
if (match(LHS, m_SExt(m_Value(LHSInner))) &&
146+
LHSInner->getType()->isVectorTy() &&
147+
LHSInner->getType()->getScalarType()->isIntegerTy(1) &&
148+
match(NotOp, m_SExt(m_Value(RHSInner))) &&
149+
RHSInner->getType()->isVectorTy() &&
150+
RHSInner->getType()->getScalarType()->isIntegerTy(1) &&
151+
LHSInner->getType() == RHSInner->getType()) {
152+
Value *NotRHSInner = Builder.CreateNot(RHSInner);
153+
return Builder.CreateAnd(LHSInner, NotRHSInner);
154+
}
155+
156+
// Try recursively on each operand
157+
Value *DecomposedLHS = tryDecomposeVectorLogicMask(LHS, Builder);
158+
Value *DecomposedNotOp = tryDecomposeVectorLogicMask(NotOp, Builder);
159+
if (DecomposedLHS && DecomposedNotOp &&
160+
DecomposedLHS->getType() == DecomposedNotOp->getType()) {
161+
Value *NotRHS = Builder.CreateNot(DecomposedNotOp);
162+
return Builder.CreateAnd(DecomposedLHS, NotRHS);
163+
}
164+
}
165+
166+
// No matching pattern found
167+
return nullptr;
168+
}
169+
170+
171+
172+
55173
// TODO: If the x86 backend knew how to convert a bool vector mask back to an
56174
// XMM register mask efficiently, we could transform all x86 masked intrinsics
57175
// to LLVM masked intrinsics and remove the x86 masked intrinsic defs.
@@ -2150,6 +2268,52 @@ static bool simplifyX86VPERMMask(Instruction *II, bool IsBinary,
21502268
return IC.SimplifyDemandedBits(II, /*OpNo=*/1, DemandedMask, KnownMask);
21512269
}
21522270

2271+
2272+
static Instruction *createMaskSelect(InstCombiner &IC, CallInst &II,
2273+
Value *BoolVec, Value *Op0, Value *Op1,
2274+
Value *MaskSrc = nullptr,
2275+
ArrayRef<int> ShuffleMask = std::nullopt) {
2276+
auto *MaskTy = cast<FixedVectorType>(II.getArgOperand(2)->getType());
2277+
auto *OpTy = cast<FixedVectorType>(II.getType());
2278+
unsigned NumMaskElts = MaskTy->getNumElements();
2279+
unsigned NumOperandElts = OpTy->getNumElements();
2280+
2281+
// If we peeked through a shuffle, reapply the shuffle to the bool vector.
2282+
if (MaskSrc) {
2283+
unsigned NumMaskSrcElts =
2284+
cast<FixedVectorType>(MaskSrc->getType())->getNumElements();
2285+
NumMaskElts = (ShuffleMask.size() * NumMaskElts) / NumMaskSrcElts;
2286+
// Multiple mask bits maps to the same operand element - bail out.
2287+
if (NumMaskElts > NumOperandElts)
2288+
return nullptr;
2289+
SmallVector<int> ScaledMask;
2290+
if (!llvm::scaleShuffleMaskElts(NumMaskElts, ShuffleMask, ScaledMask))
2291+
return nullptr;
2292+
BoolVec = IC.Builder.CreateShuffleVector(BoolVec, ScaledMask);
2293+
MaskTy = FixedVectorType::get(MaskTy->getElementType(), NumMaskElts);
2294+
}
2295+
2296+
assert(MaskTy->getPrimitiveSizeInBits() ==
2297+
OpTy->getPrimitiveSizeInBits() &&
2298+
"Not expecting mask and operands with different sizes");
2299+
2300+
if (NumMaskElts == NumOperandElts) {
2301+
return SelectInst::Create(BoolVec, Op1, Op0);
2302+
}
2303+
2304+
// If the mask has less elements than the operands, each mask bit maps to
2305+
// multiple elements of the operands. Bitcast back and forth.
2306+
if (NumMaskElts < NumOperandElts) {
2307+
Value *CastOp0 = IC.Builder.CreateBitCast(Op0, MaskTy);
2308+
Value *CastOp1 = IC.Builder.CreateBitCast(Op1, MaskTy);
2309+
Value *Sel = IC.Builder.CreateSelect(BoolVec, CastOp1, CastOp0);
2310+
return new BitCastInst(Sel, II.getType());
2311+
}
2312+
2313+
return nullptr;
2314+
}
2315+
2316+
21532317
std::optional<Instruction *>
21542318
X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
21552319
auto SimplifyDemandedVectorEltsLow = [&IC](Value *Op, unsigned Width,
@@ -2914,42 +3078,16 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
29143078
if (match(Mask, m_SExt(m_Value(BoolVec))) &&
29153079
BoolVec->getType()->isVectorTy() &&
29163080
BoolVec->getType()->getScalarSizeInBits() == 1) {
2917-
auto *MaskTy = cast<FixedVectorType>(Mask->getType());
2918-
auto *OpTy = cast<FixedVectorType>(II.getType());
2919-
unsigned NumMaskElts = MaskTy->getNumElements();
2920-
unsigned NumOperandElts = OpTy->getNumElements();
2921-
2922-
// If we peeked through a shuffle, reapply the shuffle to the bool vector.
2923-
if (MaskSrc) {
2924-
unsigned NumMaskSrcElts =
2925-
cast<FixedVectorType>(MaskSrc->getType())->getNumElements();
2926-
NumMaskElts = (ShuffleMask.size() * NumMaskElts) / NumMaskSrcElts;
2927-
// Multiple mask bits maps to the same operand element - bail out.
2928-
if (NumMaskElts > NumOperandElts)
2929-
break;
2930-
SmallVector<int> ScaledMask;
2931-
if (!llvm::scaleShuffleMaskElts(NumMaskElts, ShuffleMask, ScaledMask))
2932-
break;
2933-
BoolVec = IC.Builder.CreateShuffleVector(BoolVec, ScaledMask);
2934-
MaskTy = FixedVectorType::get(MaskTy->getElementType(), NumMaskElts);
2935-
}
2936-
assert(MaskTy->getPrimitiveSizeInBits() ==
2937-
OpTy->getPrimitiveSizeInBits() &&
2938-
"Not expecting mask and operands with different sizes");
2939-
2940-
if (NumMaskElts == NumOperandElts) {
2941-
return SelectInst::Create(BoolVec, Op1, Op0);
2942-
}
2943-
2944-
// If the mask has less elements than the operands, each mask bit maps to
2945-
// multiple elements of the operands. Bitcast back and forth.
2946-
if (NumMaskElts < NumOperandElts) {
2947-
Value *CastOp0 = IC.Builder.CreateBitCast(Op0, MaskTy);
2948-
Value *CastOp1 = IC.Builder.CreateBitCast(Op1, MaskTy);
2949-
Value *Sel = IC.Builder.CreateSelect(BoolVec, CastOp1, CastOp0);
2950-
return new BitCastInst(Sel, II.getType());
2951-
}
2952-
}
3081+
Instruction *Select = createMaskSelect(IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
3082+
if (Select) return Select;
3083+
} else {
3084+
BoolVec = tryDecomposeVectorLogicMask(Mask,IC.Builder);
3085+
if (BoolVec) {
3086+
Instruction *Select = createMaskSelect(IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
3087+
if (Select)
3088+
return Select;
3089+
}
3090+
}
29533091

29543092
break;
29553093
}

0 commit comments

Comments
 (0)