diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 899ee67352df..ca527d2d8127 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -383,14 +383,35 @@ class Analyzer( j case Some((oldRelation, newRelation)) => val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) + def applyRewrites(plan: LogicalPlan): LogicalPlan = + plan transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } } - } - j.copy(right = newRight) + val newRight = applyRewrites(right) + // Also apply the rewrites to foreign keys on the left side, because these are meant to + // reference the right side. + val newLeft = + if (left.keys.nonEmpty) { + left.transform { + case KeyHint(keys, child) => + val newKeys = keys.collect { + case ForeignKey(attr, referencedRelation, referencedAttr) => + ForeignKey( + attr, + applyRewrites(referencedRelation), + attributeRewrites.get(referencedAttr).getOrElse(referencedAttr)) + } + // Keep the old keys as well to accommodate future self-joins + KeyHint((keys ++ newKeys).distinct, child) + } + } else { + left + } + j.copy(left = newLeft, right = newRight) } // When resolve `SortOrder`s in Sort based on child, don't report errors as @@ -399,6 +420,32 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // Resolve referenced attributes of foreign keys using the referenced relation + case h @ KeyHint(keys, child) if child.resolved && !h.foreignKeyReferencesResolved => + KeyHint(keys.map { + case ForeignKey(k, r, u @ UnresolvedAttribute(nameParts)) => withPosition(u) { + // The referenced relation r is itself guaranteed to be resolved already, so we can + // resolve u against it + val referencedAttr = r.resolve(nameParts, resolver).getOrElse(u).toAttribute + + // Enforce the constraint that foreign keys can only reference unique keys + if (referencedAttr.resolved) { + val referencedAttrIsUnique = r.keys.exists { + case UniqueKey(attr) if attr semanticEquals referencedAttr => true + case _ => false + } + if (!referencedAttrIsUnique) { + failAnalysis("Foreign keys can only reference unique keys, but " + + s"$k references $referencedAttr which is not unique.") + } + } + + ForeignKey(k, r, referencedAttr) + } + + case otherKey => otherKey + }, child) + // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 338c5193cb7a..40cc0f62267a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.RightOuter @@ -32,14 +33,7 @@ import org.apache.spark.sql.types._ abstract class Optimizer extends RuleExecutor[LogicalPlan] object DefaultOptimizer extends Optimizer { - val batches = - // SubQueries are only needed for analysis and can be removed before execution. - Batch("Remove SubQueries", FixedPoint(100), - EliminateSubQueries) :: - Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, - RemoveLiteralFromGroupExpressions) :: - Batch("Operator Optimizations", FixedPoint(100), + val operatorOptimizations = Seq( // Operator push down SetOperationPushDown, SamplePushDown, @@ -50,6 +44,8 @@ object DefaultOptimizer extends Optimizer { ColumnPruning, // Operator combine ProjectCollapsing, + KeyHintCollapsing, + JoinElimination, CombineFilters, CombineLimits, // Constant folding @@ -61,7 +57,21 @@ object DefaultOptimizer extends Optimizer { RemoveDispensable, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions) + + val batches = + // SubQueries are only needed for analysis and can be removed before execution. + Batch("Remove SubQueries", FixedPoint(100), + EliminateSubQueries) :: + Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: + // Hints are necessary for some operator optimizations but interfere with others, so we run the + // rules with them, then remove them and run the rules again. + Batch("Operator Optimizations", FixedPoint(100), + operatorOptimizations: _*) :: + Batch("Remove Hints", FixedPoint(100), + (RemoveKeyHints +: operatorOptimizations): _*) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -325,6 +335,46 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +/** + * Combines two adjacent [[KeyHint]]s into one by merging their key lists. + */ +object KeyHintCollapsing extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case KeyHint(keys1, KeyHint(keys2, child)) => + KeyHint((keys1 ++ keys2).distinct, child) + } +} + +/** + * Eliminates keyed equi-joins when followed by a [[Project]] that only keeps columns from one side. + */ +object JoinElimination extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case CanEliminateUniqueKeyOuterJoin(outer, projectList) => + Project(projectList, outer) + case CanEliminateReferentialIntegrityJoin(parent, child, primaryForeignMap, projectList) => + Project(substituteParentForChild(projectList, parent, primaryForeignMap), child) + } + + /** + * In the given expressions, substitute all references to parent columns with references to the + * corresponding child columns. The `primaryForeignMap` contains these equivalences, extracted + * from the equality join expressions. + */ + private def substituteParentForChild( + expressions: Seq[NamedExpression], + parent: LogicalPlan, + primaryForeignMap: AttributeMap[Attribute]) + : Seq[NamedExpression] = { + expressions.map(_.transform { + case a: Attribute => + if (parent.outputSet.contains(a)) Alias(primaryForeignMap(a), a.name)(a.exprId) + else a + }.asInstanceOf[NamedExpression]) + } + +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given @@ -844,6 +894,15 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } +/** + * Removes [[KeyHint]]s from the plan to avoid interfering with other rules. + */ +object RemoveKeyHints extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case KeyHint(_, child) => child + } +} + /** * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala new file mode 100644 index 000000000000..668bc5a4a8dc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joinEliminationPatterns.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * Finds left or right outer joins where only the outer table's columns are kept, and a key from the + * inner table is involved in the join so no duplicates would be generated. + */ +object CanEliminateUniqueKeyOuterJoin { + /** (outer, projectList) */ + type ReturnType = (LogicalPlan, Seq[NamedExpression]) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case p @ Project(projectList, + ExtractEquiJoinKeys( + joinType @ (LeftOuter | RightOuter), leftJoinExprs, rightJoinExprs, _, left, right)) => + val (outer, inner, innerJoinExprs) = (joinType: @unchecked) match { + case LeftOuter => (left, right, rightJoinExprs) + case RightOuter => (right, left, leftJoinExprs) + } + + val onlyOuterColsKept = AttributeSet(projectList).subsetOf(outer.outputSet) + + val innerUniqueKeys = AttributeSet(inner.keys.collect { case UniqueKey(attr) => attr }) + val innerKeyIsInvolved = innerUniqueKeys.intersect(AttributeSet(innerJoinExprs)).nonEmpty + + if (onlyOuterColsKept && innerKeyIsInvolved) { + Some((outer, projectList)) + } else { + None + } + + case _ => None + } +} + +/** + * Finds joins based on foreign-key referential integrity, followed by [[Project]]s that reference + * no columns from the parent table other than the referenced unique keys. Such joins can be + * eliminated and replaced by the child table. + * + * The table containing the foreign key is referred to as the child table, while the table + * containing the referenced unique key is referred to as the parent table. + * + * For inner joins, all involved foreign keys must be non-nullable. + */ +object CanEliminateReferentialIntegrityJoin { + /** (parent, child, primaryForeignMap, projectList) */ + type ReturnType = + (LogicalPlan, LogicalPlan, AttributeMap[Attribute], Seq[NamedExpression]) + + def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { + case p @ Project(projectList, ExtractEquiJoinKeys( + joinType @ (Inner | LeftOuter | RightOuter), + leftJoinExprs, rightJoinExprs, _, left, right)) => + val innerJoin = joinType == Inner + + val leftParentPFM = getPrimaryForeignMap(left, right, leftJoinExprs, rightJoinExprs) + val rightForeignKeysAreNonNullable = leftParentPFM.values.forall(!_.nullable) + val leftIsParent = + (leftParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, leftParentPFM, left) + && (!innerJoin || rightForeignKeysAreNonNullable)) + + val rightParentPFM = getPrimaryForeignMap(right, left, rightJoinExprs, leftJoinExprs) + val leftForeignKeysAreNonNullable = rightParentPFM.values.forall(!_.nullable) + val rightIsParent = + (rightParentPFM.nonEmpty && onlyPrimaryKeysKept(projectList, rightParentPFM, right) + && (!innerJoin || leftForeignKeysAreNonNullable)) + + if (leftIsParent) { + Some((left, right, leftParentPFM, projectList)) + } else if (rightIsParent) { + Some((right, left, rightParentPFM, projectList)) + } else { + None + } + + case _ => None + } + + /** + * Return a map where, for each PK=FK join expression based on referential integrity between + * `parent` and `child`, the unique key from `parent` is mapped to its corresponding foreign + * key from `child`. + */ + private def getPrimaryForeignMap( + parent: LogicalPlan, + child: LogicalPlan, + parentJoinExprs: Seq[Expression], + childJoinExprs: Seq[Expression]) + : AttributeMap[Attribute] = { + val primaryKeys = AttributeSet(parent.keys.collect { case UniqueKey(attr) => attr }) + val foreignKeys = new ForeignKeyFinder(child, parent) + AttributeMap(parentJoinExprs.zip(childJoinExprs).collect { + case (parentExpr: NamedExpression, childExpr: NamedExpression) + if primaryKeys.contains(parentExpr.toAttribute) + && foreignKeys.foreignKeyExists(childExpr.toAttribute, parentExpr.toAttribute) => + (parentExpr.toAttribute, childExpr.toAttribute) + }) + } + + /** + * Return true if `kept` references no columns from `parent` except those involved in a PK=FK + * join expression. Such join expressions are stored in `primaryForeignMap`. + */ + private def onlyPrimaryKeysKept( + kept: Seq[NamedExpression], + primaryForeignMap: AttributeMap[Attribute], + parent: LogicalPlan) + : Boolean = { + AttributeSet(kept).forall { keptAttr => + if (parent.outputSet.contains(keptAttr)) { + primaryForeignMap.contains(keptAttr) + } else { + true + } + } + } +} + +private class ForeignKeyFinder(plan: LogicalPlan, referencedPlan: LogicalPlan) { + val equivalent = equivalences(referencedPlan) + + def foreignKeyExists(attr: Attribute, referencedAttr: Attribute): Boolean = { + plan.keys.exists { + case ForeignKey(attr2, _, referencedAttr2) + if (attr semanticEquals attr2) + && equivalent.query(referencedAttr, referencedAttr2) => true + case _ => false + } + } + + private def equivalences(plan: LogicalPlan): MutableDisjointAttributeSets = { + val s = new MutableDisjointAttributeSets + plan.collect { + case Project(projectList, _) => projectList.collect { + case a @ Alias(old: Attribute, _) => s.union(old, a.toAttribute) + } + } + s + } +} + +private class MutableDisjointAttributeSets() { + private var sets = Set[AttributeSet]() + def add(x: Attribute): Unit = { + if (!sets.exists(_.contains(x))) { + sets += AttributeSet(x) + } + } + def union(x: Attribute, y: Attribute): Unit = { + add(x) + add(y) + val xSet = sets.find(_.contains(x)).get + val ySet = sets.find(_.contains(y)).get + sets -= xSet + sets -= ySet + sets += (xSet ++ ySet) + } + def query(x: Attribute, y: Attribute): Boolean = { + (x semanticEquals y) || sets.exists(s => s.contains(x) && s.contains(y)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e10593..7c680b7c9bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -76,6 +76,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } } + /** + * The unique and foreign key constraints that will hold for the output of this plan. Specific + * plan nodes can override this to introduce or propagate keys. + */ + def keys: Seq[Key] = Seq.empty + /** * Computes [[Statistics]] for this plan. The default implementation assumes the output * cardinality is the product of of all child plan's cardinality, i.e. applies in the case diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4cb67aacf33e..6904205ad657 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -27,6 +27,19 @@ import org.apache.spark.util.collection.OpenHashSet case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def keys: Seq[Key] = { + val aliasMap = AttributeMap(projectList.collect { + case a @ Alias(old: AttributeReference, _) => (old, a.toAttribute) + case r: AttributeReference => (r, r) + }) + child.keys.collect { + case UniqueKey(attr) if aliasMap.contains(attr) => + UniqueKey(aliasMap(attr)) + case ForeignKey(attr, referencedRelation, referencedAttr) if aliasMap.contains(attr) => + ForeignKey(aliasMap(attr), referencedRelation, referencedAttr) + } + } + override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { case agg: AggregateExpression => agg @@ -137,6 +150,16 @@ case class Join( } } + override def keys: Seq[Key] = { + // TODO: try to propagate unique keys as well as foreign keys + def fk(keys: Seq[Key]): Seq[ForeignKey] = keys.collect { case k: ForeignKey => k } + joinType match { + case LeftSemi | LeftOuter => fk(left.keys) + case RightOuter => fk(right.keys) + case _ => fk(left.keys ++ right.keys) + } + } + def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. @@ -388,6 +411,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) + override def keys: Seq[Key] = child.keys } /** @@ -445,6 +469,36 @@ case object OneRowRelation extends LeafNode { } /** + * A hint to the optimizer that the given key constraints hold for the output of the child plan. + */ +case class KeyHint(newKeys: Seq[Key], child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def keys: Seq[Key] = newKeys ++ child.keys + + override lazy val resolved: Boolean = newKeys.forall(_.resolved) && childrenResolved + + def foreignKeyReferencesResolved: Boolean = newKeys.forall { + case ForeignKey(_, _, referencedAttr) => referencedAttr.resolved + case _ => true + } + + /** Overridden here to apply `rule` to the keys. */ + override def transformExpressionsDown( + rule: PartialFunction[Expression, Expression]): this.type = { + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] + } + + /** Overridden here to apply `rule` to the keys. */ + override def transformExpressionsUp( + rule: PartialFunction[Expression, Expression]): this.type = { + KeyHint(newKeys.map(_.transformAttribute(rule.andThen(_.asInstanceOf[Attribute]))), child) + .asInstanceOf[this.type] + } +} + +/* * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are * used respectively to decode/encode from the JVM object representation expected by `func.` */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala new file mode 100644 index 000000000000..e9ea9fe3eabb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/keys.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * A key constraint on the output of a [[LogicalPlan]]. + */ +sealed abstract class Key { + def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key + def resolved: Boolean +} + +/** + * Declares that the values of `attr` are unique. + */ +case class UniqueKey(attr: Attribute) extends Key { + override def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key = + UniqueKey(rule.applyOrElse(attr, identity[Attribute])) + + override def resolved: Boolean = attr.resolved +} + +/** + * Declares that the values of `attr` reference `referencedAttr`, which is a unique key in + * `referencedRelation`. Note that the `referencedRelation` plan must contain a unique key + * constraint on `referencedAttr`, and it must be resolved. + */ +case class ForeignKey( + attr: Attribute, + referencedRelation: LogicalPlan, + referencedAttr: Attribute) extends Key { + assert(referencedRelation.resolved) + + override def transformAttribute(rule: PartialFunction[Attribute, Attribute]): Key = + ForeignKey(rule.applyOrElse(attr, identity[Attribute]), referencedRelation, referencedAttr) + + override def resolved: Boolean = attr.resolved && referencedAttr.resolved +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala new file mode 100644 index 000000000000..5426f96aab8c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinEliminationSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.logical.ForeignKey +import org.apache.spark.sql.catalyst.plans.logical.KeyHint +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.UniqueKey +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class JoinEliminationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("JoinElimination", Once, JoinElimination) :: Nil + } + + val customer = { + val r = LocalRelation('customerId.int.notNull, 'customerName.string) + KeyHint(List(UniqueKey(r.output(0))), r) + } + val employee = { + val r = LocalRelation('employeeId.int.notNull, 'employeeName.string) + KeyHint(List(UniqueKey(r.output(0))), r) + } + val order = { + val r = LocalRelation( + 'orderId.int.notNull, 'o_customerId.int.notNull, 'o_employeeId.int) + KeyHint(List( + UniqueKey(r.output(0)), + ForeignKey(r.output(1), customer, customer.output(0)), + ForeignKey(r.output(2), employee, employee.output(0))), r) + } + val bannedCustomer = { + val r = LocalRelation('bannedCustomerName.string.notNull) + KeyHint(List(UniqueKey(r.output(0))), r) + } + + def checkJoinEliminated( + base: LogicalPlan, + join: LogicalPlan => LogicalPlan, + project: LogicalPlan => LogicalPlan, + projectAfterElimination: LogicalPlan => LogicalPlan): Unit = { + val query = project(join(base)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = projectAfterElimination(base).analyze + comparePlans(optimized, correctAnswer) + } + + def checkJoinEliminated( + base: LogicalPlan, + join: LogicalPlan => LogicalPlan, + project: LogicalPlan => LogicalPlan): Unit = { + checkJoinEliminated(base, join, project, project) + } + + def checkJoinNotEliminated( + base: LogicalPlan, + join: LogicalPlan => LogicalPlan, + project: LogicalPlan => LogicalPlan): Unit = { + val query = project(join(base)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + comparePlans(optimized, correctAnswer) + } + + test("eliminate unique key left outer join") { + checkJoinEliminated( + customer, + _.join(bannedCustomer, LeftOuter, Some('customerName === 'bannedCustomerName)), + _.select('customerId, 'customerName)) + } + + test("do not eliminate unique key inner join") { + checkJoinNotEliminated( + customer, + _.join(bannedCustomer, Inner, Some('customerName === 'bannedCustomerName)), + _.select('customerId, 'customerName)) + } + + test("do not eliminate unique key full outer join") { + checkJoinNotEliminated( + customer, + _.join(bannedCustomer, FullOuter, Some('customerName === 'bannedCustomerName)), + _.select('customerId, 'customerName)) + } + + test("do not eliminate referential integrity inner join where foreign key is nullable") { + checkJoinNotEliminated( + order, + _.join(employee, Inner, Some('employeeId === 'o_employeeId)), + _.select('orderId, 'employeeId)) + } + + test("eliminate referential integrity inner join when foreign key is not null") { + checkJoinEliminated( + order, + _.join(customer, Inner, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId), + _.select('orderId, 'o_customerId.as('customerId))) + } + + test("eliminate referential integrity left/right outer join when foreign key is not null") { + checkJoinEliminated( + order, + _.join(customer, LeftOuter, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId), + _.select('orderId, 'o_customerId.as('customerId))) + + checkJoinEliminated( + order, + customer.join(_, RightOuter, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId), + _.select('orderId, 'o_customerId.as('customerId))) + } + + test("do not eliminate referential integrity full outer join") { + checkJoinNotEliminated( + order, + _.join(customer, FullOuter, Some('customerId === 'o_customerId)), + _.select('orderId, 'customerId)) + } + + test("eliminate referential integrity outer join despite alias") { + checkJoinEliminated( + order, + _.join(customer.select('customerId.as('customerId_alias), 'customerName), + LeftOuter, Some('customerId_alias === 'o_customerId)), + _.select('orderId, 'customerId_alias), + _.select('orderId, 'o_customerId.as('customerId_alias))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f2d4db555027..dbe3ef83e7fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.optimizer.KeyHintCollapsing import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} @@ -720,6 +721,36 @@ class DataFrame private[sql]( */ def as(alias: Symbol): DataFrame = as(alias.name) + /** + * :: Experimental :: + * Declares that the values of the given column are unique. + */ + @Experimental + def uniqueKey(col: String): DataFrame = withPlan { + KeyHintCollapsing(KeyHint(List(UniqueKey(UnresolvedAttribute(col))), logicalPlan)) + } + + /** + * :: Experimental :: + * Declares that the values of the given column reference a unique column from another + * [[DataFrame]]. The referenced column must be declared as a unique key within the referenced + * [[DataFrame]]: + * {{{ + * val department = dept.uniqueKey("id") + * employee.foreignKey("departmentId", department, "id") + * }}} + */ + @Experimental + def foreignKey(col: String, referencedDF: DataFrame, referencedCol: String): DataFrame = + withPlan { + KeyHintCollapsing( + KeyHint(List(ForeignKey( + UnresolvedAttribute(col), + referencedDF.logicalPlan, + UnresolvedAttribute(referencedCol))), + logicalPlan)) + } + /** * Returns a new [[DataFrame]] with an alias set. Same as `as`. * @group dfops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f4464e0b916f..4811e55948c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -461,6 +461,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "PhysicalRDD") :: Nil case BroadcastHint(child) => apply(child) + case logical.KeyHint(_, child) => apply(child) case _ => Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala new file mode 100644 index 000000000000..bda1828b6b18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/KeyHintSuite.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.TestSQLContext + +private object KeyHintSuite { + case class Customer(id: Int, name: String) + case class Employee(id: Int, name: String) + case class Order(id: Int, customerId: Int, employeeId: Option[Int]) + case class Manager(managerId: Int, subordinateId: Int) + case class BestFriend(id: Int, friendId: Int) + case class BannedCustomer(name: String) +} + +class KeyHintSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + import KeyHintSuite._ + + lazy val customer = sqlContext.sparkContext.parallelize(Seq( + Customer(0, "alice"), + Customer(1, "bob"), + Customer(2, "alice"))).toDF() + .uniqueKey("id") + lazy val employee = sqlContext.sparkContext.parallelize(Seq( + Employee(0, "charlie"), + Employee(1, "dan"))).toDF() + .uniqueKey("id") + lazy val order = sqlContext.sparkContext.parallelize(Seq( + Order(0, 0, Some(0)), + Order(1, 1, None))).toDF() + .foreignKey("customerId", customer, "id") + .foreignKey("employeeId", employee, "id") + lazy val manager = sqlContext.sparkContext.parallelize(Seq( + Manager(0, 1))).toDF() + .foreignKey("managerId", employee, "id") + .foreignKey("subordinateId", employee, "id") + lazy val bestFriend = { + val tmp = sqlContext.sparkContext.parallelize(Seq( + BestFriend(0, 1), + BestFriend(1, 2), + BestFriend(2, 0))).toDF() + .uniqueKey("id") + tmp.foreignKey("friendId", tmp, "id") + } + lazy val bannedCustomer = sqlContext.sparkContext.parallelize(Seq( + BannedCustomer("alice"), + BannedCustomer("eve"))).toDF() + .uniqueKey("name") + + // Joins involving referential integrity (a foreign key referencing a unique key) + lazy val orderInnerJoinView = order + .join(customer, order("customerId") === customer("id")) + .join(employee, order("employeeId") === employee("id")) + + lazy val orderLeftOuterJoinView = order + .join(customer, order("customerId") === customer("id"), "left_outer") + .join(employee, order("employeeId") === employee("id"), "left_outer") + + lazy val orderRightOuterJoinView = employee.join( + customer.join(order, order("customerId") === customer("id"), "right_outer"), + order("employeeId") === employee("id"), "right_outer") + + lazy val orderCustomerFullOuterJoinView = order + .join(customer, order("customerId") === customer("id"), "full_outer") + + lazy val orderEmployeeFullOuterJoinView = order + .join(employee, order("employeeId") === employee("id"), "full_outer") + + lazy val managerInnerJoinView = manager + .join(employee.as("emp_manager"), manager("managerId") === $"emp_manager.id") + .join(employee.as("emp_subordinate"), manager("subordinateId") === $"emp_subordinate.id") + + lazy val bestFriendInnerJoinView = bestFriend + .join(bestFriend.as("bestFriend2"), bestFriend("friendId") === $"bestFriend2.id") + + // Joins involving only a unique key + lazy val bannedCustomerInnerJoinView = customer + .join(bannedCustomer, bannedCustomer("name") === customer("name")) + + lazy val bannedCustomerLeftOuterJoinView = customer + .join(bannedCustomer, bannedCustomer("name") === customer("name"), "left_outer") + + lazy val bannedCustomerFullOuterJoinView = customer + .join(bannedCustomer, bannedCustomer("name") === customer("name"), "full_outer") + + def checkJoinCount(df: DataFrame, joinCount: Int): Unit = { + val joins = df.queryExecution.optimizedPlan.collect { + case j: Join => j + } + assert(joins.size == joinCount) + } + + def checkJoinsEliminated(df: DataFrame): Unit = checkJoinCount(df, 0) + + test("no elimination") { + val orderInnerJoin = orderInnerJoinView + .select(order("id"), order("customerId"), customer("name"), + order("employeeId"), employee("name")) + checkAnswer(orderInnerJoin, Seq( + Row(0, 0, "alice", 0, "charlie"))) + + val orderLeftOuterJoin = orderLeftOuterJoinView + .select(order("id"), order("customerId"), customer("name"), + order("employeeId"), employee("name")) + checkAnswer(orderLeftOuterJoin, Seq( + Row(0, 0, "alice", 0, "charlie"), + Row(1, 1, "bob", null, null))) + + val orderRightOuterJoin = orderRightOuterJoinView + .select(order("id"), order("customerId"), customer("name"), + order("employeeId"), employee("name")) + checkAnswer(orderRightOuterJoin, Seq( + Row(0, 0, "alice", 0, "charlie"), + Row(1, 1, "bob", null, null))) + + val orderCustomerFullOuterJoin = orderCustomerFullOuterJoinView + .select(order("id"), customer("id"), customer("name")) + checkAnswer(orderCustomerFullOuterJoin, Seq( + Row(0, 0, "alice"), + Row(1, 1, "bob"), + Row(null, 2, "alice"))) + + val orderEmployeeFullOuterJoin = orderEmployeeFullOuterJoinView + .select(order("id"), employee("id"), employee("name")) + checkAnswer(orderEmployeeFullOuterJoin, Seq( + Row(0, 0, "charlie"), + Row(1, null, null), + Row(null, 1, "dan"))) + + val managerInnerJoin = managerInnerJoinView + .select(manager("managerId"), $"emp_manager.name", + manager("subordinateId"), $"emp_subordinate.name") + checkAnswer(managerInnerJoin, Seq( + Row(0, "charlie", 1, "dan"))) + + val bestFriendInnerJoin = bestFriendInnerJoinView + .select(bestFriend("id"), $"bestFriend2.id", $"bestFriend2.friendId") + checkAnswer(bestFriendInnerJoin, Seq( + Row(0, 1, 2), + Row(1, 2, 0), + Row(2, 0, 1))) + + val bannedCustomerInnerJoin = bannedCustomerInnerJoinView + .select(customer("id"), bannedCustomer("name")) + checkAnswer(bannedCustomerInnerJoin, Seq( + Row(0, "alice"), + Row(2, "alice"))) + + val bannedCustomerLeftOuterJoin = bannedCustomerLeftOuterJoinView + .select(customer("id"), bannedCustomer("name")) + checkAnswer(bannedCustomerLeftOuterJoin, Seq( + Row(0, "alice"), + Row(1, null), + Row(2, "alice"))) + + val bannedCustomerFullOuterJoin = bannedCustomerFullOuterJoinView + .select(customer("id"), bannedCustomer("name")) + checkAnswer(bannedCustomerFullOuterJoin, Seq( + Row(0, "alice"), + Row(1, null), + Row(2, "alice"), + Row(null, "eve"))) + } + + test("can't create foreign key referencing non-unique column") { + intercept[AnalysisException] { + bannedCustomer.foreignKey("name", customer, "name") + } + } + + test("eliminate unique key left outer join") { + val bannedCustomerJoinEliminated = bannedCustomerLeftOuterJoinView + .select(customer("id"), customer("name")) + checkAnswer(bannedCustomerJoinEliminated, customer) + checkJoinsEliminated(bannedCustomerJoinEliminated) + } + + test("do not eliminate unique key inner/full outer join") { + val bannedCustomerInnerJoinNotEliminated = bannedCustomerInnerJoinView + .select(customer("id"), customer("name")) + checkAnswer(bannedCustomerInnerJoinNotEliminated, Seq( + Row(0, "alice"), + Row(2, "alice"))) + + val bannedCustomerFullOuterJoinNotEliminated = bannedCustomerFullOuterJoinView + .select(customer("id"), customer("name")) + checkAnswer(bannedCustomerFullOuterJoinNotEliminated, Seq( + Row(0, "alice"), + Row(1, "bob"), + Row(2, "alice"), + Row(null, null))) + } + + test("do not eliminate referential integrity inner join where foreign key is nullable") { + val orderInnerJoin = orderInnerJoinView + .select(order("id"), customer("id"), employee("id")) + checkAnswer(orderInnerJoin, Seq( + Row(0, 0, 0))) + // Only the customer join should be eliminated + checkJoinCount(orderInnerJoinView, 2) + checkJoinCount(orderInnerJoin, 1) + } + + test("eliminate referential integrity join") { + val orderLeftOuterJoinEliminated = orderLeftOuterJoinView + .select(order("id"), customer("id"), employee("id")) + checkAnswer(orderLeftOuterJoinEliminated, Seq( + Row(0, 0, 0), + Row(1, 1, null))) + checkJoinsEliminated(orderLeftOuterJoinEliminated) + + val orderRightOuterJoinEliminated = orderRightOuterJoinView + .select(order("id"), customer("id"), employee("id")) + checkAnswer(orderRightOuterJoinEliminated, Seq( + Row(0, 0, 0), + Row(1, 1, null))) + checkJoinsEliminated(orderRightOuterJoinEliminated) + } + + test("do not eliminate referential integrity full outer join") { + val orderCustomerFullOuterJoinNotEliminated = orderCustomerFullOuterJoinView + .select(order("id"), order("customerId"), customer("id")) + checkAnswer(orderCustomerFullOuterJoinNotEliminated, Seq( + Row(0, 0, 0), + Row(1, 1, 1), + Row(null, null, 2))) + + val orderEmployeeFullOuterJoinNotEliminated = orderEmployeeFullOuterJoinView + .select(order("id"), order("employeeId"), employee("id")) + checkAnswer(orderEmployeeFullOuterJoinNotEliminated, Seq( + Row(0, 0, 0), + Row(1, null, null), + Row(null, null, 1))) + } + + test("eliminate referential integrity join despite multiple foreign keys with same referent") { + val managerInnerJoinEliminated = managerInnerJoinView + .select($"emp_manager.id", $"emp_subordinate.id") + checkAnswer(managerInnerJoinEliminated, manager) + checkJoinsEliminated(managerInnerJoinEliminated) + } + + test("eliminate referential integrity self-join") { + val bestFriendInnerJoinEliminated = bestFriendInnerJoinView + .select(bestFriend("id"), $"bestFriend2.id") + checkAnswer(bestFriendInnerJoinEliminated, Seq( + Row(0, 1), + Row(1, 2), + Row(2, 0))) + checkJoinsEliminated(bestFriendInnerJoinEliminated) + } +}