From 324cca20a5d025ee1b8aaaa8112e9d5578107356 Mon Sep 17 00:00:00 2001 From: Michael-Chen-NJU <2802328816@qq.com> Date: Wed, 29 Oct 2025 14:57:06 +0800 Subject: [PATCH 1/4] [DAG] Add generic m_TernaryOp() / m_c_TernaryOp() matchers and corresponding tests --- llvm/include/llvm/CodeGen/SDPatternMatch.h | 12 ++++ .../CodeGen/SelectionDAGPatternMatchTest.cpp | 70 +++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index 0dcf400962393..9a6bf5ffdd227 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -583,6 +583,18 @@ m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) { return TernaryOpc_match(ISD::INSERT_SUBVECTOR, Base, Sub, Idx); } +template +inline TernaryOpc_match +m_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return TernaryOpc_match(Opc, Op0, Op1, Op2); +} + +template +inline TernaryOpc_match +m_c_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return TernaryOpc_match(Opc, Op0, Op1, Op2); +} + template inline auto m_SelectCC(const LTy &L, const RTy &R, const TTy &T, const FTy &F, const CCTy &CC) { diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index aa56aafa2812c..ceaee52a3948b 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -354,6 +354,76 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { sd_match(InsertELT, m_InsertElt(m_Value(), m_Value(), m_SpecificInt(1)))); } +TEST_F(SelectionDAGPatternMatchTest, matchGenericTernaryOp) { + SDLoc DL; + auto Float32VT = EVT::getFloatingPointVT(32); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Float32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Float32VT); + SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT); + + SDValue FMA = DAG->getNode(ISD::FMA, DL, Float32VT, Op0, Op1, Op2); + SDValue FAdd = DAG->getNode(ISD::FADD, DL, Float32VT, Op0, Op1); + + using namespace SDPatternMatch; + SDValue A, B, C; + + EXPECT_TRUE(sd_match(FMA, m_TernaryOp(ISD::FMA, m_Specific(Op0), + m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE(sd_match(FMA, m_TernaryOp(ISD::FADD, m_Specific(Op0), + m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE( + sd_match(FAdd, m_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(FMA, m_TernaryOp(ISD::FMA, m_Specific(Op1), + m_Specific(Op0), m_Specific(Op2)))); + + EXPECT_TRUE( + sd_match(FMA, m_TernaryOp(ISD::FMA, m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op1); + EXPECT_EQ(C, Op2); + + A = B = C = SDValue(); + + EXPECT_TRUE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op0), + m_Specific(Op1), m_Specific(Op2)))); + EXPECT_TRUE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op1), + m_Specific(Op0), m_Specific(Op2)))); + + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op2), + m_Specific(Op1), m_Specific(Op0)))); + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op2), + m_Specific(Op0), m_Specific(Op1)))); + + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op0), + m_Specific(Op2), m_Specific(Op1)))); + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op1), + m_Specific(Op2), m_Specific(Op0)))); + + EXPECT_TRUE(sd_match( + FMA, m_c_TernaryOp(ISD::FMA, m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op1); + EXPECT_EQ(C, Op2); + + A = B = C = SDValue(); + EXPECT_TRUE(sd_match( + FMA, m_c_TernaryOp(ISD::FMA, m_Value(B), m_Value(A), m_Value(C)))); + EXPECT_EQ(A, Op1); + EXPECT_EQ(B, Op0); + EXPECT_EQ(C, Op2); + + A = B = C = SDValue(); + EXPECT_TRUE(sd_match( + FMA, m_c_TernaryOp(ISD::FMA, m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op1); + EXPECT_EQ(C, Op2); + + EXPECT_FALSE( + sd_match(FAdd, m_c_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value()))); +} + TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) { SDLoc DL; auto Int32VT = EVT::getIntegerVT(Context, 32); From de8c30dcc3f78c497bfe8448ae5a56b141c80800 Mon Sep 17 00:00:00 2001 From: Michael-Chen-NJU <2802328816@qq.com> Date: Wed, 29 Oct 2025 17:01:00 +0800 Subject: [PATCH 2/4] [DAG] Optimize FMA handling for constant 1.0 by transforming to FADD --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index cf221bba1e3a3..c9ed7b8e4a7d3 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -18369,11 +18369,14 @@ template SDValue DAGCombiner::visitFMA(SDNode *N) { } } - // FIXME: Support splat of constant. - if (N0CFP && N0CFP->isExactlyValue(1.0)) - return matcher.getNode(ISD::FADD, DL, VT, N1, N2); - if (N1CFP && N1CFP->isExactlyValue(1.0)) - return matcher.getNode(ISD::FADD, DL, VT, N0, N2); + using namespace SDPatternMatch; + SDValue X, Y, Cst; + + // (fma 1.0, X, Y) or (fma X, 1.0, Y) -> (fadd X, Y) + SDValue C1 = DAG.getConstantFP(1.0, DL, VT); + if (sd_match(N, + m_c_TernaryOp(ISD::FMA, m_Specific(C1), m_Value(X), m_Value(Y)))) + return matcher.getNode(ISD::FADD, DL, VT, X, Y); // Canonicalize (fma c, x, y) -> (fma x, c, y) if (DAG.isConstantFPBuildVectorOrConstantFP(N0) && From c60fccef3fc5ad6d3ddfdd69463d2b76f65523fe Mon Sep 17 00:00:00 2001 From: Michael-Chen-NJU <2802328816@qq.com> Date: Wed, 29 Oct 2025 21:18:47 +0800 Subject: [PATCH 3/4] [DAG] Remove unnecessary code --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c9ed7b8e4a7d3..450850e784dee 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -18369,8 +18369,7 @@ template SDValue DAGCombiner::visitFMA(SDNode *N) { } } - using namespace SDPatternMatch; - SDValue X, Y, Cst; + SDValue X, Y; // (fma 1.0, X, Y) or (fma X, 1.0, Y) -> (fadd X, Y) SDValue C1 = DAG.getConstantFP(1.0, DL, VT); From fa16ae7ca2de62655ec4164b452867dbcbce65fb Mon Sep 17 00:00:00 2001 From: Michael-Chen-NJU <2802328816@qq.com> Date: Wed, 29 Oct 2025 22:21:34 +0800 Subject: [PATCH 4/4] [DAG] revert --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 450850e784dee..cf221bba1e3a3 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -18369,13 +18369,11 @@ template SDValue DAGCombiner::visitFMA(SDNode *N) { } } - SDValue X, Y; - - // (fma 1.0, X, Y) or (fma X, 1.0, Y) -> (fadd X, Y) - SDValue C1 = DAG.getConstantFP(1.0, DL, VT); - if (sd_match(N, - m_c_TernaryOp(ISD::FMA, m_Specific(C1), m_Value(X), m_Value(Y)))) - return matcher.getNode(ISD::FADD, DL, VT, X, Y); + // FIXME: Support splat of constant. + if (N0CFP && N0CFP->isExactlyValue(1.0)) + return matcher.getNode(ISD::FADD, DL, VT, N1, N2); + if (N1CFP && N1CFP->isExactlyValue(1.0)) + return matcher.getNode(ISD::FADD, DL, VT, N0, N2); // Canonicalize (fma c, x, y) -> (fma x, c, y) if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&