Skip to content

Commit 0d2a0b4

Browse files
committed
[VectorCombine] scalarize binop of inserted elements into vector constants
As with the extractelement patterns that are currently in vector-combine, there are going to be several possible variations on this theme. This should be the clearest, simplest example. Scalarization is the right direction for target-independent canonicalization, and InstCombine has some of those folds already, but it doesn't do this. I proposed a similar transform in D50992. Here in vector-combine, we can check the cost model to be sure it's profitable, so there should be less risk. Differential Revision: https://reviews.llvm.org/D79452
1 parent eb7d32e commit 0d2a0b4

File tree

2 files changed

+93
-24
lines changed

2 files changed

+93
-24
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using namespace llvm::PatternMatch;
3434
#define DEBUG_TYPE "vector-combine"
3535
STATISTIC(NumVecCmp, "Number of vector compares formed");
3636
STATISTIC(NumVecBO, "Number of vector binops formed");
37+
STATISTIC(NumScalarBO, "Number of scalar binops formed");
3738

3839
static cl::opt<bool> DisableVectorCombine(
3940
"disable-vector-combine", cl::init(false), cl::Hidden,
@@ -308,6 +309,64 @@ static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) {
308309
return true;
309310
}
310311

312+
/// Match a vector binop instruction with inserted scalar operands and convert
313+
/// to scalar binop followed by insertelement.
314+
static bool scalarizeBinop(Instruction &I, const TargetTransformInfo &TTI) {
315+
Instruction *Ins0, *Ins1;
316+
if (!match(&I, m_BinOp(m_Instruction(Ins0), m_Instruction(Ins1))))
317+
return false;
318+
319+
// TODO: Loosen restriction for one-use by adjusting cost equation.
320+
// TODO: Deal with mismatched index constants and variable indexes?
321+
Constant *VecC0, *VecC1;
322+
Value *V0, *V1;
323+
uint64_t Index;
324+
if (!match(Ins0, m_OneUse(m_InsertElement(m_Constant(VecC0), m_Value(V0),
325+
m_ConstantInt(Index)))) ||
326+
!match(Ins1, m_OneUse(m_InsertElement(m_Constant(VecC1), m_Value(V1),
327+
m_SpecificInt(Index)))))
328+
return false;
329+
330+
Type *ScalarTy = V0->getType();
331+
Type *VecTy = I.getType();
332+
assert(VecTy->isVectorTy() && ScalarTy == V1->getType() &&
333+
(ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy()) &&
334+
"Unexpected types for insert into binop");
335+
336+
Instruction::BinaryOps Opcode = cast<BinaryOperator>(&I)->getOpcode();
337+
int ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
338+
int VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
339+
340+
// Get cost estimate for the insert element. This cost will factor into
341+
// both sequences.
342+
int InsertCost =
343+
TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
344+
int OldCost = InsertCost + InsertCost + VectorOpCost;
345+
int NewCost = ScalarOpCost + InsertCost;
346+
347+
// We want to scalarize unless the vector variant actually has lower cost.
348+
if (OldCost < NewCost)
349+
return false;
350+
351+
// vec_bo (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
352+
// inselt NewVecC, (scalar_bo V0, V1), Index
353+
++NumScalarBO;
354+
IRBuilder<> Builder(&I);
355+
Value *Scalar = Builder.CreateBinOp(Opcode, V0, V1, I.getName() + ".scalar");
356+
357+
// All IR flags are safe to back-propagate. There is no potential for extra
358+
// poison to be created by the scalar instruction.
359+
if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
360+
ScalarInst->copyIRFlags(&I);
361+
362+
// Fold the vector constants in the original vectors into a new base vector.
363+
Constant *NewVecC = ConstantExpr::get(Opcode, VecC0, VecC1);
364+
Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
365+
I.replaceAllUsesWith(Insert);
366+
Insert->takeName(&I);
367+
return true;
368+
}
369+
311370
/// This is the entry point for all transforms. Pass manager differences are
312371
/// handled in the callers of this function.
313372
static bool runImpl(Function &F, const TargetTransformInfo &TTI,
@@ -330,6 +389,7 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI,
330389
continue;
331390
MadeChange |= foldExtractExtract(I, TTI);
332391
MadeChange |= foldBitcastShuf(I, TTI);
392+
MadeChange |= scalarizeBinop(I, TTI);
333393
}
334394
}
335395

llvm/test/Transforms/VectorCombine/X86/insert-binop.ll

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ declare void @use(<4 x i32>)
88

99
define <16 x i8> @ins0_ins0_add(i8 %x, i8 %y) {
1010
; CHECK-LABEL: @ins0_ins0_add(
11-
; CHECK-NEXT: [[I0:%.*]] = insertelement <16 x i8> undef, i8 [[X:%.*]], i32 0
12-
; CHECK-NEXT: [[I1:%.*]] = insertelement <16 x i8> undef, i8 [[Y:%.*]], i32 0
13-
; CHECK-NEXT: [[R:%.*]] = add <16 x i8> [[I0]], [[I1]]
11+
; CHECK-NEXT: [[R_SCALAR:%.*]] = add i8 [[X:%.*]], [[Y:%.*]]
12+
; CHECK-NEXT: [[R:%.*]] = insertelement <16 x i8> undef, i8 [[R_SCALAR]], i64 0
1413
; CHECK-NEXT: ret <16 x i8> [[R]]
1514
;
1615
%i0 = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -23,9 +22,8 @@ define <16 x i8> @ins0_ins0_add(i8 %x, i8 %y) {
2322

2423
define <8 x i16> @ins0_ins0_sub_flags(i16 %x, i16 %y) {
2524
; CHECK-LABEL: @ins0_ins0_sub_flags(
26-
; CHECK-NEXT: [[I0:%.*]] = insertelement <8 x i16> undef, i16 [[X:%.*]], i8 5
27-
; CHECK-NEXT: [[I1:%.*]] = insertelement <8 x i16> undef, i16 [[Y:%.*]], i32 5
28-
; CHECK-NEXT: [[R:%.*]] = sub nuw nsw <8 x i16> [[I0]], [[I1]]
25+
; CHECK-NEXT: [[R_SCALAR:%.*]] = sub nuw nsw i16 [[X:%.*]], [[Y:%.*]]
26+
; CHECK-NEXT: [[R:%.*]] = insertelement <8 x i16> undef, i16 [[R_SCALAR]], i64 5
2927
; CHECK-NEXT: ret <8 x i16> [[R]]
3028
;
3129
%i0 = insertelement <8 x i16> undef, i16 %x, i8 5
@@ -34,11 +32,13 @@ define <8 x i16> @ins0_ins0_sub_flags(i16 %x, i16 %y) {
3432
ret <8 x i16> %r
3533
}
3634

35+
; The new vector constant is calculated by constant folding.
36+
; This is conservatively created as zero rather than undef for 'undef ^ undef'.
37+
3738
define <2 x i64> @ins1_ins1_xor(i64 %x, i64 %y) {
3839
; CHECK-LABEL: @ins1_ins1_xor(
39-
; CHECK-NEXT: [[I0:%.*]] = insertelement <2 x i64> undef, i64 [[X:%.*]], i64 1
40-
; CHECK-NEXT: [[I1:%.*]] = insertelement <2 x i64> undef, i64 [[Y:%.*]], i32 1
41-
; CHECK-NEXT: [[R:%.*]] = xor <2 x i64> [[I0]], [[I1]]
40+
; CHECK-NEXT: [[R_SCALAR:%.*]] = xor i64 [[X:%.*]], [[Y:%.*]]
41+
; CHECK-NEXT: [[R:%.*]] = insertelement <2 x i64> zeroinitializer, i64 [[R_SCALAR]], i64 1
4242
; CHECK-NEXT: ret <2 x i64> [[R]]
4343
;
4444
%i0 = insertelement <2 x i64> undef, i64 %x, i64 1
@@ -51,9 +51,8 @@ define <2 x i64> @ins1_ins1_xor(i64 %x, i64 %y) {
5151

5252
define <2 x double> @ins0_ins0_fadd(double %x, double %y) {
5353
; CHECK-LABEL: @ins0_ins0_fadd(
54-
; CHECK-NEXT: [[I0:%.*]] = insertelement <2 x double> undef, double [[X:%.*]], i32 0
55-
; CHECK-NEXT: [[I1:%.*]] = insertelement <2 x double> undef, double [[Y:%.*]], i32 0
56-
; CHECK-NEXT: [[R:%.*]] = fadd reassoc nsz <2 x double> [[I0]], [[I1]]
54+
; CHECK-NEXT: [[R_SCALAR:%.*]] = fadd reassoc nsz double [[X:%.*]], [[Y:%.*]]
55+
; CHECK-NEXT: [[R:%.*]] = insertelement <2 x double> undef, double [[R_SCALAR]], i64 0
5756
; CHECK-NEXT: ret <2 x double> [[R]]
5857
;
5958
%i0 = insertelement <2 x double> undef, double %x, i32 0
@@ -62,6 +61,8 @@ define <2 x double> @ins0_ins0_fadd(double %x, double %y) {
6261
ret <2 x double> %r
6362
}
6463

64+
; Negative test - mismatched indexes (but could fold this).
65+
6566
define <16 x i8> @ins1_ins0_add(i8 %x, i8 %y) {
6667
; CHECK-LABEL: @ins1_ins0_add(
6768
; CHECK-NEXT: [[I0:%.*]] = insertelement <16 x i8> undef, i8 [[X:%.*]], i32 1
@@ -75,11 +76,12 @@ define <16 x i8> @ins1_ins0_add(i8 %x, i8 %y) {
7576
ret <16 x i8> %r
7677
}
7778

79+
; Base vector does not have to be undef.
80+
7881
define <4 x i32> @ins0_ins0_mul(i32 %x, i32 %y) {
7982
; CHECK-LABEL: @ins0_ins0_mul(
80-
; CHECK-NEXT: [[I0:%.*]] = insertelement <4 x i32> zeroinitializer, i32 [[X:%.*]], i32 0
81-
; CHECK-NEXT: [[I1:%.*]] = insertelement <4 x i32> undef, i32 [[Y:%.*]], i32 0
82-
; CHECK-NEXT: [[R:%.*]] = mul <4 x i32> [[I0]], [[I1]]
83+
; CHECK-NEXT: [[R_SCALAR:%.*]] = mul i32 [[X:%.*]], [[Y:%.*]]
84+
; CHECK-NEXT: [[R:%.*]] = insertelement <4 x i32> zeroinitializer, i32 [[R_SCALAR]], i64 0
8385
; CHECK-NEXT: ret <4 x i32> [[R]]
8486
;
8587
%i0 = insertelement <4 x i32> zeroinitializer, i32 %x, i32 0
@@ -88,11 +90,12 @@ define <4 x i32> @ins0_ins0_mul(i32 %x, i32 %y) {
8890
ret <4 x i32> %r
8991
}
9092

93+
; It is safe to scalarize any binop (no extra UB/poison danger).
94+
9195
define <2 x i64> @ins1_ins1_sdiv(i64 %x, i64 %y) {
9296
; CHECK-LABEL: @ins1_ins1_sdiv(
93-
; CHECK-NEXT: [[I0:%.*]] = insertelement <2 x i64> <i64 42, i64 -42>, i64 [[X:%.*]], i64 1
94-
; CHECK-NEXT: [[I1:%.*]] = insertelement <2 x i64> <i64 -7, i64 128>, i64 [[Y:%.*]], i32 1
95-
; CHECK-NEXT: [[R:%.*]] = sdiv <2 x i64> [[I0]], [[I1]]
97+
; CHECK-NEXT: [[R_SCALAR:%.*]] = sdiv i64 [[X:%.*]], [[Y:%.*]]
98+
; CHECK-NEXT: [[R:%.*]] = insertelement <2 x i64> <i64 -6, i64 0>, i64 [[R_SCALAR]], i64 1
9699
; CHECK-NEXT: ret <2 x i64> [[R]]
97100
;
98101
%i0 = insertelement <2 x i64> <i64 42, i64 -42>, i64 %x, i64 1
@@ -101,11 +104,12 @@ define <2 x i64> @ins1_ins1_sdiv(i64 %x, i64 %y) {
101104
ret <2 x i64> %r
102105
}
103106

107+
; Constant folding deals with undef per element - the entire value does not become undef.
108+
104109
define <2 x i64> @ins1_ins1_udiv(i64 %x, i64 %y) {
105110
; CHECK-LABEL: @ins1_ins1_udiv(
106-
; CHECK-NEXT: [[I0:%.*]] = insertelement <2 x i64> <i64 42, i64 undef>, i64 [[X:%.*]], i32 1
107-
; CHECK-NEXT: [[I1:%.*]] = insertelement <2 x i64> <i64 7, i64 undef>, i64 [[Y:%.*]], i32 1
108-
; CHECK-NEXT: [[R:%.*]] = udiv <2 x i64> [[I0]], [[I1]]
111+
; CHECK-NEXT: [[R_SCALAR:%.*]] = udiv i64 [[X:%.*]], [[Y:%.*]]
112+
; CHECK-NEXT: [[R:%.*]] = insertelement <2 x i64> <i64 6, i64 undef>, i64 [[R_SCALAR]], i64 1
109113
; CHECK-NEXT: ret <2 x i64> [[R]]
110114
;
111115
%i0 = insertelement <2 x i64> <i64 42, i64 undef>, i64 %x, i32 1
@@ -114,11 +118,13 @@ define <2 x i64> @ins1_ins1_udiv(i64 %x, i64 %y) {
114118
ret <2 x i64> %r
115119
}
116120

121+
; This could be simplified -- creates immediate UB without the transform because
122+
; divisor has an undef element -- but that is hidden after the transform.
123+
117124
define <2 x i64> @ins1_ins1_urem(i64 %x, i64 %y) {
118125
; CHECK-LABEL: @ins1_ins1_urem(
119-
; CHECK-NEXT: [[I0:%.*]] = insertelement <2 x i64> <i64 42, i64 undef>, i64 [[X:%.*]], i64 1
120-
; CHECK-NEXT: [[I1:%.*]] = insertelement <2 x i64> <i64 undef, i64 128>, i64 [[Y:%.*]], i32 1
121-
; CHECK-NEXT: [[R:%.*]] = urem <2 x i64> [[I0]], [[I1]]
126+
; CHECK-NEXT: [[R_SCALAR:%.*]] = urem i64 [[X:%.*]], [[Y:%.*]]
127+
; CHECK-NEXT: [[R:%.*]] = insertelement <2 x i64> <i64 undef, i64 0>, i64 [[R_SCALAR]], i64 1
122128
; CHECK-NEXT: ret <2 x i64> [[R]]
123129
;
124130
%i0 = insertelement <2 x i64> <i64 42, i64 undef>, i64 %x, i64 1
@@ -127,6 +133,9 @@ define <2 x i64> @ins1_ins1_urem(i64 %x, i64 %y) {
127133
ret <2 x i64> %r
128134
}
129135

136+
; Negative test
137+
; TODO: extra use can be accounted for in cost calculation.
138+
130139
define <4 x i32> @ins0_ins0_xor(i32 %x, i32 %y) {
131140
; CHECK-LABEL: @ins0_ins0_xor(
132141
; CHECK-NEXT: [[I0:%.*]] = insertelement <4 x i32> undef, i32 [[X:%.*]], i32 0

0 commit comments

Comments
 (0)