diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index 0dcf400962393..9a6bf5ffdd227 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -583,6 +583,18 @@ m_InsertSubvector(const LHS &Base, const RHS &Sub, const IDX &Idx) { return TernaryOpc_match(ISD::INSERT_SUBVECTOR, Base, Sub, Idx); } +template +inline TernaryOpc_match +m_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return TernaryOpc_match(Opc, Op0, Op1, Op2); +} + +template +inline TernaryOpc_match +m_c_TernaryOp(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return TernaryOpc_match(Opc, Op0, Op1, Op2); +} + template inline auto m_SelectCC(const LTy &L, const RTy &R, const TTy &T, const FTy &F, const CCTy &CC) { diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index aa56aafa2812c..ceaee52a3948b 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -354,6 +354,76 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { sd_match(InsertELT, m_InsertElt(m_Value(), m_Value(), m_SpecificInt(1)))); } +TEST_F(SelectionDAGPatternMatchTest, matchGenericTernaryOp) { + SDLoc DL; + auto Float32VT = EVT::getFloatingPointVT(32); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Float32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Float32VT); + SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT); + + SDValue FMA = DAG->getNode(ISD::FMA, DL, Float32VT, Op0, Op1, Op2); + SDValue FAdd = DAG->getNode(ISD::FADD, DL, Float32VT, Op0, Op1); + + using namespace SDPatternMatch; + SDValue A, B, C; + + EXPECT_TRUE(sd_match(FMA, m_TernaryOp(ISD::FMA, m_Specific(Op0), + m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE(sd_match(FMA, m_TernaryOp(ISD::FADD, m_Specific(Op0), + m_Specific(Op1), m_Specific(Op2)))); + EXPECT_FALSE( + sd_match(FAdd, m_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(FMA, m_TernaryOp(ISD::FMA, m_Specific(Op1), + m_Specific(Op0), m_Specific(Op2)))); + + EXPECT_TRUE( + sd_match(FMA, m_TernaryOp(ISD::FMA, m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op1); + EXPECT_EQ(C, Op2); + + A = B = C = SDValue(); + + EXPECT_TRUE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op0), + m_Specific(Op1), m_Specific(Op2)))); + EXPECT_TRUE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op1), + m_Specific(Op0), m_Specific(Op2)))); + + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op2), + m_Specific(Op1), m_Specific(Op0)))); + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op2), + m_Specific(Op0), m_Specific(Op1)))); + + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op0), + m_Specific(Op2), m_Specific(Op1)))); + EXPECT_FALSE(sd_match(FMA, m_c_TernaryOp(ISD::FMA, m_Specific(Op1), + m_Specific(Op2), m_Specific(Op0)))); + + EXPECT_TRUE(sd_match( + FMA, m_c_TernaryOp(ISD::FMA, m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op1); + EXPECT_EQ(C, Op2); + + A = B = C = SDValue(); + EXPECT_TRUE(sd_match( + FMA, m_c_TernaryOp(ISD::FMA, m_Value(B), m_Value(A), m_Value(C)))); + EXPECT_EQ(A, Op1); + EXPECT_EQ(B, Op0); + EXPECT_EQ(C, Op2); + + A = B = C = SDValue(); + EXPECT_TRUE(sd_match( + FMA, m_c_TernaryOp(ISD::FMA, m_Value(A), m_Value(B), m_Value(C)))); + EXPECT_EQ(A, Op0); + EXPECT_EQ(B, Op1); + EXPECT_EQ(C, Op2); + + EXPECT_FALSE( + sd_match(FAdd, m_c_TernaryOp(ISD::FMA, m_Value(), m_Value(), m_Value()))); +} + TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) { SDLoc DL; auto Int32VT = EVT::getIntegerVT(Context, 32);