-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef h… #147044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef h… #147044
Conversation
@llvm/pr-subscribers-backend-x86 Author: woruyu (woruyu) ChangesThis PR resolves #146871 Full diff: https://github.com/llvm/llvm-project/pull/147044.diff 5 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 35322c32a8283..7c5cdbbeb0ca8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
return SpecificInt_match(APInt(64, V));
}
-inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
-inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
+struct Zero_match {
+ bool AllowUndefs;
+
+ explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &, SDValue N) const {
+ return isZeroOrZeroSplat(N, AllowUndefs);
+ }
+};
+
+struct Ones_match {
+ bool AllowUndefs;
+
+ Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ return isOnesOrOnesSplat(N, AllowUndefs);
+ }
+};
struct AllOnes_match {
+ bool AllowUndefs;
- AllOnes_match() = default;
+ AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
- return isAllOnesOrAllOnesSplat(N);
+ return isAllOnesOrAllOnesSplat(N, AllowUndefs);
}
};
-inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
+inline Ones_match m_One(bool AllowUndefs = false) {
+ return Ones_match(AllowUndefs);
+}
+inline Zero_match m_Zero(bool AllowUndefs = false) {
+ return Zero_match(AllowUndefs);
+}
+inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
+ return AllOnes_match(AllowUndefs);
+}
/// Match true boolean value based on the information provided by
/// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
/// Match a negate as a sub(0, v)
template <typename ValTy>
-inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
+inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
return m_Sub(m_Zero(), V);
}
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..6bfc40afeb55e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
/// Does not permit build vector implicit truncation.
LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
+LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
+
+LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+
/// Return true if \p V is either a integer or FP constant.
inline bool isIntOrFPConstant(SDValue V) {
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 586eb2f3cf45e..db53fb92ae08b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
return V;
// (A - B) - 1 -> add (xor B, -1), A
- if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+ if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
// Look for:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..d6605c3ec77dd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
}
+bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
+ N = peekThroughBitcasts(N);
+ ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
+ return C && C->getAPIntValue() == 1;
+}
+
+bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
+ N = peekThroughBitcasts(N);
+ ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
+ return C && C->isZero();
+}
+
HandleSDNode::~HandleSDNode() {
DropOperands();
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6edbb7b1bae95..1128406236a20 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57923,22 +57923,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
}
}
+ SDValue X, Y;
+
// add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
// iff X and Y won't overflow.
- if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
- ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
- ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
- if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
- MVT OpVT = Op0.getOperand(1).getSimpleValueType();
- SDValue Sum =
- DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
- return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
- getZeroVector(OpVT, Subtarget, DAG, DL));
- }
+ if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
+ sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
+ DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
+ MVT OpVT = X.getSimpleValueType();
+ SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
+ return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
+ getZeroVector(OpVT, Subtarget, DAG, DL));
}
if (VT.isVector()) {
- SDValue X, Y;
EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
VT.getVectorElementCount());
|
@llvm/pr-subscribers-llvm-selectiondag Author: woruyu (woruyu) ChangesThis PR resolves #146871 Full diff: https://github.com/llvm/llvm-project/pull/147044.diff 5 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 35322c32a8283..7c5cdbbeb0ca8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
return SpecificInt_match(APInt(64, V));
}
-inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
-inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
+struct Zero_match {
+ bool AllowUndefs;
+
+ explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &, SDValue N) const {
+ return isZeroOrZeroSplat(N, AllowUndefs);
+ }
+};
+
+struct Ones_match {
+ bool AllowUndefs;
+
+ Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ return isOnesOrOnesSplat(N, AllowUndefs);
+ }
+};
struct AllOnes_match {
+ bool AllowUndefs;
- AllOnes_match() = default;
+ AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
- return isAllOnesOrAllOnesSplat(N);
+ return isAllOnesOrAllOnesSplat(N, AllowUndefs);
}
};
-inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
+inline Ones_match m_One(bool AllowUndefs = false) {
+ return Ones_match(AllowUndefs);
+}
+inline Zero_match m_Zero(bool AllowUndefs = false) {
+ return Zero_match(AllowUndefs);
+}
+inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
+ return AllOnes_match(AllowUndefs);
+}
/// Match true boolean value based on the information provided by
/// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
/// Match a negate as a sub(0, v)
template <typename ValTy>
-inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
+inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
return m_Sub(m_Zero(), V);
}
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..6bfc40afeb55e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
/// Does not permit build vector implicit truncation.
LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
+LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
+
+LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+
/// Return true if \p V is either a integer or FP constant.
inline bool isIntOrFPConstant(SDValue V) {
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 586eb2f3cf45e..db53fb92ae08b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
return V;
// (A - B) - 1 -> add (xor B, -1), A
- if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+ if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
// Look for:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..d6605c3ec77dd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
}
+bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
+ N = peekThroughBitcasts(N);
+ ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
+ return C && C->getAPIntValue() == 1;
+}
+
+bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
+ N = peekThroughBitcasts(N);
+ ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
+ return C && C->isZero();
+}
+
HandleSDNode::~HandleSDNode() {
DropOperands();
}
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6edbb7b1bae95..1128406236a20 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57923,22 +57923,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
}
}
+ SDValue X, Y;
+
// add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
// iff X and Y won't overflow.
- if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
- ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
- ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
- if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
- MVT OpVT = Op0.getOperand(1).getSimpleValueType();
- SDValue Sum =
- DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
- return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
- getZeroVector(OpVT, Subtarget, DAG, DL));
- }
+ if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
+ sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
+ DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
+ MVT OpVT = X.getSimpleValueType();
+ SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
+ return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
+ getZeroVector(OpVT, Subtarget, DAG, DL));
}
if (VT.isVector()) {
- SDValue X, Y;
EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
VT.getVectorElementCount());
|
e71c622
to
229ec28
Compare
Hello, any suggestion for this PR, thank you! @RKSimon @mshockwave |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should get tests in unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
That makes sense,I think adding test for m_Zero/m_One/m_AllOnes default behavior and supportting peekthrough bitcast is a better choice, I will add it! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers
…PatternMatchTest (#147443) ### Summary This PR remove the extra llvm::SDPatternMatch prefix in llvm/llvm-project#147044
@woruyu do you have the |
I originally added the SDPattern for add(psadbw(X,0), psadbw(Y,0)) -> psadbw(add(X,Y),0) in commit e71c622, but removed it during a force-push in 229ec28 based on review feedback. It’s not included in the final PR. |
Sorry for the confusion - I meant that I asked you in the review to remove it from #147044 but then create a separate PR for it as a followup |
Thanks a lot for the clarification — understand, I’d be happy to submit it as a follow-up patch. |
Summary
This PR resolves #146871
This PR resolves #140745
Refactor m_Zero/m_One/m_AllOnes all use struct template function to match and AllowUndefs=false as default.