diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 7fa787bc9befd..9a50a50cee66c 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7954,9 +7954,9 @@ bool VPRecipeBuilder::getScaledReductions( auto CollectExtInfo = [this, &Exts, &ExtOpTypes, &ExtKinds](SmallVectorImpl &Ops) -> bool { for (const auto &[I, OpI] : enumerate(Ops)) { - auto *CI = dyn_cast(OpI); - if (I > 0 && CI && - canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) { + const APInt *C; + if (I > 0 && match(OpI, m_APInt(C)) && + canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) { ExtOpTypes[I] = ExtOpTypes[0]; ExtKinds[I] = ExtKinds[0]; continue; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 07b191a787806..df52f7864c104 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -1753,14 +1753,14 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) { } #endif -bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType, +bool llvm::canConstantBeExtended(const APInt *C, Type *NarrowType, TTI::PartialReductionExtendKind ExtKind) { - APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits()); - unsigned WideSize = CI->getType()->getScalarSizeInBits(); + APInt TruncatedVal = C->trunc(NarrowType->getScalarSizeInBits()); + unsigned WideSize = C->getBitWidth(); APInt ExtendedVal = ExtKind == TTI::PR_SignExtend ? TruncatedVal.sext(WideSize) : TruncatedVal.zext(WideSize); - return ExtendedVal == CI->getValue(); + return ExtendedVal == *C; } TargetTransformInfo::OperandValueInfo diff --git a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h index fc1a09e9850f6..dea8f6723a767 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanHelpers.h +++ b/llvm/lib/Transforms/Vectorize/VPlanHelpers.h @@ -470,7 +470,7 @@ class VPlanPrinter { /// Check if a constant \p CI can be safely treated as having been extended /// from a narrower type with the given extension kind. -bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType, +bool canConstantBeExtended(const APInt *C, Type *NarrowType, TTI::PartialReductionExtendKind ExtKind); } // end namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index 555efea1ea840..49dd631a07156 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -173,10 +173,10 @@ inline int_pred_ty m_ZeroInt() { /// For vectors, this includes constants with undefined elements. inline int_pred_ty m_One() { return int_pred_ty(); } -struct bind_const_int { - uint64_t &Res; +struct bind_apint { + const APInt *&Res; - bind_const_int(uint64_t &Res) : Res(Res) {} + bind_apint(const APInt *&Res) : Res(Res) {} bool match(VPValue *VPV) const { if (!VPV->isLiveIn()) @@ -188,7 +188,23 @@ struct bind_const_int { const auto *CI = dyn_cast(V); if (!CI) return false; - if (auto C = CI->getValue().tryZExtValue()) { + Res = &CI->getValue(); + return true; + } +}; + +inline bind_apint m_APInt(const APInt *&C) { return C; } + +struct bind_const_int { + uint64_t &Res; + + bind_const_int(uint64_t &Res) : Res(Res) {} + + bool match(VPValue *VPV) const { + const APInt *APConst; + if (!bind_apint(APConst).match(VPV)) + return false; + if (auto C = APConst->tryZExtValue()) { Res = *C; return true; } diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 67b9244e9dc72..3080e9dc394bd 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -341,12 +341,12 @@ VPPartialReductionRecipe::computeCost(ElementCount VF, ExtAType = GetExtendKind(ExtAR); ExtBType = GetExtendKind(ExtBR); - if (!ExtBR && Widen->getOperand(1)->isLiveIn()) { - auto *CI = cast(Widen->getOperand(1)->getLiveInIRValue()); - if (canConstantBeExtended(CI, InputTypeA, ExtAType)) { - InputTypeB = InputTypeA; - ExtBType = ExtAType; - } + using namespace VPlanPatternMatch; + const APInt *C; + if (!ExtBR && match(Widen->getOperand(1), m_APInt(C)) && + canConstantBeExtended(C, InputTypeA, ExtAType)) { + InputTypeB = InputTypeA; + ExtBType = ExtAType; } };