|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql.catalyst.optimizer |
19 | 19 |
|
| 20 | +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys |
| 21 | + |
20 | 22 | import scala.collection.immutable.HashSet |
21 | 23 |
|
22 | 24 | import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} |
@@ -987,41 +989,28 @@ object JoinSkewOptimizer extends Rule[LogicalPlan] with PredicateHelper { |
987 | 989 | /** |
988 | 990 | * Adds a null filter on given columns, if any |
989 | 991 | */ |
990 | | - def addNullFilter(columns: AttributeSet, expr: LogicalPlan): LogicalPlan = { |
991 | | - columns.map(IsNotNull(_)) |
| 992 | + def addNullFilter(columns: Seq[Expression], expr: LogicalPlan): LogicalPlan = { |
| 993 | + columns.map(IsNotNull) |
992 | 994 | .reduceLeftOption(And) |
993 | 995 | .map(Filter(_, expr)) |
994 | 996 | .getOrElse(expr) |
995 | 997 | } |
996 | 998 |
|
997 | | - def apply(plan: LogicalPlan): LogicalPlan = plan transform { |
998 | | - case f@Join(left, right, joinType, joinCondition) => |
999 | | - // get "real" join conditions, which refer both left and right |
1000 | | - val joinConditionsOnBothRelations = joinCondition |
1001 | | - .map(splitConjunctivePredicates).getOrElse(Nil) |
1002 | | - .filter(_.isInstanceOf[EqualTo]) |
1003 | | - .filter(cond => !canEvaluate(cond, left) && !canEvaluate(cond, right)) |
1004 | | - |
1005 | | - def nullableJoinKeys(leftOrRight: LogicalPlan) = { |
1006 | | - val joinKeys = leftOrRight.outputSet.intersect( |
1007 | | - joinConditionsOnBothRelations |
1008 | | - .map(_.references) |
1009 | | - .reduceLeftOption(_ ++ _).getOrElse(AttributeSet.empty) |
1010 | | - ) |
1011 | | - joinKeys.filter(_.nullable) |
1012 | | - } |
1013 | | - |
1014 | | - def hasNullableKeys = Seq(left, right).exists(nullableJoinKeys(_).nonEmpty) |
1015 | | - |
1016 | | - joinType match { |
1017 | | - case _ @ (Inner | LeftSemi) if hasNullableKeys => |
1018 | | - // add a non-null keys filter for both sides sub queries |
1019 | | - val newLeft = addNullFilter(nullableJoinKeys(left), left) |
1020 | | - val newRight = addNullFilter(nullableJoinKeys(right), right) |
1021 | | - |
1022 | | - Join(newLeft, newRight, joinType, joinCondition) |
| 999 | + private def hasNullableKeys(leftKeys: Seq[Expression], rightKeys: Seq[Expression]) = { |
| 1000 | + leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable) |
| 1001 | + } |
1023 | 1002 |
|
1024 | | - case _ => f |
| 1003 | + def apply(plan: LogicalPlan): LogicalPlan = plan transform { |
| 1004 | + case join @ Join(left, right, joinType, originalJoinCondition) => |
| 1005 | + join match { |
| 1006 | + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) |
| 1007 | + if hasNullableKeys(leftKeys, rightKeys) && Seq(Inner, LeftSemi).contains(joinType) => |
| 1008 | + // add a non-null join-key filter on both sides of join |
| 1009 | + val newLeft = addNullFilter(leftKeys.filter(_.nullable), left) |
| 1010 | + val newRight = addNullFilter(rightKeys.filter(_.nullable), right) |
| 1011 | + Join(newLeft, newRight, joinType, originalJoinCondition) |
| 1012 | + |
| 1013 | + case _ => join |
1025 | 1014 | } |
1026 | 1015 | } |
1027 | 1016 | } |
0 commit comments