Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7002,12 +7002,12 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
if (Instruction *UI = GetInstructionForCost(&R)) {
// If we adjusted the predicate of the recipe, the cost in the legacy
// cost model may be different.
if (auto *WidenCmp = dyn_cast<VPWidenRecipe>(&R)) {
if ((WidenCmp->getOpcode() == Instruction::ICmp ||
WidenCmp->getOpcode() == Instruction::FCmp) &&
WidenCmp->getPredicate() != cast<CmpInst>(UI)->getPredicate())
return true;
}
using namespace VPlanPatternMatch;
CmpPredicate Pred;
if (match(&R, m_Cmp(Pred, m_VPValue(), m_VPValue())) &&
cast<VPRecipeWithIRFlags>(R).getPredicate() !=
cast<CmpInst>(UI)->getPredicate())
return true;
SeenInstrs.insert(UI);
}
}
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,9 @@ class VPIRFlags {

GEPNoWrapFlags getGEPNoWrapFlags() const { return GEPFlags; }

/// Returns true if the recipe has a comparison predicate.
bool hasPredicate() const { return OpType == OperationType::Cmp; }

/// Returns true if the recipe has fast-math flags.
bool hasFastMathFlags() const { return OpType == OperationType::FPMathOp; }

Expand Down
67 changes: 50 additions & 17 deletions llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,24 +392,32 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
return m_c_Binary<Instruction::Or, Op0_t, Op1_t>(Op0, Op1);
}

/// ICmp_match is a variant of BinaryRecipe_match that also binds the comparison
/// predicate.
template <typename Op0_t, typename Op1_t> struct ICmp_match {
/// Cmp_match is a variant of BinaryRecipe_match that also binds the comparison
/// predicate. Opcodes must either be Instruction::ICmp or Instruction::FCmp, or
/// both.
template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
struct Cmp_match {
static_assert((sizeof...(Opcodes) == 1 || sizeof...(Opcodes) == 2) &&
"Expected one or two opcodes");
static_assert(
((Opcodes == Instruction::ICmp || Opcodes == Instruction::FCmp) && ...) &&
"Expected a compare instruction opcode");

CmpPredicate *Predicate = nullptr;
Op0_t Op0;
Op1_t Op1;

ICmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
Cmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
: Predicate(&Pred), Op0(Op0), Op1(Op1) {}
ICmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}
Cmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}

bool match(const VPValue *V) const {
auto *DefR = V->getDefiningRecipe();
return DefR && match(DefR);
}

bool match(const VPRecipeBase *V) const {
if (m_Binary<Instruction::ICmp>(Op0, Op1).match(V)) {
if ((m_Binary<Opcodes>(Op0, Op1).match(V) || ...)) {
if (Predicate)
*Predicate = cast<VPRecipeWithIRFlags>(V)->getPredicate();
return true;
Expand All @@ -418,38 +426,63 @@ template <typename Op0_t, typename Op1_t> struct ICmp_match {
}
};

/// SpecificICmp_match is a variant of ICmp_match that matches the comparison
/// SpecificCmp_match is a variant of Cmp_match that matches the comparison
/// predicate, instead of binding it.
template <typename Op0_t, typename Op1_t> struct SpecificICmp_match {
template <typename Op0_t, typename Op1_t, unsigned... Opcodes>
struct SpecificCmp_match {
const CmpPredicate Predicate;
Op0_t Op0;
Op1_t Op1;

SpecificICmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
SpecificCmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
: Predicate(Pred), Op0(LHS), Op1(RHS) {}

bool match(const VPValue *V) const {
CmpPredicate CurrentPred;
return ICmp_match<Op0_t, Op1_t>(CurrentPred, Op0, Op1).match(V) &&
return Cmp_match<Op0_t, Op1_t, Opcodes...>(CurrentPred, Op0, Op1)
.match(V) &&
CmpPredicate::getMatching(CurrentPred, Predicate);
}
};

template <typename Op0_t, typename Op1_t>
inline ICmp_match<Op0_t, Op1_t> m_ICmp(const Op0_t &Op0, const Op1_t &Op1) {
return ICmp_match<Op0_t, Op1_t>(Op0, Op1);
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp> m_ICmp(const Op0_t &Op0,
const Op1_t &Op1) {
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline ICmp_match<Op0_t, Op1_t> m_ICmp(CmpPredicate &Pred, const Op0_t &Op0,
const Op1_t &Op1) {
return ICmp_match<Op0_t, Op1_t>(Pred, Op0, Op1);
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp>
m_ICmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp>(Pred, Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline SpecificICmp_match<Op0_t, Op1_t>
inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>
m_SpecificICmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
return SpecificICmp_match<Op0_t, Op1_t>(MatchPred, Op0, Op1);
return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp>(MatchPred, Op0,
Op1);
}

template <typename Op0_t, typename Op1_t>
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
m_Cmp(const Op0_t &Op0, const Op1_t &Op1) {
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(Op0,
Op1);
}

template <typename Op0_t, typename Op1_t>
inline Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
m_Cmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) {
return Cmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
Pred, Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>
m_SpecificCmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
return SpecificCmp_match<Op0_t, Op1_t, Instruction::ICmp, Instruction::FCmp>(
MatchPred, Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2930,6 +2930,9 @@ static void scalarizeInstruction(const Instruction *Instr,
RepRecipe->applyFlags(*Cloned);
RepRecipe->applyMetadata(*Cloned);

if (RepRecipe->hasPredicate())
cast<CmpInst>(Cloned)->setPredicate(RepRecipe->getPredicate());

if (auto DL = RepRecipe->getDebugLoc())
State.setDebugLocFrom(DL);

Expand Down
34 changes: 16 additions & 18 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,33 +1106,31 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
return Def->replaceAllUsesWith(A);

// Try to fold Not into compares by adjusting the predicate in-place.
if (auto *WideCmp = dyn_cast<VPWidenRecipe>(A)) {
if ((WideCmp->getOpcode() == Instruction::ICmp ||
WideCmp->getOpcode() == Instruction::FCmp) &&
all_of(WideCmp->users(), [&WideCmp](VPUser *U) {
return match(U, m_CombineOr(m_Not(m_Specific(WideCmp)),
m_Select(m_Specific(WideCmp),
m_VPValue(), m_VPValue())));
CmpPredicate Pred;
if (match(A, m_Cmp(Pred, m_VPValue(), m_VPValue()))) {
auto *Cmp = cast<VPRecipeWithIRFlags>(A);
if (all_of(Cmp->users(), [&Cmp](VPUser *U) {
return match(U, m_CombineOr(m_Not(m_Specific(Cmp)),
m_Select(m_Specific(Cmp), m_VPValue(),
m_VPValue())));
})) {
WideCmp->setPredicate(
CmpInst::getInversePredicate(WideCmp->getPredicate()));
for (VPUser *U : to_vector(WideCmp->users())) {
Cmp->setPredicate(CmpInst::getInversePredicate(Pred));
for (VPUser *U : to_vector(Cmp->users())) {
auto *R = cast<VPSingleDefRecipe>(U);
if (match(R, m_Select(m_Specific(WideCmp), m_VPValue(X),
m_VPValue(Y)))) {
if (match(R, m_Select(m_Specific(Cmp), m_VPValue(X), m_VPValue(Y)))) {
// select (cmp pred), x, y -> select (cmp inv_pred), y, x
R->setOperand(1, Y);
R->setOperand(2, X);
} else {
// not (cmp pred) -> cmp inv_pred
assert(match(R, m_Not(m_Specific(WideCmp))) && "Unexpected user");
R->replaceAllUsesWith(WideCmp);
assert(match(R, m_Not(m_Specific(Cmp))) && "Unexpected user");
R->replaceAllUsesWith(Cmp);
}
}
// If WideCmp doesn't have a debug location, use the one from the
// negation, to preserve the location.
if (!WideCmp->getDebugLoc() && R.getDebugLoc())
WideCmp->setDebugLoc(R.getDebugLoc());
// If Cmp doesn't have a debug location, use the one from the negation,
// to preserve the location.
if (!Cmp->getDebugLoc() && R.getDebugLoc())
Cmp->setDebugLoc(R.getDebugLoc());
}
}
}
Expand Down
24 changes: 11 additions & 13 deletions llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
Original file line number Diff line number Diff line change
Expand Up @@ -973,18 +973,16 @@ define void @test_widen_exp_v2(ptr noalias %p2, ptr noalias %p, i64 %n) #5 {
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK_ENTRY1:%.*]] = icmp ult i64 1, [[TMP0]]
; TFA_INTERLEAVE-NEXT: br label %[[VECTOR_BODY:.*]]
; TFA_INTERLEAVE: [[VECTOR_BODY]]:
; TFA_INTERLEAVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[TMP19:.*]] ]
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[TMP19]] ]
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK2:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY1]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT6:%.*]], %[[TMP19]] ]
; TFA_INTERLEAVE-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[INDEX_NEXT:%.*]], %[[TMP18:.*]] ]
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT:%.*]], %[[TMP18]] ]
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK2:%.*]] = phi i1 [ [[ACTIVE_LANE_MASK_ENTRY1]], %[[ENTRY]] ], [ [[ACTIVE_LANE_MASK_NEXT6:%.*]], %[[TMP18]] ]
; TFA_INTERLEAVE-NEXT: [[TMP4:%.*]] = load double, ptr [[P2]], align 8
; TFA_INTERLEAVE-NEXT: [[TMP5:%.*]] = tail call double @llvm.exp.f64(double [[TMP4]]) #[[ATTR7:[0-9]+]]
; TFA_INTERLEAVE-NEXT: [[TMP6:%.*]] = tail call double @llvm.exp.f64(double [[TMP4]]) #[[ATTR7]]
; TFA_INTERLEAVE-NEXT: [[TMP7:%.*]] = fcmp ogt double [[TMP5]], 0.000000e+00
; TFA_INTERLEAVE-NEXT: [[TMP8:%.*]] = fcmp ogt double [[TMP6]], 0.000000e+00
; TFA_INTERLEAVE-NEXT: [[TMP9:%.*]] = xor i1 [[TMP7]], true
; TFA_INTERLEAVE-NEXT: [[TMP10:%.*]] = xor i1 [[TMP8]], true
; TFA_INTERLEAVE-NEXT: [[TMP11:%.*]] = select i1 [[ACTIVE_LANE_MASK]], i1 [[TMP9]], i1 false
; TFA_INTERLEAVE-NEXT: [[TMP12:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], i1 [[TMP10]], i1 false
; TFA_INTERLEAVE-NEXT: [[TMP7:%.*]] = fcmp ule double [[TMP5]], 0.000000e+00
; TFA_INTERLEAVE-NEXT: [[TMP8:%.*]] = fcmp ule double [[TMP6]], 0.000000e+00
; TFA_INTERLEAVE-NEXT: [[TMP11:%.*]] = select i1 [[ACTIVE_LANE_MASK]], i1 [[TMP7]], i1 false
; TFA_INTERLEAVE-NEXT: [[TMP12:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], i1 [[TMP8]], i1 false
; TFA_INTERLEAVE-NEXT: [[PREDPHI:%.*]] = select i1 [[TMP11]], double 1.000000e+00, double 0.000000e+00
; TFA_INTERLEAVE-NEXT: [[PREDPHI3:%.*]] = select i1 [[TMP12]], double 1.000000e+00, double 0.000000e+00
; TFA_INTERLEAVE-NEXT: [[SPEC_SELECT:%.*]] = select i1 [[ACTIVE_LANE_MASK2]], double [[PREDPHI3]], double [[PREDPHI]]
Expand All @@ -993,11 +991,11 @@ define void @test_widen_exp_v2(ptr noalias %p2, ptr noalias %p, i64 %n) #5 {
; TFA_INTERLEAVE-NEXT: [[TMP15:%.*]] = xor i1 [[TMP13]], true
; TFA_INTERLEAVE-NEXT: [[TMP16:%.*]] = xor i1 [[TMP14]], true
; TFA_INTERLEAVE-NEXT: [[TMP17:%.*]] = or i1 [[TMP15]], [[TMP16]]
; TFA_INTERLEAVE-NEXT: br i1 [[TMP17]], label %[[BB18:.*]], label %[[TMP19]]
; TFA_INTERLEAVE: [[BB18]]:
; TFA_INTERLEAVE-NEXT: br i1 [[TMP17]], label %[[BB16:.*]], label %[[TMP18]]
; TFA_INTERLEAVE: [[BB16]]:
; TFA_INTERLEAVE-NEXT: store double [[SPEC_SELECT]], ptr [[P]], align 8
; TFA_INTERLEAVE-NEXT: br label %[[TMP19]]
; TFA_INTERLEAVE: [[TMP19]]:
; TFA_INTERLEAVE-NEXT: br label %[[TMP18]]
; TFA_INTERLEAVE: [[TMP18]]:
; TFA_INTERLEAVE-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 2
; TFA_INTERLEAVE-NEXT: [[TMP20:%.*]] = add i64 [[INDEX]], 1
; TFA_INTERLEAVE-NEXT: [[ACTIVE_LANE_MASK_NEXT]] = icmp ult i64 [[INDEX]], [[TMP3]]
Expand Down
Loading