diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b1c74293bb8..92da389e20e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys + import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} @@ -40,6 +42,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: + // run only once, and before other condition push-down optimizations + Batch("Join Skew optimization", FixedPoint(1), + JoinSkewOptimizer) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -976,3 +981,29 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * For an inner join - remove rows with null keys on both sides + */ +object JoinSkewOptimizer extends Rule[LogicalPlan] with PredicateHelper { + private def hasNullableKeys(leftKeys: Seq[Expression], rightKeys: Seq[Expression]) = { + leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case join @ Join(left, right, joinType, originalJoinCondition) => + join match { + // add a non-null join-key filter on both sides of Inner or LeftSemi join + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) + if Seq(Inner, LeftSemi).contains(joinType) && hasNullableKeys(leftKeys, rightKeys) => + val nullFilters = (leftKeys ++ rightKeys) + .filter(_.nullable) + .map(IsNotNull) + val newJoinCondition = (originalJoinCondition ++ nullFilters).reduceLeftOption(And) + + Join(left, right, joinType, newJoinCondition) + + case _ => join + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fba4c5ca77d6..b1e6851f17ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,6 +33,8 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: + Batch("Join Skew optimization", FixedPoint(1), + JoinSkewOptimizer) :: Batch("Filter Pushdown", Once, SamplePushDown, CombineFilters, @@ -279,17 +281,34 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("joins: push down left semi join") { + test("joins: push down left semi join, do NOT add null skew filter for <=>") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) val originalQuery = { - x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2)) + x.join(y, LeftSemi, Option("x.a".attr <=> "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2)) } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b >= 1) val right = testRelation1.where('d >= 2) + val correctAnswer = + left.join(right, LeftSemi, Option("a".attr <=> "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down left semi join, and add null filter") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b >= 1 && 'a.isNotNull) + val right = testRelation1.where('d >= 2 && 'd.isNotNull) val correctAnswer = left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze @@ -474,7 +493,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("joins: can't push down") { + test("joins: can't push down query filters, but inner join can be optimized for null skew") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -483,7 +502,11 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) + val expectedQueryWithNullFilters = { + x.where('b.isNotNull) + .join(y.where('b.isNotNull), condition = Some("x.b".attr === "y.b".attr)) + } + comparePlans(analysis.EliminateSubQueries(expectedQueryWithNullFilters.analyze), optimized) } test("joins: conjunctive predicates") {