Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
Type *ResTy, VectorType *Ty,
FastMathFlags FMF,
TTI::TargetCostKind CostKind) {
if (auto *FTy = dyn_cast<FixedVectorType>(Ty);
FTy && IsUnsigned && Opcode == Instruction::Add &&
FTy->getElementType() == IntegerType::getInt1Ty(Ty->getContext())) {
// Represent vector_reduce_add(ZExt(<n x i1>)) as
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
auto *IntTy =
IntegerType::get(ResTy->getContext(), FTy->getNumElements());
IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy,
TTI::CastContextHint::None, CostKind) +
thisT()->getIntrinsicInstrCost(ICA, CostKind);
}
// Without any native support, this is equivalent to the cost of
// vecreduce.opcode(ext(Ty A)).
VectorType *ExtTy = VectorType::get(ResTy, Ty);
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost(

std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);

if (IsUnsigned && Opcode == Instruction::Add &&
LT.second.isFixedLengthVector() && LT.second.getScalarType() == MVT::i1) {
// Represent vector_reduce_add(ZExt(<n x i1>)) as
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
return LT.first *
getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind);
}

if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits())
return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy,
FMF, CostKind);
Expand Down
151 changes: 100 additions & 51 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1371,22 +1371,46 @@ class BoUpSLP {
return VectorizableTree.front()->Scalars;
}

/// Returns the type/is-signed info for the root node in the graph without
/// casting.
std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
const TreeEntry &Root = *VectorizableTree.front().get();
if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
!Root.Scalars.front()->getType()->isIntegerTy())
return std::nullopt;
auto It = MinBWs.find(&Root);
if (It != MinBWs.end())
return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
It->second.first),
It->second.second);
if (Root.getOpcode() == Instruction::ZExt ||
Root.getOpcode() == Instruction::SExt)
return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
Root.getOpcode() == Instruction::SExt);
return std::nullopt;
}

/// Checks if the root graph node can be emitted with narrower bitwidth at
/// codegen and returns it signedness, if so.
bool isSignedMinBitwidthRootNode() const {
return MinBWs.at(VectorizableTree.front().get()).second;
}

/// Returns reduction bitwidth and signedness, if it does not match the
/// original requested size.
std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
/// Returns reduction type after minbitdth analysis.
FixedVectorType *getReductionType() const {
if (ReductionBitWidth == 0 ||
!VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
ReductionBitWidth >=
DL->getTypeSizeInBits(
VectorizableTree.front()->Scalars.front()->getType()))
return std::nullopt;
return std::make_pair(ReductionBitWidth,
MinBWs.at(VectorizableTree.front().get()).second);
return getWidenedType(
VectorizableTree.front()->Scalars.front()->getType(),
VectorizableTree.front()->getVectorFactor());
return getWidenedType(
IntegerType::get(
VectorizableTree.front()->Scalars.front()->getContext(),
ReductionBitWidth),
VectorizableTree.front()->getVectorFactor());
}

/// Builds external uses of the vectorized scalars, i.e. the list of
Expand Down Expand Up @@ -11297,6 +11321,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return CommonCost;
auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));

bool IsArithmeticExtendedReduction =
E->Idx == 0 && UserIgnoreList &&
all_of(*UserIgnoreList, [](Value *V) {
auto *I = cast<Instruction>(V);
return is_contained({Instruction::Add, Instruction::FAdd,
Instruction::Mul, Instruction::FMul,
Instruction::And, Instruction::Or,
Instruction::Xor},
I->getOpcode());
});
if (IsArithmeticExtendedReduction &&
(VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
return CommonCost;
return CommonCost +
TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
VecOpcode == Opcode ? VI : nullptr);
Expand Down Expand Up @@ -12652,32 +12690,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
unsigned SrcSize = It->second.first;
unsigned DstSize = ReductionBitWidth;
unsigned Opcode = Instruction::Trunc;
if (SrcSize < DstSize)
Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
auto *SrcVecTy =
getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
auto *DstVecTy =
getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
TTI::CastContextHint CCH = getCastContextHint(E);
InstructionCost CastCost;
switch (E.getOpcode()) {
case Instruction::SExt:
case Instruction::ZExt:
case Instruction::Trunc: {
const TreeEntry *OpTE = getOperandEntry(&E, 0);
CCH = getCastContextHint(*OpTE);
break;
}
default:
break;
if (SrcSize < DstSize) {
bool IsArithmeticExtendedReduction =
all_of(*UserIgnoreList, [](Value *V) {
auto *I = cast<Instruction>(V);
return is_contained({Instruction::Add, Instruction::FAdd,
Instruction::Mul, Instruction::FMul,
Instruction::And, Instruction::Or,
Instruction::Xor},
I->getOpcode());
});
if (IsArithmeticExtendedReduction)
Opcode =
Instruction::BitCast; // Handle it by getExtendedReductionCost
else
Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
}
if (Opcode != Instruction::BitCast) {
auto *SrcVecTy =
getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
auto *DstVecTy =
getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
TTI::CastContextHint CCH = getCastContextHint(E);
InstructionCost CastCost;
switch (E.getOpcode()) {
case Instruction::SExt:
case Instruction::ZExt:
case Instruction::Trunc: {
const TreeEntry *OpTE = getOperandEntry(&E, 0);
CCH = getCastContextHint(*OpTE);
break;
}
default:
break;
}
CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
TTI::TCK_RecipThroughput);
Cost += CastCost;
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
<< " for final resize for reduction from " << SrcVecTy
<< " to " << DstVecTy << "\n";
dbgs() << "SLP: Current total cost = " << Cost << "\n");
}
CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
TTI::TCK_RecipThroughput);
Cost += CastCost;
LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
<< " for final resize for reduction from " << SrcVecTy
<< " to " << DstVecTy << "\n";
dbgs() << "SLP: Current total cost = " << Cost << "\n");
}
}

Expand Down Expand Up @@ -19815,8 +19869,8 @@ class HorizontalReduction {

// Estimate cost.
InstructionCost TreeCost = V.getTreeCost(VL);
InstructionCost ReductionCost = getReductionCost(
TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign());
InstructionCost ReductionCost =
getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V);
InstructionCost Cost = TreeCost + ReductionCost;
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
<< " for reduction\n");
Expand Down Expand Up @@ -20107,14 +20161,14 @@ class HorizontalReduction {

private:
/// Calculate the cost of a reduction.
InstructionCost getReductionCost(
TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
bool IsCmpSelMinMax, FastMathFlags FMF,
const std::optional<std::pair<unsigned, bool>> BitwidthAndSign) {
InstructionCost getReductionCost(TargetTransformInfo *TTI,
ArrayRef<Value *> ReducedVals,
bool IsCmpSelMinMax, FastMathFlags FMF,
const BoUpSLP &R) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *ScalarTy = ReducedVals.front()->getType();
unsigned ReduxWidth = ReducedVals.size();
FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
FixedVectorType *VectorTy = R.getReductionType();
InstructionCost VectorCost = 0, ScalarCost;
// If all of the reduced values are constant, the vector cost is 0, since
// the reduction value can be calculated at the compile time.
Expand Down Expand Up @@ -20172,21 +20226,16 @@ class HorizontalReduction {
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
/*Extract*/ false, TTI::TCK_RecipThroughput);
} else {
auto [Bitwidth, IsSigned] =
BitwidthAndSign.value_or(std::make_pair(0u, false));
if (RdxKind == RecurKind::Add && Bitwidth == 1) {
// Represent vector_reduce_add(ZExt(<n x i1>)) to
// ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
VectorCost =
TTI->getCastInstrCost(Instruction::BitCast, IntTy,
getWidenedType(ScalarTy, ReduxWidth),
TTI::CastContextHint::None, CostKind) +
TTI->getIntrinsicInstrCost(ICA, CostKind);
} else {
Type *RedTy = VectorTy->getElementType();
auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
std::make_pair(RedTy, true));
if (RType == RedTy) {
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
FMF, CostKind);
} else {
VectorCost = TTI->getExtendedReductionCost(
RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
FMF, CostKind);
}
}
}
Expand Down
18 changes: 4 additions & 14 deletions llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -877,20 +877,10 @@ entry:
define i64 @red_zext_ld_4xi64(ptr %ptr) {
; CHECK-LABEL: @red_zext_ld_4xi64(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
; CHECK-NEXT: [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
; CHECK-NEXT: [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
; CHECK-NEXT: [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
; CHECK-NEXT: [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
; CHECK-NEXT: [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
; CHECK-NEXT: [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
; CHECK-NEXT: [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
; CHECK-NEXT: [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
; CHECK-NEXT: [[ADD_3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
; CHECK-NEXT: [[ADD_3:%.*]] = zext i16 [[TMP2]] to i64
; CHECK-NEXT: ret i64 [[ADD_3]]
;
entry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
; YAML-NEXT: Function: test
; YAML-NEXT: Args:
; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost '
; YAML-NEXT: - Cost: '-1'
; YAML-NEXT: - Cost: '-10'
; YAML-NEXT: - String: ' and with tree size '
; YAML-NEXT: - TreeSize: '8'
; YAML-NEXT:...
Expand Down
Loading