diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca00a5e49f668..06c90ea5b49e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -389,7 +389,7 @@ class Analyzer( a.copy(aggregateExpressions = expanded) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => + case j @ Join(left, right, _, _, _) if !j.selfJoinResolved => val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b2c93d63d673..e81ebe60e4041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -87,12 +87,12 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + case j @ Join(_, _, _, Some(condition), _) if condition.dataType != BooleanType => failAnalysis( s"join condition '${condition.prettyString}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, _, Some(condition)) => + case j @ Join(_, _, _, Some(condition), _) => def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { case p: Predicate => p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) @@ -190,14 +190,15 @@ trait CheckAnalysis { | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - failAnalysis( - s""" - |Failure when resolving conflicting references in Join: - |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) + case j @ Join(left, right, _, _, _) + if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) case o if !o.resolved => failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f7162e420d19a..2ca572bd85409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -87,6 +87,13 @@ class EquivalentExpressions { equivalenceMap.values.map(_.toSeq).toSeq } + /** + * Returns true if e exists. + */ + def contains(e: Expression): Boolean = { + equivalenceMap.contains(Expr(e)) + } + /** * Returns the state of the data structure as a string. If `all` is false, skips sets of * equivalent expressions with cardinality 1. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a9276..0b3e840eaab21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueri import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.planning.ExtractNonNullableAttributes import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -52,6 +53,8 @@ object DefaultOptimizer extends Optimizer { ProjectCollapsing, CombineFilters, CombineLimits, + // Predicate inference + AddJoinKeyNullabilityFilters, // Constant folding NullPropagation, OptimizeIn, @@ -233,7 +236,7 @@ object ColumnPruning extends Rule[LogicalPlan] { child)) // Eliminate unneeded attributes from either side of a Join. - case Project(projectList, Join(left, right, joinType, condition)) => + case Project(projectList, Join(left, right, joinType, condition, generated)) => // Collect the list of all references required either above or to evaluate the condition. val allReferences: AttributeSet = AttributeSet( @@ -243,15 +246,16 @@ object ColumnPruning extends Rule[LogicalPlan] { /** Applies a projection only when the child is producing unnecessary attributes */ def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences) - Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition)) + Project(projectList, + Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition, generated)) // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case Join(left, right, LeftSemi, condition) => + case Join(left, right, LeftSemi, condition, generated) => // Collect the list of all references required to evaluate the condition. val allReferences: AttributeSet = condition.map(_.references).getOrElse(AttributeSet(Seq.empty)) - Join(left, prunedChild(right, allReferences), LeftSemi, condition) + Join(left, prunedChild(right, allReferences), LeftSemi, condition, generated) // Push down project through limit, so that we may have chance to push it further. case Project(projectList, Limit(exp, child)) => @@ -355,6 +359,51 @@ object LikeSimplification extends Rule[LogicalPlan] { } } +/** + * This rule adds IsNotNull predicates based on join keys. If the join contains a condition + * `a` binaryOp *, a is non-nullable. This adds filters for those attributes to the children. + * + * To avoid the problem of repeatedly generating the IsNotNull predicates, the join operator + * remembers all the expressions it generated. + */ +object AddJoinKeyNullabilityFilters extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ Join(left, right, Inner, Some(condition), generated) => { + val nonNullableKeys = ExtractNonNullableAttributes.unapply(condition) + val newKeys = nonNullableKeys.filter { generated.isEmpty || !generated.get.contains(_) } + + if (newKeys.isEmpty) { + j + } else { + val leftKeys = newKeys.filter { canEvaluate(_, left) } + val rightKeys = newKeys.filter { canEvaluate(_, right) } + + if (leftKeys.nonEmpty || rightKeys.nonEmpty) { + val newGenerated = + if (j.generatedExpressions.isDefined) j.generatedExpressions.get + else new EquivalentExpressions + + var newLeft: LogicalPlan = left + var newRight: LogicalPlan = right + + if (leftKeys.nonEmpty) { + newLeft = Filter(leftKeys.map(IsNotNull(_)).reduce(And), left) + leftKeys.foreach { e => newGenerated.addExpr(e) } + } + if (rightKeys.nonEmpty) { + newRight = Filter(rightKeys.map(IsNotNull(_)).reduce(And), right) + rightKeys.foreach { e => newGenerated.addExpr(e) } + } + + Join(newLeft, newRight, Inner, Some(condition), Some(newGenerated)) + } else { + j + } + } + } + } +} + /** * Replaces [[Expression Expressions]] that can be statically evaluated with * equivalent [[Literal]] values. This rule is more specific with @@ -784,7 +833,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter - case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => + case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition, generated)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) @@ -797,14 +846,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) - Join(newLeft, newRight, Inner, newJoinCond) + Join(newLeft, newRight, Inner, newJoinCond, generated) case RightOuter => // push down the right side only `where` condition val newLeft = left val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond) + val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond, generated) (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -814,7 +863,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = right val newJoinCond = joinCondition - val newJoin = Join(newLeft, newRight, joinType, newJoinCond) + val newJoin = Join(newLeft, newRight, joinType, newJoinCond, generated) (rightFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) @@ -822,7 +871,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } // push down the join filter into sub query scanning if applicable - case f @ Join(left, right, joinType, joinCondition) => + case f @ Join(left, right, joinType, joinCondition, generated) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) @@ -835,7 +884,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond, generated) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. @@ -843,7 +892,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = right val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, RightOuter, newJoinCond) + Join(newLeft, newRight, RightOuter, newJoinCond, generated) case LeftOuter => // push down the right side only join filter for right sub query val newLeft = left @@ -851,7 +900,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, LeftOuter, newJoinCond) + Join(newLeft, newRight, LeftOuter, newJoinCond, generated) case FullOuter => f } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd3f15cbe107b..ed4fcb2c5f1d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -95,7 +97,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition) => + case join @ Join(left, right, joinType, condition, _) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. @@ -150,11 +152,11 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { // flatten all inner joins, which are next to each other def flattenJoin(plan: LogicalPlan): (Seq[LogicalPlan], Seq[Expression]) = plan match { - case Join(left, right, Inner, cond) => + case Join(left, right, Inner, cond, _) => val (plans, conditions) = flattenJoin(left) (plans ++ Seq(right), conditions ++ cond.toSeq) - case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition)) => + case Filter(filterCondition, j @ Join(left, right, Inner, joinCondition, _)) => val (plans, conditions) = flattenJoin(j) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) @@ -162,9 +164,9 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { } def unapply(plan: LogicalPlan): Option[(Seq[LogicalPlan], Seq[Expression])] = plan match { - case f @ Filter(filterCondition, j @ Join(_, _, Inner, _)) => + case f @ Filter(filterCondition, j @ Join(_, _, Inner, _, _)) => Some(flattenJoin(f)) - case j @ Join(_, _, Inner, _) => + case j @ Join(_, _, Inner, _, _) => Some(flattenJoin(j)) case _ => None } @@ -184,3 +186,39 @@ object Unions { case other => other :: Nil } } + +/** + * A pattern that finds all attributes in `expr` that cannot be nullable. + */ +object ExtractNonNullableAttributes extends Logging with PredicateHelper { + def unapply(condition: Expression): Set[Attribute] = { + val predicates = splitConjunctivePredicates(condition) + + val result = mutable.HashSet.empty[Attribute] + def extract(e: Expression): Unit = e match { + case IsNotNull(a: Attribute) => result.add(a) + case BinaryComparison(a: Attribute, b: Attribute) => { + if (!e.isInstanceOf[EqualNullSafe]) { + result.add(a) + result.add(b) + } + } + case BinaryComparison(Cast(a: Attribute, _), Cast(b: Attribute, _)) => { + if (!e.isInstanceOf[EqualNullSafe]) { + result.add(a) + result.add(b) + } + } + case BinaryComparison(a: Attribute, _) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case BinaryComparison(_, a: Attribute) => if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case BinaryComparison(Cast(a: Attribute, _), _) => + if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case BinaryComparison(_, Cast(a: Attribute, _)) => + if (!e.isInstanceOf[EqualNullSafe]) result.add(a) + case Not(child) => extract(child) + case _ => + } + predicates.foreach { extract(_) } + result.toSet + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e105932..b8dbc4a78790a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -99,6 +99,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved + /** + * Returns true if the two plans are semantically equal. This should ignore state generated + * during planning to help the planning process. + * TODO: implement this as a pass that canonicalizes the plan tree instead? + */ + def semanticEquals(other: LogicalPlan): Boolean = this == other + override protected def statePrefix = if (!resolved) "'" else super.statePrefix /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5665fd7e5f419..65bd4d7904f94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -122,11 +122,22 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override def output: Seq[Attribute] = left.output } +object Join { + def apply( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]): Join = { + Join(left, right, joinType, condition, None) + } +} + case class Join( left: LogicalPlan, right: LogicalPlan, joinType: JoinType, - condition: Option[Expression]) extends BinaryNode { + condition: Option[Expression], + generatedExpressions: Option[EquivalentExpressions]) extends BinaryNode { override def output: Seq[Attribute] = { joinType match { @@ -152,6 +163,17 @@ case class Join( selfJoinResolved && condition.forall(_.dataType == BooleanType) } + + override def simpleString: String = s"$nodeName $joinType, $condition".trim + + override def semanticEquals(other: LogicalPlan): Boolean = { + other match { + case Join (l, r, joinType, condition, _) => { + l == left && r == right && this.joinType == joinType && this.condition == condition + } + case _ => false + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala new file mode 100644 index 0000000000000..bff414adbb579 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFilterSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.DoubleType + +class JoinFilterSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Filter Pushdown", Once, + SamplePushDown, + CombineFilters, + PushPredicateThroughProject, + BooleanSimplification, + PushPredicateThroughJoin, + PushPredicateThroughGenerate, + PushPredicateThroughAggregate, + ColumnPruning, + ProjectCollapsing, + AddJoinKeyNullabilityFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + val testRelation1 = LocalRelation('d.int) + + test("joins infer is NOT NULL on equijoin keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr === "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.b".attr), y), Inner, Some("x.b".attr === "y.b".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL one key") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr + 1 === "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = x.join( + Filter(IsNotNull("y.b".attr), y), Inner, Some("x.b".attr + 1 === "y.b".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL for cast") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where(Cast("x.b".attr, DoubleType) === "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.b".attr), y), Inner, + Some(Cast("x.b".attr, DoubleType) === Cast("y.b".attr, DoubleType))).analyze + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL on join keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr >= "y.b".attr).where("x.b".attr < "y.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.b".attr), y), Inner, + Some(And("x.b".attr >= "y.b".attr, "x.b".attr < "y.b".attr))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL on not equal keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = x.join(y). + where("x.b".attr !== "y.a".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + Filter(IsNotNull("x.b".attr), x).join( + Filter(IsNotNull("y.a".attr), y), Inner, Some("x.b".attr !== "y.a".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins infer is NOT NULL with multiple joins") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + val t3 = testRelation.subquery('t3) + val t4 = testRelation.subquery('t4) + + val originalQuery = t1.join(t2).join(t3).join(t4) + .where("t1.b".attr === "t2.b".attr) + .where("t1.b".attr === "t3.b".attr) + .where("t1.b".attr === "t4.b".attr) + .where("t2.b".attr === "t3.b".attr) + .where("t2.b".attr === "t4.b".attr) + .where("t3.b".attr === "t4.b".attr) + + val optimized = Optimize.execute(originalQuery.analyze) + var numFilters = 0 + optimized.foreach { p => p match { + case Filter(condition, _) => { + var containsIsNotNull = false + condition.foreach { c => + if (c.isInstanceOf[IsNotNull]) containsIsNotNull = true + } + if (containsIsNotNull) { + // If the condition contained IsNotNull, then it must be generated. Verify the entire + // condition is IsNotNull (to verify IsNotNull is not repeated) and that we generate the + // expected number of filters. + assert(condition.isInstanceOf[IsNotNull]) + numFilters += 1 + } + } + case _ => + }} + assertResult(4)(numFilters) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2efee1fc54706..c51a7b375fd8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -43,7 +43,7 @@ abstract class PlanTest extends SparkFunSuite { protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizeExprIds(plan1) val normalized2 = normalizeExprIds(plan2) - if (normalized1 != normalized2) { + if (!normalized1.semanticEquals(normalized2)) { fail( s""" |== FAIL: Plans do not match === diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 25e98c0bdd431..5e0824d58180f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -43,7 +43,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { joins.LeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // no predicate can be evaluated by matching hash keys - case logical.Join(left, right, LeftSemi, condition) => + case logical.Join(left, right, LeftSemi, condition, _) => joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil case _ => Nil } @@ -234,11 +234,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object BroadcastNestedLoop extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join( - CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi => + CanBroadcast(left), right, joinType, condition, _) if joinType != LeftSemi => execution.joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil case logical.Join( - left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi => + left, CanBroadcast(right), joinType, condition, _) if joinType != LeftSemi => execution.joins.BroadcastNestedLoopJoin( planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil case _ => Nil @@ -248,9 +248,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object CartesianProduct extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // TODO CartesianProduct doesn't support the Left Semi Join - case logical.Join(left, right, joinType, None) if joinType != LeftSemi => + case logical.Join(left, right, joinType, None, _) if joinType != LeftSemi => execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil - case logical.Join(left, right, Inner, Some(condition)) => + case logical.Join(left, right, Inner, Some(condition), _) => execution.Filter(condition, execution.joins.CartesianProduct(planLater(left), planLater(right))) :: Nil case _ => Nil @@ -259,7 +259,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DefaultJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Join(left, right, joinType, condition) => + case logical.Join(left, right, joinType, condition, _) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { joins.BuildRight