diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index b3583e2819ee4..d3d68ff1c6ed2 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -2765,6 +2765,18 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { Type *ResTy, VectorType *Ty, FastMathFlags FMF, TTI::TargetCostKind CostKind) { + if (auto *FTy = dyn_cast(Ty); + FTy && IsUnsigned && Opcode == Instruction::Add && + FTy->getElementType() == IntegerType::getInt1Ty(Ty->getContext())) { + // Represent vector_reduce_add(ZExt()) as + // ZExtOrTrunc(ctpop(bitcast to in)). + auto *IntTy = + IntegerType::get(ResTy->getContext(), FTy->getNumElements()); + IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF); + return thisT()->getCastInstrCost(Instruction::BitCast, IntTy, FTy, + TTI::CastContextHint::None, CostKind) + + thisT()->getIntrinsicInstrCost(ICA, CostKind); + } // Without any native support, this is equivalent to the cost of // vecreduce.opcode(ext(Ty A)). VectorType *ExtTy = VectorType::get(ResTy, Ty); diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 2b16dcbcd8695..d9729e06f7aea 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1620,6 +1620,14 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost( std::pair LT = getTypeLegalizationCost(ValTy); + if (IsUnsigned && Opcode == Instruction::Add && + LT.second.isFixedLengthVector() && LT.second.getScalarType() == MVT::i1) { + // Represent vector_reduce_add(ZExt()) as + // ZExtOrTrunc(ctpop(bitcast to in)). + return LT.first * + getRISCVInstructionCost(RISCV::VCPOP_M, LT.second, CostKind); + } + if (ResTy->getScalarSizeInBits() != 2 * LT.second.getScalarSizeInBits()) return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, FMF, CostKind); diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 8e0ca2677bf0a..46ae908f57ab8 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1371,22 +1371,46 @@ class BoUpSLP { return VectorizableTree.front()->Scalars; } + /// Returns the type/is-signed info for the root node in the graph without + /// casting. + std::optional> getRootNodeTypeWithNoCast() const { + const TreeEntry &Root = *VectorizableTree.front().get(); + if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() || + !Root.Scalars.front()->getType()->isIntegerTy()) + return std::nullopt; + auto It = MinBWs.find(&Root); + if (It != MinBWs.end()) + return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(), + It->second.first), + It->second.second); + if (Root.getOpcode() == Instruction::ZExt || + Root.getOpcode() == Instruction::SExt) + return std::make_pair(cast(Root.getMainOp())->getSrcTy(), + Root.getOpcode() == Instruction::SExt); + return std::nullopt; + } + /// Checks if the root graph node can be emitted with narrower bitwidth at /// codegen and returns it signedness, if so. bool isSignedMinBitwidthRootNode() const { return MinBWs.at(VectorizableTree.front().get()).second; } - /// Returns reduction bitwidth and signedness, if it does not match the - /// original requested size. - std::optional> getReductionBitWidthAndSign() const { + /// Returns reduction type after minbitdth analysis. + FixedVectorType *getReductionType() const { if (ReductionBitWidth == 0 || + !VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() || ReductionBitWidth >= DL->getTypeSizeInBits( VectorizableTree.front()->Scalars.front()->getType())) - return std::nullopt; - return std::make_pair(ReductionBitWidth, - MinBWs.at(VectorizableTree.front().get()).second); + return getWidenedType( + VectorizableTree.front()->Scalars.front()->getType(), + VectorizableTree.front()->getVectorFactor()); + return getWidenedType( + IntegerType::get( + VectorizableTree.front()->Scalars.front()->getContext(), + ReductionBitWidth), + VectorizableTree.front()->getVectorFactor()); } /// Builds external uses of the vectorized scalars, i.e. the list of @@ -11297,6 +11321,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, return CommonCost; auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr; TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0)); + + bool IsArithmeticExtendedReduction = + E->Idx == 0 && UserIgnoreList && + all_of(*UserIgnoreList, [](Value *V) { + auto *I = cast(V); + return is_contained({Instruction::Add, Instruction::FAdd, + Instruction::Mul, Instruction::FMul, + Instruction::And, Instruction::Or, + Instruction::Xor}, + I->getOpcode()); + }); + if (IsArithmeticExtendedReduction && + (VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt)) + return CommonCost; return CommonCost + TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind, VecOpcode == Opcode ? VI : nullptr); @@ -12652,32 +12690,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef VectorizedVals) { unsigned SrcSize = It->second.first; unsigned DstSize = ReductionBitWidth; unsigned Opcode = Instruction::Trunc; - if (SrcSize < DstSize) - Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt; - auto *SrcVecTy = - getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor()); - auto *DstVecTy = - getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor()); - TTI::CastContextHint CCH = getCastContextHint(E); - InstructionCost CastCost; - switch (E.getOpcode()) { - case Instruction::SExt: - case Instruction::ZExt: - case Instruction::Trunc: { - const TreeEntry *OpTE = getOperandEntry(&E, 0); - CCH = getCastContextHint(*OpTE); - break; - } - default: - break; + if (SrcSize < DstSize) { + bool IsArithmeticExtendedReduction = + all_of(*UserIgnoreList, [](Value *V) { + auto *I = cast(V); + return is_contained({Instruction::Add, Instruction::FAdd, + Instruction::Mul, Instruction::FMul, + Instruction::And, Instruction::Or, + Instruction::Xor}, + I->getOpcode()); + }); + if (IsArithmeticExtendedReduction) + Opcode = + Instruction::BitCast; // Handle it by getExtendedReductionCost + else + Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt; + } + if (Opcode != Instruction::BitCast) { + auto *SrcVecTy = + getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor()); + auto *DstVecTy = + getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor()); + TTI::CastContextHint CCH = getCastContextHint(E); + InstructionCost CastCost; + switch (E.getOpcode()) { + case Instruction::SExt: + case Instruction::ZExt: + case Instruction::Trunc: { + const TreeEntry *OpTE = getOperandEntry(&E, 0); + CCH = getCastContextHint(*OpTE); + break; + } + default: + break; + } + CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH, + TTI::TCK_RecipThroughput); + Cost += CastCost; + LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost + << " for final resize for reduction from " << SrcVecTy + << " to " << DstVecTy << "\n"; + dbgs() << "SLP: Current total cost = " << Cost << "\n"); } - CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH, - TTI::TCK_RecipThroughput); - Cost += CastCost; - LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost - << " for final resize for reduction from " << SrcVecTy - << " to " << DstVecTy << "\n"; - dbgs() << "SLP: Current total cost = " << Cost << "\n"); } } @@ -19815,8 +19869,8 @@ class HorizontalReduction { // Estimate cost. InstructionCost TreeCost = V.getTreeCost(VL); - InstructionCost ReductionCost = getReductionCost( - TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign()); + InstructionCost ReductionCost = + getReductionCost(TTI, VL, IsCmpSelMinMax, RdxFMF, V); InstructionCost Cost = TreeCost + ReductionCost; LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n"); @@ -20107,14 +20161,14 @@ class HorizontalReduction { private: /// Calculate the cost of a reduction. - InstructionCost getReductionCost( - TargetTransformInfo *TTI, ArrayRef ReducedVals, - bool IsCmpSelMinMax, FastMathFlags FMF, - const std::optional> BitwidthAndSign) { + InstructionCost getReductionCost(TargetTransformInfo *TTI, + ArrayRef ReducedVals, + bool IsCmpSelMinMax, FastMathFlags FMF, + const BoUpSLP &R) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *ScalarTy = ReducedVals.front()->getType(); unsigned ReduxWidth = ReducedVals.size(); - FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth); + FixedVectorType *VectorTy = R.getReductionType(); InstructionCost VectorCost = 0, ScalarCost; // If all of the reduced values are constant, the vector cost is 0, since // the reduction value can be calculated at the compile time. @@ -20172,21 +20226,16 @@ class HorizontalReduction { VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true, /*Extract*/ false, TTI::TCK_RecipThroughput); } else { - auto [Bitwidth, IsSigned] = - BitwidthAndSign.value_or(std::make_pair(0u, false)); - if (RdxKind == RecurKind::Add && Bitwidth == 1) { - // Represent vector_reduce_add(ZExt()) to - // ZExtOrTrunc(ctpop(bitcast to in)). - auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth); - IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF); - VectorCost = - TTI->getCastInstrCost(Instruction::BitCast, IntTy, - getWidenedType(ScalarTy, ReduxWidth), - TTI::CastContextHint::None, CostKind) + - TTI->getIntrinsicInstrCost(ICA, CostKind); - } else { + Type *RedTy = VectorTy->getElementType(); + auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or( + std::make_pair(RedTy, true)); + if (RType == RedTy) { VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind); + } else { + VectorCost = TTI->getExtendedReductionCost( + RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth), + FMF, CostKind); } } } diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll index bc24a44cecbe3..85131758853b3 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll @@ -877,20 +877,10 @@ entry: define i64 @red_zext_ld_4xi64(ptr %ptr) { ; CHECK-LABEL: @red_zext_ld_4xi64( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1 -; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[LD0]] to i64 -; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1 -; CHECK-NEXT: [[LD1:%.*]] = load i8, ptr [[GEP]], align 1 -; CHECK-NEXT: [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64 -; CHECK-NEXT: [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]] -; CHECK-NEXT: [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2 -; CHECK-NEXT: [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1 -; CHECK-NEXT: [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64 -; CHECK-NEXT: [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]] -; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3 -; CHECK-NEXT: [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1 -; CHECK-NEXT: [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64 -; CHECK-NEXT: [[ADD_3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]] +; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1 +; CHECK-NEXT: [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16> +; CHECK-NEXT: [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]]) +; CHECK-NEXT: [[ADD_3:%.*]] = zext i16 [[TMP2]] to i64 ; CHECK-NEXT: ret i64 [[ADD_3]] ; entry: diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll index e4d20a6db8fa6..09c11bbefd4a3 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/remark-zext-incoming-for-neg-icmp.ll @@ -8,7 +8,7 @@ ; YAML-NEXT: Function: test ; YAML-NEXT: Args: ; YAML-NEXT: - String: 'Vectorized horizontal reduction with cost ' -; YAML-NEXT: - Cost: '-1' +; YAML-NEXT: - Cost: '-10' ; YAML-NEXT: - String: ' and with tree size ' ; YAML-NEXT: - TreeSize: '8' ; YAML-NEXT:...