Skip to content

Commit d2a99d1

Browse files
wangyumTongWei1105
authored andcommitted
[CARMEL-7452][SPARK-42500][SQL] ConstantPropagation support more case (apache#256)
### What changes were proposed in this pull request? This PR enhances ConstantPropagation to support more cases. Propagated through other binary comparisons. Propagated across equality comparisons. This can be further optimized to false. ### Why are the changes needed? Improve query performance. [Denodo](https://community.denodo.com/docs/html/browse/latest/en/vdp/administration/optimizing_queries/automatic_simplification_of_queries/removing_redundant_branches_of_queries_partitioned_unions) also has a similar optimization. For example: ``` CREATE TABLE t1(a int, b int) using parquet; CREATE TABLE t2(x int, y int) using parquet; CREATE TEMP VIEW v1 AS SELECT * FROM t1 JOIN t2 WHERE a = x AND a = 0 UNION ALL SELECT * FROM t1 JOIN t2 WHERE a = x AND (a IS NULL OR a <> 0); SELECT * FROM v1 WHERE x > 1; ``` Before this PR: ``` == Optimized Logical Plan == Union false, false :- Project [a#0 AS a#12, b#1 AS b#13, x#2 AS x#14, y#3 AS y#15] : +- Join Inner : :- Filter (isnotnull(a#0) AND (a#0 = 0)) : : +- Relation spark_catalog.default.t1[a#0,b#1] parquet : +- Filter (isnotnull(x#2) AND ((0 = x#2) AND (x#2 > 1))) : +- Relation spark_catalog.default.t2[x#2,y#3] parquet +- Join Inner, (a#16 = x#18) :- Filter ((isnull(a#16) OR NOT (a#16 = 0)) AND ((a#16 > 1) AND isnotnull(a#16))) : +- Relation spark_catalog.default.t1[a#16,b#17] parquet +- Filter ((isnotnull(x#18) AND (x#18 > 1)) AND (isnull(x#18) OR NOT (x#18 = 0))) +- Relation spark_catalog.default.t2[x#18,y#19] parquet ``` After this PR: ``` == Optimized Logical Plan == Join Inner, (a#16 = x#18) :- Filter ((isnull(a#16) OR NOT (a#16 = 0)) AND ((a#16 > 1) AND isnotnull(a#16))) : +- Relation spark_catalog.default.t1[a#16,b#17] parquet +- Filter ((isnotnull(x#18) AND (x#18 > 1)) AND (isnull(x#18) OR NOT (x#18 = 0))) +- Relation spark_catalog.default.t2[x#18,y#19] parquet ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. Closes apache#42038 from TongWei1105/SPARK-42500. Authored-by: TongWei1105 <[email protected]> Signed-off-by: Yuming Wang <[email protected]> (cherry picked from commit 74ae1e3) Co-authored-by: TongWei1105 <[email protected]>
1 parent 17c8af4 commit d2a99d1

File tree

2 files changed

+47
-22
lines changed

2 files changed

+47
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ object ConstantPropagation extends Rule[LogicalPlan] {
122122
}
123123
}
124124

125-
type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]
126-
127125
/**
128126
* Traverse a condition as a tree and replace attributes with constant values.
129127
* - On matching [[And]], recursively traverse each children and get propagated mappings.
@@ -140,23 +138,23 @@ object ConstantPropagation extends Rule[LogicalPlan] {
140138
* resulted false
141139
* @return A tuple including:
142140
* 1. Option[Expression]: optional changed condition after traversal
143-
* 2. EqualityPredicates: propagated mapping of attribute => constant
141+
* 2. AttributeMap: propagated mapping of attribute => constant
144142
*/
145143
private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean)
146-
: (Option[Expression], EqualityPredicates) =
144+
: (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
147145
condition match {
148146
case e @ EqualTo(left: AttributeReference, right: Literal)
149147
if safeToReplace(left, nullIsFalse) =>
150-
(None, Seq(((left, right), e)))
148+
(None, AttributeMap(Map(left -> (right, e))))
151149
case e @ EqualTo(left: Literal, right: AttributeReference)
152150
if safeToReplace(right, nullIsFalse) =>
153-
(None, Seq(((right, left), e)))
151+
(None, AttributeMap(Map(right -> (left, e))))
154152
case e @ EqualNullSafe(left: AttributeReference, right: Literal)
155153
if safeToReplace(left, nullIsFalse) =>
156-
(None, Seq(((left, right), e)))
154+
(None, AttributeMap(Map(left -> (right, e))))
157155
case e @ EqualNullSafe(left: Literal, right: AttributeReference)
158156
if safeToReplace(right, nullIsFalse) =>
159-
(None, Seq(((right, left), e)))
157+
(None, AttributeMap(Map(right -> (left, e))))
160158
case a: And =>
161159
val (newLeft, equalityPredicatesLeft) =
162160
traverse(a.left, replaceChildren = false, nullIsFalse)
@@ -183,12 +181,12 @@ object ConstantPropagation extends Rule[LogicalPlan] {
183181
} else {
184182
None
185183
}
186-
(newSelf, Seq.empty)
184+
(newSelf, AttributeMap.empty)
187185
case n: Not =>
188186
// Ignore the EqualityPredicates from children since they are only propagated through And.
189187
val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false)
190-
(newChild.map(Not), Seq.empty)
191-
case _ => (None, Seq.empty)
188+
(newChild.map(Not), AttributeMap.empty)
189+
case _ => (None, AttributeMap.empty)
192190
}
193191

194192
// We need to take into account if an attribute is nullable and the context of the conjunctive
@@ -199,16 +197,15 @@ object ConstantPropagation extends Rule[LogicalPlan] {
199197
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
200198
!ar.nullable || nullIsFalse
201199

202-
private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
203-
: Expression = {
204-
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
205-
val predicates = equalityPredicates.map(_._2).toSet
206-
def replaceConstants0(expression: Expression) = expression transform {
207-
case a: AttributeReference => constantsMap.getOrElse(a, a)
208-
}
200+
private def replaceConstants(
201+
condition: Expression,
202+
equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = {
203+
val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })
204+
val predicates = equalityPredicates.values.map(_._2).toSet
209205
condition transform {
210-
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
211-
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
206+
case b: BinaryComparison if !predicates.contains(b) => b transform {
207+
case a: AttributeReference => constantsMap.getOrElse(a, a)
208+
}
212209
}
213210
}
214211
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest {
159159
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))
160160

161161
val correctAnswer = testRelation
162-
.select(columnA)
163-
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze
162+
.select(columnA, columnB)
163+
.where(Literal.FalseLiteral)
164+
.select(columnA).analyze
164165

165166
comparePlans(Optimize.execute(query.analyze), correctAnswer)
166167
}
@@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest {
186187
.analyze
187188
comparePlans(Optimize.execute(query2), correctAnswer2)
188189
}
190+
191+
test("SPARK-42500: ConstantPropagation supports more cases") {
192+
comparePlans(
193+
Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze),
194+
testRelation.where(columnA === 1 && columnB > 3).analyze)
195+
196+
comparePlans(
197+
Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze),
198+
testRelation.where(Literal.FalseLiteral).analyze)
199+
200+
comparePlans(
201+
Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze),
202+
testRelation.where(Literal.FalseLiteral).analyze)
203+
204+
comparePlans(
205+
Optimize.execute(
206+
testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze),
207+
testRelation.where(columnA === 1 && columnB === 1).analyze)
208+
209+
comparePlans(
210+
Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze),
211+
testRelation.where(columnA === 1).analyze)
212+
213+
comparePlans(
214+
Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze),
215+
testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze)
216+
}
189217
}

0 commit comments

Comments
 (0)