Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys

import scala.collection.immutable.HashSet

import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
Expand All @@ -40,6 +42,9 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
Batch("Aggregate", FixedPoint(100),
ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
// run only once, and before other condition push-down optimizations
Batch("Join Skew optimization", FixedPoint(1),
JoinSkewOptimizer) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
SetOperationPushDown,
Expand Down Expand Up @@ -976,3 +981,29 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
a.copy(groupingExpressions = newGrouping)
}
}

/**
* For an inner join - remove rows with null keys on both sides
*/
object JoinSkewOptimizer extends Rule[LogicalPlan] with PredicateHelper {
private def hasNullableKeys(leftKeys: Seq[Expression], rightKeys: Seq[Expression]) = {
leftKeys.exists(_.nullable) || rightKeys.exists(_.nullable)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case join @ Join(left, right, joinType, originalJoinCondition) =>
join match {
// add a non-null join-key filter on both sides of Inner or LeftSemi join
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _)
if Seq(Inner, LeftSemi).contains(joinType) && hasNullableKeys(leftKeys, rightKeys) =>
val nullFilters = (leftKeys ++ rightKeys)
.filter(_.nullable)
.map(IsNotNull)
val newJoinCondition = (originalJoinCondition ++ nullFilters).reduceLeftOption(And)

Join(left, right, joinType, newJoinCondition)

case _ => join
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class FilterPushdownSuite extends PlanTest {
val batches =
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Join Skew optimization", FixedPoint(1),
JoinSkewOptimizer) ::
Batch("Filter Pushdown", Once,
SamplePushDown,
CombineFilters,
Expand Down Expand Up @@ -279,17 +281,34 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("joins: push down left semi join") {
test("joins: push down left semi join, do NOT add null skew filter for <=>") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)

val originalQuery = {
x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2))
x.join(y, LeftSemi, Option("x.a".attr <=> "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b >= 1)
val right = testRelation1.where('d >= 2)
val correctAnswer =
left.join(right, LeftSemi, Option("a".attr <=> "d".attr)).analyze

comparePlans(optimized, correctAnswer)
}

test("joins: push down left semi join, and add null filter") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)

val originalQuery = {
x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b >= 1 && 'a.isNotNull)
val right = testRelation1.where('d >= 2 && 'd.isNotNull)
val correctAnswer =
left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze

Expand Down Expand Up @@ -474,7 +493,7 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("joins: can't push down") {
test("joins: can't push down query filters, but inner join can be optimized for null skew") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

Expand All @@ -483,7 +502,11 @@ class FilterPushdownSuite extends PlanTest {
}
val optimized = Optimize.execute(originalQuery.analyze)

comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized)
val expectedQueryWithNullFilters = {
x.where('b.isNotNull)
.join(y.where('b.isNotNull), condition = Some("x.b".attr === "y.b".attr))
}
comparePlans(analysis.EliminateSubQueries(expectedQueryWithNullFilters.analyze), optimized)
}

test("joins: conjunctive predicates") {
Expand Down