diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1f05f2065c949..6165663cd0436 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -945,8 +945,10 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * Elimination of outer joins, if the predicates can restrict the result sets so that - * all null-supplying rows are eliminated + * Elimination of Outer Joins + * + * Rule Set 1: checking the Filter Condition if the predicates can restrict the result sets + * so that all null-supplying rows are eliminated * * - full outer -> inner if both sides have such predicates * - left outer -> inner if the right side has such predicates @@ -954,7 +956,21 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { * - full outer -> left outer if only the left side has such predicates * - full outer -> right outer if only the right side has such predicates * - * This rule should be executed before pushing down the Filter + * This rule set should be executed before pushing down the Filter + * + * Rule Set 2: given an outer join is involved in another join (called parent join), when the join + * type of the parent join is inner, left-semi, left-outer and right-outer, checking if the join + * condition of the parent join satisfies the following two conditions: + * 1) there exist null filtering predicates against the columns in the null-supplying side of + * parent join. + * 2) these columns are from the child join. + * + * If having such join predicates, execute the elimination rules: + * - full outer -> inner if both sides of the child join have such predicates + * - left outer -> inner if the right side of the child join has such predicates + * - right outer -> inner if the left side of the child join has such predicates + * - full outer -> left outer if only the left side of the child join has such predicates + * - full outer -> right outer if only the right side of the child join has such predicates */ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { @@ -969,18 +985,21 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { v == null || v == false } - private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val splitConjunctiveConditions: Seq[Expression] = splitConjunctivePredicates(filter.condition) + private def buildNewJoinType( + condition: Expression, + constraints: Set[Expression], + join: Join): JoinType = { + val splitConjunctiveConditions: Seq[Expression] = splitConjunctivePredicates(condition) val leftConditions = splitConjunctiveConditions .filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = splitConjunctiveConditions .filter(_.references.subsetOf(join.right.outputSet)) val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) || - filter.constraints.filter(_.isInstanceOf[IsNotNull]) + constraints.filter(_.isInstanceOf[IsNotNull]) .exists(expr => join.left.outputSet.intersect(expr.references).nonEmpty) val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) || - filter.constraints.filter(_.isInstanceOf[IsNotNull]) + constraints.filter(_.isInstanceOf[IsNotNull]) .exists(expr => join.right.outputSet.intersect(expr.references).nonEmpty) join.joinType match { @@ -994,9 +1013,37 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Rule Set 1: elimination using Filter conditions/constraints case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => - val newJoinType = buildNewJoinType(f, j) + val newJoinType = buildNewJoinType(f.condition, f.constraints, j) if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) + + // Rule Set 2: elimination using Parent Join conditions/constraints + // Case 1: when parent join is Inner|LeftSemi|LeftOuter and the child join is on the right side + case pj @ Join( + _, + j @ Join(left, right, RightOuter|LeftOuter|FullOuter, condition), + Inner | LeftSemi | LeftOuter, + Some(pJoinCond)) => + val newJoinType = buildNewJoinType(pJoinCond, pj.constraints, j) + if (j.joinType == newJoinType) { + pj + } else { + Join(pj.left, j.copy(joinType = newJoinType), pj.joinType, pj.condition) + } + + // Case 2: when parent join is Inner|LeftSemi|RightOuter and the child join is on the left side + case pj @ Join( + j @ Join(left, right, RightOuter|LeftOuter|FullOuter, condition), + _, + Inner | LeftSemi | RightOuter, + Some(pJoinCond)) => + val newJoinType = buildNewJoinType(pJoinCond, pj.constraints, j) + if (j.joinType == newJoinType) { + pj + } else { + Join(j.copy(joinType = newJoinType), pj.right, pj.joinType, pj.condition) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 067a62d011ec4..514d53fb5e020 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Join, Project} import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -204,4 +204,157 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { leftJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) } + + test("join - left outer to inner by the parent join's join condition") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + val df3 = Seq((1, 3, "1"), (3, 6, "5")).toDF("int", "int2", "str").as("c") + + // Left -> Inner + val right = df.join(df2, $"a.int" === $"b.int", "left") + val left2Inner = + df3.join(right, $"c.int" === $"b.int", "inner").select($"a.*", $"b.*", $"c.*") + + left2Inner.explain(true) + + // The order before conversion: Left Then Inner + assert(left2Inner.queryExecution.analyzed.collect { + case j @ Join(_, Join(_, _, LeftOuter, _), Inner, _) => j + }.size === 1) + + // The order after conversion: Inner Then Inner + assert(left2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, Join(_, _, Inner, _), Inner, _) => j + }.size === 1) + + checkAnswer( + left2Inner, + Row(1, 2, "1", 1, 3, "1", 1, 3, "1") :: Nil) + } + + test("join - right outer to inner by the parent join's join condition") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + val df3 = Seq((1, 9, "8"), (5, 0, "4")).toDF("int", "int2", "str").as("c") + + // Right Then Inner -> Inner Then Right + val right2Inner = df.join(df2, $"a.int" === $"b.int", "right") + .join(df3, $"a.int" === $"b.int", "inner").select($"a.*", $"b.*", $"c.*") + + // The order before conversion: Left Then Inner + assert(right2Inner.queryExecution.analyzed.collect { + case j @ Join(Join(_, _, RightOuter, _), _, Inner, _) => j + }.size === 1) + + // The order after conversion: Inner Then Inner + assert(right2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(Join(_, _, Inner, _), _, Inner, _) => j + }.size === 1) + + checkAnswer( + right2Inner, + Row(1, 2, "1", 1, 3, "1", 1, 9, "8") :: + Row(1, 2, "1", 1, 3, "1", 5, 0, "4") :: Nil) + } + + test("join - full outer to inner by the parent join's join condition") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 2, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + val df3 = Seq((1, 3, "1"), (3, 6, "5")).toDF("int", "int2", "str").as("c") + + // Full -> Inner + val right = df.join(df2, $"a.int" === $"b.int", "full") + val full2Inner = df3.join(right, $"c.int" === $"a.int" && $"b.int" === 1, "inner") + .select($"a.*", $"b.*", $"c.*") + + // The order before conversion: Left Then Inner + assert(full2Inner.queryExecution.analyzed.collect { + case j @ Join(_, Join(_, _, FullOuter, _), Inner, _) => j + }.size === 1) + + // The order after conversion: Inner Then Inner + assert(full2Inner.queryExecution.optimizedPlan.collect { + case j @ Join(_, Join(_, _, Inner, _), Inner, _) => j + }.size === 1) + + checkAnswer( + full2Inner, + Row(1, 2, "1", 1, 2, "1", 1, 3, "1") :: Nil) + } + + test("join - full outer to right by the parent join's join condition") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 2, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + val df3 = Seq((1, 3, "1"), (3, 6, "5")).toDF("int", "int2", "str").as("c") + + // Full -> Right + val right = df.join(df2, $"a.int" === $"b.int", "full") + val full2Right = df3.join(right, $"b.int" === 1, "leftsemi") + + // The order before conversion: Left Then Inner + assert(full2Right.queryExecution.analyzed.collect { + case j @ Join(_, Join(_, _, FullOuter, _), LeftSemi, _) => j + }.size === 1) + + // The order after conversion: Inner Then Inner + assert(full2Right.queryExecution.optimizedPlan.collect { + case j @ Join(_, Project(_, Join(_, _, RightOuter, _)), LeftSemi, _) => j + }.size === 1) + + checkAnswer( + full2Right, + Row(1, 3, "1") :: Row(3, 6, "5") :: Nil) + } + + + test("join - full outer to left by the parent join's join condition #1") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 2, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + val df3 = Seq((1, 3, "1"), (4, 6, "5")).toDF("int", "int2", "str").as("c") + + // Full -> Left + val right = df.join(df2, $"a.int" === $"b.int", "full") + val full2Left = df3.join(right, lit(3) === $"a.int", "left") + .select($"a.*", $"b.*", $"c.*") + + // The order before conversion: Full Then Left + assert(full2Left.queryExecution.analyzed.collect { + case j @ Join(_, Join(_, _, FullOuter, _), LeftOuter, _) => j + }.size === 1) + + // The order after conversion: Left Then Left + assert(full2Left.queryExecution.optimizedPlan.collect { + case j @ Join(_, Join(_, _, LeftOuter, _), LeftOuter, _) => j + }.size === 1) + + checkAnswer( + full2Left, + Row(3, 4, "3", null, null, null, 1, 3, "1") :: + Row(3, 4, "3", null, null, null, 4, 6, "5") :: Nil) + } + + test("join - full outer to left by the parent join's join condition #2") { + val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") + val df2 = Seq((1, 2, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") + val df3 = Seq((1, 3, "1"), (4, 6, "5")).toDF("int", "int2", "str").as("c") + + // Full -> Left + val full2Left = df.join(df2, $"a.int" === $"b.int", "full") + .join(df3, lit(3) === $"a.int", "right").select($"a.*", $"b.*", $"c.*") + + // The order before conversion: Full Then Right + assert(full2Left.queryExecution.analyzed.collect { + case j @ Join(Join(_, _, FullOuter, _), _, RightOuter, _) => j + }.size === 1) + + // The order after conversion: Left Then Right + assert(full2Left.queryExecution.optimizedPlan.collect { + case j @ Join(Join(_, _, LeftOuter, _), _, RightOuter, _) => j + }.size === 1) + + checkAnswer( + full2Left, + Row(3, 4, "3", null, null, null, 1, 3, "1") :: + Row(3, 4, "3", null, null, null, 4, 6, "5") :: Nil) + } }