Skip to content

Commit cd8ca34

Browse files
author
vidmantas zemleris
committed
Refactor to add null filter to joinConditions
it will be pushed down by other rules, such as PushPredicateThroughJoin
1 parent eaa12bc commit cd8ca34

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -986,29 +986,22 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
986986
* For an inner join - remove rows with null keys on both sides
987987
*/
988988
object JoinSkewOptimizer extends Rule[LogicalPlan] with PredicateHelper {
989-
/**
990-
* Adds a null filter on given columns, if any
991-
*/
992-
def addNullFilter(columns: Seq[Expression], expr: LogicalPlan): LogicalPlan = {
993-
columns.map(IsNotNull)
994-
.reduceLeftOption(And)
995-
.map(Filter(_, expr))
996-
.getOrElse(expr)
997-
}
998-
999989
private def hasNullableKeys(leftKeys: Seq[Expression], rightKeys: Seq[Expression]) = {
1000990
leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable)
1001991
}
1002992

1003993
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1004994
case join @ Join(left, right, joinType, originalJoinCondition) =>
1005995
join match {
996+
// add a non-null join-key filter on both sides of Inner or LeftSemi join
1006997
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _)
1007-
if hasNullableKeys(leftKeys, rightKeys) && Seq(Inner, LeftSemi).contains(joinType) =>
1008-
// add a non-null join-key filter on both sides of join
1009-
val newLeft = addNullFilter(leftKeys.filter(_.nullable), left)
1010-
val newRight = addNullFilter(rightKeys.filter(_.nullable), right)
1011-
Join(newLeft, newRight, joinType, originalJoinCondition)
998+
if Seq(Inner, LeftSemi).contains(joinType) && hasNullableKeys(leftKeys, rightKeys) =>
999+
val nullFilters = (leftKeys ++ rightKeys)
1000+
.filter(_.nullable)
1001+
.map(IsNotNull)
1002+
val newJoinCondition = (originalJoinCondition ++ nullFilters).reduceLeftOption(And)
1003+
1004+
Join(left, right, joinType, newJoinCondition)
10121005

10131006
case _ => join
10141007
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ class FilterPushdownSuite extends PlanTest {
307307
}
308308

309309
val optimized = Optimize.execute(originalQuery.analyze)
310-
val left = testRelation.where('a.isNotNull).where('b >= 1)
311-
val right = testRelation1.where('d.isNotNull).where('d >= 2)
310+
val left = testRelation.where('b >= 1 && 'a.isNotNull)
311+
val right = testRelation1.where('d >= 2 && 'd.isNotNull)
312312
val correctAnswer =
313313
left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze
314314

0 commit comments

Comments
 (0)