diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b4670e270141f..b606de022daf0 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47081,7 +47081,8 @@ static SDValue combineArithReduction(SDNode *ExtElt, SelectionDAG &DAG, /// scalars back, while for x64 we should use 64-bit extracts and shifts. static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, - const X86Subtarget &Subtarget) { + const X86Subtarget &Subtarget, + bool& TransformedBinOpReduction) { if (SDValue NewOp = combineExtractWithShuffle(N, DAG, DCI, Subtarget)) return NewOp; @@ -47169,23 +47170,33 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, // Check whether this extract is the root of a sum of absolute differences // pattern. This has to be done here because we really want it to happen // pre-legalization, - if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) + if (SDValue SAD = combineBasicSADPattern(N, DAG, Subtarget)) { + TransformedBinOpReduction = true; return SAD; + } - if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) + if (SDValue VPDPBUSD = combineVPDPBUSDPattern(N, DAG, Subtarget)) { + TransformedBinOpReduction = true; return VPDPBUSD; + } // Attempt to replace an all_of/any_of horizontal reduction with a MOVMSK. - if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) + if (SDValue Cmp = combinePredicateReduction(N, DAG, Subtarget)) { + TransformedBinOpReduction = true; return Cmp; + } // Attempt to replace min/max v8i16/v16i8 reductions with PHMINPOSUW. - if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) + if (SDValue MinMax = combineMinMaxReduction(N, DAG, Subtarget)) { + TransformedBinOpReduction = true; return MinMax; + } // Attempt to optimize ADD/FADD/MUL reductions with HADD, promotion etc.. - if (SDValue V = combineArithReduction(N, DAG, Subtarget)) + if (SDValue V = combineArithReduction(N, DAG, Subtarget)) { + TransformedBinOpReduction = true; return V; + } if (SDValue V = scalarizeExtEltFP(N, DAG, Subtarget, DCI)) return V; @@ -47255,6 +47266,36 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue combineExtractVectorEltAndOperand(SDNode* N, SelectionDAG& DAG, + TargetLowering::DAGCombinerInfo& DCI, + const X86Subtarget& Subtarget) +{ + bool TransformedBinOpReduction = false; + auto Op = combineExtractVectorElt(N, DAG, DCI, Subtarget, TransformedBinOpReduction); + + if (TransformedBinOpReduction) + { + // In case we simplified N = extract_vector_element(V, 0) with Op and V + // resulted from a reduction, then we need to replace all uses of V with + // scalar_to_vector(Op) to make sure that we eliminated the binop + shuffle + // pyramid. This is safe to do, because the elements of V are undefined except + // for the zeroth element and Op does not depend on V. + + auto OldV = N->getOperand(0); + assert(!Op.getNode()->hasPredecessor(OldV.getNode()) && + "Op must not depend on the converted reduction"); + + auto NewV = DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), OldV->getValueType(0), Op); + + auto NV = DCI.CombineTo(N, Op); + DCI.CombineTo(OldV.getNode(), NewV); + + Op = NV; // Return N so it doesn't get rechecked! + } + + return Op; +} + // Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). // This is more or less the reverse of combineBitcastvxi1. static SDValue combineToExtendBoolVectorInReg( @@ -60702,7 +60743,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::EXTRACT_VECTOR_ELT: case X86ISD::PEXTRW: case X86ISD::PEXTRB: - return combineExtractVectorElt(N, DAG, DCI, Subtarget); + return combineExtractVectorEltAndOperand(N, DAG, DCI, Subtarget); case ISD::CONCAT_VECTORS: return combineCONCAT_VECTORS(N, DAG, DCI, Subtarget); case ISD::INSERT_SUBVECTOR: diff --git a/llvm/test/CodeGen/X86/optimize-reduction.ll b/llvm/test/CodeGen/X86/optimize-reduction.ll new file mode 100644 index 0000000000000..e51ac1bd3c13c --- /dev/null +++ b/llvm/test/CodeGen/X86/optimize-reduction.ll @@ -0,0 +1,114 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-- -mattr=+sse4.1,+fast-hops | FileCheck %s --check-prefixes=SSE41 +; RUN: llc < %s -mtriple=x86_64-- -mattr=+avx2,+fast-hops | FileCheck %s --check-prefixes=AVX2 + +define { i16, i16 } @test_reduce_v16i16_with_umin(<16 x i16> %x, <16 x i16> %y) { +; SSE41-LABEL: test_reduce_v16i16_with_umin: +; SSE41: # %bb.0: +; SSE41-NEXT: movdqa %xmm0, %xmm4 +; SSE41-NEXT: pminuw %xmm1, %xmm4 +; SSE41-NEXT: phminposuw %xmm4, %xmm4 +; SSE41-NEXT: movd %xmm4, %eax +; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7] +; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1] +; SSE41-NEXT: pcmpeqw %xmm4, %xmm1 +; SSE41-NEXT: pcmpeqd %xmm5, %xmm5 +; SSE41-NEXT: pxor %xmm5, %xmm1 +; SSE41-NEXT: por %xmm3, %xmm1 +; SSE41-NEXT: pcmpeqw %xmm4, %xmm0 +; SSE41-NEXT: pxor %xmm5, %xmm0 +; SSE41-NEXT: por %xmm2, %xmm0 +; SSE41-NEXT: pminuw %xmm1, %xmm0 +; SSE41-NEXT: phminposuw %xmm0, %xmm0 +; SSE41-NEXT: movd %xmm0, %edx +; SSE41-NEXT: # kill: def $ax killed $ax killed $eax +; SSE41-NEXT: # kill: def $dx killed $dx killed $edx +; SSE41-NEXT: retq +; +; AVX2-LABEL: test_reduce_v16i16_with_umin: +; AVX2: # %bb.0: +; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2 +; AVX2-NEXT: vpminuw %xmm2, %xmm0, %xmm2 +; AVX2-NEXT: vphminposuw %xmm2, %xmm2 +; AVX2-NEXT: vmovd %xmm2, %eax +; AVX2-NEXT: vpbroadcastw %xmm2, %ymm2 +; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0 +; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2 +; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0 +; AVX2-NEXT: vpor %ymm1, %ymm0, %ymm0 +; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX2-NEXT: vpminuw %xmm1, %xmm0, %xmm0 +; AVX2-NEXT: vphminposuw %xmm0, %xmm0 +; AVX2-NEXT: vmovd %xmm0, %edx +; AVX2-NEXT: # kill: def $ax killed $ax killed $eax +; AVX2-NEXT: # kill: def $dx killed $dx killed $edx +; AVX2-NEXT: vzeroupper +; AVX2-NEXT: retq + %min_x = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %x) + %min_x_vec = insertelement <1 x i16> poison, i16 %min_x, i64 0 + %min_x_splat = shufflevector <1 x i16> %min_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer + %cmp = icmp eq <16 x i16> %x, %min_x_splat + %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1) + %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select) + %ret_0 = insertvalue { i16, i16 } poison, i16 %min_x, 0 + %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1 + ret { i16, i16 } %ret +} + +define { i16, i16 } @test_reduce_v16i16_with_add(<16 x i16> %x, <16 x i16> %y) { +; SSE41-LABEL: test_reduce_v16i16_with_add: +; SSE41: # %bb.0: # %start +; SSE41-NEXT: movdqa %xmm1, %xmm4 +; SSE41-NEXT: phaddw %xmm0, %xmm4 +; SSE41-NEXT: phaddw %xmm4, %xmm4 +; SSE41-NEXT: phaddw %xmm4, %xmm4 +; SSE41-NEXT: phaddw %xmm4, %xmm4 +; SSE41-NEXT: movd %xmm4, %eax +; SSE41-NEXT: pshuflw {{.*#+}} xmm4 = xmm4[0,0,0,0,4,5,6,7] +; SSE41-NEXT: pshufd {{.*#+}} xmm4 = xmm4[0,1,0,1] +; SSE41-NEXT: pcmpeqw %xmm4, %xmm1 +; SSE41-NEXT: pcmpeqd %xmm5, %xmm5 +; SSE41-NEXT: pxor %xmm5, %xmm1 +; SSE41-NEXT: por %xmm3, %xmm1 +; SSE41-NEXT: pcmpeqw %xmm4, %xmm0 +; SSE41-NEXT: pxor %xmm5, %xmm0 +; SSE41-NEXT: por %xmm2, %xmm0 +; SSE41-NEXT: pminuw %xmm1, %xmm0 +; SSE41-NEXT: phminposuw %xmm0, %xmm0 +; SSE41-NEXT: movd %xmm0, %edx +; SSE41-NEXT: # kill: def $ax killed $ax killed $eax +; SSE41-NEXT: # kill: def $dx killed $dx killed $edx +; SSE41-NEXT: retq +; +; AVX2-LABEL: test_reduce_v16i16_with_add: +; AVX2: # %bb.0: # %start +; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm2 +; AVX2-NEXT: vphaddw %xmm0, %xmm2, %xmm2 +; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2 +; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2 +; AVX2-NEXT: vphaddw %xmm2, %xmm2, %xmm2 +; AVX2-NEXT: vmovd %xmm2, %eax +; AVX2-NEXT: vpbroadcastw %xmm2, %ymm2 +; AVX2-NEXT: vpcmpeqw %ymm2, %ymm0, %ymm0 +; AVX2-NEXT: vpcmpeqd %ymm2, %ymm2, %ymm2 +; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0 +; AVX2-NEXT: vpor %ymm1, %ymm0, %ymm0 +; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm1 +; AVX2-NEXT: vpminuw %xmm1, %xmm0, %xmm0 +; AVX2-NEXT: vphminposuw %xmm0, %xmm0 +; AVX2-NEXT: vmovd %xmm0, %edx +; AVX2-NEXT: # kill: def $ax killed $ax killed $eax +; AVX2-NEXT: # kill: def $dx killed $dx killed $edx +; AVX2-NEXT: vzeroupper +; AVX2-NEXT: retq +start: + %sum_x = tail call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %x) + %sum_x_vec = insertelement <1 x i16> poison, i16 %sum_x, i64 0 + %sum_x_splat = shufflevector <1 x i16> %sum_x_vec, <1 x i16> poison, <16 x i32> zeroinitializer + %cmp = icmp eq <16 x i16> %x, %sum_x_splat + %select = select <16 x i1> %cmp, <16 x i16> %y, <16 x i16> splat (i16 -1) + %select_min = tail call i16 @llvm.vector.reduce.umin.v16i16(<16 x i16> %select) + %ret_0 = insertvalue { i16, i16 } poison, i16 %sum_x, 0 + %ret = insertvalue { i16, i16 } %ret_0, i16 %select_min, 1 + ret { i16, i16 } %ret +}