Skip to content

Commit ee551cf

Browse files
committed
[VPlan] Introduce m_Cmp; match more compares
Extend [Specific]Cmp_match to handle floating-point compares, and introduce m_Cmp that matches both integer and floating-point compares. Use it in simplifyRecipe to match and simplify the general case of compares.
1 parent 2975e67 commit ee551cf

File tree

6 files changed

+120
-102
lines changed

6 files changed

+120
-102
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ class VPIRFlags {
805805

806806
GEPNoWrapFlags getGEPNoWrapFlags() const { return GEPFlags; }
807807

808+
/// Returns true if the recipe has a comparison predicate.
809+
bool hasPredicate() const { return OpType == OperationType::Cmp; }
810+
808811
/// Returns true if the recipe has fast-math flags.
809812
bool hasFastMathFlags() const { return OpType == OperationType::FPMathOp; }
810813

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -392,24 +392,32 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
392392
return m_c_Binary<Instruction::Or, Op0_t, Op1_t>(Op0, Op1);
393393
}
394394

395-
/// ICmp_match is a variant of BinaryRecipe_match that also binds the comparison
396-
/// predicate.
397-
template <typename Op0_t, typename Op1_t> struct ICmp_match {
395+
/// Cmp_match is a variant of BinaryRecipe_match that also binds the comparison
396+
/// predicate. Opcodes must either be Instruction::ICmp or Instruction::FCmp, or
397+
/// both.
398+
template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
399+
struct Cmp_match {
400+
static_assert((sizeof...(Opcodes) == 1 || sizeof...(Opcodes) == 2) &&
401+
"Expected one or two opcodes");
402+
static_assert(
403+
((Opcodes == Instruction::ICmp || Opcodes == Instruction::FCmp) && ...) &&
404+
"Expected a compare instruction opcode");
405+
398406
CmpPredicate *Predicate = nullptr;
399407
Op0_t Op0;
400408
Op1_t Op1;
401409

402-
ICmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
410+
Cmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
403411
: Predicate(&Pred), Op0(Op0), Op1(Op1) {}
404-
ICmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}
412+
Cmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}
405413

406414
bool match(const VPValue *V) const {
407415
auto *DefR = V->getDefiningRecipe();
408416
return DefR && match(DefR);
409417
}
410418

411419
bool match(const VPRecipeBase *V) const {
412-
if (m_Binary<Instruction::ICmp>(Op0, Op1).match(V)) {
420+
if ((m_Binary<Opcodes>(Op0, Op1).match(V) || ...)) {
413421
if (Predicate)
414422
*Predicate = cast<VPRecipeWithIRFlags>(V)->getPredicate();
415423
return true;
@@ -418,38 +426,63 @@ template <typename Op0_t, typename Op1_t> struct ICmp_match {
418426
}
419427
};
420428

421-
/// SpecificICmp_match is a variant of ICmp_match that matches the comparison
429+
/// SpecificCmp_match is a variant of Cmp_match that matches the comparison
422430
/// predicate, instead of binding it.
423-
template <typename Op0_t, typename Op1_t> struct SpecificICmp_match {
431+
template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
432+
struct SpecificCmp_match {
424433
const CmpPredicate Predicate;
425434
Op0_t Op0;
426435
Op1_t Op1;
427436

428-
SpecificICmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
437+
SpecificCmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
429438
: Predicate(Pred), Op0(LHS), Op1(RHS) {}
430439

431440
bool match(const VPValue *V) const {
432441
CmpPredicate CurrentPred;
433-
return ICmp_match<Op0_t, Op1_t>(CurrentPred, Op0, Op1).match(V) &&
442+
return Cmp_match<Op0_t, Op1_t, Opcodes...>(CurrentPred, Op0, Op1)
443+
.match(V) &&
434444
CmpPredicate::getMatching(CurrentPred, Predicate);
435445
}
436446
};
437447

438448
template <typename Op0_t, typename Op1_t>
439-
inline ICmp_match<Op0_t, Op1_t> m_ICmp(const Op0_t &Op0, const Op1_t &Op1) {
440-
return ICmp_match<Op0_t, Op1_t>(Op0, Op1);
449+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp> m_ICmp(const Op0_t &Op0,
450+
const Op1_t &Op1) {
451+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Op0, Op1);
441452
}
442453

443454
template <typename Op0_t, typename Op1_t>
444-
inline ICmp_match<Op0_t, Op1_t> m_ICmp(CmpPredicate &Pred, const Op0_t &Op0,
445-
const Op1_t &Op1) {
446-
return ICmp_match<Op0_t, Op1_t>(Pred, Op0, Op1);
455+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp>
456+
m_ICmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
457+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Pred, Op0, Op1);
447458
}
448459

449460
template <typename Op0_t, typename Op1_t>
450-
inline SpecificICmp_match<Op0_t, Op1_t>
461+
inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>
451462
m_SpecificICmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
452-
return SpecificICmp_match<Op0_t, Op1_t>(MatchPred, Op0, Op1);
463+
return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>(MatchPred, Op0,
464+
Op1);
465+
}
466+
467+
template <typename Op0_t, typename Op1_t>
468+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
469+
m_Cmp(const Op0_t &Op0, const Op1_t &Op1) {
470+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(Op0,
471+
Op1);
472+
}
473+
474+
template <typename Op0_t, typename Op1_t>
475+
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
476+
m_Cmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
477+
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
478+
Pred, Op0, Op1);
479+
}
480+
481+
template <typename Op0_t, typename Op1_t>
482+
inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
483+
m_SpecificCmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
484+
return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
485+
MatchPred, Op0, Op1);
453486
}
454487

455488
template <typename Op0_t, typename Op1_t>

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,6 +2930,9 @@ static void scalarizeInstruction(const Instruction *Instr,
29302930
RepRecipe->applyFlags(*Cloned);
29312931
RepRecipe->applyMetadata(*Cloned);
29322932

2933+
if (RepRecipe->hasPredicate())
2934+
cast<CmpInst>(Cloned)->setPredicate(RepRecipe->getPredicate());
2935+
29332936
if (auto DL = RepRecipe->getDebugLoc())
29342937
State.setDebugLocFrom(DL);
29352938

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,33 +1106,31 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
11061106
return Def->replaceAllUsesWith(A);
11071107

11081108
// Try to fold Not into compares by adjusting the predicate in-place.
1109-
if (auto *WideCmp = dyn_cast<VPWidenRecipe>(A)) {
1110-
if ((WideCmp->getOpcode() == Instruction::ICmp ||
1111-
WideCmp->getOpcode() == Instruction::FCmp) &&
1112-
all_of(WideCmp->users(), [&WideCmp](VPUser *U) {
1113-
return match(U, m_CombineOr(m_Not(m_Specific(WideCmp)),
1114-
m_Select(m_Specific(WideCmp),
1115-
m_VPValue(), m_VPValue())));
1109+
CmpPredicate Pred;
1110+
if (match(A, m_Cmp(Pred, m_VPValue(), m_VPValue()))) {
1111+
auto *Cmp = cast<VPRecipeWithIRFlags>(A);
1112+
if (all_of(Cmp->users(), [&Cmp](VPUser *U) {
1113+
return match(U, m_CombineOr(m_Not(m_Specific(Cmp)),
1114+
m_Select(m_Specific(Cmp), m_VPValue(),
1115+
m_VPValue())));
11161116
})) {
1117-
WideCmp->setPredicate(
1118-
CmpInst::getInversePredicate(WideCmp->getPredicate()));
1119-
for (VPUser *U : to_vector(WideCmp->users())) {
1117+
Cmp->setPredicate(CmpInst::getInversePredicate(Pred));
1118+
for (VPUser *U : to_vector(Cmp->users())) {
11201119
auto *R = cast<VPSingleDefRecipe>(U);
1121-
if (match(R, m_Select(m_Specific(WideCmp), m_VPValue(X),
1122-
m_VPValue(Y)))) {
1120+
if (match(R, m_Select(m_Specific(Cmp), m_VPValue(X), m_VPValue(Y)))) {
11231121
// select (cmp pred), x, y -> select (cmp inv_pred), y, x
11241122
R->setOperand(1, Y);
11251123
R->setOperand(2, X);
11261124
} else {
11271125
// not (cmp pred) -> cmp inv_pred
1128-
assert(match(R, m_Not(m_Specific(WideCmp))) && "Unexpected user");
1129-
R->replaceAllUsesWith(WideCmp);
1126+
assert(match(R, m_Not(m_Specific(Cmp))) && "Unexpected user");
1127+
R->replaceAllUsesWith(Cmp);
11301128
}
11311129
}
1132-
// If WideCmp doesn't have a debug location, use the one from the
1133-
// negation, to preserve the location.
1134-
if (!WideCmp->getDebugLoc() && R.getDebugLoc())
1135-
WideCmp->setDebugLoc(R.getDebugLoc());
1130+
// If Cmp doesn't have a debug location, use the one from the negation,
1131+
// to preserve the location.
1132+
if (!Cmp->getDebugLoc() && R.getDebugLoc())
1133+
Cmp->setDebugLoc(R.getDebugLoc());
11361134
}
11371135
}
11381136
}

llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -973,18 +973,16 @@ define void @test_widen_exp_v2(ptr noalias %p2, ptr noalias %p, i64 %n) #5 {
973973
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK_ENTRY1:%.*]] = icmp ult i64 1, [[TMP0]]
974974
; TFA_INTERLEAVE-NEXT: br label %[[VECTOR_BODY:.*]]
975975
; TFA_INTERLEAVE: [[VECTOR_BODY]]:
976-
; TFA_INTERLEAVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[TMP19:.*]] ]
977-
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[TMP19]] ]
978-
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK2:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY1]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT6:%.*]], %[[TMP19]] ]
976+
; TFA_INTERLEAVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[TMP18:.*]] ]
977+
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[TMP18]] ]
978+
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK2:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY1]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT6:%.*]], %[[TMP18]] ]
979979
; TFA_INTERLEAVE-NEXT: [[TMP4:%.*]] = load double, ptr [[P2]], align 8
980980
; TFA_INTERLEAVE-NEXT: [[TMP5:%.*]] = tail call double @llvm.exp.f64(double [[TMP4]]) #[[ATTR7:[0-9]+]]
981981
; TFA_INTERLEAVE-NEXT: [[TMP6:%.*]] = tail call double @llvm.exp.f64(double [[TMP4]]) #[[ATTR7]]
982-
; TFA_INTERLEAVE-NEXT: [[TMP7:%.*]] = fcmp ogt double [[TMP5]], 0.000000e+00
983-
; TFA_INTERLEAVE-NEXT: [[TMP8:%.*]] = fcmp ogt double [[TMP6]], 0.000000e+00
984-
; TFA_INTERLEAVE-NEXT: [[TMP9:%.*]] = xor i1 [[TMP7]], true
985-
; TFA_INTERLEAVE-NEXT: [[TMP10:%.*]] = xor i1 [[TMP8]], true
986-
; TFA_INTERLEAVE-NEXT: [[TMP11:%.*]] = select i1 [[ACTIVE_LANE_MASK]], i1 [[TMP9]], i1 false
987-
; TFA_INTERLEAVE-NEXT: [[TMP12:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], i1 [[TMP10]], i1 false
982+
; TFA_INTERLEAVE-NEXT: [[TMP7:%.*]] = fcmp ule double [[TMP5]], 0.000000e+00
983+
; TFA_INTERLEAVE-NEXT: [[TMP8:%.*]] = fcmp ule double [[TMP6]], 0.000000e+00
984+
; TFA_INTERLEAVE-NEXT: [[TMP11:%.*]] = select i1 [[ACTIVE_LANE_MASK]], i1 [[TMP7]], i1 false
985+
; TFA_INTERLEAVE-NEXT: [[TMP12:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], i1 [[TMP8]], i1 false
988986
; TFA_INTERLEAVE-NEXT: [[PREDPHI:%.*]] = select i1 [[TMP11]], double 1.000000e+00, double 0.000000e+00
989987
; TFA_INTERLEAVE-NEXT: [[PREDPHI3:%.*]] = select i1 [[TMP12]], double 1.000000e+00, double 0.000000e+00
990988
; TFA_INTERLEAVE-NEXT: [[SPEC_SELECT:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], double [[PREDPHI3]], double [[PREDPHI]]
@@ -993,11 +991,11 @@ define void @test_widen_exp_v2(ptr noalias %p2, ptr noalias %p, i64 %n) #5 {
993991
; TFA_INTERLEAVE-NEXT: [[TMP15:%.*]] = xor i1 [[TMP13]], true
994992
; TFA_INTERLEAVE-NEXT: [[TMP16:%.*]] = xor i1 [[TMP14]], true
995993
; TFA_INTERLEAVE-NEXT: [[TMP17:%.*]] = or i1 [[TMP15]], [[TMP16]]
996-
; TFA_INTERLEAVE-NEXT: br i1 [[TMP17]], label %[[BB18:.*]], label %[[TMP19]]
997-
; TFA_INTERLEAVE: [[BB18]]:
994+
; TFA_INTERLEAVE-NEXT: br i1 [[TMP17]], label %[[BB16:.*]], label %[[TMP18]]
995+
; TFA_INTERLEAVE: [[BB16]]:
998996
; TFA_INTERLEAVE-NEXT: store double [[SPEC_SELECT]], ptr [[P]], align 8
999-
; TFA_INTERLEAVE-NEXT: br label %[[TMP19]]
1000-
; TFA_INTERLEAVE: [[TMP19]]:
997+
; TFA_INTERLEAVE-NEXT: br label %[[TMP18]]
998+
; TFA_INTERLEAVE: [[TMP18]]:
1001999
; TFA_INTERLEAVE-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 2
10021000
; TFA_INTERLEAVE-NEXT: [[TMP20:%.*]] = add i64 [[INDEX]], 1
10031001
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = icmp ult i64 [[INDEX]], [[TMP3]]

0 commit comments

Comments
 (0)