Skip to content

[DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef h… #147044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

woruyu
Copy link
Contributor

@woruyu woruyu commented Jul 4, 2025

Summary

This PR resolves #146871
This PR resolves #140745

Refactor m_Zero/m_One/m_AllOnes all use struct template function to match and AllowUndefs=false as default.

@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels Jul 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-backend-x86

Author: woruyu (woruyu)

Changes

This PR resolves #146871


Full diff: https://github.com/llvm/llvm-project/pull/147044.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+33-6)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+12)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+9-11)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 35322c32a8283..7c5cdbbeb0ca8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
   return SpecificInt_match(APInt(64, V));
 }
 
-inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
-inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
+struct Zero_match {
+  bool AllowUndefs;
+
+  explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &, SDValue N) const {
+    return isZeroOrZeroSplat(N, AllowUndefs);
+  }
+};
+
+struct Ones_match {
+  bool AllowUndefs;
+
+  Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return isOnesOrOnesSplat(N, AllowUndefs);
+  }
+};
 
 struct AllOnes_match {
+  bool AllowUndefs;
 
-  AllOnes_match() = default;
+  AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
 
   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
-    return isAllOnesOrAllOnesSplat(N);
+    return isAllOnesOrAllOnesSplat(N, AllowUndefs);
   }
 };
 
-inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
+inline Ones_match m_One(bool AllowUndefs = false) {
+  return Ones_match(AllowUndefs);
+}
+inline Zero_match m_Zero(bool AllowUndefs = false) {
+  return Zero_match(AllowUndefs);
+}
+inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
+  return AllOnes_match(AllowUndefs);
+}
 
 /// Match true boolean value based on the information provided by
 /// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
 
 /// Match a negate as a sub(0, v)
 template <typename ValTy>
-inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
+inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
   return m_Sub(m_Zero(), V);
 }
 
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..6bfc40afeb55e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
 /// Does not permit build vector implicit truncation.
 LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
 
+LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
+
+LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+
 /// Return true if \p V is either a integer or FP constant.
 inline bool isIntOrFPConstant(SDValue V) {
   return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 586eb2f3cf45e..db53fb92ae08b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
     return V;
 
   // (A - B) - 1  ->  add (xor B, -1), A
-  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
     return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
 
   // Look for:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..d6605c3ec77dd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
   return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
 }
 
+bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
+  return C && C->getAPIntValue() == 1;
+}
+
+bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
+  return C && C->isZero();
+}
+
 HandleSDNode::~HandleSDNode() {
   DropOperands();
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6edbb7b1bae95..1128406236a20 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57923,22 +57923,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  SDValue X, Y;
+
   // add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
   // iff X and Y won't overflow.
-  if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
-      ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
-      ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
-    if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
-      MVT OpVT = Op0.getOperand(1).getSimpleValueType();
-      SDValue Sum =
-          DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
-      return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
-                         getZeroVector(OpVT, Subtarget, DAG, DL));
-    }
+  if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
+      sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
+      DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
+    MVT OpVT = X.getSimpleValueType();
+    SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
+    return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
+                       getZeroVector(OpVT, Subtarget, DAG, DL));
   }
 
   if (VT.isVector()) {
-    SDValue X, Y;
     EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
                                   VT.getVectorElementCount());
 

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: woruyu (woruyu)

Changes

This PR resolves #146871


Full diff: https://github.com/llvm/llvm-project/pull/147044.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+33-6)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+12)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+9-11)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 35322c32a8283..7c5cdbbeb0ca8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
   return SpecificInt_match(APInt(64, V));
 }
 
-inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
-inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
+struct Zero_match {
+  bool AllowUndefs;
+
+  explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &, SDValue N) const {
+    return isZeroOrZeroSplat(N, AllowUndefs);
+  }
+};
+
+struct Ones_match {
+  bool AllowUndefs;
+
+  Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return isOnesOrOnesSplat(N, AllowUndefs);
+  }
+};
 
 struct AllOnes_match {
+  bool AllowUndefs;
 
-  AllOnes_match() = default;
+  AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
 
   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
-    return isAllOnesOrAllOnesSplat(N);
+    return isAllOnesOrAllOnesSplat(N, AllowUndefs);
   }
 };
 
-inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
+inline Ones_match m_One(bool AllowUndefs = false) {
+  return Ones_match(AllowUndefs);
+}
+inline Zero_match m_Zero(bool AllowUndefs = false) {
+  return Zero_match(AllowUndefs);
+}
+inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
+  return AllOnes_match(AllowUndefs);
+}
 
 /// Match true boolean value based on the information provided by
 /// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
 
 /// Match a negate as a sub(0, v)
 template <typename ValTy>
-inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
+inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
   return m_Sub(m_Zero(), V);
 }
 
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..6bfc40afeb55e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
 /// Does not permit build vector implicit truncation.
 LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
 
+LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
+
+LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+
 /// Return true if \p V is either a integer or FP constant.
 inline bool isIntOrFPConstant(SDValue V) {
   return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 586eb2f3cf45e..db53fb92ae08b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
     return V;
 
   // (A - B) - 1  ->  add (xor B, -1), A
-  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
     return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
 
   // Look for:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..d6605c3ec77dd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
   return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
 }
 
+bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
+  return C && C->getAPIntValue() == 1;
+}
+
+bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
+  return C && C->isZero();
+}
+
 HandleSDNode::~HandleSDNode() {
   DropOperands();
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6edbb7b1bae95..1128406236a20 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57923,22 +57923,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  SDValue X, Y;
+
   // add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
   // iff X and Y won't overflow.
-  if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
-      ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
-      ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
-    if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
-      MVT OpVT = Op0.getOperand(1).getSimpleValueType();
-      SDValue Sum =
-          DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
-      return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
-                         getZeroVector(OpVT, Subtarget, DAG, DL));
-    }
+  if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
+      sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
+      DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
+    MVT OpVT = X.getSimpleValueType();
+    SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
+    return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
+                       getZeroVector(OpVT, Subtarget, DAG, DL));
   }
 
   if (VT.isVector()) {
-    SDValue X, Y;
     EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
                                   VT.getVectorElementCount());
 

@woruyu woruyu force-pushed the feat/SDPatternMatch-m_Zero-m_One-m_AllOnes branch from e71c622 to 229ec28 Compare July 7, 2025 02:18
@woruyu
Copy link
Contributor Author

woruyu commented Jul 7, 2025

Hello, any suggestion for this PR, thank you! @RKSimon @mshockwave

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should get tests in unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

@woruyu
Copy link
Contributor Author

woruyu commented Jul 7, 2025

Probably should get tests in unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

That makes sense,I think adding test for m_Zero/m_One/m_AllOnes default behavior and supportting peekthrough bitcast is a better choice, I will add it!

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers

@RKSimon RKSimon merged commit c80fa23 into llvm:main Jul 7, 2025
9 checks passed
mshockwave pushed a commit that referenced this pull request Jul 8, 2025
…hTest (#147443)

### Summary
This PR remove the extra llvm::SDPatternMatch prefix in
#147044
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jul 8, 2025
…PatternMatchTest (#147443)

### Summary
This PR remove the extra llvm::SDPatternMatch prefix in
llvm/llvm-project#147044
@RKSimon
Copy link
Collaborator

RKSimon commented Jul 8, 2025

@woruyu do you have the add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0) patch somewhere?

@woruyu
Copy link
Contributor Author

woruyu commented Jul 8, 2025

@woruyu do you have the add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0) patch somewhere?

I originally added the SDPattern for add(psadbw(X,0), psadbw(Y,0)) -> psadbw(add(X,Y),0) in commit e71c622, but removed it during a force-push in 229ec28 based on review feedback. It’s not included in the final PR.

@RKSimon
Copy link
Collaborator

RKSimon commented Jul 8, 2025

Sorry for the confusion - I meant that I asked you in the review to remove it from #147044 but then create a separate PR for it as a followup

@woruyu
Copy link
Contributor Author

woruyu commented Jul 8, 2025

Sorry for the confusion - I meant that I asked you in the review to remove it from #147044 but then create a separate PR for it as a followup

Thanks a lot for the clarification — understand, I’d be happy to submit it as a follow-up patch.

@woruyu
Copy link
Contributor Author

woruyu commented Jul 9, 2025

Sorry for the confusion - I meant that I asked you in the review to remove it from #147044 but then create a separate PR for it as a followup

The related PR is #147637

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef handling [DAG] SDPatternMatch - m_Zero can't see through bitcasts
5 participants