From 4f528770ecf4a2ae780d6514fdc8c5e7cf899288 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 3 Aug 2015 22:33:59 -0700 Subject: [PATCH 01/23] Eliminate outer join before project --- .../sql/catalyst/optimizer/Optimizer.scala | 25 ++++++ ...EliminateOuterJoinBeforeProjectSuite.scala | 82 +++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateOuterJoinBeforeProjectSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 29d706dcb39a7..95f6184cbff88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -49,6 +49,7 @@ object DefaultOptimizer extends Optimizer { ColumnPruning, // Operator combine ProjectCollapsing, + EliminateOuterJoinBeforeProject, CombineFilters, CombineLimits, // Constant folding @@ -266,6 +267,30 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +/** + * Eliminates [[LeftOuter]] and [[RightOuter]] joins when followed by a [[Project]] that keeps only + * the left or right columns, respectively. + */ +object EliminateOuterJoinBeforeProject extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p @ Project(projectList, + j @ Join(left, right, joinType @ (LeftOuter | RightOuter), condition)) => + val projectReferences = AttributeSet(projectList) + + val child = joinType match { + case LeftOuter => left + case RightOuter => right + } + val joinList = child.outputSet + + if (projectReferences.subsetOf(joinList)) { + Project(projectList, child) + } else { + p + } + } +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateOuterJoinBeforeProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateOuterJoinBeforeProjectSuite.scala new file mode 100644 index 0000000000000..805ea3003a748 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateOuterJoinBeforeProjectSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class EliminateOuterJoinBeforeProjectSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("EliminateOuterJoinBeforeProject", Once, EliminateOuterJoinBeforeProject) :: Nil + } + + val testRelation1 = LocalRelation('a.int, 'b.int) + val testRelation2 = LocalRelation('c.int, 'd.int) + + test("collapse left outer join followed by subset project") { + val query = testRelation1 + .join(testRelation2, LeftOuter, Some('a === 'c)) + .select('a, 'b) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation1.select('a, 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse right outer join followed by subset project") { + val query = testRelation1 + .join(testRelation2, RightOuter, Some('a === 'c)) + .select('c, 'd) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation2.select('c, 'd).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse outer join followed by subset project with expressions") { + val query = testRelation1 + .join(testRelation2, LeftOuter, Some('a === 'c)) + .select(('a + 1).as('a), ('b + 2).as('b)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation1.select(('a + 1).as('a), ('b + 2).as('b)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse non-subset project") { + val query = testRelation1 + .join(testRelation2, LeftOuter, Some('a === 'c)) + .select('a, 'b, ('c + 1).as('c), 'd) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} From ae46ab0891e974f6491d4b266f08d95d7a1c1382 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 12 Aug 2015 13:15:50 -0700 Subject: [PATCH 02/23] Use KeyHint to do join elimination --- .../sql/catalyst/optimizer/Optimizer.scala | 38 +++++++++---------- .../catalyst/plans/logical/LogicalPlan.scala | 2 + .../plans/logical/basicOperators.scala | 18 +++++++++ ...Suite.scala => JoinEliminationSuite.scala} | 29 ++++++++++---- .../org/apache/spark/sql/DataFrame.scala | 6 +++ .../spark/sql/execution/SparkStrategies.scala | 1 + .../org/apache/spark/sql/DataFrameSuite.scala | 4 ++ 7 files changed, 71 insertions(+), 27 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{EliminateOuterJoinBeforeProjectSuite.scala => JoinEliminationSuite.scala} (75%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 95f6184cbff88..8df2041929ae0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.RightOuter @@ -49,7 +50,7 @@ object DefaultOptimizer extends Optimizer { ColumnPruning, // Operator combine ProjectCollapsing, - EliminateOuterJoinBeforeProject, + JoinElimination, CombineFilters, CombineLimits, // Constant folding @@ -268,27 +269,24 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } /** - * Eliminates [[LeftOuter]] and [[RightOuter]] joins when followed by a [[Project]] that keeps only - * the left or right columns, respectively. + * Eliminates keyed equi-joins when followed by a [[Project]] that keeps only columns from one side. */ -object EliminateOuterJoinBeforeProject extends Rule[LogicalPlan] { +object JoinElimination extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ Project(projectList, - j @ Join(left, right, joinType @ (LeftOuter | RightOuter), condition)) => - val projectReferences = AttributeSet(projectList) - - val child = joinType match { - case LeftOuter => left - case RightOuter => right - } - val joinList = child.outputSet - - if (projectReferences.subsetOf(joinList)) { - Project(projectList, child) - } else { - p - } - } + // Left outer join where only the left columns are kept, and a key from the right is involved in + // the join so no duplicates are generated + case Project(projectList, ExtractEquiJoinKeys(LeftOuter, _, rightKeys, _, left, right)) + if AttributeSet(projectList).subsetOf(left.outputSet) + && AttributeSet(right.keys).intersect(AttributeSet(rightKeys)).nonEmpty => + Project(projectList, left) + + // Right outer join where only the right columns are kept, and a key from the left is involved in + // the join so no duplicates are generated + case Project(projectList, ExtractEquiJoinKeys(RightOuter, leftKeys, _, _, left, right)) + if AttributeSet(projectList).subsetOf(right.outputSet) + && AttributeSet(left.keys).intersect(AttributeSet(leftKeys)).nonEmpty => + Project(projectList, right) +} } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index bedeaf06adf12..d7d78e160fd53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.trees.TreeNode abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { + def keys: Seq[Expression] = Seq.empty + /** * Computes [[Statistics]] for this plan. The default implementation assumes the output * cardinality is the product of of all child plan's cardinality, i.e. applies in the case diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index aacfc86ab0e49..326b39d260a6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,9 +23,18 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet +case class KeyHint( + override val keys: Seq[Expression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def keys: Seq[Expression] = + child.keys.filter(k => projectList.exists(_.semanticEquals(k))) + override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { case agg: AggregateExpression => agg @@ -87,6 +96,7 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def keys: Seq[Expression] = child.keys } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -140,12 +150,16 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def keys: Seq[Expression] = child.keys } case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output + override def keys: Seq[Expression] = left.keys + override lazy val resolved: Boolean = childrenResolved && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } @@ -199,6 +213,8 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def keys: Seq[Expression] = child.keys + def hasNoEvaluation: Boolean = order.forall(_.child.isInstanceOf[AttributeReference]) override lazy val resolved: Boolean = @@ -416,6 +432,7 @@ case class Sample( */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def keys: Seq[Expression] = child.keys } /** @@ -427,6 +444,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def keys: Seq[Expression] = child.keys } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateOuterJoinBeforeProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala similarity index 75% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateOuterJoinBeforeProjectSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala index 805ea3003a748..f7f0b8be53985 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateOuterJoinBeforeProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -18,27 +18,31 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.KeyHint +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -class EliminateOuterJoinBeforeProjectSuite extends PlanTest { +class JoinEliminationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: - Batch("EliminateOuterJoinBeforeProject", Once, EliminateOuterJoinBeforeProject) :: Nil + Batch("JoinElimination", Once, JoinElimination) :: Nil } val testRelation1 = LocalRelation('a.int, 'b.int) val testRelation2 = LocalRelation('c.int, 'd.int) + val testRelation1K = KeyHint(List(testRelation1.output.head), testRelation1) + val testRelation2K = KeyHint(List(testRelation2.output.head), testRelation2) test("collapse left outer join followed by subset project") { val query = testRelation1 - .join(testRelation2, LeftOuter, Some('a === 'c)) + .join(testRelation2K, LeftOuter, Some('a === 'c)) .select('a, 'b) val optimized = Optimize.execute(query.analyze) @@ -48,7 +52,7 @@ class EliminateOuterJoinBeforeProjectSuite extends PlanTest { } test("collapse right outer join followed by subset project") { - val query = testRelation1 + val query = testRelation1K .join(testRelation2, RightOuter, Some('a === 'c)) .select('c, 'd) @@ -60,7 +64,7 @@ class EliminateOuterJoinBeforeProjectSuite extends PlanTest { test("collapse outer join followed by subset project with expressions") { val query = testRelation1 - .join(testRelation2, LeftOuter, Some('a === 'c)) + .join(testRelation2K, LeftOuter, Some('a === 'c)) .select(('a + 1).as('a), ('b + 2).as('b)) val optimized = Optimize.execute(query.analyze) @@ -70,6 +74,17 @@ class EliminateOuterJoinBeforeProjectSuite extends PlanTest { } test("do not collapse non-subset project") { + val query = testRelation1 + .join(testRelation2K, LeftOuter, Some('a === 'c)) + .select('a, 'b, ('c + 1).as('c), 'd) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse non-keyed join") { val query = testRelation1 .join(testRelation2, LeftOuter, Some('a === 'c)) .select('a, 'b, ('c + 1).as('c), 'd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3ea0f9ed3bddd..a75734cf785a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -666,6 +666,12 @@ class DataFrame private[sql]( */ def as(alias: Symbol): DataFrame = as(alias.name) + @scala.annotation.varargs + def key(cols: Column*): DataFrame = KeyHint(cols.map(_.expr), logicalPlan) + + @scala.annotation.varargs + def key(col: String, cols: String*): DataFrame = key((col +: cols).map(Column(_)) : _*) + /** * Selects a set of column based expressions. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4aff52d992e6b..0560011d37071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -415,6 +415,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil case BroadcastHint(child) => apply(child) + case logical.KeyHint(_, child) => apply(child) case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index aef940a526675..e970375d2a909 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -98,6 +98,10 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { testData.collect().toSeq) } + test("key hint") { + checkAnswer(testData.key("key"), testData.collect().toSeq) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) From df9ef1421cee2f8f94dac24a8116ad504a009a20 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 12 Aug 2015 16:25:30 -0700 Subject: [PATCH 03/23] Add foreign keys --- .../sql/catalyst/optimizer/Optimizer.scala | 48 ++++++++++++++----- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../plans/logical/basicOperators.scala | 25 +++++----- .../sql/catalyst/plans/logical/keys.scala | 24 ++++++++++ .../optimizer/JoinEliminationSuite.scala | 21 ++++++-- .../org/apache/spark/sql/DataFrame.scala | 11 +++-- .../org/apache/spark/sql/DataFrameSuite.scala | 11 ++++- 7 files changed, 109 insertions(+), 33 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8df2041929ae0..821091b45747c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -275,18 +275,42 @@ object JoinElimination extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Left outer join where only the left columns are kept, and a key from the right is involved in // the join so no duplicates are generated - case Project(projectList, ExtractEquiJoinKeys(LeftOuter, _, rightKeys, _, left, right)) - if AttributeSet(projectList).subsetOf(left.outputSet) - && AttributeSet(right.keys).intersect(AttributeSet(rightKeys)).nonEmpty => - Project(projectList, left) - - // Right outer join where only the right columns are kept, and a key from the left is involved in - // the join so no duplicates are generated - case Project(projectList, ExtractEquiJoinKeys(RightOuter, leftKeys, _, _, left, right)) - if AttributeSet(projectList).subsetOf(right.outputSet) - && AttributeSet(left.keys).intersect(AttributeSet(leftKeys)).nonEmpty => - Project(projectList, right) -} + case p @ Project(projectList, ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, _, left, right)) => + // A unique key from the right must be involved in the join so no duplicates are generated + val rightUniqueKeys = AttributeSet(right.keys.collect { case UniqueKey(attr) => attr }) + val noDups = rightUniqueKeys.intersect(AttributeSet(rightKeys)).nonEmpty + + // Only the left columns (or right columns that are equal to left columns via an equijoin + // predicate) may be kept + val onlyLeftCols = AttributeSet(projectList).subsetOf(left.outputSet) + val onlyLeftOrEquiCols = + AttributeSet(projectList).subsetOf(left.outputSet ++ AttributeSet(rightKeys)) + + // If right columns were kept, they must be referenced by a foreign key constraint on the + // corresponding left column so no nonexistent values are generated + val aliasMap = AttributeMap(rightKeys.zip(leftKeys).collect { + case (a: NamedExpression, b: NamedExpression) => (a.toAttribute, b.toAttribute) + }) + val foreignKeyConstraintsSatisfied = projectList.flatMap(_.collect { + case a: Attribute if aliasMap.contains(a) && left.keys.collect { + case ForeignKey(leftAttr, referencedAttr) + if leftAttr == aliasMap(a) && referencedAttr == a => true + }.nonEmpty => true + }).nonEmpty + + if (onlyLeftCols) { + Project(projectList, left) + } else if (onlyLeftOrEquiCols && foreignKeyConstraintsSatisfied) { + val substitutedProjection = projectList.map(_.transform { + case a: Attribute => + if (aliasMap.contains(a)) Alias(aliasMap(a), a.name)() + else a + }).asInstanceOf[Seq[NamedExpression]] + Project(substitutedProjection, left) + } else { + p + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index d7d78e160fd53..8a0cd3a20e6c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.trees.TreeNode abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { - def keys: Seq[Expression] = Seq.empty + def keys: Seq[Key] = Seq.empty /** * Computes [[Statistics]] for this plan. The default implementation assumes the output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 326b39d260a6d..560080a00a4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet case class KeyHint( - override val keys: Seq[Expression], + override val keys: Seq[Key], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } @@ -32,8 +32,18 @@ case class KeyHint( case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def keys: Seq[Expression] = - child.keys.filter(k => projectList.exists(_.semanticEquals(k))) + override def keys: Seq[Key] = { + val aliasMap = AttributeMap(projectList.collect { + case a @ Alias(old: AttributeReference, _) => (old, a.toAttribute) + case r: AttributeReference => (r, r) + }) + child.keys.collect { + case UniqueKey(attr) if aliasMap.contains(attr) => + UniqueKey(aliasMap(attr)) + case ForeignKey(attr, referencedAttr) if aliasMap.contains(attr) => + ForeignKey(aliasMap(attr), referencedAttr) + } + } override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { @@ -96,7 +106,6 @@ case class Generate( case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def keys: Seq[Expression] = child.keys } case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -150,16 +159,12 @@ case class Join( */ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - - override def keys: Seq[Expression] = child.keys } case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output - override def keys: Seq[Expression] = left.keys - override lazy val resolved: Boolean = childrenResolved && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } @@ -213,8 +218,6 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def keys: Seq[Expression] = child.keys - def hasNoEvaluation: Boolean = order.forall(_.child.isInstanceOf[AttributeReference]) override lazy val resolved: Boolean = @@ -432,7 +435,6 @@ case class Sample( */ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def keys: Seq[Expression] = child.keys } /** @@ -444,7 +446,6 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def keys: Seq[Expression] = child.keys } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala new file mode 100644 index 0000000000000..dbfbef4b26749 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +sealed trait Key +case class UniqueKey(attr: Attribute) extends Key +case class ForeignKey(attr: Attribute, referencedAttr: Attribute) extends Key diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala index f7f0b8be53985..132a668f191e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -23,9 +23,11 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.logical.ForeignKey import org.apache.spark.sql.catalyst.plans.logical.KeyHint import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.UniqueKey import org.apache.spark.sql.catalyst.rules.RuleExecutor class JoinEliminationSuite extends PlanTest { @@ -37,8 +39,10 @@ class JoinEliminationSuite extends PlanTest { val testRelation1 = LocalRelation('a.int, 'b.int) val testRelation2 = LocalRelation('c.int, 'd.int) - val testRelation1K = KeyHint(List(testRelation1.output.head), testRelation1) - val testRelation2K = KeyHint(List(testRelation2.output.head), testRelation2) + val testRelation1K = KeyHint(List(UniqueKey(testRelation1.output.head)), testRelation1) + val testRelation2K = KeyHint(List(UniqueKey(testRelation2.output.head)), testRelation2) + val testRelation3 = LocalRelation('e.int, 'f.int) + val testRelation3K = KeyHint(List(ForeignKey(testRelation3.output.head, testRelation1.output.head)), testRelation3) test("collapse left outer join followed by subset project") { val query = testRelation1 @@ -51,7 +55,7 @@ class JoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("collapse right outer join followed by subset project") { + ignore("collapse right outer join followed by subset project") { val query = testRelation1K .join(testRelation2, RightOuter, Some('a === 'c)) .select('c, 'd) @@ -73,6 +77,17 @@ class JoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("collapse outer join with cross-table aliasing") { + val query = testRelation3K + .join(testRelation1, LeftOuter, Some('e === 'a)) + .select('a, 'f) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation3K.select('e.as('a), 'f).analyze + + comparePlans(optimized, correctAnswer) + } + test("do not collapse non-subset project") { val query = testRelation1 .join(testRelation2K, LeftOuter, Some('a === 'c)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index a75734cf785a4..c2ed23f612aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -666,11 +666,14 @@ class DataFrame private[sql]( */ def as(alias: Symbol): DataFrame = as(alias.name) - @scala.annotation.varargs - def key(cols: Column*): DataFrame = KeyHint(cols.map(_.expr), logicalPlan) + def uniqueKey(col: String): DataFrame = + KeyHint(List(UniqueKey(logicalPlan.output.find(_.name == col).get)), logicalPlan) - @scala.annotation.varargs - def key(col: String, cols: String*): DataFrame = key((col +: cols).map(Column(_)) : _*) + def foreignKey(col: String, referencedTable: DataFrame, referencedCol: String): DataFrame = { + val colAttr = logicalPlan.output.find(_.name == col).get + val referencedAttr = referencedTable.logicalPlan.output.find(_.name == referencedCol).get + KeyHint(List(ForeignKey(colAttr, referencedAttr)), logicalPlan) + } /** * Selects a set of column based expressions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e970375d2a909..d1c43fedd629f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -99,7 +99,16 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("key hint") { - checkAnswer(testData.key("key"), testData.collect().toSeq) + import org.apache.spark.sql.catalyst.plans.logical.Join + + val personK = person.uniqueKey("id") + val salaryK = salary.foreignKey("personId", personK, "id") + val salaries = salaryK.join(personK, salaryK("personId") === person("id"), "left_outer") + .select(person("id"), salary("salary")) + checkAnswer(salaries, salary.collect().toSeq) + assert(salaries.queryExecution.optimizedPlan.collect { + case j: Join => j + }.isEmpty) } test("empty data frame") { From b22f7025860fed1b3f7bd5147691f5ef887bca01 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 12 Aug 2015 19:49:26 -0700 Subject: [PATCH 04/23] Alias-aware join elimination + bugfixes --- .../sql/catalyst/optimizer/Optimizer.scala | 49 ++++++++++++++++--- .../plans/logical/basicOperators.scala | 6 +-- .../optimizer/JoinEliminationSuite.scala | 15 +++++- .../org/apache/spark/sql/DataFrameSuite.scala | 5 +- 4 files changed, 60 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 821091b45747c..5c87d84174ac8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -291,19 +291,20 @@ object JoinElimination extends Rule[LogicalPlan] { val aliasMap = AttributeMap(rightKeys.zip(leftKeys).collect { case (a: NamedExpression, b: NamedExpression) => (a.toAttribute, b.toAttribute) }) - val foreignKeyConstraintsSatisfied = projectList.flatMap(_.collect { - case a: Attribute if aliasMap.contains(a) && left.keys.collect { + val eq = equivalences(right) + val foreignKeyConstraintsSatisfied = AttributeSet(projectList).forall(a => + !aliasMap.contains(a) || left.keys.exists { case ForeignKey(leftAttr, referencedAttr) - if leftAttr == aliasMap(a) && referencedAttr == a => true - }.nonEmpty => true - }).nonEmpty + if leftAttr == aliasMap(a) && eq.query(referencedAttr, a) => true + case _ => false + }) - if (onlyLeftCols) { + if (noDups && onlyLeftCols) { Project(projectList, left) - } else if (onlyLeftOrEquiCols && foreignKeyConstraintsSatisfied) { + } else if (noDups && onlyLeftOrEquiCols && foreignKeyConstraintsSatisfied) { val substitutedProjection = projectList.map(_.transform { case a: Attribute => - if (aliasMap.contains(a)) Alias(aliasMap(a), a.name)() + if (aliasMap.contains(a)) Alias(aliasMap(a), a.name)(a.exprId) else a }).asInstanceOf[Seq[NamedExpression]] Project(substitutedProjection, left) @@ -311,6 +312,38 @@ object JoinElimination extends Rule[LogicalPlan] { p } } + + private def equivalences(plan: LogicalPlan): MutableDisjointSet[Attribute] = { + val s = new MutableDisjointSet[Attribute] + plan.collect { + case Project(projectList, _) => projectList.collect { + case a @ Alias(old: Attribute, _) => s.union(old, a.toAttribute) + } + } + s + } + + private class MutableDisjointSet[A]() { + import scala.collection.mutable.Set + private var sets = Set[Set[A]]() + def add(x: A): Unit = { + if (!sets.exists(_.contains(x))) { + sets += Set(x) + } + } + def union(x: A, y: A): Unit = { + add(x) + add(y) + val xSet = sets.find(_.contains(x)).get + val ySet = sets.find(_.contains(y)).get + sets -= xSet + sets -= ySet + sets += (xSet ++ ySet) + } + def query(x: A, y: A): Boolean = { + x == y || sets.exists(s => s.contains(x) && s.contains(y)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 560080a00a4fb..fedc663876922 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -case class KeyHint( - override val keys: Seq[Key], - child: LogicalPlan) extends UnaryNode { +case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + override def keys: Seq[Key] = newKeys ++ child.keys } case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala index 132a668f191e3..5c45d50a5a695 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -77,9 +77,9 @@ class JoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("collapse outer join with cross-table aliasing") { + test("collapse outer join with foreign key") { val query = testRelation3K - .join(testRelation1, LeftOuter, Some('e === 'a)) + .join(testRelation1K, LeftOuter, Some('e === 'a)) .select('a, 'f) val optimized = Optimize.execute(query.analyze) @@ -88,6 +88,17 @@ class JoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("collapse outer join with foreign key despite alias") { + val query = testRelation3K + .join(testRelation1K.select('a.as('g), 'b), LeftOuter, Some('e === 'g)) + .select('g, 'f) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation3K.select('e.as('g), 'f).analyze + + comparePlans(optimized, correctAnswer) + } + test("do not collapse non-subset project") { val query = testRelation1 .join(testRelation2K, LeftOuter, Some('a === 'c)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d1c43fedd629f..a9c3fca9aff4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -102,9 +102,10 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.catalyst.plans.logical.Join val personK = person.uniqueKey("id") + val personKA = personK.select($"id".as("p_id")) val salaryK = salary.foreignKey("personId", personK, "id") - val salaries = salaryK.join(personK, salaryK("personId") === person("id"), "left_outer") - .select(person("id"), salary("salary")) + val salaries = salaryK.join(personKA, salaryK("personId") === personKA("p_id"), "left_outer") + .select(personKA("p_id"), salary("salary")) checkAnswer(salaries, salary.collect().toSeq) assert(salaries.queryExecution.optimizedPlan.collect { case j: Join => j From 9072cb70872b156027cb2e673a397cc01f326128 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 12 Aug 2015 20:22:55 -0700 Subject: [PATCH 05/23] Propagate foreign keys through Join operator --- .../sql/catalyst/plans/logical/basicOperators.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index fedc663876922..8c27b70bf6f3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -143,6 +143,16 @@ case class Join( } } + override def keys: Seq[Key] = { + // TODO: try to propagate unique keys as well as foreign keys + def fk(keys: Seq[Key]): Seq[ForeignKey] = keys.collect { case k: ForeignKey => k } + joinType match { + case LeftSemi | LeftOuter => fk(left.keys) + case RightOuter => fk(right.keys) + case _ => fk(left.keys ++ right.keys) + } + } + def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. From f430ea2c6413879403973fc4fdd4217dde9d27ec Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 12 Aug 2015 20:43:06 -0700 Subject: [PATCH 06/23] Remove key hints after join elimination --- .../sql/catalyst/optimizer/Optimizer.scala | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5c87d84174ac8..36e5f9fc15c85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -33,14 +33,7 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] object DefaultOptimizer extends Optimizer { - val batches = - // SubQueries are only needed for analysis and can be removed before execution. - Batch("Remove SubQueries", FixedPoint(100), - EliminateSubQueries) :: - Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, - RemoveLiteralFromGroupExpressions) :: - Batch("Operator Optimizations", FixedPoint(100), + val operatorOptimizations = Seq( // Operator push down SetOperationPushDown, SamplePushDown, @@ -62,7 +55,19 @@ object DefaultOptimizer extends Optimizer { RemovePositive, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions) + + val batches = + // SubQueries are only needed for analysis and can be removed before execution. + Batch("Remove SubQueries", FixedPoint(100), + EliminateSubQueries) :: + Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: + Batch("Operator Optimizations", FixedPoint(100), + operatorOptimizations: _*) :: + Batch("Remove Hints", FixedPoint(100), + (RemoveKeyHints +: operatorOptimizations): _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -346,6 +351,12 @@ object JoinElimination extends Rule[LogicalPlan] { } } +object RemoveKeyHints extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case KeyHint(_, child) => child + } +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given From 130253101f2db627c42ea4f8759dfeef6c62e574 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sun, 16 Aug 2015 18:55:36 -0700 Subject: [PATCH 07/23] Support inner joins based on referential integrity --- .../sql/catalyst/optimizer/Optimizer.scala | 95 +++------- .../optimizer/joinEliminationPatterns.scala | 177 ++++++++++++++++++ .../optimizer/JoinEliminationSuite.scala | 64 ++++++- 3 files changed, 267 insertions(+), 69 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 36e5f9fc15c85..62dfdc6fcc267 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -274,81 +274,40 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } /** - * Eliminates keyed equi-joins when followed by a [[Project]] that keeps only columns from one side. + * Eliminates keyed equi-joins when followed by a [[Project]] that only keeps columns from one side. + * + * See [[http://www.info.teradata.com/HTMLPubs/DB_TTU_14_00/index.html#page/SQL_Reference/B035_1142_111A/ch02.124.042.html#ww17434326]]. */ object JoinElimination extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Left outer join where only the left columns are kept, and a key from the right is involved in - // the join so no duplicates are generated - case p @ Project(projectList, ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, _, left, right)) => - // A unique key from the right must be involved in the join so no duplicates are generated - val rightUniqueKeys = AttributeSet(right.keys.collect { case UniqueKey(attr) => attr }) - val noDups = rightUniqueKeys.intersect(AttributeSet(rightKeys)).nonEmpty - - // Only the left columns (or right columns that are equal to left columns via an equijoin - // predicate) may be kept - val onlyLeftCols = AttributeSet(projectList).subsetOf(left.outputSet) - val onlyLeftOrEquiCols = - AttributeSet(projectList).subsetOf(left.outputSet ++ AttributeSet(rightKeys)) - - // If right columns were kept, they must be referenced by a foreign key constraint on the - // corresponding left column so no nonexistent values are generated - val aliasMap = AttributeMap(rightKeys.zip(leftKeys).collect { - case (a: NamedExpression, b: NamedExpression) => (a.toAttribute, b.toAttribute) - }) - val eq = equivalences(right) - val foreignKeyConstraintsSatisfied = AttributeSet(projectList).forall(a => - !aliasMap.contains(a) || left.keys.exists { - case ForeignKey(leftAttr, referencedAttr) - if leftAttr == aliasMap(a) && eq.query(referencedAttr, a) => true - case _ => false - }) - - if (noDups && onlyLeftCols) { - Project(projectList, left) - } else if (noDups && onlyLeftOrEquiCols && foreignKeyConstraintsSatisfied) { - val substitutedProjection = projectList.map(_.transform { - case a: Attribute => - if (aliasMap.contains(a)) Alias(aliasMap(a), a.name)(a.exprId) - else a - }).asInstanceOf[Seq[NamedExpression]] - Project(substitutedProjection, left) - } else { - p - } + // Outer join where only the outer table's columns are kept, and a key from the inner table is + // involved in the join so no duplicates would be generated + case CanEliminateUniqueKeyOuterJoin(outer, projectList) => + Project(projectList, outer) + + // Any kind of join based on referential integrity + case CanEliminateReferentialIntegrityEquiJoin( + _, parent, child, primaryForeignMap, projectList) => + Project(substituteParentForChild(projectList, parent, primaryForeignMap), child) } - private def equivalences(plan: LogicalPlan): MutableDisjointSet[Attribute] = { - val s = new MutableDisjointSet[Attribute] - plan.collect { - case Project(projectList, _) => projectList.collect { - case a @ Alias(old: Attribute, _) => s.union(old, a.toAttribute) - } - } - s + /** + * In the given expressions, substitute all references to parent columns with references to the + * corresponding child columns. The `primaryForeignMap` contains these equivalences, extracted + * from the equality join expressions. + */ + private def substituteParentForChild( + expressions: Seq[NamedExpression], + parent: LogicalPlan, + primaryForeignMap: AttributeMap[Attribute]) + : Seq[NamedExpression] = { + expressions.map(_.transform { + case a: Attribute => + if (parent.outputSet.contains(a)) Alias(primaryForeignMap(a), a.name)(a.exprId) + else a + }.asInstanceOf[NamedExpression]) } - private class MutableDisjointSet[A]() { - import scala.collection.mutable.Set - private var sets = Set[Set[A]]() - def add(x: A): Unit = { - if (!sets.exists(_.contains(x))) { - sets += Set(x) - } - } - def union(x: A, y: A): Unit = { - add(x) - add(y) - val xSet = sets.find(_.contains(x)).get - val ySet = sets.find(_.contains(y)).get - sets -= xSet - sets -= ySet - sets += (xSet ++ ySet) - } - def query(x: A, y: A): Boolean = { - x == y || sets.exists(s => s.contains(x) && s.contains(y)) - } - } } object RemoveKeyHints extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala new file mode 100644 index 0000000000000..9cd7859587be1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * Finds outer joins where only the outer table's columns are kept, and a key from the inner table + * is involved in the join so no duplicates would be generated. + */ +object CanEliminateUniqueKeyOuterJoin { + /** (outer, projectList) */ + type ReturnType = (LogicalPlan, Seq[NamedExpression]) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case p @ Project(projectList, + ExtractEquiJoinKeys( + joinType @ (LeftOuter | RightOuter), leftJoinExprs, rightJoinExprs, _, left, right)) => + val (outer, inner, innerJoinExprs) = (joinType: @unchecked) match { + case LeftOuter => (left, right, rightJoinExprs) + case RightOuter => (right, left, leftJoinExprs) + } + + val onlyOuterColsKept = AttributeSet(projectList).subsetOf(outer.outputSet) + + val innerUniqueKeys = AttributeSet(inner.keys.collect { case UniqueKey(attr) => attr }) + val innerKeyIsInvolved = innerUniqueKeys.intersect(AttributeSet(innerJoinExprs)).nonEmpty + + if (onlyOuterColsKept && innerKeyIsInvolved) { + Some((outer, projectList)) + } else { + None + } + + case _ => None + } +} + +/** + * Finds equijoins based on foreign-key referential integrity, followed by [[Project]]s that + * reference no columns from the parent table other than the referenced unique keys. + * + * The table containing the foreign key is referred to as the child table, while the table + * containing the referenced unique key is referred to as the parent table. Such equijoins can be + * eliminated and replaced by the child table. + * + * See [[http://www.info.teradata.com/HTMLPubs/DB_TTU_14_00/index.html#page/SQL_Reference/B035_1142_111A/ch02.124.045.html]]. + */ +object CanEliminateReferentialIntegrityEquiJoin { + /** (joinType, parent, child, primaryForeignMap, projectList) */ + type ReturnType = + (JoinType, LogicalPlan, LogicalPlan, AttributeMap[Attribute], Seq[NamedExpression]) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case Project(projectList, + ExtractEquiJoinKeys(joinType, leftJoinExprs, rightJoinExprs, _, left, right)) => + val leftParentPFM = getPrimaryForeignMap(left, right, leftJoinExprs, rightJoinExprs) + val leftIsParent = + leftParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, leftParentPFM, left) + val rightParentPFM = getPrimaryForeignMap(right, left, rightJoinExprs, leftJoinExprs) + val rightIsParent = + rightParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, rightParentPFM, right) + + if (leftIsParent) { + Some((joinType, left, right, leftParentPFM, projectList)) + } else if (rightIsParent) { + Some((joinType, right, left, rightParentPFM, projectList)) + } else { + None + } + + case _ => None + } + + /** + * Return a map where, for each PK=FK join expression based on referential integrity between + * `parent` and `child`, the unique key from `parent` is mapped to its corresponding foreign + * key from `child`. + */ + private def getPrimaryForeignMap( + parent: LogicalPlan, + child: LogicalPlan, + parentJoinExprs: Seq[Expression], + childJoinExprs: Seq[Expression]) + : AttributeMap[Attribute] = { + val primaryKeys = AttributeSet(parent.keys.collect { case UniqueKey(attr) => attr }) + val foreignKeys = new ForeignKeyFinder(child, parent) + AttributeMap(parentJoinExprs.zip(childJoinExprs).collect { + case (parentExpr: NamedExpression, childExpr: NamedExpression) + if primaryKeys.contains(parentExpr.toAttribute) + && foreignKeys.foreignKeyExists(childExpr.toAttribute, parentExpr.toAttribute) => + (parentExpr.toAttribute, childExpr.toAttribute) + }) + } + + /** + * Return true if `kept` references no columns from `parent` except those involved in a PK=FK + * join expression. Such join expressions are stored in `primaryForeignMap`. + */ + private def onlyPrimaryKeysKept( + kept: Seq[NamedExpression], + primaryForeignMap: AttributeMap[Attribute], + parent: LogicalPlan) + : Boolean = { + AttributeSet(kept).forall { keptAttr => + if (parent.outputSet.contains(keptAttr)) { + primaryForeignMap.contains(keptAttr) + } else { + true + } + } + } +} + +private class ForeignKeyFinder(plan: LogicalPlan, referencedPlan: LogicalPlan) { + val equivalent = equivalences(referencedPlan) + + def foreignKeyExists(attr: Attribute, referencedAttr: Attribute): Boolean = { + plan.keys.exists { + case ForeignKey(attr2, referencedAttr2) if attr == attr2 && equivalent.query(referencedAttr, referencedAttr2) => true + case _ => false + } + } + + private def equivalences(plan: LogicalPlan): MutableDisjointSet[Attribute] = { + val s = new MutableDisjointSet[Attribute] + plan.collect { + case Project(projectList, _) => projectList.collect { + case a @ Alias(old: Attribute, _) => s.union(old, a.toAttribute) + } + } + s + } + +} + +private class MutableDisjointSet[A]() { + import scala.collection.mutable.Set + private var sets = Set[Set[A]]() + def add(x: A): Unit = { + if (!sets.exists(_.contains(x))) { + sets += Set(x) + } + } + def union(x: A, y: A): Unit = { + add(x) + add(y) + val xSet = sets.find(_.contains(x)).get + val ySet = sets.find(_.contains(y)).get + sets -= xSet + sets -= ySet + sets += (xSet ++ ySet) + } + def query(x: A, y: A): Boolean = { + x == y || sets.exists(s => s.contains(x) && s.contains(y)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala index 5c45d50a5a695..d70b3b9922264 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.RightOuter @@ -55,7 +56,7 @@ class JoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - ignore("collapse right outer join followed by subset project") { + test("collapse right outer join followed by subset project") { val query = testRelation1K .join(testRelation2, RightOuter, Some('a === 'c)) .select('c, 'd) @@ -120,4 +121,65 @@ class JoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("collapse join preceded by join") { + val query1 = testRelation3K + .join(testRelation1, LeftOuter, Some('e === 'a)) // will not be eliminated + val query2 = query1 + .join(testRelation2K, LeftOuter, Some('a === 'c)) // should be eliminated + .select('a, 'b, 'e, 'f) + + val optimized = Optimize.execute(query2.analyze) + val correctAnswer = query1.select('a, 'b, 'e, 'f).analyze + + comparePlans(optimized, correctAnswer) + } + + test("eliminate inner join - fk on right, no pk columns kept") { + val query = testRelation1K + .join(testRelation3K, Inner, Some('a === 'e)) + .select('e, 'f) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation3K.select('e, 'f).analyze + + comparePlans(optimized, correctAnswer) + } + + test("eliminate inner join - fk on right, pk columns kept") { + val query = testRelation1K + .join(testRelation3K, Inner, Some('a === 'e)) + .select('a, 'f) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation3K.select('e.as('a), 'f).analyze + + comparePlans(optimized, correctAnswer) + } + + test("eliminate inner join - fk on left, pk columns kept") { + val query = testRelation3K + .join(testRelation1K, Inner, Some('a === 'e)) + .select('a, 'f) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation3K.select('e.as('a), 'f).analyze + + comparePlans(optimized, correctAnswer) + } + + test("triangles") { + val e0 = LocalRelation('srcId.int, 'dstId.int) + val v0 = LocalRelation('id.int, 'attr.int) + val e = KeyHint(List(ForeignKey(e0.output.head, v0.output.head), ForeignKey(e0.output.last, v0.output.head)), e0) + val v = KeyHint(List(UniqueKey(v0.output.head)), v0) + + val query = e.join(v, LeftOuter, Some('dstId === 'id)).select('srcId, 'id) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = e.select('srcId, 'dstId.as('id)).analyze + + comparePlans(optimized, correctAnswer) + + } } From 35949f54c53357a86e0a2e2aeb0e5524a8285ce5 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 17 Aug 2015 23:38:30 -0700 Subject: [PATCH 08/23] Correctness fixes for join elimination Do not eliminate referential integrity full outer joins, or inner joins where foreign key is nullable. Require foreign keys to reference unique columns. --- .../sql/catalyst/optimizer/Optimizer.scala | 7 +- .../optimizer/joinEliminationPatterns.scala | 41 ++-- .../org/apache/spark/sql/DataFrame.scala | 7 + .../org/apache/spark/sql/KeyHintSuite.scala | 232 ++++++++++++++++++ 4 files changed, 265 insertions(+), 22 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 62dfdc6fcc267..0872318006acc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -280,14 +280,9 @@ object ProjectCollapsing extends Rule[LogicalPlan] { */ object JoinElimination extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Outer join where only the outer table's columns are kept, and a key from the inner table is - // involved in the join so no duplicates would be generated case CanEliminateUniqueKeyOuterJoin(outer, projectList) => Project(projectList, outer) - - // Any kind of join based on referential integrity - case CanEliminateReferentialIntegrityEquiJoin( - _, parent, child, primaryForeignMap, projectList) => + case CanEliminateReferentialIntegrityJoin(parent, child, primaryForeignMap, projectList) => Project(substituteParentForChild(projectList, parent, primaryForeignMap), child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala index 9cd7859587be1..d1b153b1b3390 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ /** @@ -58,33 +56,43 @@ object CanEliminateUniqueKeyOuterJoin { /** * Finds equijoins based on foreign-key referential integrity, followed by [[Project]]s that - * reference no columns from the parent table other than the referenced unique keys. + * reference no columns from the parent table other than the referenced unique keys. Such equijoins + * can be eliminated and replaced by the child table. * * The table containing the foreign key is referred to as the child table, while the table - * containing the referenced unique key is referred to as the parent table. Such equijoins can be - * eliminated and replaced by the child table. + * containing the referenced unique key is referred to as the parent table. + * + * For inner joins, all involved foreign keys must be non-nullable. * * See [[http://www.info.teradata.com/HTMLPubs/DB_TTU_14_00/index.html#page/SQL_Reference/B035_1142_111A/ch02.124.045.html]]. */ -object CanEliminateReferentialIntegrityEquiJoin { - /** (joinType, parent, child, primaryForeignMap, projectList) */ +object CanEliminateReferentialIntegrityJoin { + /** (parent, child, primaryForeignMap, projectList) */ type ReturnType = - (JoinType, LogicalPlan, LogicalPlan, AttributeMap[Attribute], Seq[NamedExpression]) + (LogicalPlan, LogicalPlan, AttributeMap[Attribute], Seq[NamedExpression]) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case Project(projectList, - ExtractEquiJoinKeys(joinType, leftJoinExprs, rightJoinExprs, _, left, right)) => + case p @ Project(projectList, ExtractEquiJoinKeys( + joinType @ (Inner | LeftOuter | RightOuter), + leftJoinExprs, rightJoinExprs, _, left, right)) => + val innerJoin = joinType == Inner + val leftParentPFM = getPrimaryForeignMap(left, right, leftJoinExprs, rightJoinExprs) + val rightForeignKeysAreNonNullable = leftParentPFM.values.forall(!_.nullable) val leftIsParent = - leftParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, leftParentPFM, left) + (leftParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, leftParentPFM, left) + && (!innerJoin || rightForeignKeysAreNonNullable)) + val rightParentPFM = getPrimaryForeignMap(right, left, rightJoinExprs, leftJoinExprs) + val leftForeignKeysAreNonNullable = rightParentPFM.values.forall(!_.nullable) val rightIsParent = - rightParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, rightParentPFM, right) + (rightParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, rightParentPFM, right) + && (!innerJoin || leftForeignKeysAreNonNullable)) if (leftIsParent) { - Some((joinType, left, right, leftParentPFM, projectList)) + Some((left, right, leftParentPFM, projectList)) } else if (rightIsParent) { - Some((joinType, right, left, rightParentPFM, projectList)) + Some((right, left, rightParentPFM, projectList)) } else { None } @@ -137,7 +145,8 @@ private class ForeignKeyFinder(plan: LogicalPlan, referencedPlan: LogicalPlan) { def foreignKeyExists(attr: Attribute, referencedAttr: Attribute): Boolean = { plan.keys.exists { - case ForeignKey(attr2, referencedAttr2) if attr == attr2 && equivalent.query(referencedAttr, referencedAttr2) => true + case ForeignKey(attr2, referencedAttr2) + if attr == attr2 && equivalent.query(referencedAttr, referencedAttr2) => true case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c2ed23f612aea..7f525f8272c8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -672,6 +672,13 @@ class DataFrame private[sql]( def foreignKey(col: String, referencedTable: DataFrame, referencedCol: String): DataFrame = { val colAttr = logicalPlan.output.find(_.name == col).get val referencedAttr = referencedTable.logicalPlan.output.find(_.name == referencedCol).get + val referencedAttrIsUnique = referencedTable.logicalPlan.keys.exists { + case UniqueKey(attr) if attr == referencedAttr => true + case _ => false + } + require(referencedAttrIsUnique, + s"Foreign keys can only reference unique keys, but $referencedAttr is not unique.\n" + + "Try calling referencedTable.uniqueKey(\"" + referencedAttr + "\").") KeyHint(List(ForeignKey(colAttr, referencedAttr)), logicalPlan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala new file mode 100644 index 0000000000000..b68039a0aaa8a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.Join + +private object KeyHintTestData { + val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + case class Customer(id: Int, name: String) + case class Employee(id: Int, name: String) + case class Order(id: Int, customerId: Int, employeeId: Option[Int]) + case class Manager(managerId: Int, subordinateId: Int) + case class BannedCustomer(name: String) + + val customer = ctx.sparkContext.parallelize(Seq( + Customer(0, "alice"), + Customer(1, "bob"), + Customer(2, "alice"))).toDF() + .uniqueKey("id") + val employee = ctx.sparkContext.parallelize(Seq( + Employee(0, "charlie"), + Employee(1, "dan"))).toDF() + .uniqueKey("id") + val order = ctx.sparkContext.parallelize(Seq( + Order(0, 0, Some(0)), + Order(1, 1, None))).toDF() + .foreignKey("customerId", customer, "id") + .foreignKey("employeeId", employee, "id") + val manager = ctx.sparkContext.parallelize(Seq( + Manager(0, 1))).toDF() + .foreignKey("managerId", employee, "id") + .foreignKey("subordinateId", employee, "id") + val bannedCustomer = ctx.sparkContext.parallelize(Seq( + BannedCustomer("alice"), + BannedCustomer("eve"))).toDF() + .uniqueKey("name") + + // Joins involving referential integrity (a foreign key referencing a unique key) + val orderInnerJoinView = order + .join(customer, order("customerId") === customer("id")) + .join(employee, order("employeeId") === employee("id")) + + val orderLeftOuterJoinView = order + .join(customer, order("customerId") === customer("id"), "left_outer") + .join(employee, order("employeeId") === employee("id"), "left_outer") + + val orderRightOuterJoinView = employee.join( + customer.join(order, order("customerId") === customer("id"), "right_outer"), + order("employeeId") === employee("id"), "right_outer") + + val orderCustomerFullOuterJoinView = order + .join(customer, order("customerId") === customer("id"), "full_outer") + + val orderEmployeeFullOuterJoinView = order + .join(employee, order("employeeId") === employee("id"), "full_outer") + + // Joins involving only a unique key + val bannedCustomerInnerJoinView = customer + .join(bannedCustomer, bannedCustomer("name") === customer("name")) + + val bannedCustomerLeftOuterJoinView = customer + .join(bannedCustomer, bannedCustomer("name") === customer("name"), "left_outer") + + val bannedCustomerFullOuterJoinView = customer + .join(bannedCustomer, bannedCustomer("name") === customer("name"), "full_outer") +} + +class KeyHintSuite extends QueryTest { + + import KeyHintTestData._ + + def checkJoinCount(df: DataFrame, joinCount: Int): Unit = { + val joins = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joins.size == joinCount) + } + + def checkJoinsEliminated(df: DataFrame): Unit = checkJoinCount(df, 0) + + test("no elimination") { + val orderInnerJoin = orderInnerJoinView + .select(order("id"), order("customerId"), customer("name"), + order("employeeId"), employee("name")) + checkAnswer(orderInnerJoin, Seq( + Row(0, 0, "alice", 0, "charlie"))) + + val orderLeftOuterJoin = orderLeftOuterJoinView + .select(order("id"), order("customerId"), customer("name"), + order("employeeId"), employee("name")) + checkAnswer(orderLeftOuterJoin, Seq( + Row(0, 0, "alice", 0, "charlie"), + Row(1, 1, "bob", null, null))) + + val orderRightOuterJoin = orderRightOuterJoinView + .select(order("id"), order("customerId"), customer("name"), + order("employeeId"), employee("name")) + checkAnswer(orderRightOuterJoin, Seq( + Row(0, 0, "alice", 0, "charlie"), + Row(1, 1, "bob", null, null))) + + val orderCustomerFullOuterJoin = orderCustomerFullOuterJoinView + .select(order("id"), customer("id"), customer("name")) + checkAnswer(orderCustomerFullOuterJoin, Seq( + Row(0, 0, "alice"), + Row(1, 1, "bob"), + Row(null, 2, "alice"))) + + val orderEmployeeFullOuterJoin = orderEmployeeFullOuterJoinView + .select(order("id"), employee("id"), employee("name")) + checkAnswer(orderEmployeeFullOuterJoin, Seq( + Row(0, 0, "charlie"), + Row(1, null, null), + Row(null, 1, "dan"))) + + val bannedCustomerInnerJoin = bannedCustomerInnerJoinView + .select(customer("id"), bannedCustomer("name")) + checkAnswer(bannedCustomerInnerJoin, Seq( + Row(0, "alice"), + Row(2, "alice"))) + + val bannedCustomerLeftOuterJoin = bannedCustomerLeftOuterJoinView + .select(customer("id"), bannedCustomer("name")) + checkAnswer(bannedCustomerLeftOuterJoin, Seq( + Row(0, "alice"), + Row(1, null), + Row(2, "alice"))) + + val bannedCustomerFullOuterJoin = bannedCustomerFullOuterJoinView + .select(customer("id"), bannedCustomer("name")) + checkAnswer(bannedCustomerFullOuterJoin, Seq( + Row(0, "alice"), + Row(1, null), + Row(2, "alice"), + Row(null, "eve"))) + } + + test("can't create foreign key referencing non-unique column") { + intercept[IllegalArgumentException] { + bannedCustomer.foreignKey("name", customer, "name") + } + } + + test("eliminate unique key left outer join") { + val bannedCustomerJoinEliminated = bannedCustomerLeftOuterJoinView + .select(customer("id"), customer("name")) + checkAnswer(bannedCustomerJoinEliminated, customer) + checkJoinsEliminated(bannedCustomerJoinEliminated) + } + + test("do not eliminate unique key inner/full outer join") { + val bannedCustomerInnerJoinNotEliminated = bannedCustomerInnerJoinView + .select(customer("id"), customer("name")) + checkAnswer(bannedCustomerInnerJoinNotEliminated, Seq( + Row(0, "alice"), + Row(2, "alice"))) + + val bannedCustomerFullOuterJoinNotEliminated = bannedCustomerFullOuterJoinView + .select(customer("id"), customer("name")) + checkAnswer(bannedCustomerFullOuterJoinNotEliminated, Seq( + Row(0, "alice"), + Row(1, "bob"), + Row(2, "alice"), + Row(null, null))) + } + + test("do not eliminate referential integrity inner joins where foreign key is nullable") { + val orderInnerJoin = orderInnerJoinView + .select(order("id"), customer("id"), employee("id")) + checkAnswer(orderInnerJoin, Seq( + Row(0, 0, 0))) + // Only the customer join should be eliminated + checkJoinCount(orderInnerJoinView, 2) + checkJoinCount(orderInnerJoin, 1) + } + + test("eliminate referential integrity joins") { + val orderLeftOuterJoinEliminated = orderLeftOuterJoinView + .select(order("id"), customer("id"), employee("id")) + checkAnswer(orderLeftOuterJoinEliminated, Seq( + Row(0, 0, 0), + Row(1, 1, null))) + checkJoinsEliminated(orderLeftOuterJoinEliminated) + + val orderRightOuterJoinEliminated = orderRightOuterJoinView + .select(order("id"), customer("id"), employee("id")) + checkAnswer(orderRightOuterJoinEliminated, Seq( + Row(0, 0, 0), + Row(1, 1, null))) + checkJoinsEliminated(orderRightOuterJoinEliminated) + } + + test("do not eliminate referential integrity full outer joins") { + val orderCustomerFullOuterJoinNotEliminated = orderCustomerFullOuterJoinView + .select(order("id"), order("customerId"), customer("id")) + checkAnswer(orderCustomerFullOuterJoinNotEliminated, Seq( + Row(0, 0, 0), + Row(1, 1, 1), + Row(null, null, 2))) + + val orderEmployeeFullOuterJoinNotEliminated = orderEmployeeFullOuterJoinView + .select(order("id"), order("employeeId"), employee("id")) + checkAnswer(orderEmployeeFullOuterJoinNotEliminated, Seq( + Row(0, 0, 0), + Row(1, null, null), + Row(null, null, 1))) + } + + test("do not eliminate non-unique key outer joins") {} + + test("self joins") {} + + test("multiple foreign keys with same referent") {} + +} From 945e5231e900621c4a2bbf103816385d68abd5e0 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 18 Aug 2015 23:15:31 -0700 Subject: [PATCH 09/23] Do key hint resolution during analysis This is necessary to support aliased self joins and multiple foreign keys with the same referent. --- .../sql/catalyst/analysis/Analyzer.scala | 53 ++++++++++++++- .../sql/catalyst/optimizer/Optimizer.scala | 8 +++ .../optimizer/joinEliminationPatterns.scala | 10 +-- .../plans/logical/basicOperators.scala | 20 ++++++ .../sql/catalyst/plans/logical/keys.scala | 22 ++++++- .../org/apache/spark/sql/DataFrame.scala | 30 +++++---- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ---- .../org/apache/spark/sql/KeyHintSuite.scala | 65 ++++++++++++++++--- 8 files changed, 175 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f5daba1543da9..f3e7c6734fd88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -370,7 +370,23 @@ class Analyzer( case a: Attribute => attributeRewrites.get(a).getOrElse(a) } } - j.copy(right = newRight) + // In case there are foreign keys on the left side that reference attributes on the + // right, duplicate them so they also refer to the new attributes + val newLeft = + if (left.keys.nonEmpty) { + left.transform { + case KeyHint(keys, child) => + val newKeys = keys.collect { + case ForeignKey(attr, referencedAttr) => + ForeignKey(attr, attributeRewrites.get(referencedAttr).getOrElse(referencedAttr)) + case other => other + } + KeyHint((keys ++ newKeys).distinct, child) + } + } else { + left + } + j.copy(left = newLeft, right = newRight) } // When resolve `SortOrder`s in Sort based on child, don't report errors as @@ -379,6 +395,41 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // Special handling for foreign key references - look them up in the catalog + case h @ KeyHint(keys, child) if child.resolved && !h.foreignKeyReferencesResolved => + KeyHint(keys.map { + case ForeignKey(k, u @ UnresolvedAttribute(nameParts)) => + ForeignKey(k, withPosition(u) { + // Resolve the target u of the foreign key as a column of this table or a table from + // the catalog + val (relation, referencedAttr) = + if (nameParts.length > 1) { + val relationName = nameParts.init + val referencedAttrName = nameParts.last + val relation = catalog.lookupRelation(relationName) + val referencedAttr = + relation.resolve(Seq(referencedAttrName), resolver).getOrElse(u).toAttribute + (relation, referencedAttr) + } else { + (h, h.resolve(nameParts, resolver).getOrElse(u).toAttribute) + } + + // Enforce the constraint that foreign keys can only reference unique keys + val referencedAttrIsUnique = relation.keys.exists { + case UniqueKey(attr) if attr == referencedAttr => true + case _ => false + } + if (referencedAttr.resolved && !referencedAttrIsUnique) { + failAnalysis("Foreign keys can only reference unique keys, but " + + s"$k references $u which is not unique.") + } + + referencedAttr + }) + + case otherKey => otherKey + }, child) + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0872318006acc..3eb995d53d3ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -43,6 +43,7 @@ object DefaultOptimizer extends Optimizer { ColumnPruning, // Operator combine ProjectCollapsing, + KeyHintCollapsing, JoinElimination, CombineFilters, CombineLimits, @@ -273,6 +274,13 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +object KeyHintCollapsing extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case KeyHint(keys1, KeyHint(keys2, child)) => + KeyHint((keys1 ++ keys2).distinct, child) + } +} + /** * Eliminates keyed equi-joins when followed by a [[Project]] that only keeps columns from one side. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala index d1b153b1b3390..7b8b0398907a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ /** - * Finds outer joins where only the outer table's columns are kept, and a key from the inner table - * is involved in the join so no duplicates would be generated. + * Finds left or right outer joins where only the outer table's columns are kept, and a key from the + * inner table is involved in the join so no duplicates would be generated. */ object CanEliminateUniqueKeyOuterJoin { /** (outer, projectList) */ @@ -55,9 +55,9 @@ object CanEliminateUniqueKeyOuterJoin { } /** - * Finds equijoins based on foreign-key referential integrity, followed by [[Project]]s that - * reference no columns from the parent table other than the referenced unique keys. Such equijoins - * can be eliminated and replaced by the child table. + * Finds joins based on foreign-key referential integrity, followed by [[Project]]s that reference + * no columns from the parent table other than the referenced unique keys. Such joins can be + * eliminated and replaced by the child table. * * The table containing the foreign key is referred to as the child table, while the table * containing the referenced unique key is referred to as the parent table. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8c27b70bf6f3c..9d28fc4ef1236 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -27,6 +27,25 @@ case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def keys: Seq[Key] = newKeys ++ child.keys + + override lazy val resolved: Boolean = newKeys.forall(_.resolved) && childrenResolved + + def foreignKeyReferencesResolved: Boolean = newKeys.forall { + case ForeignKey(_, referencedAttr) => referencedAttr.resolved + case _ => true + } + + override def transformExpressionsDown( + rule: PartialFunction[Expression, Expression]): this.type = { + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] + } + + override def transformExpressionsUp( + rule: PartialFunction[Expression, Expression]): this.type = { + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] + } } case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { @@ -418,6 +437,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) + override def keys: Seq[Key] = child.keys } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala index dbfbef4b26749..15d62f1550c31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala @@ -19,6 +19,22 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.Attribute -sealed trait Key -case class UniqueKey(attr: Attribute) extends Key -case class ForeignKey(attr: Attribute, referencedAttr: Attribute) extends Key +sealed abstract class Key { + def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key + def resolved: Boolean +} + +case class UniqueKey(attr: Attribute) extends Key { + override def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key = + UniqueKey(rule.applyOrElse(attr, identity[Attribute])) + + override def resolved: Boolean = attr.resolved +} + +/** Referenced column must be unique. */ +case class ForeignKey(attr: Attribute, referencedAttr: Attribute) extends Key { + override def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key = + ForeignKey(rule.applyOrElse(attr, identity[Attribute]), referencedAttr) + + override def resolved: Boolean = attr.resolved && referencedAttr.resolved +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7f525f8272c8c..c602d712891b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -36,6 +36,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.KeyHintCollapsing import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} @@ -666,22 +667,23 @@ class DataFrame private[sql]( */ def as(alias: Symbol): DataFrame = as(alias.name) - def uniqueKey(col: String): DataFrame = - KeyHint(List(UniqueKey(logicalPlan.output.find(_.name == col).get)), logicalPlan) - - def foreignKey(col: String, referencedTable: DataFrame, referencedCol: String): DataFrame = { - val colAttr = logicalPlan.output.find(_.name == col).get - val referencedAttr = referencedTable.logicalPlan.output.find(_.name == referencedCol).get - val referencedAttrIsUnique = referencedTable.logicalPlan.keys.exists { - case UniqueKey(attr) if attr == referencedAttr => true - case _ => false - } - require(referencedAttrIsUnique, - s"Foreign keys can only reference unique keys, but $referencedAttr is not unique.\n" + - "Try calling referencedTable.uniqueKey(\"" + referencedAttr + "\").") - KeyHint(List(ForeignKey(colAttr, referencedAttr)), logicalPlan) + def uniqueKey(col: String): DataFrame = { + KeyHintCollapsing(KeyHint(List(UniqueKey(UnresolvedAttribute(col))), logicalPlan)) } + /** + * Declares a foreign key referencing a key from this or another DataFrame. The referenced key + * must be declared as a unique key. + * {{{ + * department.uniqueKey("id").registerTempTable("department") + * employee.foreignKey("departmentId", "department.id") + * }}} + */ + def foreignKey(col: String, referencedCol: String): DataFrame = + KeyHintCollapsing( + KeyHint(List(ForeignKey(UnresolvedAttribute(col), UnresolvedAttribute(referencedCol))), + logicalPlan)) + /** * Selects a set of column based expressions. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a9c3fca9aff4f..aef940a526675 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -98,20 +98,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { testData.collect().toSeq) } - test("key hint") { - import org.apache.spark.sql.catalyst.plans.logical.Join - - val personK = person.uniqueKey("id") - val personKA = personK.select($"id".as("p_id")) - val salaryK = salary.foreignKey("personId", personK, "id") - val salaries = salaryK.join(personKA, salaryK("personId") === personKA("p_id"), "left_outer") - .select(personKA("p_id"), salary("salary")) - checkAnswer(salaries, salary.collect().toSeq) - assert(salaries.queryExecution.optimizedPlan.collect { - case j: Join => j - }.isEmpty) - } - test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala index b68039a0aaa8a..14755c21dbfe5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala @@ -27,6 +27,7 @@ private object KeyHintTestData { case class Employee(id: Int, name: String) case class Order(id: Int, customerId: Int, employeeId: Option[Int]) case class Manager(managerId: Int, subordinateId: Int) + case class BestFriend(id: Int, friendId: Int) case class BannedCustomer(name: String) val customer = ctx.sparkContext.parallelize(Seq( @@ -34,19 +35,30 @@ private object KeyHintTestData { Customer(1, "bob"), Customer(2, "alice"))).toDF() .uniqueKey("id") + customer.registerTempTable("customer") val employee = ctx.sparkContext.parallelize(Seq( Employee(0, "charlie"), Employee(1, "dan"))).toDF() .uniqueKey("id") + employee.registerTempTable("employee") val order = ctx.sparkContext.parallelize(Seq( Order(0, 0, Some(0)), Order(1, 1, None))).toDF() - .foreignKey("customerId", customer, "id") - .foreignKey("employeeId", employee, "id") + .foreignKey("customerId", "customer.id") + .foreignKey("employeeId", "employee.id") val manager = ctx.sparkContext.parallelize(Seq( Manager(0, 1))).toDF() - .foreignKey("managerId", employee, "id") - .foreignKey("subordinateId", employee, "id") + .foreignKey("managerId", "employee.id") + .foreignKey("subordinateId", "employee.id") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) + val bestFriend = ctx.sparkContext.parallelize(Seq( + BestFriend(0, 1), + BestFriend(1, 2), + BestFriend(2, 0))).toDF() + .uniqueKey("id") + .foreignKey("friendId", "bestFriend.id") + bestFriend.registerTempTable("bestFriend") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) val bannedCustomer = ctx.sparkContext.parallelize(Seq( BannedCustomer("alice"), BannedCustomer("eve"))).toDF() @@ -71,6 +83,13 @@ private object KeyHintTestData { val orderEmployeeFullOuterJoinView = order .join(employee, order("employeeId") === employee("id"), "full_outer") + val managerInnerJoinView = manager + .join(employee.as("emp_manager"), manager("managerId") === $"emp_manager.id") + .join(employee.as("emp_subordinate"), manager("subordinateId") === $"emp_subordinate.id") + + val bestFriendInnerJoinView = bestFriend + .join(bestFriend.as("bestFriend2"), bestFriend("friendId") === $"bestFriend2.id") + // Joins involving only a unique key val bannedCustomerInnerJoinView = customer .join(bannedCustomer, bannedCustomer("name") === customer("name")) @@ -85,6 +104,7 @@ private object KeyHintTestData { class KeyHintSuite extends QueryTest { import KeyHintTestData._ + import KeyHintTestData.ctx.implicits._ def checkJoinCount(df: DataFrame, joinCount: Int): Unit = { val joins = df.queryExecution.optimizedPlan.collect { @@ -130,6 +150,19 @@ class KeyHintSuite extends QueryTest { Row(1, null, null), Row(null, 1, "dan"))) + val managerInnerJoin = managerInnerJoinView + .select(manager("managerId"), $"emp_manager.name", + manager("subordinateId"), $"emp_subordinate.name") + checkAnswer(managerInnerJoin, Seq( + Row(0, "charlie", 1, "dan"))) + + val bestFriendInnerJoin = bestFriendInnerJoinView + .select(bestFriend("id"), $"bestFriend2.id", $"bestFriend2.friendId") + checkAnswer(bestFriendInnerJoin, Seq( + Row(0, 1, 2), + Row(1, 2, 0), + Row(2, 0, 1))) + val bannedCustomerInnerJoin = bannedCustomerInnerJoinView .select(customer("id"), bannedCustomer("name")) checkAnswer(bannedCustomerInnerJoin, Seq( @@ -153,8 +186,8 @@ class KeyHintSuite extends QueryTest { } test("can't create foreign key referencing non-unique column") { - intercept[IllegalArgumentException] { - bannedCustomer.foreignKey("name", customer, "name") + intercept[AnalysisException] { + bannedCustomer.foreignKey("name", "customer.name") } } @@ -223,10 +256,22 @@ class KeyHintSuite extends QueryTest { Row(null, null, 1))) } - test("do not eliminate non-unique key outer joins") {} - - test("self joins") {} + test("eliminate referential integrity join despite multiple foreign keys with same referent") { + val managerInnerJoinEliminated = managerInnerJoinView + .select($"emp_manager.id", $"emp_subordinate.id") + checkAnswer(managerInnerJoinEliminated, manager) + checkJoinsEliminated(managerInnerJoinEliminated) + } - test("multiple foreign keys with same referent") {} + test("eliminate referential integrity self-join") { + val bestFriendInnerJoinEliminated = bestFriendInnerJoinView + .select(bestFriend("id"), $"bestFriend2.id") + checkAnswer(bestFriendInnerJoinEliminated, Seq( + Row(0, 1), + Row(1, 2), + Row(2, 0))) + checkJoinsEliminated(bestFriendInnerJoinEliminated) + } + test("join followed by join") {} } From 504c9d858b8b35ed788e31bf99fc5f6506be792d Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 18 Aug 2015 23:18:02 -0700 Subject: [PATCH 10/23] Don't crash when foreign key refers to unresolved relation Instead just leave the KeyHint unresolved. --- .../sql/catalyst/analysis/Analyzer.scala | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f3e7c6734fd88..df220ab213b4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -406,22 +406,28 @@ class Analyzer( if (nameParts.length > 1) { val relationName = nameParts.init val referencedAttrName = nameParts.last - val relation = catalog.lookupRelation(relationName) - val referencedAttr = - relation.resolve(Seq(referencedAttrName), resolver).getOrElse(u).toAttribute - (relation, referencedAttr) + if (catalog.tableExists(relationName)) { + val relation = catalog.lookupRelation(relationName) + val referencedAttr = + relation.resolve(Seq(referencedAttrName), resolver).getOrElse(u).toAttribute + (relation, referencedAttr) + } else { + (UnresolvedRelation(relationName), u) + } } else { (h, h.resolve(nameParts, resolver).getOrElse(u).toAttribute) } // Enforce the constraint that foreign keys can only reference unique keys - val referencedAttrIsUnique = relation.keys.exists { - case UniqueKey(attr) if attr == referencedAttr => true - case _ => false - } - if (referencedAttr.resolved && !referencedAttrIsUnique) { - failAnalysis("Foreign keys can only reference unique keys, but " + - s"$k references $u which is not unique.") + if (referencedAttr.resolved) { + val referencedAttrIsUnique = relation.keys.exists { + case UniqueKey(attr) if attr == referencedAttr => true + case _ => false + } + if (!referencedAttrIsUnique) { + failAnalysis("Foreign keys can only reference unique keys, but " + + s"$k references $u which is not unique.") + } } referencedAttr From 83c8ff913dc06f79ce059906e62b0e744967c1e4 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 19 Aug 2015 00:42:04 -0700 Subject: [PATCH 11/23] Fix JoinEliminationSuite --- .../optimizer/JoinEliminationSuite.scala | 203 ++++++++---------- .../org/apache/spark/sql/KeyHintSuite.scala | 8 +- 2 files changed, 90 insertions(+), 121 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala index d70b3b9922264..3e30110f6c199 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.plans.logical.ForeignKey @@ -38,148 +39,118 @@ class JoinEliminationSuite extends PlanTest { Batch("JoinElimination", Once, JoinElimination) :: Nil } - val testRelation1 = LocalRelation('a.int, 'b.int) - val testRelation2 = LocalRelation('c.int, 'd.int) - val testRelation1K = KeyHint(List(UniqueKey(testRelation1.output.head)), testRelation1) - val testRelation2K = KeyHint(List(UniqueKey(testRelation2.output.head)), testRelation2) - val testRelation3 = LocalRelation('e.int, 'f.int) - val testRelation3K = KeyHint(List(ForeignKey(testRelation3.output.head, testRelation1.output.head)), testRelation3) - - test("collapse left outer join followed by subset project") { - val query = testRelation1 - .join(testRelation2K, LeftOuter, Some('a === 'c)) - .select('a, 'b) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation1.select('a, 'b).analyze - - comparePlans(optimized, correctAnswer) + val customer = { + val r = LocalRelation('customerId.int.notNull, 'customerName.string) + KeyHint(List(UniqueKey(r.output(0))), r) } - - test("collapse right outer join followed by subset project") { - val query = testRelation1K - .join(testRelation2, RightOuter, Some('a === 'c)) - .select('c, 'd) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation2.select('c, 'd).analyze - - comparePlans(optimized, correctAnswer) + val employee = { + val r = LocalRelation('employeeId.int.notNull, 'employeeName.string) + KeyHint(List(UniqueKey(r.output(0))), r) } - - test("collapse outer join followed by subset project with expressions") { - val query = testRelation1 - .join(testRelation2K, LeftOuter, Some('a === 'c)) - .select(('a + 1).as('a), ('b + 2).as('b)) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation1.select(('a + 1).as('a), ('b + 2).as('b)).analyze - - comparePlans(optimized, correctAnswer) + val order = { + val r = LocalRelation( + 'orderId.int.notNull, 'o_customerId.int.notNull, 'o_employeeId.int) + KeyHint(List( + UniqueKey(r.output(0)), + ForeignKey(r.output(1), customer.output(0)), + ForeignKey(r.output(2), employee.output(0))), r) } - - test("collapse outer join with foreign key") { - val query = testRelation3K - .join(testRelation1K, LeftOuter, Some('e === 'a)) - .select('a, 'f) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation3K.select('e.as('a), 'f).analyze - - comparePlans(optimized, correctAnswer) + val bannedCustomer = { + val r = LocalRelation('bannedCustomerName.string.notNull) + KeyHint(List(UniqueKey(r.output(0))), r) } - test("collapse outer join with foreign key despite alias") { - val query = testRelation3K - .join(testRelation1K.select('a.as('g), 'b), LeftOuter, Some('e === 'g)) - .select('g, 'f) - + def checkJoinEliminated( + base: LogicalPlan, + join: LogicalPlan => LogicalPlan, + project: LogicalPlan => LogicalPlan, + projectAfterElimination: LogicalPlan => LogicalPlan): Unit = { + val query = project(join(base)) val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation3K.select('e.as('g), 'f).analyze - + val correctAnswer = projectAfterElimination(base).analyze comparePlans(optimized, correctAnswer) } - test("do not collapse non-subset project") { - val query = testRelation1 - .join(testRelation2K, LeftOuter, Some('a === 'c)) - .select('a, 'b, ('c + 1).as('c), 'd) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = query.analyze - - comparePlans(optimized, correctAnswer) + def checkJoinEliminated( + base: LogicalPlan, + join: LogicalPlan => LogicalPlan, + project: LogicalPlan => LogicalPlan): Unit = { + checkJoinEliminated(base, join, project, project) } - test("do not collapse non-keyed join") { - val query = testRelation1 - .join(testRelation2, LeftOuter, Some('a === 'c)) - .select('a, 'b, ('c + 1).as('c), 'd) - + def checkJoinNotEliminated( + base: LogicalPlan, + join: LogicalPlan => LogicalPlan, + project: LogicalPlan => LogicalPlan): Unit = { + val query = project(join(base)) val optimized = Optimize.execute(query.analyze) val correctAnswer = query.analyze - comparePlans(optimized, correctAnswer) } - test("collapse join preceded by join") { - val query1 = testRelation3K - .join(testRelation1, LeftOuter, Some('e === 'a)) // will not be eliminated - val query2 = query1 - .join(testRelation2K, LeftOuter, Some('a === 'c)) // should be eliminated - .select('a, 'b, 'e, 'f) - - val optimized = Optimize.execute(query2.analyze) - val correctAnswer = query1.select('a, 'b, 'e, 'f).analyze - - comparePlans(optimized, correctAnswer) + test("eliminate unique key left outer join") { + checkJoinEliminated( + customer, + _.join(bannedCustomer, LeftOuter, Some('customerName === 'bannedCustomerName)), + _.select('customerId, 'customerName)) } - test("eliminate inner join - fk on right, no pk columns kept") { - val query = testRelation1K - .join(testRelation3K, Inner, Some('a === 'e)) - .select('e, 'f) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation3K.select('e, 'f).analyze - - comparePlans(optimized, correctAnswer) + test("do not eliminate unique key inner join") { + checkJoinNotEliminated( + customer, + _.join(bannedCustomer, Inner, Some('customerName === 'bannedCustomerName)), + _.select('customerId, 'customerName)) } - test("eliminate inner join - fk on right, pk columns kept") { - val query = testRelation1K - .join(testRelation3K, Inner, Some('a === 'e)) - .select('a, 'f) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation3K.select('e.as('a), 'f).analyze - - comparePlans(optimized, correctAnswer) + test("do not eliminate unique key full outer join") { + checkJoinNotEliminated( + customer, + _.join(bannedCustomer, FullOuter, Some('customerName === 'bannedCustomerName)), + _.select('customerId, 'customerName)) } - test("eliminate inner join - fk on left, pk columns kept") { - val query = testRelation3K - .join(testRelation1K, Inner, Some('a === 'e)) - .select('a, 'f) - - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation3K.select('e.as('a), 'f).analyze - - comparePlans(optimized, correctAnswer) + test("do not eliminate referential integrity inner join where foreign key is nullable") { + checkJoinNotEliminated( + order, + _.join(employee, Inner, Some('employeeId === 'o_employeeId)), + _.select('orderId, 'employeeId)) } - test("triangles") { - val e0 = LocalRelation('srcId.int, 'dstId.int) - val v0 = LocalRelation('id.int, 'attr.int) - val e = KeyHint(List(ForeignKey(e0.output.head, v0.output.head), ForeignKey(e0.output.last, v0.output.head)), e0) - val v = KeyHint(List(UniqueKey(v0.output.head)), v0) - - val query = e.join(v, LeftOuter, Some('dstId === 'id)).select('srcId, 'id) + test("eliminate referential integrity inner join when foreign key is not null") { + checkJoinEliminated( + order, + _.join(customer, Inner, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId), + _.select('orderId, 'o_customerId.as('customerId))) + } - val optimized = Optimize.execute(query.analyze) - val correctAnswer = e.select('srcId, 'dstId.as('id)).analyze + test("eliminate referential integrity left/right outer join when foreign key is not null") { + checkJoinEliminated( + order, + _.join(customer, LeftOuter, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId), + _.select('orderId, 'o_customerId.as('customerId))) + + checkJoinEliminated( + order, + customer.join(_, RightOuter, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId), + _.select('orderId, 'o_customerId.as('customerId))) + } - comparePlans(optimized, correctAnswer) + test("do not eliminate referential integrity full outer join") { + checkJoinNotEliminated( + order, + _.join(customer, FullOuter, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId)) + } + test("eliminate referential integrity outer join despite alias") { + checkJoinEliminated( + order, + _.join(customer.select('customerId.as('customerId_alias), 'customerName), + LeftOuter, Some('customerId_alias === 'o_customerId)), + _.select('orderId, 'customerId_alias), + _.select('orderId, 'o_customerId.as('customerId_alias))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala index 14755c21dbfe5..7de8c5c5040ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala @@ -214,7 +214,7 @@ class KeyHintSuite extends QueryTest { Row(null, null))) } - test("do not eliminate referential integrity inner joins where foreign key is nullable") { + test("do not eliminate referential integrity inner join where foreign key is nullable") { val orderInnerJoin = orderInnerJoinView .select(order("id"), customer("id"), employee("id")) checkAnswer(orderInnerJoin, Seq( @@ -224,7 +224,7 @@ class KeyHintSuite extends QueryTest { checkJoinCount(orderInnerJoin, 1) } - test("eliminate referential integrity joins") { + test("eliminate referential integrity join") { val orderLeftOuterJoinEliminated = orderLeftOuterJoinView .select(order("id"), customer("id"), employee("id")) checkAnswer(orderLeftOuterJoinEliminated, Seq( @@ -240,7 +240,7 @@ class KeyHintSuite extends QueryTest { checkJoinsEliminated(orderRightOuterJoinEliminated) } - test("do not eliminate referential integrity full outer joins") { + test("do not eliminate referential integrity full outer join") { val orderCustomerFullOuterJoinNotEliminated = orderCustomerFullOuterJoinView .select(order("id"), order("customerId"), customer("id")) checkAnswer(orderCustomerFullOuterJoinNotEliminated, Seq( @@ -272,6 +272,4 @@ class KeyHintSuite extends QueryTest { Row(2, 0))) checkJoinsEliminated(bestFriendInnerJoinEliminated) } - - test("join followed by join") {} } From 9150ddaf2d598314ff3ea1fe4a434de37325d213 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 19 Aug 2015 05:14:53 -0700 Subject: [PATCH 12/23] Fix KeyHintSuite after merge --- .../org/apache/spark/sql/KeyHintSuite.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala index 7de8c5c5040ec..677fa58ad98a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala @@ -18,17 +18,20 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.TestSQLContext private object KeyHintTestData { - val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - case class Customer(id: Int, name: String) case class Employee(id: Int, name: String) case class Order(id: Int, customerId: Int, employeeId: Option[Int]) case class Manager(managerId: Int, subordinateId: Int) case class BestFriend(id: Int, friendId: Int) case class BannedCustomer(name: String) +} + +private class KeyHintTestData(ctx: SQLContext) { + import ctx.implicits._ + import KeyHintTestData._ val customer = ctx.sparkContext.parallelize(Seq( Customer(0, "alice"), @@ -103,8 +106,11 @@ private object KeyHintTestData { class KeyHintSuite extends QueryTest { - import KeyHintTestData._ - import KeyHintTestData.ctx.implicits._ + val ctx = new TestSQLContext() + private val data = new KeyHintTestData(ctx) + + import data._ + import ctx.implicits._ def checkJoinCount(df: DataFrame, joinCount: Int): Unit = { val joins = df.queryExecution.optimizedPlan.collect { From 873b3224b043875718959c645146743ed78084da Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 12 Oct 2015 18:47:47 -0700 Subject: [PATCH 13/23] In ForeignKey, store referencedRelation as logical plan Previously we stored its name as part of referencedAttr, requiring a catalog lookup. --- .../sql/catalyst/analysis/Analyzer.scala | 78 +++++++++---------- .../optimizer/joinEliminationPatterns.scala | 2 +- .../plans/logical/basicOperators.scala | 6 +- .../sql/catalyst/plans/logical/keys.scala | 11 ++- .../optimizer/JoinEliminationSuite.scala | 4 +- .../org/apache/spark/sql/DataFrame.scala | 11 ++- .../org/apache/spark/sql/KeyHintSuite.scala | 29 ++++--- 7 files changed, 69 insertions(+), 72 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6731b88bd19d1..87670ab49d69d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -362,22 +362,27 @@ class Analyzer( j case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) + def applyRewrites(plan: LogicalPlan): LogicalPlan = + plan transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } } - } - // In case there are foreign keys on the left side that reference attributes on the - // right, duplicate them so they also refer to the new attributes + val newRight = applyRewrites(right) + // Also apply the rewrites to foreign keys on the left side, because these are meant to + // reference the right side. (TODO: Why duplicate them instead of replacing?) val newLeft = if (left.keys.nonEmpty) { left.transform { case KeyHint(keys, child) => val newKeys = keys.collect { - case ForeignKey(attr, referencedAttr) => - ForeignKey(attr, attributeRewrites.get(referencedAttr).getOrElse(referencedAttr)) + case ForeignKey(attr, referencedRelation, referencedAttr) => + ForeignKey( + attr, + applyRewrites(referencedRelation), + attributeRewrites.get(referencedAttr).getOrElse(referencedAttr)) case other => other } KeyHint((keys ++ newKeys).distinct, child) @@ -394,43 +399,30 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) - // Special handling for foreign key references - look them up in the catalog + // Resolve referenced attributes of foreign keys using the referenced relation + // TODO: move this to its own rule? case h @ KeyHint(keys, child) if child.resolved && !h.foreignKeyReferencesResolved => KeyHint(keys.map { - case ForeignKey(k, u @ UnresolvedAttribute(nameParts)) => - ForeignKey(k, withPosition(u) { - // Resolve the target u of the foreign key as a column of this table or a table from - // the catalog - val (relation, referencedAttr) = - if (nameParts.length > 1) { - val relationName = nameParts.init - val referencedAttrName = nameParts.last - if (catalog.tableExists(relationName)) { - val relation = catalog.lookupRelation(relationName) - val referencedAttr = - relation.resolve(Seq(referencedAttrName), resolver).getOrElse(u).toAttribute - (relation, referencedAttr) - } else { - (UnresolvedRelation(relationName), u) - } - } else { - (h, h.resolve(nameParts, resolver).getOrElse(u).toAttribute) - } - - // Enforce the constraint that foreign keys can only reference unique keys - if (referencedAttr.resolved) { - val referencedAttrIsUnique = relation.keys.exists { - case UniqueKey(attr) if attr == referencedAttr => true - case _ => false - } - if (!referencedAttrIsUnique) { - failAnalysis("Foreign keys can only reference unique keys, but " + - s"$k references $u which is not unique.") - } + case ForeignKey(k, r, u @ UnresolvedAttribute(nameParts)) => withPosition(u) { + // The referenced relation r is itself guaranteed to be resolved already, so we can + // resolve u against it + val referencedAttr = r.resolve(nameParts, resolver).getOrElse(u).toAttribute + + // Enforce the constraint that foreign keys can only reference unique keys + if (referencedAttr.resolved) { + val referencedAttrIsUnique = r.keys.exists { + // TODO: use semanticEquals + case UniqueKey(attr) if attr == referencedAttr => true + case _ => false + } + if (!referencedAttrIsUnique) { + failAnalysis("Foreign keys can only reference unique keys, but " + + s"$k references $referencedAttr which is not unique.") } + } - referencedAttr - }) + ForeignKey(k, r, referencedAttr) + } case otherKey => otherKey }, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala index 7b8b0398907a5..c2f8eb4fab52b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -145,7 +145,7 @@ private class ForeignKeyFinder(plan: LogicalPlan, referencedPlan: LogicalPlan) { def foreignKeyExists(attr: Attribute, referencedAttr: Attribute): Boolean = { plan.keys.exists { - case ForeignKey(attr2, referencedAttr2) + case ForeignKey(attr2, _, referencedAttr2) if attr == attr2 && equivalent.query(referencedAttr, referencedAttr2) => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 84f199a2529e8..d00a10525f3af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -31,7 +31,7 @@ case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { override lazy val resolved: Boolean = newKeys.forall(_.resolved) && childrenResolved def foreignKeyReferencesResolved: Boolean = newKeys.forall { - case ForeignKey(_, referencedAttr) => referencedAttr.resolved + case ForeignKey(_, _, referencedAttr) => referencedAttr.resolved case _ => true } @@ -59,8 +59,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend child.keys.collect { case UniqueKey(attr) if aliasMap.contains(attr) => UniqueKey(aliasMap(attr)) - case ForeignKey(attr, referencedAttr) if aliasMap.contains(attr) => - ForeignKey(aliasMap(attr), referencedAttr) + case ForeignKey(attr, referencedRelation, referencedAttr) if aliasMap.contains(attr) => + ForeignKey(aliasMap(attr), referencedRelation, referencedAttr) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala index 15d62f1550c31..381d9ae764e79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala @@ -31,10 +31,15 @@ case class UniqueKey(attr: Attribute) extends Key { override def resolved: Boolean = attr.resolved } -/** Referenced column must be unique. */ -case class ForeignKey(attr: Attribute, referencedAttr: Attribute) extends Key { +/** Referenced column must be unique. Referenced relation must already be resolved. */ +case class ForeignKey( + attr: Attribute, + referencedRelation: LogicalPlan, + referencedAttr: Attribute) extends Key { + assert(referencedRelation.resolved) + override def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key = - ForeignKey(rule.applyOrElse(attr, identity[Attribute]), referencedAttr) + ForeignKey(rule.applyOrElse(attr, identity[Attribute]), referencedRelation, referencedAttr) override def resolved: Boolean = attr.resolved && referencedAttr.resolved } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala index 3e30110f6c199..5426f96aab8cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -52,8 +52,8 @@ class JoinEliminationSuite extends PlanTest { 'orderId.int.notNull, 'o_customerId.int.notNull, 'o_employeeId.int) KeyHint(List( UniqueKey(r.output(0)), - ForeignKey(r.output(1), customer.output(0)), - ForeignKey(r.output(2), employee.output(0))), r) + ForeignKey(r.output(1), customer, customer.output(0)), + ForeignKey(r.output(2), employee, employee.output(0))), r) } val bannedCustomer = { val r = LocalRelation('bannedCustomerName.string.notNull) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7794ec809f612..56496366c727c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -677,13 +677,16 @@ class DataFrame private[sql]( * Declares a foreign key referencing a key from this or another DataFrame. The referenced key * must be declared as a unique key. * {{{ - * department.uniqueKey("id").registerTempTable("department") - * employee.foreignKey("departmentId", "department.id") + * val department = dept.uniqueKey("id") + * employee.foreignKey("departmentId", department, "id") * }}} */ - def foreignKey(col: String, referencedCol: String): DataFrame = + def foreignKey(col: String, referencedDF: DataFrame, referencedCol: String): DataFrame = KeyHintCollapsing( - KeyHint(List(ForeignKey(UnresolvedAttribute(col), UnresolvedAttribute(referencedCol))), + KeyHint(List(ForeignKey( + UnresolvedAttribute(col), + referencedDF.logicalPlan, + UnresolvedAttribute(referencedCol))), logicalPlan)) /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala index 677fa58ad98a0..e7acc40fa0c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala @@ -38,30 +38,27 @@ private class KeyHintTestData(ctx: SQLContext) { Customer(1, "bob"), Customer(2, "alice"))).toDF() .uniqueKey("id") - customer.registerTempTable("customer") val employee = ctx.sparkContext.parallelize(Seq( Employee(0, "charlie"), Employee(1, "dan"))).toDF() .uniqueKey("id") - employee.registerTempTable("employee") val order = ctx.sparkContext.parallelize(Seq( Order(0, 0, Some(0)), Order(1, 1, None))).toDF() - .foreignKey("customerId", "customer.id") - .foreignKey("employeeId", "employee.id") + .foreignKey("customerId", customer, "id") + .foreignKey("employeeId", employee, "id") val manager = ctx.sparkContext.parallelize(Seq( Manager(0, 1))).toDF() - .foreignKey("managerId", "employee.id") - .foreignKey("subordinateId", "employee.id") - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - val bestFriend = ctx.sparkContext.parallelize(Seq( - BestFriend(0, 1), - BestFriend(1, 2), - BestFriend(2, 0))).toDF() - .uniqueKey("id") - .foreignKey("friendId", "bestFriend.id") - bestFriend.registerTempTable("bestFriend") - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) + .foreignKey("managerId", employee, "id") + .foreignKey("subordinateId", employee, "id") + val bestFriend = { + val tmp = ctx.sparkContext.parallelize(Seq( + BestFriend(0, 1), + BestFriend(1, 2), + BestFriend(2, 0))).toDF() + .uniqueKey("id") + tmp.foreignKey("friendId", tmp, "id") + } val bannedCustomer = ctx.sparkContext.parallelize(Seq( BannedCustomer("alice"), BannedCustomer("eve"))).toDF() @@ -193,7 +190,7 @@ class KeyHintSuite extends QueryTest { test("can't create foreign key referencing non-unique column") { intercept[AnalysisException] { - bannedCustomer.foreignKey("name", "customer.name") + bannedCustomer.foreignKey("name", customer, "name") } } From 98e0b5e316b1692a188dedc6b49daaa5854a064b Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 12 Oct 2015 19:45:21 -0700 Subject: [PATCH 14/23] Use semanticEquals for Attributes --- .../sql/catalyst/analysis/Analyzer.scala | 3 +-- .../optimizer/joinEliminationPatterns.scala | 22 +++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 87670ab49d69d..bcad2c90ff0e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -411,8 +411,7 @@ class Analyzer( // Enforce the constraint that foreign keys can only reference unique keys if (referencedAttr.resolved) { val referencedAttrIsUnique = r.keys.exists { - // TODO: use semanticEquals - case UniqueKey(attr) if attr == referencedAttr => true + case UniqueKey(attr) if attr semanticEquals referencedAttr => true case _ => false } if (!referencedAttrIsUnique) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala index c2f8eb4fab52b..19b3cadd438f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -146,13 +146,14 @@ private class ForeignKeyFinder(plan: LogicalPlan, referencedPlan: LogicalPlan) { def foreignKeyExists(attr: Attribute, referencedAttr: Attribute): Boolean = { plan.keys.exists { case ForeignKey(attr2, _, referencedAttr2) - if attr == attr2 && equivalent.query(referencedAttr, referencedAttr2) => true + if (attr semanticEquals attr2) + && equivalent.query(referencedAttr, referencedAttr2) => true case _ => false } } - private def equivalences(plan: LogicalPlan): MutableDisjointSet[Attribute] = { - val s = new MutableDisjointSet[Attribute] + private def equivalences(plan: LogicalPlan): MutableDisjointAttributeSets = { + val s = new MutableDisjointAttributeSets plan.collect { case Project(projectList, _) => projectList.collect { case a @ Alias(old: Attribute, _) => s.union(old, a.toAttribute) @@ -163,15 +164,14 @@ private class ForeignKeyFinder(plan: LogicalPlan, referencedPlan: LogicalPlan) { } -private class MutableDisjointSet[A]() { - import scala.collection.mutable.Set - private var sets = Set[Set[A]]() - def add(x: A): Unit = { +private class MutableDisjointAttributeSets() { + private var sets = Set[AttributeSet]() + def add(x: Attribute): Unit = { if (!sets.exists(_.contains(x))) { - sets += Set(x) + sets += AttributeSet(x) } } - def union(x: A, y: A): Unit = { + def union(x: Attribute, y: Attribute): Unit = { add(x) add(y) val xSet = sets.find(_.contains(x)).get @@ -180,7 +180,7 @@ private class MutableDisjointSet[A]() { sets -= ySet sets += (xSet ++ ySet) } - def query(x: A, y: A): Boolean = { - x == y || sets.exists(s => s.contains(x) && s.contains(y)) + def query(x: Attribute, y: Attribute): Boolean = { + (x semanticEquals y) || sets.exists(s => s.contains(x) && s.contains(y)) } } From d43a2c005b091e571a9d5dc3cc7d22e22a29ffd0 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 12 Oct 2015 20:37:35 -0700 Subject: [PATCH 15/23] Remove TODOs --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 +-- .../spark/sql/catalyst/optimizer/joinEliminationPatterns.scala | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bcad2c90ff0e2..f828f19e2b883 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -372,7 +372,7 @@ class Analyzer( } val newRight = applyRewrites(right) // Also apply the rewrites to foreign keys on the left side, because these are meant to - // reference the right side. (TODO: Why duplicate them instead of replacing?) + // reference the right side. val newLeft = if (left.keys.nonEmpty) { left.transform { @@ -400,7 +400,6 @@ class Analyzer( Sort(newOrdering, global, child) // Resolve referenced attributes of foreign keys using the referenced relation - // TODO: move this to its own rule? case h @ KeyHint(keys, child) if child.resolved && !h.foreignKeyReferencesResolved => KeyHint(keys.map { case ForeignKey(k, r, u @ UnresolvedAttribute(nameParts)) => withPosition(u) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala index 19b3cadd438f9..ad26169f99787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -161,7 +161,6 @@ private class ForeignKeyFinder(plan: LogicalPlan, referencedPlan: LogicalPlan) { } s } - } private class MutableDisjointAttributeSets() { From f4e7e0140865df27f3c0b000f22d69117316070e Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 12 Oct 2015 21:02:02 -0700 Subject: [PATCH 16/23] Add more comments --- .../sql/catalyst/optimizer/Optimizer.scala | 20 +++++-- .../catalyst/plans/logical/LogicalPlan.scala | 8 ++- .../plans/logical/basicOperators.scala | 55 ++++++++++--------- .../sql/catalyst/plans/logical/keys.scala | 12 +++- .../org/apache/spark/sql/DataFrame.scala | 8 ++- 5 files changed, 67 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2117d141db500..341c5e5525f06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -65,6 +65,8 @@ object DefaultOptimizer extends Optimizer { Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: + // Hints are necessary for some operator optimizations but interfere with others, so we run the + // rules with them, then remove them and run the rules again. Batch("Operator Optimizations", FixedPoint(100), operatorOptimizations: _*) :: Batch("Remove Hints", FixedPoint(100), @@ -277,6 +279,9 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +/** + * Combines two adjacent [[KeyHint]]s into one by merging their key lists. + */ object KeyHintCollapsing extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case KeyHint(keys1, KeyHint(keys2, child)) => @@ -316,12 +321,6 @@ object JoinElimination extends Rule[LogicalPlan] { } -object RemoveKeyHints extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case KeyHint(_, child) => child - } -} - /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given @@ -809,6 +808,15 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } +/** + * Removes [[KeyHint]]s from the plan to avoid interfering with other rules. + */ +object RemoveKeyHints extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case KeyHint(_, child) => child + } +} + /** * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 1107d60d1a634..5903da78fe1c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { - def keys: Seq[Key] = Seq.empty - private var _analyzed: Boolean = false /** @@ -78,6 +76,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } } + /** + * The unique and foreign key constraints that will hold for the output of this plan. Specific + * plan nodes can override this to introduce or propagate keys. + */ + def keys: Seq[Key] = Seq.empty + /** * Computes [[Statistics]] for this plan. The default implementation assumes the output * cardinality is the product of of all child plan's cardinality, i.e. applies in the case diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d00a10525f3af..9439b1efca303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -23,31 +23,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - override def keys: Seq[Key] = newKeys ++ child.keys - - override lazy val resolved: Boolean = newKeys.forall(_.resolved) && childrenResolved - - def foreignKeyReferencesResolved: Boolean = newKeys.forall { - case ForeignKey(_, _, referencedAttr) => referencedAttr.resolved - case _ => true - } - - override def transformExpressionsDown( - rule: PartialFunction[Expression, Expression]): this.type = { - KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) - .asInstanceOf[this.type] - } - - override def transformExpressionsUp( - rule: PartialFunction[Expression, Expression]): this.type = { - KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) - .asInstanceOf[this.type] - } -} - case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -496,3 +471,33 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { childrenResolved && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } + +/** + * A hint to the optimizer that the given key constraints hold for the output of the child plan. + */ +case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def keys: Seq[Key] = newKeys ++ child.keys + + override lazy val resolved: Boolean = newKeys.forall(_.resolved) && childrenResolved + + def foreignKeyReferencesResolved: Boolean = newKeys.forall { + case ForeignKey(_, _, referencedAttr) => referencedAttr.resolved + case _ => true + } + + /** Overridden here to apply `rule` to the keys as well as the child plan. */ + override def transformExpressionsDown( + rule: PartialFunction[Expression, Expression]): this.type = { + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] + } + + /** Overridden here to apply `rule` to the keys as well as the child plan. */ + override def transformExpressionsUp( + rule: PartialFunction[Expression, Expression]): this.type = { + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala index 381d9ae764e79..e9ea9fe3eabbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala @@ -19,11 +19,17 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.Attribute +/** + * A key constraint on the output of a [[LogicalPlan]]. + */ sealed abstract class Key { def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key def resolved: Boolean } +/** + * Declares that the values of `attr` are unique. + */ case class UniqueKey(attr: Attribute) extends Key { override def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key = UniqueKey(rule.applyOrElse(attr, identity[Attribute])) @@ -31,7 +37,11 @@ case class UniqueKey(attr: Attribute) extends Key { override def resolved: Boolean = attr.resolved } -/** Referenced column must be unique. Referenced relation must already be resolved. */ +/** + * Declares that the values of `attr` reference `referencedAttr`, which is a unique key in + * `referencedRelation`. Note that the `referencedRelation` plan must contain a unique key + * constraint on `referencedAttr`, and it must be resolved. + */ case class ForeignKey( attr: Attribute, referencedRelation: LogicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 56496366c727c..bb5315af10771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -669,13 +669,17 @@ class DataFrame private[sql]( */ def as(alias: Symbol): DataFrame = as(alias.name) + /** + * Declares that the values of the given column are unique. + */ def uniqueKey(col: String): DataFrame = { KeyHintCollapsing(KeyHint(List(UniqueKey(UnresolvedAttribute(col))), logicalPlan)) } /** - * Declares a foreign key referencing a key from this or another DataFrame. The referenced key - * must be declared as a unique key. + * Declares that the values of the given column reference a unique column from another + * [[DataFrame]]. The referenced column must be declared as a unique key within the referenced + * [[DataFrame]]: * {{{ * val department = dept.uniqueKey("id") * employee.foreignKey("departmentId", department, "id") From 578797c456e20d0fb07bf10cb3e64f09065948f9 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 12 Oct 2015 21:38:46 -0700 Subject: [PATCH 17/23] Use SharedSQLContext in KeyHintSuite --- .../org/apache/spark/sql/KeyHintSuite.scala | 52 ++++++++----------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala index e7acc40fa0c31..bda1828b6b183 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.TestSQLContext -private object KeyHintTestData { +private object KeyHintSuite { case class Customer(id: Int, name: String) case class Employee(id: Int, name: String) case class Order(id: Int, customerId: Int, employeeId: Option[Int]) @@ -29,85 +30,76 @@ private object KeyHintTestData { case class BannedCustomer(name: String) } -private class KeyHintTestData(ctx: SQLContext) { - import ctx.implicits._ - import KeyHintTestData._ +class KeyHintSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + import KeyHintSuite._ - val customer = ctx.sparkContext.parallelize(Seq( + lazy val customer = sqlContext.sparkContext.parallelize(Seq( Customer(0, "alice"), Customer(1, "bob"), Customer(2, "alice"))).toDF() .uniqueKey("id") - val employee = ctx.sparkContext.parallelize(Seq( + lazy val employee = sqlContext.sparkContext.parallelize(Seq( Employee(0, "charlie"), Employee(1, "dan"))).toDF() .uniqueKey("id") - val order = ctx.sparkContext.parallelize(Seq( + lazy val order = sqlContext.sparkContext.parallelize(Seq( Order(0, 0, Some(0)), Order(1, 1, None))).toDF() .foreignKey("customerId", customer, "id") .foreignKey("employeeId", employee, "id") - val manager = ctx.sparkContext.parallelize(Seq( + lazy val manager = sqlContext.sparkContext.parallelize(Seq( Manager(0, 1))).toDF() .foreignKey("managerId", employee, "id") .foreignKey("subordinateId", employee, "id") - val bestFriend = { - val tmp = ctx.sparkContext.parallelize(Seq( + lazy val bestFriend = { + val tmp = sqlContext.sparkContext.parallelize(Seq( BestFriend(0, 1), BestFriend(1, 2), BestFriend(2, 0))).toDF() .uniqueKey("id") tmp.foreignKey("friendId", tmp, "id") } - val bannedCustomer = ctx.sparkContext.parallelize(Seq( + lazy val bannedCustomer = sqlContext.sparkContext.parallelize(Seq( BannedCustomer("alice"), BannedCustomer("eve"))).toDF() .uniqueKey("name") // Joins involving referential integrity (a foreign key referencing a unique key) - val orderInnerJoinView = order + lazy val orderInnerJoinView = order .join(customer, order("customerId") === customer("id")) .join(employee, order("employeeId") === employee("id")) - val orderLeftOuterJoinView = order + lazy val orderLeftOuterJoinView = order .join(customer, order("customerId") === customer("id"), "left_outer") .join(employee, order("employeeId") === employee("id"), "left_outer") - val orderRightOuterJoinView = employee.join( + lazy val orderRightOuterJoinView = employee.join( customer.join(order, order("customerId") === customer("id"), "right_outer"), order("employeeId") === employee("id"), "right_outer") - val orderCustomerFullOuterJoinView = order + lazy val orderCustomerFullOuterJoinView = order .join(customer, order("customerId") === customer("id"), "full_outer") - val orderEmployeeFullOuterJoinView = order + lazy val orderEmployeeFullOuterJoinView = order .join(employee, order("employeeId") === employee("id"), "full_outer") - val managerInnerJoinView = manager + lazy val managerInnerJoinView = manager .join(employee.as("emp_manager"), manager("managerId") === $"emp_manager.id") .join(employee.as("emp_subordinate"), manager("subordinateId") === $"emp_subordinate.id") - val bestFriendInnerJoinView = bestFriend + lazy val bestFriendInnerJoinView = bestFriend .join(bestFriend.as("bestFriend2"), bestFriend("friendId") === $"bestFriend2.id") // Joins involving only a unique key - val bannedCustomerInnerJoinView = customer + lazy val bannedCustomerInnerJoinView = customer .join(bannedCustomer, bannedCustomer("name") === customer("name")) - val bannedCustomerLeftOuterJoinView = customer + lazy val bannedCustomerLeftOuterJoinView = customer .join(bannedCustomer, bannedCustomer("name") === customer("name"), "left_outer") - val bannedCustomerFullOuterJoinView = customer + lazy val bannedCustomerFullOuterJoinView = customer .join(bannedCustomer, bannedCustomer("name") === customer("name"), "full_outer") -} - -class KeyHintSuite extends QueryTest { - - val ctx = new TestSQLContext() - private val data = new KeyHintTestData(ctx) - - import data._ - import ctx.implicits._ def checkJoinCount(df: DataFrame, joinCount: Int): Unit = { val joins = df.queryExecution.optimizedPlan.collect { From 7c7357bf9c1e8bab3f2d828dd8bc3d6f7d851196 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 12 Oct 2015 21:53:00 -0700 Subject: [PATCH 18/23] Remove long URLs They were references to the join elimination logic in Teradata, which is really just a standard optimization rule. --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 -- .../spark/sql/catalyst/optimizer/joinEliminationPatterns.scala | 2 -- 2 files changed, 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 35984552d1609..9578b62d77b10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -350,8 +350,6 @@ object KeyHintCollapsing extends Rule[LogicalPlan] { /** * Eliminates keyed equi-joins when followed by a [[Project]] that only keeps columns from one side. - * - * See [[http://www.info.teradata.com/HTMLPubs/DB_TTU_14_00/index.html#page/SQL_Reference/B035_1142_111A/ch02.124.042.html#ww17434326]]. */ object JoinElimination extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala index ad26169f99787..668bc5a4a8dcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -63,8 +63,6 @@ object CanEliminateUniqueKeyOuterJoin { * containing the referenced unique key is referred to as the parent table. * * For inner joins, all involved foreign keys must be non-nullable. - * - * See [[http://www.info.teradata.com/HTMLPubs/DB_TTU_14_00/index.html#page/SQL_Reference/B035_1142_111A/ch02.124.045.html]]. */ object CanEliminateReferentialIntegrityJoin { /** (parent, child, primaryForeignMap, projectList) */ From 50717599f1eb5bf2184a6b1df2e0aebabdebddec Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 13 Oct 2015 00:36:00 -0700 Subject: [PATCH 19/23] Fix override of KeyHint#transformExpressions{Up,Down} --- .../sql/catalyst/plans/logical/basicOperators.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 9cc24f98f1152..84e398ff65cd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -485,14 +485,16 @@ case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { /** Overridden here to apply `rule` to the keys as well as the child plan. */ override def transformExpressionsDown( rule: PartialFunction[Expression, Expression]): this.type = { - KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) - .asInstanceOf[this.type] + KeyHint( + newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), + child.transformExpressionsDown(rule)).asInstanceOf[this.type] } /** Overridden here to apply `rule` to the keys as well as the child plan. */ override def transformExpressionsUp( rule: PartialFunction[Expression, Expression]): this.type = { - KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) - .asInstanceOf[this.type] + KeyHint( + newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), + child.transformExpressionsUp(rule)).asInstanceOf[this.type] } } From ec2b80bff89c79856bed68e4ed367eee0cacf8d2 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 13 Oct 2015 00:37:17 -0700 Subject: [PATCH 20/23] Declare new DataFrame methods extra-experimental --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7f3cb353e4833..f4cfabc7634b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -689,13 +689,16 @@ class DataFrame private[sql]( def as(alias: Symbol): DataFrame = as(alias.name) /** + * :: Experimental :: * Declares that the values of the given column are unique. */ + @Experimental def uniqueKey(col: String): DataFrame = { KeyHintCollapsing(KeyHint(List(UniqueKey(UnresolvedAttribute(col))), logicalPlan)) } /** + * :: Experimental :: * Declares that the values of the given column reference a unique column from another * [[DataFrame]]. The referenced column must be declared as a unique key within the referenced * [[DataFrame]]: @@ -704,6 +707,7 @@ class DataFrame private[sql]( * employee.foreignKey("departmentId", department, "id") * }}} */ + @Experimental def foreignKey(col: String, referencedDF: DataFrame, referencedCol: String): DataFrame = KeyHintCollapsing( KeyHint(List(ForeignKey( From 55bb1354efcef98944caf96f8d59dc2f4a6459c0 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 13 Oct 2015 00:50:22 -0700 Subject: [PATCH 21/23] Explain why we keep old keys in self-join rewrite --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 697f2faeceabb..c20806dfc29ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -383,8 +383,8 @@ class Analyzer( attr, applyRewrites(referencedRelation), attributeRewrites.get(referencedAttr).getOrElse(referencedAttr)) - case other => other } + // Keep the old keys as well to accommodate future self-joins KeyHint((keys ++ newKeys).distinct, child) } } else { From e1ec23da83d02adafbe1fdc7852e258f9289d293 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 14 Oct 2015 17:14:45 -0700 Subject: [PATCH 22/23] Revert "Fix override of KeyHint#transformExpressions{Up,Down}" This reverts commit 50717599f1eb5bf2184a6b1df2e0aebabdebddec. --- .../sql/catalyst/plans/logical/basicOperators.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 84e398ff65cd8..9cc24f98f1152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -485,16 +485,14 @@ case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { /** Overridden here to apply `rule` to the keys as well as the child plan. */ override def transformExpressionsDown( rule: PartialFunction[Expression, Expression]): this.type = { - KeyHint( - newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), - child.transformExpressionsDown(rule)).asInstanceOf[this.type] + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] } /** Overridden here to apply `rule` to the keys as well as the child plan. */ override def transformExpressionsUp( rule: PartialFunction[Expression, Expression]): this.type = { - KeyHint( - newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), - child.transformExpressionsUp(rule)).asInstanceOf[this.type] + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] } } From 0cd8a9185e96d0f21a8bd9a437c124566b9f2ce1 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 14 Oct 2015 17:23:49 -0700 Subject: [PATCH 23/23] Update transformExpressions override comments --- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 9cc24f98f1152..260c107b8eb10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -482,14 +482,14 @@ case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { case _ => true } - /** Overridden here to apply `rule` to the keys as well as the child plan. */ + /** Overridden here to apply `rule` to the keys. */ override def transformExpressionsDown( rule: PartialFunction[Expression, Expression]): this.type = { KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) .asInstanceOf[this.type] } - /** Overridden here to apply `rule` to the keys as well as the child plan. */ + /** Overridden here to apply `rule` to the keys. */ override def transformExpressionsUp( rule: PartialFunction[Expression, Expression]): this.type = { KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child)