From 7d87244a2bce20d891135ba64ee408bb5d23c6cd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Feb 2016 21:36:29 -0800 Subject: [PATCH] push filter throughout outer join --- .../sql/catalyst/optimizer/Optimizer.scala | 32 +++++++++++- .../optimizer/FilterPushdownSuite.scala | 50 +++++++++---------- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 567010f23fc8a..ace2cc67108ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -947,6 +947,32 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftEvaluateCondition, rightEvaluateCondition, commonCondition) } + /** + * Returns whether the expression returns null or false when all inputs are nulls. + */ + private def canFilterOutNull(e: Expression): Boolean = { + val attributes = e.references.toSeq + val emptyRow = new GenericInternalRow(attributes.length) + val v = BindReferences.bindReference(e, attributes).eval(emptyRow) + v == null || v == false + } + + /** + * Returns whether the join could be inner join or not. + * + * If a left/right outer join followed by a filter with a condition that could filter out rows + * with null from right/left, the left/right outer join has the same result as inner join, + * should be rewritten as inner join. + */ + private def isInnerJoin( + joinType: JoinType, + leftCond: Seq[Expression], + rightCond: Seq[Expression]): Boolean = { + joinType == Inner || + joinType == RightOuter && leftCond.exists(canFilterOutNull) || + joinType == LeftOuter && rightCond.exists(canFilterOutNull) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => @@ -954,7 +980,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(splitConjunctivePredicates(filterCondition), left, right) joinType match { - case Inner => + case _ if isInnerJoin(joinType, leftFilterConditions, rightFilterConditions) => // push down the single side `where` condition into respective sides val newLeft = leftFilterConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -963,6 +989,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) Join(newLeft, newRight, Inner, newJoinCond) + case RightOuter => // push down the right side only `where` condition val newLeft = left @@ -973,6 +1000,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) + case _ @ (LeftOuter | LeftSemi) => // push down the left side only `where` condition val newLeft = leftFilterConditions. @@ -1080,7 +1108,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ object DecimalAggregates extends Rule[LogicalPlan] { - import Decimal.MAX_LONG_DIGITS + import org.apache.spark.sql.types.Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ private val MAX_DOUBLE_DIGITS = 15 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b49ca928b6292..c5255c2129f87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types.IntegerType @@ -300,13 +300,13 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, LeftOuter) - .where("x.b".attr === 1 && "y.b".attr === 2) + .where("x.b".attr === 1 && "y.b".attr.isNull) } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val correctAnswer = - left.join(y, LeftOuter).where("y.b".attr === 2).analyze + left.join(y, LeftOuter).where("y.b".attr.isNull).analyze comparePlans(optimized, correctAnswer) } @@ -317,13 +317,13 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, RightOuter) - .where("x.b".attr === 1 && "y.b".attr === 2) + .where("x.b".attr.isNull && "y.b".attr === 2) } val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = - x.join(right, RightOuter).where("x.b".attr === 1).analyze + x.join(right, RightOuter).where("x.b".attr.isNull).analyze comparePlans(optimized, correctAnswer) } @@ -334,13 +334,13 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, LeftOuter, Some("x.b".attr === 1)) - .where("x.b".attr === 2 && "y.b".attr === 2) + .where("x.b".attr === 2 && "y.b".attr.isNull) } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('d) val correctAnswer = - left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze + left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr.isNull).analyze comparePlans(optimized, correctAnswer) } @@ -351,13 +351,13 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, RightOuter, Some("y.b".attr === 1)) - .where("x.b".attr === 2 && "y.b".attr === 2) + .where("x.b".attr.isNull && "y.b".attr === 2) } val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = - x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze + x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr.isNull).analyze comparePlans(optimized, correctAnswer) } @@ -368,14 +368,14 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, LeftOuter, Some("y.b".attr === 1)) - .where("x.b".attr === 2 && "y.b".attr === 2) + .where("x.b".attr === 2 && "y.b".attr.isNull) } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = - left.join(right, LeftOuter).where("r.b".attr === 2).analyze + left.join(right, LeftOuter).where("r.b".attr.isNull).analyze comparePlans(optimized, correctAnswer) } @@ -386,13 +386,13 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, RightOuter, Some("y.b".attr === 1)) - .where("x.b".attr === 2 && "y.b".attr === 2) + .where("x.b".attr.isNull && "y.b".attr === 2) } val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = - x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze + x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr.isNull).analyze comparePlans(optimized, correctAnswer) } @@ -403,14 +403,14 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, LeftOuter, Some("y.b".attr === 1)) - .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) + .where("x.b".attr === 2 && "y.b".attr.isNull && "x.c".attr === "y.c".attr) } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = - left.join(right, LeftOuter).where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + left.join(right, LeftOuter).where("r.b".attr.isNull && "l.c".attr === "r.c".attr).analyze comparePlans(optimized, correctAnswer) } @@ -421,7 +421,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, RightOuter, Some("y.b".attr === 1)) - .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) + .where("x.b".attr.isNull && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } val optimized = Optimize.execute(originalQuery.analyze) @@ -429,7 +429,7 @@ class FilterPushdownSuite extends PlanTest { val right = testRelation.where('b === 2).subquery('r) val correctAnswer = left.join(right, RightOuter, Some("r.b".attr === 1)). - where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + where("l.b".attr.isNull && "l.c".attr === "r.c".attr).analyze comparePlans(optimized, correctAnswer) } @@ -439,16 +439,16 @@ class FilterPushdownSuite extends PlanTest { val y = testRelation.subquery('y) val originalQuery = { - x.join(y, LeftOuter, Some("y.b".attr === 1 && "x.a".attr === 3)) + x.join(y, LeftOuter, Some("y.a".attr === 1 && "x.a".attr === 3)) .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) - val right = testRelation.where('b === 1).subquery('r) + val right = testRelation.where('b === 2).subquery('r) val correctAnswer = - left.join(right, LeftOuter, Some("l.a".attr===3)). - where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + left.join(right, Inner, + Some("l.c".attr === "r.c".attr && ("r.a".attr === 1 && "l.a".attr === 3))).analyze comparePlans(optimized, correctAnswer) } @@ -458,16 +458,16 @@ class FilterPushdownSuite extends PlanTest { val y = testRelation.subquery('y) val originalQuery = { - x.join(y, RightOuter, Some("y.b".attr === 1 && "x.a".attr === 3)) + x.join(y, RightOuter, Some("y.a".attr === 1 && "x.a".attr === 3)) .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a === 3).subquery('l) + val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = - left.join(right, RightOuter, Some("r.b".attr === 1)). - where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + left.join(right, Inner, + Some("l.c".attr === "r.c".attr && ("r.a".attr === 1 && "l.a".attr === 3))).analyze comparePlans(optimized, correctAnswer) }