@@ -986,29 +986,22 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
986986 * For an inner join - remove rows with null keys on both sides
987987 */
988988object JoinSkewOptimizer extends Rule [LogicalPlan ] with PredicateHelper {
989- /**
990- * Adds a null filter on given columns, if any
991- */
992- def addNullFilter (columns : Seq [Expression ], expr : LogicalPlan ): LogicalPlan = {
993- columns.map(IsNotNull )
994- .reduceLeftOption(And )
995- .map(Filter (_, expr))
996- .getOrElse(expr)
997- }
998-
999989 private def hasNullableKeys (leftKeys : Seq [Expression ], rightKeys : Seq [Expression ]) = {
1000990 leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable)
1001991 }
1002992
1003993 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
1004994 case join @ Join (left, right, joinType, originalJoinCondition) =>
1005995 join match {
996+ // add a non-null join-key filter on both sides of Inner or LeftSemi join
1006997 case ExtractEquiJoinKeys (_, leftKeys, rightKeys, _, _, _)
1007- if hasNullableKeys(leftKeys, rightKeys) && Seq (Inner , LeftSemi ).contains(joinType) =>
1008- // add a non-null join-key filter on both sides of join
1009- val newLeft = addNullFilter(leftKeys.filter(_.nullable), left)
1010- val newRight = addNullFilter(rightKeys.filter(_.nullable), right)
1011- Join (newLeft, newRight, joinType, originalJoinCondition)
998+ if Seq (Inner , LeftSemi ).contains(joinType) && hasNullableKeys(leftKeys, rightKeys) =>
999+ val nullFilters = (leftKeys ++ rightKeys)
1000+ .filter(_.nullable)
1001+ .map(IsNotNull )
1002+ val newJoinCondition = (originalJoinCondition ++ nullFilters).reduceLeftOption(And )
1003+
1004+ Join (left, right, joinType, newJoinCondition)
10121005
10131006 case _ => join
10141007 }
0 commit comments