diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7a4dd78703155..d78cf00a5a2fc 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44957,6 +44957,7 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( KnownBits KnownOp0, KnownOp1; SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + SDValue Op2 = Op.getOperand(2); // Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of // operand 2). APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52); @@ -44967,6 +44968,13 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1, TLO, Depth + 1)) return true; + + // X * 0 + Y --> Y + // TODO: Handle cases where lower/higher 52 of bits of Op0 * Op1 are known + // zeroes. + if (KnownOp0.trunc(52).isZero() || KnownOp1.trunc(52).isZero()) + return TLO.CombineTo(Op, Op2); + // TODO: Compute the known bits for VPMADD52L/VPMADD52H. break; } diff --git a/llvm/test/CodeGen/X86/combine-vpmadd52.ll b/llvm/test/CodeGen/X86/combine-vpmadd52.ll index 004db995ee584..fd295ea31c55c 100644 --- a/llvm/test/CodeGen/X86/combine-vpmadd52.ll +++ b/llvm/test/CodeGen/X86/combine-vpmadd52.ll @@ -102,5 +102,84 @@ define <2 x i64> @test_vpmadd52h(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> %x2) { %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> %and, <2 x i64> %or) ret <2 x i64> %1 } -;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: -; CHECK: {{.*}} + +; Test the fold x * 0 + y -> y +define <2 x i64> @test_vpmadd52l_mul_zero(<2 x i64> %x0, <2 x i64> %x1) { +; CHECK-LABEL: test_vpmadd52l_mul_zero: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> , <2 x i64> %x1) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52h_mul_zero(<2 x i64> %x0, <2 x i64> %x1) { +; CHECK-LABEL: test_vpmadd52h_mul_zero: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52h.uq.128(<2 x i64> %x0, <2 x i64> , <2 x i64> %x1) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_mul_zero_commuted(<2 x i64> %x0, <2 x i64> %x1) { +; CHECK-LABEL: test_vpmadd52l_mul_zero_commuted: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> %x1, <2 x i64> ) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_mul_zero_both(<2 x i64> %x0) { +; CHECK-LABEL: test_vpmadd52l_mul_zero_both: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> , <2 x i64> ) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_mul_zero_in_52bits(<2 x i64> %x0, <2 x i64> %x1) { +; CHECK-LABEL: test_vpmadd52l_mul_zero_in_52bits: +; CHECK: # %bb.0: +; CHECK-NEXT: retq + + ; mul by (1 << 52) + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> splat (i64 4503599627370496), <2 x i64> %x1) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_add_zero(<2 x i64> %x0, <2 x i64> %x1) { +; AVX512-LABEL: test_vpmadd52l_add_zero: +; AVX512: # %bb.0: +; AVX512-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX512-NEXT: vpmadd52luq %xmm1, %xmm0, %xmm2 +; AVX512-NEXT: vmovdqa %xmm2, %xmm0 +; AVX512-NEXT: retq +; +; AVX-LABEL: test_vpmadd52l_add_zero: +; AVX: # %bb.0: +; AVX-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; AVX-NEXT: {vex} vpmadd52luq %xmm1, %xmm0, %xmm2 +; AVX-NEXT: vmovdqa %xmm2, %xmm0 +; AVX-NEXT: retq + + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> , <2 x i64> %x0, <2 x i64> %x1) + ret <2 x i64> %1 +} + +define <2 x i64> @test_vpmadd52l_mul_zero_scalar(<2 x i64> %x0, <2 x i64> %x1) { +; AVX512-LABEL: test_vpmadd52l_mul_zero_scalar: +; AVX512: # %bb.0: +; AVX512-NEXT: vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0 +; AVX512-NEXT: retq +; +; AVX-LABEL: test_vpmadd52l_mul_zero_scalar: +; AVX: # %bb.0: +; AVX-NEXT: {vex} vpmadd52luq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0 +; AVX-NEXT: retq + + %1 = call <2 x i64> @llvm.x86.avx512.vpmadd52l.uq.128(<2 x i64> %x0, <2 x i64> , <2 x i64> %x1) + ret <2 x i64> %1 +}