Skip to content

Commit d585084

Browse files
committed
Remove extend cost for fdot and fix for Invalid costs
1 parent d7e60d3 commit d585084

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5844,6 +5844,14 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
58445844
return Cost;
58455845
}
58465846

5847+
// f16 -> f32 is natively supported for fdot
5848+
if (ST->isSVEorStreamingSVEAvailable() && ST->hasSVE2p1() &&
5849+
Opcode == Instruction::FAdd) {
5850+
if (AccumLT.second.getScalarType() == MVT::f32 &&
5851+
InputLT.second.getScalarType() == MVT::f16)
5852+
return Cost;
5853+
}
5854+
58475855
// Add additional cost for the extends that would need to be inserted.
58485856
return Cost + 2;
58495857
}

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,12 @@ m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) {
498498
return m_c_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
499499
}
500500

501+
template <typename Op0_t, typename Op1_t>
502+
inline AllRecipe_match<Instruction::FMul, Op0_t, Op1_t>
503+
m_FMul(const Op0_t &Op0, const Op1_t &Op1) {
504+
return m_Binary<Instruction::FMul, Op0_t, Op1_t>(Op0, Op1);
505+
}
506+
501507
/// Match a binary AND operation.
502508
template <typename Op0_t, typename Op1_t>
503509
inline AllRecipe_commutative_match<Instruction::And, Op0_t, Op1_t>

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
317317
// FIXME: Replace the entire function with this once all partial reduction
318318
// variants are bundled into VPExpressionRecipe.
319319
if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) &&
320-
!match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) {
320+
!match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst))) &&
321+
!match(Op, m_FMul(m_VPValue(), m_VPValue()))) {
321322
auto *PhiType = Ctx.Types.inferScalarType(getChainOp());
322323
auto *InputType = Ctx.Types.inferScalarType(getVecOp());
323324
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType,

0 commit comments

Comments
 (0)