Skip to content

Commit e04ab76

Browse files
committed
[VPlan] Improve code around canConstantBeExtended (NFC)
Follow up on 7c4f188 ([LV] Support multiplies by constants when forming scaled reductions), introducing m_APInt, and improving code around canConstantBeExtended.
1 parent f3f9e7b commit e04ab76

File tree

5 files changed

+34
-18
lines changed

5 files changed

+34
-18
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7954,9 +7954,9 @@ bool VPRecipeBuilder::getScaledReductions(
79547954
auto CollectExtInfo = [this, &Exts, &ExtOpTypes,
79557955
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
79567956
for (const auto &[I, OpI] : enumerate(Ops)) {
7957-
auto *CI = dyn_cast<ConstantInt>(OpI);
7958-
if (I > 0 && CI &&
7959-
canConstantBeExtended(CI, ExtOpTypes[0], ExtKinds[0])) {
7957+
const APInt *C;
7958+
if (I > 0 && match(OpI, m_APInt(C)) &&
7959+
canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) {
79607960
ExtOpTypes[I] = ExtOpTypes[0];
79617961
ExtKinds[I] = ExtKinds[0];
79627962
continue;

llvm/lib/Transforms/Vectorize/VPlan.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,14 +1753,14 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
17531753
}
17541754
#endif
17551755

1756-
bool llvm::canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
1756+
bool llvm::canConstantBeExtended(const APInt *C, Type *NarrowType,
17571757
TTI::PartialReductionExtendKind ExtKind) {
1758-
APInt TruncatedVal = CI->getValue().trunc(NarrowType->getScalarSizeInBits());
1759-
unsigned WideSize = CI->getType()->getScalarSizeInBits();
1758+
APInt TruncatedVal = C->trunc(NarrowType->getScalarSizeInBits());
1759+
unsigned WideSize = C->getBitWidth();
17601760
APInt ExtendedVal = ExtKind == TTI::PR_SignExtend
17611761
? TruncatedVal.sext(WideSize)
17621762
: TruncatedVal.zext(WideSize);
1763-
return ExtendedVal == CI->getValue();
1763+
return ExtendedVal == *C;
17641764
}
17651765

17661766
TargetTransformInfo::OperandValueInfo

llvm/lib/Transforms/Vectorize/VPlanHelpers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ class VPlanPrinter {
470470

471471
/// Check if a constant \p CI can be safely treated as having been extended
472472
/// from a narrower type with the given extension kind.
473-
bool canConstantBeExtended(const ConstantInt *CI, Type *NarrowType,
473+
bool canConstantBeExtended(const APInt *C, Type *NarrowType,
474474
TTI::PartialReductionExtendKind ExtKind);
475475
} // end namespace llvm
476476

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,10 @@ inline int_pred_ty<is_zero_int> m_ZeroInt() {
173173
/// For vectors, this includes constants with undefined elements.
174174
inline int_pred_ty<is_one> m_One() { return int_pred_ty<is_one>(); }
175175

176-
struct bind_const_int {
177-
uint64_t &Res;
176+
struct bind_apint {
177+
const APInt *&Res;
178178

179-
bind_const_int(uint64_t &Res) : Res(Res) {}
179+
bind_apint(const APInt *&Res) : Res(Res) {}
180180

181181
bool match(VPValue *VPV) const {
182182
if (!VPV->isLiveIn())
@@ -188,7 +188,23 @@ struct bind_const_int {
188188
const auto *CI = dyn_cast<ConstantInt>(V);
189189
if (!CI)
190190
return false;
191-
if (auto C = CI->getValue().tryZExtValue()) {
191+
Res = &CI->getValue();
192+
return true;
193+
}
194+
};
195+
196+
inline bind_apint m_APInt(const APInt *&C) { return C; }
197+
198+
struct bind_const_int {
199+
uint64_t &Res;
200+
201+
bind_const_int(uint64_t &Res) : Res(Res) {}
202+
203+
bool match(VPValue *VPV) const {
204+
const APInt *APConst;
205+
if (!bind_apint(APConst).match(VPV))
206+
return false;
207+
if (auto C = APConst->tryZExtValue()) {
192208
Res = *C;
193209
return true;
194210
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,12 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
341341
ExtAType = GetExtendKind(ExtAR);
342342
ExtBType = GetExtendKind(ExtBR);
343343

344-
if (!ExtBR && Widen->getOperand(1)->isLiveIn()) {
345-
auto *CI = cast<ConstantInt>(Widen->getOperand(1)->getLiveInIRValue());
346-
if (canConstantBeExtended(CI, InputTypeA, ExtAType)) {
347-
InputTypeB = InputTypeA;
348-
ExtBType = ExtAType;
349-
}
344+
using namespace VPlanPatternMatch;
345+
const APInt *C;
346+
if (!ExtBR && match(Widen->getOperand(1), m_APInt(C)) &&
347+
canConstantBeExtended(C, InputTypeA, ExtAType)) {
348+
InputTypeB = InputTypeA;
349+
ExtBType = ExtAType;
350350
}
351351
};
352352

0 commit comments

Comments
 (0)