diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 248a107ded514..427b8bd0e75ab 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -5086,6 +5086,7 @@ BoUpSLP::canVectorizeLoads(ArrayRef VL, const Value *VL0, VecLdCost += TTI.getInstructionCost(cast(VL[Idx]), CostKind); } + unsigned ScalarTyNumElements = getNumElements(ScalarTy); auto *SubVecTy = getWidenedType(ScalarTy, VF); for (auto [I, LS] : enumerate(States)) { auto *LI0 = cast(VL[I * VF]); @@ -5109,11 +5110,12 @@ BoUpSLP::canVectorizeLoads(ArrayRef VL, const Value *VL0, SubVecTy, APInt::getAllOnes(VF), /*Insert=*/true, /*Extract=*/false, CostKind); else - VectorGEPCost += TTI.getScalarizationOverhead( - SubVecTy, APInt::getOneBitSet(VF, 0), - /*Insert=*/true, /*Extract=*/false, CostKind) + - ::getShuffleCost(TTI, TTI::SK_Broadcast, SubVecTy, - {}, CostKind); + VectorGEPCost += + TTI.getScalarizationOverhead( + SubVecTy, APInt::getOneBitSet(ScalarTyNumElements * VF, 0), + /*Insert=*/true, /*Extract=*/false, CostKind) + + ::getShuffleCost(TTI, TTI::SK_Broadcast, SubVecTy, {}, + CostKind); } switch (LS) { case LoadsState::Vectorize: diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll index 65d0078080d22..0cf4da623a0fe 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll @@ -40,3 +40,57 @@ sw.bb509.i: ; preds = %if.then458.i, %if.e %5 = phi <2 x i32> [ %1, %if.then458.i ], [ zeroinitializer, %if.end.i87 ], [ zeroinitializer, %if.end.i87 ] ret i32 0 } + +define void @test2() { +; CHECK-LABEL: @test2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr i8, ptr null, i64 132 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr null, i64 200 +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr null, i64 300 +; CHECK-NEXT: [[TMP3:%.*]] = load <8 x float>, ptr [[TMP1]], align 4 +; CHECK-NEXT: [[TMP4:%.*]] = load <8 x float>, ptr [[TMP2]], align 4 +; CHECK-NEXT: [[TMP5:%.*]] = load <16 x float>, ptr [[TMP0]], align 4 +; CHECK-NEXT: [[TMP6:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> [[TMP4]], i64 0) +; CHECK-NEXT: [[TMP7:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP6]], <8 x float> [[TMP3]], i64 8) +; CHECK-NEXT: [[TMP8:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v16f32(<32 x float> [[TMP7]], <16 x float> [[TMP5]], i64 16) +; CHECK-NEXT: [[TMP9:%.*]] = fpext <32 x float> [[TMP8]] to <32 x double> +; CHECK-NEXT: [[TMP10:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> poison, <8 x double> zeroinitializer, i64 0) +; CHECK-NEXT: [[TMP11:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP10]], <8 x double> zeroinitializer, i64 8) +; CHECK-NEXT: [[TMP12:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP11]], <8 x double> zeroinitializer, i64 16) +; CHECK-NEXT: [[TMP13:%.*]] = call <32 x double> @llvm.vector.insert.v32f64.v8f64(<32 x double> [[TMP12]], <8 x double> zeroinitializer, i64 24) +; CHECK-NEXT: [[TMP14:%.*]] = fadd <32 x double> [[TMP13]], [[TMP9]] +; CHECK-NEXT: [[TMP15:%.*]] = fptrunc <32 x double> [[TMP14]] to <32 x float> +; CHECK-NEXT: [[TMP16:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> poison, <8 x float> zeroinitializer, i64 0) +; CHECK-NEXT: [[TMP17:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP16]], <8 x float> zeroinitializer, i64 8) +; CHECK-NEXT: [[TMP18:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP17]], <8 x float> zeroinitializer, i64 16) +; CHECK-NEXT: [[TMP19:%.*]] = call <32 x float> @llvm.vector.insert.v32f32.v8f32(<32 x float> [[TMP18]], <8 x float> zeroinitializer, i64 24) +; CHECK-NEXT: [[TMP20:%.*]] = fcmp ogt <32 x float> [[TMP19]], [[TMP15]] +; CHECK-NEXT: ret void +; +entry: + %0 = getelementptr i8, ptr null, i64 132 + %1 = getelementptr i8, ptr null, i64 164 + %2 = getelementptr i8, ptr null, i64 200 + %3 = getelementptr i8, ptr null, i64 300 + %4 = load <8 x float>, ptr %0, align 4 + %5 = load <8 x float>, ptr %1, align 4 + %6 = load <8 x float>, ptr %2, align 4 + %7 = load <8 x float>, ptr %3, align 4 + %8 = fpext <8 x float> %4 to <8 x double> + %9 = fpext <8 x float> %5 to <8 x double> + %10 = fpext <8 x float> %6 to <8 x double> + %11 = fpext <8 x float> %7 to <8 x double> + %12 = fadd <8 x double> zeroinitializer, %8 + %13 = fadd <8 x double> zeroinitializer, %9 + %14 = fadd <8 x double> zeroinitializer, %10 + %15 = fadd <8 x double> zeroinitializer, %11 + %16 = fptrunc <8 x double> %12 to <8 x float> + %17 = fptrunc <8 x double> %13 to <8 x float> + %18 = fptrunc <8 x double> %14 to <8 x float> + %19 = fptrunc <8 x double> %15 to <8 x float> + %20 = fcmp ogt <8 x float> zeroinitializer, %16 + %21 = fcmp ogt <8 x float> zeroinitializer, %17 + %22 = fcmp ogt <8 x float> zeroinitializer, %18 + %23 = fcmp ogt <8 x float> zeroinitializer, %19 + ret void +}