@@ -52,6 +52,124 @@ static Value *getBoolVecFromMask(Value *Mask, const DataLayout &DL) {
52
52
return nullptr ;
53
53
}
54
54
55
+ // Helper function to decompose complex logic on sign-extended i1 vectors
56
+ static Value *tryDecomposeVectorLogicMask (Value *Mask, IRBuilderBase &Builder) {
57
+ // Look through bitcasts
58
+ Mask = InstCombiner::peekThroughBitcast (Mask);
59
+
60
+ // Direct sign-extension case (should be caught by the main code path)
61
+ Value *InnerVal;
62
+ if (match (Mask, m_SExt (m_Value (InnerVal))) &&
63
+ InnerVal->getType ()->isVectorTy () &&
64
+ InnerVal->getType ()->getScalarType ()->isIntegerTy (1 ))
65
+ return InnerVal;
66
+
67
+ // Handle AND of sign-extended vectors: (sext A) & (sext B) -> sext(A & B)
68
+ Value *LHS, *RHS;
69
+ Value *LHSInner, *RHSInner;
70
+ if (match (Mask, m_And (m_Value (LHS), m_Value (RHS)))) {
71
+ LHS = InstCombiner::peekThroughBitcast (LHS);
72
+ RHS = InstCombiner::peekThroughBitcast (RHS);
73
+
74
+ if (match (LHS, m_SExt (m_Value (LHSInner))) &&
75
+ LHSInner->getType ()->isVectorTy () &&
76
+ LHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
77
+ match (RHS, m_SExt (m_Value (RHSInner))) &&
78
+ RHSInner->getType ()->isVectorTy () &&
79
+ RHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
80
+ LHSInner->getType () == RHSInner->getType ()) {
81
+ return Builder.CreateAnd (LHSInner, RHSInner);
82
+ }
83
+
84
+ // Try recursively on each operand
85
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask (LHS, Builder);
86
+ Value *DecomposedRHS = tryDecomposeVectorLogicMask (RHS, Builder);
87
+ if (DecomposedLHS && DecomposedRHS &&
88
+ DecomposedLHS->getType () == DecomposedRHS->getType ())
89
+ return Builder.CreateAnd (DecomposedLHS, DecomposedRHS);
90
+ }
91
+
92
+ // Handle XOR of sign-extended vectors: (sext A) ^ (sext B) -> sext(A ^ B)
93
+ if (match (Mask, m_Xor (m_Value (LHS), m_Value (RHS)))) {
94
+ LHS = InstCombiner::peekThroughBitcast (LHS);
95
+ RHS = InstCombiner::peekThroughBitcast (RHS);
96
+
97
+ if (match (LHS, m_SExt (m_Value (LHSInner))) &&
98
+ LHSInner->getType ()->isVectorTy () &&
99
+ LHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
100
+ match (RHS, m_SExt (m_Value (RHSInner))) &&
101
+ RHSInner->getType ()->isVectorTy () &&
102
+ RHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
103
+ LHSInner->getType () == RHSInner->getType ()) {
104
+ return Builder.CreateXor (LHSInner, RHSInner);
105
+ }
106
+
107
+ // Try recursively on each operand
108
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask (LHS, Builder);
109
+ Value *DecomposedRHS = tryDecomposeVectorLogicMask (RHS, Builder);
110
+ if (DecomposedLHS && DecomposedRHS &&
111
+ DecomposedLHS->getType () == DecomposedRHS->getType ())
112
+ return Builder.CreateXor (DecomposedLHS, DecomposedRHS);
113
+ }
114
+
115
+ // Handle OR of sign-extended vectors: (sext A) | (sext B) -> sext(A | B)
116
+ if (match (Mask, m_Or (m_Value (LHS), m_Value (RHS)))) {
117
+ LHS = InstCombiner::peekThroughBitcast (LHS);
118
+ RHS = InstCombiner::peekThroughBitcast (RHS);
119
+
120
+ if (match (LHS, m_SExt (m_Value (LHSInner))) &&
121
+ LHSInner->getType ()->isVectorTy () &&
122
+ LHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
123
+ match (RHS, m_SExt (m_Value (RHSInner))) &&
124
+ RHSInner->getType ()->isVectorTy () &&
125
+ RHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
126
+ LHSInner->getType () == RHSInner->getType ()) {
127
+ return Builder.CreateOr (LHSInner, RHSInner);
128
+ }
129
+
130
+ // Try recursively on each operand
131
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask (LHS, Builder);
132
+ Value *DecomposedRHS = tryDecomposeVectorLogicMask (RHS, Builder);
133
+ if (DecomposedLHS && DecomposedRHS &&
134
+ DecomposedLHS->getType () == DecomposedRHS->getType ())
135
+ return Builder.CreateOr (DecomposedLHS, DecomposedRHS);
136
+ }
137
+
138
+ // Handle AndNot: (sext A) & ~(sext B) -> sext(A & ~B)
139
+ Value *NotOp;
140
+ if (match (Mask, m_And (m_Value (LHS),
141
+ m_Not (m_Value (NotOp))))) {
142
+ LHS = InstCombiner::peekThroughBitcast (LHS);
143
+ NotOp = InstCombiner::peekThroughBitcast (NotOp);
144
+
145
+ if (match (LHS, m_SExt (m_Value (LHSInner))) &&
146
+ LHSInner->getType ()->isVectorTy () &&
147
+ LHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
148
+ match (NotOp, m_SExt (m_Value (RHSInner))) &&
149
+ RHSInner->getType ()->isVectorTy () &&
150
+ RHSInner->getType ()->getScalarType ()->isIntegerTy (1 ) &&
151
+ LHSInner->getType () == RHSInner->getType ()) {
152
+ Value *NotRHSInner = Builder.CreateNot (RHSInner);
153
+ return Builder.CreateAnd (LHSInner, NotRHSInner);
154
+ }
155
+
156
+ // Try recursively on each operand
157
+ Value *DecomposedLHS = tryDecomposeVectorLogicMask (LHS, Builder);
158
+ Value *DecomposedNotOp = tryDecomposeVectorLogicMask (NotOp, Builder);
159
+ if (DecomposedLHS && DecomposedNotOp &&
160
+ DecomposedLHS->getType () == DecomposedNotOp->getType ()) {
161
+ Value *NotRHS = Builder.CreateNot (DecomposedNotOp);
162
+ return Builder.CreateAnd (DecomposedLHS, NotRHS);
163
+ }
164
+ }
165
+
166
+ // No matching pattern found
167
+ return nullptr ;
168
+ }
169
+
170
+
171
+
172
+
55
173
// TODO: If the x86 backend knew how to convert a bool vector mask back to an
56
174
// XMM register mask efficiently, we could transform all x86 masked intrinsics
57
175
// to LLVM masked intrinsics and remove the x86 masked intrinsic defs.
@@ -2150,6 +2268,52 @@ static bool simplifyX86VPERMMask(Instruction *II, bool IsBinary,
2150
2268
return IC.SimplifyDemandedBits (II, /* OpNo=*/ 1 , DemandedMask, KnownMask);
2151
2269
}
2152
2270
2271
+
2272
+ static Instruction *createMaskSelect (InstCombiner &IC, CallInst &II,
2273
+ Value *BoolVec, Value *Op0, Value *Op1,
2274
+ Value *MaskSrc = nullptr ,
2275
+ ArrayRef<int > ShuffleMask = std::nullopt) {
2276
+ auto *MaskTy = cast<FixedVectorType>(II.getArgOperand (2 )->getType ());
2277
+ auto *OpTy = cast<FixedVectorType>(II.getType ());
2278
+ unsigned NumMaskElts = MaskTy->getNumElements ();
2279
+ unsigned NumOperandElts = OpTy->getNumElements ();
2280
+
2281
+ // If we peeked through a shuffle, reapply the shuffle to the bool vector.
2282
+ if (MaskSrc) {
2283
+ unsigned NumMaskSrcElts =
2284
+ cast<FixedVectorType>(MaskSrc->getType ())->getNumElements ();
2285
+ NumMaskElts = (ShuffleMask.size () * NumMaskElts) / NumMaskSrcElts;
2286
+ // Multiple mask bits maps to the same operand element - bail out.
2287
+ if (NumMaskElts > NumOperandElts)
2288
+ return nullptr ;
2289
+ SmallVector<int > ScaledMask;
2290
+ if (!llvm::scaleShuffleMaskElts (NumMaskElts, ShuffleMask, ScaledMask))
2291
+ return nullptr ;
2292
+ BoolVec = IC.Builder .CreateShuffleVector (BoolVec, ScaledMask);
2293
+ MaskTy = FixedVectorType::get (MaskTy->getElementType (), NumMaskElts);
2294
+ }
2295
+
2296
+ assert (MaskTy->getPrimitiveSizeInBits () ==
2297
+ OpTy->getPrimitiveSizeInBits () &&
2298
+ " Not expecting mask and operands with different sizes" );
2299
+
2300
+ if (NumMaskElts == NumOperandElts) {
2301
+ return SelectInst::Create (BoolVec, Op1, Op0);
2302
+ }
2303
+
2304
+ // If the mask has less elements than the operands, each mask bit maps to
2305
+ // multiple elements of the operands. Bitcast back and forth.
2306
+ if (NumMaskElts < NumOperandElts) {
2307
+ Value *CastOp0 = IC.Builder .CreateBitCast (Op0, MaskTy);
2308
+ Value *CastOp1 = IC.Builder .CreateBitCast (Op1, MaskTy);
2309
+ Value *Sel = IC.Builder .CreateSelect (BoolVec, CastOp1, CastOp0);
2310
+ return new BitCastInst (Sel, II.getType ());
2311
+ }
2312
+
2313
+ return nullptr ;
2314
+ }
2315
+
2316
+
2153
2317
std::optional<Instruction *>
2154
2318
X86TTIImpl::instCombineIntrinsic (InstCombiner &IC, IntrinsicInst &II) const {
2155
2319
auto SimplifyDemandedVectorEltsLow = [&IC](Value *Op, unsigned Width,
@@ -2914,42 +3078,16 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
2914
3078
if (match (Mask, m_SExt (m_Value (BoolVec))) &&
2915
3079
BoolVec->getType ()->isVectorTy () &&
2916
3080
BoolVec->getType ()->getScalarSizeInBits () == 1 ) {
2917
- auto *MaskTy = cast<FixedVectorType>(Mask->getType ());
2918
- auto *OpTy = cast<FixedVectorType>(II.getType ());
2919
- unsigned NumMaskElts = MaskTy->getNumElements ();
2920
- unsigned NumOperandElts = OpTy->getNumElements ();
2921
-
2922
- // If we peeked through a shuffle, reapply the shuffle to the bool vector.
2923
- if (MaskSrc) {
2924
- unsigned NumMaskSrcElts =
2925
- cast<FixedVectorType>(MaskSrc->getType ())->getNumElements ();
2926
- NumMaskElts = (ShuffleMask.size () * NumMaskElts) / NumMaskSrcElts;
2927
- // Multiple mask bits maps to the same operand element - bail out.
2928
- if (NumMaskElts > NumOperandElts)
2929
- break ;
2930
- SmallVector<int > ScaledMask;
2931
- if (!llvm::scaleShuffleMaskElts (NumMaskElts, ShuffleMask, ScaledMask))
2932
- break ;
2933
- BoolVec = IC.Builder .CreateShuffleVector (BoolVec, ScaledMask);
2934
- MaskTy = FixedVectorType::get (MaskTy->getElementType (), NumMaskElts);
2935
- }
2936
- assert (MaskTy->getPrimitiveSizeInBits () ==
2937
- OpTy->getPrimitiveSizeInBits () &&
2938
- " Not expecting mask and operands with different sizes" );
2939
-
2940
- if (NumMaskElts == NumOperandElts) {
2941
- return SelectInst::Create (BoolVec, Op1, Op0);
2942
- }
2943
-
2944
- // If the mask has less elements than the operands, each mask bit maps to
2945
- // multiple elements of the operands. Bitcast back and forth.
2946
- if (NumMaskElts < NumOperandElts) {
2947
- Value *CastOp0 = IC.Builder .CreateBitCast (Op0, MaskTy);
2948
- Value *CastOp1 = IC.Builder .CreateBitCast (Op1, MaskTy);
2949
- Value *Sel = IC.Builder .CreateSelect (BoolVec, CastOp1, CastOp0);
2950
- return new BitCastInst (Sel, II.getType ());
2951
- }
2952
- }
3081
+ Instruction *Select = createMaskSelect (IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
3082
+ if (Select) return Select;
3083
+ } else {
3084
+ BoolVec = tryDecomposeVectorLogicMask (Mask,IC.Builder );
3085
+ if (BoolVec) {
3086
+ Instruction *Select = createMaskSelect (IC, II, BoolVec, Op0, Op1, MaskSrc, ShuffleMask);
3087
+ if (Select)
3088
+ return Select;
3089
+ }
3090
+ }
2953
3091
2954
3092
break ;
2955
3093
}
0 commit comments