From ea0edd46e080cd0a1c6a1d41374563c149a030f7 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 8 Dec 2015 15:50:25 -0800 Subject: [PATCH 1/3] [SPARK-9372] [SQL] For joins, insert IS NOT NULL filters to children. Some join types and conditions imply that the join keys cannot be NULL and can be filtered out by the children. This patch does this for inner joins and introduces a mechanism to generate predicates. The complex part of doing this is to make sure the transformation is stable. The problem that we want to avoid is generating a filter in the join, having that pushed down and then having the join regenerate the filter. This patch solves this by having the join remember predicates that it has generated. This mechanism should be general enough that we can infer other predicates, for example "a join b where a.id = b.id AND a.id = 10" could also use this mechanism to generate the predicate "b.id = 10". --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 +- .../expressions/EquivalentExpressions.scala | 7 + .../sql/catalyst/optimizer/Optimizer.scala | 73 ++++++++-- .../sql/catalyst/planning/patterns.scala | 37 ++++- .../catalyst/plans/logical/LogicalPlan.scala | 7 + .../plans/logical/basicOperators.scala | 24 +++- .../catalyst/optimizer/JoinFilterSuite.scala | 134 ++++++++++++++++++ .../spark/sql/catalyst/plans/PlanTest.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 12 +- 10 files changed, 276 insertions(+), 30 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca00a5e49f66..06c90ea5b49e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -389,7 +389,7 @@ class Analyzer( a.copy(aggregateExpressions = expanded) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => + case j @ Join(left, right, _, _, _) if !j.selfJoinResolved => val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b2c93d63d67..f6ec7651e74d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -87,12 +87,12 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + case j @ Join(_, _, _, Some(condition), _) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.prettyString}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, _, Some(condition)) => + case j @ Join(_, _, _, Some(condition), _) => def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { case p: Predicate => p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) @@ -190,7 +190,7 @@ trait CheckAnalysis { | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + case j @ Join(left, right, _, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => val conflictingAttributes = left.outputSet.intersect(right.outputSet) failAnalysis( s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f7162e420d19..2ca572bd8540 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -87,6 +87,13 @@ class EquivalentExpressions { equivalenceMap.values.map(_.toSeq).toSeq } + /** + * Returns true if e exists. + */ + def contains(e: Expression): Boolean = { + equivalenceMap.contains(Expr(e)) + } + /** * Returns the state of the data structure as a string. If `all` is false, skips sets of * equivalent expressions with cardinality 1. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a927..0b3e840eaab2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.ExtractNonNullableAttributes import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -52,6 +53,8 @@ object DefaultOptimizer extends Optimizer { ProjectCollapsing, CombineFilters, CombineLimits, + // Predicate inference + AddJoinKeyNullabilityFilters, // Constant folding NullPropagation, OptimizeIn, @@ -233,7 +236,7 @@ object ColumnPruning extends Rule[LogicalPlan] { child)) // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Join(left, right, joinType, condition)) => + case Project(projectList, Join(left, right, joinType, condition, generated)) => // Collect the list of all references required either above or to evaluate the condition. val allReferences: AttributeSet = AttributeSet( @@ -243,15 +246,16 @@ object ColumnPruning extends Rule[LogicalPlan] { /** Applies a projection only when the child is producing unnecessary attributes */ def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) - Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) + Project(projectList, + Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition, generated)) // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case Join(left, right, LeftSemi, condition) => + case Join(left, right, LeftSemi, condition, generated) => // Collect the list of all references required to evaluate the condition. val allReferences: AttributeSet = condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - Join(left, prunedChild(right, allReferences), LeftSemi, condition) + Join(left, prunedChild(right, allReferences), LeftSemi, condition, generated) // Push down project through limit, so that we may have chance to push it further. case Project(projectList, Limit(exp, child)) => @@ -355,6 +359,51 @@ object LikeSimplification extends Rule[LogicalPlan] { } } +/** + * This rule adds IsNotNull predicates based on join keys. If the join contains a condition + * `a` binaryOp *, a is non-nullable. This adds filters for those attributes to the children. + * + * To avoid the problem of repeatedly generating the IsNotNull predicates, the join operator + * remembers all the expressions it generated. + */ +object AddJoinKeyNullabilityFilters extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ Join(left, right, Inner, Some(condition), generated) => { + val nonNullableKeys = ExtractNonNullableAttributes.unapply(condition) + val newKeys = nonNullableKeys.filter { generated.isEmpty || !generated.get.contains(_) } + + if (newKeys.isEmpty) { + j + } else { + val leftKeys = newKeys.filter { canEvaluate(_, left) } + val rightKeys = newKeys.filter { canEvaluate(_, right) } + + if (leftKeys.nonEmpty || rightKeys.nonEmpty) { + val newGenerated = + if (j.generatedExpressions.isDefined) j.generatedExpressions.get + else new EquivalentExpressions + + var newLeft: LogicalPlan = left + var newRight: LogicalPlan = right + + if (leftKeys.nonEmpty) { + newLeft = Filter(leftKeys.map(IsNotNull(_)).reduce(And), left) + leftKeys.foreach { e => newGenerated.addExpr(e) } + } + if (rightKeys.nonEmpty) { + newRight = Filter(rightKeys.map(IsNotNull(_)).reduce(And), right) + rightKeys.foreach { e => newGenerated.addExpr(e) } + } + + Join(newLeft, newRight, Inner, Some(condition), Some(newGenerated)) + } else { + j + } + } + } + } +} + /** * Replaces [[Expression Expressions]] that can be statically evaluated with * equivalent [[Literal]] values. This rule is more specific with @@ -784,7 +833,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter - case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => + case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, generated)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) @@ -797,14 +846,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) - Join(newLeft, newRight, Inner, newJoinCond) + Join(newLeft, newRight, Inner, newJoinCond, generated) case RightOuter => // push down the right side only `where` condition val newLeft = left val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond) + val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, generated) (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -814,7 +863,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = right val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, joinType, newJoinCond) + val newJoin = Join(newLeft, newRight, joinType, newJoinCond, generated) (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -822,7 +871,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } // push down the join filter into sub query scanning if applicable - case f @ Join(left, right, joinType, joinCondition) => + case f @ Join(left, right, joinType, joinCondition, generated) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) @@ -835,7 +884,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, generated) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. @@ -843,7 +892,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = right val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, RightOuter, newJoinCond) + Join(newLeft, newRight, RightOuter, newJoinCond, generated) case LeftOuter => // push down the right side only join filter for right sub query val newLeft = left @@ -851,7 +900,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, LeftOuter, newJoinCond) + Join(newLeft, newRight, LeftOuter, newJoinCond, generated) case FullOuter => f } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd3f15cbe107..779fff5f53f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -95,7 +97,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition) => + case join @ Join(left, right, joinType, condition, _) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. @@ -150,11 +152,11 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { // flatten all inner joins, which are next to each other def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match { - case Join(left, right, Inner, cond) => + case Join(left, right, Inner, cond, _) => val (plans, conditions) = flattenJoin(left) (plans ++ Seq(right), conditions ++ cond.toSeq) - case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) => + case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition, _)) => val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) @@ -162,9 +164,9 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match { - case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) => + case f @ Filter(filterCondition, j @ Join(_, _, Inner, _, _)) => Some(flattenJoin(f)) - case j @ Join(_, _, Inner, _) => + case j @ Join(_, _, Inner, _, _) => Some(flattenJoin(j)) case _ => None } @@ -184,3 +186,28 @@ object Unions { case other => other :: Nil } } + +/** + * A pattern that finds all attributes in `expr` that cannot be nullable. + */ +object ExtractNonNullableAttributes extends Logging with PredicateHelper { + def unapply(condition: Expression): Set[Attribute] = { + val predicates = splitConjunctivePredicates(condition) + + val result = mutable.HashSet.empty[Attribute] + def extract(e: Expression): Unit = e match { + case IsNotNull(a: Attribute) => result.add(a) + case BinaryComparison(a: Attribute, b: Attribute) => { + if (!e.isInstanceOf[EqualNullSafe]) { + result.add(a) + result.add(b) + } + } + case BinaryComparison(a: Attribute, _) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case BinaryComparison(_, a: Attribute) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case Not(child) => extract(child) + } + predicates.foreach { extract(_) } + result.toSet + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e10593..b8dbc4a78790 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -99,6 +99,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved + /** + * Returns true if the two plans are semantically equal. This should ignore state generated + * during planning to help the planning process. + * TODO: implement this as a pass that canonicalizes the plan tree instead? + */ + def semanticEquals(other: LogicalPlan): Boolean = this == other + override protected def statePrefix = if (!resolved) "'" else super.statePrefix /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5665fd7e5f41..65bd4d7904f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -122,11 +122,22 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output } +object Join { + def apply( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]): Join = { + Join(left, right, joinType, condition, None) + } +} + case class Join( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + condition: Option[Expression], + generatedExpressions: Option[EquivalentExpressions]) extends BinaryNode { override def output: Seq[Attribute] = { joinType match { @@ -152,6 +163,17 @@ case class Join( selfJoinResolved && condition.forall(_.dataType == BooleanType) } + + override def simpleString: String = s"$nodeName $joinType, $condition".trim + + override def semanticEquals(other: LogicalPlan): Boolean = { + other match { + case Join (l, r, joinType, condition, _) => { + l == left && r == right && this.joinType == joinType && this.condition == condition + } + case _ => false + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala new file mode 100644 index 000000000000..51b89775c535 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class JoinFilterSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Filter Pushdown", Once, + SamplePushDown, + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + ColumnPruning, + ProjectCollapsing, + AddJoinKeyNullabilityFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val testRelation1 = LocalRelation('d.int) + + test("joins infer is NOT NULL on equijoin keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr === "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.b".attr), y), Inner, Some("x.b".attr === "y.b".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL on join keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr >= "y.b".attr).where("x.b".attr < "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.b".attr), y), Inner, + Some(And("x.b".attr >= "y.b".attr, "x.b".attr < "y.b".attr))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL on not equal keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr !== "y.a".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.a".attr), y), Inner, Some("x.b".attr !== "y.a".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL with multiple joins") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + val t3 = testRelation.subquery('t3) + val t4 = testRelation.subquery('t4) + + val originalQuery = t1.join(t2).join(t3).join(t4) + .where("t1.b".attr === "t2.b".attr) + .where("t1.b".attr === "t3.b".attr) + .where("t1.b".attr === "t4.b".attr) + .where("t2.b".attr === "t3.b".attr) + .where("t2.b".attr === "t4.b".attr) + .where("t3.b".attr === "t4.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + var numFilters = 0 + optimized.foreach { p => p match { + case Filter(condition, _) => { + var containsIsNotNull = false + condition.foreach { c => + if (c.isInstanceOf[IsNotNull]) containsIsNotNull = true + } + if (containsIsNotNull) { + // If the condition contained IsNotNull, then it must be generated. Verify the entire + // condition is IsNotNull (to verify IsNotNull is not repeated) and that we generate the + // expected number of filters. + assert(condition.isInstanceOf[IsNotNull]) + numFilters += 1 + } + } + case _ => + }} + assertResult(4)(numFilters) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2efee1fc5470..c51a7b375fd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -43,7 +43,7 @@ abstract class PlanTest extends SparkFunSuite { protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizeExprIds(plan1) val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) { + if (!normalized1.semanticEquals(normalized2)) { fail( s""" |== FAIL: Plans do not match === diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 25e98c0bdd43..5e0824d58180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -43,7 +43,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // no predicate can be evaluated by matching hash keys - case logical.Join(left, right, LeftSemi, condition) => + case logical.Join(left, right, LeftSemi, condition, _) => joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } @@ -234,11 +234,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join( - CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + CanBroadcast(left), right, joinType, condition, _) if joinType != LeftSemi => execution.joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil case logical.Join( - left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + left, CanBroadcast(right), joinType, condition, _) if joinType != LeftSemi => execution.joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil @@ -248,9 +248,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // TODO CartesianProduct doesn't support the Left Semi Join - case logical.Join(left, right, joinType, None) if joinType != LeftSemi => + case logical.Join(left, right, joinType, None, _) if joinType != LeftSemi => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil - case logical.Join(left, right, Inner, Some(condition)) => + case logical.Join(left, right, Inner, Some(condition), _) => execution.Filter(condition, execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil @@ -259,7 +259,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DefaultJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => + case logical.Join(left, right, joinType, condition, _) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { joins.BuildRight From 00e957c83ca2983941170e503e926535f2c8a91c Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 9 Dec 2015 21:54:41 -0800 Subject: [PATCH 2/3] Long line fix. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f6ec7651e74d..e81ebe60e404 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -190,14 +190,15 @@ trait CheckAnalysis { | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - failAnalysis( - s""" - |Failure when resolving conflicting references in Join: - |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) + case j @ Join(left, right, _, _, _) + if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) case o if !o.resolved => failAnalysis( From fb562fb67a761276456b14a81513f3fc69a6ead8 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Thu, 10 Dec 2015 00:02:39 -0800 Subject: [PATCH 3/3] Fix pattern with casts and add more test cases. --- .../sql/catalyst/planning/patterns.scala | 11 +++++++ .../catalyst/optimizer/JoinFilterSuite.scala | 32 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 779fff5f53f9..ed4fcb2c5f1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -203,9 +203,20 @@ object ExtractNonNullableAttributes extends Logging with PredicateHelper { result.add(b) } } + case BinaryComparison(Cast(a: Attribute, _), Cast(b: Attribute, _)) => { + if (!e.isInstanceOf[EqualNullSafe]) { + result.add(a) + result.add(b) + } + } case BinaryComparison(a: Attribute, _) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a) case BinaryComparison(_, a: Attribute) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case BinaryComparison(Cast(a: Attribute, _), _) => + if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case BinaryComparison(_, Cast(a: Attribute, _)) => + if (!e.isInstanceOf[EqualNullSafe]) result.add(a) case Not(child) => extract(child) + case _ => } predicates.foreach { extract(_) } result.toSet diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala index 51b89775c535..bff414adbb57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.DoubleType class JoinFilterSuite extends PlanTest { @@ -64,6 +65,37 @@ class JoinFilterSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins infer is NOT NULL one key") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr + 1 === "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = x.join( + Filter(IsNotNull("y.b".attr), y), Inner, Some("x.b".attr + 1 === "y.b".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL for cast") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where(Cast("x.b".attr, DoubleType) === "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.b".attr), y), Inner, + Some(Cast("x.b".attr, DoubleType) === Cast("y.b".attr, DoubleType))).analyze + comparePlans(optimized, correctAnswer) + } + test("joins infer is NOT NULL on join keys") { val x = testRelation.subquery('x) val y = testRelation.subquery('y)