From ea2bc3a75ba14db5770956219529c0cb0126e8d3 Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Wed, 6 Mar 2024 14:23:26 +0800 Subject: [PATCH 1/2] [RISCV] Don't run combineBinOp_VLToVWBinOp_VL until after legalize types. NFCI I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL. Prior to #72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes. _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types. After #72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in #76785. I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 24 +++++++++------------ 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index fa37306a49990..2f624a9c1505d 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13657,10 +13657,6 @@ struct NodeExtensionHelper { unsigned ScalarBits = VT.getScalarSizeInBits(); unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits(); - // Ensure the narrowing element type is legal - if (!Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType())) - break; - // Ensure the extension's semantic is equivalent to rvv vzext or vsext. if (ScalarBits != NarrowScalarBits * 2) break; @@ -13732,14 +13728,11 @@ struct NodeExtensionHelper { } /// Check if \p Root supports any extension folding combines. - static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + static bool isSupportedRoot(const SDNode *Root) { switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: case ISD::MUL: { - if (!TLI.isTypeLegal(Root->getValueType(0))) - return false; return Root->getValueType(0).isScalableVector(); } // Vector Widening Integer Add/Sub/Mul Instructions @@ -13756,7 +13749,7 @@ struct NodeExtensionHelper { case RISCVISD::FMUL_VL: case RISCVISD::VFWADD_W_VL: case RISCVISD::VFWSUB_W_VL: - return TLI.isTypeLegal(Root->getValueType(0)); + return true; default: return false; } @@ -13765,9 +13758,10 @@ struct NodeExtensionHelper { /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx). NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an " - "unsupported root"); + assert(isSupportedRoot(Root) && "Trying to build an helper with an " + "unsupported root"); assert(OperandIdx < 2 && "Requesting something else than LHS or RHS"); + assert(DAG.getTargetLoweringInfo().isTypeLegal(Root->getValueType(0))); OrigOperand = Root->getOperand(OperandIdx); unsigned Opc = Root->getOpcode(); @@ -13817,7 +13811,7 @@ struct NodeExtensionHelper { static std::pair getMaskAndVL(const SDNode *Root, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { - assert(isSupportedRoot(Root, DAG) && "Unexpected root"); + assert(isSupportedRoot(Root) && "Unexpected root"); switch (Root->getOpcode()) { case ISD::ADD: case ISD::SUB: @@ -14117,8 +14111,10 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const RISCVSubtarget &Subtarget) { SelectionDAG &DAG = DCI.DAG; + if (DCI.isBeforeLegalize()) + return SDValue(); - if (!NodeExtensionHelper::isSupportedRoot(N, DAG)) + if (!NodeExtensionHelper::isSupportedRoot(N)) return SDValue(); SmallVector Worklist; @@ -14129,7 +14125,7 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N, while (!Worklist.empty()) { SDNode *Root = Worklist.pop_back_val(); - if (!NodeExtensionHelper::isSupportedRoot(Root, DAG)) + if (!NodeExtensionHelper::isSupportedRoot(Root)) return SDValue(); NodeExtensionHelper LHS(N, 0, DAG, Subtarget); From a617f3efed49268084d4942274a7776a39091a1e Mon Sep 17 00:00:00 2001 From: Luke Lau Date: Mon, 11 Mar 2024 17:06:39 +0800 Subject: [PATCH 2/2] Add assert --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 2f624a9c1505d..71759fdde9af0 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -13657,6 +13657,9 @@ struct NodeExtensionHelper { unsigned ScalarBits = VT.getScalarSizeInBits(); unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits(); + assert( + Subtarget.getTargetLowering()->isTypeLegal(NarrowElt.getValueType())); + // Ensure the extension's semantic is equivalent to rvv vzext or vsext. if (ScalarBits != NarrowScalarBits * 2) break;