diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 6067b3b29ea18..a63bf1566007d 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2935,8 +2935,8 @@ class MaskedGatherScatterSDNode : public MemSDNode { const SDValue &getScale() const { return getOperand(5); } static bool classof(const SDNode *N) { - return N->getOpcode() == ISD::MGATHER || - N->getOpcode() == ISD::MSCATTER; + return N->getOpcode() == ISD::MGATHER || N->getOpcode() == ISD::MSCATTER || + N->getOpcode() == ISD::EXPERIMENTAL_VECTOR_HISTOGRAM; } }; @@ -2991,17 +2991,15 @@ class MaskedScatterSDNode : public MaskedGatherScatterSDNode { } }; -class MaskedHistogramSDNode : public MemSDNode { +class MaskedHistogramSDNode : public MaskedGatherScatterSDNode { public: friend class SelectionDAG; MaskedHistogramSDNode(unsigned Order, const DebugLoc &DL, SDVTList VTs, EVT MemVT, MachineMemOperand *MMO, ISD::MemIndexType IndexType) - : MemSDNode(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, Order, DL, VTs, MemVT, - MMO) { - LSBaseSDNodeBits.AddressingMode = IndexType; - } + : MaskedGatherScatterSDNode(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, Order, DL, + VTs, MemVT, MMO, IndexType) {} ISD::MemIndexType getIndexType() const { return static_cast(LSBaseSDNodeBits.AddressingMode); diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c6f6fc2508054..bfe6f4fdc2b82 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -549,6 +549,7 @@ namespace { SDValue visitMSTORE(SDNode *N); SDValue visitMGATHER(SDNode *N); SDValue visitMSCATTER(SDNode *N); + SDValue visitMHISTOGRAM(SDNode *N); SDValue visitVPGATHER(SDNode *N); SDValue visitVPSCATTER(SDNode *N); SDValue visitVP_STRIDED_LOAD(SDNode *N); @@ -1972,6 +1973,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::MLOAD: return visitMLOAD(N); case ISD::MSCATTER: return visitMSCATTER(N); case ISD::MSTORE: return visitMSTORE(N); + case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return visitMHISTOGRAM(N); case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N); case ISD::LIFETIME_END: return visitLIFETIME_END(N); case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); @@ -12353,6 +12355,35 @@ SDValue DAGCombiner::visitMLOAD(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) { + MaskedHistogramSDNode *HG = cast(N); + SDValue Chain = HG->getChain(); + SDValue Inc = HG->getInc(); + SDValue Mask = HG->getMask(); + SDValue BasePtr = HG->getBasePtr(); + SDValue Index = HG->getIndex(); + SDLoc DL(HG); + + EVT MemVT = HG->getMemoryVT(); + MachineMemOperand *MMO = HG->getMemOperand(); + ISD::MemIndexType IndexType = HG->getIndexType(); + + if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) + return Chain; + + SDValue Ops[] = {Chain, Inc, Mask, BasePtr, Index, + HG->getScale(), HG->getIntID()}; + if (refineUniformBase(BasePtr, Index, HG->isIndexScaled(), DAG, DL)) + return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops, + MMO, IndexType); + + EVT DataVT = Index.getValueType(); + if (refineIndexType(Index, IndexType, DataVT, DAG)) + return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, Ops, + MMO, IndexType); + return SDValue(); +} + SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) { auto *SLD = cast(N); EVT EltVT = SLD->getValueType(0).getVectorElementType(); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 48e1b96d841ef..31bd2a6f7b5a9 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1114,7 +1114,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT, ISD::VECREDUCE_ADD, ISD::STEP_VECTOR}); - setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER}); + setTargetDAGCombine( + {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM}); setTargetDAGCombine(ISD::FP_EXTEND); @@ -24079,11 +24080,9 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N, static SDValue performMaskedGatherScatterCombine( SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - MaskedGatherScatterSDNode *MGS = cast(N); - assert(MGS && "Can only combine gather load or scatter store nodes"); - if (!DCI.isBeforeLegalize()) return SDValue(); + MaskedGatherScatterSDNode *MGS = cast(N); SDLoc DL(MGS); SDValue Chain = MGS->getChain(); @@ -24105,12 +24104,18 @@ static SDValue performMaskedGatherScatterCombine( DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); } - auto *MSC = cast(MGS); - SDValue Data = MSC->getValue(); - SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, - Ops, MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); + if (auto *MSC = dyn_cast(MGS)) { + SDValue Data = MSC->getValue(); + SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), + DL, Ops, MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } + auto *HG = cast(MGS); + SDValue Ops[] = {Chain, HG->getInc(), Mask, BasePtr, + Index, Scale, HG->getIntID()}; + return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), HG->getMemoryVT(), + DL, Ops, HG->getMemOperand(), IndexType); } /// Target-specific DAG combine function for NEON load/store intrinsics @@ -26277,6 +26282,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performMSTORECombine(N, DCI, DAG, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: + case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM: return performMaskedGatherScatterCombine(N, DCI, DAG); case ISD::FP_EXTEND: return performFPExtendCombine(N, DAG, DCI, Subtarget); diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll index dd0b9639a8fc2..06cd65620d1c9 100644 --- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll +++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll @@ -267,5 +267,233 @@ define void @histogram_i16_8_lane(ptr %base, %indices, i16 %i ret void } +define void @histogram_i8_zext(ptr %base, %indices, %mask, i8 %inc) #0{ +; CHECK-LABEL: histogram_i8_zext: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: mov z3.s, w1 +; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, uxtw] +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, uxtw] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i8, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i8( %buckets, i8 %inc, %mask) + ret void +} + +define void @histogram_i16_zext(ptr %base, %indices, %mask, i16 %inc) #0{ +; CHECK-LABEL: histogram_i16_zext: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: mov z3.s, w1 +; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, uxtw #1] +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, uxtw #1] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i16, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i16( %buckets, i16 %inc, %mask) + ret void +} + +define void @histogram_i32_zext(ptr %base, %indices, %mask) #0 { +; CHECK-LABEL: histogram_i32_zext: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: mov z3.s, #1 // =0x1 +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2] +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, uxtw #2] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_i32_sext(ptr %base, %indices, %mask) #0 { +; CHECK-LABEL: histogram_i32_sext: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: mov z3.s, #1 // =0x1 +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2] +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2] +; CHECK-NEXT: ret + %extended = sext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_zext_from_i8_to_i64(ptr %base, %indices, %mask) #0{ +; CHECK-LABEL: histogram_zext_from_i8_to_i64: +; CHECK: // %bb.0: +; CHECK-NEXT: and z0.s, z0.s, #0xff +; CHECK-NEXT: mov z3.s, #1 // =0x1 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2] +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, uxtw #2] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_zext_from_i16_to_i64(ptr %base, %indices, %mask) #0{ +; CHECK-LABEL: histogram_zext_from_i16_to_i64: +; CHECK: // %bb.0: +; CHECK-NEXT: and z0.s, z0.s, #0xffff +; CHECK-NEXT: mov z3.s, #1 // =0x1 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2] +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, uxtw #2] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_sext_from_i16_to_i64(ptr %base, %indices, %mask) #0{ +; CHECK-LABEL: histogram_sext_from_i16_to_i64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mov z3.s, #1 // =0x1 +; CHECK-NEXT: sxth z0.s, p1/m, z0.s +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, sxtw #2] +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2] +; CHECK-NEXT: ret + %extended = sext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_zext_from_i8_to_i32(ptr %base, %indices, %mask) #0{ +; CHECK-LABEL: histogram_zext_from_i8_to_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: and z0.s, z0.s, #0xff +; CHECK-NEXT: mov z3.s, #1 // =0x1 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2] +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, uxtw #2] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_zext_from_i16_to_i32(ptr %base, %indices, %mask) #0 { +; CHECK-LABEL: histogram_zext_from_i16_to_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: and z0.s, z0.s, #0xffff +; CHECK-NEXT: mov z3.s, #1 // =0x1 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z0.s, uxtw #2] +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, uxtw #2] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_2_lane_zext(ptr %base, %indices, %mask) #0 { +; CHECK-LABEL: histogram_2_lane_zext: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z1.d, z0.d +; CHECK-NEXT: mov z3.d, #1 // =0x1 +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, uxtw #2] +; CHECK-NEXT: and z1.d, z1.d, #0xffffffff +; CHECK-NEXT: histcnt z1.d, p0/z, z1.d, z1.d +; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d +; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, uxtw #2] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv2p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_8_lane_zext(ptr %base, %indices, %mask) #0{ +; CHECK-LABEL: histogram_8_lane_zext: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: mov z4.s, #1 // =0x1 +; CHECK-NEXT: ptrue p2.s +; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s +; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, uxtw #2] +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s +; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, uxtw #2] +; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, uxtw #2] +; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s +; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, uxtw #2] +; CHECK-NEXT: ret + %extended = zext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv8p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_8_lane_sext(ptr %base, %indices, %mask) #0{ +; CHECK-LABEL: histogram_8_lane_sext: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: mov z4.s, #1 // =0x1 +; CHECK-NEXT: ptrue p2.s +; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s +; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2] +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s +; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2] +; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s +; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2] +; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s +; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2] +; CHECK-NEXT: ret + %extended = sext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv8p0.i32( %buckets, i32 1, %mask) + ret void +} + +define void @histogram_zero_mask( %buckets, i64 %inc, %mask) #0{ +; CHECK-LABEL: histogram_zero_mask: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + call void @llvm.experimental.vector.histogram.add.nxv2p0.i64( %buckets, i64 %inc, zeroinitializer) + ret void +} + +define void @histogram_sext_zero_mask(ptr %base, %indices, %mask) #0{ +; CHECK-LABEL: histogram_sext_zero_mask: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %extended = sext %indices to + %buckets = getelementptr i32, ptr %base, %extended + call void @llvm.experimental.vector.histogram.add.nxv4p0.i32( %buckets, i32 1, zeroinitializer) + ret void +} attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }